#SimCLR Pretraining 10 000

In [None]:
import torch.nn as nn
from torchvision.models import mobilenet_v2

# Load MobileNetV2 without pretrained weights and remove classifier
base_model = mobilenet_v2(weights=None)

# Remove the classification head and keep only the feature extractor
# The feature output size is 1280 for MobileNetV2
backbone = nn.Sequential(
    base_model.features,
    nn.AdaptiveAvgPool2d(1),  # Ensure consistent output shape
    nn.Flatten(),             # Shape: [B, 1280]
)


In [None]:
class SimCLRModel(nn.Module):
    def __init__(self, backbone, feature_dim=128):
        super().__init__()
        self.backbone = backbone
        self.projection_head = nn.Sequential(
            nn.Linear(1280, 256),
            nn.ReLU(),
            nn.Linear(256, feature_dim)
        )
        #simCLR requires a projection head after the encoder (backbone)

    def forward(self, x):
        features = self.backbone(x)         # Shape: [B, 1280]
        projections = self.projection_head(features)  # Shape: [B, feature_dim]
        return projections


In [None]:
!pip install lightly


Collecting lightly
  Downloading lightly-1.5.22-py3-none-any.whl.metadata (38 kB)
Collecting hydra-core>=1.0.0 (from lightly)
  Downloading hydra_core-1.3.2-py3-none-any.whl.metadata (5.5 kB)
Collecting lightly_utils~=0.0.0 (from lightly)
  Downloading lightly_utils-0.0.2-py3-none-any.whl.metadata (1.4 kB)
Collecting pytorch_lightning>=1.0.4 (from lightly)
  Downloading pytorch_lightning-2.5.6-py3-none-any.whl.metadata (20 kB)
Collecting aenum>=3.1.11 (from lightly)
  Downloading aenum-3.1.16-py3-none-any.whl.metadata (3.8 kB)
Collecting torchmetrics>0.7.0 (from pytorch_lightning>=1.0.4->lightly)
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch_lightning>=1.0.4->lightly)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading lightly-1.5.22-py3-none-any.whl (859 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m859.3/859.3 kB[0m [31m38.8 MB/s[0m eta [36m0:00:00[0m

In [None]:
from lightly.data import LightlyDataset
from lightly.transforms import SimCLRTransform
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

In [None]:
# SimCLR Transform from lightly
simclr_transform = SimCLRTransform(input_size=224)  # you can change input_size as needed

# Custom dataset to return 2 views
class SimCLRDataset(ImageFolder):
    def __getitem__(self, index):
        sample, _ = super().__getitem__(index)
        xi, xj = simclr_transform(sample)
        return xi, xj


In [None]:
data_path = '/content/drive/MyDrive/pets0/unlabeled_train'

# Create dataset and dataloader
dataset = SimCLRDataset(root=data_path)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2)
num_imgs = len(dataset)
print(f"Number of images in the dataset: {num_imgs}")

Number of images in the dataset: 10000


In [None]:
from lightly.loss import NTXentLoss
import torch

# Initialize model
model = SimCLRModel(backbone)
model = model.cuda()

# Loss
criterion = NTXentLoss()

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

# Training loop
for epoch in range(20):
    for (views) in dataloader:  # Assuming LightlyDataset with SimCLRTransform

        view1, view2 = views[0].cuda(), views[1].cuda()


        z1 = model(view1)
        z2 = model(view2)


        loss = criterion(z1, z2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch}: Loss = {loss.item():.4f}")


Epoch 0: Loss = 3.1890
Epoch 1: Loss = 3.3769
Epoch 2: Loss = 2.8553
Epoch 3: Loss = 3.3284
Epoch 4: Loss = 2.7613
Epoch 5: Loss = 2.7699
Epoch 6: Loss = 2.6084
Epoch 7: Loss = 2.7445
Epoch 8: Loss = 2.5992
Epoch 9: Loss = 2.7181
Epoch 10: Loss = 2.5164
Epoch 11: Loss = 2.6766
Epoch 12: Loss = 2.5250
Epoch 13: Loss = 2.5185
Epoch 14: Loss = 2.6390
Epoch 15: Loss = 2.5899
Epoch 16: Loss = 2.6170
Epoch 17: Loss = 2.4712
Epoch 18: Loss = 2.3676
Epoch 19: Loss = 2.8177


In [None]:
torch.save(model.state_dict(), '/content/drive/MyDrive/mods/mobilenet_sim_10000.pth')

In [None]:
# Training loop
for epoch in range(5):
    for (views) in dataloader:  # Assuming LightlyDataset with SimCLRTransform

        view1, view2 = views[0].cuda(), views[1].cuda()


        z1 = model(view1)
        z2 = model(view2)


        loss = criterion(z1, z2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


    print(f"Epoch {epoch+20}: Loss = {loss.item():.4f}")

Epoch 20: Loss = 2.4320
Epoch 21: Loss = 2.4563
Epoch 22: Loss = 2.4456
Epoch 23: Loss = 2.5486
Epoch 24: Loss = 2.5704


In [None]:
# Training loop
mloss = 2.5
for epoch in range(5):
    for (views) in dataloader:  # Assuming LightlyDataset with SimCLRTransform

        view1, view2 = views[0].cuda(), views[1].cuda()


        z1 = model(view1)
        z2 = model(view2)


        loss = criterion(z1, z2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if mloss > loss.item():
          mloss = loss.item()
          torch.save(model.state_dict(), '/content/drive/MyDrive/mods/mobilenet_sim_10000.pth')

    print(f"Epoch {epoch+25}: Loss = {loss.item():.4f}")

Epoch 25: Loss = 2.5670
Epoch 26: Loss = 2.2463
Epoch 27: Loss = 2.3873
Epoch 28: Loss = 2.2702
Epoch 29: Loss = 2.4605


In [None]:
torch.save(model.state_dict(), '/content/drive/MyDrive/mods/mobilenet_sim_10000.pth')

In [None]:
# Training loop
for epoch in range(10):
    for (views) in dataloader:  # Assuming LightlyDataset with SimCLRTransform

        view1, view2 = views[0].cuda(), views[1].cuda()


        z1 = model(view1)
        z2 = model(view2)


        loss = criterion(z1, z2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


    if mloss > loss.item():
          mloss = loss.item()
          torch.save(model.state_dict(), '/content/drive/MyDrive/mods/mobilenet_sim_10000.pth')

    print(f"Epoch {epoch+30}: Loss = {loss.item():.4f}")

Epoch 30: Loss = 2.3099
Epoch 31: Loss = 2.4192
Epoch 32: Loss = 2.5080
Epoch 33: Loss = 2.3715
Epoch 34: Loss = 2.1255
Epoch 35: Loss = 2.3541
Epoch 36: Loss = 2.3326
Epoch 37: Loss = 2.1406
Epoch 38: Loss = 2.3493
Epoch 39: Loss = 2.1011


In [None]:
# Training loop
for epoch in range(10):
    for (views) in dataloader:  # Assuming LightlyDataset with SimCLRTransform

        view1, view2 = views[0].cuda(), views[1].cuda()


        z1 = model(view1)
        z2 = model(view2)


        loss = criterion(z1, z2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


    if mloss > loss.item():
          mloss = loss.item()
          torch.save(model.state_dict(), '/content/drive/MyDrive/mods/mobilenet_sim_10000.pth')

    print(f"Epoch {epoch+40}: Loss = {loss.item():.4f}")

Epoch 40: Loss = 2.3627
Epoch 41: Loss = 2.2122
Epoch 42: Loss = 2.0049
Epoch 43: Loss = 2.2572
Epoch 44: Loss = 2.3366
Epoch 45: Loss = 2.3494
Epoch 46: Loss = 2.1856
Epoch 47: Loss = 2.1016
Epoch 48: Loss = 2.0798
Epoch 49: Loss = 2.3723


####Continue training (hit GPU limit on colab)

In [None]:
!pip install lightly


Collecting lightly
  Downloading lightly-1.5.22-py3-none-any.whl.metadata (38 kB)
Collecting hydra-core>=1.0.0 (from lightly)
  Downloading hydra_core-1.3.2-py3-none-any.whl.metadata (5.5 kB)
Collecting lightly_utils~=0.0.0 (from lightly)
  Downloading lightly_utils-0.0.2-py3-none-any.whl.metadata (1.4 kB)
Collecting pytorch_lightning>=1.0.4 (from lightly)
  Downloading pytorch_lightning-2.5.6-py3-none-any.whl.metadata (20 kB)
Collecting aenum>=3.1.11 (from lightly)
  Downloading aenum-3.1.16-py3-none-any.whl.metadata (3.8 kB)
Collecting torchmetrics>0.7.0 (from pytorch_lightning>=1.0.4->lightly)
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch_lightning>=1.0.4->lightly)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading lightly-1.5.22-py3-none-any.whl (859 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m859.3/859.3 kB[0m [31m50.1 MB/s[0m eta [36m0:00:00[0m

In [None]:
import torch.nn as nn
from torchvision.models import mobilenet_v2

# Load MobileNetV2 without pretrained weights and remove classifier
base_model = mobilenet_v2(weights=None)

# Remove the classification head and keep only the feature extractor
# The feature output size is 1280 for MobileNetV2
backbone = nn.Sequential(
    base_model.features,
    nn.AdaptiveAvgPool2d(1),  # Ensure consistent output shape
    nn.Flatten(),             # Shape: [B, 1280]
)


In [None]:
class SimCLRModel(nn.Module):
    def __init__(self, backbone, feature_dim=128):
        super().__init__()
        self.backbone = backbone
        self.projection_head = nn.Sequential(
            nn.Linear(1280, 256),
            nn.ReLU(),
            nn.Linear(256, feature_dim)
        )

    def forward(self, x):
        features = self.backbone(x)         # Shape: [B, 1280]
        projections = self.projection_head(features)  # Shape: [B, feature_dim]
        return projections


In [None]:
from lightly.data import LightlyDataset
from lightly.transforms import SimCLRTransform
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

In [None]:
import torch
# Define the path to your saved pretrained student model checkpoint
pretrained_student_path = '/content/drive/MyDrive/mods/mobilenet_sim_10000.pth'

# Load the state dictionary from the saved checkpoint
# Using map_location='cpu' to load onto CPU first is safer, then move to device
student_state_dict = torch.load(pretrained_student_path, map_location='cpu')

# Create the student model architecture
# Assuming the student model is SimCLRModel(backbone) as defined previously
# If your model architecture is different, adjust this part accordingly
model = SimCLRModel(backbone)

# Load the state dictionary into the model
# Use strict=True if the saved state_dict exactly matches the current model structure
# Use strict=False if there are missing or extra keys (e.g., if you saved the entire model instead of just the state_dict)
model.load_state_dict(student_state_dict, strict=True)

# Move the model to the appropriate device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

print("Pretrained student model loaded successfully.")

Pretrained student model loaded successfully.


In [None]:
data_path = '/content/drive/MyDrive/pets0/unlabeled_train'
# SimCLR Transform from lightly
simclr_transform = SimCLRTransform(input_size=224)  # you can change input_size as needed

# Custom dataset to return 2 views
class SimCLRDataset(ImageFolder):
    def __getitem__(self, index):
        sample, _ = super().__getitem__(index)
        xi, xj = simclr_transform(sample)
        return xi, xj

# Create dataset and dataloader
dataset = SimCLRDataset(root=data_path)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2)
num_imgs = len(dataset)
print(f"Number of images in the dataset: {num_imgs}")

Number of images in the dataset: 10000


In [None]:
from lightly.loss import NTXentLoss
import torch

# Loss
criterion = NTXentLoss()

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

mloss = 2.0
# Training loop
for epoch in range(10):
    for (views) in dataloader:  # Assuming LightlyDataset with SimCLRTransform

        view1, view2 = views[0].cuda(), views[1].cuda()


        z1 = model(view1)
        z2 = model(view2)


        loss = criterion(z1, z2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if mloss > loss.item():
          mloss = loss.item()
          torch.save(model.state_dict(), '/content/drive/MyDrive/mods/mobilenet_sim_10000_cont.pth')

    if (epoch+1) % 5 == 0:
      torch.save(model.state_dict(), '/content/drive/MyDrive/mods/mobilenet_sim_10000_epoch.pth')

    print(f"Epoch {epoch+43}: Loss = {loss.item():.4f}")


Epoch 43: Loss = 2.2164
Epoch 44: Loss = 2.3832
Epoch 45: Loss = 2.2882
Epoch 46: Loss = 2.1571
Epoch 47: Loss = 2.3185
Epoch 48: Loss = 2.1638
Epoch 49: Loss = 2.0395
Epoch 50: Loss = 2.0959
Epoch 51: Loss = 1.9750
Epoch 52: Loss = 2.1218


In [None]:
torch.save(model.state_dict(), '/content/drive/MyDrive/mods/mobilenet_sim_10000_epoch.pth')

In [None]:
for epoch in range(8):
    for (views) in dataloader:  # Assuming LightlyDataset with SimCLRTransform

        view1, view2 = views[0].cuda(), views[1].cuda()


        z1 = model(view1)
        z2 = model(view2)


        loss = criterion(z1, z2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if mloss > loss.item():
          mloss = loss.item()
          torch.save(model.state_dict(), '/content/drive/MyDrive/mods/mobilenet_sim_10000_cont.pth')

    if (epoch+1) % 5 == 0:
      torch.save(model.state_dict(), '/content/drive/MyDrive/mods/mobilenet_sim_10000_epoch.pth')

    print(f"Epoch {epoch+53}: Loss = {loss.item():.4f}")

Epoch 53: Loss = 1.9999
Epoch 54: Loss = 2.2887
Epoch 55: Loss = 2.1144
Epoch 56: Loss = 2.1746
Epoch 57: Loss = 2.1592
Epoch 58: Loss = 2.0773
Epoch 59: Loss = 2.4056


####Continue training 2

In [None]:
import torch
# Define the path to your saved pretrained student model checkpoint
pretrained_student_path = '/content/drive/MyDrive/mods/mobilenet_sim_10000_epoch.pth'

# Load the state dictionary from the saved checkpoint
# Using map_location='cpu' to load onto CPU first is safer, then move to device
student_state_dict = torch.load(pretrained_student_path, map_location='cpu')

# Create the student model architecture
# Assuming the student model is SimCLRModel(backbone) as defined previously
# If your model architecture is different, adjust this part accordingly
model = SimCLRModel(backbone)

# Load the state dictionary into the model
# Use strict=True if the saved state_dict exactly matches the current model structure
# Use strict=False if there are missing or extra keys (e.g., if you saved the entire model instead of just the state_dict)
model.load_state_dict(student_state_dict, strict=True)

# Move the model to the appropriate device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

print("Pretrained student model loaded successfully.")

Pretrained student model loaded successfully.


In [None]:
from lightly.loss import NTXentLoss
import torch

# Loss
criterion = NTXentLoss()

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

mloss = 1.97
# Training loop
for epoch in range(10):
    for (views) in dataloader:  # Assuming LightlyDataset with SimCLRTransform

        view1, view2 = views[0].cuda(), views[1].cuda()


        z1 = model(view1)
        z2 = model(view2)


        loss = criterion(z1, z2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if mloss > loss.item():
          mloss = loss.item()
          torch.save(model.state_dict(), '/content/drive/MyDrive/mods/mobilenet_sim_10000_cont2.pth')

    if (epoch+1) % 2 == 0:
      torch.save(model.state_dict(), '/content/drive/MyDrive/mods/mobilenet_sim_10000_epoch2.pth')

    print(f"Epoch {epoch+58}: Loss = {loss.item():.4f}")


Epoch 58: Loss = 2.2098
Epoch 59: Loss = 2.3115
Epoch 60: Loss = 2.3314
Epoch 61: Loss = 2.3468
Epoch 62: Loss = 2.0173
Epoch 63: Loss = 2.1298
Epoch 64: Loss = 2.2794
Epoch 65: Loss = 2.2211
Epoch 66: Loss = 2.2764
Epoch 67: Loss = 2.2217


In [None]:
torch.save(model.state_dict(), '/content/drive/MyDrive/mods/mobilenet_sim_10000_epoch2.pth')

# Finetune the pretrained student model 10 000

In [None]:
import torch
import torch.nn as nn
from torchvision.models import mobilenet_v2
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision import transforms, datasets

In [None]:
# Define transforms for finetuning

finetune_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.CenterCrop(192),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# Create datasets and dataloaders for finetuning
finetune_dataset_labeled = ImageFolder('/content/drive/MyDrive/pets0/finetune_train', transform=finetune_transform)
val_dataset_labeled = ImageFolder('/content/drive/MyDrive/pets0/val', transform=val_transform)

finetune_loader_labeled = DataLoader(finetune_dataset_labeled, batch_size=64, shuffle=True, num_workers=2)
val_loader_labeled = DataLoader(val_dataset_labeled, batch_size=64, shuffle=False, num_workers=2)

print(f"Number of samples in finetune dataset: {len(finetune_dataset_labeled)}")
print(f"Number of samples in validation dataset: {len(val_dataset_labeled)}")

Number of samples in finetune dataset: 420
Number of samples in validation dataset: 180


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the pretrained SimCLR model state dict
pretrained_path = '/content/drive/MyDrive/mods/mobilenet_sim_10000.pth'
simclr_state_dict = torch.load(pretrained_path, map_location=device)

# Create a standard MobileNetV2 model
student_finetune = mobilenet_v2(weights=None)

# Filter the state dict to keep only the backbone weights
backbone_state_dict = {}
for k, v in simclr_state_dict.items():
    # Keys in the saved SimCLRModel state dict for the backbone start with 'backbone.0.'
    # We want to load these into the 'features.' of the standard MobileNetV2
    if k.startswith('backbone.0.'):
        backbone_state_dict[k.replace('backbone.0.', 'features.')] = v
    # Also handle the case if the state dict keys were just 'backbone.' without the '0.'
    elif k.startswith('backbone.'):
        backbone_state_dict[k.replace('backbone.', 'features.')] = v


# Load the backbone weights into the standard MobileNetV2 model
# Use strict=False because we are not loading the classifier weights
student_finetune.load_state_dict(backbone_state_dict, strict=False)

# Replace the classifier
num_ftrs = student_finetune.classifier[1].in_features
student_finetune.classifier[1] = nn.Linear(num_ftrs, 2) # 2 classes: cat, dog

student_finetune = student_finetune.to(device)

# Optionally freeze the backbone and only train the classifier for faster initial training
# for param in student_finetune.features.parameters():
#     param.requires_grad = False

print("Pretrained MobileNetV2 loaded and classifier replaced.")

Pretrained MobileNetV2 loaded and classifier replaced.


In [None]:
# Define optimizer and loss function for finetuning
optimizer_finetune = torch.optim.Adam(student_finetune.parameters(), lr=1e-4) # Start with a lower learning rate
criterion_finetune = nn.CrossEntropyLoss()

# Training loop
for epoch in range(10):
    student_finetune.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in finetune_loader_labeled:
        images, labels = images.to(device), labels.to(device)

        optimizer_finetune.zero_grad()
        outputs = student_finetune(images)
        loss = criterion_finetune(outputs, labels)
        loss.backward()
        optimizer_finetune.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total

    # Validation
    student_finetune.eval()
    val_running_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for images, labels in val_loader_labeled:
            images, labels = images.to(device), labels.to(device)
            outputs = student_finetune(images)
            loss = criterion_finetune(outputs, labels)

            val_running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    val_loss = val_running_loss / val_total
    val_acc = val_correct / val_total

    print(f'Epoch {epoch+1}: '
          f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} | '
          f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')


Epoch 1: Train Loss: 0.6793 Acc: 0.5762 | Val Loss: 0.6300 Acc: 0.7444
Epoch 2: Train Loss: 0.6182 Acc: 0.6952 | Val Loss: 0.5751 Acc: 0.7611
Epoch 3: Train Loss: 0.5733 Acc: 0.7238 | Val Loss: 0.5383 Acc: 0.7722
Epoch 4: Train Loss: 0.5214 Acc: 0.7690 | Val Loss: 0.5085 Acc: 0.8000
Epoch 5: Train Loss: 0.4980 Acc: 0.7762 | Val Loss: 0.4862 Acc: 0.7778
Epoch 6: Train Loss: 0.4440 Acc: 0.8071 | Val Loss: 0.4735 Acc: 0.7778
Epoch 7: Train Loss: 0.4241 Acc: 0.8119 | Val Loss: 0.4632 Acc: 0.7778
Epoch 8: Train Loss: 0.4033 Acc: 0.8119 | Val Loss: 0.4855 Acc: 0.7500
Epoch 9: Train Loss: 0.3614 Acc: 0.8405 | Val Loss: 0.4933 Acc: 0.7500
Epoch 10: Train Loss: 0.3428 Acc: 0.8571 | Val Loss: 0.4808 Acc: 0.7444


In [None]:
# Training loop
for epoch in range(6):
    student_finetune.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in finetune_loader_labeled:
        images, labels = images.to(device), labels.to(device)

        optimizer_finetune.zero_grad()
        outputs = student_finetune(images)
        loss = criterion_finetune(outputs, labels)
        loss.backward()
        optimizer_finetune.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total

    # Validation
    student_finetune.eval()
    val_running_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for images, labels in val_loader_labeled:
            images, labels = images.to(device), labels.to(device)
            outputs = student_finetune(images)
            loss = criterion_finetune(outputs, labels)

            val_running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    val_loss = val_running_loss / val_total
    val_acc = val_correct / val_total

    print(f'Epoch {epoch+11}: '
          f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} | '
          f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')

Epoch 11: Train Loss: 0.3278 Acc: 0.8595 | Val Loss: 0.4756 Acc: 0.7333
Epoch 12: Train Loss: 0.3610 Acc: 0.8476 | Val Loss: 0.4870 Acc: 0.7500
Epoch 13: Train Loss: 0.2893 Acc: 0.8810 | Val Loss: 0.4996 Acc: 0.7667
Epoch 14: Train Loss: 0.2752 Acc: 0.8810 | Val Loss: 0.5087 Acc: 0.7556
Epoch 15: Train Loss: 0.2533 Acc: 0.9000 | Val Loss: 0.5000 Acc: 0.7556
Epoch 16: Train Loss: 0.2369 Acc: 0.9190 | Val Loss: 0.4840 Acc: 0.8000


In [None]:
torch.save(student_finetune.state_dict(), '/content/drive/MyDrive/mods/student_finetuned_10000.pth')

In [None]:
# Training loop
for epoch in range(5):
    student_finetune.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in finetune_loader_labeled:
        images, labels = images.to(device), labels.to(device)

        optimizer_finetune.zero_grad()
        outputs = student_finetune(images)
        loss = criterion_finetune(outputs, labels)
        loss.backward()
        optimizer_finetune.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total

    # Validation
    student_finetune.eval()
    val_running_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for images, labels in val_loader_labeled:
            images, labels = images.to(device), labels.to(device)
            outputs = student_finetune(images)
            loss = criterion_finetune(outputs, labels)

            val_running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    val_loss = val_running_loss / val_total
    val_acc = val_correct / val_total

    print(f'Epoch {epoch+17}: '
          f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} | '
          f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')

Epoch 17: Train Loss: 0.2108 Acc: 0.9286 | Val Loss: 0.4865 Acc: 0.7889
Epoch 18: Train Loss: 0.1957 Acc: 0.9286 | Val Loss: 0.5109 Acc: 0.7778
Epoch 19: Train Loss: 0.1863 Acc: 0.9452 | Val Loss: 0.5302 Acc: 0.7167
Epoch 20: Train Loss: 0.1732 Acc: 0.9405 | Val Loss: 0.5250 Acc: 0.7556
Epoch 21: Train Loss: 0.1626 Acc: 0.9476 | Val Loss: 0.5326 Acc: 0.7722


In [None]:
# Training loop
for epoch in range(5):
    student_finetune.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in finetune_loader_labeled:
        images, labels = images.to(device), labels.to(device)

        optimizer_finetune.zero_grad()
        outputs = student_finetune(images)
        loss = criterion_finetune(outputs, labels)
        loss.backward()
        optimizer_finetune.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total

    # Validation
    student_finetune.eval()
    val_running_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for images, labels in val_loader_labeled:
            images, labels = images.to(device), labels.to(device)
            outputs = student_finetune(images)
            loss = criterion_finetune(outputs, labels)

            val_running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    val_loss = val_running_loss / val_total
    val_acc = val_correct / val_total

    print(f'Epoch {epoch+22}: '
          f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} | '
          f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')

Epoch 22: Train Loss: 0.1385 Acc: 0.9595 | Val Loss: 0.5486 Acc: 0.7667
Epoch 23: Train Loss: 0.1165 Acc: 0.9714 | Val Loss: 0.5918 Acc: 0.7222
Epoch 24: Train Loss: 0.1133 Acc: 0.9667 | Val Loss: 0.5718 Acc: 0.7500
Epoch 25: Train Loss: 0.0965 Acc: 0.9643 | Val Loss: 0.6047 Acc: 0.7500
Epoch 26: Train Loss: 0.0917 Acc: 0.9738 | Val Loss: 0.6562 Acc: 0.7556


In [None]:
torch.save(student_finetune.state_dict(), '/content/drive/MyDrive/mods/student_over_finetuned_10000.pth')

In [None]:
# Training loop
for epoch in range(5):
    student_finetune.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in finetune_loader_labeled:
        images, labels = images.to(device), labels.to(device)

        optimizer_finetune.zero_grad()
        outputs = student_finetune(images)
        loss = criterion_finetune(outputs, labels)
        loss.backward()
        optimizer_finetune.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total

    # Validation
    student_finetune.eval()
    val_running_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for images, labels in val_loader_labeled:
            images, labels = images.to(device), labels.to(device)
            outputs = student_finetune(images)
            loss = criterion_finetune(outputs, labels)

            val_running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    val_loss = val_running_loss / val_total
    val_acc = val_correct / val_total

    print(f'Epoch {epoch+27}: '
          f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} | '
          f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')

Epoch 27: Train Loss: 0.1007 Acc: 0.9738 | Val Loss: 0.6235 Acc: 0.7556
Epoch 28: Train Loss: 0.0767 Acc: 0.9857 | Val Loss: 0.6155 Acc: 0.7611
Epoch 29: Train Loss: 0.0858 Acc: 0.9762 | Val Loss: 0.6313 Acc: 0.7778
Epoch 30: Train Loss: 0.0622 Acc: 0.9857 | Val Loss: 0.6276 Acc: 0.7833
Epoch 31: Train Loss: 0.0615 Acc: 0.9905 | Val Loss: 0.6389 Acc: 0.7667


In [None]:
torch.save(student_finetune.state_dict(), '/content/drive/MyDrive/mods/student_more_over_finetuned_10000.pth')

#SimCLR Pretraining 3000

In [None]:
!pip install lightly


Collecting lightly
  Downloading lightly-1.5.22-py3-none-any.whl.metadata (38 kB)
Collecting hydra-core>=1.0.0 (from lightly)
  Downloading hydra_core-1.3.2-py3-none-any.whl.metadata (5.5 kB)
Collecting lightly_utils~=0.0.0 (from lightly)
  Downloading lightly_utils-0.0.2-py3-none-any.whl.metadata (1.4 kB)
Collecting pytorch_lightning>=1.0.4 (from lightly)
  Downloading pytorch_lightning-2.5.6-py3-none-any.whl.metadata (20 kB)
Collecting aenum>=3.1.11 (from lightly)
  Downloading aenum-3.1.16-py3-none-any.whl.metadata (3.8 kB)
Collecting torchmetrics>0.7.0 (from pytorch_lightning>=1.0.4->lightly)
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch_lightning>=1.0.4->lightly)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading lightly-1.5.22-py3-none-any.whl (859 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m859.3/859.3 kB[0m [31m30.0 MB/s[0m eta [36m0:00:00[0m

In [None]:
import torch.nn as nn
from torchvision.models import mobilenet_v2

# Load MobileNetV2 without pretrained weights and remove classifier
base_model = mobilenet_v2(weights=None)

# Remove the classification head and keep only the feature extractor
# The feature output size is 1280 for MobileNetV2
backbone = nn.Sequential(
    base_model.features,
    nn.AdaptiveAvgPool2d(1),  # Ensure consistent output shape
    nn.Flatten(),             # Shape: [B, 1280]
)


In [None]:
class SimCLRModel(nn.Module):
    def __init__(self, backbone, feature_dim=128):
        super().__init__()
        self.backbone = backbone
        self.projection_head = nn.Sequential(
            nn.Linear(1280, 256),
            nn.ReLU(),
            nn.Linear(256, feature_dim)
        )

    def forward(self, x):
        features = self.backbone(x)         # Shape: [B, 1280]
        projections = self.projection_head(features)  # Shape: [B, feature_dim]
        return projections


In [None]:
from lightly.data import LightlyDataset
from lightly.transforms import SimCLRTransform
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

In [None]:
data_path = '/content/drive/MyDrive/pets0/train3000'
# SimCLR Transform from lightly
simclr_transform = SimCLRTransform(input_size=224)  # you can change input_size as needed

# Custom dataset to return 2 views
class SimCLRDataset(ImageFolder):
    def __getitem__(self, index):
        sample, _ = super().__getitem__(index)
        xi, xj = simclr_transform(sample)
        return xi, xj

# Create dataset and dataloader
dataset = SimCLRDataset(root=data_path)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2)
num_imgs = len(dataset)
print(f"Number of images in the dataset: {num_imgs}")

Number of images in the dataset: 3000


In [None]:
from lightly.loss import NTXentLoss
import torch

# Initialize model
model = SimCLRModel(backbone)
model = model.cuda()

# Loss
criterion = NTXentLoss()

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

mloss = 10.0
# Training loop
for epoch in range(10):
    for (views) in dataloader:  # Assuming LightlyDataset with SimCLRTransform

        view1, view2 = views[0].cuda(), views[1].cuda()


        z1 = model(view1)
        z2 = model(view2)


        loss = criterion(z1, z2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if loss < mloss:
        mloss = loss.item()
        torch.save(model.state_dict(), '/content/drive/MyDrive/mods/mobilenet_sim_3000.pth')
    if (epoch+1)%2==0:
        torch.save(model.state_dict(), '/content/drive/MyDrive/mods/mobilenet_sim_3000_epoch.pth')

    print(f"Epoch {epoch}: Loss = {loss.item():.4f}")


Epoch 0: Loss = 4.6568
Epoch 1: Loss = 4.5776
Epoch 2: Loss = 4.5338
Epoch 3: Loss = 4.4639
Epoch 4: Loss = 4.3600
Epoch 5: Loss = 4.4496
Epoch 6: Loss = 4.4015
Epoch 7: Loss = 4.2809
Epoch 8: Loss = 4.5226
Epoch 9: Loss = 4.2381


In [None]:
for epoch in range(10):
    for (views) in dataloader:  # Assuming LightlyDataset with SimCLRTransform

        view1, view2 = views[0].cuda(), views[1].cuda()


        z1 = model(view1)
        z2 = model(view2)


        loss = criterion(z1, z2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if loss < mloss:
        mloss = loss.item()
        torch.save(model.state_dict(), '/content/drive/MyDrive/mods/mobilenet_sim_3000.pth')
    if (epoch+1)%2==0:
        torch.save(model.state_dict(),  '/content/drive/MyDrive/mods/mobilenet_sim_3000_epoch.pth')

    print(f"Epoch {epoch+10}: Loss = {loss.item():.4f}")

Epoch 10: Loss = 4.2056
Epoch 11: Loss = 4.1428
Epoch 12: Loss = 4.2731
Epoch 13: Loss = 4.2184
Epoch 14: Loss = 4.1554
Epoch 15: Loss = 4.1283
Epoch 16: Loss = 3.8937
Epoch 17: Loss = 4.2686
Epoch 18: Loss = 4.0384
Epoch 19: Loss = 3.9816


In [None]:
for epoch in range(10):
    for (views) in dataloader:  # Assuming LightlyDataset with SimCLRTransform

        view1, view2 = views[0].cuda(), views[1].cuda()


        z1 = model(view1)
        z2 = model(view2)


        loss = criterion(z1, z2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if loss < mloss:
        mloss = loss.item()
        torch.save(model.state_dict(), '/content/drive/MyDrive/mods/mobilenet_sim_3000.pth')
    if (epoch+1)%2==0:
        torch.save(model.state_dict(),  '/content/drive/MyDrive/mods/mobilenet_sim_3000_epoch.pth')

    print(f"Epoch {epoch+20}: Loss = {loss.item():.4f}")

Epoch 20: Loss = 3.9605
Epoch 21: Loss = 3.8694
Epoch 22: Loss = 3.9110
Epoch 23: Loss = 3.8154
Epoch 24: Loss = 4.1257
Epoch 25: Loss = 3.9339
Epoch 26: Loss = 3.8907
Epoch 27: Loss = 3.8074
Epoch 28: Loss = 3.8780
Epoch 29: Loss = 3.9567


In [None]:
torch.save(model.state_dict(),  '/content/drive/MyDrive/mods/mobilenet_sim_3000_epoch.pth')

In [None]:
for epoch in range(10):
    for (views) in dataloader:  # Assuming LightlyDataset with SimCLRTransform

        view1, view2 = views[0].cuda(), views[1].cuda()


        z1 = model(view1)
        z2 = model(view2)


        loss = criterion(z1, z2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if loss < mloss:
        mloss = loss.item()
        torch.save(model.state_dict(), '/content/drive/MyDrive/mods/mobilenet_sim_3000.pth')
    if (epoch+1)%2==0:
        torch.save(model.state_dict(),  '/content/drive/MyDrive/mods/mobilenet_sim_3000_epoch.pth')

    print(f"Epoch {epoch+30}: Loss = {loss.item():.4f}")

Epoch 30: Loss = 3.9850
Epoch 31: Loss = 3.9224
Epoch 32: Loss = 3.9892
Epoch 33: Loss = 3.7859
Epoch 34: Loss = 3.8417
Epoch 35: Loss = 3.9339
Epoch 36: Loss = 3.7608
Epoch 37: Loss = 3.8823
Epoch 38: Loss = 3.8279
Epoch 39: Loss = 3.8465


In [None]:
for epoch in range(10):
    for (views) in dataloader:  # Assuming LightlyDataset with SimCLRTransform

        view1, view2 = views[0].cuda(), views[1].cuda()


        z1 = model(view1)
        z2 = model(view2)


        loss = criterion(z1, z2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if loss < mloss:
        mloss = loss.item()
        torch.save(model.state_dict(), '/content/drive/MyDrive/mods/mobilenet_sim_3000.pth')
    if (epoch+1)%2==0:
        torch.save(model.state_dict(),  '/content/drive/MyDrive/mods/mobilenet_sim_3000_epoch.pth')

    print(f"Epoch {epoch+40}: Loss = {loss.item():.4f}")

Epoch 40: Loss = 3.9658
Epoch 41: Loss = 3.7096
Epoch 42: Loss = 3.7941
Epoch 43: Loss = 3.9225
Epoch 44: Loss = 3.7813
Epoch 45: Loss = 3.8362
Epoch 46: Loss = 3.7475
Epoch 47: Loss = 3.8510
Epoch 48: Loss = 3.7701
Epoch 49: Loss = 3.6677


In [None]:
for epoch in range(10):
    for (views) in dataloader:  # Assuming LightlyDataset with SimCLRTransform

        view1, view2 = views[0].cuda(), views[1].cuda()


        z1 = model(view1)
        z2 = model(view2)


        loss = criterion(z1, z2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if loss < mloss:
        mloss = loss.item()
        torch.save(model.state_dict(), '/content/drive/MyDrive/mods/mobilenet_sim_3000.pth')
    if (epoch+1)%2==0:
        torch.save(model.state_dict(),  '/content/drive/MyDrive/mods/mobilenet_sim_3000_epoch.pth')

    print(f"Epoch {epoch+50}: Loss = {loss.item():.4f}")

Epoch 50: Loss = 3.6622
Epoch 51: Loss = 3.6666
Epoch 52: Loss = 3.8610
Epoch 53: Loss = 3.7226
Epoch 54: Loss = 3.6765
Epoch 55: Loss = 3.7334
Epoch 56: Loss = 3.7732
Epoch 57: Loss = 3.7030
Epoch 58: Loss = 3.7847
Epoch 59: Loss = 3.6570


In [None]:
torch.save(model.state_dict(), '/content/drive/MyDrive/mods/mobilenet_sim_3000.pth')
torch.save(model.state_dict(),  '/content/drive/MyDrive/mods/mobilenet_sim_3000_epoch.pth')

In [None]:
for epoch in range(10):
    for (views) in dataloader:  # Assuming LightlyDataset with SimCLRTransform

        view1, view2 = views[0].cuda(), views[1].cuda()


        z1 = model(view1)
        z2 = model(view2)


        loss = criterion(z1, z2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if loss < mloss:
        mloss = loss.item()
        torch.save(model.state_dict(), '/content/drive/MyDrive/mods/mobilenet_sim_3000.pth')
    if (epoch+1)%2==0:
        torch.save(model.state_dict(),  '/content/drive/MyDrive/mods/mobilenet_sim_3000_epoch.pth')

    print(f"Epoch {epoch+60}: Loss = {loss.item():.4f}")

Epoch 60: Loss = 3.8698
Epoch 61: Loss = 3.6273
Epoch 62: Loss = 3.6860
Epoch 63: Loss = 3.8455
Epoch 64: Loss = 3.7100


# Finetune the pretrained student model 3 000

In [None]:
import torch
import torch.nn as nn
from torchvision.models import mobilenet_v2
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision import transforms, datasets

In [None]:
# Define transforms for finetuning

finetune_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.CenterCrop(192),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# Create datasets and dataloaders for finetuning
finetune_dataset_labeled = ImageFolder('/content/drive/MyDrive/pets0/finetune_train', transform=finetune_transform)
val_dataset_labeled = ImageFolder('/content/drive/MyDrive/pets0/val', transform=val_transform)

finetune_loader_labeled = DataLoader(finetune_dataset_labeled, batch_size=64, shuffle=True, num_workers=2)
val_loader_labeled = DataLoader(val_dataset_labeled, batch_size=64, shuffle=False, num_workers=2)

print(f"Number of samples in finetune dataset: {len(finetune_dataset_labeled)}")
print(f"Number of samples in validation dataset: {len(val_dataset_labeled)}")

Number of samples in finetune dataset: 420
Number of samples in validation dataset: 180


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the pretrained SimCLR model state dict
pretrained_path = '/content/drive/MyDrive/mods/mobilenet_sim_3000_epoch.pth'
simclr_state_dict = torch.load(pretrained_path, map_location=device)

# Create a standard MobileNetV2 model
student_finetune = mobilenet_v2(weights=None)

# Filter the state dict to keep only the backbone weights
backbone_state_dict = {}
for k, v in simclr_state_dict.items():
    # Keys in the saved SimCLRModel state dict for the backbone start with 'backbone.0.'
    # We want to load these into the 'features.' of the standard MobileNetV2
    if k.startswith('backbone.0.'):
        backbone_state_dict[k.replace('backbone.0.', 'features.')] = v
    # Also handle the case if the state dict keys were just 'backbone.' without the '0.'
    elif k.startswith('backbone.'):
        backbone_state_dict[k.replace('backbone.', 'features.')] = v


# Load the backbone weights into the standard MobileNetV2 model
# Use strict=False because we are not loading the classifier weights
student_finetune.load_state_dict(backbone_state_dict, strict=False)

# Replace the classifier
num_ftrs = student_finetune.classifier[1].in_features
student_finetune.classifier[1] = nn.Linear(num_ftrs, 2) # 2 classes: cat, dog

student_finetune = student_finetune.to(device)

# Optionally freeze the backbone and only train the classifier for faster initial training
# for param in student_finetune.features.parameters():
#     param.requires_grad = False

print("Pretrained MobileNetV2 loaded and classifier replaced.")

Pretrained MobileNetV2 loaded and classifier replaced.


In [None]:
# Define optimizer and loss function for finetuning
optimizer_finetune = torch.optim.Adam(student_finetune.parameters(), lr=1e-4) # Start with a lower learning rate
criterion_finetune = nn.CrossEntropyLoss()

vacc = 0.0
# Training loop
for epoch in range(10):
    student_finetune.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in finetune_loader_labeled:
        images, labels = images.to(device), labels.to(device)

        optimizer_finetune.zero_grad()
        outputs = student_finetune(images)
        loss = criterion_finetune(outputs, labels)
        loss.backward()
        optimizer_finetune.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total

    # Validation
    student_finetune.eval()
    val_running_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for images, labels in val_loader_labeled:
            images, labels = images.to(device), labels.to(device)
            outputs = student_finetune(images)
            loss = criterion_finetune(outputs, labels)

            val_running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    val_loss = val_running_loss / val_total
    val_acc = val_correct / val_total

    if (epoch+1) % 2 == 0:
        torch.save(student_finetune.state_dict(), '/content/drive/MyDrive/mods/student_finetuned_3000_epoch.pth')
    if val_acc > vacc:
        torch.save(student_finetune.state_dict(), '/content/drive/MyDrive/mods/student_finetuned_3000.pth')
    print(f'Epoch {epoch+1}: '
          f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} | '
          f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')


Epoch 1: Train Loss: 0.7052 Acc: 0.5310 | Val Loss: 0.6311 Acc: 0.6444
Epoch 2: Train Loss: 0.6544 Acc: 0.6452 | Val Loss: 0.6014 Acc: 0.7222
Epoch 3: Train Loss: 0.6239 Acc: 0.6643 | Val Loss: 0.5766 Acc: 0.7444
Epoch 4: Train Loss: 0.5974 Acc: 0.6881 | Val Loss: 0.5604 Acc: 0.7444
Epoch 5: Train Loss: 0.5805 Acc: 0.7143 | Val Loss: 0.5467 Acc: 0.7167
Epoch 6: Train Loss: 0.5459 Acc: 0.7190 | Val Loss: 0.5339 Acc: 0.7167
Epoch 7: Train Loss: 0.5394 Acc: 0.7310 | Val Loss: 0.5253 Acc: 0.7167
Epoch 8: Train Loss: 0.4910 Acc: 0.7810 | Val Loss: 0.5305 Acc: 0.7056
Epoch 9: Train Loss: 0.5186 Acc: 0.7452 | Val Loss: 0.5385 Acc: 0.7056
Epoch 10: Train Loss: 0.4547 Acc: 0.7786 | Val Loss: 0.5296 Acc: 0.7111


In [None]:
# Training loop
for epoch in range(10):
    student_finetune.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in finetune_loader_labeled:
        images, labels = images.to(device), labels.to(device)

        optimizer_finetune.zero_grad()
        outputs = student_finetune(images)
        loss = criterion_finetune(outputs, labels)
        loss.backward()
        optimizer_finetune.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total

    # Validation
    student_finetune.eval()
    val_running_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for images, labels in val_loader_labeled:
            images, labels = images.to(device), labels.to(device)
            outputs = student_finetune(images)
            loss = criterion_finetune(outputs, labels)

            val_running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    val_loss = val_running_loss / val_total
    val_acc = val_correct / val_total

    if (epoch+1) % 2 == 0:
        torch.save(student_finetune.state_dict(), '/content/drive/MyDrive/mods/student_finetuned_3000_epoch.pth')
    if val_acc > vacc:
        vacc = val_acc
        torch.save(student_finetune.state_dict(), '/content/drive/MyDrive/mods/student_finetuned_3000.pth')
    print(f'Epoch {epoch+1}: '
          f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} | '
          f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')

Epoch 1: Train Loss: 0.4284 Acc: 0.8214 | Val Loss: 0.5192 Acc: 0.7278
Epoch 2: Train Loss: 0.4122 Acc: 0.8262 | Val Loss: 0.5212 Acc: 0.7222
Epoch 3: Train Loss: 0.3852 Acc: 0.8405 | Val Loss: 0.5334 Acc: 0.7389
Epoch 4: Train Loss: 0.3541 Acc: 0.8357 | Val Loss: 0.5450 Acc: 0.7278
Epoch 5: Train Loss: 0.3518 Acc: 0.8429 | Val Loss: 0.5609 Acc: 0.7556
Epoch 6: Train Loss: 0.3140 Acc: 0.8810 | Val Loss: 0.5693 Acc: 0.7389
Epoch 7: Train Loss: 0.2711 Acc: 0.9000 | Val Loss: 0.5756 Acc: 0.7389
Epoch 8: Train Loss: 0.2627 Acc: 0.9167 | Val Loss: 0.6068 Acc: 0.7278
Epoch 9: Train Loss: 0.2460 Acc: 0.8976 | Val Loss: 0.6332 Acc: 0.7278
Epoch 10: Train Loss: 0.2310 Acc: 0.9071 | Val Loss: 0.6453 Acc: 0.7000


In [None]:
torch.save(student_finetune.state_dict(), '/content/drive/MyDrive/mods/student_finetuned_3000_epoch.pth')

#SimCLR Pretraining 6 000 with compact function

In [None]:
!pip install lightly


Collecting lightly
  Downloading lightly-1.5.22-py3-none-any.whl.metadata (38 kB)
Collecting hydra-core>=1.0.0 (from lightly)
  Downloading hydra_core-1.3.2-py3-none-any.whl.metadata (5.5 kB)
Collecting lightly_utils~=0.0.0 (from lightly)
  Downloading lightly_utils-0.0.2-py3-none-any.whl.metadata (1.4 kB)
Collecting pytorch_lightning>=1.0.4 (from lightly)
  Downloading pytorch_lightning-2.6.0-py3-none-any.whl.metadata (21 kB)
Collecting aenum>=3.1.11 (from lightly)
  Downloading aenum-3.1.16-py3-none-any.whl.metadata (3.8 kB)
Collecting torchmetrics>0.7.0 (from pytorch_lightning>=1.0.4->lightly)
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch_lightning>=1.0.4->lightly)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading lightly-1.5.22-py3-none-any.whl (859 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m859.3/859.3 kB[0m [31m27.0 MB/s[0m eta [36m0:00:00[0m

In [None]:
import torch
import torch.nn as nn
from torchvision.models import mobilenet_v2
from lightly.transforms import SimCLRTransform
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from lightly.loss import NTXentLoss

# Define the SimCLRModel class globally as it's a core component
class SimCLRModel(nn.Module):
    def __init__(self, backbone, feature_dim=128):
        super().__init__()
        self.backbone = backbone
        self.projection_head = nn.Sequential(
            nn.Linear(1280, 256),
            nn.ReLU(),
            nn.Linear(256, feature_dim)
        )

    def forward(self, x):
        features = self.backbone(x)
        projections = self.projection_head(features)
        return projections

def load_simclr_model(feature_dim=128, pretrained_path=None):
    """
    Loads a MobileNetV2 backbone and constructs the SimCLR model.
    Optionally loads pretrained weights.
    """
    # Load MobileNetV2 without pretrained weights and remove classifier
    base_model = mobilenet_v2(weights=None)

    # Remove the classification head and keep only the feature extractor
    backbone = nn.Sequential(
        base_model.features,
        nn.AdaptiveAvgPool2d(1),  # Ensure consistent output shape
        nn.Flatten(),             # Shape: [B, 1280]
    )

    model = SimCLRModel(backbone, feature_dim=feature_dim)

    if pretrained_path:
        print(f"Loading pretrained model from {pretrained_path}")
        # Use map_location='cpu' to load onto CPU first, then move to device
        student_state_dict = torch.load(pretrained_path, map_location='cpu')
        model.load_state_dict(student_state_dict, strict=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    print("SimCLR model loaded successfully.")
    return model, device

# SimCLR Transform from lightly
  # you can change input_size as needed

# Custom dataset to return 2 views
class SimCLRDataset(ImageFolder):
    def __getitem__(self, index):
        simclr_transform = SimCLRTransform(input_size=224)
        sample, _ = super().__getitem__(index)
        xi, xj = simclr_transform(sample)
        return xi, xj

def load_simclr_data(data_path, batch_size=64, input_size=224, num_workers=2):
    """
    Loads the dataset and creates a DataLoader for SimCLR pretraining.
    """
    simclr_transform = SimCLRTransform(input_size=input_size)
    dataset = SimCLRDataset(root=data_path)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
    num_imgs = len(dataset)
    print(f"Number of images in the dataset: {num_imgs}")
    return dataloader, num_imgs

def train_simclr_model(model, dataloader, device, epochs=20, lr=3e-4, save_best_path=None, save_epoch_path=None, start_epoch=0, initial_mloss=float('inf')):
    """
    Trains the SimCLR model.
    """
    criterion = NTXentLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    mloss = initial_mloss # Keep track of the minimum loss for saving the best model

    print("Starting SimCLR training...")
    for epoch in range(start_epoch, start_epoch + epochs):
        running_loss = 0.0
        total_batches = 0
        for views in dataloader:
            view1, view2 = views[0].to(device), views[1].to(device)

            optimizer.zero_grad()
            z1 = model(view1)
            z2 = model(view2)
            loss = criterion(z1, z2)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            total_batches += 1

        avg_epoch_loss = running_loss / total_batches if total_batches > 0 else 0.0

        print(f"Epoch {epoch + 1}: Loss = {avg_epoch_loss:.4f}")

        # Save model if current loss is the best so far
        if save_best_path and avg_epoch_loss < mloss:
            mloss = avg_epoch_loss
            torch.save(model.state_dict(), save_best_path)
            print(f"Saved best model with loss {mloss:.4f} at epoch {epoch + 1}")

        # Save model periodically
        if save_epoch_path and (epoch + 1) % 2 == 0:
            torch.save(model.state_dict(), save_epoch_path)
            print(f"Saved model checkpoint at epoch {epoch + 1}")

    print("SimCLR training finished.")
    return model

In [None]:
model, device = load_simclr_model(feature_dim=128)

SimCLR model loaded successfully.


In [None]:
data_path = '/content/drive/MyDrive/pets0/train6000'
dataloader, num_imgs = load_simclr_data(data_path, batch_size=64, input_size=224, num_workers=2)

Number of images in the dataset: 6000


In [None]:
model = train_simclr_model(
     model, dataloader, device, epochs=20, lr=3e-4,
     save_best_path='/content/drive/MyDrive/mods/mobilenet_sim_6000.pth',
     save_epoch_path='/content/drive/MyDrive/mods/mobilenet_sim_6000_epoch.pth',
     start_epoch=0 # For new training
)

Starting SimCLR training...
Epoch 1: Loss = 4.8232
Saved best model with loss 4.8232 at epoch 1
Epoch 2: Loss = 4.7587
Saved best model with loss 4.7587 at epoch 2
Saved model checkpoint at epoch 2
Epoch 3: Loss = 4.6996
Saved best model with loss 4.6996 at epoch 3
Epoch 4: Loss = 4.6325
Saved best model with loss 4.6325 at epoch 4
Saved model checkpoint at epoch 4
Epoch 5: Loss = 4.5592
Saved best model with loss 4.5592 at epoch 5
Epoch 6: Loss = 4.4771
Saved best model with loss 4.4771 at epoch 6
Saved model checkpoint at epoch 6
Epoch 7: Loss = 4.3855
Saved best model with loss 4.3855 at epoch 7
Epoch 8: Loss = 4.3165
Saved best model with loss 4.3165 at epoch 8
Saved model checkpoint at epoch 8
Epoch 9: Loss = 4.2478
Saved best model with loss 4.2478 at epoch 9
Epoch 10: Loss = 4.2240
Saved best model with loss 4.2240 at epoch 10
Saved model checkpoint at epoch 10
Epoch 11: Loss = 4.1939
Saved best model with loss 4.1939 at epoch 11
Epoch 12: Loss = 4.1348
Saved best model with los

In [None]:
model = train_simclr_model(
     model, dataloader, device, epochs=10, lr=3e-4,
     save_best_path='/content/drive/MyDrive/mods/mobilenet_sim_6000.pth',
     save_epoch_path='/content/drive/MyDrive/mods/mobilenet_sim_6000_epoch.pth',
     start_epoch=21,
     initial_mloss = 3.9897
)

Starting SimCLR training...
Epoch 22: Loss = 3.9808
Saved best model with loss 3.9808 at epoch 22
Saved model checkpoint at epoch 22
Epoch 23: Loss = 3.9529
Saved best model with loss 3.9529 at epoch 23
Epoch 24: Loss = 3.9589
Saved model checkpoint at epoch 24
Epoch 25: Loss = 3.9534
Epoch 26: Loss = 3.9527
Saved best model with loss 3.9527 at epoch 26
Saved model checkpoint at epoch 26
Epoch 27: Loss = 3.9244
Saved best model with loss 3.9244 at epoch 27
Epoch 28: Loss = 3.9112
Saved best model with loss 3.9112 at epoch 28
Saved model checkpoint at epoch 28
Epoch 29: Loss = 3.9094
Saved best model with loss 3.9094 at epoch 29
Epoch 30: Loss = 3.8954
Saved best model with loss 3.8954 at epoch 30
Saved model checkpoint at epoch 30
Epoch 31: Loss = 3.8855
Saved best model with loss 3.8855 at epoch 31
SimCLR training finished.


In [None]:
model = train_simclr_model(
     model, dataloader, device, epochs=10, lr=3e-4,
     save_best_path='/content/drive/MyDrive/mods/mobilenet_sim_6000.pth',
     save_epoch_path='/content/drive/MyDrive/mods/mobilenet_sim_6000_epoch.pth',
     start_epoch=32,
     initial_mloss = 3.8855
)

Starting SimCLR training...
Epoch 33: Loss = 3.8926
Epoch 34: Loss = 3.8753
Saved best model with loss 3.8753 at epoch 34
Saved model checkpoint at epoch 34
Epoch 35: Loss = 3.8603
Saved best model with loss 3.8603 at epoch 35
Epoch 36: Loss = 3.8596
Saved best model with loss 3.8596 at epoch 36
Saved model checkpoint at epoch 36
Epoch 37: Loss = 3.8481
Saved best model with loss 3.8481 at epoch 37
Epoch 38: Loss = 3.8348
Saved best model with loss 3.8348 at epoch 38
Saved model checkpoint at epoch 38
Epoch 39: Loss = 3.8436
Epoch 40: Loss = 3.8275
Saved best model with loss 3.8275 at epoch 40
Saved model checkpoint at epoch 40
Epoch 41: Loss = 3.8259
Saved best model with loss 3.8259 at epoch 41
Epoch 42: Loss = 3.8065
Saved best model with loss 3.8065 at epoch 42
Saved model checkpoint at epoch 42
SimCLR training finished.


In [None]:
model = train_simclr_model(
     model, dataloader, device, epochs=10, lr=3e-4,
     save_best_path='/content/drive/MyDrive/mods/mobilenet_sim_6000.pth',
     save_epoch_path='/content/drive/MyDrive/mods/mobilenet_sim_6000_epoch.pth',
     start_epoch=43,
     initial_mloss = 3.8065
)

Starting SimCLR training...
Epoch 44: Loss = 3.8167
Saved model checkpoint at epoch 44


##Continue pretraining

In [None]:
# To resume training from a saved checkpoint:
pretrained_model_path = '/content/drive/MyDrive/mods/mobilenet_sim_6000_epoch.pth'
model, device = load_simclr_model(feature_dim=128, pretrained_path=pretrained_model_path)

Loading pretrained model from /content/drive/MyDrive/mods/mobilenet_sim_6000_epoch.pth
SimCLR model loaded successfully.


In [None]:
data_path = '/content/drive/MyDrive/pets/train6000'
dataloader, num_imgs = load_simclr_data(data_path, batch_size=64, input_size=224, num_workers=2)

Number of images in the dataset: 6000


In [None]:
# Example for continuing training from a checkpoint (adjust start_epoch and initial_mloss based on previous runs)
model = train_simclr_model(
        model, dataloader, device, epochs=15, lr=3e-4,
        save_best_path='/content/drive/MyDrive/mods/mobilenet_sim_6000.pth',
        save_epoch_path='/content/drive/MyDrive/mods/mobilenet_sim_6000_epoch.pth',
        start_epoch=45,
        initial_mloss=3.8065 # Based on the mloss from the last training run
    )

Starting SimCLR training...
Epoch 46: Loss = 3.8067
Saved model checkpoint at epoch 46
Epoch 47: Loss = 3.7897
Saved best model with loss 3.7897 at epoch 47
Epoch 48: Loss = 3.7713
Saved best model with loss 3.7713 at epoch 48
Saved model checkpoint at epoch 48
Epoch 49: Loss = 3.7787
Epoch 50: Loss = 3.7828
Saved model checkpoint at epoch 50
Epoch 51: Loss = 3.7639
Saved best model with loss 3.7639 at epoch 51
Epoch 52: Loss = 3.7698
Saved model checkpoint at epoch 52
Epoch 53: Loss = 3.7530
Saved best model with loss 3.7530 at epoch 53
Epoch 54: Loss = 3.7421
Saved best model with loss 3.7421 at epoch 54
Saved model checkpoint at epoch 54
Epoch 55: Loss = 3.7503
Epoch 56: Loss = 3.7428
Saved model checkpoint at epoch 56
Epoch 57: Loss = 3.7253
Saved best model with loss 3.7253 at epoch 57
Epoch 58: Loss = 3.7294
Saved model checkpoint at epoch 58
Epoch 59: Loss = 3.7302
Epoch 60: Loss = 3.7256
Saved model checkpoint at epoch 60
SimCLR training finished.


#Finetune the simCLR model 6000 with compact function

In [None]:
import torch
import torch.nn as nn
from torchvision.models import mobilenet_v2
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

def finetune_model(model, train_loader, val_loader, optimizer, criterion, num_epochs, device, save_best=None, save_epoch=None, start_epoch=0, init_acc = 0.0):
    best_val_acc = init_acc

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        epoch_loss = running_loss / total
        epoch_acc = correct / total

        # Validation
        model.eval()
        val_running_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_running_loss += loss.item() * images.size(0)
                _, preds = torch.max(outputs, 1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

        val_loss = val_running_loss / val_total
        val_acc = val_correct / val_total

        print(f'Epoch {start_epoch + epoch + 1}: '
              f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} | '
              f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')

        if save_epoch and (epoch+1)%2==0:
            torch.save(model.state_dict(), save_epoch)
            print(f"Saved model checkpoint at epoch {start_epoch + epoch + 1}")

        if save_best:
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(), save_best)
                print(f"Saved best model with Val Acc: {best_val_acc:.4f} at epoch {start_epoch + epoch + 1}")

    return best_val_acc

In [None]:
def run_finetuning_workflow(pretrained_simclr_path=None, num_epochs_initial=10, best_save=None,
                            save_epoch=None, start_epoch=0, init_acc = 0.0, finetuned_path=None):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create a standard MobileNetV2 model
    student_finetune = mobilenet_v2(weights=None)
    num_ftrs = student_finetune.classifier[1].in_features
    student_finetune.classifier[1] = nn.Linear(num_ftrs, 2) # 2 classes: cat, dog

    if finetuned_path:
        print(f"Continuing finetuning from {finetuned_path}")
        student_finetune.load_state_dict(torch.load(finetuned_path, map_location=device))
        print("Previous finetuned model loaded successfully.")
    elif pretrained_simclr_path:
        print(f"Starting new finetuning using SimCLR backbone from {pretrained_simclr_path}")
        # Load the pretrained SimCLR model state dict
        simclr_state_dict = torch.load(pretrained_simclr_path, map_location=device)

        # Filter the state dict to keep only the backbone weights
        backbone_state_dict = {}
        for k, v in simclr_state_dict.items():
            # Keys in the saved SimCLRModel state dict for the backbone start with 'backbone.0.'
            if k.startswith('backbone.0.'):
                backbone_state_dict[k.replace('backbone.0.', 'features.')] = v
            # Also handle the case if the state dict keys were just 'backbone.' without the '0.'
            elif k.startswith('backbone.'):
                backbone_state_dict[k.replace('backbone.', 'features.')] = v
            # Handle projection head weights if directly loading the SimCLRModel's state_dict
            elif k.startswith('projection_head.'):
                # These are not needed for finetuning the classification head, so we ignore them
                pass


        # Load the backbone weights into the standard MobileNetV2 model
        # Use strict=False because we are not loading the classifier weights
        student_finetune.load_state_dict(backbone_state_dict, strict=False)
        print("Pretrained SimCLR backbone loaded and classifier replaced.")
    else:
        raise ValueError("Either pretrained_simclr_path must be provided for new finetuning, or cont=True and finetuned_path must be provided for resuming.")

    student_finetune = student_finetune.to(device)

    # Define transforms for finetuning
    finetune_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.CenterCrop(192),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    # Create datasets and dataloaders for finetuning
    finetune_dataset_labeled = datasets.ImageFolder('/content/drive/MyDrive/pets/finetune_train', transform=finetune_transform)
    val_dataset_labeled = datasets.ImageFolder('/content/drive/MyDrive/pets/val', transform=val_transform)

    finetune_loader_labeled = DataLoader(finetune_dataset_labeled, batch_size=64, shuffle=True, num_workers=2)
    val_loader_labeled = DataLoader(val_dataset_labeled, batch_size=64, shuffle=False, num_workers=2)

    print(f"Number of samples in finetune dataset: {len(finetune_dataset_labeled)}")
    print(f"Number of samples in validation dataset: {len(val_dataset_labeled)}")

    # Define optimizer and loss function for finetuning
    optimizer_finetune = torch.optim.Adam(student_finetune.parameters(), lr=1e-4) # Start with a lower learning rate
    criterion_finetune = nn.CrossEntropyLoss()

    # Initial finetuning
    finetune_model(
        student_finetune,
        finetune_loader_labeled,
        val_loader_labeled,
        optimizer_finetune,
        criterion_finetune,
        num_epochs=num_epochs_initial,
        device=device,
        save_best=best_save,
        save_epoch=save_epoch,
        start_epoch=start_epoch,
        init_acc=init_acc
    )

    print(f"Finetuning complete. Best model saved to {best_save}")

In [None]:
run_finetuning_workflow(pretrained_simclr_path='/content/drive/MyDrive/mods/mobilenet_sim_6000.pth',
                        num_epochs_initial=20,
                        best_save='/content/drive/MyDrive/mods/student_finetuned_6000.pth',
                        save_epoch= '/content/drive/MyDrive/mods/student_finetuned_6000_epoch.pth'
)

Starting new finetuning using SimCLR backbone from /content/drive/MyDrive/mods/mobilenet_sim_6000.pth
Pretrained SimCLR backbone loaded and classifier replaced.
Number of samples in finetune dataset: 420
Number of samples in validation dataset: 180
Epoch 1: Train Loss: 0.6830 Acc: 0.5762 | Val Loss: 0.6445 Acc: 0.7000
Saved best model with Val Acc: 0.7000 at epoch 1
Epoch 2: Train Loss: 0.6348 Acc: 0.6762 | Val Loss: 0.5817 Acc: 0.7111
Saved model checkpoint at epoch 2
Saved best model with Val Acc: 0.7111 at epoch 2
Epoch 3: Train Loss: 0.5920 Acc: 0.7286 | Val Loss: 0.5371 Acc: 0.7333
Saved best model with Val Acc: 0.7333 at epoch 3
Epoch 4: Train Loss: 0.5626 Acc: 0.6952 | Val Loss: 0.5051 Acc: 0.7833
Saved model checkpoint at epoch 4
Saved best model with Val Acc: 0.7833 at epoch 4
Epoch 5: Train Loss: 0.5250 Acc: 0.7429 | Val Loss: 0.4874 Acc: 0.7944
Saved best model with Val Acc: 0.7944 at epoch 5
Epoch 6: Train Loss: 0.4930 Acc: 0.7690 | Val Loss: 0.4752 Acc: 0.7889
Saved model 