In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import logging
import os
import warnings

warnings.filterwarnings("ignore")
import torch.optim as optim
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from student_model_new import Knowledge_distiller, SimilarityLoss
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from tqdm.notebook import tqdm
from transformers import BertModel, BertTokenizer

In [24]:
# Define the ProjectionHead class
class ProjectionHead(nn.Module):
    def __init__(
        self,
        embedding_dim,
        projection_dim=512,  # Update the projection dimension
        dropout=0.1,
    ):
        super(ProjectionHead, self).__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)


    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected  # skip connection
        x = self.layer_norm(x)  # Layer Normalization
        return x


In [25]:
# Student feature extractor + Classifier
class StudentNet(nn.Module):
    def __init__(self,num_classes, hidden_layers=512):
        super(StudentNet, self).__init__()

        # download resnet152 model
        model = models.resnet152(pretrained=False)
        self.cropped_resnet152 = torch.nn.Sequential(*list(model.children())[:-2])
        self.projection_head = ProjectionHead(2048, projection_dim=512, dropout=0.1)

        # for classifier
        self.fc = nn.Linear(hidden_layers, num_classes)

    def forward(self, image):
        #feature extractor
        img_embedding = self.cropped_resnet152(image)
        image_embedding = img_embedding.view(len(img_embedding), 2048, 49)#(b,2048,49)
        image_embedding = image_embedding.permute(0, 2, 1)#(b,49,2048)
        image_embedding = self.projection_head(image_embedding)#(b,49,512)
        image_embedding = image_embedding.permute(1, 0, 2)#(49,b,512)
        #classifier
        x = torch.mean(image_embedding, dim=0)#(1,b,512)
        x = torch.unsqueeze(x,0)#(b,512)-----------------------------------//////
        logit = self.fc(x)

        return image_embedding, logit


In [26]:
#Loss function for knowledge distillation
class SimilarityLoss(nn.Module):
    def __init__(self):
        super(SimilarityLoss, self).__init__()

    def forward(self, teacher_em, student_em):#(b,512,49)
        batch_cs = torch.zeros(teacher_em.shape[0])  # tensor to hold cs for each feature
        for i in range(teacher_em.shape[0]):
            feature_cs = F.cosine_similarity(teacher_em[i, :, :], student_em[i, :, :], dim=1)
            mean_cs = torch.mean(feature_cs)
            batch_cs[i] = mean_cs

        mean_batch_cs = torch.mean(batch_cs)
        final_loss = 1 - mean_batch_cs
        return final_loss ##a value between 0 & 2 

In [27]:
#Main model for knowledge distillation
class Knowledge_distiller(nn.Module):
    def __init__(self, model_path="path_to_teacher", num_classes=2):
        super(Knowledge_distiller, self).__init__()

        self.student_net = StudentNet(num_classes=num_classes)

        #resnet teacher model
        model = models.resnet152(pretrained=True)#152
        self.teacher_img_net = torch.nn.Sequential(*list(model.children())[:-2])
        self.projection_head = ProjectionHead(2048, projection_dim=512, dropout=0.1)
        
        #freezing the teacher
        for param in self.teacher_img_net.parameters():
            param.requires_grad = False

        #Moving networks to GPU if available
        if torch.cuda.is_available():
            self.student_net = self.student_net.to("cuda:0")
            self.teacher_img_net = self.teacher_img_net.to("cuda:0")

    def forward(self, image, label):
        teacher_img_emb = self.teacher_img_net(image)
        teacher_img_emb = teacher_img_emb.view(len(teacher_img_emb), 2048, 49)#(b,2048,49)
        teacher_img_emb = teacher_img_emb.permute(0, 2, 1)#(b,49,2048)
        teacher_img_emb = self.projection_head(teacher_img_emb)#(b,49,512)
        teacher_img_emb = teacher_img_emb.permute(0, 2, 1)#(b,512,49)

        student_img_emb, prediction = self.student_net(image)#(49,b,512)
        student_img_emb = student_img_emb.permute(1, 2, 0)#(b,512,49)

        return  teacher_img_emb, student_img_emb, prediction

In [28]:
Image_size = 224
trainset_size = 0.99
Epochs = 5
training_loss = []
saved_teacher_net_path = "/data/ood/teacher_model_13_11/checkpoints/model_epoch_10.pt"

# Define batch_size
batch_size = 16

DIR = {
        "P": "/data/ood/PACS/pacs_data/pacs_data/photo",
        "A": "/data/ood/PACS/pacs_data/pacs_data/art_painting",
        "C": "/data/ood/PACS/pacs_data/pacs_data/cartoon",
        "S": "/data/ood/PACS/pacs_data/pacs_data/sketch",
    }

# PACS classes
labels = {
        "0": "Dog",
        "1": "Elephant",
        "2": "Giraffe",
        "3": "Guitar",
        "4": "Horse",
        "5": "House",
        "6": "Person",
    }


In [29]:
def data_loader(
    train_data_dir_1,
    train_data_dir_2,
    train_data_dir_3,
    valid_data_dir,
    batch_size,
    random_seed=38,
    valid_size=0.9,#train_size is not 0.1 it comes from totaly seperate domain
    shuffle=True,
):
    
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )
    

    # define transforms
    transform = transforms.Compose(
        [
            transforms.Resize((Image_size, Image_size)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            normalize,
        ]
    )

    train_dataset_1 = ImageFolder(root=train_data_dir_1, transform=transform)
    train_dataset_2 = ImageFolder(root=train_data_dir_2, transform=transform)
    train_dataset_3 = ImageFolder(root=train_data_dir_3, transform=transform)
    valid_dataset = ImageFolder(root=valid_data_dir, transform=transform)

    #train_dataset = train_dataset_1 + train_dataset_2 + train_dataset_3
    train_dataset = train_dataset_1

    train_indices, _ = train_test_split(
        list(range(len(train_dataset))),
        train_size=trainset_size,
        random_state=random_seed,
    )
    valid_indices, _ = train_test_split(
        list(range(len(valid_dataset))), train_size=valid_size, random_state=random_seed
    )

    # Create DataLoader for train and test sets
    train_dataset = torch.utils.data.Subset(train_dataset, train_indices)
    valid_dataset = torch.utils.data.Subset(valid_dataset, valid_indices)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, valid_loader


In [30]:
    # set the relevant domain as train, valid sets
train_dataloader, validation_dataloader= data_loader(
        DIR["P"],   #train
        DIR["A"],   #train
        DIR["C"],   #train
        DIR["P"],   #validation
        batch_size)

In [31]:
number_count=[0,0,0,0,0,0,0]
for batch, (image, label) in enumerate((train_dataloader)):
    for i in label:
        number_count[int(i)]+=1

for i,j in enumerate(number_count):
    print(f"Class {i} : {j*100/sum(number_count):.2f}%")

Class 0 : 11.19%
Class 1 : 12.04%
Class 2 : 10.95%
Class 3 : 11.19%
Class 4 : 11.92%
Class 5 : 16.76%
Class 6 : 25.95%


In [32]:
len(train_dataloader)*batch_size

1664

In [33]:
# lists to save embeddings to do PCA
text_embedding_global = []
S_img_embedding_global = []
T_img_embedding_global = []
label_set = []
epoch_flag = False  # flag to identify last epoch

In [34]:
device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available() #check whether multi process service is enabled
        else "cpu"
    )

In [35]:
if torch.cuda.is_available():
        model = Knowledge_distiller(num_classes=7,model_path=saved_teacher_net_path)
        model = model.to("cuda:0")  # model moved to specified GPU
        print("Model loaded to GPU.")
else:
        model = Knowledge_distiller(num_classes=7,model_path=saved_teacher_net_path)
        print("GPU is unavailable. model loaded to CPU.")

Model loaded to GPU.


In [36]:
criterion_1 = SimilarityLoss()  # Loss function for knowledge distillation
criterion_2 = nn.CrossEntropyLoss()  # Loss function for classifier
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Define a directory to save your model and checkpoints
save_dir = "Model_checkpoints_for_distiller"

# Make sure the directory exists, create it if it doesn't
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [37]:
# training only student network (classifier-> frozen)
for name, param in model.named_parameters():
    if name in ["student_net.fc.weight", "student_net.fc.bias"]:
        param.requires_grad = False
        

In [38]:
logging.basicConfig(
        filename="Distillation.log",
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
    )

In [39]:
for i in tqdm(range(Epochs)):
    epoch_loss = 0
    count = 0
    for batch, (image, label) in enumerate(tqdm(train_dataloader)):
        if epoch_flag:
            label_set.append(label.detach().numpy())
        label_tensor = []
        for idx, val in enumerate(label):
            label_tensor.append(labels[str(label[idx].item())])

        # Moving Img_data to GPU
        if torch.cuda.is_available():
            image = image.to("cuda:0")
            
        #print("Label {}".format(label))
        teacher_img_emb, student_img_emb,_ = model(image, label)

        # Calculating loss
        loss = criterion_1(teacher_img_emb, student_img_emb)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss


    if i == Epochs - 2:
        epoch_flag = True

    training_loss.append(
        (epoch_loss / len(train_dataloader)).cpu().detach().numpy()
    )
    # Print average loss for a batch in each epoch
    print(f"Epoch {i+1}: Loss: {training_loss[-1]}\n\n")

    # Save the model and training checkpoint
    if (i + 1) % 2 == 0:  # Save every 2 epochs, adjust as needed
        # Save the model's state dictionary
        model_checkpoint_path = os.path.join(save_dir, f"model_epoch_{i+1}.pt")
        # torch.save(model.state_dict(), model_checkpoint_path)

        # Save training checkpoint information
        checkpoint = {
            "epoch": i + 1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "loss": training_loss[-1],
        }
        checkpoint_path = os.path.join(save_dir, f"checkpoint_epoch_{i+1}.pt")
        # torch.save(checkpoint, checkpoint_path)

        # Log the saved paths and other details
        logging.info(f"Epoch {i+1} - Model saved to: {model_checkpoint_path}")
        logging.info(f"Epoch {i+1} - Checkpoint saved to: {checkpoint_path}")

logging.shutdown()

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/104 [00:00<?, ?it/s]

Epoch 1: Loss: 0.05128302797675133




  0%|          | 0/104 [00:00<?, ?it/s]

Epoch 2: Loss: 0.0026086282450705767




  0%|          | 0/104 [00:00<?, ?it/s]

Epoch 3: Loss: 0.0010389857925474644




  0%|          | 0/104 [00:00<?, ?it/s]

Epoch 4: Loss: 0.0006129953544586897




  0%|          | 0/104 [00:00<?, ?it/s]

Epoch 5: Loss: 0.00046085394569672644




In [40]:

# freezing feature extractor
for name, param in model.named_parameters():
    if name in ["student_net.fc.weight", "student_net.fc.bias"]:
        param.requires_grad = True



In [41]:
# Define a directory to save your model and checkpoints
save_dir = "Model_checkpoints_for_classifier"

# Make sure the directory exists, create it if it doesn't
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

## Training for the classifier by freezing feature extractor
training_loss = []

logging.basicConfig(
    filename="Classification.log",
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
)

In [42]:
Epochs  = 20
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [43]:
for i in tqdm(range(Epochs)):
    epoch_loss = 0
    count = 0
    for batch, (image, label) in enumerate(tqdm(train_dataloader)):
        if torch.cuda.is_available():
            image = image.to("cuda:0")

        imgt,imgs, logits = model(image,label)
        logits = torch.squeeze(logits,0)
        if torch.cuda.is_available():
            label = label.to("cuda:0")
        loss = criterion_2(logits, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss

    print(torch.argmax(logits,dim=1))
    training_loss.append(
        (epoch_loss / len(train_dataloader)).cpu().detach().numpy()
    )
    print(f"Epoch {i+1}: Loss: {training_loss[-1]}\n\n")

    # Save the model's state dictionary
    model_checkpoint_path = os.path.join(save_dir, f"model_epoch_{i+1}.pt")
    # torch.save(model.state_dict(), model_checkpoint_path)

    # Save training checkpoint information
    checkpoint = {
        "epoch": i + 1,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": training_loss[-1],
    }
    checkpoint_path = os.path.join(save_dir, f"checkpoint_epoch_{i+1}.pt")
    # torch.save(checkpoint, checkpoint_path)

    # Log the saved paths and other details
    logging.info(f"Epoch {i+1} - Model saved to: {model_checkpoint_path}")
    logging.info(f"Epoch {i+1} - Checkpoint saved to: {checkpoint_path}")

logging.shutdown()

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/104 [00:00<?, ?it/s]

tensor([0, 5, 6, 3, 5], device='cuda:0')
Epoch 1: Loss: 1.666720986366272




  0%|          | 0/104 [00:00<?, ?it/s]

tensor([5, 3, 6, 1, 1], device='cuda:0')
Epoch 2: Loss: 1.3974285125732422




  0%|          | 0/104 [00:00<?, ?it/s]

tensor([6, 5, 2, 0, 6], device='cuda:0')
Epoch 3: Loss: 1.2905151844024658




  0%|          | 0/104 [00:00<?, ?it/s]

tensor([3, 1, 6, 1, 6], device='cuda:0')
Epoch 4: Loss: 1.159075379371643




  0%|          | 0/104 [00:00<?, ?it/s]

tensor([6, 6, 5, 4, 1], device='cuda:0')
Epoch 5: Loss: 1.1306816339492798




  0%|          | 0/104 [00:00<?, ?it/s]

tensor([3, 6, 5, 3, 6], device='cuda:0')
Epoch 6: Loss: 1.132980465888977




  0%|          | 0/104 [00:00<?, ?it/s]

tensor([0, 5, 1, 3, 5], device='cuda:0')
Epoch 7: Loss: 1.1160430908203125




  0%|          | 0/104 [00:00<?, ?it/s]

tensor([1, 2, 6, 6, 6], device='cuda:0')
Epoch 8: Loss: 1.0364614725112915




  0%|          | 0/104 [00:00<?, ?it/s]

tensor([3, 6, 4, 5, 6], device='cuda:0')
Epoch 9: Loss: 1.0218942165374756




  0%|          | 0/104 [00:00<?, ?it/s]

tensor([6, 6, 4, 3, 6], device='cuda:0')
Epoch 10: Loss: 0.9699264168739319




  0%|          | 0/104 [00:00<?, ?it/s]

tensor([6, 2, 5, 1, 3], device='cuda:0')
Epoch 11: Loss: 1.021360993385315




  0%|          | 0/104 [00:00<?, ?it/s]

tensor([1, 6, 3, 2, 2], device='cuda:0')
Epoch 12: Loss: 0.9324161410331726




  0%|          | 0/104 [00:00<?, ?it/s]

tensor([0, 5, 0, 0, 6], device='cuda:0')
Epoch 13: Loss: 0.9081798791885376




  0%|          | 0/104 [00:00<?, ?it/s]

tensor([1, 4, 5, 4, 6], device='cuda:0')
Epoch 14: Loss: 0.8608699440956116




  0%|          | 0/104 [00:00<?, ?it/s]

tensor([6, 1, 5, 1, 6], device='cuda:0')
Epoch 15: Loss: 0.8248481154441833




  0%|          | 0/104 [00:00<?, ?it/s]

tensor([0, 3, 6, 5, 5], device='cuda:0')
Epoch 16: Loss: 0.774986743927002




  0%|          | 0/104 [00:00<?, ?it/s]

tensor([0, 6, 6, 0, 0], device='cuda:0')
Epoch 17: Loss: 0.783867359161377




  0%|          | 0/104 [00:00<?, ?it/s]

tensor([2, 6, 5, 6, 6], device='cuda:0')
Epoch 18: Loss: 0.8021409511566162




  0%|          | 0/104 [00:00<?, ?it/s]

tensor([5, 6, 6, 3, 4], device='cuda:0')
Epoch 19: Loss: 0.730690598487854




  0%|          | 0/104 [00:00<?, ?it/s]

tensor([6, 1, 6, 3, 5], device='cuda:0')
Epoch 20: Loss: 0.7032714486122131




In [44]:
epoch_flag = False
total_loss = 0
correct = 0
total = 0
i=0
for batch, (image, label) in enumerate(tqdm(validation_dataloader)):     
        # Moving Img_data to GPU
        if torch.cuda.is_available():
            image = image.to("cuda:0")

        imgt,imgs,prediction = model(image,label)
        prediction =torch.squeeze(prediction,0)
        loss = criterion_2(prediction.cpu(), label)
        total += label.size(0)
        predict_labels = torch.argmax(prediction, dim=1)
        predict_labels = torch.tensor(predict_labels,dtype = int)
        predict_eval_list = [a == b for a, b in zip(predict_labels, label)]
        correct += sum(predict_eval_list)
        total_loss += loss.item()
validation_loss = total_loss / len(validation_dataloader)
validation_accuracy = 100 * correct / total
print(f"Validation Loss: {validation_loss:.4f}")
print(f"Validation Accuracy: {validation_accuracy:.2f}%")


  0%|          | 0/94 [00:00<?, ?it/s]

Validation Loss: 0.6445
Validation Accuracy: 74.45%


: 