In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights

import matplotlib.pyplot as plt
import numpy as np

In [2]:
import sys
sys.path.insert(1, '/home/shuvraneel/Desktop/Discrete Key Value Bottleneck')

In [3]:
BATCH_SIZE = 256
NUM_CLASSES = 10
SAMPLES_PER_CLASS = 5000
NUM_EPOCHS = 2000
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
resnet50 = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2).to(device)
resnet_model = nn.Sequential(*list(resnet50.children())[:-1]) # Embeddings of dimension 2048 are created using the backbone, droppinf the fc layer

weights = ResNet50_Weights.IMAGENET1K_V2
preprocess = weights.transforms()

In [5]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=preprocess)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=preprocess)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [6]:
def plot_images(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [7]:
def filter_for_two_classes(dataset, class_a, class_b):
    targets = torch.tensor(dataset.targets)
    indices = torch.where((targets == class_a) | (targets == class_b))[0]
    filtered_dataset = torch.utils.data.Subset(dataset, indices)
    return torch.utils.data.DataLoader(filtered_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

In [8]:
loss_criterion = nn.CrossEntropyLoss()

def train_incrementally(num_epochs, trainset, model, loss_criterion, optimizer, min_delta=0, scheduler=None):

    for i in range(int(NUM_CLASSES/2)):
        print(f"Training period  {i + 1 } started")
        print(f"Classes being trained on are: {(2*i, 2*i+1)}")

        train_loader = filter_for_two_classes(trainset, 2*i, 2*i+1)

        losses = []
        last_epoch = None

        for epoch in range(num_epochs):
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                loss = loss_criterion(outputs, labels)
                losses.append(loss.item())

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if epoch % 100 == 0:
                print(f'Epoch {epoch+1}/{num_epochs}, Loss: {losses[-1]:.4f}')

            last_epoch = epoch

        print(f"Training period { i + 1 } completed; last epoch run = { last_epoch + 1 }")

        plt.plot(losses)
        plt.xlabel("Epochs")
        plt.ylabel("Training Loss")

In [9]:
def init_weights_kaiming(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
        nn.init.zeros_(m.bias)

In [None]:
import DiscreteKeyValueBottleNeck

dkvb = DiscreteKeyValueBottleNeck.DiscreteKeyValueBottleneck(encoder = resnet_model,
                                  num_codebooks = 256,
                                  enc_out_dim = 2048,
                                  embed_dim = 14,
                                  value_dim = 10,
                                  keys_per_codebook = 4096,
                                  device = device)

dkvb = dkvb.to(device)
dkvb.apply(init_weights_kaiming)
optimizer = optim.SGD(dkvb.parameters(), lr=0.3)
train_incrementally(NUM_EPOCHS, trainset, dkvb, loss_criterion, optimizer)