In [1]:
import torch
from torch import nn
from datasets import load_dataset
from random import randint

dataset = load_dataset("mnist")

Reusing dataset mnist (C:\Users\eshaa\.cache\huggingface\datasets\mnist\mnist\1.0.0\fda16c03c4ecfb13f165ba7e29cf38129ce035011519968cdaf74894ce91c9d4)


  0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
dataset['train'], dataset['test']

(Dataset({
     features: ['image', 'label'],
     num_rows: 60000
 }),
 Dataset({
     features: ['image', 'label'],
     num_rows: 10000
 }))

## Transform MNIST Dataset into variable sized images :D

In [3]:
from PIL import Image
from tqdm import tqdm
import numpy as np

def parse_dataset (dataset, desc=""):
    oneXImages = []
    twoXImages = []
    threeXImages = []
    labels = []

    for value in tqdm(dataset, desc=desc):
        oneXImage = value['image']
        twoXImage = value['image'].resize((48, 48), resample=Image.BOX)
        threeXImage = value['image'].resize((64, 64), resample=Image.BOX)
        
        oneXImages.append(np.expand_dims(np.uint8(oneXImage),0).tolist())
        twoXImages.append(np.expand_dims(np.uint8(twoXImage),0).tolist())
        threeXImages.append(np.expand_dims(np.uint8(threeXImage),0).tolist())
        labels.append(value['label'])

    print("Converting to tensor...",end="") 
    oneXImages = torch.tensor(oneXImages).to(torch.float)/255
    twoXImages = torch.tensor(twoXImages).to(torch.float)/255
    threeXImages = torch.tensor(threeXImages).to(torch.float)/255
    labels = torch.tensor(labels).to(torch.long)
    print("Done")

    return oneXImages, twoXImages, threeXImages, labels

oneXImagesTrain, twoXImagesTrain, threeXImagesTrain, labelsTrain = parse_dataset(dataset['train'])
oneXImagesTest, twoXImagesTest, threeXImagesTest, labelsTest = parse_dataset(dataset['test'])

100%|██████████| 60000/60000 [00:30<00:00, 1942.04it/s]


Converting to tensor...Done


100%|██████████| 10000/10000 [00:05<00:00, 1801.38it/s]


Converting to tensor...Done


In [4]:
print("1x Train Shape:   ", oneXImagesTrain.shape)
print("2x Train Shape:   ", twoXImagesTrain.shape)
print("3x Train Shape:   ", threeXImagesTrain.shape)
print("Train Label Shape:", labelsTrain.shape)
print("1x Test Shape:    ", oneXImagesTest.shape)
print("2x Test Shape:    ", twoXImagesTest.shape)
print("3x Test Shape:    ", threeXImagesTest.shape)
print("Test Label Shape: ", labelsTest.shape)

1x Train Shape:    torch.Size([60000, 1, 28, 28])
2x Train Shape:    torch.Size([60000, 1, 48, 48])
3x Train Shape:    torch.Size([60000, 1, 64, 64])
Train Label Shape: torch.Size([60000])
1x Test Shape:     torch.Size([10000, 1, 28, 28])
2x Test Shape:     torch.Size([10000, 1, 48, 48])
3x Test Shape:     torch.Size([10000, 1, 64, 64])
Test Label Shape:  torch.Size([10000])


## Initialize Regular NN Model

In [10]:
# https://medium.com/@nutanbhogendrasharma/pytorch-convolutional-neural-network-with-mnist-dataset-4e8a4265e118
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(         
            nn.Conv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,              
                stride=1,                   
                padding=2,                  
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(         
            nn.Conv2d(
                in_channels=16, 
                out_channels=32, 
                kernel_size=5, 
                stride=1,
                padding=2),     
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),                
        )

        self.conv3 = nn.Sequential(         
            nn.Conv2d(
                in_channels=32, 
                out_channels=32, 
                kernel_size=5, 
                stride=1,
                padding=2),     
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),                
        )

        # fully connected layer, output 10 classes
        self.out = nn.LazyLinear(10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)       
        output = self.out(x)
        return output

model = CNN().to(device)
inp = oneXImagesTrain[0].unsqueeze(0).to(device)
print("Input Shape:", inp.shape, "dtype:", inp.dtype)
out = model(inp) # unsqueeze because we only have one sample in the batch
print("Output shape:", out.shape)

Input Shape: torch.Size([1, 1, 28, 28]) dtype: torch.float32
Output shape: torch.Size([1, 10])


## Training the control model!

In [11]:
from tqdm import trange
from random import randint
import numpy as np

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_func = nn.CrossEntropyLoss()
batch_size = 32
epochs = 3

losses = []
loss_avg = 0
for epochs in range(epochs):
    # Test Validation
    print("Running test validation...", end='\r')
    num_correct = 0
    for i in range(0, len(oneXImagesTest), 10):
        inp = oneXImagesTest[i:i+10].to(device)
        exp_out = labelsTest[i:i+10].to(device)
        with torch.no_grad():
            out = model(inp)
            out = torch.argmax(torch.softmax(out,dim=1),1)
        num_correct += torch.sum(exp_out == out).item()
    print(f"Validation Accuracy on Epoch {epochs+1}: {round(num_correct/len(oneXImagesTest)*100)}%       ") 

    # Actual training
    progress_bar = trange(len(oneXImagesTrain)//batch_size)
    for i in progress_bar:
        index = torch.randperm(len(oneXImagesTrain))[:batch_size]
        inp = oneXImagesTrain[index].to(device)
        exp_out = labelsTrain[index].to(device)

        optimizer.zero_grad()
        out = model(inp)
        loss = loss_func(out, exp_out)
        loss.backward()
        optimizer.step()

        loss_avg += loss.item()
        if i % 100 == 0: 
            if i == 0:
                losses.append(loss.item())
            else:
                losses.append(loss_avg / 100)
                loss_avg = 0
        progress_bar.set_description(f"Epoch: {epochs+1} Loss: {loss.item():.4f}")

Validation Accuracy on Epoch 1: 11%       


Epoch: 1 Loss: 0.1661: 100%|██████████| 1875/1875 [00:28<00:00, 65.18it/s]


Validation Accuracy on Epoch 2: 95%       


Epoch: 2 Loss: 0.0419: 100%|██████████| 1875/1875 [00:31<00:00, 58.62it/s]


Validation Accuracy on Epoch 3: 97%       


Epoch: 3 Loss: 0.0013: 100%|██████████| 1875/1875 [00:32<00:00, 57.54it/s]


In [None]:
import matplotlib.pyplot as plt

plt.plot(losses)

## Initialize VNN

In [12]:
from VNN import *

class CNNwithVNN (nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.conv1 = nn.Sequential(         
            nn.Conv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,              
                stride=1,                   
                padding=2,                  
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(         
            nn.Conv2d(
                in_channels=16, 
                out_channels=32, 
                kernel_size=5, 
                stride=1,
                padding=2),     
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),                
        )
        self.conv3 = nn.Sequential(         
            nn.Conv2d(
                in_channels=32, 
                out_channels=32, 
                kernel_size=5, 
                stride=1,
                padding=2),     
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),                
        )

        #* VNN Declaration
        weight_model = nn.Sequential(
            nn.Linear(33, 32),
            nn.Tanh(),
            nn.Linear(32, 1)
        ) 

        bias_model = nn.Sequential(
            nn.Linear(17, 10),
            nn.Tanh(),
            nn.Linear(10, 1)
        )

        dense_model = nn.Sequential(
            nn.Linear(128, 10)
        )
        self.vnn = VNN(dense_model, weight_model, bias_model).to(device)

    def forward (self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size(0), -1)       
        x = self.vnn(x)
        return x 
        
model = CNNwithVNN().to(device)
inp = oneXImagesTrain[0].unsqueeze(0).to(device)
out = model(inp)
print(f"Input Shape: {inp.shape} Output shape: {out.shape}")

inp = twoXImagesTrain[0].unsqueeze(0).to(device)
out = model(inp)
print(f"Input Shape: {inp.shape} Output shape: {out.shape}")

inp = threeXImagesTrain[0].unsqueeze(0).to(device)
out = model(inp)
print(f"Input Shape: {inp.shape} Output shape: {out.shape}")

cuda
Input Shape: torch.Size([1, 1, 28, 28]) Output shape: torch.Size([1, 10])
Input Shape: torch.Size([1, 1, 48, 48]) Output shape: torch.Size([1, 10])
Input Shape: torch.Size([1, 1, 64, 64]) Output shape: torch.Size([1, 10])


## Train VNN

In [15]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_func = nn.CrossEntropyLoss()
batch_size = 32
epochs = 5

losses = []
loss_avg = 0
for epochs in range(epochs):
    # Test Validation
    print("Running test validation...", end='\r')
    num_correct = 0
    for i in range(0, len(oneXImagesTest), 10):
        random_int = randint(0, 2)
        if random_int == 0: 
            inp = oneXImagesTest[i:i+10].to(device)
        elif random_int == 1:
            inp = twoXImagesTest[i:i+10].to(device)
        else:
            inp = threeXImagesTest[i:i+10].to(device)

        exp_out = labelsTest[i:i+10].to(device)
        with torch.no_grad():
            out = model(inp)
            out = torch.argmax(torch.softmax(out,dim=1),1)
        num_correct += torch.sum(exp_out == out).item()
    print(f"Validation Accuracy on Epoch {epochs+1}: {round(num_correct/len(oneXImagesTest)*100)}%       ") 

    # Actual training
    progress_bar = trange(len(oneXImagesTrain)//batch_size)
    for i in progress_bar:
        index = torch.randperm(len(oneXImagesTrain))[:batch_size]
        random_int = randint(0, 2)
        if random_int == 0: 
            inp = oneXImagesTrain[index].to(device)
        elif random_int == 1:
            inp = twoXImagesTrain[index].to(device)
        else:
            inp = threeXImagesTrain[index].to(device)
            
        exp_out = labelsTrain[index].to(device)

        optimizer.zero_grad()
        out = model(inp)
        loss = loss_func(out, exp_out)
        loss.backward()
        optimizer.step()

        loss_avg += loss.item()
        if i % 100 == 0: 
            if i == 0:
                losses.append(loss.item())
            else:
                losses.append(loss_avg / 100)
                loss_avg = 0
        progress_bar.set_description(f"Epoch: {epochs+1} Loss: {loss.item():.4f}")

Validation Accuracy on Epoch 1: 10%       


Epoch: 1 Loss: 2.3385:   0%|          | 5/1875 [00:01<10:39,  2.93it/s]


RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 6.00 GiB total capacity; 3.37 GiB already allocated; 102.33 MiB free; 3.94 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF