<a href="https://colab.research.google.com/github/HemaGarima/Machine-Learning/blob/master/Generative_Teaching_Networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### GTN - Generative Training Network

### Learning Objecjtives
By the end of this notebook , i should :
- 1. Understand the concepts of teaching networks , meta-learning , and neural architecture search , and how they relate to the objective of data augumentation.
- 2. Implement and train a GTN on MNIST , and observe how a GTN can accelerate training.

In [1]:
import os
import sys
import math
import random

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.autograd import grad

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable

if 'higher' not in sys.modules:
  !pip install higher
import higher as higher

print(sys.version)
print(torch.__version__)

Collecting higher
  Downloading higher-0.2.1-py3-none-any.whl.metadata (10 kB)
Downloading higher-0.2.1-py3-none-any.whl (27 kB)
Installing collected packages: higher
Successfully installed higher-0.2.1
3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
2.4.1+cu121


In [2]:
# Set random seeds
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)

# Set important parameters
learning_rate = 1e-2
inner_loop_iterations = 32
outer_loop_iterations = 5
num_classes = 10

noise_size = 64 # size of noise of curriculum vector
img_size = 28 # width/height of generated image

inner_loop_batch_size = 128
outer_loop_batch_size = 128

mnist_mean = 0.1307 # for normalizing mnist images
mnist_std =  0.3081 # for normalizing mnist images

imgs_per_row = num_classes

## Dataset
- Download the MNISt dataset and organize it into a torch.utils.data.Dataset object. Then apply torchvision.transforms to convert raw PIL images to tensors.

In [4]:
# Initialize MNIST transforms
transform = transforms.Compose([
    transforms.Lambda(lambda x: np.array(x)),
    transforms.ToTensor(),
    transforms.Normalize((mnist_mean,),(mnist_std,)),
])

# Create data splits
train = datasets.MNIST('./data',train = True , transform = transform, download = True)
train , val = torch.utils.data.random_split(train,[50000 , 10000])
test = datasets.MNIST('./data',train = False,transform = transform , download = True)
print('Created train , val , and test datasets.')

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 17805420.91it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 486799.58it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 3821760.06it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2298013.12it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Created train , val , and test datasets.


## Dataloader
- Now wrap your dataset class in a torch.utils.data.DataLoader class , whch will iterate over batches in training. This class increases memory access bandwidth so retrieving images from your dataset won't be a bottleneck in training. MNIST images are small , so the increase in memory retrieval speed should be relatively trivial.

In [5]:
train_loader = torch.utils.data.DataLoader(
    train , batch_size = outer_loop_batch_size , shuffle = True , drop_last = True , num_workers = 1 , pin_memory = True ,
)

val_loader = torch.utils.data.DataLoader(
    val , batch_size = outer_loop_batch_size , shuffle = True , drop_last = True , num_workers = 1 , pin_memory = True,
)

test_loader = torch.utils.data.DataLoader(
    test, batch_size = outer_loop_batch_size , shuffle = True , drop_last = True , num_workers = 1 , pin_memory = True,
)

## MNIST classification
- In this next section, you'll implement and train a GTN on MNIST classification. Note that the student model for this task is a classifier. To extend GTNs to other datasets, you also want to check out the weight normalization technique in the paper --- for now on MNIST, you don't need to worry about this. Alright, let's get started with the generator and classifier's model architecture!

## Generator
- Let's now build the generator. For this task, the generator will consist of two fully connected blocks (each consisting of a fully connected layer, a leaky ReLU, and a batch normalization layer) and two convolutional blocks (each consisting of a convolutional layer, a batch normalization layer and a leaky ReLU). A tanh layer is applied to this output to center it around 0 with reasonable standard deviation.