semi supervised gans

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report
import joblib

# === Load Data ===
df = pd.read_csv("new_network_train.csv")
X = df.drop(columns=["ProtocolName"]).values
y = df["ProtocolName"].astype(np.int64).values

# === Preprocessing ===
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
joblib.dump(scaler, "scaler.pkl")

X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, stratify=y)
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)

train_loader = DataLoader(TensorDataset(X_train_tensor, y_train_tensor), batch_size=128, shuffle=True)
test_loader = DataLoader(TensorDataset(X_test_tensor, y_test_tensor), batch_size=128)

input_dim = X.shape[1]
num_classes = len(np.unique(y))
latent_dim = 100
label_smooth_real = 0.9

# === One-hot encoder ===
def one_hot(labels, num_classes):
    return torch.eye(num_classes, device=labels.device)[labels]

# === Focal Loss ===
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0):
        super().__init__()
        self.gamma = gamma

    def forward(self, input, target):
        logpt = F.log_softmax(input, dim=1)
        pt = torch.exp(logpt)
        loss = (1 - pt) ** self.gamma * logpt
        return F.nll_loss(loss, target)

# === Generator ===
class Generator(nn.Module):
    def __init__(self, noise_dim, label_dim, output_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim + label_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, output_dim),
        )

    def forward(self, noise, labels):
        labels = one_hot(labels, num_classes).to(noise.device)
        x = torch.cat([noise, labels], dim=1)
        return self.model(x)

# === Discriminator ===
class Discriminator(nn.Module):
    def __init__(self, input_dim, label_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim + label_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, data, labels):
        labels = one_hot(labels, num_classes).to(data.device)
        x = torch.cat([data, labels], dim=1)
        return self.model(x)

# === CNN Classifier ===
class CNNClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        return self.net(x)

# === Pseudo-labeling function ===
def get_confident_pseudo_labels(model, data, threshold=0.9):
    model.eval()
    data = data.to(device)  # Ensure data is on the same device
    with torch.no_grad():
        outputs = model(data)
        probs = F.softmax(outputs, dim=1)
        confidences, predictions = torch.max(probs, dim=1)
        mask = confidences >= threshold
        confident_data = data[mask]
        confident_labels = predictions[mask]
    return confident_data.cpu(), confident_labels.cpu()

# === Init Models ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
G = Generator(latent_dim, num_classes, input_dim).to(device)
D = Discriminator(input_dim, num_classes).to(device)
clf = CNNClassifier(input_dim, num_classes).to(device)

# === Losses and Optimizers ===
bce = nn.BCELoss()
focal_loss = FocalLoss()
opt_G = optim.Adam(G.parameters(), lr=0.0002)
opt_D = optim.Adam(D.parameters(), lr=0.0002)
opt_C = optim.Adam(clf.parameters(), lr=0.001)
scheduler_C = optim.lr_scheduler.StepLR(opt_C, step_size=15, gamma=0.5)

# === Training Loop ===
epochs = 100
for epoch in range(epochs):
    G.train(); D.train(); clf.train()
    total_d_loss, total_g_loss, total_c_loss = 0, 0, 0

    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        batch_size = xb.size(0)

        # === Train Discriminator ===
        real_labels = torch.full((batch_size, 1), label_smooth_real).to(device)
        fake_labels = torch.zeros((batch_size, 1)).to(device)

        z = torch.randn(batch_size, latent_dim).to(device)
        fake_data = G(z, yb)

        D_real = D(xb, yb)
        D_fake = D(fake_data.detach(), yb)

        loss_D = bce(D_real, real_labels) + bce(D_fake, fake_labels)
        opt_D.zero_grad()
        loss_D.backward()
        opt_D.step()

        # === Train Generator ===
        D_fake = D(fake_data, yb)
        loss_G = bce(D_fake, real_labels)
        opt_G.zero_grad()
        loss_G.backward()
        opt_G.step()

        # === Train Classifier on real + fake ===
        synthetic_data = G(torch.randn(batch_size, latent_dim).to(device), yb)
        combined_data = torch.cat([xb, synthetic_data], dim=0)
        combined_labels = torch.cat([yb, yb], dim=0)

        preds = clf(combined_data)
        loss_C = focal_loss(preds, combined_labels)
        opt_C.zero_grad()
        loss_C.backward()
        opt_C.step()

        total_d_loss += loss_D.item()
        total_g_loss += loss_G.item()
        total_c_loss += loss_C.item()

    # === Semi-Supervised Learning: Pseudo-labeling ===
    pseudo_data, pseudo_labels = get_confident_pseudo_labels(clf, X_test_tensor)
    if len(pseudo_data) > 1:
        combined_loader = DataLoader(
            TensorDataset(torch.cat([X_train_tensor, pseudo_data], dim=0),
                          torch.cat([y_train_tensor, pseudo_labels], dim=0)),
            batch_size=128, shuffle=True
        )

        clf.train()
        for xb, yb in combined_loader:
            xb, yb = xb.to(device), yb.to(device)
            preds = clf(xb)
            loss = focal_loss(preds, yb)
            opt_C.zero_grad()
            loss.backward()
            opt_C.step()

    scheduler_C.step()
    print(f"[Epoch {epoch+1}] D Loss: {total_d_loss:.4f}, G Loss: {total_g_loss:.4f}, C Loss: {total_c_loss:.4f}")

# === Evaluation ===
clf.eval()
with torch.no_grad():
    preds = clf(X_test_tensor.to(device))
    y_pred = preds.argmax(dim=1).cpu().numpy()
    y_true = y_test_tensor.cpu().numpy()

label_map = {
    0: 'AMAZON', 1: 'CLOUDFLARE', 2: 'DROPBOX', 3: 'FACEBOOK',
    4: 'GMAIL', 5: 'GOOGLE', 6: 'HTTP', 7: 'HTTP_CONNECT', 8: 'HTTP_PROXY',
    9: 'MICROSOFT', 10: 'MSN', 11: 'SKYPE', 12: 'SSL', 13: 'TWITTER',
    14: 'WINDOWS_UPDATE', 15: 'YAHOO', 16: 'YOUTUBE'
}
print("\n✅ Accuracy:", accuracy_score(y_true, y_pred))
print("\n📊 Classification Report:\n", classification_report(y_true, y_pred, target_names=[label_map[i] for i in sorted(label_map)]))

# === Save Models ===
torch.save(G.state_dict(), "generator_cgan_ssl.pth")
torch.save(D.state_dict(), "discriminator_cgan_ssl.pth")
torch.save(clf.state_dict(), "classifier_cgan_ssl.pth")


[Epoch 1] D Loss: 631.5705, G Loss: 2532.6959, C Loss: 1442.7490
[Epoch 2] D Loss: 678.3556, G Loss: 2220.0775, C Loss: 1313.9520
[Epoch 3] D Loss: 673.1131, G Loss: 2066.8701, C Loss: 1268.8183
[Epoch 4] D Loss: 658.4293, G Loss: 2033.3208, C Loss: 996.4584
[Epoch 5] D Loss: 678.4546, G Loss: 1933.0647, C Loss: 659.6484
[Epoch 6] D Loss: 718.6039, G Loss: 1792.0583, C Loss: 523.7967
[Epoch 7] D Loss: 732.2376, G Loss: 1735.9393, C Loss: 487.0592
[Epoch 8] D Loss: 759.5357, G Loss: 1656.9688, C Loss: 468.6067
[Epoch 9] D Loss: 795.6981, G Loss: 1556.6133, C Loss: 427.4745
[Epoch 10] D Loss: 804.8653, G Loss: 1527.0469, C Loss: 411.1072
[Epoch 11] D Loss: 791.7677, G Loss: 1541.9570, C Loss: 401.4506
[Epoch 12] D Loss: 807.9213, G Loss: 1528.5791, C Loss: 388.9573
[Epoch 13] D Loss: 812.0208, G Loss: 1536.3853, C Loss: 356.9381
[Epoch 14] D Loss: 821.2800, G Loss: 1502.9984, C Loss: 335.4334
[Epoch 15] D Loss: 830.1162, G Loss: 1477.1575, C Loss: 321.1355
[Epoch 16] D Loss: 832.8014, G 