In [82]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn as nn
import torch.optim as optim

train_images_path = "brain_train_image_final.npy"
train_labels_path = "brain_train_label.npy"
test_images_path = "brain_test_image_final.npy"
test_labels_path = "brain_test_label.npy"

final_X_train_modified = np.load(train_images_path)[:, 1, :, :]  # Use second channel (grayscale)
final_X_test_modified = np.load(test_images_path)[:, 1, :, :]  # Use second channel (grayscale)
train_labels = np.load(train_labels_path)
test_labels = np.load(test_labels_path)

# Normalize and Resize Images using Pillow
def normalize_and_resize(images, target_size=(224, 224)):
    resized_images = []
    for img in images:
        img = Image.fromarray((img * 255).astype(np.uint8))  # Convert to PIL Image
        img_resized = img.resize(target_size, Image.Resampling.LANCZOS)  # Resize to target size using LANCZOS
        resized_images.append(np.array(img_resized) / 255.0)  # Normalize back to [0, 1]
    return np.array(resized_images)

# Preprocess data
final_X_train_resized = normalize_and_resize(final_X_train_modified, target_size=(224, 224))
final_X_test_resized = normalize_and_resize(final_X_test_modified, target_size=(224, 224))

# Define SimCLR Augmentation Transform for Grayscale Images
transform_simclr = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomResizedCrop(size=224),  # Random crop and resize
    transforms.RandomHorizontalFlip(p=0.5),  # Random horizontal flip
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1),  # Color jitter
    transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),  # Gaussian blur
    transforms.ToTensor(),  # Convert to tensor
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1, 1]
])

# Custom Dataset Class for SimCLR
class SimCLRDataset(Dataset):
    def __init__(self, images, transform):
        self.images = images
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        img_1 = self.transform(img)
        img_2 = self.transform(img)
        return img_1, img_2

# Define SimCLR Model
class SimCLR(nn.Module):
    def __init__(self, base_encoder, projection_dim):
        super(SimCLR, self).__init__()
        self.encoder = base_encoder
        self.projector = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, projection_dim)
        )

    def forward(self, x):
        h = self.encoder(x)
        z = self.projector(h)
        return z

# Define NT-Xent Loss
class NTXentLoss(nn.Module):
    def __init__(self, batch_size, temperature):
        super(NTXentLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss(reduction="sum")

    def forward(self, z_i, z_j):
        N = z_i.size(0) + z_j.size(0)
        z = torch.cat((z_i, z_j), dim=0)
        sim = torch.matmul(z, z.T) / self.temperature
        mask = ~torch.eye(N, dtype=torch.bool, device=z.device)

        positives = torch.cat([
            torch.diag(sim, z_i.size(0)),  # Sim between z_i and z_j
            torch.diag(sim, -z_i.size(0)) # Sim between z_j and z_i
        ])

        negatives = sim[mask].view(N, -1)
        logits = torch.cat((positives.unsqueeze(1), negatives), dim=1)
        labels = torch.zeros(N, dtype=torch.long, device=z.device)
        loss = self.criterion(logits, labels) / N
        return loss

# Initialize Dataset and DataLoader
train_dataset = SimCLRDataset(final_X_train_resized, transform=transform_simclr)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)

# Initialize Encoder and SimCLR Model
base_encoder = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.Flatten(),
    nn.Linear(128 * 56 * 56, 512)
)

model = SimCLR(base_encoder, projection_dim=128).to("cuda")
optimizer = optim.Adam(model.parameters(), lr=1e-7)
criterion = NTXentLoss(batch_size=512, temperature=0.7)

# Training Loop
for epoch in range(200):  # Number of epochs
    total_loss = 0
    model.train()
    for img_1, img_2 in train_loader:
        img_1, img_2 = img_1.to("cuda"), img_2.to("cuda")
        z_i = model(img_1)
        z_j = model(img_2)

        loss = criterion(z_i, z_j)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch [{epoch+1}/200], Loss: {total_loss/len(train_loader):.4f}")

# Save the Trained Model
torch.save(model.state_dict(), "simclr_model.pth")

# Evaluation on Test Dataset
class TestDataset(Dataset):
    def __init__(self, images, labels, transform):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[idx]
        img_transformed = self.transform(img)
        return img_transformed, label

# Prepare test dataset and dataloader
test_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

test_dataset = TestDataset(final_X_test_resized, test_labels, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)

# Classification Head
class ClassificationHead(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(ClassificationHead, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

# Add Classification Head to Encoder
classification_head = ClassificationHead(input_dim=512, num_classes=len(np.unique(train_labels))).to("cuda")
optimizer_cls = optim.Adam(classification_head.parameters(), lr=1e-3)
criterion_cls = nn.CrossEntropyLoss()

# Fine-tune on Training Dataset
for epoch in range(200):
    model.eval()  # Freeze the encoder
    classification_head.train()
    total_loss = 0
    correct = 0
    for img, label in DataLoader(TestDataset(final_X_train_resized, train_labels, test_transform), batch_size=64, shuffle=True):
        img, label = img.to("cuda"), label.to("cuda")
        with torch.no_grad():
            features = model.encoder(img)
        logits = classification_head(features)
        loss = criterion_cls(logits, label)

        optimizer_cls.zero_grad()
        loss.backward()
        optimizer_cls.step()

        total_loss += loss.item()
        correct += (logits.argmax(dim=1) == label).sum().item()

    accuracy = correct / len(train_labels)
    print(f"Epoch [{epoch+1}/100], Loss: {total_loss/len(train_labels):.4f}, Accuracy: {accuracy:.4f}")

# Evaluate on Test Dataset
classification_head.eval()
correct = 0
with torch.no_grad():
    for img, label in test_loader:
        img, label = img.to("cuda"), label.to("cuda")
        features = model.encoder(img)
        logits = classification_head(features)
        correct += (logits.argmax(dim=1) == label).sum().item()

test_accuracy = correct / len(test_labels)
print(f"Test Accuracy: {test_accuracy:.4f}")


Epoch [1/200], Loss: 6.5709
Epoch [2/200], Loss: 6.5706
Epoch [3/200], Loss: 6.5707
Epoch [4/200], Loss: 6.5706
Epoch [5/200], Loss: 6.5701
Epoch [6/200], Loss: 6.5705
Epoch [7/200], Loss: 6.5702
Epoch [8/200], Loss: 6.5697
Epoch [9/200], Loss: 6.5709
Epoch [10/200], Loss: 6.5707
Epoch [11/200], Loss: 6.5702
Epoch [12/200], Loss: 6.5701
Epoch [13/200], Loss: 6.5695
Epoch [14/200], Loss: 6.5705
Epoch [15/200], Loss: 6.5681
Epoch [16/200], Loss: 6.5699
Epoch [17/200], Loss: 6.5707
Epoch [18/200], Loss: 6.5680
Epoch [19/200], Loss: 6.5694
Epoch [20/200], Loss: 6.5713
Epoch [21/200], Loss: 6.5699
Epoch [22/200], Loss: 6.5698
Epoch [23/200], Loss: 6.5686
Epoch [24/200], Loss: 6.5693
Epoch [25/200], Loss: 6.5689
Epoch [26/200], Loss: 6.5694
Epoch [27/200], Loss: 6.5696
Epoch [28/200], Loss: 6.5675
Epoch [29/200], Loss: 6.5700
Epoch [30/200], Loss: 6.5675
Epoch [31/200], Loss: 6.5684
Epoch [32/200], Loss: 6.5662
Epoch [33/200], Loss: 6.5646
Epoch [34/200], Loss: 6.5655
Epoch [35/200], Loss: 6

In [5]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models import resnet18
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from PIL import Image
import matplotlib.pyplot as plt

# Load and preprocess the dataset
train_images_path = "brain_train_image_final.npy"
train_labels_path = "brain_train_label.npy"
test_images_path = "brain_test_image_final.npy"
test_labels_path = "brain_test_label.npy"

# Load the data
final_X_train_modified = np.load(train_images_path)[:, 1, :, :]
final_X_test_modified = np.load(test_images_path)[:, 1, :, :]
train_labels = np.load(train_labels_path)
test_labels = np.load(test_labels_path)

# Normalize and Resize Images using Pillow
def normalize_and_resize(images, target_size=(224, 224)):
    resized_images = []
    for img in images:
        img = Image.fromarray((img * 255).astype(np.uint8))
        img_resized = img.resize(target_size, Image.Resampling.LANCZOS)
        resized_images.append(np.array(img_resized) / 255.0)
    return np.array(resized_images)

final_X_train_resized = normalize_and_resize(final_X_train_modified)
final_X_test_resized = normalize_and_resize(final_X_test_modified)

# Define SimCLR Augmentation Transform
transform_simclr = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomResizedCrop(size=224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1),
    transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Custom Dataset for SimCLR
class SimCLRDataset(Dataset):
    def __init__(self, images, transform):
        self.images = images
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        img_1 = self.transform(img)
        img_2 = self.transform(img)
        return img_1, img_2

train_dataset = SimCLRDataset(final_X_train_resized, transform_simclr)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)

# Define SimCLR Model
class SimCLR(nn.Module):
    def __init__(self, base_encoder, projection_dim):
        super(SimCLR, self).__init__()
        self.encoder = base_encoder
        self.projector = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, projection_dim)
        )

    def forward(self, x):
        h = self.encoder(x)
        z = self.projector(h)
        return z

# Define NT-Xent Loss
class NTXentLoss(nn.Module):
    def __init__(self, batch_size, temperature):
        super(NTXentLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss(reduction="sum")

    def forward(self, z_i, z_j):
        N = z_i.size(0) + z_j.size(0)
        z = torch.cat((z_i, z_j), dim=0)
        sim = torch.matmul(z, z.T) / self.temperature
        mask = ~torch.eye(N, dtype=torch.bool, device=z.device)

        positives = torch.cat([
            torch.diag(sim, z_i.size(0)),
            torch.diag(sim, -z_i.size(0))
        ])

        negatives = sim[mask].view(N, -1)
        logits = torch.cat((positives.unsqueeze(1), negatives), dim=1)
        labels = torch.zeros(N, dtype=torch.long, device=z.device)
        loss = self.criterion(logits, labels) / N
        return loss

# Initialize ResNet-18 Encoder
base_encoder = resnet18(pretrained=True)
base_encoder.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
base_encoder.fc = nn.Identity()

# Initialize SimCLR Model
model = SimCLR(base_encoder, projection_dim=128).to("cuda")
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = NTXentLoss(batch_size=512, temperature=0.5)

# Train SimCLR Model
for epoch in range(100):
    total_loss = 0
    model.train()
    for img_1, img_2 in train_loader:
        img_1, img_2 = img_1.to("cuda"), img_2.to("cuda")
        z_i = model(img_1)
        z_j = model(img_2)

        loss = criterion(z_i, z_j)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch [{epoch+1}/100], Loss: {total_loss/len(train_loader):.4f}")
    
    

# Define Dataset for Classification
class TestDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

# Initialize Training and Test Dataset and DataLoader
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

test_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

train_dataset = TestDataset(final_X_train_resized, train_labels, transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

test_dataset = TestDataset(final_X_test_resized, test_labels, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# Add Classification Head
class ClassificationHead(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(ClassificationHead, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

classification_head = ClassificationHead(input_dim=512, num_classes=len(np.unique(train_labels))).to("cuda")
optimizer_cls = optim.Adam([
    {"params": model.encoder.parameters(), "lr": 1e-5},
    {"params": classification_head.parameters(), "lr": 1e-3},
])
scheduler_cls = StepLR(optimizer_cls, step_size=10, gamma=0.5)
criterion_cls = nn.CrossEntropyLoss()

# Fine-tune Classification Head
for epoch in range(100):
    model.encoder.train()
    classification_head.train()
    total_loss = 0
    correct = 0
    for img, label in DataLoader(train_dataset, batch_size=128, shuffle=True):
        img, label = img.to("cuda"), label.to("cuda")
        features = model.encoder(img)
        logits = classification_head(features)
        loss = criterion_cls(logits, label)

        optimizer_cls.zero_grad()
        loss.backward()
        optimizer_cls.step()

        total_loss += loss.item()
        correct += (logits.argmax(dim=1) == label).sum().item()

    accuracy = correct / len(train_labels)
    scheduler_cls.step()
    print(f"Epoch [{epoch+1}/100], Loss: {total_loss/len(train_loader):.4f}, Accuracy: {accuracy:.4f}")

# Evaluate on Test Dataset
classification_head.eval()
correct = 0
with torch.no_grad():
    for img, label in DataLoader(test_dataset, batch_size=128, shuffle=False):
        img, label = img.to("cuda"), label.to("cuda")
        features = model.encoder(img)
        logits = classification_head(features)
        correct += (logits.argmax(dim=1) == label).sum().item()

test_accuracy = correct / len(test_labels)
print(f"Test Accuracy: {test_accuracy:.4f}")


Epoch [1/100], Loss: 7.6340
Epoch [2/100], Loss: 6.6477
Epoch [3/100], Loss: 6.4410
Epoch [4/100], Loss: 6.3400
Epoch [5/100], Loss: 6.2405
Epoch [6/100], Loss: 6.1682
Epoch [7/100], Loss: 6.0574
Epoch [8/100], Loss: 5.9719
Epoch [9/100], Loss: 5.9086
Epoch [10/100], Loss: 5.7387
Epoch [11/100], Loss: 5.6313
Epoch [12/100], Loss: 5.4963
Epoch [13/100], Loss: 5.3828
Epoch [14/100], Loss: 5.2629
Epoch [15/100], Loss: 5.0523
Epoch [16/100], Loss: 4.9362
Epoch [17/100], Loss: 4.7863
Epoch [18/100], Loss: 4.6578
Epoch [19/100], Loss: 4.4885
Epoch [20/100], Loss: 4.4764
Epoch [21/100], Loss: 4.3528
Epoch [22/100], Loss: 4.1687
Epoch [23/100], Loss: 4.2206
Epoch [24/100], Loss: 3.9252
Epoch [25/100], Loss: 3.8786
Epoch [26/100], Loss: 3.7276
Epoch [27/100], Loss: 3.6913
Epoch [28/100], Loss: 3.6383
Epoch [29/100], Loss: 3.5958
Epoch [30/100], Loss: 3.5365
Epoch [31/100], Loss: 3.4605
Epoch [32/100], Loss: 3.4075
Epoch [33/100], Loss: 3.3468
Epoch [34/100], Loss: 3.2655
Epoch [35/100], Loss: 3

In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models import resnet50
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from PIL import Image

# Load and preprocess the dataset
train_images_path = "brain_train_image_final.npy"
train_labels_path = "brain_train_label.npy"
test_images_path = "brain_test_image_final.npy"
test_labels_path = "brain_test_label.npy"

# Load the data
final_X_train_modified = np.load(train_images_path)[:, 1, :, :]
final_X_test_modified = np.load(test_images_path)[:, 1, :, :]
train_labels = np.load(train_labels_path)
test_labels = np.load(test_labels_path)

# Normalize and Resize Images using Pillow
def normalize_and_resize(images, target_size=(224, 224)):
    resized_images = []
    for img in images:
        img = Image.fromarray((img * 255).astype(np.uint8))
        img_resized = img.resize(target_size, Image.Resampling.LANCZOS)
        resized_images.append(np.array(img_resized) / 255.0)
    return np.array(resized_images)

final_X_train_resized = normalize_and_resize(final_X_train_modified)
final_X_test_resized = normalize_and_resize(final_X_test_modified)

# Define SimCLR Augmentation Transform
transform_simclr = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0), ratio=(3/4, 4/3)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0))], p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Custom Dataset for SimCLR
class SimCLRDataset(Dataset):
    def __init__(self, images, transform):
        self.images = images
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        img_1 = self.transform(img)
        img_2 = self.transform(img)
        return img_1, img_2

train_dataset = SimCLRDataset(final_X_train_resized, transform_simclr)
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)

# Define SimCLR Model
class SimCLR(nn.Module):
    def __init__(self, base_encoder, projection_dim):
        super(SimCLR, self).__init__()
        self.encoder = base_encoder
        self.projector = nn.Sequential(
            nn.Linear(2048, 2048),
            nn.ReLU(),
            nn.Linear(2048, projection_dim)
        )

    def forward(self, x):
        h = self.encoder(x)
        z = self.projector(h)
        return z

# Define NT-Xent Loss
class NTXentLoss(nn.Module):
    def __init__(self, temperature):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss(reduction="mean")

    def forward(self, z_i, z_j):
        N = z_i.size(0) + z_j.size(0)
        z = torch.cat((z_i, z_j), dim=0)
        sim = torch.mm(z, z.T) / self.temperature
        sim = torch.nn.functional.softmax(sim, dim=1)

        labels = torch.cat([
            torch.arange(z_i.size(0), device=z.device),
            torch.arange(z_j.size(0), device=z.device)
        ])
        loss = self.criterion(sim, labels)
        return loss

# Initialize ResNet-50 Encoder
base_encoder = resnet50(pretrained=True)
base_encoder.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
base_encoder.fc = nn.Identity()

# Initialize SimCLR Model
model = SimCLR(base_encoder, projection_dim=128).to("cuda")
optimizer = optim.Adam(model.parameters(), lr=3e-4)
criterion = NTXentLoss(temperature=0.5)
scheduler = CosineAnnealingLR(optimizer, T_max=100)

# Train SimCLR Model
for epoch in range(100):
    total_loss = 0
    model.train()
    for img_1, img_2 in train_loader:
        img_1, img_2 = img_1.to("cuda"), img_2.to("cuda")
        z_i = model(img_1)
        z_j = model(img_2)

        loss = criterion(z_i, z_j)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    scheduler.step()
    print(f"Epoch [{epoch+1}/100], Loss: {total_loss/len(train_loader):.4f}")

# Define Dataset for Classification
class TestDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

# Initialize Training and Test Dataset and DataLoader
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

test_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

train_dataset = TestDataset(final_X_train_resized, train_labels, transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = TestDataset(final_X_test_resized, test_labels, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Add Classification Head
class ClassificationHead(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(ClassificationHead, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

classification_head = ClassificationHead(input_dim=2048, num_classes=len(np.unique(train_labels))).to("cuda")
optimizer_cls = optim.Adam([
    {"params": model.encoder.parameters(), "lr": 1e-5},
    {"params": classification_head.parameters(), "lr": 3e-4},
])
criterion_cls = nn.CrossEntropyLoss()

# Fine-tune Classification Head
for epoch in range(100):
    model.encoder.train()
    classification_head.train()
    total_loss = 0
    correct = 0
    for img, label in train_loader:
        img, label = img.to("cuda"), label.to("cuda")
        features = model.encoder(img)
        logits = classification_head(features)
        loss = criterion_cls(logits, label)

        optimizer_cls.zero_grad()
        loss.backward()
        optimizer_cls.step()

        total_loss += loss.item()
        correct += (logits.argmax(dim=1) == label).sum().item()

    accuracy = correct / len(train_labels)
    print(f"Epoch [{epoch+1}/100], Loss: {total_loss/len(train_loader):.4f}, Accuracy: {accuracy:.4f}")

# Evaluate on Test Dataset
classification_head.eval()
correct = 0
with torch.no_grad():
    for img, label in test_loader:
        img, label = img.to("cuda"), label.to("cuda")
        features = model.encoder(img)
        logits = classification_head(features)
        correct += (logits.argmax(dim=1) == label).sum().item()

test_accuracy = correct / len(test_labels)
print(f"Test Accuracy: {test_accuracy:.4f}")


