<a href="https://colab.research.google.com/github/Kashara-Alvin-Ssali/ML-models/blob/main/VanillaGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from IPython import get_ipython
from IPython.display import display

In [None]:
!pip install torch torchvision matplotlib
!pip install pytorch-fid



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.utils as vutils
from torch.utils.data import DataLoader
import numpy as np
import os
import zipfile

In [None]:
# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# %%

In [None]:
# # Data Preprocessing (Modified for 256x256 images)
# transform = transforms.Compose([
#     transforms.Resize((256, 256)),  # Resize images to 256x256
#     transforms.ToTensor(),        # Convert to tensor
#     transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
# ])

In [None]:
from google.colab import drive
# 1. Mount Google Drive
drive.mount('/content/drive')

# 2. Specify dataset path
dataset_path = "/content/drive/MyDrive/Dataset4"  # Path to your dataset folder in Drive

# 3. Data Preprocessing (Modified for 256x256 images)
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize images to 256x256
    transforms.ToTensor(),        # Convert to tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

# 4. Create dataset and dataloader
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Define Generator (Modified for 256x256 images)
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 256 * 256 * 3),  # Output for 256x256 images
            nn.Tanh()
        )

    def forward(self, z):
        output = self.main(z)
        # Reshape and upsample
        output = output.view(-1, 3, 256, 256)  # Reshape for 256x256
        return output

In [None]:
# Define Discriminator (Modified for 256x256 images)
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(256 * 256 * 3, 1024),  # Input for 256x256 images
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        return self.main(img_flat)

In [None]:
# Define FGSM attack function
def fgsm_attack(image, epsilon, data_grad):
    # Collect the element-wise sign of the data gradient
    sign_data_grad = data_grad.sign()
    # Create the perturbed image by adjusting each pixel of the input image
    perturbed_image = image + epsilon*sign_data_grad
    # Adding clipping to maintain [0,1] range
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    # Return the perturbed image
    return perturbed_image
# %%

In [None]:
# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [None]:
# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
# Loss function
criterion = nn.BCELoss()

In [None]:
import os
import torch
import torch.nn as nn
import torchvision.utils as vutils

# Ensure 'generated_images' directory exists
os.makedirs("generated_images", exist_ok=True)

# Hyperparameters
num_epochs = 50
epsilon = 0.001  # FGSM attack strength
patience = 10  # Early stopping patience
early_stop_counter = 0
best_g_loss = float('inf')

def fgsm_attack(image, epsilon, data_grad):
    """Generate adversarial example using FGSM attack."""
    perturbed_image = image + epsilon * data_grad.sign()
    perturbed_image = torch.clamp(perturbed_image, 0, 1)  # Keep image valid
    return perturbed_image

for epoch in range(num_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        real_imgs = imgs.to(device)
        real_labels = torch.ones(1, 1).to(device)
        fake_labels = torch.zeros(1, 1).to(device)

        # Train Discriminator
        optimizer_D.zero_grad()
        real_imgs.requires_grad = True
        outputs = discriminator(real_imgs)
        d_loss_real = criterion(outputs, real_labels)

        # Generate adversarial examples
        d_loss_real.backward(retain_graph=True)
        data_grad = real_imgs.grad.data
        perturbed_data = fgsm_attack(real_imgs, epsilon, data_grad)

        outputs = discriminator(perturbed_data)
        d_loss_adv = criterion(outputs, real_labels)
        d_loss_adv.backward()

        # Train on fake images
        z = torch.randn(1, 100).to(device)
        fake_imgs = generator(z)
        outputs = discriminator(fake_imgs.detach())
        d_loss_fake = criterion(outputs, fake_labels)
        d_loss_fake.backward()

        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        outputs = discriminator(fake_imgs)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        optimizer_G.step()

    total_d_loss = (d_loss_real + d_loss_adv + d_loss_fake).item()
    print(f"Epoch [{epoch+1}/{num_epochs}] | D Loss: {total_d_loss:.4f} | G Loss: {g_loss.item():.4f}")

    # Save generated images every 10 epochs
    if epoch % 10 == 0:
        vutils.save_image(fake_imgs, f"generated_images/epoch_{epoch}_image.png", normalize=True)

    # Early stopping logic
    if g_loss.item() < best_g_loss:
        best_g_loss = g_loss.item()
        early_stop_counter = 0
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print(f"Early stopping triggered at epoch {epoch+1}.")
            break


Epoch [1/50] | D Loss: 5.2612 | G Loss: 0.0192
Epoch [2/50] | D Loss: 1.1051 | G Loss: 0.8174
Epoch [3/50] | D Loss: 1.1987 | G Loss: 0.9690
Epoch [4/50] | D Loss: 2.2747 | G Loss: 0.1961
Epoch [5/50] | D Loss: 2.0314 | G Loss: 0.2203
Epoch [6/50] | D Loss: 2.4930 | G Loss: 1.2456
Epoch [7/50] | D Loss: 1.4488 | G Loss: 0.2336
Epoch [8/50] | D Loss: 1.7991 | G Loss: 0.3156
Epoch [9/50] | D Loss: 0.4593 | G Loss: 1.2052
Epoch [10/50] | D Loss: 7.2327 | G Loss: 0.0305
Epoch [11/50] | D Loss: 0.6467 | G Loss: 0.9491
Early stopping triggered at epoch 11.
