In [1]:
# to load dataset
from datasets import Datasets
from kd_triplet_datasets import KDTripletDataset

# for network and training
from network import Net_teacher, Net_student
from network_fit import NetworkFit

# to calculate the score
import savescore
from score import Score
from score_calc import ScoreCalc

# pytorch
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim

# numpy and matplotlib
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# initialize for  each parameters
DATASET = 'CIFAR10'
BATCH_SIZE = 64
NUM_WORKERS = 2

WEIGHT_DECAY = 0.001
LEARNING_RATE = 0.01
MOMENTUM = 0.9

SCHEDULER_STEPS = 10
SCHEDULER_GAMMA = 0.1

SEED = 1

EPOCH = 100

KD_LAMBDA = 10.0

TRIPLET_MARGINE = 5.0

In [3]:
# fixing the seed
torch.cuda.manual_seed_all(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)

In [4]:
# check if gpu is available
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("gpu mode")
else:
    device = torch.device("cpu")
    print("cpu mode")

gpu mode


In [5]:
# the name of results files
codename = 'kd_example'

fnnname = codename + "_fnn_model"

total_loss_name = codename + "_total_loss"
soft_loss_name = codename + "_soft_loss"
tri_loss_name = codename + "_tri_loss"
acc_name = codename + "_accuracy"

result_name = codename + "_result"

In [6]:
# load the data set
instance_datasets = Datasets(DATASET, BATCH_SIZE, NUM_WORKERS, shuffle = False)
data_sets = instance_datasets.create()

#trainloader = data_sets[0]
#testloader = data_sets[1]
classes = data_sets[2]
based_labels = data_sets[3]
trainset = data_sets[4]
testset = data_sets[5]

Dataset : CIFAR10
Files already downloaded and verified
Files already downloaded and verified


In [7]:
# use the KD Triplet Dataset by using above dataset
tri_trainset = KDTripletDataset(trainset)
tri_testset = KDTripletDataset(testset)
tri_trainloader = torch.utils.data.DataLoader(tri_trainset, batch_size = BATCH_SIZE, shuffle = True, num_workers = NUM_WORKERS)
tri_testloader = torch.utils.data.DataLoader(tri_testset, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS)

In [8]:
# network and criterions
model_t = Net_teacher().to(device)
model_s = Net_student().to(device)

# model_t.load_state_dict(torch.load("cnn_alex.pth"))
params_st = list(model_t.parameters()) + list(model_s.parameters())
optimizer = optim.AdamW(params_st, lr=LEARNING_RATE, betas=(0.9, 0.999), eps=1e-08, 
                        weight_decay=WEIGHT_DECAY, amsgrad=False)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=SCHEDULER_STEPS, gamma=SCHEDULER_GAMMA)

soft_criterion = nn.CrossEntropyLoss()
triplet_loss = nn.TripletMarginLoss(margin=TRIPLET_MARGINE)

In [9]:
# fit for training and test
fit = NetworkFit(model_t, model_s, optimizer, soft_criterion, triplet_loss)

In [10]:
# to manage all scores
loss = Score()
loss_s = Score()
loss_t = Score()
correct = Score()
score_loss = [loss, loss_s, loss_t]
score_correct = [correct]
sc = ScoreCalc(score_loss, score_correct, BATCH_SIZE)

In [11]:
import wandb
wandb.init(project='251b_distillation_metric')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mspamidi[0m (use `wandb login --relogin` to force relogin)
2022-03-09 16:03:05.917457: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib64:/usr/local/cuda/lib64
2022-03-09 16:03:05.917498: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [12]:
wandb.config.update({"KD_LAMBDA": KD_LAMBDA,"TRIPLET_MARGINE": TRIPLET_MARGINE})

In [None]:
# training and test
for epoch in range(EPOCH):
    print('epoch', epoch+1)
    
    for (inputs, labels) in tri_trainloader:
        img1_t = inputs[0].to(device)
        img2_s = inputs[0].to(device)
        img3_s = inputs[1].to(device)
        
        images = (img1_t, img2_s, img3_s)
        
        label1_t = labels[0].to(device)
        label2_s = labels[0].to(device)
        label3_s = labels[1].to(device)

        label = (label1_t, label2_s, label3_s)
        loss_train = fit.train(images, label, KD_LAMBDA)
        wandb.log({"train_loss":loss_train})
        
    for (inputs, labels) in tri_trainloader:
        img1_t = inputs[0].to(device)
        img2_s = inputs[0].to(device)
        img3_s = inputs[1].to(device)
        
        images = (img1_t, img2_s, img3_s)
        
        label1_t = labels[0].to(device)
        label2_s = labels[0].to(device)
        label3_s = labels[1].to(device)
        
        label = (label1_t, label2_s, label3_s)
        
        losses, corrects = fit.test(images, label, KD_LAMBDA)
        
        losses, corrects = sc.calc_sum(losses, corrects)
        wandb.log({"valid_loss":losses[0].get_score()})
        
    sc.score_print(len(trainset))
    sc.score_append(len(trainset))
    sc.score_del()
    
    for (inputs, labels) in tri_testloader:
        img1_t = inputs[0].to(device)
        img2_s = inputs[0].to(device)
        img3_s = inputs[1].to(device)
        
        images = (img1_t, img2_s, img3_s)
        
        label1_t = labels[0].to(device)
        label2_s = labels[0].to(device)
        label3_s = labels[1].to(device)
        
        label = (label1_t, label2_s, label3_s)
        wandb.log({"train_loss":loss_train})
        losses, corrects = fit.test(images, label, KD_LAMBDA)
        
        losses, corrects = sc.calc_sum(losses, corrects)
        wandb.log({"test_loss":losses[0].get_score()})
    
    sc.score_print(len(testset), train = False)
    sc.score_append(len(testset), train = False)
    sc.score_del()
    
    scheduler.step()
    
    torch.save(model_s.state_dict(), 'epoch_saved.pth')

epoch 1
train mean loss=10.759288651428223, accuracy=0.28864
test mean loss=11.76998985595703, accuracy=0.2883
epoch 2
train mean loss=9.414013042907715, accuracy=0.30644
test mean loss=9.708882153320312, accuracy=0.3082
epoch 3
train mean loss=7.930045469665528, accuracy=0.28668
test mean loss=8.64157592163086, accuracy=0.2908
epoch 4
train mean loss=8.807560066223145, accuracy=0.30388
test mean loss=9.452113809204102, accuracy=0.3117
epoch 5
train mean loss=6.803151614074707, accuracy=0.36482
test mean loss=7.381226063537597, accuracy=0.358
epoch 6
train mean loss=7.676170022583007, accuracy=0.35058
test mean loss=8.226743469238281, accuracy=0.3513
epoch 7
train mean loss=6.137057235412597, accuracy=0.38238
test mean loss=6.732665087890625, accuracy=0.3773
epoch 8
train mean loss=6.198793807373047, accuracy=0.39806
test mean loss=6.886965162658691, accuracy=0.3928
epoch 9
train mean loss=6.274522077026367, accuracy=0.39062
test mean loss=6.711544671630859, accuracy=0.3867
epoch 10
tr

In [None]:
# get the scores
train_losses, train_corrects = sc.get_value()
test_losses, test_corrects = sc.get_value(train = False)

In [None]:
# output the glaphs of the scores
torch.save(model_s.state_dict(), fnnname + '.pth')

savescore.plot_score(EPOCH, train_losses[0], test_losses[0], y_lim = 5.0, y_label = 'LOSS', legend = ['train loss', 'test loss'], title = 'total loss', filename = total_loss_name)

savescore.plot_score(EPOCH, train_losses[1], test_losses[1], y_lim = 5.0, y_label = 'LOSS', legend = ['train loss', 'test loss'], title = 'softmax loss', filename = soft_loss_name)

savescore.plot_score(EPOCH, train_losses[2], test_losses[2], y_lim = 5.0, y_label = 'LOSS', legend = ['train loss', 'test loss'], title = 'triplet loss', filename = tri_loss_name)

savescore.plot_score(EPOCH, train_corrects[0], test_corrects[0], y_lim = 1, y_label = 'ACCURACY', legend = ['train acc', 'test acc'], title = 'accuracy', filename = acc_name)

savescore.save_data(train_losses[0], test_losses[0], train_corrects[0], test_corrects[0], result_name)