<a href="https://colab.research.google.com/github/arsh-kamal/Artextract-GSOC-2025/blob/main/assignment7.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Initialization, utilities (no TODOs)

In [None]:
import torch
import torchvision
import torch.nn as nn
import argparse
import PIL
import random

In [None]:
def to_list(img):
    return list(map(int, img.view((28*28,)).tolist()))

SCALE_OFF = 0
SCALE_RANGE = 1
SCALE_01 = 2


def show_image(tens, imgname=None, scale=SCALE_01):
    """
    Show an image contained in a tensor. The tensor will be reshaped properly, as long as it has the required 28*28 = 784 entries.

    If imgname is provided, the image will be saved to a file, otherwise it will be stored in a temporary file and displayed on screen.

    The parameter scale can be used to perform one of three scaling operations:
        SCALE_OFF: No scaling is performed, the data is expected to use values between 0 and 255
        SCALE_RANGE: The data will be rescaled from whichever scale it has to be between 0 and 255. This is useful for data in an unknown/arbitrary range. The lowest value present in the data will be
        converted to 0, the highest to 255, and all intermediate values will be assigned using linear interpolation
        SCALE_01: The data will be rescaled from a range between 0 and 1 to the range between 0 and 255. This can be useful if you normalize your data into that range.
    """
    r = tens.max() - tens.min()
    img = PIL.Image.new("L", (28,28))
    scaled = tens
    if scale == SCALE_RANGE:
        scaled = (tens - tens.min())*255/r
    elif scale == SCALE_01:
        scaled = tens*255
    img.putdata(to_list(scaled))
    if imgname is None:
        img.show()
    else:
        img.save(imgname)




# Classification (5 TODOs)

In [None]:
# Used for both tasks
loss_fn = torch.nn.BCELoss()

# TODO 1: Choose a digit
digit = 7

# TODO 2: Change number of training iterations for classifier
n0 = 10

In [None]:
# TODO 3
# Change Network architecture of the discriminator/classifier network. It should have 784 inputs and 1 output (0 = fake, 1 = real)
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [None]:
# TODO 4
# Implement training loop for the classifier:
# for i in range(n0):
#     zero gradients
#     calculate predictions for given x
#     calculate loss, comparing the predictions with the given y
#     calculate the gradient (loss.backward())
#     print i and the loss
#     perform an optimizer step
def train_classifier(opt, model, x, y):
    model.train()
    for i in range(n0):
        opt.zero_grad()  # Zero gradients
        y_pred = model(x)  # Calculate predictions
        loss = loss_fn(y_pred, y)  # Calculate loss
        loss.backward()  # Calculate gradients
        print(f"Iteration {i}, Loss: {loss.item():.4f}")
        opt.step()  # Perform optimizer step

In [None]:
# TODO 5
# Instantiate the network and the optimizer
# call train_classifier with the training set
# Calculate metrics on the validation set
# Example:
#      y_pred = net(x_validation[labels_validation == 3]) calculates all predictions for all images we know to be 3s
#      (y_pred > 0.5) is a tensor that tells you if a given image was classified as your chosen digit (True) or not (False)
#      You can convert this tensor to 0s and 1s by calling .float()
#      (y_pred > 0.5).sum() will tell you how many of these predictions were true
# You are supposed to calculate:
#     For each digit from 0 to 9, which number percentage of images that were of that digit were predicted as your chosen digit
#     The percentage of digits that were classified correctly (i.e. that were your digit and predicted as such, or were another digit and not predicted as your digit)
#     This last value (accuracy) should be over 90% (preferably over 98%; precision and recall may be lower than that, 90-93% would be decent values)
#     Precision (which percentage of images identified as your chosen digit was actually that digit: TP/(TP+FP))
#     Recall (which percentage of your chosen digit was identified as such: TP/(TP+FN))
def classify(x_train, y_train, x_validation, labels_validation):
    # Instantiate network and optimizer
    net = Discriminator()
    opt = torch.optim.Adam(net.parameters(), lr=0.01)

    # Train the classifier
    train_classifier(opt, net, x_train, y_train)

    # Evaluate on validation set
    net.eval()
    with torch.no_grad():
        # Calculate predictions for all images
        y_pred = net(x_validation)

        # Initialize counters
        TP = 0
        FP = 0
        TN = 0
        FN = 0
        per_digit_percent = {i: 0 for i in range(10)}

        # Calculate metrics for each digit
        for digit in range(10):
            # Select images of this digit
            mask = labels_validation == digit
            y_pred_digit = y_pred[mask]
            y_true_digit = (labels_validation[mask] == digit).float().view(-1, 1)

            # Predictions for this digit
            pred_positive = (y_pred_digit > 0.5).float()

            if digit == digit:  # Chosen digit
                TP = (pred_positive == 1).sum().item()
                FN = (pred_positive == 0).sum().item()
                # Save misclassified images (FN)
                fn_indices = (labels_validation == digit) & (y_pred[:, 0] <= 0.5)
                for idx, fn_idx in enumerate(torch.where(fn_indices)[0][:5]):  # Save up to 5
                    show_image(x_validation[fn_idx], f"fn_digit_{digit}_{idx}.png", scale=SCALE_01)
            else:  # Other digits
                FP_digit = (pred_positive == 1).sum().item()
                TN_digit = (pred_positive == 0).sum().item()
                FP += FP_digit
                TN += TN_digit
                # Save misclassified images (FP)
                fp_indices = (labels_validation == digit) & (y_pred[:, 0] > 0.5)
                for idx, fp_idx in enumerate(torch.where(fp_indices)[0][:5]):  # Save up to 5
                    show_image(x_validation[fp_idx], f"fp_digit_{digit}_{idx}.png", scale=SCALE_01)

                # Calculate percentage classified as chosen digit
                total = mask.sum().item()
                if total > 0:
                    per_digit_percent[digit] = (FP_digit / total) * 100

        # Calculate metrics
        total = TP + TN + FP + FN
        accuracy = (TP + TN) / total * 100 if total > 0 else 0
        precision = TP / (TP + FP) * 100 if (TP + FP) > 0 else 0
        recall = TP / (TP + FN) * 100 if (TP + FN) > 0 else 0

        # Print results
        print(f"Accuracy: {accuracy:.2f}%")
        print(f"Precision: {precision:.2f}%")
        print(f"Recall: {recall:.2f}%")
        print("Percentage classified as chosen digit ({}):".format(digit))
        for d, percent in per_digit_percent.items():
            print(f"Digit {d}: {percent:.2f}%")

# GAN (5 TODOs)

In [None]:
# TODO 6: Change number of total training iterations for GAN, for the discriminator and for the generator
n = 10    # Total GAN iterations
n1 = 50   # Discriminator iterations per GAN iteration
n2 = 50   # Generator iterations per GAN iteration

In [None]:
# TODO 7
# Change Network architecture of the generator network. It should have 100 inputs (will be random numbers) and 784 outputs (one for each pixel, each between 0 and 1)
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 784),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [None]:
# TODO 8
# Implement training loop for the discriminator, given real and fake data:
# for i in range(n1):
#     zero gradients
#     calculate predictions for the x known as real
#     calculate loss, comparing the predictions with a tensor consisting of 1s (we want all of these samples to be classified as real)
#     calculate the gradient (loss_true.backward())
#     calculate predictions for the x known as fake
#     calculate loss, comparing the predictions with a tensor consisting of 0s (we want all of these samples to be classified as fake)
#     calculate the gradient (loss_false.backward())
#     print i and both of the loss values
#     perform an optimizer step
def train_discriminator(opt, discriminator, x_true, x_false):
    print("Training discriminator")
    discriminator.train()
    for i in range(n1):
        opt.zero_grad()  # Zero gradients

        # Real images
        pred_true = discriminator(x_true)
        loss_true = loss_fn(pred_true, torch.ones_like(pred_true))
        loss_true.backward()

        # Fake images
        pred_false = discriminator(x_false)
        loss_false = loss_fn(pred_false, torch.zeros_like(pred_false))
        loss_false.backward()

        print(f"Iteration {i}, Real Loss: {loss_true.item():.4f}, Fake Loss: {loss_false.item():.4f}")
        opt.step()

In [None]:
# TODO 9
# Implement training loop for the generator:
# for i in range(n2):
#     zero gradients
#     generate some random inputs
#     calculate generated images by passing these inputs to the generator
#     pass the generated images to the discriminator to predict if they are true or fake
#     calculate the loss, comparing the predictions with a tensor of 1s (the *generator* wants the discriminator to classify its images as real)
#     calculate the gradient (loss.backward())
#     print i and the loss
#     perform an optimization step
def train_generator(opt, generator, discriminator):
    print("Training generator")
    generator.train()
    for i in range(n2):
        opt.zero_grad()  # Zero gradients

        # Generate random noise
        noise = torch.randn(100, 100)
        fake_images = generator(noise)

        # Pass fake images through discriminator
        pred = discriminator(fake_images)
        loss = loss_fn(pred, torch.ones_like(pred))  # Generator wants fake images to be classified as real

        loss.backward()
        print(f"Iteration {i}, Loss: {loss.item():.4f}")
        opt.step()

In [None]:
# TODO 10
# Implement GAN training loop:
# Generate some random images (with torch.rand) as an initial collection of fakes
# Instantiate the two networks and two optimizers (one for each network!)
# for i in range(n):
#    call train_discriminator with the given real images and the collection of fake images
#    call train_generator
#    generate some images with the current generator, and add a random selection of old fake images (e.g. 100 random old ones, and 100new ones = 200 in total)
#    this will be your new collection of fake images
#    save some of the current fake images to a file (use a filename like "sample_%d_%d.png"%(i,j) so you have some samples from each iteration so you can see if the network improves)
# If you read the todos above, your training code will print the loss in each iteration. The loss for the discriminator and the generator should decrease each time their respective training functions are called
# The images should start to look like numbers after just a few (could be after 1 or 2 already, or 3-10) iterations of *this* loop
def gan(x_real):
    # Initialize fake image repository
    x_false = torch.rand((100, 784))

    # Instantiate networks and optimizers
    generator = Generator()
    discriminator = Discriminator()
    opt_g = torch.optim.Adam(generator.parameters(), lr=0.001)
    opt_d = torch.optim.Adam(discriminator.parameters(), lr=0.001)

    # Training loop
    for i in range(n):
        # Train discriminator
        train_discriminator(opt_d, discriminator, x_real, x_false.detach())

        # Train generator
        train_generator(opt_g, generator, discriminator)

        # Generate new fake images
        with torch.no_grad():
            noise = torch.randn(100, 100)
            new_fakes = generator(noise)

        # Update fake image repository (50 old + 50 new)
        indices = torch.randperm(x_false.size(0))[:50]
        x_false = torch.cat((x_false[indices], new_fakes[:50]), dim=0)

        # Save sample images
        for j in range(5):  # Save 5 samples per iteration
            show_image(new_fakes[j], f"sample_{i}_{j}.png", scale=SCALE_01)

    show_image(x_real[0], "train_0.png", scale=SCALE_01)

# Main (no TODOs)

In [None]:
def main(rungan):
    """
    You do not have to change this function!

    It will:
        automatically download the data set if it doesn't exist yet
        make sure all tensor shapes are correct
        normalize the images (all pixels between 0 and 1)
        provide labels for the classification task (0 for all images that are not your digit, 1 for the ones that are)
        extract the images of your chosen digit for the GAN
    """
    train = torchvision.datasets.MNIST(".", download=True)
    x_train = train.data.float().view(-1,28*28)/255.0
    labels_train = train.targets
    y_train = (labels_train == digit).float().view(-1,1)

    validation = torchvision.datasets.MNIST(".", train=False)
    x_validation = validation.data.float().view(-1,28*28)/255.0
    labels_validation = validation.targets

    if rungan:
        gan(x_train[labels_train == digit])
    else:
        classify(x_train, y_train, x_validation, labels_validation)

# Test call (TODO: TEST)

In [None]:
# NOTE: This will not work until you have done TODO 1 above!
# If you have not done TODO 1 yet, you will get: AttributeError: 'bool' object has no attribute 'float'
GAN = False
main(GAN)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 114836390.27it/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 33061051.81it/s]


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 36478529.21it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4438613.41it/s]


Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw

