In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet50
from PIL import Image
from tqdm.auto import tqdm
import gc

class ContrastiveTransform:
    def __init__(self, size=224):
        self.train_transform = transforms.Compose([
            transforms.RandomResizedCrop(size=size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, x):
        return self.train_transform(x), self.train_transform(x)

class ImageDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            return self.transform(image)
        return image

class ProjectionHead(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=512, output_dim=128):  # Reduced hidden_dim
        super().__init__()
        self.projection = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.projection(x)

class SimCLR(nn.Module):
    def __init__(self, temperature=0.07, arch='resnet50'):
        super().__init__()
        
        # Use smaller ResNet variant if needed
        self.encoder = resnet50(pretrained=True)
        self.encoder.fc = nn.Identity()
        
        self.projection = ProjectionHead()
        self.temperature = temperature

    @torch.cuda.amp.autocast()  # Enable automatic mixed precision
    def forward(self, x1, x2):
        h1 = self.encoder(x1)
        h2 = self.encoder(x2)
        
        z1 = self.projection(h1)
        z2 = self.projection(h2)
        
        return z1, z2

    def info_nce_loss(self, z1, z2):
        batch_size = z1.shape[0]
        
        z1_norm = F.normalize(z1, dim=1)
        z2_norm = F.normalize(z2, dim=1)
        
        # Process positive pairs in chunks to save memory
        chunk_size = min(batch_size, 128)  # Adjust based on your GPU
        total_loss = 0
        
        for i in range(0, batch_size, chunk_size):
            chunk_end = min(i + chunk_size, batch_size)
            chunk_size_current = chunk_end - i
            
            z1_chunk = z1_norm[i:chunk_end]
            z2_chunk = z2_norm[i:chunk_end]
            
            representations = torch.cat([z1_chunk, z2_chunk], dim=0)
            
            similarity_matrix = torch.matmul(representations, representations.t())
            
            mask = torch.eye(2 * chunk_size_current, dtype=bool, device=z1.device)
            positives = similarity_matrix[~mask].view(2 * chunk_size_current, -1)
            
            pos_mask = torch.zeros_like(positives, dtype=bool)
            pos_mask[:chunk_size_current, chunk_size_current:] = True
            pos_mask[chunk_size_current:, :chunk_size_current] = True
            
            numerator = torch.exp(positives[pos_mask] / self.temperature)
            denominator = torch.sum(torch.exp(similarity_matrix / self.temperature), dim=1)
            
            loss = -torch.log(numerator / denominator[pos_mask])
            total_loss += loss.sum()
        
        return total_loss / (2 * batch_size)

def train_simclr(model, train_loader, optimizer, device, epochs):
    model.train()
    scaler = torch.cuda.amp.GradScaler()  # For mixed precision training
    
    epoch_pbar = tqdm(range(epochs), desc='Training epochs', unit='epoch')
    
    for epoch in epoch_pbar:
        total_loss = 0
        batch_pbar = tqdm(train_loader, desc=f'Epoch {epoch}', 
                         leave=False, unit='batch')
        
        for batch_idx, (x1, x2) in enumerate(batch_pbar):
            # Clear cache periodically
            if batch_idx % 10 == 0:
                torch.cuda.empty_cache()
                gc.collect()
            
            x1, x2 = x1.to(device), x2.to(device)
            
            optimizer.zero_grad()
            
            # Use automatic mixed precision
            with torch.cuda.amp.autocast():
                z1, z2 = model(x1, x2)
                loss = model.info_nce_loss(z1, z2)
            
            # Scale loss and backprop
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            total_loss += loss.item()
            batch_pbar.set_postfix({'loss': f'{loss.item():.4f}'})
            
            # Clear unnecessary tensors
            del x1, x2, z1, z2, loss
        
        avg_loss = total_loss / len(train_loader)
        epoch_pbar.set_postfix({'avg_loss': f'{avg_loss:.4f}'})
        batch_pbar.close()

# Example usage with memory optimizations
if __name__ == "__main__":
    torch.backends.cudnn.benchmark = True  # Optimize CUDA operations
    
    # Setup data with smaller batch size
    transform = ContrastiveTransform()
    dataset = ImageDataset(image_paths=paths, transform=transform)
    train_loader = DataLoader(
        dataset, 
        batch_size=64,  # Reduced batch size
        shuffle=True,
        num_workers=4,
        pin_memory=True  # Faster data transfer to GPU
    )

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SimCLR().to(device)
    
    # Use gradient checkpointing to save memory
    model.encoder.train()
    
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

    train_simclr(model, train_loader, optimizer, device, epochs=100)

  @torch.cuda.amp.autocast()  # Enable automatic mixed precision


OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 0 has a total capacity of 15.77 GiB of which 704.00 KiB is free. Including non-PyTorch memory, this process has 15.76 GiB memory in use. Of the allocated memory 15.07 GiB is allocated by PyTorch, and 339.20 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)