# Semi-Supervised Learning with Generative Adversarial Networks

In [60]:
import sys
import setproctitle
sys.path.append('../src')

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
import torch.optim as optim
import numpy as np

from dataset import load_data
from sklearn.metrics import recall_score

import matplotlib.pyplot as plt


In [61]:
setproctitle.setproctitle("SSGAN") # Shows string as Process name

In [62]:
gpu_id = input("Enter GPU index: ") # Pick a GPU  

# Set device to use GPU if available, otherwise use CPU
if int(gpu_id) >= 0 and torch.cuda.is_available():
    device = torch.device("cuda:" + gpu_id)
    print("Selected device:", torch.cuda.get_device_name(device))
    print(f"cuda:{gpu_id}")
else:
    print("No GPU available or no GPU index specified, using CPU instead.")
    device = torch.device("cpu")

Selected device: NVIDIA RTX A5000
cuda:1


In [63]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

Device: cuda


## Load Data

In [64]:
train_loader, valid_loader, test_loader = load_data(batch_size=128)

## SSGAN

In [65]:
class Discriminator(nn.Module):
  def __init__(self, in_features, n_classes):
    super().__init__()
    self.in_features = in_features

    # Discriminator will down-sample the input producing a binary output
    self.fc1 = nn.Linear(in_features=in_features, out_features=64)
    self.leaky_relu1 = nn.LeakyReLU(negative_slope=0.2)
    self.fc2 = nn.Linear(in_features=64, out_features=64)
    self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.2)
    self.fc3 = nn.Linear(in_features=64, out_features=32)
    self.leaky_relu3 = nn.LeakyReLU(negative_slope=0.2)
    #self.fc4 = nn.Linear(in_features=32, out_features=out_features)
    self.dropout = nn.Dropout(0.3)

    self.adv_layer = nn.Linear(in_features=32, out_features=1) # Fake or Real
    self.aux_layer = nn.Linear(in_features=32, out_features=n_classes+1) # Class + fake

    self.sigmoid =  nn.Sigmoid()



  def forward(self, x):
    # Rehape passed image batch
    batch_size = x.shape[0]
    x = x.view(batch_size, -1)
    # Feed forward
    x = self.fc1(x)
    x = self.leaky_relu1(x)
    #x = self.dropout(x)
    x = self.fc2(x)
    x = self.leaky_relu2(x)
    #x = self.dropout(x)
    x = self.fc3(x)
    x = self.leaky_relu3(x)
    #x = self.dropout(x)
    validity = self.sigmoid(self.adv_layer(x))
    label_logits = self.aux_layer(x)

    return validity, label_logits

In [66]:
class Generator(nn.Module):
  def __init__(self, in_features, out_features):
    super(Generator, self).__init__()
    self.in_features = in_features
    self.out_features = out_features

    # Generator will up-sample the input producing input of size
    # suitable for feeding into discriminator
    self.fc1 = nn.Linear(in_features=in_features, out_features=32)
    self.relu1 = nn.LeakyReLU(negative_slope=0.2)
    self.fc2 = nn.Linear(in_features=32, out_features=64)
    self.relu2 = nn.LeakyReLU(negative_slope=0.2)
    self.fc3 = nn.Linear(in_features=64, out_features=128)
    self.relu3 = nn.LeakyReLU(negative_slope=0.2)
    self.fc4 = nn.Linear(in_features=128, out_features=out_features)
    self.dropout = nn.Dropout(0.3)
    self.tanh = nn.Tanh()


  def forward(self, x):
    # Feed forward
    x = self.fc1(x)
    x = self.relu1(x)
    x = self.dropout(x)
    x = self.fc2(x)
    x = self.relu2(x)
    x = self.dropout(x)
    x = self.fc3(x)
    x = self.relu3(x)
    x = self.dropout(x)
    x = self.fc4(x)
    tanh_out = self.tanh(x)

    return tanh_out

## Train

In [67]:
z_size = 100
n_epochs = 100

model_d = Discriminator(in_features=49, n_classes=7).to(device)
model_g = Generator(in_features=z_size, out_features=49).to(device)

d_optim = optim.Adam(model_d.parameters(), lr=0.002, betas=(0.5, 0.999))
g_optim = optim.Adam(model_g.parameters(), lr=0.002, betas=(0.5, 0.999))

adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()

In [68]:
validation_recall = []

In [69]:
for epoch in range(n_epochs):
  # Switch the training mode on
  model_d.train()
  model_g.train()
  d_running_batch_loss = 0
  g_running_batch_loss = 0
  d_batch_acc = 0
  d_batch_real_acc = 0
  d_batch_fake_acc = 0

  # Batches
  for curr_batch, (real_images, real_target) in enumerate(train_loader):
    real_images, real_target = real_images.to(device), real_target.to(device)
    real_images = (real_images * 2) - 1
    batch_size = real_images.shape[0]

    # -----------------
    #  Train Generator
    # -----------------
    g_optim.zero_grad()
    # Sample Noise
    z = np.random.uniform(-1, 1, size=(batch_size, z_size))
    z = torch.from_numpy(z).float().to(device)
    # Generate a batch of images
    gen_imgs = model_g(z)
    # Loss measures generator's ability to fool the discriminator
    validity, _ = model_d(gen_imgs) # Model input
    valid = torch.ones(batch_size, 1).to(device) # all 1 for all fake images
    g_loss = adversarial_loss(validity, valid)

    g_loss.backward()
    g_optim.step()
    #print(f"g_loss: {g_loss}")

    # ---------------------
    #  Train Discriminator
    # ---------------------
    d_optim.zero_grad()

    # Loss for real images
    real_pred, real_aux = model_d(real_images)
    d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, real_target.view(-1))) / 2
    #print(f"d_real_loss: {d_real_loss}")

    # Loss for fake images
    fake_pred, fake_aux = model_d(gen_imgs.detach()) # Model input
    fake = torch.zeros(batch_size, 1).to(device) # all 0 for all fake images
    fake_aux_gt = torch.full((batch_size,), 7).to(device) # 10 = n_classes
    d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, fake_aux_gt)) / 2
    #print(f"d_fake_loss: {d_fake_loss}")

    # Total discriminator loss
    d_loss = (d_real_loss + d_fake_loss) / 2

    # Calculate discriminator accuracy
    real_aux = real_aux.data.cpu().numpy()
    real_target = real_target.data.cpu().numpy()
    fake_aux = fake_aux.data.cpu().numpy()
    fake_aux_gt = fake_aux_gt.data.cpu().numpy()
    pred = np.concatenate([real_aux, fake_aux], axis=0)
    gt = np.concatenate([real_target.flatten(), fake_aux_gt], axis=0)
    d_acc = np.mean(np.argmax(pred, axis=1) == gt)

    d_real_acc = np.mean(np.argmax(real_aux, axis=1) == real_target)
    d_fake_acc = np.mean(np.argmax(fake_aux, axis=1) == fake_aux_gt)

    d_loss.backward()
    d_optim.step()

    # Log loss
    d_running_batch_loss += d_loss
    g_running_batch_loss += g_loss
    d_batch_acc += d_acc
    d_batch_real_acc += d_real_acc
    d_batch_fake_acc += d_fake_acc

  v_y_true = []
  v_y_pred = []
  for _, (real_images, real_target) in enumerate(valid_loader):
    model_d.eval()
    with torch.no_grad():
      _, real_aux = model_d(real_images.to(device))
    real_aux = real_aux.data.cpu().numpy()
    real_target = real_target.data.cpu().numpy()
    v_y_pred += np.argmax(real_aux, axis=1).tolist()
    v_y_true += real_target.tolist()

  val_recall_per_class = recall_score(v_y_true, v_y_pred, average=None) 
  validation_recall.append(val_recall_per_class)







  d_running_batch_loss = d_running_batch_loss /  len(train_loader)
  g_running_batch_loss = g_running_batch_loss /  len(train_loader)
  d_batch_acc = d_batch_acc /  len(train_loader)
  d_batch_real_acc = d_batch_real_acc /  len(train_loader)
  d_batch_fake_acc = d_batch_fake_acc /  len(train_loader)
  print(f'Epoch: {epoch} \tepoch_d_loss: {d_running_batch_loss:.6f} \tepoch_g_loss: {g_running_batch_loss:.6f} \td_acc: {d_batch_acc:.2f} \treal_acc: {d_batch_real_acc:.2f} \tfake_acc: {d_batch_fake_acc:.2f}')
  print(f'Epoch: {epoch} \tValidation Class Recall {val_recall_per_class}')

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 0 	epoch_d_loss: 0.292428 	epoch_g_loss: 1.971591 	d_acc: 0.88 	real_acc: 0.57 	fake_acc: 0.97
Epoch: 0 	Validation Class Recall [0.72234436 0.         0.         0.         0.         0.
 0.         0.        ]


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch: 1 	epoch_d_loss: 0.214265 	epoch_g_loss: 2.429528 	d_acc: 0.91 	real_acc: 0.60 	fake_acc: 0.98
Epoch: 1 	Validation Class Recall [9.99987718e-01 0.00000000e+00 0.00000000e+00 8.12931870e-04
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]


KeyboardInterrupt: 

In [None]:
len(v_y_true) / len(v_y_pred)

0.5938748335552596