In [None]:
!pip install tensorboard tensorboardX -q
import math
import os
import time
from tqdm import tqdm
from collections import OrderedDict
import getpass
from tensorboardX import SummaryWriter
import numpy as np
import sys

from __future__ import absolute_import
import torch.nn as nn
import torch.nn.functional as F

import torch
import torch.nn as nn
import torch.nn.functional as F


import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.models as models
from torchvision.models import ResNet
import torchvision.transforms as transforms
from torchvision.models import resnet18, resnet101, resnet34
from torch.utils.data import DataLoader
import math
import torch.optim as optim
import matplotlib.pyplot as plt
import wandb

In [None]:
def perception(logits, epsilon=1e-5):
    """
    perform perception on logits.
    
    Parameters:
    logits (torch.Tensor): A tensor of shape (B, N) where B is the batch size and N is the number of classes.
    epsilon (float): A small constant to avoid division by zero in normalization.
    
    Returns:
    torch.Tensor: perception logits.
    """
    
    batch_mean = torch.mean(logits, dim=0, keepdim=True)
    batch_var = torch.var(logits, dim=0, keepdim=True, unbiased=False)
    x_normalized = (logits - batch_mean) / torch.sqrt(batch_var + epsilon)
    
    return x_normalized
    

def luminet_loss(logits_student, logits_teacher, target, alpha, temperature):
    #print('Student')
    stu_batch = perception(logits_student)
    #print('Teacher')
    tea_batch = perception(logits_teacher)
    
    pred_teacher = F.softmax(
        tea_batch/temperature, dim=1
    )
    log_pred_student = F.log_softmax(
        stu_batch/temperature,dim=1
    )
    nckd_loss = F.kl_div(log_pred_student, pred_teacher, reduction='batchmean')
    nckd_loss*=alpha**2
    
    return nckd_loss

def perception(logits, epsilon=1e-5):
    """
    perform perception on logits.
    
    Parameters:
    logits (torch.Tensor): A tensor of shape (B, N) where B is the batch size and N is the number of classes.
    epsilon (float): A small constant to avoid division by zero in normalization.
    
    Returns:
    torch.Tensor: perception logits.
    """
    
    batch_mean = torch.mean(logits, dim=0, keepdim=True)
    batch_var = torch.var(logits, dim=0, keepdim=True, unbiased=False)
    x_normalized = (logits - batch_mean) / torch.sqrt(batch_var + epsilon)
    
    return x_normalized
    

def normalize(logit):
    mean = logit.mean(dim=-1, keepdims=True)
    stdv = logit.std(dim=-1, keepdims=True)
    return (logit - mean) / (1e-7 + stdv)

def kd_loss(logits_student_in, logits_teacher_in, temperature, logit_stand):
    logits_student = normalize(logits_student_in) if logit_stand else logits_student_in
    logits_teacher = normalize(logits_teacher_in) if logit_stand else logits_teacher_in
    log_pred_student = F.log_softmax(logits_student / temperature, dim=1)
    pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
    loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean()
    loss_kd *= temperature**2
    return loss_kd

class Distiller(nn.Module):
    def __init__(self, student, teacher):
        super(Distiller, self).__init__()
        self.student = student
        self.teacher = teacher

    def train(self, mode=True):
        # teacher as eval mode by default
        if not isinstance(mode, bool):
            raise ValueError("training mode is expected to be boolean")
        self.training = mode
        for module in self.children():
            module.train(mode)
        self.teacher.eval()
        return self

    def get_learnable_parameters(self):
        # if the method introduces extra parameters, re-impl this function
        return [v for k, v in self.student.named_parameters()]

    def get_extra_parameters(self):
        # calculate the extra parameters introduced by the distiller
        return 0

    def forward_train(self, **kwargs):
        # training function for the distillation method
        raise NotImplementedError()

    def forward_test(self, image):
        return self.student(image)

    def forward(self, **kwargs):
        if self.training:
            return self.forward_train(**kwargs)
        return self.forward_test(kwargs["image"])

class DTKD(Distiller):
    def __init__(self, student, teacher):
        super(DTKD, self).__init__(student, teacher)
        self.temperature = 2
        self.ce_loss_weight = 0.1
        self.kd_loss_weight = 9
        self.logit_stand = True

    def forward_train(self, image, target, **kwargs):
        logits_student = self.student(image)
        with torch.no_grad():
            logits_teacher = self.teacher(image)

        # losses
        loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target)
        loss_kd = self.kd_loss_weight * kd_loss(
            logits_student, logits_teacher, self.temperature, self.logit_stand
        )
        losses_dict = {
            "loss_ce": loss_ce,
            "loss_kd": loss_kd,
        }
        return logits_student, losses_dict
        
class BaseTrainer(object):
    def __init__(
        self, 
        experiment_name, 
        distiller, 
        train_loader, 
        val_loader
    ):
        self.distiller = distiller
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.optimizer = torch.optim.SGD(
            self.distiller.get_learnable_parameters(), 
            lr=0.05, 
            weight_decay=5e-4,
            momentum=0.9
        )
        self.best_acc = -1

        username = getpass.getuser()
        # init loggers
        self.log_path = os.path.join("./output", experiment_name)
        if not os.path.exists(self.log_path):
            os.makedirs(self.log_path)
        self.tf_writer = SummaryWriter(os.path.join(self.log_path, "train.events"))

    def adjust_learning_rate(self, epoch, optimizer):
        steps = np.sum(epoch > np.asarray([62, 75, 87]))
        if steps > 0:
            new_lr = 0.05 * (0.1**steps)
            for param_group in optimizer.param_groups:
                param_group["lr"] = new_lr
            return new_lr
        return 0.05

    def log(self, lr, epoch, log_dict):
        # tensorboard log
        for k, v in log_dict.items():
            self.tf_writer.add_scalar(k, v, epoch)
        self.tf_writer.flush()

        # wandb.init(
        #     project="DTKD",  # Replace with your project name
        #     name="DTKD",      # Optional: Give your run a name
        #     config={                     # Optional: Add configuration details
        #         "learning_rate": 0.05,
        #         "batch_size": 128,
        #         "epochs": 3,
        #     }
        # )
        # wandb.log({"current lr": lr})
        # wandb.log(log_dict)
        if log_dict["test_acc"] > self.best_acc:
            self.best_acc = log_dict["test_acc"]
        #     wandb.run.summary["best_acc"] = self.best_acc
        # worklog.txt
        with open(os.path.join(self.log_path, "worklog.txt"), "a") as writer:
            lines = [
                "-" * 25 + os.linesep,
                "epoch: {}".format(epoch) + os.linesep,
                "lr: {:.2f}".format(float(lr)) + os.linesep,
            ]
            for k, v in log_dict.items():
                lines.append("{}: {:.2f}".format(k, v) + os.linesep)
            lines.append("-" * 25 + os.linesep)
            writer.writelines(lines)

    def train(self, resume=False, num_epochs=100):
        epoch = 1
        if resume:
            state = load_checkpoint(os.path.join(self.log_path, "latest"))
            epoch = state["epoch"] + 1
            self.distiller.load_state_dict(state["model"])
            self.optimizer.load_state_dict(state["optimizer"])
            self.best_acc = state["best_acc"]
        while epoch < num_epochs + 1:
            self.train_epoch(epoch)
            epoch += 1
        print(log_msg("Best accuracy:{}".format(self.best_acc), "EVAL"))
        with open(os.path.join(self.log_path, "worklog.txt"), "a") as writer:
            writer.write("best_acc\t" + "{:.2f}".format(float(self.best_acc)))

    def train_epoch(self, epoch):
        lr = self.adjust_learning_rate(epoch, self.optimizer)
        train_meters = {
            "training_time": AverageMeter(),
            "data_time": AverageMeter(),
            "losses": AverageMeter(),
            "top1": AverageMeter(),
            "top5": AverageMeter(),
        }
        num_iter = len(self.train_loader)
        pbar = tqdm(range(num_iter))

        # train loops
        self.distiller.train()
        for idx, data in enumerate(self.train_loader):
            msg, train_loss = self.train_iter(data, epoch, train_meters)
            pbar.set_description(log_msg(msg, "TRAIN"))
            pbar.update()
        pbar.close()

        test_acc, test_acc_top5, test_loss = validate(self.val_loader, self.distiller)

        dtkd_losses.append({"train_loss": train_loss, "test_loss": test_loss})
        dtkd_accuracies.append({"acc@1": test_acc.item(), "acc@5": test_acc_top5.item()})
        # log
        log_dict = OrderedDict(
            {
                "train_acc": train_meters["top1"].avg,
                "train_loss": train_meters["losses"].avg,
                "test_acc": test_acc,
                "test_acc_top5": test_acc_top5,
                "test_loss": test_loss,
            }
        )
        self.log(lr, epoch, log_dict)
        # saving checkpoint
        state = {
            "epoch": epoch,
            "model": self.distiller.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "best_acc": self.best_acc,
        }
        student_state = {"model": self.distiller.student.state_dict()}
        save_checkpoint(state, os.path.join(self.log_path, "latest"))
        save_checkpoint(
            student_state, os.path.join(self.log_path, "student_latest")
        )
        if epoch % 20 == 0:
            save_checkpoint(
                state, os.path.join(self.log_path, "epoch_{}".format(epoch))
            )
            save_checkpoint(
                student_state,
                os.path.join(self.log_path, "student_{}".format(epoch)),
            )
        # update the best
        if test_acc >= self.best_acc:
            save_checkpoint(state, os.path.join(self.log_path, "best"))
            save_checkpoint(
                student_state, os.path.join(self.log_path, "student_best")
            )

    def train_iter(self, data, epoch, train_meters):
        self.optimizer.zero_grad()
        train_start_time = time.time()
        image, target = data  # Adjusted to match the output of your data loader
        train_meters["data_time"].update(time.time() - train_start_time)
        image = image.float()
        image = image.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
    
        # forward
        preds, losses_dict = self.distiller(image=image, target=target, epoch=epoch)
    
        # backward
        loss = sum([l.mean() for l in losses_dict.values()])
        loss.backward()
        self.optimizer.step()
        train_meters["training_time"].update(time.time() - train_start_time)
        # collect info
        batch_size = image.size(0)
        acc1, acc5 = accuracy(preds, target, topk=(1, 5))
        train_meters["losses"].update(loss.cpu().detach().numpy().mean(), batch_size)
        train_meters["top1"].update(acc1[0], batch_size)
        train_meters["top5"].update(acc5[0], batch_size)
        # print info
        msg = "Epoch:{}| Time(data):{:.3f}| Time(train):{:.3f}| Loss:{:.4f}| Top-1:{:.3f}| Top-5:{:.3f}".format(
            epoch,
            train_meters["data_time"].avg,
            train_meters["training_time"].avg,
            train_meters["losses"].avg,
            train_meters["top1"].avg,
            train_meters["top5"].avg,
        )
        return (msg, train_meters["losses"].avg)
        
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def validate(val_loader, distiller):
    batch_time, losses, top1, top5 = [AverageMeter() for _ in range(4)]
    criterion = nn.CrossEntropyLoss()
    num_iter = len(val_loader)
    pbar = tqdm(range(num_iter))

    distiller.eval()
    with torch.no_grad():
        start_time = time.time()
        for idx, (image, target) in enumerate(val_loader):
            image = image.float()
            image = image.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)
            output = distiller(image=image)
            loss = criterion(output, target)
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            batch_size = image.size(0)
            losses.update(loss.cpu().detach().numpy().mean(), batch_size)
            top1.update(acc1[0], batch_size)
            top5.update(acc5[0], batch_size)

            # measure elapsed time
            batch_time.update(time.time() - start_time)
            start_time = time.time()
            msg = "Top-1:{top1.avg:.3f}| Top-5:{top5.avg:.3f}".format(
                top1=top1, top5=top5
            )
            pbar.set_description(log_msg(msg, "EVAL"))
            pbar.update()
    pbar.close()
    return top1.avg, top5.avg, losses.avg

def log_msg(msg, mode="INFO"):
    color_map = {
        "INFO": 36,
        "TRAIN": 32,
        "EVAL": 31,
    }
    msg = "\033[{}m[{}] {}\033[0m".format(color_map[mode], mode, msg)
    return msg

def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.reshape(1, -1).expand_as(pred))
        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def save_checkpoint(obj, path):
    with open(path, "wb") as f:
        torch.save(obj, f)

def load_checkpoint(path):
    with open(path, "rb") as f:
        return torch.load(f, map_location="cpu")

In [None]:
student_model, path = cifar_model_dict[student]
student_model = student_model(num_classes=num_classes)
student_model.to("cuda", non_blocking=True)

distiller = DTKD(student_model, teacher_model)

# # Initialize the CRDTrainer
trainer = BaseTrainer(
    experiment_name="DTKD",
    distiller=distiller,
    train_loader=train_loader, 
    val_loader=val_loader
)

trainer.train(num_epochs=max_epoch)

In [None]:
import matplotlib.pyplot as plt

def plot_losses():
    # Extracting train and test losses for plotting
    dtkd_train_loss = [entry['train_loss'] for entry in dtkd_losses]
    dtkd_test_loss = [entry['test_loss'] for entry in dtkd_losses]
    our_train_loss = [entry['train_loss'] for entry in our_losses]
    our_test_loss = [entry['test_loss'] for entry in our_losses]
    
    # FOR 100 EPOCH
    # Plotting
    plt.figure(figsize=(8, 6)) 
    
    # Train Losses
    plt.subplot(2, 1, 1)  # Positioning in the first row
    plt.plot(dtkd_train_loss, label="DTKD Train Loss", color='blue')
    plt.plot(our_train_loss, label="Our Train Loss", color='red')
    plt.title("Train Losses")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid()
    
    # Test Losses
    plt.subplot(2, 1, 2)  # Positioning in the second row
    plt.plot(dtkd_test_loss, label="DTKD Test Loss", color='blue')
    plt.plot(our_test_loss, label="Our Test Loss", color='red')
    plt.title("Test Losses")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid()
    
    plt.tight_layout()  # Adjust layout to avoid overlap
    plt.show()

plot_losses()

In [None]:
def plot_accuracies():
    
    # Extract data
    dtkd_acc1 = [entry['acc@1'] for entry in dtkd_accuracies]
    our_acc1 = [entry['acc@1'] for entry in our_accuracies]
    
    # Plotting
    plt.figure(figsize=(8, 6))
    plt.plot(dtkd_acc1, label="DTKD acc@1", color='blue')
    plt.plot(our_acc1, label="Our acc@1", color='red')
    
    # Graph details
    plt.title("Accuracy Comparison", fontsize=16)
    plt.xlabel("Epoch", fontsize=12)
    plt.ylabel("Accuracy (%)", fontsize=12)
    plt.legend()
    plt.grid()
    plt.tight_layout()
    
    # Show plot
    plt.show()
    
plot_accuracies()