<a href="https://colab.research.google.com/github/Sirfowahid/BrainStrokeDetection/blob/main/CGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import Grayscale
from tqdm import tqdm
import matplotlib.pyplot as plt

# Hyperparameters
batch_size = 32
learning_rate_g = 0.0002
learning_rate_d = 0.0001
epochs = 60
latent_dim = 100
num_classes = 2
img_size = 224

# Transformation: Convert images to grayscale with fixed size
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Load dataset using ImageFolder
dataset = datasets.ImageFolder(
    root='/content/drive/MyDrive/Projects/23. Brain Stroke Prediction/Dataset',
    transform=transform
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define Generator
class Generator(nn.Module):
    def __init__(self, latent_dim, num_classes, img_size):
        super(Generator, self).__init__()
        self.img_size = img_size
        self.label_embedding = nn.Embedding(num_classes, num_classes)

        self.model = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256, 0.8),
            nn.ReLU(inplace=True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512, 0.8),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024, 0.8),
            nn.ReLU(inplace=True),
            nn.Linear(1024, img_size * img_size),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        label_input = self.label_embedding(labels)
        gen_input = torch.cat((noise, label_input), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), 1, self.img_size, self.img_size)  # Output grayscale image
        return img

# Define Discriminator
class Discriminator(nn.Module):
    def __init__(self, num_classes, img_size):
        super(Discriminator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, img_size * img_size)

        self.model = nn.Sequential(
            nn.Conv2d(2, 16, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(16, 32, 3, 2, 1),
            nn.BatchNorm2d(32, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 64, 3, 2, 1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True)
        )

        ds_size = img_size // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())

    def forward(self, img, labels):
        label_embedding = self.label_embedding(labels).view(labels.size(0), 1, img.size(2), img.size(3))
        d_in = torch.cat((img, label_embedding), 1)
        out = self.model(d_in)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        return validity

# Initialize models
generator = Generator(latent_dim, num_classes, img_size)
discriminator = Discriminator(num_classes, img_size)

# Loss and Optimizers
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate_g, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate_d, betas=(0.5, 0.999))

# Loss trackers
g_losses = []
d_losses = []

# Function to plot generated images
def plot_generated_images(epoch, generator, latent_dim, num_classes):
    z = torch.randn(16, latent_dim)
    labels = torch.randint(0, num_classes, (16,))
    gen_imgs = generator(z, labels).detach().cpu()

    fig, axs = plt.subplots(4, 4, figsize=(8, 8))
    for i in range(16):
        axs[i // 4, i % 4].imshow(gen_imgs[i, 0], cmap='gray')
        axs[i // 4, i % 4].axis('off')
    plt.suptitle(f'Generated Images at Epoch {epoch + 1}')
    plt.show()

# Training Loop
for epoch in range(epochs):
    for i, (real_imgs, labels) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}")):
        batch_size = real_imgs.size(0)

        # Adversarial ground truths with label smoothing
        valid = torch.full((batch_size, 1), 0.9)  # Smooth labels for real images
        fake = torch.full((batch_size, 1), 0.0)

        # Configure input
        real_imgs = real_imgs
        labels = labels

        # Train Generator
        optimizer_G.zero_grad()
        z = torch.randn(batch_size, latent_dim)
        gen_labels = torch.randint(0, num_classes, (batch_size,))
        gen_imgs = generator(z, gen_labels)
        g_loss = adversarial_loss(discriminator(gen_imgs, gen_labels), valid)
        g_loss.backward()
        optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(real_imgs, labels), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach(), gen_labels), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # Save losses
        g_losses.append(g_loss.item())
        d_losses.append(d_loss.item())

    print(f"[Epoch {epoch + 1}/{epochs}] [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")

    # Plot generated images every 5 epochs
    if (epoch + 1) % 5 == 0:
        plot_generated_images(epoch, generator, latent_dim, num_classes)

# Plot the losses
plt.figure(figsize=(10, 5))
plt.plot(g_losses, label="Generator Loss")
plt.plot(d_losses, label="Discriminator Loss")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import torch
from torchvision.transforms import ToPILImage

def plot_single_generated_image(generator, latent_dim, num_classes):
    # Generate a batch of images
    z = torch.randn(16, latent_dim)  # Batch size of 16
    labels = torch.randint(0, num_classes, (16,))  # Batch size of 16

    # Generate images from the batch
    with torch.no_grad():
        gen_imgs = generator(z, labels).detach().cpu()

    # Extract a single image from the batch
    single_img = gen_imgs[0]  # Take the first image in the batch


    # Remove batch dimension if needed
    if single_img.dim() == 4:  # Check if it has batch and channel dimensions
        single_img = single_img.squeeze(0)  # Remove batch dimension
    if single_img.size(0) == 1:  # Check if image is grayscale
        single_img = single_img.expand(3, -1, -1)  # Convert grayscale to RGB

    # Resize image to 224x224 if necessary
    single_img = torch.nn.functional.interpolate(single_img.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)

    # Convert tensor to PIL image
    to_pil = ToPILImage()
    img = to_pil(single_img)

    # Plot the single image
    plt.figure(figsize=(4, 4))
    plt.imshow(img)
    plt.axis('off')
    plt.title('Generated Image')
    plt.show()

# Example usage
latent_dim = 100
num_classes = 2
plot_single_generated_image(generator, latent_dim, num_classes)
