In [1]:
import os
import glob
import pandas as pd
import numpy as np
from tqdm import tqdm
import pickle
from PIL import Image


import xml.etree.ElementTree as ET

In [2]:
MAIN_DIR = '/kaggle/input/pascal-voc-2012/VOC2012'

In [3]:
def xml_to_csv(path = os.path.join(MAIN_DIR,'Annotations')):
    xml_list = []
    
    for xml_file in tqdm(glob.glob(os.path.join(path, '2007*.xml'))):
        tree = ET.parse(xml_file)
        root = tree.getroot()
        
        for obj in root.findall('object'):
            bbx = obj.find('bndbox')
            xmin = int(bbx.find('xmin').text)
            ymin = int(bbx.find('ymin').text)
            xmax = int(bbx.find('xmax').text)
            ymax = int(bbx.find('ymax').text)
            label = obj.find('name').text

            # it would be better to use column name instead of index
            value = (root.find('filename').text,
                     int(root.find('size').find('depth').text), #0 , 2
                     int(root.find('size').find('width').text), #1 , 0
                     int(root.find('size').find('height').text), #2 , 1
                     label,
                     xmin,
                     ymin,
                     xmax,
                     ymax
                     )
            xml_list.append(value)
            
    column_name = ['filename', 'channels', 'width', 'height',
                   'class', 'xmin', 'ymin', 'xmax', 'ymax']
    
    xml_df = pd.DataFrame(xml_list, columns=column_name)
    
    return xml_df

In [4]:
xml_df = xml_to_csv().head(500)

100%|██████████| 756/756 [00:05<00:00, 128.28it/s]


In [5]:
xml_df.to_csv('data_descriptor.csv', index = False)

In [6]:
xml_df

Unnamed: 0,filename,channels,width,height,class,xmin,ymin,xmax,ymax
0,2007_005144.jpg,3,332,500,person,1,12,331,500
1,2007_005989.jpg,3,500,375,motorbike,140,130,408,273
2,2007_005989.jpg,3,500,375,person,213,96,355,260
3,2007_002107.jpg,3,500,375,aeroplane,408,243,449,257
4,2007_000822.jpg,3,500,374,motorbike,98,165,230,346
...,...,...,...,...,...,...,...,...,...
495,2007_004663.jpg,3,500,375,train,1,168,407,260
496,2007_003571.jpg,3,500,333,boat,259,206,500,333
497,2007_005428.jpg,3,375,500,bottle,1,1,217,362
498,2007_007414.jpg,3,500,333,person,453,50,491,110


In [7]:
classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", 
           "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa",
           "train", "tvmonitor"]
num_classes = len(classes)

## Vanilla KD Training

In [8]:
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torch.nn.functional as F

In [9]:
class CustomDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_path = os.path.join(MAIN_DIR,'JPEGImages',self.dataframe.iloc[idx]['filename'])
        label = self.dataframe.iloc[idx]['class']

        img = Image.open(img_path).convert('RGB')

        if self.transform:
            img = self.transform(img)

        return img, label

In [10]:
label_encoder = LabelEncoder().fit(xml_df['class'])
xml_df['class'] = label_encoder.transform(xml_df['class'])

In [11]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Assuming using ImageNet normalization
])

dataset = CustomDataset(xml_df, transform = transform)
dataloader = DataLoader(dataset, batch_size = 1, shuffle=True, num_workers=8)




In [12]:
mobilenet = models.mobilenet_v2(pretrained=True)
mobilenet.classifier[1] = torch.nn.Linear(mobilenet.classifier[1].in_features, num_classes)

class ModifiedMobileNetV2(nn.Module):
    def __init__(self, mobilenet):
        super(ModifiedMobileNetV2, self).__init__()
        self.mobilenet = mobilenet
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.mobilenet(x)
        x = self.softmax(x)
        return x

mobilenet = ModifiedMobileNetV2(mobilenet)

Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth
100%|██████████| 13.6M/13.6M [00:00<00:00, 80.4MB/s]


In [13]:
criterion = nn.CrossEntropyLoss()
kl_div_loss = nn.KLDivLoss()
optimizer = optim.SGD(mobilenet.parameters(), lr=0.005, momentum=0.9)

In [14]:
class ModifiedResnet(nn.Module):
    def __init__(self, resnet):
        super(ModifiedResnet, self).__init__()
        self.resnet = resnet
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.resnet(x)
        x = self.softmax(x)
        return x


In [15]:
teacher = pickle.load(open('/kaggle/input/deeplearning-data/teacher.pkl','rb'))

In [16]:
def get_gradcam_map_resnet(model, input_image, target_class=None):
    # Set the model to evaluation mode
    model.eval()
    
    # Forward pass
    features = model.conv1(input_image)
    features = model.bn1(features)
    features = model.relu(features)
    features = model.maxpool(features)

    features = model.layer1(features)
    features = model.layer2(features)
    features = model.layer3(features)
    features = model.layer4(features)

    # Average pooling to reduce spatial dimensions to 1x1
    pooled_features = nn.functional.adaptive_avg_pool2d(features, 1)
    
    # Compute the gradients of the target class score with respect to the pooled feature maps
    gradients = None
    for param in model.parameters():
        param.requires_grad_(True)
    output = model.fc(pooled_features.view(pooled_features.size(0), -1))
    if target_class is None:
        target_class = output.argmax(dim=1)
    output[:, target_class].backward(retain_graph=True)
    gradients = model.layer4[-1].conv2.weight.grad
    alpha = gradients.mean(dim=(2, 3), keepdim=True)

    # Compute the importance weights for each feature map
    gradcam_map = (alpha * features).sum(dim=1, keepdim=True)
    gradcam_map = nn.functional.relu(gradcam_map)
    
    # Resize the Grad-CAM map to match the input size
    gradcam_map = nn.functional.interpolate(gradcam_map, size=(10, 10), mode='bilinear', align_corners=False)

    return torch.mean(gradcam_map, dim=0)

In [17]:
def get_gradcam_map_mobilenet(model, input_image, target_class=None):
    activation = {}
    gradient = {}

    def get_activation(name):
        def hook(model, input, output):
            activation[name] = output.detach()
        return hook

    def get_gradient(name):
        def hook(module, grad_input, grad_output):
            gradient[name] = grad_output[0].detach()
        return hook

    hook1 = model.features[-1].register_forward_hook(get_activation('features'))
    hook2 = model.features[-1].register_backward_hook(get_gradient('features'))
    
    model.eval()
    output = model(input_image)
    target_class = torch.argmax(output)

    model.zero_grad()
    output[:, target_class].backward(create_graph=True)
    grads = gradient['features']
    activations = activation['features']

    for i in range(activations.shape[0]):
        activations[i, :] *= grads[i, :]
    grad_cam = torch.mean(activations, dim=1).unsqueeze(0)

    grad_cam = F.relu(grad_cam)
    gradcam_map = nn.functional.interpolate(grad_cam, size=(10, 10), mode='bilinear', align_corners=False)
    gradcam_map = gradcam_map.squeeze()
    
    hook1.remove()
    hook2.remove()
    return gradcam_map

In [18]:
def remove_hooks(model):
    for module in model.modules():
        # Check if the module has any hooks
        if hasattr(module, '_forward_hooks'):
            # Remove all hooks from the module
            module._forward_hooks.clear()
        if hasattr(module, '_backward_hooks'):
            module._backward_hooks.clear()

In [19]:
def train_model(student_model, teacher_model, criterion, optimizer, num_epochs = 10):
    for epoch in range(num_epochs):
        student_model.train()
        total_loss = 0.0

        for inputs, labels in dataloader:

            optimizer.zero_grad()
            with torch.no_grad():
                teacher_outputs = teacher_model(inputs)    
            student_outputs = student_model(inputs)
            
            teacher_CAM = get_gradcam_map_resnet(teacher_model.resnet, inputs)
            student_CAM = get_gradcam_map_mobilenet(student_model.mobilenet, inputs)
            
            teacher_model.zero_grad()
            student_model.zero_grad()
            
            kd_loss = kl_div_loss(student_outputs, teacher_outputs)
            ce_loss = criterion(student_outputs, labels)
            
            teacher_CAM = teacher_CAM.view(-1)
            student_CAM = student_CAM.view(-1)
            cosine_similarity = F.cosine_similarity(teacher_CAM, student_CAM, dim=0)
            cosine_distance = 1 - cosine_similarity
            
            loss = kd_loss + ce_loss + cosine_distance

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataloader.dataset):.4f}")
        remove_hooks(student_model)
        with open(f'mobilenet_EKD{epoch+1}.pkl', 'wb') as file:
            pickle.dump(student_model, file)
train_model(mobilenet, teacher, criterion, optimizer, num_epochs = 15)

  self.pid = os.fork()
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  self.pid = os.fork()


Epoch 1/15, Loss: 2.7387
Epoch 2/15, Loss: 2.7326
Epoch 3/15, Loss: 2.7320
Epoch 4/15, Loss: 2.7311
Epoch 5/15, Loss: 2.7304
Epoch 6/15, Loss: 2.7298
Epoch 7/15, Loss: 2.7299
Epoch 8/15, Loss: 2.7294
Epoch 9/15, Loss: 2.7292
Epoch 10/15, Loss: 2.7304
Epoch 11/15, Loss: 2.7327
Epoch 12/15, Loss: 2.7368
Epoch 13/15, Loss: 2.7428
Epoch 14/15, Loss: 2.7443
Epoch 15/15, Loss: 2.7323


In [20]:
with open('mobilenet_EKD.pkl', 'wb') as file:
    pickle.dump(mobilenet, file)