In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from torch.utils.data import DataLoader, Dataset
from torchvision.datasets.utils import download_url
from torchvision.datasets.folder import default_loader
import os
import zipfile
import random
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        # Reduced convolution complexity
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class ResNet10(nn.Module):
    def __init__(self, num_classes=200):
        super(ResNet10, self).__init__()
        block = BasicBlock

        # Initial convolution layer with reduced kernel size
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Create layers with reduced depth
        self.layer1 = self._make_layer(block, 64, 64, 1, stride=1)
        self.layer2 = self._make_layer(block, 64, 128, 1, stride=2)
        self.layer3 = self._make_layer(block, 128, 256, 1, stride=2)
        self.layer4 = self._make_layer(block, 256, 512, 1, stride=2)

        # Global average pooling
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        # Final fully connected layer
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        # Weight initialization
        self._initialize_weights()

    def _make_layer(self, block, in_channels, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * block.expansion)
            )

        layers = []
        layers.append(block(in_channels, out_channels, stride, downsample))
        for _ in range(1, blocks):
            layers.append(block(out_channels * block.expansion, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

# Ensure reproducibility
torch.manual_seed(0)
random.seed(0)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# Define paths
data_dir = './tiny-imagenet-200'
train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'val')

# Download and extract Tiny ImageNet dataset
def download_and_extract_tiny_imagenet():
    if not os.path.exists(data_dir):
        os.makedirs(data_dir, exist_ok=True)
        url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'
        filename = 'tiny-imagenet-200.zip'
        zip_path = os.path.join('./', filename)
        print('Downloading Tiny ImageNet dataset...')
        download_url(url, root='./', filename=filename)
        print('Extracting Tiny ImageNet dataset...')
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall('./')
        os.remove(zip_path)
        print('Dataset downloaded and extracted.')
    else:
        print('Tiny ImageNet dataset already exists.')

download_and_extract_tiny_imagenet()

# Prepare validation data
def prepare_val_folder():
    val_img_dir = os.path.join(val_dir, 'images')
    if not os.path.exists(val_img_dir):
        return
    # Read val annotations file
    val_annotations_file = os.path.join(val_dir, 'val_annotations.txt')
    val_img_dict = {}
    with open(val_annotations_file, 'r') as f:
        for line in f.readlines():
            parts = line.strip().split('\t')
            img_name = parts[0]
            img_class = parts[1]
            val_img_dict[img_name] = img_class

    # Create folders for validation images
    print('Organizing validation images...')
    for img, cls in tqdm(val_img_dict.items()):
        cls_dir = os.path.join(val_dir, cls)
        if not os.path.exists(cls_dir):
            os.mkdir(cls_dir)
            os.mkdir(os.path.join(cls_dir, 'images'))
        img_src = os.path.join(val_dir, 'images', img)
        img_dst = os.path.join(cls_dir, 'images', img)
        if os.path.exists(img_src):
            os.rename(img_src, img_dst)
    os.rmdir(os.path.join(val_dir, 'images'))
    print('Validation images organized.')

prepare_val_folder()

# Define data transformations
transform_train = transforms.Compose([
    transforms.RandomRotation(20),
    transforms.RandomHorizontalFlip(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Create custom dataset class
class TinyImageNetDataset(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.root = root
        self.transform = transform
        self.images = []
        self.labels = []
        self.train = train
        self._load_data()

    def _load_data(self):
        if self.train:
            data_dir = os.path.join(self.root, 'train')
        else:
            data_dir = os.path.join(self.root, 'val')
        classes = sorted(os.listdir(data_dir))
        class_to_idx = {cls_name: idx for idx, cls_name in enumerate(classes)}
        for cls_name in classes:
            cls_dir = os.path.join(data_dir, cls_name, 'images')
            if not os.path.isdir(cls_dir):
                continue
            img_files = os.listdir(cls_dir)
            for img_name in img_files:
                img_path = os.path.join(cls_dir, img_name)
                self.images.append(img_path)
                self.labels.append(class_to_idx[cls_name])
        self.classes = classes
        self.class_to_idx = class_to_idx

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = default_loader(img_path)
        if self.transform:
            image = self.transform(image)
        return image, label

# Load datasets
train_dataset = TinyImageNetDataset(root=data_dir, train=True, transform=transform_train)
test_dataset = TinyImageNetDataset(root=data_dir, train=False, transform=transform_test)

# Set a larger batch size
batch_size = 1024  # Adjust this value based on your GPU memory
print('Batch size:', batch_size)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

# Define function to create and modify ResNet models
def create_resnet_model(name, num_classes=200, pretrained=True):
    if name == 'resnet10':
        model = ResNet10(num_classes=num_classes)
    elif name == 'resnet18':
        model = torchvision.models.resnet18(pretrained=pretrained)
    elif name == 'resnet34':
        model = torchvision.models.resnet34(pretrained=pretrained)
    elif name == 'resnet101':
        model = torchvision.models.resnet101(pretrained=pretrained)
    else:
        raise ValueError('Invalid model name')

    # Modify the final layer to match num_classes
    if name != 'resnet10':
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, num_classes)

    return model

from torch.cuda.amp import autocast, GradScaler

# Function to train a model normally (used for teacher model)
def train_model(model, train_loader, test_loader, num_epochs=10, base_lr=0.1, device='cuda', save_path='best_model.pth'):
    criterion = nn.CrossEntropyLoss()
    adjusted_lr = base_lr * (batch_size / 256)
    print('Adjusted learning rate:', adjusted_lr)
    optimizer = optim.SGD(model.parameters(), lr=adjusted_lr,
                          momentum=0.9, weight_decay=5e-4)
    scaler = GradScaler()
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    model.to(device)

    best_acc = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            optimizer.zero_grad()
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, targets)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()

            if (batch_idx+1) % 10 == 0:
                print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f'
                      % (epoch+1, num_epochs, batch_idx+1, len(train_loader), running_loss/10))
                running_loss = 0.0

        # Validation
        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(test_loader):
                inputs = inputs.to(device, non_blocking=True)
                targets = targets.to(device, non_blocking=True)
                with autocast():
                    outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()

        acc = 100 * correct / total
        print('Test Accuracy of the model on the test images: {:.2f} %'.format(acc))

        if acc > best_acc:
            best_acc = acc
            # Save the best model
            torch.save(model.state_dict(), save_path)
            print(f"Saved best model to {save_path}")

        scheduler.step()

    print('Best Accuracy: {:.2f} %'.format(best_acc))
    return best_acc

# Function for knowledge distillation from teacher to student
def train_kd(student_model, teacher_model, train_loader, test_loader, num_epochs=10, base_lr=0.1, temperature=4, alpha=0.9, device='cuda', save_path='best_student_model.pth'):
    criterion = nn.CrossEntropyLoss()
    soft_loss_fn = nn.KLDivLoss(reduction='batchmean')

    adjusted_lr = base_lr * (batch_size / 256)
    print('Adjusted learning rate:', adjusted_lr)

    optimizer = optim.SGD(student_model.parameters(), lr=adjusted_lr,
                          momentum=0.9, weight_decay=5e-4)
    scaler = GradScaler()
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    student_model.to(device)
    teacher_model.to(device)
    teacher_model.eval()

    best_acc = 0

    for epoch in range(num_epochs):
        student_model.train()
        running_loss = 0.0

        for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            optimizer.zero_grad()

            with autocast():
                outputs = student_model(inputs)
                with torch.no_grad():
                    teacher_outputs = teacher_model(inputs)

                loss_ce = criterion(outputs, targets)
                loss_kd = soft_loss_fn(F.log_softmax(outputs/temperature, dim=1),
                                       F.softmax(teacher_outputs/temperature, dim=1)) * (temperature ** 2)

                loss = alpha * loss_kd + (1 - alpha) * loss_ce

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()

            if (batch_idx+1) % 10 == 0:
                print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f'
                      % (epoch+1, num_epochs, batch_idx+1, len(train_loader), running_loss/10))
                running_loss = 0.0

        # Validation
        student_model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(test_loader):
                inputs = inputs.to(device, non_blocking=True)
                targets = targets.to(device, non_blocking=True)
                with autocast():
                    outputs = student_model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()

        acc = 100 * correct / total
        print('Test Accuracy of the student model on the test images: {:.2f} %'.format(acc))

        if acc > best_acc:
            best_acc = acc
            # Save the best model
            torch.save(student_model.state_dict(), save_path)
            print(f"Saved best model to {save_path}")

        scheduler.step()

    print('Best Accuracy: {:.2f} %'.format(best_acc))
    return best_acc

# Function for knowledge distillation with both teacher and TA (simple average)
def train_kd_with_ta(student_model, teacher_model, ta_model, train_loader, test_loader, num_epochs=10, base_lr=0.1, temperature=4, alpha=0.9, device='cuda', save_path='best_student_model.pth'):
    criterion = nn.CrossEntropyLoss()
    soft_loss_fn = nn.KLDivLoss(reduction='batchmean')

    adjusted_lr = base_lr * (batch_size / 256)
    print('Adjusted learning rate:', adjusted_lr)

    optimizer = optim.SGD(student_model.parameters(), lr=adjusted_lr,
                          momentum=0.9, weight_decay=5e-4)
    scaler = GradScaler()
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    student_model.to(device)
    teacher_model.to(device)
    teacher_model.eval()
    ta_model.to(device)
    ta_model.eval()

    best_acc = 0

    for epoch in range(num_epochs):
        student_model.train()
        running_loss = 0.0

        for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            optimizer.zero_grad()

            with autocast():
                outputs = student_model(inputs)
                with torch.no_grad():
                    teacher_outputs = teacher_model(inputs)
                    ta_outputs = ta_model(inputs)
                    # Average the softmax outputs
                    avg_outputs = (F.softmax(teacher_outputs/temperature, dim=1) + F.softmax(ta_outputs/temperature, dim=1)) / 2

                loss_ce = criterion(outputs, targets)
                loss_kd = soft_loss_fn(F.log_softmax(outputs/temperature, dim=1),
                                       avg_outputs) * (temperature ** 2)

                loss = alpha * loss_kd + (1 - alpha) * loss_ce

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()

            if (batch_idx+1) % 10 == 0:
                print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f'
                      % (epoch+1, num_epochs, batch_idx+1, len(train_loader), running_loss/10))
                running_loss = 0.0

        # Validation
        student_model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(test_loader):
                inputs = inputs.to(device, non_blocking=True)
                targets = targets.to(device, non_blocking=True)
                with autocast():
                    outputs = student_model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()

        acc = 100 * correct / total
        print('Test Accuracy of the student model on the test images: {:.2f} %'.format(acc))

        if acc > best_acc:
            best_acc = acc
            # Save the best model
            torch.save(student_model.state_dict(), save_path)
            print(f"Saved best model to {save_path}")

        scheduler.step()

    print('Best Accuracy: {:.2f} %'.format(best_acc))
    return best_acc

# Function for the new distillation algorithm
def train_kd_new_algorithm(student_model, teacher_model, ta_model, train_loader, test_loader, num_epochs=10, base_lr=0.1, temp=5, alpha=0.9, device='cuda', save_path='best_student_model.pth'):
    criterion = nn.CrossEntropyLoss(reduction='none')  # per-sample loss
    kl_criterion = nn.KLDivLoss(reduction='none')  # per-sample loss

    adjusted_lr = base_lr * (batch_size / 256)
    print('Adjusted learning rate:', adjusted_lr)

    optimizer = optim.SGD(student_model.parameters(), lr=adjusted_lr,
                          momentum=0.9, weight_decay=5e-4)
    scaler = GradScaler()
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    student_model.to(device)
    teacher_model.to(device)
    teacher_model.eval()
    ta_model.to(device)
    ta_model.eval()

    best_acc = 0

    for epoch in range(num_epochs):
        student_model.train()
        running_loss = 0.0

        for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
            data = data.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)

            optimizer.zero_grad()

            with autocast():
                output = student_model(data)
                with torch.no_grad():
                    teacher_outputs = teacher_model(data)
                    ta_outputs = ta_model(data)

                # Standard Learning Loss (Classification Loss)
                loss_SL = criterion(output, target)  # shape: [batch_size]
                hard_loss = loss_SL

                # Implement the new distillation algorithm
                # Compute per-sample cross-entropy losses for teacher and TA
                ce_teacher = criterion(teacher_outputs, target)  # shape: [batch_size]
                ce_ta = criterion(ta_outputs, target)  # shape: [batch_size]

                # Compute negative ce
                neg_ce_teacher = -ce_teacher
                neg_ce_ta = -ce_ta

                # Stack negative ce to compute confidence scores
                neg_ce = torch.stack([neg_ce_teacher, neg_ce_ta], dim=1)  # shape: [batch_size, 2]

                # Compute confidence scores
                conf_scores = F.softmax(neg_ce, dim=1)  # shape: [batch_size, 2]

                conf_teacher = conf_scores[:, 0]  # shape: [batch_size]
                conf_ta = conf_scores[:, 1]  # shape: [batch_size]

                # Compute softmax outputs for teacher and TA
                teacher_pred = F.softmax(teacher_outputs / temp, dim=1)  # shape: [batch_size, num_classes]
                ta_pred = F.softmax(ta_outputs / temp, dim=1)

                # Compute KL divergence between teacher and TA
                kl_teacher_ta = kl_criterion(
                    F.log_softmax(teacher_outputs / temp, dim=1),
                    ta_pred
                ).sum(dim=1)  # shape: [batch_size]

                # Compute kl_factor
                kl_factor = torch.sigmoid(kl_teacher_ta)  # shape: [batch_size]

                # Compute final weights
                w_teacher = (1 - kl_factor) * 0.5 + kl_factor * conf_teacher  # shape: [batch_size]
                w_ta = (1 - kl_factor) * 0.5 + kl_factor * conf_ta  # shape: [batch_size]

                # Compute KL divergence between student and teacher
                kl_student_teacher = kl_criterion(
                    F.log_softmax(output / temp, dim=1),
                    teacher_pred
                ).sum(dim=1)  # shape: [batch_size]

                kl_student_ta = kl_criterion(
                    F.log_softmax(output / temp, dim=1),
                    ta_pred
                ).sum(dim=1)  # shape: [batch_size]

                # Compute soft losses
                soft_loss_teacher = w_teacher * kl_student_teacher * (temp ** 2)  # shape: [batch_size]
                soft_loss_ta = w_ta * kl_student_ta * (temp ** 2)  # shape: [batch_size]

                # Compute total_loss per sample
                total_loss = alpha * (soft_loss_teacher + soft_loss_ta) + (1 - alpha) * hard_loss  # shape: [batch_size]

                # Compute loss as average over batch
                loss = total_loss.mean()

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()

            if (batch_idx+1) % 10 == 0:
                print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f'
                      % (epoch+1, num_epochs, batch_idx+1, len(train_loader), running_loss/10))
                running_loss = 0.0

        # Validation
        student_model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(test_loader):
                inputs = inputs.to(device, non_blocking=True)
                targets = targets.to(device, non_blocking=True)
                with autocast():
                    outputs = student_model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()

        acc = 100 * correct / total
        print('Test Accuracy of the student model on the test images: {:.2f} %'.format(acc))

        if acc > best_acc:
            best_acc = acc
            # Save the best model
            torch.save(student_model.state_dict(), save_path)
            print(f"Saved best model to {save_path}")

        scheduler.step()

    print('Best Accuracy: {:.2f} %'.format(best_acc))
    return best_acc

# Set up device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# Clear cache
torch.cuda.empty_cache()

from torchsummary import summary


# Load the Teacher Model (ResNet-101)
print('Loading Teacher Model (ResNet-101)')
teacher_model = create_resnet_model('resnet101', num_classes=200, pretrained=False)
teacher_model.load_state_dict(torch.load('/content/best_model.pth'))
teacher_model = teacher_model.to(device)
teacher_model.eval()

summary(teacher_model, (3, 224, 224))

print('Loading TA Model (ResNet-34)')
ta_model = create_resnet_model('resnet34', num_classes=200, pretrained=False)
ta_model.load_state_dict(torch.load('/content/resnet_34_tf.pth'))
ta_model = ta_model.to(device)
ta_model.eval()

print("NO KD")
no_kd_10 = create_resnet_model('resnet10', num_classes=200, pretrained=False)
no_kd_10 = no_kd_10.to(device)
no_kd_10_best_acc = train_model(no_kd_10, train_loader, test_loader, num_epochs=40, base_lr=0.1, device=device, save_path='no_kd_10.pth')

print("STANDARD KD")
standard_kd_10 = create_resnet_model('resnet10', num_classes=200, pretrained=False)
standard_kd_10 = standard_kd_10.to(device)
standard_kd_10_best_acc = train_kd(standard_kd_10, teacher_model, train_loader, test_loader, num_epochs=40, base_lr=0.1, temperature=4, alpha=0.9, device=device, save_path='standard_kd_10.pth')

print("TA")
student_model_alg3 = create_resnet_model('resnet10', num_classes=200, pretrained=False)
student_model_alg3 = student_model_alg3.to(device)
student_best_acc_alg3 = train_kd(student_model_alg3, ta_model, train_loader, test_loader, num_epochs=40, base_lr=0.1, temperature=4, alpha=0.9, device=device, save_path='student_model_alg3.pth')

print("Average")
student_model_alg2 = create_resnet_model('resnet10', num_classes=200, pretrained=False)
student_model_alg2 = student_model_alg2.to(device)
student_best_acc_alg2 = train_kd_with_ta(student_model_alg2, teacher_model, ta_model, train_loader, test_loader, num_epochs=40, base_lr=0.1, temperature=4, alpha=0.9, device=device, save_path='student_model_alg2.pth')

print("Weighted")
student_model_alg1 = create_resnet_model('resnet10', num_classes=200, pretrained=False)
student_model_alg1 = student_model_alg1.to(device)
student_best_acc_alg1 = train_kd_new_algorithm(student_model_alg1, teacher_model, ta_model, train_loader, test_loader, num_epochs=40, base_lr=0.1, temp=5, alpha=0.9, device=device, save_path='student_model_alg1.pth')

Using device: cuda
Tiny ImageNet dataset already exists.
Batch size: 1024
Using device: cuda
Loading Teacher Model (ResNet-101)


  teacher_model.load_state_dict(torch.load('/content/best_model.pth'))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
           Conv2d-13          [-1, 256, 56, 56]          16,384
      BatchNorm2d-14          [-1, 256,

  ta_model.load_state_dict(torch.load('/content/resnet_34_tf.pth'))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64,

In [None]:
!ls -lh

total 54M
-rw-r--r-- 1 root root  27M Dec  9 03:25 best_model.pth
-rw-r--r-- 1 root root  26M Dec  9 03:25 resnet_34_tf.pth
drwxr-xr-x 1 root root 4.0K Dec  5 14:24 sample_data
drwxr-xr-x 5 root root 4.0K Dec  9 03:27 tiny-imagenet-200
