In [1]:
import torch.nn as nn
from torchvision import models
import torch
import torch.optim as optim
import numpy as np

import time

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [3]:
from kd_triplet_datasets import CDataset
from datasets import Datasets

DATASET = 'CIFAR10'
BATCH_SIZE = 64
NUM_WORKERS = 2

# load the data set
instance_datasets = Datasets(DATASET, BATCH_SIZE, NUM_WORKERS, shuffle = False)
data_sets = instance_datasets.create()

#trainloader = data_sets[0]
#testloader = data_sets[1]
classes = data_sets[2]
based_labels = data_sets[3]
trainset = data_sets[4]
testset = data_sets[5]

trainset = CDataset(trainset)
testset = CDataset(testset)

train_loader = torch.utils.data.DataLoader(trainset, batch_size = BATCH_SIZE, shuffle = True, 
                                              num_workers = NUM_WORKERS)
test_loader = torch.utils.data.DataLoader(testset, batch_size = BATCH_SIZE, shuffle = False, 
                                             num_workers = NUM_WORKERS)

Dataset : CIFAR10
Files already downloaded and verified
Files already downloaded and verified


In [4]:
class Only_teacher(nn.Module):
    def __init__(self,n_class):
        self.n_class = n_class
        super(Only_teacher,self).__init__()

        resnet = models.resnet50(pretrained=True)
        num_input_ftrs = resnet.fc.in_features
        self.pretrained_model = nn.Sequential(*(list(resnet.children())[:-1]))
        self.linear = nn.Linear(num_input_ftrs , self.n_class)
        
        for param in self.pretrained_model.parameters():
            param.requires_grad = False        
        
        for param in self.linear.parameters():
            param.requires_grad = True
            
    def forward(self, x):
        features = self.pretrained_model(x)
        features = features.reshape(features.size(0),-1)
        output = self.linear(features)
        return output

In [5]:
criterion = nn.CrossEntropyLoss() 
epochs = 100
n_class = 10
lr = 0.0001

teacher_model = Only_teacher(n_class=n_class)
optimizer = optim.AdamW(teacher_model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, 
                        weight_decay=0.01, amsgrad=False)

teacher_model = teacher_model.to(device)

KeyboardInterrupt: 

In [None]:
import wandb
wandb.init(project='251b_distillation_metric')

In [None]:
wandb.config.update({"learning_rate": 0.0001})

In [None]:
for epoch in range(epochs):
    teacher_model.train()
    ts = time.time()
    losses = []
    for iter, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()

        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = teacher_model(inputs) 
        loss = criterion(outputs, labels.long())
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        
    wandb.log({"train_loss":np.mean(np.array(losses))})
    
    teacher_model.eval()
    
    val_losses = []
    for iter, (inputs, labels) in enumerate(test_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = teacher_model(inputs) 
        loss = criterion(outputs, labels.long())
        val_losses.append(loss.item())
    
    wandb.log({"test_loss":np.mean(np.array(val_losses))})


In [None]:
torch.save(teacher_model.state_dict(), 'only_teacher_model.pth')