In [62]:
model_checkpoint = "/kaggle/input/ssl-models/models/simCLR_checkpoint_epoch_0100.pth.tar" 
linear_probing = "/kaggle/input/ssl-models/models/LinearProbing_simCLR.pth.tar" 
Validation_dataset_path = "/kaggle/input/ssl-dataset/ssl_dataset/val.X"
Class_idx_json_path = "/kaggle/input/ssl-models/models/class_to_idx.json"

In [63]:
# Models stored in https://www.kaggle.com/datasets/harishchanderprivate/ssl-models (Alt ID)

In [64]:
# Code for clearing memory

import torch
import gc

gc.collect()
torch.cuda.empty_cache()
if torch.cuda.is_available():
    torch.cuda.ipc_collect()

In [65]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.amp import GradScaler, autocast
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import transforms
import torchvision.models as models
import torch.backends.cudnn as cudnn
from torchvision import datasets
import numpy as np

from tqdm import tqdm
# 
import os
import sys
import logging

import yaml
import argparse

torch.manual_seed(0)
np.random.seed(0)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda')

In [66]:
# exceptions/exceptions.py

class BaseSimCLRException(Exception):
    """Base exception"""


class InvalidBackboneError(BaseSimCLRException):
    """Raised when the choice of backbone Convnet is invalid."""


class InvalidDatasetSelection(BaseSimCLRException):
    """Raised when the choice of dataset is invalid."""

In [67]:
class ResNetSimCLR(nn.Module):
    def __init__(self, base_model, out_dim):
        super(ResNetSimCLR, self).__init__()
        self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),
                            "resnet50": models.resnet50(pretrained=False, num_classes=out_dim)}

        self.backbone = self._get_basemodel(base_model)
        dim_mlp = self.backbone.fc.in_features

        self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)

    def _get_basemodel(self, model_name = 'resnet50'):
        try:
            model = self.resnet_dict[model_name]
        except KeyError:
            raise InvalidBackboneError(
                "Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50")
        else:
            return model

    def forward(self, x):
        return self.backbone(x)

In [68]:
model = ResNetSimCLR(base_model="resnet50", out_dim=128)
checkpoint = torch.load(model_checkpoint)

if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs.")
        model = torch.nn.DataParallel(model) # Since model was stored as data parallel

model.load_state_dict(checkpoint['state_dict'])

feature_dim = model.module.backbone.fc[0].in_features
encoder = model.module.backbone
encoder.fc = nn.Identity()
encoder.to(DEVICE);


Using 2 GPUs.


In [69]:
LinearProbe = nn.Sequential(
        nn.BatchNorm1d(2048, affine=False, eps=1e-6),
        nn.Linear(2048, 100)
    ).to(DEVICE)

saved_checkpoint = torch.load(linear_probing, map_location=DEVICE)
msg = LinearProbe.load_state_dict(saved_checkpoint['model_state_dict'])
print(msg)


<All keys matched successfully>


In [70]:

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

dataset = datasets.ImageFolder(root=Validation_dataset_path, transform=transform)

with open(Class_idx_json_path, 'r') as f:
    class_to_idx = json.load(f)

dataset.class_to_idx = class_to_idx

dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4)

In [71]:
import torch

encoder.eval()
LinearProbe.eval()

correct = 0
total = 0

with torch.no_grad():
    for batch_X, batch_y in dataloader:
        batch_X, batch_y = batch_X.to(DEVICE), batch_y.to(DEVICE)
        features = encoder(batch_X)
        logits = LinearProbe(features)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == batch_y).sum().item()
        total += batch_y.size(0)
            
accuracy = correct / total
print(f"Evaluation Accuracy: {accuracy:.4f}")

Evaluation Accuracy: 0.4108
