# Train Resnet18 from VIT with Distiller on CIFAR100
Distiller can transfer knowledge from a heavy model (teacher) to a light one (student) with different structure.

* Teacher is a large model pretrained on specific dataset, which contains sufficient knowledge for this task, while the student model has much smaller structure. Distiller trains the student not only on the dataset, but also with the help of teacher’s knowledge.
* Distiller can take use of the knowledge from the existing pretrained large models but use much less training time. It can also significantly improve the converge  speed and predicting accuracy of a small model, which is very helpful for inference.
![Distiller](../doc/imgs/distiller.png)

In this notebook, we will do distillation from VIT to ResNet18 to show the basic usage of Model Adapter Distiller.

# Environment Setup

In [1]:
import torch
from torchvision import transforms,datasets
from torch.utils.data import DataLoader
import torch.optim as optim
import timm
import transformers
import datetime

In [2]:
import sys
sys.path.append("/home/vmagent/app/TLK/frameworks.bigdata.AIDK/AIDK/")
from TransferLearningKit.src.engine_core.transferrable_model import make_transferrable_with_knowledge_distillation
from TransferLearningKit.src.engine_core.distiller import KD

# Prepare Data

### Define Data Preprocessor

In [12]:
CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343) # mean for 3 channels
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)  # std for 3 channels

train_transform = transforms.Compose([
  transforms.RandomCrop(32, padding=4),
  transforms.RandomHorizontalFlip(),
  transforms.Resize(224),  # pretrained model is trained on large imgage size, scale 32x32 to 112x112
  transforms.ToTensor(),
  transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
])

test_transform = transforms.Compose([
  transforms.RandomCrop(32, padding=4),
  transforms.Resize(224),  # pretrained model is trained on large imgage size, scale 32x32 to 112x112
  transforms.ToTensor(),
  transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
])

### Prepare dataset and dataloader

In [13]:
batch_size = 128
num_workers = 1 # data worker
data_folder='./dataset' # dataset location
train_set = datasets.CIFAR100(root=data_folder, train=True, download=True, transform=train_transform)
test_set = datasets.CIFAR100(root=data_folder, train=False, download=True, transform=test_transform)

train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=False)
validate_loader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=False)

Files already downloaded and verified
Files already downloaded and verified


# Create Model

### Create Backbone model

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = timm.create_model('resnet18', pretrained=False, num_classes=100).to(device)

### Define Distiller 
When define distiller, we need to define teacher_type with a name start with "huggingface" if the teacher model comes from hugging face. Otherwise, don't need to set it.

In [None]:
%%time
loss_fn = torch.nn.CrossEntropyLoss()
teacher_model = transformers.ViTForImageClassification.from_pretrained('edumunozsala/vit_base-224-in21k-ft-cifar100')
distiller= KD(teacher_model,teacher_type="huggingface_vit_base-224-in21k-ft-cifar100")

### Make Model transferrable with distiller

In [6]:
model = make_transferrable_with_knowledge_distillation(model, loss_fn,distiller)

CPU times: user 8.02 s, sys: 1.97 s, total: 9.99 s
Wall time: 21.1 s


# create optimizer and scheduler

In [7]:
################# create optimizer #################
init_lr = 0.01
weight_decay = 0.005
momentum = 0.9
optimizer = optim.SGD(model.parameters(),lr=init_lr, weight_decay=weight_decay,momentum=momentum)
################# create scheduler #################
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1)

# Create Trainer

In [16]:
max_epoch = 1 # max 1 epoch
print_interval = 10 

In [9]:
def accuracy(output,label):
    pred = output.data.cpu().max(1)[1]
    label = label.data.cpu()
    if label.shape == output.shape:
        label = label.max(1)[1]
    return torch.mean((pred == label).float())

In [26]:
class Trainer:
    def __init__(self, model, optimizer, scheduler):
        self._model = model
        self._optimizer = optimizer
        self._scheduler = scheduler
        
    def train(self, train_dataloader, valid_dataloader, max_epoch):
        ''' 
        :param train_dataloader: train dataloader
        :param valid_dataloader: validation dataloader
        :param max_epoch: steps per epoch
        '''
        for epoch in range(0, max_epoch):
            ################## train #####################
            model.train()  # set training flag
            for (cur_step,(data, label)) in enumerate(train_dataloader):
                data = data.to(device)
                label = label.to(device)
                optimizer.zero_grad()
                output = model(data)
                # loss_value = loss_fn(output, label)
                loss_value = model.loss(output, label) # use model.loss
                loss_value.backward()       
                if cur_step%print_interval == 0:
                    # batch_acc = accuracy(output,label)
                    batch_acc = accuracy(output.backbone_output,label) # use output.backbone_output
                    dt = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') # date time
                    # print("[{}] epoch {} step {} : training batch loss {:.4f}, training batch acc {:.4f}".format(
                    #   dt, epoch, cur_step, loss_value.item(), batch_acc.item()))
                    print("[{}] epoch {} step {} : training batch loss {:.4f}, training batch acc {:.4f}".format(
                      dt, epoch, cur_step, loss_value.backbone_loss.item(), batch_acc.item())) # use loss_value.backbone_loss
                self._optimizer.step()
            self._scheduler.step()
            ################## evaluate ######################
            self.evaluate(model, valid_dataloader, epoch)
            
    def evaluate(self, model, valid_dataloader, epoch):
        with torch.no_grad():
            model.eval()  
            backbone = model.backbone # use backbone in evaluation
            loss_cum = 0.0
            sample_num = 0
            acc_cum = 0.0
            for (cur_step,(data, label)) in enumerate(valid_dataloader):
                data = data.to(device)
                label = label.to(device)
                # output = model(data)
                output = backbone(data)[0] # use backbone in evaluation and backbone has multi output
                batch_size = data.size(0)
                sample_num += batch_size
                loss_cum += loss_fn(output, label).item() * batch_size
                acc_cum += accuracy(output, label).item() * batch_size
            dt = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') # date time
            if sample_num > 0:
                loss_value = loss_cum/sample_num
                acc_value = acc_cum/sample_num
            else:
                loss_value = 0.0
                acc_value = 0.0

            print("[{}] epoch {} : evaluation loss {:.4f}, evaluation acc {:.4f}".format(
                dt, epoch, loss_value, acc_value))

# Train and evaluate

In [18]:
%%time
trainer = Trainer(model, optimizer, scheduler)
trainer.train(train_loader,validate_loader,max_epoch)

[2022-11-15 05:00:14] epoch 0 step 0 : training batch loss 4.6086, training batch acc 0.0234
[2022-11-15 05:00:48] epoch 0 step 10 : training batch loss 4.5780, training batch acc 0.0078
[2022-11-15 05:01:22] epoch 0 step 20 : training batch loss 4.5429, training batch acc 0.0156
[2022-11-15 05:01:56] epoch 0 step 30 : training batch loss 4.4996, training batch acc 0.0391
[2022-11-15 05:02:30] epoch 0 step 40 : training batch loss 4.4805, training batch acc 0.0391
[2022-11-15 05:03:05] epoch 0 step 50 : training batch loss 4.4295, training batch acc 0.0547
[2022-11-15 05:03:39] epoch 0 step 60 : training batch loss 4.4019, training batch acc 0.0781
[2022-11-15 05:04:11] epoch 0 step 70 : training batch loss 4.3581, training batch acc 0.1016
[2022-11-15 05:04:43] epoch 0 step 80 : training batch loss 4.3802, training batch acc 0.0391
[2022-11-15 05:05:15] epoch 0 step 90 : training batch loss 4.3280, training batch acc 0.0703
[2022-11-15 05:05:47] epoch 0 step 100 : training batch loss 

AttributeError: 'Tensor' object has no attribute 'logits'