Taken from:

https://pyimagesearch.com/2021/07/19/pytorch-training-your-first-convolutional-neural-network-cnn/

In [16]:
# import the necessary packages
from torch.nn import Module
from torch.nn import Conv2d
from torch.nn import Linear
from torch.nn import MaxPool2d
from torch.nn import ReLU
from torch.nn import LogSoftmax
from torch import flatten

class LeNet(Module):
    def __init__(self, numChannels, classes):
        super(LeNet, self).__init__()

        # CONV => RELU => POOL
        self.conv1 = Conv2d(in_channels=numChannels, out_channels=20, kernel_size=(5,5))
        self.relu1 = ReLU()
        self.maxpool1 = MaxPool2d(kernel_size=(2,2), stride=(2,2))

        # second set of CONV layers
        self.conv2 = Conv2d(in_channels=20, out_channels=50, kernel_size=(5,5))
        self.relu2 = ReLU()
        self.maxpool2 = MaxPool2d(kernel_size=(2,2), stride=(2,2))

        # FC layer
        self.fc1 = Linear(in_features=800, out_features=500)
        self.relu3 = ReLU()

        # softmax
        self.fc2 = Linear(in_features=500, out_features=classes)
        self.logSoftmax = LogSoftmax(dim=1)

    def forward(self, x):
        hidden_states = []
        
        # pass the input through our first set of CONV => RELU =>
        # POOL layers
        x = self.conv1(x)
        #hidden_states.append(x.clone())
        x = self.relu1(x)
        x = self.maxpool1(x)
        
        # pass the output from the previous layer through the second
        # set of CONV => RELU => POOL layers
        x = self.conv2(x)
        #hidden_states.append(x.clone())
        x = self.relu2(x)
        x = self.maxpool2(x)
        
        # flatten the output from the previous layer and pass it
        # through our only set of FC => RELU layers
        x = flatten(x, 1)
        x = self.fc1(x)
        hidden_states.append(x.clone())
        x = self.relu3(x)
        
        # pass the output to our softmax classifier to get our output
        # predictions
        x = self.fc2(x)
        output = self.logSoftmax(x)
        
        # return the output predictions
        return output, hidden_states
    

In [17]:
# set the matplotlib backend so figures can be saved in the background
import matplotlib
matplotlib.use("Agg")
# import the necessary packages
from sklearn.metrics import classification_report
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
from torch.optim import AdamW
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
import argparse
import torch
import time

In [18]:
!conda info


     active environment : cv-py39
    active env location : /home/kenneth/anaconda3/envs/cv-py39
            shell level : 2
       user config file : /home/kenneth/.condarc
 populated config files : /home/kenneth/.condarc
          conda version : 24.4.0
    conda-build version : 3.28.2
         python version : 3.11.5.final.0
                 solver : libmamba (default)
       virtual packages : __archspec=1=zen2
                          __conda=24.4.0=0
                          __cuda=11.2=0
                          __glibc=2.31=0
                          __linux=5.8.0=0
                          __unix=0=0
       base environment : /home/kenneth/anaconda3  (writable)
      conda av data dir : /home/kenneth/anaconda3/etc/conda
  conda av metadata url : None
           channel URLs : https://repo.anaconda.com/pkgs/main/linux-64
                          https://repo.anaconda.com/pkgs/main/noarch
                          https://repo.anaconda.com/pkgs/r/linux-64
                

In [20]:
# define training hyperparameters
INIT_LR = 1e-3
BATCH_SIZE = 64
EPOCHS = 10
# define the train and val splits
TRAIN_SPLIT = 0.75
VAL_SPLIT = 1 - TRAIN_SPLIT
# set the device we will be using to train the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [21]:
# load the MNIST dataset
print("[INFO] loading the MNIST dataset...")
trainData = MNIST(root="data", train=True, download=True,
    transform=ToTensor())
testData = MNIST(root="data", train=False, download=True,
    transform=ToTensor())
# calculate the train/validation split
print("[INFO] generating the train/validation split...")
numTrainSamples = int(len(trainData) * TRAIN_SPLIT)
numValSamples = int(len(trainData) * VAL_SPLIT)
(trainData, valData) = random_split(trainData,
    [numTrainSamples, numValSamples],
    generator=torch.Generator().manual_seed(42))

[INFO] loading the MNIST dataset...
[INFO] generating the train/validation split...


In [22]:
# initialize the train, validation, and test data loaders
trainDataLoader = DataLoader(trainData, shuffle=True,
    batch_size=BATCH_SIZE)
valDataLoader = DataLoader(valData, batch_size=BATCH_SIZE)
testDataLoader = DataLoader(testData, batch_size=BATCH_SIZE)
# calculate steps per epoch for training and validation set
trainSteps = len(trainDataLoader.dataset) // BATCH_SIZE
valSteps = len(valDataLoader.dataset) // BATCH_SIZE

In [23]:
# Create a tensor with shape (3, 4, 5)
tensor = torch.randn(3, 4, 5)

# Flatten the tensor using flatten
flattened_tensor = tensor.flatten(start_dim=1)

In [24]:
!export CUDA_VISIBLE_DEVICES=3

In [25]:
# KD Loss
is_kd = False

lam = 1e-3
stabilizer = 1e-9

base_loss_fn = nn.NLLLoss()
loss_non_reducing = nn.NLLLoss(reduction='none')

def calc_knowledge_discontinuities(class_losses, hss):
    global stabilizer
    global lam

    total_score = 0
    
    for layer_idx in range(len(hss)):
        hs = hss[layer_idx]
        batch_size = hs.shape[0]

        hs = hs.flatten(start_dim=1)
        dist = torch.cdist(hs, hs) + stabilizer

        class_losses = class_losses.view(-1, 1)
        loss_diff = torch.cdist(class_losses, class_losses, p=1)
        
        """dist = torch.zeros(batch_size, batch_size)

        for i in range(batch_size):
            for j in range(batch_size):
                dist[i][j] = torch.dist(hs[i], hs[j])
        
        dist += stabilizer

        loss_diff = torch.zeros(batch_size, batch_size)

        for i in range(batch_size):
            for j in range(batch_size):
                # should just be absolute value of difference
                loss_diff[i][j] = torch.dist(class_losses[i], class_losses[j], p=1)
        
        #print('hs shape:', hs.shape)
        #print('dist shape:', dist.shape)
        #print('loss diff shape:', loss_diff.shape)

        #print('final matrix:', (loss_diff / dist).shape, loss_diff / dist)"""

        upper_tri_indices = torch.triu_indices(batch_size, batch_size, offset=1)

        total_score += torch.sum(loss_diff[upper_tri_indices[0], upper_tri_indices[1]] / dist[upper_tri_indices[0], upper_tri_indices[1]])

    return total_score

def normal_loss(output, target):
    global base_loss_fn
    
    out, hidden = output

    return base_loss_fn(out, target)

def kd_loss(output, target):
    global stabilizer
    global lam
    global base_loss_fn

    out, hidden = output

    initial_loss = base_loss_fn(out, target)
    
    return initial_loss + lam * calc_knowledge_discontinuities(loss_non_reducing(out, target), hidden)

In [26]:
# initialize the LeNet model
print("[INFO] initializing the LeNet model...")
model = LeNet(
    numChannels=1,
    classes=len(trainData.dataset.classes)).to(device)
# initialize our optimizer and loss function
opt = AdamW(model.parameters(), lr=INIT_LR)
lossFn = nn.NLLLoss()

if is_kd:
    lossFn = kd_loss
else:
    lossFn = normal_loss

# initialize a dictionary to store training history
H = {
    "train_loss": [],
    "train_acc": [],
    "val_loss": [],
    "val_acc": []
}
# measure how long training is going to take
print("[INFO] training the network...")
startTime = time.time()

[INFO] initializing the LeNet model...
[INFO] training the network...


### Train model

In [27]:
# loop over our epochs
for e in range(0, EPOCHS):
    # set the model in training mode
    model.train()
    # initialize the total training and validation loss
    totalTrainLoss = 0
    totalValLoss = 0
    # initialize the number of correct predictions in the training
    # and validation step
    trainCorrect = 0
    valCorrect = 0
    # loop over the training set
    for (x, y) in trainDataLoader:
        # send the input to the device
        (x, y) = (x.to(device), y.to(device))
        
        # perform a forward pass and calculate the training loss
        output = model(x)
        pred, hs = output
        
        loss = lossFn(output, y)
        # zero out the gradients, perform the backpropagation step,
        # and update the weights
        opt.zero_grad()
        loss.backward()
        opt.step()
        # add the loss to the total training loss so far and
        # calculate the number of correct predictions
        totalTrainLoss += loss
        trainCorrect += (pred.argmax(1) == y).type(
            torch.float).sum().item()

    # switch off autograd for evaluation
    with torch.no_grad():
        # set the model in evaluation mode
        model.eval()
        # loop over the validation set
        for (x, y) in valDataLoader:
            # send the input to the device
            (x, y) = (x.to(device), y.to(device))
            # make the predictions and calculate the validation loss
            output = model(x)
            pred, hs = output
            totalValLoss += lossFn(output, y)
            # calculate the number of correct predictions
            valCorrect += (pred.argmax(1) == y).type(
                torch.float).sum().item()
    
    # calculate the average training and validation loss
    avgTrainLoss = totalTrainLoss / trainSteps
    avgValLoss = totalValLoss / valSteps
    # calculate the training and validation accuracy
    trainCorrect = trainCorrect / len(trainDataLoader.dataset)
    valCorrect = valCorrect / len(valDataLoader.dataset)
    # update our training history
    H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
    H["train_acc"].append(trainCorrect)
    H["val_loss"].append(avgValLoss.cpu().detach().numpy())
    H["val_acc"].append(valCorrect)
    # print the model training and validation information
    print("[INFO] EPOCH: {}/{}".format(e + 1, EPOCHS))
    print("Train loss: {:.6f}, Train accuracy: {:.4f}".format(
        avgTrainLoss, trainCorrect))
    print("Val loss: {:.6f}, Val accuracy: {:.4f}\n".format(
        avgValLoss, valCorrect))

[INFO] EPOCH: 1/10
Train loss: 0.181119, Train accuracy: 0.9452
Val loss: 0.061007, Val accuracy: 0.9821

[INFO] EPOCH: 2/10
Train loss: 0.054079, Train accuracy: 0.9828
Val loss: 0.046786, Val accuracy: 0.9855

[INFO] EPOCH: 3/10
Train loss: 0.035444, Train accuracy: 0.9892
Val loss: 0.042500, Val accuracy: 0.9871

[INFO] EPOCH: 4/10
Train loss: 0.026009, Train accuracy: 0.9912
Val loss: 0.039098, Val accuracy: 0.9888

[INFO] EPOCH: 5/10
Train loss: 0.020102, Train accuracy: 0.9935
Val loss: 0.037184, Val accuracy: 0.9903

[INFO] EPOCH: 6/10
Train loss: 0.015976, Train accuracy: 0.9946
Val loss: 0.048381, Val accuracy: 0.9880

[INFO] EPOCH: 7/10
Train loss: 0.012865, Train accuracy: 0.9960
Val loss: 0.035916, Val accuracy: 0.9901

[INFO] EPOCH: 8/10
Train loss: 0.011305, Train accuracy: 0.9963
Val loss: 0.041740, Val accuracy: 0.9901

[INFO] EPOCH: 9/10
Train loss: 0.009487, Train accuracy: 0.9971
Val loss: 0.042666, Val accuracy: 0.9891

[INFO] EPOCH: 10/10
Train loss: 0.007565, Trai

In [28]:
# finish measuring how long training took
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
    endTime - startTime))
# we can now evaluate the network on the test set
print("[INFO] evaluating network...")
# turn off autograd for testing evaluation
with torch.no_grad():
    # set the model in evaluation mode
    model.eval()
    
    # initialize a list to store our predictions
    preds = []
    # loop over the test set
    for (x, y) in testDataLoader:
        # send the input to the device
        x = x.to(device)
        # make the predictions and add them to the list
        output = model(x)
        pred, hs = output
        preds.extend(pred.argmax(axis=1).cpu().numpy())
# generate a classification report
print(classification_report(testData.targets.cpu().numpy(),
    np.array(preds), target_names=testData.classes))

[INFO] total time taken to train the model: 47.56s
[INFO] evaluating network...
              precision    recall  f1-score   support

    0 - zero       1.00      0.99      0.99       980
     1 - one       0.99      0.99      0.99      1135
     2 - two       0.99      1.00      0.99      1032
   3 - three       0.99      0.99      0.99      1010
    4 - four       1.00      0.96      0.98       982
    5 - five       0.97      0.99      0.98       892
     6 - six       0.99      0.99      0.99       958
   7 - seven       1.00      0.97      0.98      1028
   8 - eight       0.99      0.99      0.99       974
    9 - nine       0.96      0.99      0.97      1009

    accuracy                           0.99     10000
   macro avg       0.99      0.99      0.99     10000
weighted avg       0.99      0.99      0.99     10000



In [14]:
plot_path = 'kd_fconly_reg_e3.png'
model_path = 'kd_fconly_reg_e3.pth'

In [15]:
# plot the training loss and accuracy
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["val_loss"], label="val_loss")
plt.plot(H["train_acc"], label="train_acc")
plt.plot(H["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.savefig(plot_path)
# serialize the model to disk
torch.save(model, model_path)
plt.show()

In [36]:
n = 3

upper_tri_indices = torch.triu_indices(n, n, offset=1)

upper_tri_indices

tensor([[0, 0, 1],
        [1, 2, 2]])

In [32]:
num = np.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

In [33]:
denom = np.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]])

In [35]:
(num[upper_tri_indices[0], upper_tri_indices[1]] / denom[upper_tri_indices[0], upper_tri_indices[1]])

array([2., 3., 6.])