# **Task 3**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as dset
import torchvision.utils as vutils
import torch.nn.functional as F

# Parameters
batch_size = 64
image_size = 32
nc = 3       # Number of channels in the training images (CIFAR10 is RGB, so 3 channels)
nz = 100     # Size of the latent z vector (input noise)
ngf = 64     # Size of feature maps in generator
ndf = 64     # Size of feature maps in discriminator
num_epochs = 5
lr = 0.0002
beta1 = 0.5  # Beta1 hyperparam for Adam optimizers

# Define a Self-Attention layer
class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, C, width, height = x.size()
        query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)  # B X (W*H) X C
        key = self.key_conv(x).view(batch_size, -1, width * height)  # B X C X (W*H)
        energy = torch.bmm(query, key)  # batch matrix multiplication
        attention = F.softmax(energy, dim=-1)  # B X (W*H) X (W*H)
        value = self.value_conv(x).view(batch_size, -1, width * height)  # B X C X (W*H)
        out = torch.bmm(value, attention.permute(0, 2, 1))  # B X C X (W*H)
        out = out.view(batch_size, C, width, height)
        out = self.gamma * out + x
        return out

# Define the Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            SelfAttention(ngf * 4),  # Self-attention layer
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.main(x)
        return x

# Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            SelfAttention(ndf * 2),  # Self-attention layer
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AdaptiveAvgPool2d(1)  # Global average pooling
        )
        self.output = nn.Sigmoid()  # Output sigmoid for binary classification

    def forward(self, x):
        x = self.main(x)
        x = x.view(x.size(0), -1)  # Flatten the output to [batch_size, 1]
        x = self.output(x)
        return x

# Initialize the networks
netG = Generator()
netD = Discriminator()

# Loss and Optimizer
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# Load the dataset
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = dset.CIFAR10(root='./data', download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training Loop
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        # Train Discriminator
        netD.zero_grad()
        real_data = data[0]
        batch_size = real_data.size(0)
        labels_real = torch.ones(batch_size)
        output_real = netD(real_data).view(-1)  # Flatten the output
        loss_real = criterion(output_real, labels_real)
        loss_real.backward()

        noise = torch.randn(batch_size, nz, 1, 1)
        fake_data = netG(noise)
        labels_fake = torch.zeros(batch_size)
        output_fake = netD(fake_data.detach()).view(-1)  # Flatten the output
        loss_fake = criterion(output_fake, labels_fake)
        loss_fake.backward()
        optimizerD.step()

        # Train Generator
        netG.zero_grad()
        labels_fake = torch.ones(batch_size)
        output_fake = netD(fake_data).view(-1)  # Flatten the output
        lossG = criterion(output_fake, labels_fake)
        lossG.backward()
        optimizerG.step()

        print(f'[{epoch}/{num_epochs}] [{i}/{len(dataloader)}] Loss D: {loss_real + loss_fake}, Loss G: {lossG}')

    # Save fake images every epoch
    vutils.save_image(fake_data, f'fake_images_epoch_{epoch}.png', normalize=True)



Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 62461038.64it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
[0/5] [0/782] Loss D: 1.5021474361419678, Loss G: 0.6679111123085022
[0/5] [1/782] Loss D: 1.4392364025115967, Loss G: 0.6435126662254333
[0/5] [2/782] Loss D: 1.312483549118042, Loss G: 0.6115994453430176
[0/5] [3/782] Loss D: 1.262702465057373, Loss G: 0.6264382004737854
[0/5] [4/782] Loss D: 1.1299593448638916, Loss G: 0.6796182990074158
[0/5] [5/782] Loss D: 1.0349690914154053, Loss G: 0.7355411648750305
[0/5] [6/782] Loss D: 0.9575192332267761, Loss G: 0.7724143266677856
[0/5] [7/782] Loss D: 0.7916373014450073, Loss G: 0.8004042506217957
[0/5] [8/782] Loss D: 0.7652654647827148, Loss G: 0.8134892582893372
[0/5] [9/782] Loss D: 0.7289714813232422, Loss G: 0.812693178653717
[0/5] [10/782] Loss D: 0.7092154026031494, Loss G: 0.8099592328071594
[0/5] [11/782] Loss D: 0.7094257473945618, Loss G: 0.815449059009552
[0/5] [12/782] Loss D: 0.6988236308097839, Loss G: 0.8263753056526184
[0/5] [13/782] Loss D: 0.6836705207824707, Loss G: 0.