# Import Necessary Libraries



In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from torchvision.utils import make_grid, save_image
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import os



# Define Utility Functions for Noise Generation, Label Encoding, and Image Display

In [None]:


def generate_noise(batch_size, noise_dim, device='cpu'):
    """
    Generates random noise vectors for the generator input.
    """
    return torch.randn(batch_size, noise_dim, device=device)

def one_hot_encode_labels(labels, num_classes):
    """
    Converts label indices into one-hot encoded vectors.
    """
    return F.one_hot(labels, num_classes)

def combine_vectors(x, y):
    """
    Combines two vectors by concatenating them along the specified dimension.
    """
    return torch.cat((x.float(), y.float()), dim=1)

def show_tensor_images(image_tensor, num_images=25, size=(3, 32, 32), nrow=5, show=True):
    """
    Displays a batch of images in a grid format.
    """
    image_tensor = (image_tensor + 1) / 2  # Normalize to [0, 1]
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    if show:
        plt.show()


# Define Generator Architecture for CIFAR-10 GAN


In [None]:
class CIFARGenerator(nn.Module):
    def __init__(self, input_dim=110, image_channels=3, hidden_dim=64):
        super(CIFARGenerator, self).__init__()
        self.input_dim = input_dim
        self.gen = nn.Sequential(
            self.make_gen_block(input_dim, hidden_dim * 8, kernel_size=4, stride=1, padding=0),   # Output: (hidden_dim*8) x 4 x 4
            self.make_gen_block(hidden_dim * 8, hidden_dim * 4, kernel_size=4, stride=2, padding=1), # Output: (hidden_dim*4) x 8 x 8
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=2, padding=1), # Output: (hidden_dim*2) x 16 x 16
            self.make_gen_block(hidden_dim * 2, hidden_dim, kernel_size=4, stride=2, padding=1),  # Output: (hidden_dim) x 32 x 32
            self.make_gen_block(hidden_dim, image_channels, kernel_size=3, stride=1, padding=1, final_layer=True), # Output: (image_channels) x 32 x 32
        )

    def make_gen_block(self, input_channels, output_channels, kernel_size=4, stride=2, padding=1, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride, padding),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride, padding),
                nn.Tanh(),
            )

    def forward(self, noise):
        x = noise.view(len(noise), self.input_dim, 1, 1)
        return self.gen(x)


# Define Discriminator Architecture for CIFAR-10 GAN


In [None]:
class CIFARDiscriminator(nn.Module):
    def __init__(self, image_channels=3, num_classes=10, hidden_dim=64):
        super(CIFARDiscriminator, self).__init__()
        self.disc = nn.Sequential(
            # Input: (image_channels + num_classes) x 32 x 32
            self.make_disc_block(image_channels + num_classes, hidden_dim),  # Output: hidden_dim x 16 x 16
            self.make_disc_block(hidden_dim, hidden_dim * 2),                # Output: hidden_dim*2 x 8 x 8
            self.make_disc_block(hidden_dim * 2, hidden_dim * 4),            # Output: hidden_dim*4 x 4 x 4
            self.make_disc_block(hidden_dim * 4, 1, kernel_size=4, stride=1, padding=0, final_layer=True),  # Output: 1 x 1 x 1
        )

    def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, padding=1, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            # Final layer without BatchNorm and with Sigmoid activation
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding),
                nn.Sigmoid(),
            )

    def forward(self, image):
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)  # Flatten to shape [batch_size, 1]


# Load CIFAR-10 Dataset and Define DataLoader


In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize all three channels
])

batch_size = 128

cifar_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
data_loader = DataLoader(cifar_dataset, batch_size=batch_size, shuffle=True)


# Initialize Generator, Discriminator, and Optimizers


In [None]:
latent_dim = 100
num_classes = 10
learning_rate = 0.0002
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generator = CIFARGenerator(input_dim=latent_dim + num_classes, image_channels=3).to(device)
discriminator = CIFARDiscriminator(image_channels=3, num_classes=num_classes).to(device)

gen_optimizer = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
disc_optimizer = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))


# Train the GAN Model or Generate Images Based on User Input

In [None]:
action = input("Enter 'train' to train the model or 'generate' to generate images: ").strip().lower()

if action == 'train':
    epochs = 200

    criterion = nn.BCELoss()

    print("Starting training from scratch.")

    for epoch in range(0, epochs):
        epoch_num = epoch + 1  # To display epochs starting from 1
        progress_bar = tqdm(data_loader, desc=f"Epoch {epoch_num}/{epochs}")
        for real, labels in progress_bar:
            cur_batch_size = len(real)
            real = real.to(device)
            labels = labels.to(device)

            # Create real labels (ones) and fake labels (zeros)
            real_labels = torch.ones(cur_batch_size, 1, device=device)
            fake_labels = torch.zeros(cur_batch_size, 1, device=device)

            # Create one-hot labels and expand to match image dimensions
            one_hot_labels = one_hot_encode_labels(labels, num_classes).float().to(device)
            one_hot_labels = one_hot_labels[:, :, None, None]
            one_hot_labels = one_hot_labels.repeat(1, 1, 32, 32)

            ### Update Discriminator ###
            disc_optimizer.zero_grad()

            # Concatenate real images with labels
            real_input = torch.cat((real, one_hot_labels), dim=1)
            real_output = discriminator(real_input)
            real_loss = criterion(real_output, real_labels)

            # Generate fake images
            noise = generate_noise(cur_batch_size, latent_dim, device=device)
            noise_and_labels = combine_vectors(noise, one_hot_encode_labels(labels, num_classes).to(device))
            fake_images = generator(noise_and_labels)

            # Concatenate fake images with labels
            fake_input = torch.cat((fake_images.detach(), one_hot_labels), dim=1)
            fake_output = discriminator(fake_input)
            fake_loss = criterion(fake_output, fake_labels)

            # Total discriminator loss
            disc_loss = real_loss + fake_loss
            disc_loss.backward()
            disc_optimizer.step()

            ### Update Generator ###
            gen_optimizer.zero_grad()

            # Generate fake images again
            noise = generate_noise(cur_batch_size, latent_dim, device=device)
            noise_and_labels = combine_vectors(noise, one_hot_encode_labels(labels, num_classes).to(device))
            fake_images = generator(noise_and_labels)

            # Concatenate fake images with labels
            fake_input = torch.cat((fake_images, one_hot_labels), dim=1)
            fake_output = discriminator(fake_input)
            gen_loss = criterion(fake_output, real_labels)
            gen_loss.backward()
            gen_optimizer.step()

            # Optionally, update progress bar with losses
            progress_bar.set_postfix({
                'D_loss': f'{disc_loss.item():.4f}',
                'G_loss': f'{gen_loss.item():.4f}'
            })

        # Print losses at every epoch
        print(f"Epoch [{epoch_num}/{epochs}] Discriminator Loss: {disc_loss.item():.4f} Generator Loss: {gen_loss.item():.4f}")

        # Show real and fake images for the first 3 epochs and then every 10 epochs
        if epoch_num <= 3 or epoch_num % 10 == 0:
            print("Real Images:")
            show_tensor_images(real, num_images=25, size=(3, 32, 32))
            print("Fake Images:")
            show_tensor_images(fake_images, num_images=25, size=(3, 32, 32))

        # Save checkpoint after every 10 epochs
        if epoch_num % 10 == 0:
            generator_file = f'cifar_generator_epoch_{epoch_num}.pth'
            discriminator_file = f'cifar_discriminator_epoch_{epoch_num}.pth'
            torch.save(generator.state_dict(), generator_file)
            torch.save(discriminator.state_dict(), discriminator_file)
            print(f"Weights saved for generator and discriminator at epoch {epoch_num}.")

elif action == 'generate':
    # Generate images
    epoch = input("Enter the epoch number for the saved generator weights (e.g., '10', '20', '30', etc.): ").strip()
    weight_file = f'cifar_generator_epoch_{epoch}.pth'

    if os.path.exists(weight_file):
        generator.load_state_dict(torch.load(weight_file, map_location=device))
        generator.eval()

        # Ask the user to input which object/class they want to generate
        print("Enter the number corresponding to the object you want to generate:")
        print("0: Airplane, 1: Automobile, 2: Bird, 3: Cat, 4: Deer")
        print("5: Dog, 6: Frog, 7: Horse, 8: Ship, 9: Truck")

        class_number = int(input("Enter a number between 0 and 9: ").strip())
        if class_number < 0 or class_number >= num_classes:
            print("Invalid class number. Please enter a number between 0 and 9.")
        else:
            # Generate images for the selected class
            noise = generate_noise(25, latent_dim, device=device)
            labels = torch.full((25,), class_number, device=device, dtype=torch.long)
            one_hot_labels = one_hot_encode_labels(labels, num_classes).to(device)
            noise_and_labels = combine_vectors(noise, one_hot_labels)
            fake_images = generator(noise_and_labels)

            # Show generated images
            print(f"Generated Images for class {class_number}:")
            show_tensor_images(fake_images, num_images=25, size=(3, 32, 32))
    else:
        print(f"Weight file '{weight_file}' not found. Please train the model first.")
else:
    print("Invalid input. Please enter 'train' or 'generate'.")
