
# **Step-by-Step Tutorial for Generating Synthetic Images with GAN**

This tutorial will guide you step-by-step to set up a Generative Adversarial Network (GAN) for generating synthetic images. The tutorial includes importing required libraries, defining the GAN architecture, preparing data, training the model, and generating synthetic images.

## **1. Set Up the Environment and Import Libraries**

Before starting, ensure you have the required libraries installed. Import the necessary modules and define basic parameters for the GAN.

In [1]:
# Import libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image

# Define parameters
latent_dim = 64      # Dimension of latent noise
image_size = 256     # Final image resolution
channels = 3         # RGB images
batch_size = 8       # Number of images per batch
num_epochs = 200     # Number of training epochs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

Using device: cuda


## **2. Define the Generator and Discriminator**

### **Residual Block**
The building block for the Generator, enhancing its ability to generate high-quality images.

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(in_channels),
        )

    def forward(self, x):
        return x + self.block(x)

### **Generator**
The Generator takes a random noise vector and creates synthetic images.

In [3]:
class Generator(nn.Module):
    def __init__(self, latent_dim, channels):
        super(Generator, self).__init__()
        self.init_size = image_size // 8  # Initial image resolution (32x32)
        self.fc = nn.Linear(latent_dim, 128 * self.init_size ** 2)

        self.upsample_blocks = nn.ModuleList([
            nn.Sequential(
                ResidualBlock(128),
                nn.Upsample(scale_factor=2),
                nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
            ),
            nn.Sequential(
                ResidualBlock(128),
                nn.Upsample(scale_factor=2),
                nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
            ),
            nn.Sequential(
                ResidualBlock(64),
                nn.Upsample(scale_factor=2),
                nn.Conv2d(64, channels, kernel_size=3, stride=1, padding=1)
            ),
        ])

    def forward(self, z):
        out = self.fc(z).view(z.shape[0], 128, self.init_size, self.init_size)
        for block in self.upsample_blocks:
            out = block(out)
        return torch.tanh(out)

### **Discriminator**
The Discriminator determines whether an image is real or generated (synthetic).

In [4]:
class Discriminator(nn.Module):
    def __init__(self, channels):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(channels, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(512 * (image_size // 16) ** 2, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img)

### **Initialize Models**
Create instances of the Generator and Discriminator, and move them to the specified device (CPU or GPU).

In [5]:
# Initialize generator and discriminator
generator = Generator(latent_dim, channels).to(device)
discriminator = Discriminator(channels).to(device)

# Print model summaries
print(generator)
print(discriminator)

Generator(
  (fc): Linear(in_features=64, out_features=131072, bias=True)
  (upsample_blocks): ModuleList(
    (0): Sequential(
      (0): ResidualBlock(
        (block): Sequential(
          (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): Upsample(scale_factor=2.0, mode='nearest')
      (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1): Sequential(
      (0): ResidualBlock(
        (block): Sequential(
          (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2)

## **3. Prepare the Dataset**
Define the dataset and preprocessing steps, such as resizing and normalizing the images.

In [6]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])

class MalignantDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.image_paths = list(image_folder.glob("*.jpg"))
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

# Load dataset
from pathlib import Path

image_folder = Path('/content/malignant_images')  # Path to malignant images
dataset = MalignantDataset(image_folder, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

print(f"Loaded {len(dataset)} images.")

Loaded 393 images.


## **4. Set Up Training**
Define the training loop for the GAN.

In [7]:
# Define loss and optimizers
adversarial_loss = nn.MSELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.00005, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.00001, betas=(0.5, 0.999))

for epoch in range(num_epochs):
    for i, real_imgs in enumerate(dataloader):
        real_imgs = real_imgs.to(device)
        batch_size = real_imgs.size(0)

        # Create labels
        valid = torch.full((batch_size, 1), 0.9, device=device)  # Label smoothing
        fake = torch.zeros((batch_size, 1), device=device)

        # Train Generator
        optimizer_G.zero_grad()
        z = torch.randn(batch_size, latent_dim, device=device)
        gen_imgs = generator(z)
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        if i % 50 == 0:
            print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")

[Epoch 0/200] [Batch 0/50] [D loss: 0.20577281713485718] [G loss: 0.15430453419685364]
[Epoch 1/200] [Batch 0/50] [D loss: 0.04977016523480415] [G loss: 0.46677467226982117]
[Epoch 2/200] [Batch 0/50] [D loss: 0.028064608573913574] [G loss: 0.4809410572052002]
[Epoch 3/200] [Batch 0/50] [D loss: 0.009479183703660965] [G loss: 0.6767597794532776]
[Epoch 4/200] [Batch 0/50] [D loss: 0.02255299687385559] [G loss: 0.6514171361923218]
[Epoch 5/200] [Batch 0/50] [D loss: 0.00865771621465683] [G loss: 0.678687334060669]
[Epoch 6/200] [Batch 0/50] [D loss: 0.021519284695386887] [G loss: 0.6877313256263733]
[Epoch 7/200] [Batch 0/50] [D loss: 0.005735539365559816] [G loss: 0.7362111806869507]
[Epoch 8/200] [Batch 0/50] [D loss: 0.010725753381848335] [G loss: 0.62656569480896]
[Epoch 9/200] [Batch 0/50] [D loss: 0.006711430847644806] [G loss: 0.6571630239486694]
[Epoch 10/200] [Batch 0/50] [D loss: 0.0032674630638211966] [G loss: 0.755791187286377]
[Epoch 11/200] [Batch 0/50] [D loss: 0.00843473

## **5. Fine-Tune the GAN for Refinement**
After training the GAN, fine-tuning is a great way to further refine the Generator and Discriminator. This step helps improve the quality of synthetic images by gradually adjusting the model weights with lower learning rates.

Why Fine-Tune?
  * Improve Image Quality: Focus on enhancing details in the generated images.
  * Stabilize Training: Address issues like overfitting or mode collapse.
  * Adaptability: Fine-tune the model for specific subsets of data or slightly different distributions.

Steps for Fine-Tuning
  1. Lower the learning rates for both the Generator and Discriminator to ensure gradual updates.
  2. Train for a reduced number of epochs (e.g., 10–20) since the model is already close to convergence.
  3. Monitor Discriminator and Generator losses to ensure stability.
  4. Save and visualize generated images periodically to assess progress and quality.

In [10]:
import os
from torchvision.utils import save_image

# Ensure the directory exists
os.makedirs('/content/generated_images', exist_ok=True)

# Fine-tuning training loop
for epoch in range(fine_tune_epochs):
    for i, real_imgs in enumerate(dataloader):  # Use the existing dataloader
        real_imgs = real_imgs.to(device)
        batch_size = real_imgs.size(0)

        # Create labels with slight smoothing
        valid = torch.full((batch_size, 1), 0.9, device=device)
        fake = torch.zeros(batch_size, 1, device=device)

        # ---------------------
        #  Train Generator
        # ---------------------
        optimizer_G.zero_grad()
        z = torch.randn(batch_size, latent_dim, device=device)
        gen_imgs = generator(z)

        # Generator loss
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()

        # Discriminator loss on real and fake images
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # Print progress
        if i % 50 == 0:
            print(f"[Fine-Tuning Epoch {epoch+1}/{fine_tune_epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")

    # Save generated images every 5 epochs
    if epoch % 5 == 0:
        save_image(gen_imgs.data[:25], f"/content/generated_images/fine_tune_epoch_{epoch}.png", nrow=5, normalize=True)

print("Fine-tuning completed!")

[Fine-Tuning Epoch 1/20] [Batch 0/50] [D loss: 0.10170271247625351] [G loss: 0.2454090118408203]
[Fine-Tuning Epoch 2/20] [Batch 0/50] [D loss: 0.021699998527765274] [G loss: 0.507217288017273]
[Fine-Tuning Epoch 3/20] [Batch 0/50] [D loss: 0.15073662996292114] [G loss: 0.4951837956905365]
[Fine-Tuning Epoch 4/20] [Batch 0/50] [D loss: 0.0728784054517746] [G loss: 0.3299887478351593]
[Fine-Tuning Epoch 5/20] [Batch 0/50] [D loss: 0.06050339341163635] [G loss: 0.34377580881118774]
[Fine-Tuning Epoch 6/20] [Batch 0/50] [D loss: 0.028382904827594757] [G loss: 0.45027947425842285]
[Fine-Tuning Epoch 7/20] [Batch 0/50] [D loss: 0.031564630568027496] [G loss: 0.46366947889328003]
[Fine-Tuning Epoch 8/20] [Batch 0/50] [D loss: 0.029923152178525925] [G loss: 0.515169620513916]
[Fine-Tuning Epoch 9/20] [Batch 0/50] [D loss: 0.04600841552019119] [G loss: 0.43605679273605347]
[Fine-Tuning Epoch 10/20] [Batch 0/50] [D loss: 0.03361981362104416] [G loss: 0.49987784028053284]
[Fine-Tuning Epoch 11/2

## **Understanding the Results After Fine-Tuning**

### **Discriminator Loss (D loss):**
#### **What it Measures:**
The ability of the Discriminator to distinguish real images from generated ones.

#### **Key Observations:**
- Lower values (close to 0) indicate the Discriminator is effective at distinguishing between real and fake images.
- A very low D loss might suggest the Discriminator is overpowering the Generator, but in this case, the D loss fluctuates within a healthy range (0.02–0.15).
- This balance indicates a well-trained Discriminator that isn’t dominating the adversarial training process.

### **Generator Loss (G loss):**
#### **What it Measures:**
How well the Generator is fooling the Discriminator.

#### **Key Observations:**
- A moderately low G loss (not too close to 0) suggests the Generator is learning effectively and producing realistic outputs.
- Your G loss values (0.24–0.57) show steady improvement and balance, indicating the Generator is becoming better at creating convincing images without overfitting or collapsing.

### **Overall Training Dynamics:**
- As fine-tuning progresses, both the D loss and G loss stabilize, signaling healthy adversarial training.
- The fluctuations in the loss values are expected and indicate dynamic interaction between the Generator and Discriminator.
- The Generator’s outputs improve over time, as reflected in both the loss values and the quality of the generated images.


## **Generate Synthetic Images**
After training, generate synthetic images.

In [None]:
import os

# Create output folder
output_folder = '/content/synthetic_images'
os.makedirs(output_folder, exist_ok=True)

# Generate 100 synthetic images
generator.eval()  # Set the model to evaluation mode
for i in range(100):
    z = torch.randn(1, latent_dim, device=device)
    gen_img = generator(z)
    save_image(gen_img, f"{output_folder}/synthetic_{i}.png", normalize=True)

print(f"Synthetic images saved to {output_folder}.")