In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
from PIL import Image
import os
import PIL
import matplotlib.pyplot as plt
import numpy as np
import random
import torchvision.models as models
from torch.optim import Adam
from tqdm import tqdm

In [3]:
# Define your dataset class

class ProteinExpressionDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.image_files = []

        # List all class folders in the data directory
        self.classes = [folder for folder in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, folder))]
        
        
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        


        # List all image files in the class folders
        for cls in self.classes:
            class_path = os.path.join(data_dir, cls)
            class_images = [os.path.join(cls, f) for f in os.listdir(class_path) if f.endswith(('.png', '.jpg', '.jpeg'))]
            self.image_files.extend(class_images)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        # Load the image and extract the class label from the file path
        image_path = os.path.join(self.data_dir, self.image_files[index])
        image = Image.open(image_path).convert('RGB')
        
        # Extract class label from the file path
        class_folder, _ = os.path.split(self.image_files[index])
        label = self.class_to_idx[class_folder]

        # Apply transformations if provided
        if self.transform is not None:
            image1 = self.transform(image)
            image2 = self.transform(image)


        return image, to_pil_image(image1), to_pil_image(image2), label

In [6]:

#defining base transforms 
base_transforms = transforms.Compose([
            transforms.RandomResizedCrop(size=(224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.5, 0.5, 0.5, 0.5),
            transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
            transforms.RandomRotation(degrees=30),
            transforms.ToTensor()
        ])

# dataset = ProteinExpressionDataset(data_dir="data", transform=base_transforms)

In [7]:
class SimCLRModel(nn.Module):
    def __init__(self, base_encoder, projection_dim=128):
        super(SimCLRModel, self).__init__()
        self.encoder = base_encoder
        self.projector = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, projection_dim)
        )

    def forward(self, x):
        features = self.encoder(x)
        projection = self.projector(features)
        return features, projection
    
class SimCLRLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(SimCLRLoss, self).__init__()
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, z1, z2):
        # Compute cosine similarity and scale by temperature
        sim_matrix = torch.matmul(z1, z2.T) / self.temperature

        # Construct labels for positive and negative pairs
        labels = torch.arange(sim_matrix.size(0)).to(z1.device)

        # Compute contrastive loss
        loss = self.criterion(sim_matrix, labels)
        return loss

In [9]:
dataset = ProteinExpressionDataset(data_dir="data_small", transform=base_transforms)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2)

# Set up the model, optimizer, and loss
resnet_encoder = models.resnet18(pretrained=True)
simclr_model = SimCLRModel(base_encoder=resnet_encoder)
optimizer = Adam(simclr_model.parameters(), lr=0.001)
criterion = SimCLRLoss()

# Training loop
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
simclr_model.to(device)

for epoch in range(num_epochs):
    simclr_model.train()
    total_loss = 0.0

    for images, _, _, _ in tqdm(dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}', dynamic_ncols=True):
        images = images.to(device)

        # Forward pass
        features, projections = simclr_model(images)

        # Compute contrastive loss
        loss = criterion(features, projections)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(dataloader)}")


Epoch 1/10:   0%|          | 0/24 [00:05<?, ?it/s]


RuntimeError: DataLoader worker (pid(s) 7400, 13628) exited unexpectedly