In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import time
from tqdm.notebook import trange , tqdm
import torch.nn.functional as F

In [2]:
transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
batch_size = 128

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [3]:
def CreateNet():
    MyModel_ResNet50 = torchvision.models.resnet50(pretrained=True)
    MyModel_ResNet50.fc = torch.nn.Linear(MyModel_ResNet50.fc.in_features, 10)
    torch.nn.init.xavier_uniform_(MyModel_ResNet50.fc.weight)
    return MyModel_ResNet50

In [11]:
device = "cuda" if torch.cuda.is_available() else "cpu"
MyModel_ResNet50 = CreateNet().to(device)
criterion = nn.CrossEntropyLoss()
lr= 1e-5
weight_decay =5e-4
epochs = 20
ModelParameters = [param for name, param in MyModel_ResNet50.named_parameters() if 'fc' not in str(name)]
optimizer = torch.optim.Adam([{'params':ModelParameters}, {'params': MyModel_ResNet50.fc.parameters(), 'lr': lr*10}], lr=lr, weight_decay=weight_decay)

In [8]:
def train(net, train_dataloader, valid_dataloader, criterion, optimizer, scheduler=None, epochs=10, device='cpu', checkpoint_epochs=10):
    start = time.time()
    print(f'Training for {epochs} epochs on {device}')
    
    for epoch in range(1,epochs+1):
        print(f"Epoch {epoch}/{epochs}")
        
        net.train()  # put network in train mode for Dropout and Batch Normalization
        train_loss = torch.tensor(0., device=device)  # loss and accuracy tensors are on the GPU to avoid data transfers
        train_accuracy = torch.tensor(0., device=device)
        for X, y in train_dataloader:
            X = X.to(device)
            y = y.to(device)
            preds = net(X)
            loss = criterion(preds, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            with torch.no_grad():
                train_loss += loss * train_dataloader.batch_size
                train_accuracy += (torch.argmax(preds, dim=1) == y).sum()
        
        if valid_dataloader is not None:
            net.eval()  # put network in train mode for Dropout and Batch Normalization
            valid_loss = torch.tensor(0., device=device)
            valid_accuracy = torch.tensor(0., device=device)
            with torch.no_grad():
                for X, y in valid_dataloader:
                    X = X.to(device)
                    y = y.to(device)
                    preds = net(X)
                    loss = criterion(preds, y)

                    valid_loss += loss * valid_dataloader.batch_size
                    valid_accuracy += (torch.argmax(preds, dim=1) == y).sum()
        
        if scheduler is not None: 
            scheduler.step()
            
        print(f'Training loss: {train_loss/len(train_dataloader.dataset):.2f}')
        print(f'Training accuracy: {100*train_accuracy/len(train_dataloader.dataset):.2f}')
        
        if valid_dataloader is not None:
            print(f'Valid loss: {valid_loss/len(valid_dataloader.dataset):.2f}')
            print(f'Valid accuracy: {100*valid_accuracy/len(valid_dataloader.dataset):.2f}')
        
        if epoch%checkpoint_epochs==0:
            torch.save({
                'epoch': epoch,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, './checkpoint.pth.tar')
        
        print()
    
    end = time.time()
    print(f'Total training time: {end-start:.1f} seconds')
    return net

In [16]:
TeacherModel = train(MyModel_ResNet50, trainloader, testloader, criterion, optimizer, epochs=10, device=device)

Training for 10 epochs on cuda
Epoch 1/10
Training loss: 0.05
Training accuracy: 98.65
Valid loss: 0.17
Valid accuracy: 94.61

Epoch 2/10
Training loss: 0.02
Training accuracy: 99.63
Valid loss: 0.18
Valid accuracy: 94.65

Epoch 3/10
Training loss: 0.01
Training accuracy: 99.89
Valid loss: 0.19
Valid accuracy: 94.83

Epoch 4/10
Training loss: 0.00
Training accuracy: 99.97
Valid loss: 0.20
Valid accuracy: 94.77

Epoch 5/10
Training loss: 0.00
Training accuracy: 99.98
Valid loss: 0.20
Valid accuracy: 94.85

Epoch 6/10
Training loss: 0.00
Training accuracy: 99.98
Valid loss: 0.21
Valid accuracy: 94.79

Epoch 7/10
Training loss: 0.00
Training accuracy: 99.97
Valid loss: 0.22
Valid accuracy: 94.83

Epoch 8/10
Training loss: 0.00
Training accuracy: 99.97
Valid loss: 0.23
Valid accuracy: 94.54

Epoch 9/10
Training loss: 0.00
Training accuracy: 99.91
Valid loss: 0.25
Valid accuracy: 94.46

Epoch 10/10
Training loss: 0.01
Training accuracy: 99.87
Valid loss: 0.23
Valid accuracy: 94.70

Total tr

In [17]:
Model_ResNet18 = torchvision.models.resnet18(pretrained=False)
Model_ResNet18.fc = torch.nn.Linear(Model_ResNet18.fc.in_features, 10)



In [18]:
def loss_fn_kd(outputs, labels, teacher_outputs, alpha=0.95 , T=6):

    alpha = alpha
    T = T
    KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
                             F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
              F.cross_entropy(outputs, labels) * (1. - alpha)

    return KD_loss

In [None]:
# alphas = [0.99, 0.95, 0.5, 0.1, 0.05]
# temperatures = [20., 10., 8., 6., 4.5, 3., 2., 1.5]

In [21]:
def train_kd(student_model, teacher_model, train_dataloader, test_dataloader,criterion, loss_kd , optimizer, scheduler=None, epochs=10, device='cuda' , checkpoint_epochs=10):
    start = time.time()
    print(f'Training for {epochs} epochs on {device}')
    
    for epoch in range(1,epochs+1):
        print(f"Epoch {epoch}/{epochs}")
        
        student_model.train()  # put network in train mode for Dropout and Batch Normalization
        teacher_model.eval()
        train_loss = torch.tensor(0., device=device)  # loss and accuracy tensors are on the GPU to avoid data transfers
        train_accuracy = torch.tensor(0., device=device)
        for X, y in train_dataloader:
            X = X.to(device)
            y = y.to(device)
            preds = student_model(X)
            with torch.no_grad():
                output_teacher_batch = teacher_model(X)
            loss = loss_fn_kd(preds, y, output_teacher_batch)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            with torch.no_grad():
                train_loss += loss * train_dataloader.batch_size
                train_accuracy += (torch.argmax(preds, dim=1) == y).sum()
        
        if test_dataloader is not None:
            student_model.eval()  # put network in train mode for Dropout and Batch Normalization
            valid_loss = torch.tensor(0., device=device)
            valid_accuracy = torch.tensor(0., device=device)
            with torch.no_grad():
                for X, y in test_dataloader:
                    X = X.to(device)
                    y = y.to(device)
                    preds = student_model(X)
                    loss = criterion(preds, y)

                    valid_loss += loss * test_dataloader.batch_size
                    valid_accuracy += (torch.argmax(preds, dim=1) == y).sum()
        
        if scheduler is not None: 
            scheduler.step()
            
        print(f'Training loss: {train_loss/len(train_dataloader.dataset):.2f}')
        print(f'Training accuracy: {100*train_accuracy/len(train_dataloader.dataset):.2f}')
        
        if test_dataloader is not None:
            print(f'Valid loss: {valid_loss/len(test_dataloader.dataset):.2f}')
            print(f'Valid accuracy: {100*valid_accuracy/len(test_dataloader.dataset):.2f}')
        
        if epoch%checkpoint_epochs==0:
            torch.save({
                'epoch': epoch,
                'state_dict': student_model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, './student.pth.tar')
        
        print()
    
    end = time.time()
    print(f'Total training time: {end-start:.1f} seconds')
    return student_model

In [25]:
optimizer = torch.optim.Adam(Model_ResNet18.parameters(), lr=0.001, weight_decay=5e-4)
device = "cuda" if torch.cuda.is_available() else "cpu"
Model_ResNet18 = Model_ResNet18.to(device)
criterion = nn.CrossEntropyLoss()

In [28]:
torch.cuda.empty_cache()

In [29]:
studentModel_kd = train_kd(TeacherModel,Model_ResNet18 , trainloader, testloader, criterion,loss_fn_kd, optimizer, epochs=20, device=device)

Training for 20 epochs on cuda
Epoch 1/20




Training loss: 1.85
Training accuracy: 99.89
Valid loss: 0.24
Valid accuracy: 94.65

Epoch 2/20
Training loss: 1.85
Training accuracy: 99.90
Valid loss: 0.23
Valid accuracy: 94.54

Epoch 3/20
Training loss: 1.84
Training accuracy: 99.89
Valid loss: 0.23
Valid accuracy: 94.65

Epoch 4/20
Training loss: 1.85
Training accuracy: 99.89
Valid loss: 0.23
Valid accuracy: 94.76

Epoch 5/20
Training loss: 1.85
Training accuracy: 99.90
Valid loss: 0.23
Valid accuracy: 94.74

Epoch 6/20
Training loss: 1.84
Training accuracy: 99.89
Valid loss: 0.23
Valid accuracy: 94.63

Epoch 7/20
Training loss: 1.85
Training accuracy: 99.91
Valid loss: 0.24
Valid accuracy: 94.61

Epoch 8/20
Training loss: 1.84
Training accuracy: 99.91
Valid loss: 0.24
Valid accuracy: 94.68

Epoch 9/20
Training loss: 1.85
Training accuracy: 99.90
Valid loss: 0.23
Valid accuracy: 94.64

Epoch 10/20
Training loss: 1.84
Training accuracy: 99.89
Valid loss: 0.23
Valid accuracy: 94.66

Epoch 11/20
Training loss: 1.85
Training accuracy: