In [1]:
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from collections import OrderedDict
import torch.optim as optim

In [2]:
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from torchvision.datasets import MNIST

# Define the transformation
transform = transforms.Compose(
    [
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)

# Load the full MNIST dataset
full_dataset = MNIST(root=".", train=True, download=True, transform=transform)

# Create a subset with the first 1000 samples
subset_indices = range(500)
subset_dataset = Subset(full_dataset, subset_indices)

# Create a DataLoader for the subset
batch_size = 64
dataloader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=True)

# Now you can use subset_dataloader to iterate over the subset of 1000 samples

In [3]:
for data, label in dataloader:
    print(data.shape, label)
    break

torch.Size([64, 1, 64, 64]) tensor([1, 2, 3, 0, 1, 4, 8, 0, 5, 2, 4, 5, 6, 1, 4, 6, 7, 7, 0, 1, 8, 5, 9, 8,
        5, 7, 0, 2, 2, 0, 1, 0, 1, 3, 4, 3, 2, 3, 6, 8, 0, 4, 2, 1, 3, 8, 2, 1,
        9, 3, 0, 5, 7, 9, 3, 9, 1, 0, 9, 5, 6, 6, 7, 6])


In [4]:
# Check the quantity of the dataset
total_data = 0

for image, label in dataloader:
    total_data += image.shape[0]

print("total quantity of the dataset # {} ".format(total_data))
print("The shape of the data # {} ".format(image.shape))

total quantity of the dataset # 500 
The shape of the data # torch.Size([52, 1, 64, 64]) 


In [5]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_channels),
            nn.PReLU(),
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_channels),
        )

    def forward(self, x):
        return x + self.block(x)


class Generator(nn.Module):
    def __init__(self, in_channels=1, num_residual_blocks=16):
        super(Generator, self).__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=9, stride=1, padding=4), nn.PReLU()
        )

        self.residuals = nn.Sequential(
            *[ResidualBlock(64) for _ in range(num_residual_blocks)]
        )

        self.upsample = nn.Sequential(
            nn.Conv2d(64, 256, kernel_size=3, stride=1, padding=1),
            nn.PixelShuffle(upscale_factor=2),
            nn.PReLU(),
            nn.Conv2d(64, 256, kernel_size=3, stride=1, padding=1),
            nn.PixelShuffle(upscale_factor=2),
            nn.PReLU(),
        )

        self.final = nn.Sequential(
            nn.Conv2d(64, in_channels, kernel_size=9, stride=1, padding=4), nn.Tanh()
        )

    def forward(self, x):
        initial = self.initial(x)
        residuals = self.residuals(initial)
        x = initial + residuals
        x = self.upsample(x)
        x = self.final(x)
        return x

In [6]:
if __name__ == "__main__":
    generator = Generator()

In [7]:
# Find out the total trainable parameters
total_params = 0
for params in generator.parameters():
    total_params+=params.numel()
    
print("Total # of trainable parameters # {}".format(total_params))

Total # of trainable parameters # 1491668


In [8]:
# Check it works or not
noise_data = torch.randn(64, 1, 64, 64)
generator(noise_data).shape

torch.Size([64, 1, 256, 256])

In [9]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=1):
        super(Discriminator, self).__init__()
        self.nf = 64
        self.main = nn.Sequential(
            nn.Conv2d(in_channels, self.nf, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(self.nf, self.nf, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(self.nf),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(self.nf, self.nf * 2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(self.nf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(self.nf * 2, self.nf * 2, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(self.nf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(self.nf * 2, self.nf * 4, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(self.nf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(self.nf * 4, self.nf * 4, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(self.nf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(self.nf * 4, self.nf * 8, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(self.nf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(self.nf * 8, self.nf * 8, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(self.nf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(self.nf * 8, 1024, kernel_size=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(1024, 1, kernel_size=1),
        )

    def forward(self, x):
        x = self.main(x)
        return torch.sigmoid(x.view(x.size(0), -1))

In [10]:
if __name__ == '__main__':
    discriminator = Discriminator()
    

In [11]:
noise_data = torch.randn(64, 1, 64, 64)
discriminator(noise_data).shape

torch.Size([64, 1])

In [12]:
# Total trainable params
total_params = 0
for params in discriminator.parameters():
    total_params+=params.numel()
    
print("Total trainable params # {} ".format(total_params))

Total trainable params # 5214273 


In [13]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device

device(type='mps')

In [14]:
generator = Generator()
discriminator = Discriminator()

loss_function = nn.BCELoss()
optimizer_generator = optim.Adam(generator.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizer_discriminator = optim.Adam(discriminator.parameters(), lr = 0.0002, betas = (0.5, 0.999))

In [16]:
num_epochs = 10
for epoch in range(num_epochs):
    D_loss = list()
    G_loss = list()
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = torch.ones(imgs.size(0), 1, requires_grad=False)
        fake = torch.zeros(imgs.size(0), 1, requires_grad=False)

        # Configure input
        real_imgs = imgs

        # -----------------
        #  Train Generator
        # -----------------
        optimizer_generator.zero_grad()

        # Generate a batch of images
        gen_imgs = generator(imgs)

        # Loss measures generator's ability to fool the discriminator
        g_loss = loss_function(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_generator.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_discriminator.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = loss_function(discriminator(real_imgs), valid)
        fake_loss = loss_function(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_discriminator.step()
        
        G_loss.append(g_loss.item())
        D_loss.append(d_loss.item())

        if i%400:
            print(
                f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]"
            )
    
    print("G_loss: {} - D_loss: {} ".format(np.array(G_loss).mean(), np.array(D_loss).mean()))
    

[Epoch 0/10] [Batch 1/8] [D loss: 0.6755176782608032] [G loss: 0.6713829040527344]
[Epoch 0/10] [Batch 2/8] [D loss: 0.6501513123512268] [G loss: 0.7601485848426819]
[Epoch 0/10] [Batch 3/8] [D loss: 0.6265161037445068] [G loss: 0.7938126921653748]
[Epoch 0/10] [Batch 4/8] [D loss: 0.5980849862098694] [G loss: 0.7848340272903442]
[Epoch 0/10] [Batch 5/8] [D loss: 0.5757236480712891] [G loss: 0.8298689723014832]
[Epoch 0/10] [Batch 6/8] [D loss: 0.5494447946548462] [G loss: 0.8864858150482178]
[Epoch 0/10] [Batch 7/8] [D loss: 0.5144287943840027] [G loss: 0.9198465943336487]
G_loss: 0.7973206043243408 - D_loss: 0.6106206402182579 
[Epoch 1/10] [Batch 1/8] [D loss: 0.456842303276062] [G loss: 1.0538866519927979]
