In [1]:
import torch
import torch.nn as nn
import clip
from torch.utils.data import DataLoader
from tqdm import tqdm
from datasets import load_from_disk
from model import CustomCLIPClassifier
from utils import CustomDataset


class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.1):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, features, labels):
        # Normalize feature vectors
        features = nn.functional.normalize(features, dim=1)

        # Compute similarity scores
        similarity_matrix = torch.mm(features, features.T)
        labels = labels.unsqueeze(1)

        # Compute mask for Positive Pairs
        positive_mask = labels == labels.T

        # Mask out self-similarity
        mask = torch.eye(similarity_matrix.size(0), device=similarity_matrix.device).bool()
        similarity_matrix = similarity_matrix.masked_fill(mask, -float('inf'))

        # Positive and Negative Pairs
        positives = similarity_matrix[positive_mask].view(-1)
        negatives = similarity_matrix[~positive_mask & ~mask].view(-1)

        # Compute Contrastive Loss
        positive_loss = -torch.log(torch.exp(positives / self.temperature).sum())
        negative_loss = torch.log(torch.exp(negatives / self.temperature).sum())
        loss = (positive_loss + negative_loss) / features.size(0)
        return loss


# Load CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device)

# Freeze CLIP model weights
for param in clip_model.parameters():
    param.requires_grad = False


classifier_model = CustomCLIPClassifier(clip_model).to(device)

In [2]:
# Optimizers for different steps
optimizer_contrastive = torch.optim.Adam(classifier_model.feature_layer.parameters(), lr=1e-4)
optimizer_classifier = torch.optim.Adam(classifier_model.classifier.parameters(), lr=1e-4)

# Schedulers for learning rate decay
scheduler_contrastive = torch.optim.lr_scheduler.StepLR(optimizer_contrastive, step_size=20, gamma=0.5)
scheduler_classifier = torch.optim.lr_scheduler.StepLR(optimizer_classifier, step_size=20, gamma=0.5)

# Loss functions
contrastive_loss_fn = ContrastiveLoss()
ce_loss_fn = nn.CrossEntropyLoss()

# Dataset and Dataloader
train_dataset = load_from_disk("/root/Representational-Learning/dataset/dataset/train")
val_dataset = load_from_disk("/root/Representational-Learning/dataset/dataset/val")
train_dataloader = DataLoader(CustomDataset(train_dataset, preprocess), batch_size=64, shuffle=True)
val_dataloader = DataLoader(CustomDataset(val_dataset, preprocess), batch_size=64, shuffle=False)

In [3]:
# Training Step 1: Contrastive Loss
print("Step 1: Training with Contrastive Loss")
classifier_model.train()
for epoch in range(100):  # 30 epochs for contrastive loss
    total_loss = 0
    for images, labels in tqdm(train_dataloader, desc=f"Contrastive Epoch {epoch + 1}/100"):
        images, labels = images.to(device), labels.to(device)

        # Forward pass for contrastive loss
        features = classifier_model(images, return_features=True)

        # Compute contrastive loss
        contrastive_loss = contrastive_loss_fn(features, labels)

        # Backward pass
        optimizer_contrastive.zero_grad()
        contrastive_loss.backward()
        optimizer_contrastive.step()

        total_loss += contrastive_loss.item()

    # Step the scheduler
    scheduler_contrastive.step()

    print(f"Epoch {epoch + 1}, Contrastive Loss: {total_loss / len(train_dataloader):.4f}")

Step 1: Training with Contrastive Loss


Contrastive Epoch 1/100: 100%|██████████| 51/51 [01:01<00:00,  1.21s/it]


Epoch 1, Contrastive Loss: 0.0139


Contrastive Epoch 2/100: 100%|██████████| 51/51 [01:05<00:00,  1.29s/it]


Epoch 2, Contrastive Loss: -0.0066


Contrastive Epoch 3/100: 100%|██████████| 51/51 [01:05<00:00,  1.28s/it]


Epoch 3, Contrastive Loss: -0.0144


Contrastive Epoch 4/100: 100%|██████████| 51/51 [01:06<00:00,  1.30s/it]


Epoch 4, Contrastive Loss: -0.0165


Contrastive Epoch 5/100: 100%|██████████| 51/51 [01:01<00:00,  1.21s/it]


Epoch 5, Contrastive Loss: -0.0184


Contrastive Epoch 6/100: 100%|██████████| 51/51 [01:05<00:00,  1.29s/it]


Epoch 6, Contrastive Loss: -0.0187


Contrastive Epoch 7/100: 100%|██████████| 51/51 [01:01<00:00,  1.21s/it]


Epoch 7, Contrastive Loss: -0.0212


Contrastive Epoch 8/100: 100%|██████████| 51/51 [01:01<00:00,  1.21s/it]


Epoch 8, Contrastive Loss: -0.0209


Contrastive Epoch 9/100: 100%|██████████| 51/51 [01:05<00:00,  1.29s/it]


Epoch 9, Contrastive Loss: -0.0229


Contrastive Epoch 10/100: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Epoch 10, Contrastive Loss: -0.0269


Contrastive Epoch 11/100: 100%|██████████| 51/51 [01:06<00:00,  1.30s/it]


Epoch 11, Contrastive Loss: -0.0263


Contrastive Epoch 12/100: 100%|██████████| 51/51 [01:03<00:00,  1.25s/it]


Epoch 12, Contrastive Loss: -0.0284


Contrastive Epoch 13/100: 100%|██████████| 51/51 [01:06<00:00,  1.30s/it]


Epoch 13, Contrastive Loss: -0.0266


Contrastive Epoch 14/100: 100%|██████████| 51/51 [01:01<00:00,  1.20s/it]


Epoch 14, Contrastive Loss: -0.0282


Contrastive Epoch 15/100: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Epoch 15, Contrastive Loss: -0.0292


Contrastive Epoch 16/100: 100%|██████████| 51/51 [01:05<00:00,  1.29s/it]


Epoch 16, Contrastive Loss: -0.0312


Contrastive Epoch 17/100: 100%|██████████| 51/51 [01:02<00:00,  1.23s/it]


Epoch 17, Contrastive Loss: -0.0320


Contrastive Epoch 18/100: 100%|██████████| 51/51 [01:05<00:00,  1.27s/it]


Epoch 18, Contrastive Loss: -0.0303


Contrastive Epoch 19/100: 100%|██████████| 51/51 [01:06<00:00,  1.31s/it]


Epoch 19, Contrastive Loss: -0.0311


Contrastive Epoch 20/100: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 20, Contrastive Loss: -0.0333


Contrastive Epoch 21/100: 100%|██████████| 51/51 [01:05<00:00,  1.29s/it]


Epoch 21, Contrastive Loss: -0.0322


Contrastive Epoch 22/100: 100%|██████████| 51/51 [01:04<00:00,  1.27s/it]


Epoch 22, Contrastive Loss: -0.0334


Contrastive Epoch 23/100: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 23, Contrastive Loss: -0.0326


Contrastive Epoch 24/100: 100%|██████████| 51/51 [01:06<00:00,  1.31s/it]


Epoch 24, Contrastive Loss: -0.0343


Contrastive Epoch 25/100: 100%|██████████| 51/51 [01:03<00:00,  1.25s/it]


Epoch 25, Contrastive Loss: -0.0340


Contrastive Epoch 26/100: 100%|██████████| 51/51 [01:05<00:00,  1.29s/it]


Epoch 26, Contrastive Loss: -0.0313


Contrastive Epoch 27/100: 100%|██████████| 51/51 [01:06<00:00,  1.30s/it]


Epoch 27, Contrastive Loss: -0.0338


Contrastive Epoch 28/100: 100%|██████████| 51/51 [01:01<00:00,  1.21s/it]


Epoch 28, Contrastive Loss: -0.0335


Contrastive Epoch 29/100: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 29, Contrastive Loss: -0.0354


Contrastive Epoch 30/100: 100%|██████████| 51/51 [01:02<00:00,  1.22s/it]


Epoch 30, Contrastive Loss: -0.0349


Contrastive Epoch 31/100: 100%|██████████| 51/51 [01:01<00:00,  1.21s/it]


Epoch 31, Contrastive Loss: -0.0343


Contrastive Epoch 32/100: 100%|██████████| 51/51 [01:01<00:00,  1.20s/it]


Epoch 32, Contrastive Loss: -0.0342


Contrastive Epoch 33/100: 100%|██████████| 51/51 [01:00<00:00,  1.19s/it]


Epoch 33, Contrastive Loss: -0.0347


Contrastive Epoch 34/100: 100%|██████████| 51/51 [01:02<00:00,  1.23s/it]


Epoch 34, Contrastive Loss: -0.0357


Contrastive Epoch 35/100: 100%|██████████| 51/51 [01:01<00:00,  1.21s/it]


Epoch 35, Contrastive Loss: -0.0373


Contrastive Epoch 36/100: 100%|██████████| 51/51 [01:01<00:00,  1.21s/it]


Epoch 36, Contrastive Loss: -0.0368


Contrastive Epoch 37/100: 100%|██████████| 51/51 [01:01<00:00,  1.20s/it]


Epoch 37, Contrastive Loss: -0.0347


Contrastive Epoch 38/100: 100%|██████████| 51/51 [01:00<00:00,  1.19s/it]


Epoch 38, Contrastive Loss: -0.0360


Contrastive Epoch 39/100: 100%|██████████| 51/51 [01:00<00:00,  1.18s/it]


Epoch 39, Contrastive Loss: -0.0381


Contrastive Epoch 40/100: 100%|██████████| 51/51 [01:02<00:00,  1.22s/it]


Epoch 40, Contrastive Loss: -0.0359


Contrastive Epoch 41/100: 100%|██████████| 51/51 [01:05<00:00,  1.29s/it]


Epoch 41, Contrastive Loss: -0.0362


Contrastive Epoch 42/100: 100%|██████████| 51/51 [01:02<00:00,  1.23s/it]


Epoch 42, Contrastive Loss: -0.0376


Contrastive Epoch 43/100: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Epoch 43, Contrastive Loss: -0.0367


Contrastive Epoch 44/100: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 44, Contrastive Loss: -0.0364


Contrastive Epoch 45/100: 100%|██████████| 51/51 [01:00<00:00,  1.18s/it]


Epoch 45, Contrastive Loss: -0.0371


Contrastive Epoch 46/100: 100%|██████████| 51/51 [01:05<00:00,  1.28s/it]


Epoch 46, Contrastive Loss: -0.0375


Contrastive Epoch 47/100: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 47, Contrastive Loss: -0.0386


Contrastive Epoch 48/100: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Epoch 48, Contrastive Loss: -0.0384


Contrastive Epoch 49/100: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 49, Contrastive Loss: -0.0388


Contrastive Epoch 50/100: 100%|██████████| 51/51 [01:04<00:00,  1.27s/it]


Epoch 50, Contrastive Loss: -0.0386


Contrastive Epoch 51/100: 100%|██████████| 51/51 [01:03<00:00,  1.25s/it]


Epoch 51, Contrastive Loss: -0.0383


Contrastive Epoch 52/100: 100%|██████████| 51/51 [01:02<00:00,  1.23s/it]


Epoch 52, Contrastive Loss: -0.0374


Contrastive Epoch 53/100: 100%|██████████| 51/51 [01:03<00:00,  1.25s/it]


Epoch 53, Contrastive Loss: -0.0385


Contrastive Epoch 54/100: 100%|██████████| 51/51 [01:03<00:00,  1.25s/it]


Epoch 54, Contrastive Loss: -0.0359


Contrastive Epoch 55/100: 100%|██████████| 51/51 [01:01<00:00,  1.20s/it]


Epoch 55, Contrastive Loss: -0.0385


Contrastive Epoch 56/100: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Epoch 56, Contrastive Loss: -0.0389


Contrastive Epoch 57/100: 100%|██████████| 51/51 [01:05<00:00,  1.29s/it]


Epoch 57, Contrastive Loss: -0.0378


Contrastive Epoch 58/100: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Epoch 58, Contrastive Loss: -0.0386


Contrastive Epoch 59/100: 100%|██████████| 51/51 [01:05<00:00,  1.29s/it]


Epoch 59, Contrastive Loss: -0.0392


Contrastive Epoch 60/100: 100%|██████████| 51/51 [01:04<00:00,  1.27s/it]


Epoch 60, Contrastive Loss: -0.0374


Contrastive Epoch 61/100: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 61, Contrastive Loss: -0.0392


Contrastive Epoch 62/100: 100%|██████████| 51/51 [01:05<00:00,  1.28s/it]


Epoch 62, Contrastive Loss: -0.0375


Contrastive Epoch 63/100: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Epoch 63, Contrastive Loss: -0.0383


Contrastive Epoch 64/100: 100%|██████████| 51/51 [01:03<00:00,  1.25s/it]


Epoch 64, Contrastive Loss: -0.0390


Contrastive Epoch 65/100: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 65, Contrastive Loss: -0.0388


Contrastive Epoch 66/100: 100%|██████████| 51/51 [01:06<00:00,  1.30s/it]


Epoch 66, Contrastive Loss: -0.0393


Contrastive Epoch 67/100: 100%|██████████| 51/51 [01:02<00:00,  1.22s/it]


Epoch 67, Contrastive Loss: -0.0383


Contrastive Epoch 68/100: 100%|██████████| 51/51 [01:04<00:00,  1.27s/it]


Epoch 68, Contrastive Loss: -0.0384


Contrastive Epoch 69/100: 100%|██████████| 51/51 [01:01<00:00,  1.21s/it]


Epoch 69, Contrastive Loss: -0.0400


Contrastive Epoch 70/100: 100%|██████████| 51/51 [01:01<00:00,  1.20s/it]


Epoch 70, Contrastive Loss: -0.0391


Contrastive Epoch 71/100: 100%|██████████| 51/51 [01:01<00:00,  1.20s/it]


Epoch 71, Contrastive Loss: -0.0396


Contrastive Epoch 72/100: 100%|██████████| 51/51 [01:04<00:00,  1.27s/it]


Epoch 72, Contrastive Loss: -0.0389


Contrastive Epoch 73/100: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 73, Contrastive Loss: -0.0369


Contrastive Epoch 74/100: 100%|██████████| 51/51 [01:04<00:00,  1.27s/it]


Epoch 74, Contrastive Loss: -0.0378


Contrastive Epoch 75/100: 100%|██████████| 51/51 [01:05<00:00,  1.28s/it]


Epoch 75, Contrastive Loss: -0.0380


Contrastive Epoch 76/100: 100%|██████████| 51/51 [01:02<00:00,  1.23s/it]


Epoch 76, Contrastive Loss: -0.0395


Contrastive Epoch 77/100: 100%|██████████| 51/51 [01:01<00:00,  1.21s/it]


Epoch 77, Contrastive Loss: -0.0386


Contrastive Epoch 78/100: 100%|██████████| 51/51 [01:00<00:00,  1.20s/it]


Epoch 78, Contrastive Loss: -0.0398


Contrastive Epoch 79/100: 100%|██████████| 51/51 [01:03<00:00,  1.25s/it]


Epoch 79, Contrastive Loss: -0.0406


Contrastive Epoch 80/100: 100%|██████████| 51/51 [01:05<00:00,  1.28s/it]


Epoch 80, Contrastive Loss: -0.0397


Contrastive Epoch 81/100: 100%|██████████| 51/51 [01:05<00:00,  1.28s/it]


Epoch 81, Contrastive Loss: -0.0399


Contrastive Epoch 82/100: 100%|██████████| 51/51 [01:06<00:00,  1.30s/it]


Epoch 82, Contrastive Loss: -0.0402


Contrastive Epoch 83/100: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 83, Contrastive Loss: -0.0392


Contrastive Epoch 84/100: 100%|██████████| 51/51 [01:06<00:00,  1.29s/it]


Epoch 84, Contrastive Loss: -0.0382


Contrastive Epoch 85/100: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 85, Contrastive Loss: -0.0375


Contrastive Epoch 86/100: 100%|██████████| 51/51 [01:02<00:00,  1.22s/it]


Epoch 86, Contrastive Loss: -0.0376


Contrastive Epoch 87/100: 100%|██████████| 51/51 [01:01<00:00,  1.20s/it]


Epoch 87, Contrastive Loss: -0.0398


Contrastive Epoch 88/100: 100%|██████████| 51/51 [01:01<00:00,  1.20s/it]


Epoch 88, Contrastive Loss: -0.0357


Contrastive Epoch 89/100: 100%|██████████| 51/51 [01:05<00:00,  1.28s/it]


Epoch 89, Contrastive Loss: -0.0389


Contrastive Epoch 90/100: 100%|██████████| 51/51 [01:02<00:00,  1.23s/it]


Epoch 90, Contrastive Loss: -0.0386


Contrastive Epoch 91/100: 100%|██████████| 51/51 [01:06<00:00,  1.31s/it]


Epoch 91, Contrastive Loss: -0.0406


Contrastive Epoch 92/100: 100%|██████████| 51/51 [01:04<00:00,  1.27s/it]


Epoch 92, Contrastive Loss: -0.0389


Contrastive Epoch 93/100: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Epoch 93, Contrastive Loss: -0.0393


Contrastive Epoch 94/100: 100%|██████████| 51/51 [01:02<00:00,  1.22s/it]


Epoch 94, Contrastive Loss: -0.0391


Contrastive Epoch 95/100: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 95, Contrastive Loss: -0.0392


Contrastive Epoch 96/100: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 96, Contrastive Loss: -0.0382


Contrastive Epoch 97/100: 100%|██████████| 51/51 [01:05<00:00,  1.28s/it]


Epoch 97, Contrastive Loss: -0.0379


Contrastive Epoch 98/100: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 98, Contrastive Loss: -0.0379


Contrastive Epoch 99/100: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Epoch 99, Contrastive Loss: -0.0394


Contrastive Epoch 100/100: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]

Epoch 100, Contrastive Loss: -0.0409





In [4]:
# Training Step 2: Cross Entropy Loss
print("Step 2: Training with Cross Entropy Loss")
for epoch in range(100):  # 30 epochs for cross entropy loss
    total_loss = 0
    classifier_model.train()

    for images, labels in tqdm(train_dataloader, desc=f"CE Epoch {epoch + 1}/50"):
        images, labels = images.to(device), labels.to(device)

        # Forward pass for CE loss
        logits = classifier_model(images)

        # Compute CE loss
        ce_loss = ce_loss_fn(logits, labels)

        # Backward pass
        optimizer_classifier.zero_grad()
        ce_loss.backward()
        optimizer_classifier.step()

        total_loss += ce_loss.item()

    # Step the scheduler
    scheduler_classifier.step()

    print(f"Epoch {epoch + 1}, CE Loss: {total_loss / len(train_dataloader):.4f}")

Step 2: Training with Cross Entropy Loss


CE Epoch 1/50: 100%|██████████| 51/51 [01:00<00:00,  1.19s/it]


Epoch 1, CE Loss: 4.4121


CE Epoch 2/50: 100%|██████████| 51/51 [01:01<00:00,  1.21s/it]


Epoch 2, CE Loss: 4.2580


CE Epoch 3/50: 100%|██████████| 51/51 [01:06<00:00,  1.30s/it]


Epoch 3, CE Loss: 4.1076


CE Epoch 4/50: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Epoch 4, CE Loss: 3.9584


CE Epoch 5/50: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 5, CE Loss: 3.8101


CE Epoch 6/50: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Epoch 6, CE Loss: 3.6655


CE Epoch 7/50: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Epoch 7, CE Loss: 3.5234


CE Epoch 8/50: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 8, CE Loss: 3.3843


CE Epoch 9/50: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 9, CE Loss: 3.2473


CE Epoch 10/50: 100%|██████████| 51/51 [01:04<00:00,  1.27s/it]


Epoch 10, CE Loss: 3.1169


CE Epoch 11/50: 100%|██████████| 51/51 [01:02<00:00,  1.22s/it]


Epoch 11, CE Loss: 2.9911


CE Epoch 12/50: 100%|██████████| 51/51 [01:06<00:00,  1.31s/it]


Epoch 12, CE Loss: 2.8678


CE Epoch 13/50: 100%|██████████| 51/51 [01:04<00:00,  1.27s/it]


Epoch 13, CE Loss: 2.7506


CE Epoch 14/50: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Epoch 14, CE Loss: 2.6373


CE Epoch 15/50: 100%|██████████| 51/51 [01:06<00:00,  1.30s/it]


Epoch 15, CE Loss: 2.5314


CE Epoch 16/50: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 16, CE Loss: 2.4269


CE Epoch 17/50: 100%|██████████| 51/51 [01:03<00:00,  1.25s/it]


Epoch 17, CE Loss: 2.3319


CE Epoch 18/50: 100%|██████████| 51/51 [01:05<00:00,  1.29s/it]


Epoch 18, CE Loss: 2.2356


CE Epoch 19/50: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 19, CE Loss: 2.1489


CE Epoch 20/50: 100%|██████████| 51/51 [01:04<00:00,  1.27s/it]


Epoch 20, CE Loss: 2.0649


CE Epoch 21/50: 100%|██████████| 51/51 [01:05<00:00,  1.28s/it]


Epoch 21, CE Loss: 2.0061


CE Epoch 22/50: 100%|██████████| 51/51 [01:08<00:00,  1.34s/it]


Epoch 22, CE Loss: 1.9661


CE Epoch 23/50: 100%|██████████| 51/51 [01:07<00:00,  1.33s/it]


Epoch 23, CE Loss: 1.9297


CE Epoch 24/50: 100%|██████████| 51/51 [01:06<00:00,  1.30s/it]


Epoch 24, CE Loss: 1.8910


CE Epoch 25/50: 100%|██████████| 51/51 [01:05<00:00,  1.29s/it]


Epoch 25, CE Loss: 1.8557


CE Epoch 26/50: 100%|██████████| 51/51 [01:05<00:00,  1.28s/it]


Epoch 26, CE Loss: 1.8216


CE Epoch 27/50: 100%|██████████| 51/51 [01:02<00:00,  1.23s/it]


Epoch 27, CE Loss: 1.7855


CE Epoch 28/50: 100%|██████████| 51/51 [01:01<00:00,  1.21s/it]


Epoch 28, CE Loss: 1.7525


CE Epoch 29/50: 100%|██████████| 51/51 [01:02<00:00,  1.22s/it]


Epoch 29, CE Loss: 1.7214


CE Epoch 30/50: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 30, CE Loss: 1.6872


CE Epoch 31/50: 100%|██████████| 51/51 [01:07<00:00,  1.32s/it]


Epoch 31, CE Loss: 1.6548


CE Epoch 32/50: 100%|██████████| 51/51 [01:04<00:00,  1.27s/it]


Epoch 32, CE Loss: 1.6260


CE Epoch 33/50: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Epoch 33, CE Loss: 1.5969


CE Epoch 34/50: 100%|██████████| 51/51 [01:05<00:00,  1.28s/it]


Epoch 34, CE Loss: 1.5679


CE Epoch 35/50: 100%|██████████| 51/51 [01:02<00:00,  1.23s/it]


Epoch 35, CE Loss: 1.5370


CE Epoch 36/50: 100%|██████████| 51/51 [01:00<00:00,  1.19s/it]


Epoch 36, CE Loss: 1.5101


CE Epoch 37/50: 100%|██████████| 51/51 [01:02<00:00,  1.22s/it]


Epoch 37, CE Loss: 1.4821


CE Epoch 38/50: 100%|██████████| 51/51 [01:01<00:00,  1.20s/it]


Epoch 38, CE Loss: 1.4579


CE Epoch 39/50: 100%|██████████| 51/51 [01:06<00:00,  1.30s/it]


Epoch 39, CE Loss: 1.4292


CE Epoch 40/50: 100%|██████████| 51/51 [01:05<00:00,  1.28s/it]


Epoch 40, CE Loss: 1.4062


CE Epoch 41/50: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 41, CE Loss: 1.3885


CE Epoch 42/50: 100%|██████████| 51/51 [01:06<00:00,  1.30s/it]


Epoch 42, CE Loss: 1.3719


CE Epoch 43/50: 100%|██████████| 51/51 [01:03<00:00,  1.25s/it]


Epoch 43, CE Loss: 1.3638


CE Epoch 44/50: 100%|██████████| 51/51 [01:03<00:00,  1.25s/it]


Epoch 44, CE Loss: 1.3513


CE Epoch 45/50: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 45, CE Loss: 1.3361


CE Epoch 46/50: 100%|██████████| 51/51 [01:05<00:00,  1.28s/it]


Epoch 46, CE Loss: 1.3262


CE Epoch 47/50: 100%|██████████| 51/51 [01:06<00:00,  1.30s/it]


Epoch 47, CE Loss: 1.3133


CE Epoch 48/50: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 48, CE Loss: 1.3035


CE Epoch 49/50: 100%|██████████| 51/51 [01:02<00:00,  1.22s/it]


Epoch 49, CE Loss: 1.2898


CE Epoch 50/50: 100%|██████████| 51/51 [01:01<00:00,  1.22s/it]


Epoch 50, CE Loss: 1.2788


CE Epoch 51/50: 100%|██████████| 51/51 [01:01<00:00,  1.21s/it]


Epoch 51, CE Loss: 1.2681


CE Epoch 52/50: 100%|██████████| 51/51 [00:59<00:00,  1.17s/it]


Epoch 52, CE Loss: 1.2571


CE Epoch 53/50: 100%|██████████| 51/51 [00:59<00:00,  1.17s/it]


Epoch 53, CE Loss: 1.2450


CE Epoch 54/50: 100%|██████████| 51/51 [01:01<00:00,  1.21s/it]


Epoch 54, CE Loss: 1.2346


CE Epoch 55/50: 100%|██████████| 51/51 [01:01<00:00,  1.21s/it]


Epoch 55, CE Loss: 1.2227


CE Epoch 56/50: 100%|██████████| 51/51 [01:01<00:00,  1.20s/it]


Epoch 56, CE Loss: 1.2112


CE Epoch 57/50: 100%|██████████| 51/51 [01:05<00:00,  1.29s/it]


Epoch 57, CE Loss: 1.2014


CE Epoch 58/50: 100%|██████████| 51/51 [01:06<00:00,  1.30s/it]


Epoch 58, CE Loss: 1.1905


CE Epoch 59/50: 100%|██████████| 51/51 [01:05<00:00,  1.28s/it]


Epoch 59, CE Loss: 1.1807


CE Epoch 60/50: 100%|██████████| 51/51 [01:06<00:00,  1.30s/it]


Epoch 60, CE Loss: 1.1702


CE Epoch 61/50: 100%|██████████| 51/51 [01:02<00:00,  1.23s/it]


Epoch 61, CE Loss: 1.1623


CE Epoch 62/50: 100%|██████████| 51/51 [01:04<00:00,  1.27s/it]


Epoch 62, CE Loss: 1.1576


CE Epoch 63/50: 100%|██████████| 51/51 [01:01<00:00,  1.21s/it]


Epoch 63, CE Loss: 1.1537


CE Epoch 64/50: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Epoch 64, CE Loss: 1.1481


CE Epoch 65/50: 100%|██████████| 51/51 [01:04<00:00,  1.27s/it]


Epoch 65, CE Loss: 1.1419


CE Epoch 66/50: 100%|██████████| 51/51 [01:05<00:00,  1.29s/it]


Epoch 66, CE Loss: 1.1355


CE Epoch 67/50: 100%|██████████| 51/51 [01:05<00:00,  1.29s/it]


Epoch 67, CE Loss: 1.1315


CE Epoch 68/50: 100%|██████████| 51/51 [01:06<00:00,  1.31s/it]


Epoch 68, CE Loss: 1.1280


CE Epoch 69/50: 100%|██████████| 51/51 [01:02<00:00,  1.23s/it]


Epoch 69, CE Loss: 1.1201


CE Epoch 70/50: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Epoch 70, CE Loss: 1.1154


CE Epoch 71/50: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 71, CE Loss: 1.1105


CE Epoch 72/50: 100%|██████████| 51/51 [01:07<00:00,  1.32s/it]


Epoch 72, CE Loss: 1.1060


CE Epoch 73/50: 100%|██████████| 51/51 [01:02<00:00,  1.22s/it]


Epoch 73, CE Loss: 1.1001


CE Epoch 74/50: 100%|██████████| 51/51 [01:05<00:00,  1.28s/it]


Epoch 74, CE Loss: 1.0948


CE Epoch 75/50: 100%|██████████| 51/51 [01:03<00:00,  1.25s/it]


Epoch 75, CE Loss: 1.0916


CE Epoch 76/50: 100%|██████████| 51/51 [01:00<00:00,  1.19s/it]


Epoch 76, CE Loss: 1.0867


CE Epoch 77/50: 100%|██████████| 51/51 [01:05<00:00,  1.28s/it]


Epoch 77, CE Loss: 1.0800


CE Epoch 78/50: 100%|██████████| 51/51 [01:05<00:00,  1.28s/it]


Epoch 78, CE Loss: 1.0753


CE Epoch 79/50: 100%|██████████| 51/51 [01:05<00:00,  1.28s/it]


Epoch 79, CE Loss: 1.0714


CE Epoch 80/50: 100%|██████████| 51/51 [01:01<00:00,  1.20s/it]


Epoch 80, CE Loss: 1.0667


CE Epoch 81/50: 100%|██████████| 51/51 [01:05<00:00,  1.29s/it]


Epoch 81, CE Loss: 1.0642


CE Epoch 82/50: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Epoch 82, CE Loss: 1.0622


CE Epoch 83/50: 100%|██████████| 51/51 [01:05<00:00,  1.28s/it]


Epoch 83, CE Loss: 1.0563


CE Epoch 84/50: 100%|██████████| 51/51 [01:04<00:00,  1.27s/it]


Epoch 84, CE Loss: 1.0538


CE Epoch 85/50: 100%|██████████| 51/51 [01:05<00:00,  1.29s/it]


Epoch 85, CE Loss: 1.0534


CE Epoch 86/50: 100%|██████████| 51/51 [01:04<00:00,  1.27s/it]


Epoch 86, CE Loss: 1.0506


CE Epoch 87/50: 100%|██████████| 51/51 [01:03<00:00,  1.25s/it]


Epoch 87, CE Loss: 1.0487


CE Epoch 88/50: 100%|██████████| 51/51 [01:02<00:00,  1.23s/it]


Epoch 88, CE Loss: 1.0467


CE Epoch 89/50: 100%|██████████| 51/51 [01:03<00:00,  1.24s/it]


Epoch 89, CE Loss: 1.0428


CE Epoch 90/50: 100%|██████████| 51/51 [01:03<00:00,  1.25s/it]


Epoch 90, CE Loss: 1.0407


CE Epoch 91/50: 100%|██████████| 51/51 [01:02<00:00,  1.23s/it]


Epoch 91, CE Loss: 1.0381


CE Epoch 92/50: 100%|██████████| 51/51 [01:04<00:00,  1.26s/it]


Epoch 92, CE Loss: 1.0378


CE Epoch 93/50: 100%|██████████| 51/51 [01:02<00:00,  1.23s/it]


Epoch 93, CE Loss: 1.0347


CE Epoch 94/50: 100%|██████████| 51/51 [01:05<00:00,  1.28s/it]


Epoch 94, CE Loss: 1.0321


CE Epoch 95/50: 100%|██████████| 51/51 [01:05<00:00,  1.28s/it]


Epoch 95, CE Loss: 1.0294


CE Epoch 96/50: 100%|██████████| 51/51 [01:04<00:00,  1.27s/it]


Epoch 96, CE Loss: 1.0270


CE Epoch 97/50: 100%|██████████| 51/51 [01:01<00:00,  1.21s/it]


Epoch 97, CE Loss: 1.0238


CE Epoch 98/50: 100%|██████████| 51/51 [01:02<00:00,  1.22s/it]


Epoch 98, CE Loss: 1.0213


CE Epoch 99/50: 100%|██████████| 51/51 [01:05<00:00,  1.29s/it]


Epoch 99, CE Loss: 1.0175


CE Epoch 100/50: 100%|██████████| 51/51 [01:05<00:00,  1.29s/it]

Epoch 100, CE Loss: 1.0164





In [5]:
import os
# Save the trained model
model_save_path = "/saved_model/model_last.pth"
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
torch.save(classifier_model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

Model saved to /saved_model/model_last.pth
