# Knowledge Distillation on BEiT3
This notebook implements the knowledge distillation between a BEiT3 VQA teacher and student on VizWiz. We used this in our project to train our KD student using the fine-tuned teacher model.

In [3]:
import sys
import os
current = os.getcwd()
parent = os.path.dirname(current)
sys.path.append(parent)

import torchvision.transforms as transforms
import datasets
import torch
from tqdm import tqdm
from loading_beit3models import load_beit3_base, load_beit3_large
from beit3_vizwiz_finetuning import initModel, getDataLoader, validate
from transformers import BatchEncoding, XLMRobertaTokenizer
from Beit3_vizwiz import VizWizDataset, get_img_names_and_questions
import torch.nn as nn
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from torch.utils.data import DataLoader, Dataset

Here we specify training settings such as hyperparameters (initialized from BEiT3 paper) and the student epoch to start the training with. Additionally we load the training and validation data.

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device)
print(f'Loading student model..')
#student = load_beit3_base()
starting_epoch = 6
if starting_epoch <= 0:
    ckp_path = f'./models/base/beit3_base_indomain_patch16_480_vqa.pth'
    is_trained = False
else:
    ckp_path = f'./kd_student/vizwiz_checkpoint_kd_epoch{starting_epoch}.tar'
    is_trained = True
student = initModel(ckp_path, 'base', is_compiled_model_checkpoint=True, is_pretrained=is_trained)
print("compiling model...")
student = torch.compile(student)
print("finished compiling")
# Using TensorFloat32 Cores for better performance
torch.set_float32_matmul_precision('high')
# checkpoint = torch.load('beit-distillation/base/vizwiz_checkpoint_base_epoch6.tar')
# student.load_state_dict(checkpoint['model_state_dict'])
# optim.load_state_dict(checkpoint['optimizer_state_dict'])
# print(f'Loading teacher model..')
# teacher = initModel(f'/teamspace/studios/internal-yellow-dr2gf-zycm/beit-distillation/models/large/vizwiz_checkpoint_epoch15_large.tar', 'large', is_compiled_model_checkpoint=True)
#parameters from beit3 vqva2 finetuning
print(f'Prepping..')
lr = 2e-5
opt_betas = (0.9, 0.98)
weight_decay = 0.01
tokenizer = XLMRobertaTokenizer("../models/beit3.spm")
batch_size = 12

print(f"Loading dataset...")
vizwiz_path = "/teamspace/studios/internal-yellow-dr2gf/VizWiz"
train_loader = getDataLoader(tokenizer=tokenizer, batch_size=batch_size, data_dir=vizwiz_path, split='train')
val_loader = getDataLoader(tokenizer=tokenizer, batch_size=batch_size, data_dir=vizwiz_path, split='val')
print(f"Finished loading dataset.")

# freeze all parameters except for the head
# for name, value in student.named_parameters():
#    if not "head" in name:
#        value.requires_grad = False

Using device:  cpu
Loading student model..


FileNotFoundError: [Errno 2] No such file or directory: './kd_student/vizwiz_checkpoint_kd_epoch6.tar'

Here we define the KD training and evaluation loop, which in our case trains the entire student network on a linear combination between normal training loss on the labels and the loss between student and teacher logits. After each epoch the resulting checkpoint including training and validation loss is saved to the output folder.

In [9]:
# adapted from https://pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html
def train_knowledge_distillation(student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, loss_weight, device, starting_epoch=0, checkpoint_path=None):
    criterion = nn.BCEWithLogitsLoss(reduction='mean')
    optim = torch.optim.AdamW(params = student.parameters(), lr = learning_rate, betas = opt_betas, weight_decay=weight_decay)
    if checkpoint_path is not None:
        ckp=torch.load(checkpoint_path)
        optim.load_state_dict(ckp["optimizer_state_dict"])
    # since logits should not change between runs we simply read them out of the previously calculated file
    all_teacher_logits = load_teacher_logits()
    
    
    folder = f"./kd_student"
    if not os.path.exists(folder):
        os.mkdir(folder)
    for epoch in range(starting_epoch+1, starting_epoch+epochs):
        running_loss = 0.0
        student.train() # Student to train mode
        for data in tqdm(train_loader):
            img = data["image"].to(device)
            q_tokens = data["language_tokens"].to(device)
            labels = data["labels"].to(device)
            optim.zero_grad()
        
            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            # Since the teachers outputs should not change between epochs, we simply use the pre-calculated logits rather than recalculate them in every epoch.
            # with torch.no_grad():
            #    teacher_logits = teacher.forward(image=img, question=q_tokens)
            #    teacher_logits = teacher_logits.float()
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                teacher_logits = torch.stack([all_teacher_logits[id] for id in data["image_id"]])
                
                # Forward pass with the student model
                student_logits = student.forward(image=img, question=q_tokens)
                student_logits = student_logits.float()
                
                #Soften the student logits by applying softmax first and log() second
                soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
                soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)
    
                # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
                soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)
    
                # Calculate the true label loss
                label_loss = criterion(student_logits, labels)
    
                # Weighted sum of the two losses
                loss = soft_target_loss_weight * soft_targets_loss + loss_weight * label_loss
    
                loss.backward()
                optim.step()
            running_loss += loss.item()
        student.eval()
        with torch.no_grad():
            val_loss = validate(
                tokenizer=tokenizer,
                criterion=criterion,
                model=student,
                val_loader=val_loader,
                device=device
            )
        torch.save({
            'epoch': epoch,
            'model_state_dict': student.state_dict(),
            'optimizer_state_dict': optim.state_dict(),
            'loss': running_loss / len(train_loader),
            'val_loss': val_loss
            }, os.path.join(folder, f'vizwiz_checkpoint_kd_epoch{epoch}.tar'))
        print(f"Epoch {epoch}/{epochs +starting_epoch}, Loss: {running_loss / len(train_loader)}, Val. Loss: {val_loss}")    

These functions are used to acquire the teacher logits on VizWiz for the KD training by inferencing with the teacher model. Once you have written the teacher logits you can simply load them, since they won't change as long as the teacher doesn't change, in which case you would need to write them again.

In [4]:
def write_teacher_logits(teacher, train_loader, device, filename=f'teacher_logits.tar'):
    teacher.eval()
    logits = {}
    for data in tqdm(train_loader):
        img = data["image"].to(device)
        q_tokens = data["language_tokens"].to(device)
        labels = data["labels"].to(device)
        with torch.no_grad():
                teacher_logits = teacher.forward(image=img, question=q_tokens)
                teacher_logits = teacher_logits.float()
        for id, logs in zip(data["image_id"], teacher_logits):
            logits[id] = logs
    torch.save(logits, filename)

def load_teacher_logits(path='./teacher_logits.tar'):
    return torch.load(path)

Starts the training process. Additionally some KD specific training parameters including temperature, the weights for the combined loss and the number of epochs to train on are set here.

In [None]:
T = 2
soft_target_loss_weight = 0.25
original_loss_weight = 1 - soft_target_loss_weight
epochs = 5
train_knowledge_distillation(student, train_loader, epochs, lr, T, soft_target_loss_weight, original_loss_weight, device, starting_epoch=starting_epoch, checkpoint_path=ckp_path)

100%|██████████| 1711/1711 [26:03<00:00,  1.09it/s]
100%|██████████| 360/360 [03:53<00:00,  1.54it/s]


Validation loss: 0.0189
Epoch 7/11, Loss: 0.10276553794210379, Val. Loss: 0.01893749316740367


 79%|███████▉  | 1352/1711 [20:41<05:23,  1.11it/s]

Gives an overview over all trained epochs and their respective losses.

In [7]:
student_dir = './kd_student'
for file in os.listdir(student_dir):
    ckp = torch.load(os.path.join(student_dir, file),map_location=torch.device('cpu'))
    print(f"File: {file}, Epoch: {ckp['epoch']}, loss: {ckp['loss']}, val_loss: {ckp['val_loss']}")

File: vizwiz_checkpoint_kd_epoch1.tar, Epoch: 1, loss: 0.5062050504216947, val_loss: 0.02807389839629953
File: vizwiz_checkpoint_kd_epoch10.tar, Epoch: 10, loss: 0.07444856881503123, val_loss: 0.01825543609625634
File: vizwiz_checkpoint_kd_epoch2.tar, Epoch: 2, loss: 0.297218958086817, val_loss: 0.023673854132842582
File: vizwiz_checkpoint_kd_epoch3.tar, Epoch: 3, loss: 0.21412093140520178, val_loss: 0.021490420229060368
File: vizwiz_checkpoint_kd_epoch4.tar, Epoch: 4, loss: 0.16911141843056554, val_loss: 0.020520159855368546
File: vizwiz_checkpoint_kd_epoch5.tar, Epoch: 5, loss: 0.1395778000389603, val_loss: 0.019529018430168636
File: vizwiz_checkpoint_kd_epoch6.tar, Epoch: 6, loss: 0.11840906368229975, val_loss: 0.019089714166410785
File: vizwiz_checkpoint_kd_epoch7.tar, Epoch: 7, loss: 0.10276553794210379, val_loss: 0.01893749316740367
File: vizwiz_checkpoint_kd_epoch8.tar, Epoch: 8, loss: 0.0915182161203154, val_loss: 0.018791058071656153
File: vizwiz_checkpoint_kd_epoch9.tar, Epoc