# SimCLR Downstream Task: UC Merced Land Use Classification

This notebook adapts a pretrained SimCLR (ResNet-101) model for the UC Merced Land Use dataset.
It includes:
1.  **Data Loading**: Automatic download of UC Merced dataset.
2.  **Model Loading**: Loading the pretrained SimCLR weights.
3.  **Fine-Tuning**: Training the full model on the new dataset.
4.  **Visualization**: 3D t-SNE animation of the embeddings.

**Prerequisites**:
- Ensure you have uploaded `simclr_model_RN101.pth` to the Colab runtime (Files tab on the left).

In [None]:
!pip install tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader
import numpy as np
import os
import requests
import zipfile
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import matplotlib.animation as animation
from sklearn.metrics import classification_report, accuracy_score
from sklearn.model_selection import train_test_split
import copy
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Model Definition (SimCLR)

In [None]:
LARGE_NUM = 1e9

def nt_xent(z: torch.Tensor, perm: torch.Tensor, tau: float) -> torch.Tensor:
    features = F.normalize(z, dim=1)
    sim = features @ features.T
    sim.fill_diagonal_(-LARGE_NUM)
    sim /= tau
    return F.cross_entropy(sim, perm)

class SimCLR(nn.Module):
    def __init__(self, backbone: nn.Module, tau: float, feat_dim: int = 256):
        super(SimCLR, self).__init__()
        self.backbone = backbone
        self.tau = tau
        z_dim = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.projection_head = nn.Sequential(
            nn.Linear(z_dim, z_dim, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(z_dim, feat_dim, bias=False)
        )

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        b = x1.size(0)
        xp = torch.cat((x1, x2))
        perm = torch.cat((torch.arange(b) + b, torch.arange(b)), dim=0).to(x1.device)
        h = self.backbone(xp)
        z = self.projection_head(h)
        return nt_xent(z, perm, tau=self.tau)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        self.eval()
        with torch.no_grad():
            h = self.backbone(x)
        return h

## 2. Data Loading (UC Merced)

In [None]:
def download_uc_merced(root):
    url = "http://weegee.vision.ucmerced.edu/datasets/UCMerced_LandUse.zip"
    target_dir = os.path.join(root, "UCMerced_LandUse")
    if os.path.exists(target_dir):
        print(f"Dataset folder found at {target_dir}")
        images_dir = os.path.join(target_dir, "Images")
        if os.path.exists(images_dir):
             return images_dir
        return target_dir

    print(f"Downloading UC Merced dataset from {url}...")
    os.makedirs(root, exist_ok=True)
    zip_path = os.path.join(root, "UCMerced_LandUse.zip")
    
    response = requests.get(url, stream=True)
    total_size_in_bytes = int(response.headers.get('content-length', 0))
    block_size = 1024
    progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
    with open(zip_path, 'wb') as file:
        for data in response.iter_content(block_size):
            progress_bar.update(len(data))
            file.write(data)
    progress_bar.close()
    
    print("Extracting...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(root)
    
    images_dir = os.path.join(root, "UCMerced_LandUse", "Images")
    if not os.path.exists(images_dir):
        return os.path.join(root, "UCMerced_LandUse")
    return images_dir

def get_uc_merced_loader(root, batch_size=64):
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    data_dir = download_uc_merced(root)
    dataset = datasets.ImageFolder(root=data_dir, transform=transform)
    train_idx, test_idx = train_test_split(list(range(len(dataset))), test_size=0.2, random_state=42, stratify=dataset.targets)
    train_ds = torch.utils.data.Subset(dataset, train_idx)
    test_ds = torch.utils.data.Subset(dataset, test_idx)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2)
    return train_loader, test_loader, dataset.classes

## 3. Fine-Tuning Setup

In [None]:
class FineTuneModel(nn.Module):
    def __init__(self, backbone, num_classes):
        super(FineTuneModel, self).__init__()
        self.backbone = backbone
        num_ftrs = 2048 
        self.classifier = nn.Linear(num_ftrs, num_classes)
        
    def forward(self, x):
        x = self.backbone(x)
        x = self.classifier(x)
        return x

def train_model(model, dataloaders, criterion, optimizer, num_epochs=15):
    since = time.time()
    val_acc_history = []
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    features = model.backbone(inputs)
                    outputs = model.classifier(features)
                    loss = criterion(outputs, labels)
                    _, preds = torch.max(outputs, 1)
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)
        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

## 4. Main Execution

In [None]:
MODEL_PATH = "simclr_model_RN101.pth"
DATA_DIR = "./data"
BATCH_SIZE = 64

# 1. Load Data
train_loader, test_loader, class_names = get_uc_merced_loader(DATA_DIR, BATCH_SIZE)
dataloaders = {'train': train_loader, 'val': test_loader}

# 2. Load Pretrained Model
if not os.path.exists(MODEL_PATH):
    print(f"WARNING: {MODEL_PATH} not found. Please upload it to Colab.")
else:
    print("Loading pretrained SimCLR model...")
    backbone = models.resnet101(weights=None)
    simclr_model = SimCLR(backbone=backbone, tau=0.1)
    try:
        state_dict = torch.load(MODEL_PATH, map_location=device)
        simclr_model.load_state_dict(state_dict)
    except:
        full_model = torch.load(MODEL_PATH, map_location=device)
        simclr_model.load_state_dict(full_model.state_dict())
    
    # 3. Setup Fine-Tuning
    backbone = simclr_model.backbone
    for param in backbone.parameters():
        param.requires_grad = True
        
    model = FineTuneModel(backbone, num_classes=len(class_names)).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam([
        {'params': model.backbone.parameters(), 'lr': 1e-5},
        {'params': model.classifier.parameters(), 'lr': 1e-3}
    ])
    
    # 4. Train
    model, hist = train_model(model, dataloaders, criterion, optimizer, num_epochs=15)

## 5. Visualization

In [None]:
def visualize_tsne(model, loader, device, save_path="tsne_animation.gif"):
    print("Extracting features for t-SNE...")
    model.eval()
    features_list = []
    labels_list = []
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device)
            # Extract from backbone
            feats = model.backbone(inputs)
            features_list.append(feats.cpu().numpy())
            labels_list.append(labels.numpy())
    features = np.concatenate(features_list, axis=0)
    labels = np.concatenate(labels_list, axis=0)
    
    print("Computing 3D t-SNE...")
    tsne = TSNE(n_components=3, perplexity=30, n_iter=1000, random_state=42)
    projections = tsne.fit_transform(features)
    
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    classes = np.unique(labels)
    colors = plt.cm.tab20(np.linspace(0, 1, len(classes)))
    
    scatters = []
    for i, c in enumerate(classes):
        mask = labels == c
        scatters.append(ax.scatter(projections[mask, 0], projections[mask, 1], projections[mask, 2], 
                                   label=f"Class {c}", s=20, alpha=0.6, color=colors[i]))
    ax.set_title("3D t-SNE of Fine-Tuned Embeddings")
    
    def update(angle):
        ax.view_init(elev=30, azim=angle)
        return scatters

    ani = animation.FuncAnimation(fig, update, frames=np.arange(0, 360, 2), interval=50, blit=False)
    ani.save(save_path, writer='pillow', fps=20)
    print(f"Animation saved to {save_path}")
    plt.show()

visualize_tsne(model, test_loader, device)