In [1]:
import torch
from datasets import load_dataset

dataset = load_dataset("mnist")

Reusing dataset mnist (C:\Users\eshaa\.cache\huggingface\datasets\mnist\mnist\1.0.0\fda16c03c4ecfb13f165ba7e29cf38129ce035011519968cdaf74894ce91c9d4)


  0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
dataset['train'], dataset['test']

(Dataset({
     features: ['image', 'label'],
     num_rows: 60000
 }),
 Dataset({
     features: ['image', 'label'],
     num_rows: 10000
 }))

## Transform MNIST Dataset into variable sized images :D

In [3]:
from PIL import Image
from tqdm import tqdm
import numpy as np

orig_x_train = []
orig_y_train = []

x_train = []
y_train = []

for value in tqdm(dataset['train'], desc="Resizing x_train dataset"):
    moreUpscaleImage = value['image'].resize((64, 64), resample=Image.BOX)
    upscaleImage = value['image'].resize((48, 48), resample=Image.BOX)
    regularImage = value['image']

    moreUpscaleImage = torch.from_numpy(np.uint8(moreUpscaleImage)).to(torch.float)
    upscaleImage = torch.from_numpy(np.uint8(upscaleImage)).to(torch.float)
    regularImage = torch.from_numpy(np.uint8(regularImage)).to(torch.float)

    x_train.extend([regularImage.flatten(), upscaleImage.flatten(), moreUpscaleImage.flatten()]) 
    y_train.extend([value['label'], value['label'], value['label']])

    orig_x_train.append(regularImage.unsqueeze(0))
    orig_y_train.append(value['label'])

orig_x_test = []
orig_y_test = []

x_test = []
y_test = []

for value in tqdm(dataset['test'], desc="Resizing x_test dataset"):
    moreUpscaleImage = value['image'].resize((64, 64), resample=Image.BOX)
    upscaleImage = value['image'].resize((48, 48), resample=Image.BOX)
    regularImage = value['image']

    moreUpscaleImage = torch.from_numpy(np.uint8(moreUpscaleImage)).to(torch.float)
    upscaleImage = torch.from_numpy(np.uint8(upscaleImage)).to(torch.float)
    regularImage = torch.from_numpy(np.uint8(regularImage)).to(torch.float)

    x_test.extend([regularImage.flatten(), upscaleImage.flatten(), moreUpscaleImage.flatten()]) 
    y_test.extend([value['label'], value['label'], value['label']])

    orig_x_test.append(regularImage.unsqueeze(0))
    orig_y_test.append(value['label'])

print(f"x_train shape: {x_train[0].shape} orig_x_train shape: {orig_x_train[0].shape}")

Resizing x_train dataset: 100%|██████████| 60000/60000 [00:22<00:00, 2651.18it/s]
Resizing x_test dataset: 100%|██████████| 10000/10000 [00:03<00:00, 2742.76it/s]

x_train shape: torch.Size([784]) orig_x_train shape: torch.Size([1, 28, 28])





## Declaring Positional Encoding (used for VNN)

In [17]:
from copy import deepcopy
import torch
from torch import nn
from torch import Tensor
import math

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        return self.pe[x]

class VNN (nn.Module):
    def __init__(self, dense_nn, weight_nn, bias_nn) -> None:
        super().__init__()
        self.d_model = (weight_nn[0].in_features-1)//2
        self.first_input = dense_nn[0].in_features

        self.dense_nn = dense_nn
        self.weight_nn = weight_nn 
        self.bias_nn = bias_nn 
        self.pos_enc = PositionalEncoding(self.d_model, dropout=0.1, max_len=5000)

    def generate_weight_vector (self, x, output_size):
        input_size = x.size(0)

        #* Weight Generation

        # Generate the weight vector
        argument_one = torch.arange(input_size).unsqueeze(1)
        
        argument_two = torch.arange(output_size)
        bias_argument = deepcopy(argument_two) 
        # Generate the repeat
        argument_one = argument_one.repeat(1, output_size).flatten()
        argument_two = argument_two.repeat(input_size)

        x_concat = x[argument_one].unsqueeze(1)
    
        # Positional Encoding + Concat
        argument_one = self.pos_enc(argument_one)
        argument_two = self.pos_enc(argument_two)
        bias_argument = self.pos_enc(bias_argument) 
        argument = torch.concat((argument_one, argument_two, x_concat), dim=1)

        # Send through the weight neural network
        weights = self.weight_nn(argument).view(input_size, output_size)

        out = torch.matmul(x, weights).unsqueeze(1)

        #* Bias Generation 
        argument = torch.concat((bias_argument, out), dim=1)
        bias = self.bias_nn(argument)
        out += bias

        return out.squeeze(1)

    def forward (self, x):
        x = self.generate_weight_vector(x, self.first_input)
        x = self.dense_nn(x)
        return x

## Control Initialization

In [16]:
# d_model = 32
# weight_model = nn.Sequential(
#     nn.Linear(65, 32),
#     nn.Tanh(),
#     nn.Linear(32, 1)
# ) 

# bias_model = nn.Sequential(
#     nn.Linear(33, 10),
#     nn.Tanh(),
#     nn.Linear(10, 1)
# )

# dense_model = nn.Sequential(
#     nn.Linear(128, 64),
#     nn.Tanh(),
#     nn.Linear(64,  10),
# )

# policy = VNN(dense_model, weight_model, bias_model)
# policy(torch.randn(128).to(device))

from random import randint


class Model(nn.Module):
    def __init__(self, use_VNN=False):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.flatten = nn.Flatten()

        if use_VNN:
            d_model = 32
            weight_model = nn.Sequential(
                nn.Linear(65, 32),
                nn.Tanh(),
                nn.Linear(32, 1)
            ) 

            bias_model = nn.Sequential(
                nn.Linear(33, 10),
                nn.Tanh(),
                nn.Linear(10, 1)
            )

            dense_model = nn.Sequential(
                nn.Linear(128, 64),
                nn.Tanh(),
                nn.Linear(64,  10),
            )
            self.linear = VNN(dense_model, weight_model, bias_model)
        else:
            self.linear = nn.Sequential(
                nn.Linear(3*3*128, 256),
                nn.ReLU(),
                nn.Dropout(p=0.4),
                nn.Linear(256, 10),
                nn.Softmax(dim=1)
            )

    def forward(self, X):
        X = self.conv(X)
        X = self.flatten(X)
        print(X.shape)
        X = self.linear(X)
        
        return X

model = Model(use_VNN=True).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=model.parameters(), lr=1e-3)

index = randint(0, len(orig_x_train))
inp = orig_x_train[index].unsqueeze(0)
print(inp.shape)
y = model(inp.to(device))
y

torch.Size([1, 1, 28, 28])
torch.Size([1, 1152])


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

## Training!

In [9]:
from tqdm import trange
from random import randint
import numpy as np

losses = []
loss_avg = 0
for epochs in range(5):
    progress_bar = trange(len(orig_x_train))
    for i in progress_bar:
        index = randint(0, len(orig_x_train)-1)
        inp = orig_x_train[index].unsqueeze(0).to(device)
        exp_out = torch.tensor(orig_y_train[index]).unsqueeze(0).to(device)

        optimizer.zero_grad()
        out = model(inp)
        loss = criterion(out, exp_out)
        loss.backward()
        optimizer.step()

        loss_avg += loss.item()
        if i % 100 == 0: 
            if i == 0:
                losses.append(loss.item())
            else:
                losses.append(loss_avg / 100)
                loss_avg = 0
        progress_bar.set_description(f"Epoch: {epochs} Loss: {loss.item():.4f}")
     

Epoch: 0 Loss: 1.4616:   2%|▏         | 1348/60000 [00:15<11:21, 86.02it/s]


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

plt.plot(losses)