In [None]:
!pip install tensorboard tensorboardX -q
import math
import os
import time
from tqdm import tqdm
from collections import OrderedDict
import getpass
from tensorboardX import SummaryWriter
import numpy as np
import sys

from __future__ import absolute_import
import torch.nn as nn
import torch.nn.functional as F

import torch
import torch.nn as nn
import torch.nn.functional as F


import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.models as models
from torchvision.models import ResNet
import torchvision.transforms as transforms
from torchvision.models import resnet18, resnet101, resnet34
from torch.utils.data import DataLoader
import math
import torch.optim as optim
import matplotlib.pyplot as plt
import wandb

In [None]:
wandb.login(key="9d5a8aab3348b03e43147ae4735979a983a3e7b0")

In [None]:
dtkd_losses = []
dtkd_accuracies = []
our_losses = []
our_accuracies = []


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(BasicBlock, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_planes,
                    self.expansion * planes,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(self.expansion * planes),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(Bottleneck, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(
            planes, self.expansion * planes, kernel_size=1, bias=False
        )
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_planes,
                    self.expansion * planes,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(self.expansion * planes),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, zero_init_residual=False):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)
        self.stage_channels = [256, 512, 1024, 2048]

    def get_feat_modules(self):
        feat_m = nn.ModuleList([])
        feat_m.append(self.conv1)
        feat_m.append(self.bn1)
        feat_m.append(self.layer1)
        feat_m.append(self.layer2)
        feat_m.append(self.layer3)
        feat_m.append(self.layer4)
        return feat_m

    def get_bn_before_relu(self):
        if isinstance(self.layer1[0], Bottleneck):
            bn1 = self.layer1[-1].bn3
            bn2 = self.layer2[-1].bn3
            bn3 = self.layer3[-1].bn3
            bn4 = self.layer4[-1].bn3
        elif isinstance(self.layer1[0], BasicBlock):
            bn1 = self.layer1[-1].bn2
            bn2 = self.layer2[-1].bn2
            bn3 = self.layer3[-1].bn2
            bn4 = self.layer4[-1].bn2
        else:
            raise NotImplementedError("ResNet unknown block error !!!")

        return [bn1, bn2, bn3, bn4]

    def get_stage_channels(self):
        return self.stage_channels

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for i in range(num_blocks):
            stride = strides[i]
            layers.append(block(self.in_planes, planes, stride, i == num_blocks - 1))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def encode(self, x, idx, preact=False):
        if idx == -1:
            out, pre = self.layer4(F.relu(x))
        elif idx == -2:
            out, pre = self.layer3(F.relu(x))
        elif idx == -3:
            out, pre = self.layer2(F.relu(x))
        else:
            raise NotImplementedError()
        return pre

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        f0 = out
        out, f1_pre = self.layer1(out)
        f1 = out
        out, f2_pre = self.layer2(out)
        f2 = out
        out, f3_pre = self.layer3(out)
        f3 = out
        out, f4_pre = self.layer4(out)
        f4 = out
        out = self.avgpool(out)
        avg = out.reshape(out.size(0), -1)
        out = self.linear(avg)

        feats = {}
        feats["feats"] = [f0, f1, f2, f3, f4]
        feats["preact_feats"] = [f0, f1_pre, f2_pre, f3_pre, f4_pre]
        feats["pooled_feat"] = avg

        return out, feats


def ResNet18(**kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)


def ResNet34(**kwargs):
    return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)


def ResNet50(**kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)


def ResNet101(**kwargs):
    return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)


def ResNet152(**kwargs):
    return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)


if __name__ == "__main__":
    net = ResNet18(num_classes=100)
    x = torch.randn(2, 3, 32, 32)
    logit, feats = net(x)

    for f in feats["feats"]:
        print(f.shape, f.min().item())
    print(logit.shape)



def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(
        in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
    )


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False):
        super(BasicBlock, self).__init__()
        self.is_last = is_last
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False):
        super(Bottleneck, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out


class ResNet(nn.Module):
    def __init__(self, depth, num_filters, block_name="BasicBlock", num_classes=10):
        super(ResNet, self).__init__()
        # Model type specifies number of layers for CIFAR-10 model
        if block_name.lower() == "basicblock":
            assert (
                depth - 2
            ) % 6 == 0, "When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202"
            n = (depth - 2) // 6
            block = BasicBlock
        elif block_name.lower() == "bottleneck":
            assert (
                depth - 2
            ) % 9 == 0, "When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199"
            n = (depth - 2) // 9
            block = Bottleneck
        else:
            raise ValueError("block_name shoule be Basicblock or Bottleneck")

        self.inplanes = num_filters[0]
        self.conv1 = nn.Conv2d(3, num_filters[0], kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(num_filters[0])
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, num_filters[1], n)
        self.layer2 = self._make_layer(block, num_filters[2], n, stride=2)
        self.layer3 = self._make_layer(block, num_filters[3], n, stride=2)
        self.avgpool = nn.AvgPool2d(8)
        self.fc = nn.Linear(num_filters[3] * block.expansion, num_classes)
        self.stage_channels = num_filters

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(
                    self.inplanes,
                    planes * block.expansion,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = list([])
        layers.append(
            block(self.inplanes, planes, stride, downsample, is_last=(blocks == 1))
        )
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, is_last=(i == blocks - 1)))

        return nn.Sequential(*layers)

    def get_feat_modules(self):
        feat_m = nn.ModuleList([])
        feat_m.append(self.conv1)
        feat_m.append(self.bn1)
        feat_m.append(self.relu)
        feat_m.append(self.layer1)
        feat_m.append(self.layer2)
        feat_m.append(self.layer3)
        return feat_m

    def get_bn_before_relu(self):
        if isinstance(self.layer1[0], Bottleneck):
            bn1 = self.layer1[-1].bn3
            bn2 = self.layer2[-1].bn3
            bn3 = self.layer3[-1].bn3
        elif isinstance(self.layer1[0], BasicBlock):
            bn1 = self.layer1[-1].bn2
            bn2 = self.layer2[-1].bn2
            bn3 = self.layer3[-1].bn2
        else:
            raise NotImplementedError("ResNet unknown block error !!!")

        return [bn1, bn2, bn3]

    def get_stage_channels(self):
        return self.stage_channels

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)  # 32x32
        f0 = x

        x, f1_pre = self.layer1(x)  # 32x32
        f1 = x
        x, f2_pre = self.layer2(x)  # 16x16
        f2 = x
        x, f3_pre = self.layer3(x)  # 8x8
        f3 = x

        x = self.avgpool(x)
        avg = x.reshape(x.size(0), -1)
        out = self.fc(avg)

        feats = {}
        feats["feats"] = [f0, f1, f2, f3]
        feats["preact_feats"] = [f0, f1_pre, f2_pre, f3_pre]
        feats["pooled_feat"] = avg

        return out


def resnet8(**kwargs):
    return ResNet(8, [16, 16, 32, 64], "basicblock", **kwargs)


def resnet14(**kwargs):
    return ResNet(14, [16, 16, 32, 64], "basicblock", **kwargs)


def resnet20(**kwargs):
    return ResNet(20, [16, 16, 32, 64], "basicblock", **kwargs)


def resnet32(**kwargs):
    return ResNet(32, [16, 16, 32, 64], "basicblock", **kwargs)


def resnet44(**kwargs):
    return ResNet(44, [16, 16, 32, 64], "basicblock", **kwargs)


def resnet56(**kwargs):
    return ResNet(56, [16, 16, 32, 64], "basicblock", **kwargs)


def resnet110(**kwargs):
    return ResNet(110, [16, 16, 32, 64], "basicblock", **kwargs)


def resnet8x4(**kwargs):
    return ResNet(8, [32, 64, 128, 256], "basicblock", **kwargs)


def resnet32x4(**kwargs):
    return ResNet(32, [32, 64, 128, 256], "basicblock", **kwargs)

cifar100_model_prefix = "/kaggle/input/cifar_teachers/pytorch/default/1/cifar_teachers/"

cifar_model_dict = {
    # teachers
    "resnet56": (
        resnet56,
        cifar100_model_prefix + "resnet56_vanilla/ckpt_epoch_240.pth",
    ),
    "resnet110": (
        resnet110,
        cifar100_model_prefix + "resnet110_vanilla/ckpt_epoch_240.pth",
    ),
    "resnet32x4": (
        resnet32x4,
        cifar100_model_prefix + "resnet32x4_vanilla/ckpt_epoch_240.pth",
    ),
    "ResNet50": (
        ResNet50,
        cifar100_model_prefix + "ResNet50_vanilla/ckpt_epoch_240.pth",
    ),
    # "wrn_40_2": (
    #     wrn_40_2,
    #     cifar100_model_prefix + "wrn_40_2_vanilla/ckpt_epoch_240.pth",
    # ),
    # "vgg13": (vgg13_bn, cifar100_model_prefix + "vgg13_vanilla/ckpt_epoch_240.pth"),
    # students
    "resnet8": (resnet8, None),
    "resnet14": (resnet14, None),
    "resnet20": (resnet20, None),
    "resnet32": (resnet32, None),
    "resnet44": (resnet44, None),
    "resnet8x4": (resnet8x4, None),
    "ResNet18": (ResNet18, None),
    # "wrn_16_1": (wrn_16_1, None),
    # "wrn_16_2": (wrn_16_2, None),
    # "wrn_40_1": (wrn_40_1, None),
    # "vgg8": (vgg8_bn, None),
    # "vgg11": (vgg11_bn, None),
    # "vgg16": (vgg16_bn, None),
    # "vgg19": (vgg19_bn, None),
    # "MobileNetV2": (mobile_half, None),
    # "ShuffleV1": (ShuffleV1, None),
    # "ShuffleV2": (ShuffleV2, None),
}

In [None]:
def perception(logits, epsilon=1e-5):
    """
    perform perception on logits.
    
    Parameters:
    logits (torch.Tensor): A tensor of shape (B, N) where B is the batch size and N is the number of classes.
    epsilon (float): A small constant to avoid division by zero in normalization.
    
    Returns:
    torch.Tensor: perception logits.
    """
    
    batch_mean = torch.mean(logits, dim=0, keepdim=True)
    batch_var = torch.var(logits, dim=0, keepdim=True, unbiased=False)
    x_normalized = (logits - batch_mean) / torch.sqrt(batch_var + epsilon)
    
    return x_normalized
    

def luminet_loss(logits_student, logits_teacher, target, alpha, temperature):
    #print('Student')
    stu_batch = perception(logits_student)
    #print('Teacher')
    tea_batch = perception(logits_teacher)
    
    pred_teacher = F.softmax(
        tea_batch/temperature, dim=1
    )
    log_pred_student = F.log_softmax(
        stu_batch/temperature,dim=1
    )
    nckd_loss = F.kl_div(log_pred_student, pred_teacher, reduction='batchmean')
    nckd_loss*=alpha**2
    
    return nckd_loss

def perception(logits, epsilon=1e-5):
    """
    perform perception on logits.
    
    Parameters:
    logits (torch.Tensor): A tensor of shape (B, N) where B is the batch size and N is the number of classes.
    epsilon (float): A small constant to avoid division by zero in normalization.
    
    Returns:
    torch.Tensor: perception logits.
    """
    
    batch_mean = torch.mean(logits, dim=0, keepdim=True)
    batch_var = torch.var(logits, dim=0, keepdim=True, unbiased=False)
    x_normalized = (logits - batch_mean) / torch.sqrt(batch_var + epsilon)
    
    return x_normalized
    

def normalize(logit):
    mean = logit.mean(dim=-1, keepdims=True)
    stdv = logit.std(dim=-1, keepdims=True)
    return (logit - mean) / (1e-7 + stdv)

def kd_loss(logits_student_in, logits_teacher_in, temperature, logit_stand):
    logits_student = normalize(logits_student_in) if logit_stand else logits_student_in
    logits_teacher = normalize(logits_teacher_in) if logit_stand else logits_teacher_in
    log_pred_student = F.log_softmax(logits_student / temperature, dim=1)
    pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
    loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean()
    loss_kd *= temperature**2
    return loss_kd

class Distiller(nn.Module):
    def __init__(self, student, teacher):
        super(Distiller, self).__init__()
        self.student = student
        self.teacher = teacher

    def train(self, mode=True):
        # teacher as eval mode by default
        if not isinstance(mode, bool):
            raise ValueError("training mode is expected to be boolean")
        self.training = mode
        for module in self.children():
            module.train(mode)
        self.teacher.eval()
        return self

    def get_learnable_parameters(self):
        # if the method introduces extra parameters, re-impl this function
        return [v for k, v in self.student.named_parameters()]

    def get_extra_parameters(self):
        # calculate the extra parameters introduced by the distiller
        return 0

    def forward_train(self, **kwargs):
        # training function for the distillation method
        raise NotImplementedError()

    def forward_test(self, image):
        return self.student(image)

    def forward(self, **kwargs):
        if self.training:
            return self.forward_train(**kwargs)
        return self.forward_test(kwargs["image"])

class DTKD(Distiller):
    def __init__(self, student, teacher):
        super(DTKD, self).__init__(student, teacher)
        self.temperature = 2
        self.ce_loss_weight = 0.1
        self.kd_loss_weight = 9
        self.logit_stand = True

    def forward_train(self, image, target, **kwargs):
        logits_student = self.student(image)
        with torch.no_grad():
            logits_teacher = self.teacher(image)

        # losses
        loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target)
        loss_kd = self.kd_loss_weight * kd_loss(
            logits_student, logits_teacher, self.temperature, self.logit_stand
        )
        losses_dict = {
            "loss_ce": loss_ce,
            "loss_kd": loss_kd,
        }
        return logits_student, losses_dict
        
class BaseTrainer(object):
    def __init__(
        self, 
        experiment_name, 
        distiller, 
        train_loader, 
        val_loader
    ):
        self.distiller = distiller
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.optimizer = torch.optim.SGD(
            self.distiller.get_learnable_parameters(), 
            lr=0.05, 
            weight_decay=5e-4,
            momentum=0.9
        )
        self.best_acc = -1

        username = getpass.getuser()
        # init loggers
        self.log_path = os.path.join("./output", experiment_name)
        if not os.path.exists(self.log_path):
            os.makedirs(self.log_path)
        self.tf_writer = SummaryWriter(os.path.join(self.log_path, "train.events"))

    def adjust_learning_rate(self, epoch, optimizer):
        steps = np.sum(epoch > np.asarray([62, 75, 87]))
        if steps > 0:
            new_lr = 0.05 * (0.1**steps)
            for param_group in optimizer.param_groups:
                param_group["lr"] = new_lr
            return new_lr
        return 0.05

    def log(self, lr, epoch, log_dict):
        # tensorboard log
        for k, v in log_dict.items():
            self.tf_writer.add_scalar(k, v, epoch)
        self.tf_writer.flush()

        # wandb.init(
        #     project="DTKD",  # Replace with your project name
        #     name="DTKD",      # Optional: Give your run a name
        #     config={                     # Optional: Add configuration details
        #         "learning_rate": 0.05,
        #         "batch_size": 128,
        #         "epochs": 3,
        #     }
        # )
        # wandb.log({"current lr": lr})
        # wandb.log(log_dict)
        if log_dict["test_acc"] > self.best_acc:
            self.best_acc = log_dict["test_acc"]
        #     wandb.run.summary["best_acc"] = self.best_acc
        # worklog.txt
        with open(os.path.join(self.log_path, "worklog.txt"), "a") as writer:
            lines = [
                "-" * 25 + os.linesep,
                "epoch: {}".format(epoch) + os.linesep,
                "lr: {:.2f}".format(float(lr)) + os.linesep,
            ]
            for k, v in log_dict.items():
                lines.append("{}: {:.2f}".format(k, v) + os.linesep)
            lines.append("-" * 25 + os.linesep)
            writer.writelines(lines)

    def train(self, resume=False, num_epochs=100):
        epoch = 1
        if resume:
            state = load_checkpoint(os.path.join(self.log_path, "latest"))
            epoch = state["epoch"] + 1
            self.distiller.load_state_dict(state["model"])
            self.optimizer.load_state_dict(state["optimizer"])
            self.best_acc = state["best_acc"]
        while epoch < num_epochs + 1:
            self.train_epoch(epoch)
            epoch += 1
        print(log_msg("Best accuracy:{}".format(self.best_acc), "EVAL"))
        with open(os.path.join(self.log_path, "worklog.txt"), "a") as writer:
            writer.write("best_acc\t" + "{:.2f}".format(float(self.best_acc)))

    def train_epoch(self, epoch):
        lr = self.adjust_learning_rate(epoch, self.optimizer)
        train_meters = {
            "training_time": AverageMeter(),
            "data_time": AverageMeter(),
            "losses": AverageMeter(),
            "top1": AverageMeter(),
            "top5": AverageMeter(),
        }
        num_iter = len(self.train_loader)
        pbar = tqdm(range(num_iter))

        # train loops
        self.distiller.train()
        for idx, data in enumerate(self.train_loader):
            msg, train_loss = self.train_iter(data, epoch, train_meters)
            pbar.set_description(log_msg(msg, "TRAIN"))
            pbar.update()
        pbar.close()

        test_acc, test_acc_top5, test_loss = validate(self.val_loader, self.distiller)

        dtkd_losses.append({"train_loss": train_loss, "test_loss": test_loss})
        dtkd_accuracies.append({"acc@1": test_acc.item(), "acc@5": test_acc_top5.item()})
        # log
        log_dict = OrderedDict(
            {
                "train_acc": train_meters["top1"].avg,
                "train_loss": train_meters["losses"].avg,
                "test_acc": test_acc,
                "test_acc_top5": test_acc_top5,
                "test_loss": test_loss,
            }
        )
        self.log(lr, epoch, log_dict)
        # saving checkpoint
        state = {
            "epoch": epoch,
            "model": self.distiller.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "best_acc": self.best_acc,
        }
        student_state = {"model": self.distiller.student.state_dict()}
        save_checkpoint(state, os.path.join(self.log_path, "latest"))
        save_checkpoint(
            student_state, os.path.join(self.log_path, "student_latest")
        )
        if epoch % 20 == 0:
            save_checkpoint(
                state, os.path.join(self.log_path, "epoch_{}".format(epoch))
            )
            save_checkpoint(
                student_state,
                os.path.join(self.log_path, "student_{}".format(epoch)),
            )
        # update the best
        if test_acc >= self.best_acc:
            save_checkpoint(state, os.path.join(self.log_path, "best"))
            save_checkpoint(
                student_state, os.path.join(self.log_path, "student_best")
            )

    def train_iter(self, data, epoch, train_meters):
        self.optimizer.zero_grad()
        train_start_time = time.time()
        image, target = data  # Adjusted to match the output of your data loader
        train_meters["data_time"].update(time.time() - train_start_time)
        image = image.float()
        image = image.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
    
        # forward
        preds, losses_dict = self.distiller(image=image, target=target, epoch=epoch)
    
        # backward
        loss = sum([l.mean() for l in losses_dict.values()])
        loss.backward()
        self.optimizer.step()
        train_meters["training_time"].update(time.time() - train_start_time)
        # collect info
        batch_size = image.size(0)
        acc1, acc5 = accuracy(preds, target, topk=(1, 5))
        train_meters["losses"].update(loss.cpu().detach().numpy().mean(), batch_size)
        train_meters["top1"].update(acc1[0], batch_size)
        train_meters["top5"].update(acc5[0], batch_size)
        # print info
        msg = "Epoch:{}| Time(data):{:.3f}| Time(train):{:.3f}| Loss:{:.4f}| Top-1:{:.3f}| Top-5:{:.3f}".format(
            epoch,
            train_meters["data_time"].avg,
            train_meters["training_time"].avg,
            train_meters["losses"].avg,
            train_meters["top1"].avg,
            train_meters["top5"].avg,
        )
        return (msg, train_meters["losses"].avg)
        
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def validate(val_loader, distiller):
    batch_time, losses, top1, top5 = [AverageMeter() for _ in range(4)]
    criterion = nn.CrossEntropyLoss()
    num_iter = len(val_loader)
    pbar = tqdm(range(num_iter))

    distiller.eval()
    with torch.no_grad():
        start_time = time.time()
        for idx, (image, target) in enumerate(val_loader):
            image = image.float()
            image = image.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)
            output = distiller(image=image)
            loss = criterion(output, target)
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            batch_size = image.size(0)
            losses.update(loss.cpu().detach().numpy().mean(), batch_size)
            top1.update(acc1[0], batch_size)
            top5.update(acc5[0], batch_size)

            # measure elapsed time
            batch_time.update(time.time() - start_time)
            start_time = time.time()
            msg = "Top-1:{top1.avg:.3f}| Top-5:{top5.avg:.3f}".format(
                top1=top1, top5=top5
            )
            pbar.set_description(log_msg(msg, "EVAL"))
            pbar.update()
    pbar.close()
    return top1.avg, top5.avg, losses.avg

def log_msg(msg, mode="INFO"):
    color_map = {
        "INFO": 36,
        "TRAIN": 32,
        "EVAL": 31,
    }
    msg = "\033[{}m[{}] {}\033[0m".format(color_map[mode], mode, msg)
    return msg

def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.reshape(1, -1).expand_as(pred))
        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def save_checkpoint(obj, path):
    with open(path, "wb") as f:
        torch.save(obj, f)

def load_checkpoint(path):
    with open(path, "rb") as f:
        return torch.load(f, map_location="cpu")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LossManager:
    def __init__(
        self, 
        alpha, 
        beta, 
        gamma, 
        initial_temperature,
        min_temperature
    ):
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.current_temperature = initial_temperature
        self.min_temperature = min_temperature

    def normalize(self, logit):
        mean = logit.mean(dim=-1, keepdims=True)
        stdv = logit.std(dim=-1, keepdims=True)
        
        return (logit - mean) / (1e-7 + stdv)
    
    def kd_loss(self, logits_student_in, logits_teacher_in, logit_stand=True):
        temperature = self.current_temperature
        
        logits_student = self.normalize(logits_student_in) 
        logits_teacher = self.normalize(logits_teacher_in)
        log_pred_student = F.log_softmax(logits_student / temperature, dim=1)
        
        pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
        loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean()
        loss_kd *= temperature*temperature
        
        return loss_kd

    def perception(self, logits, epsilon=1e-5):
        """
        perform perception on logits.
        
        Parameters:
        logits (torch.Tensor): A tensor of shape (B, N) where B is the batch size and N is the number of classes.
        epsilon (float): A small constant to avoid division by zero in normalization.
        
        Returns:
        torch.Tensor: perception logits.
        """
        batch_mean = torch.mean(logits, dim=0, keepdim=True)
        batch_var = torch.var(logits, dim=0, keepdim=True, unbiased=False)
        x_normalized = (logits - batch_mean) / torch.sqrt(batch_var + epsilon)
    
        return x_normalized
    
    def luminet_loss(self, logits_student, logits_teacher, target):
        temperature = self.current_temperature
        stu_batch = self.perception(logits_student)
        tea_batch = self.perception(logits_teacher)
    
        pred_teacher = F.softmax(tea_batch/temperature, dim=1)
        log_pred_student = F.log_softmax(stu_batch/temperature,dim=1)
        
        nckd_loss = F.kl_div(log_pred_student, pred_teacher, reduction='batchmean')
        nckd_loss *= (33.0*33.0)

        return nckd_loss
        
    def cosine_loss(self, student_logits, teacher_logits):
        """
        Compute cosine similarity loss between student and teacher logits.

        Args:
            student_logits (torch.Tensor): Logits from student model.
            teacher_logits (torch.Tensor): Logits from teacher model.

        Returns:
            torch.Tensor: Cosine similarity loss.
        """
        # Normalize logits
        student_norm = F.normalize(student_logits, p=2, dim=1)
        teacher_norm = F.normalize(teacher_logits, p=2, dim=1)
        
        # Compute cosine similarity loss
        cosine_loss = 1 - F.cosine_similarity(student_norm, teacher_norm).mean()
        return cosine_loss

    def rmse_loss(self, student_logits, teacher_logits):
        """
        Compute Root Mean Square Error (RMSE) between student and teacher logits.

        Args:
            student_logits (torch.Tensor): Logits from student model.
            teacher_logits (torch.Tensor): Logits from teacher model.

        Returns:
            torch.Tensor: RMSE loss.
        """
        
        rmse = torch.sqrt(F.mse_loss(student_logits, teacher_logits))
        return rmse
        
    def mae_loss(self, student_logits, teacher_logits):
        """
        Compute Root Mean Square Error (RMSE) between student and teacher logits.

        Args:
            student_logits (torch.Tensor): Logits from student model.
            teacher_logits (torch.Tensor): Logits from teacher model.

        Returns:
            torch.Tensor: RMSE loss.
        """
        
        rmse = torch.nn.L1Loss()(student_logits, teacher_logits)
        return rmse

    def hard_loss(self, student_logits, outputs):
        """
        Compute hard loss (cross-entropy) between student logits and true labels.

        Args:
            student_logits (torch.Tensor): Logits from student model.
            outputs (torch.Tensor): True labels.

        Returns:
            torch.Tensor: Cross-entropy loss.
        """
        
        return torch.nn.CrossEntropyLoss()(student_logits, outputs)

    def soft_distillation_loss(self, student_logits, teacher_logits):
        """
        Compute knowledge distillation loss with dynamic temperature.

        Args:
            student_logits (torch.Tensor): Logits from student model.
            teacher_logits (torch.Tensor): Logits from teacher model.

        Returns:
            torch.Tensor: Knowledge distillation loss.
        """
        soft_targets = F.softmax(teacher_logits / self.current_temperature, dim=1)
        soft_predictions = F.log_softmax(student_logits / self.current_temperature, dim=1)
        
        loss = F.kl_div(soft_predictions, soft_targets, reduction='batchmean')
        return loss * (self.current_temperature ** 2)

    def combined_loss(self, student_logits, teacher_logits, outputs):
        """Only include the additional losses (cosine and RMSE) here"""
        # Cosine loss
        cosine_loss = self.beta * self.cosine_loss(student_logits, teacher_logits)
        # RMSE loss
        rmse_loss = self.gamma * self.rmse_loss(student_logits, teacher_logits)
        return cosine_loss + rmse_loss
    
class DynamicTemperatureScheduler(nn.Module):
    """
    Dynamic Temperature Scheduler for Knowledge Distillation.

    Args:
        initial_temperature (float): Starting temperature value.
        min_temperature (float): Minimum allowable temperature.
        max_temperature (float): Maximum allowable temperature.
        schedule_type (str): Type of temperature scheduling strategy.
        loss_type (str): Type of loss to use (combined or general KD).
        alpha (float): Importance for soft loss, 1-alpha for hard loss.
        beta (float): Importance of cosine loss.
        gamma (float): Importance for RMSE loss.
    """
    def __init__(
        self, 
        initial_temperature=8.0, 
        min_temperature=4.0, 
        max_temperature=8,
        max_epoch=50,
        warmup=20,
        alpha=0.5,
        beta=0.9,
        gamma=0.5,
    ):
        super(DynamicTemperatureScheduler, self).__init__()

        self.current_temperature = initial_temperature
        self.initial_temperature = initial_temperature
        self.min_temperature = min_temperature
        self.max_temperature = max_temperature
        self.max_epoch = max_epoch
        self.warmup = warmup
        
        # Tracking training dynamics
        self.loss_history = []
        self.student_loss = []

        # Constants for importance
        self.loss_manager = LossManager(
            alpha, 
            beta, 
            gamma, 
            initial_temperature,
            min_temperature
        )
        
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        
    def update_temperature(self, current_epoch, loss_divergence):
        progress = torch.tensor(current_epoch / self.max_epoch)
        cosine_factor = 0.5 * (1 + torch.cos(torch.pi * progress))
        log_loss = torch.log(1 + torch.tensor(loss_divergence))
        adaptive_scale = log_loss/ (log_loss + 1)

        if adaptive_scale > 1:
            target_temperature = self.initial_temperature * cosine_factor * (1 + adaptive_scale)
        else:
            target_temperature = self.initial_temperature * cosine_factor
        
        target_temperature = torch.clamp(
            target_temperature, 
            self.min_temperature, 
            self.max_temperature
        )
        
        momentum = 0.9
        self.current_temperature = momentum * self.current_temperature + (1 - momentum) * target_temperature

        self.loss_manager.current_temperature = self.current_temperature
        
    def get_temperature(self):
        """
        Retrieve current temperature value.

        Returns:
            float: Current dynamic temperature.
        """
        
        return self.current_temperature
        
    def forward(self, epoch, student_logits, teacher_logits, outputs, loss_type="kd++"):
        """
        Forward pass to compute the loss based on the specified loss type.

        Args:
            student_logits (torch.Tensor): Logits from student model.
            teacher_logits (torch.Tensor): Logits from teacher model.
            outputs (torch.Tensor): True labels.

        Returns:
            torch.Tensor: Computed loss.
        """
        if loss_type == "ours":
            temp_ratio = (self.current_temperature - 1.0) / (3.0 - 1.0)
            temp_ratio = max(0, min(1, temp_ratio))
            
            # Base losses (always present)
            soft_loss = self.loss_manager.soft_distillation_loss(
                student_logits, 
                teacher_logits
            )
            
            hard_loss = self.loss_manager.hard_loss(
                student_logits, 
                outputs
            )
            
            teacher_loss = self.loss_manager.hard_loss(
                teacher_logits, 
                outputs
            )
            
            # Temperature-dependent weighting for soft vs hard
            if self.current_temperature > 1:
                soft_weight = self.alpha * temp_ratio + 0.4 * (1 - temp_ratio)
                hard_weight = (1 - self.alpha) * temp_ratio + 0.5 * (1 - temp_ratio)
            else:
                soft_weight = 0.2
                hard_weight = 0.5
                
            # Additional losses only when temperature is higher
            additional_losses = temp_ratio * self.loss_manager.combined_loss(
                student_logits, 
                teacher_logits, 
                outputs
            )
                
            warmup = 1 if self.warmup == None else min(epoch / self.warmup, 1.0)
            
            total_loss = (
                soft_weight * soft_loss + 
                hard_weight * hard_loss + 
                additional_losses
            )
            
            return  warmup * total_loss
            
        elif loss_type == "luminet":
            warmup = 1 if self.warmup == None else min(epoch / self.warmup, 1.0)
            
            loss_ce = (2.0) * F.cross_entropy(
                student_logits, 
                outputs
            )
            
            loss_luminet = warmup * self.loss_manager.luminet_loss(
                student_logits,
                teacher_logits,
                outputs
            )

            losses_dict = {
                "loss_ce": loss_ce,
                "loss_kd": loss_luminet,
            }

            return sum([l.mean() for l in losses_dict.values()])

        elif loss_type == "kd++":
            logits_student = student_logits
            logits_teacher = teacher_logits
    
            target = outputs
            
            loss_ce = 0.1 * F.cross_entropy(logits_student, target)
            
            loss_kd = 9 * self.loss_manager.kd_loss(
                logits_student, logits_teacher
            )
            
            losses_dict = {
                "loss_ce": loss_ce,
                "loss_kd": loss_kd,
            }

            return sum([l.mean() for l in losses_dict.values()])


In [None]:
import wandb
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
import os
import numpy as np

def calculate_accuracy(outputs, targets, topk=(1, 5)):
    """
    Calculate top-k accuracy
    
    Args:
        outputs (torch.Tensor): Model predictions
        targets (torch.Tensor): Ground truth labels
        topk (tuple): Top-k values to compute accuracy
    
    Returns:
        list: Top-k accuracies
    """
    with torch.no_grad():
        maxk = max(topk)
        batch_size = targets.size(0)

        # Get top-k predictions
        _, pred = outputs.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(targets.view(1, -1).expand_as(pred))

        # Calculate accuracies
        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        
        return res

def adjust_learning_rate(epoch, lr, optimizer):
    steps = np.sum(epoch > np.asarray([62, 75, 87]))
    if steps > 0:
        new_lr = 0.05 * (0.1**steps)
        for param_group in optimizer.param_groups:
            param_group["lr"] = new_lr
        return new_lr
    return lr

def train_knowledge_distillation(
    name,
    teacher_model, 
    student_model, 
    train_loader, 
    val_loader,
    optimizer,
    lr,
    epochs=50, 
    val_steps=10,
    temperature_scheduler=None,
    scheduler=None,
    save_path="./output/"
):
    """
    Train student model with periodic validation
    
    Args:
        teacher_model (nn.Module): Pre-trained teacher model
        student_model (nn.Module): Model to be distilled
        train_dataset (Dataset): Training data
        val_dataset (Dataset): Validation data
        epochs (int): Total training epochs
        alpha (float): Loss balancing coefficient
        temperature_scheduler (DynamicTemperatureScheduler): Temperature scheduler
        save_path (str): Path to save the best model
    """
    
    run = wandb.init(
        # Set the project where this run will be logged
        project="DTAD_Trials",
        name=name
    )
    
    # Optimizer and criterion
    student_optimizer = optimizer
    task_criterion = torch.nn.CrossEntropyLoss()
    
    # Set models to appropriate modes
    teacher_model.eval()
    val_loss = 0
    top1_acc = 0
    top5_acc = 0
    
    print("-" * 15 + " Teacher Validation " + "-" * 15)
    with torch.no_grad():
        for val_x, val_y in val_loader:
            val_x, val_y = val_x.to("cuda"), val_y.to("cuda")
            val_outputs = teacher_model(val_x)
            val_batch_loss = task_criterion(val_outputs, val_y)
            val_loss += val_batch_loss.item()
            
            # Calculate accuracies
            batch_top1, batch_top5 = calculate_accuracy(val_outputs, val_y)
            top1_acc += batch_top1.item()
            top5_acc += batch_top5.item()
    
    # Average validation metrics
    val_loss /= len(val_loader)
    top1_acc /= len(val_loader)
    top5_acc /= len(val_loader)
    
    print(f"Val Loss: {val_loss:.4f} | "
          f"Top-1 Accuracy: {top1_acc:.2f}% | Top-5 Accuracy: {top5_acc:.2f}%")
    print("-" * 50)

    best_top1_acc = 0.0  # Initialize best accuracy tracker
    
    # Training loop
    for epoch in range(epochs):
        # Training phase
        print("-" * 16 + " Training " + "-" * 16)
        student_model.train()
        train_loss = 0
        train_acc_1 = 0
        train_acc_5 = 0

        lr = adjust_learning_rate(epoch+1, lr, student_optimizer) if scheduler == None else 0

        for batch_idx, (batch_x, batch_y) in enumerate(train_loader):
            batch_x, batch_y = batch_x.to('cuda'), batch_y.to('cuda')
            
            # Forward passes
            with torch.no_grad():
                teacher_logits = teacher_model(batch_x)
                teacher_loss = task_criterion(teacher_logits, batch_y)

            student_logits = student_model(batch_x)
            student_loss = task_criterion(student_logits, batch_y)
            
            # Knowledge distillation loss
            if temperature_scheduler:
                # Combine losses
                total_batch_loss = temperature_scheduler(
                    epoch,
                    student_logits=student_logits,
                    teacher_logits=teacher_logits,
                    outputs=batch_y
                )

                temperature_scheduler.update_temperature(
                    current_epoch=epoch, 
                    loss_divergence=teacher_loss.item()-student_loss.item()
                )

                # Backward pass and optimization
                student_optimizer.zero_grad()
                total_batch_loss.backward()
                student_optimizer.step()
            
            # Calculate accuracies
            acc1, acc5 = calculate_accuracy(student_logits, batch_y) 
            train_loss += total_batch_loss.item()
            train_acc_1 += acc1.item()
            train_acc_5 += acc5.item()

            if batch_idx % 100 == 0:
                print(f"Epoch {epoch+1} | Batch {batch_idx}/{len(train_loader)} | "
                      f"Loss: {total_batch_loss.item():.4f} | Temp: {temperature_scheduler.get_temperature():.2f} | "
                      f"Acc@1: {acc1.item():.2f}% | Acc@5: {acc5.item():.2f}%")
        
        # Epoch-end metrics
        train_loss /= len(train_loader)
        train_acc_1 /= len(train_loader)
        train_acc_5 /= len(train_loader)
        
        if scheduler != None:
            scheduler.step()
            lr = scheduler.get_last_lr()[0]
        
        print(f"Epoch {epoch+1}/{epochs} | Training Loss: {train_loss:.4f} | "
              f"Acc@1: {train_acc_1:.2f}% | Acc@5: {train_acc_5:.2f}%")
        print("-" * 42)

    # if (epoch+1) % val_steps == 0:    
        # Validation phase
        print("-" * 15 + " Validation " + "-" * 15)
        student_model.eval()
        val_loss = 0
        top1_acc = 0
        top5_acc = 0
        
        with torch.no_grad():
            for val_x, val_y in val_loader:
                val_x, val_y = val_x.to("cuda"), val_y.to("cuda")
                val_outputs = student_model(val_x)
                val_batch_loss = task_criterion(val_outputs, val_y)
                val_loss += val_batch_loss.item()
                
                # Calculate accuracies
                batch_top1, batch_top5 = calculate_accuracy(val_outputs, val_y)
                top1_acc += batch_top1.item()
                top5_acc += batch_top5.item()
        
        # Average validation metrics
        val_loss /= len(val_loader)
        top1_acc /= len(val_loader)
        top5_acc /= len(val_loader)

        wandb.log(
            {
                "train_acc": train_acc_1, 
                "train_loss": train_loss,
                "val_acc": top1_acc, 
                "val_loss": val_loss,
                "lr": lr,
                "temp": temperature_scheduler.get_temperature()
            }
        )

        our_losses.append({"train_loss": train_loss, "test_loss": val_loss})
        our_accuracies.append({"acc@1": top1_acc, "acc@5": top5_acc})
        
        print(f"Epoch {epoch+1}/{epochs} | Val Loss: {val_loss:.4f} | "
              f"Top-1 Accuracy: {top1_acc:.2f}% | Top-5 Accuracy: {top5_acc:.2f}%")
        print("-" * 42)
        
        # Save the best model
        if top1_acc > best_top1_acc:
            best_top1_acc = top1_acc
            torch.save(student_model.state_dict(), f"DTAD_@{top1_acc}.pth")
            print(f"Best model saved at epoch {epoch+1} with Top-1 Accuracy: {best_top1_acc:.2f}%")
    print("Best Model Accuracy: ", best_top1_acc)
    run.finish()
    
    torch.save(student_model.state_dict(), "trained_studentDTAD.pth")
    return student_model

In [None]:
# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# CIFAR-10 Data Preparation
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))  # Mean and std of CIFAR-10
])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])

# Load CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR100(root="./data", train=True, transform=transform, download=True)
val_dataset = torchvision.datasets.CIFAR100(root="./data", train=False, transform=val_transform, download=True)
num_classes = len(train_dataset.classes)

batch_size=128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256)

teacher_model, path = cifar_model_dict["resnet56"]
teacher_model = teacher_model(num_classes=num_classes)
teacher_model.load_state_dict(torch.load(path)["model"])
teacher_model.to("cuda")

print("model_loaded")

In [None]:
student_model, path = cifar_model_dict["resnet20"]
student_model = student_model(num_classes=num_classes)
student_model.to("cuda")
print("models loaded")

max_epoch = 100
lr = 0.05

optimizer = torch.optim.SGD(
    student_model.parameters(), 
    lr=lr,
    weight_decay=5e-4,
    momentum=0.9
)

temp_scheduler = DynamicTemperatureScheduler(
    initial_temperature=4.0, 
    min_temperature=2.0, 
    max_temperature=8,
    max_epoch=max_epoch, 
    warmup=20
)

trained_student = train_knowledge_distillation(
    "56->20 (Ours) + kd+norm4->2 + log scaling",
    teacher_model, 
    student_model, 
    train_loader, 
    val_loader,
    optimizer=optimizer,
    lr=lr,
    epochs=max_epoch,
    val_steps=1,
    temperature_scheduler=temp_scheduler,
)

In [None]:
student_model, path = cifar_model_dict["resnet20"]
student_model = student_model(num_classes=num_classes)
student_model.to("cuda")

distiller = DTKD(student_model, teacher_model)

# # Initialize the CRDTrainer
trainer = BaseTrainer(
    experiment_name="DTKD",
    distiller=distiller,
    train_loader=train_loader, 
    val_loader=val_loader
)

trainer.train(num_epochs=max_epoch)

In [None]:
import matplotlib.pyplot as plt

def plot_losses():
    # Extracting train and test losses for plotting
    dtkd_train_loss = [entry['train_loss'] for entry in dtkd_losses]
    dtkd_test_loss = [entry['test_loss'] for entry in dtkd_losses]
    our_train_loss = [entry['train_loss'] for entry in our_losses]
    our_test_loss = [entry['test_loss'] for entry in our_losses]
    
    # FOR 100 EPOCH
    # Plotting
    plt.figure(figsize=(8, 6)) 
    
    # Train Losses
    plt.subplot(2, 1, 1)  # Positioning in the first row
    plt.plot(dtkd_train_loss, label="DTKD Train Loss", color='blue')
    plt.plot(our_train_loss, label="Our Train Loss", color='red')
    plt.title("Train Losses")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid()
    
    # Test Losses
    plt.subplot(2, 1, 2)  # Positioning in the second row
    plt.plot(dtkd_test_loss, label="DTKD Test Loss", color='blue')
    plt.plot(our_test_loss, label="Our Test Loss", color='red')
    plt.title("Test Losses")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid()
    
    plt.tight_layout()  # Adjust layout to avoid overlap
    plt.show()

plot_losses()

In [None]:
def plot_accuracies():
    
    # Extract data
    dtkd_acc1 = [entry['acc@1'] for entry in dtkd_accuracies]
    our_acc1 = [entry['acc@1'] for entry in our_accuracies]
    
    # Plotting
    plt.figure(figsize=(8, 6))
    plt.plot(dtkd_acc1, label="DTKD acc@1", color='blue')
    plt.plot(our_acc1, label="Our acc@1", color='red')
    
    # Graph details
    plt.title("Accuracy Comparison", fontsize=16)
    plt.xlabel("Epoch", fontsize=12)
    plt.ylabel("Accuracy (%)", fontsize=12)
    plt.legend()
    plt.grid()
    plt.tight_layout()
    
    # Show plot
    plt.show()
    
plot_accuracies()