In [4]:
import torch
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision.datasets import CIFAR10
from torch import nn
from torch import optim
from torchvision import models
from torch.utils.data import DataLoader
from tqdm import tqdm

In [5]:
train_data = CIFAR10(root='./data', train=True, download=True, transform=Compose([ToTensor(),Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))
test_data = CIFAR10(root='./data', train=False, download=True, transform=Compose([ToTensor(),Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))

Files already downloaded and verified
Files already downloaded and verified


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
class ExampleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet = models.resnet18(pretrained=True)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, 10)
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.resnet(x)

    def loss(self, inputs, labels):
        outputs = self.forward(inputs)
        correct = (outputs.argmax(dim=1) == labels).sum().item()
        return self.ce_loss(outputs, labels), correct

In [14]:
example_model = ExampleModel().to(device)

In [15]:
lr = 1e-2
batch_size = 256

In [16]:
train_loader = DataLoader(train_data, batch_size=batch_size)
test_loader = DataLoader(test_data, batch_size=batch_size)

In [17]:
optimizer = torch.optim.SGD([
            {'params': list(example_model.parameters())[:-2]},
            {'params': example_model.resnet.fc.parameters(), 'lr': lr}
        ], lr=lr*0.1, momentum=0.9)

In [18]:
def train(model, optimizer, train_loader, test_loader, epochs):

  def step():
    correct = 0
    count = 0
    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        loss, cor = model.loss(inputs, labels)
        correct += cor
        count += inputs.shape[0]
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return correct/count

  def test():
    correct = 0
    count = 0
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        with torch.no_grad():
            loss, cor = model.loss(inputs, labels)
            correct += cor
            count += inputs.shape[0]
    return correct/count

  for epoch in tqdm(range(epochs)):
      train_accuracy = step()
      test_accuracy = test()
      print()
      print('train accuracy =', train_accuracy)
      print('test accuracy =', test_accuracy)

In [19]:
train(example_model, optimizer, train_loader, test_loader, epochs=6)


  0%|          | 0/6 [00:00<?, ?it/s][A
 17%|█▋        | 1/6 [00:19<01:38, 19.76s/it][A


train accuracy = 0.54354
test accuracy = 0.6511



 33%|███▎      | 2/6 [00:39<01:19, 19.77s/it][A


train accuracy = 0.7229
test accuracy = 0.7



 50%|█████     | 3/6 [00:58<00:58, 19.63s/it][A


train accuracy = 0.79844
test accuracy = 0.7101



 67%|██████▋   | 4/6 [01:18<00:39, 19.53s/it][A


train accuracy = 0.8606
test accuracy = 0.7136



 83%|████████▎ | 5/6 [01:37<00:19, 19.51s/it][A


train accuracy = 0.9158
test accuracy = 0.7136



100%|██████████| 6/6 [01:57<00:00, 19.53s/it]


train accuracy = 0.9559
test accuracy = 0.7139





In [20]:
class ImitationLoss(nn.Module):
    def __init__(self, coeff=0.3):
        super().__init__()
        self.coeff = coeff
        self.imit_loss = nn.MSELoss()
        self.loss = nn.CrossEntropyLoss()
    
    
    def forward(self, student, teacher, labels):
        stud_loss = self.loss(student, labels)
        imit_loss = self.imit_loss(student, teacher)
        return self.coeff * stud_loss + (1 - self.coeff) * imit_loss

In [21]:
class ImitationModel(nn.Module):
    def __init__(self, example_model):
        super().__init__()
        self.model = example_model
        self.imit_loss = ImitationLoss()
        self.conv2d = nn.Sequential(
            nn.Conv2d(3, 10, 4), nn.ReLU(),
            nn.MaxPool2d(2, 2), nn.ReLU(),
            nn.Conv2d(10, 15, 5), nn.ReLU()
        )
        self.linear = nn.Sequential(
            nn.Linear(1500, 100), nn.ReLU(), nn.Linear(100, 10)
        )
        
    def forward(self, x):
        x = self.conv2d(x)
        return self.linear(x.view(-1, 1500))

    def loss(self, inputs, labels):
        pred = self.forward(inputs)
        model_pred = self.model(inputs)
        loss = self.imit_loss(pred, model_pred, labels)
        correct = (pred.argmax(dim=1) == labels).sum().item()
        return loss, correct

In [29]:
model = ImitationModel(example_model).to(device)

In [31]:
optimizer = optimizer = torch.optim.SGD([
            {'params': list(model.parameters())[:-8]},
            {'params': list(model.parameters())[-8:], 'lr': lr}
        ], lr=lr*0.1, momentum=0.9)

In [32]:
train(model, optimizer, train_loader, test_loader, epochs=20)



  0%|          | 0/20 [00:00<?, ?it/s][A[A

  5%|▌         | 1/20 [00:19<06:18, 19.93s/it][A[A


train accuracy = 0.19256
test accuracy = 0.2691




 10%|█         | 2/20 [00:40<06:01, 20.07s/it][A[A


train accuracy = 0.30516
test accuracy = 0.3312




 15%|█▌        | 3/20 [01:00<05:40, 20.01s/it][A[A


train accuracy = 0.35128
test accuracy = 0.3683




 20%|██        | 4/20 [01:19<05:18, 19.89s/it][A[A


train accuracy = 0.38748
test accuracy = 0.4027




 25%|██▌       | 5/20 [01:39<04:57, 19.85s/it][A[A


train accuracy = 0.41808
test accuracy = 0.4298




 30%|███       | 6/20 [01:59<04:38, 19.87s/it][A[A


train accuracy = 0.44076
test accuracy = 0.4507




 35%|███▌      | 7/20 [02:19<04:18, 19.88s/it][A[A


train accuracy = 0.45646
test accuracy = 0.4622




 40%|████      | 8/20 [02:39<03:58, 19.85s/it][A[A


train accuracy = 0.46844
test accuracy = 0.4715




 45%|████▌     | 9/20 [02:58<03:38, 19.83s/it][A[A


train accuracy = 0.48076
test accuracy = 0.4815




 50%|█████     | 10/20 [03:18<03:18, 19.83s/it][A[A


train accuracy = 0.49194
test accuracy = 0.4941




 55%|█████▌    | 11/20 [03:38<02:58, 19.85s/it][A[A


train accuracy = 0.50322
test accuracy = 0.5045




 60%|██████    | 12/20 [03:58<02:38, 19.82s/it][A[A


train accuracy = 0.51358
test accuracy = 0.5118




 65%|██████▌   | 13/20 [04:18<02:18, 19.81s/it][A[A


train accuracy = 0.52292
test accuracy = 0.5231




 70%|███████   | 14/20 [04:38<01:58, 19.80s/it][A[A


train accuracy = 0.53192
test accuracy = 0.5306




 75%|███████▌  | 15/20 [04:57<01:39, 19.82s/it][A[A


train accuracy = 0.54024
test accuracy = 0.5386




 80%|████████  | 16/20 [05:17<01:19, 19.82s/it][A[A


train accuracy = 0.54766
test accuracy = 0.5457




 85%|████████▌ | 17/20 [05:37<00:59, 19.80s/it][A[A


train accuracy = 0.55586
test accuracy = 0.5544




 90%|█████████ | 18/20 [05:57<00:39, 19.80s/it][A[A


train accuracy = 0.56398
test accuracy = 0.561




 95%|█████████▌| 19/20 [06:17<00:19, 19.80s/it][A[A


train accuracy = 0.57122
test accuracy = 0.5671




100%|██████████| 20/20 [06:36<00:00, 19.84s/it]


train accuracy = 0.5784
test accuracy = 0.5727



