# SimCLR Model Building and Training
This notebook provides a full, reproducible SimCLR (contrastive) training pipeline that adapts to whether your dataset is tabular or image-based.

Primary features:
- Detects cleaned dataset `ztf_image_search_results_full_cleaned.csv` in the workspace root.
- Uses an image-based ResNet encoder when an `image_path` column exists; otherwise uses an MLP encoder for tabular features.
- Implements contrastive augmentations for both modalities, NT-Xent loss, training loop, checkpointing, and a linear evaluation cell.

Run each cell sequentially. If you want me to start training now, confirm and I will run the training cells (this can be long-running).

In [None]:
# Optional: install required packages if missing (uncomment to run).
# Note: install the appropriate torch build for your machine (cpu/cuda). Example (CPU):
# !pip install -q torch torchvision pandas scikit-learn tqdm
# If you need umap or other extras for visualization:
# !pip install -q umap-learn seaborn matplotlib

In [None]:
# Imports and helper functions
import os
import math
import time
from datetime import datetime
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
# For image support if present
from PIL import Image
import torchvision.transforms as T
import torchvision.models as models

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

In [None]:
# Configuration
DATA_PATH = 'ztf_image_search_results_full_cleaned.csv'
CHECKPOINT_DIR = 'checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
BATCH_SIZE = 256
EPOCHS = 50
LR = 1e-3
PROJ_DIM = 128  # projection head output dim
TEMPERATURE = 0.5
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

In [None]:
# Load dataframe and auto-detect data modality (image vs tabular)
assert os.path.exists(DATA_PATH), f"Data file {DATA_PATH} not found in workspace."
df = pd.read_csv(DATA_PATH)
print('Loaded', DATA_PATH, 'with shape', df.shape)
# Detect common image path column names
image_cols = [c for c in df.columns if 'image' in c.lower() or 'img' in c.lower() or 'filepath' in c.lower()]
has_images = len(image_cols) > 0
print('Detected image columns:', image_cols)
# Try to find object id column to group time-series if available
id_cols = [c for c in df.columns if c.lower() in ('objectid','objid','object_id','oid','sourceid','source_id')]
obj_id_col = id_cols[0] if len(id_cols) else None
print('Using object id column:', obj_id_col)

## Augmentations and Contrastive Dataset
The dataset returns two augmented views for each sample. For images we use torchvision transforms; for tabular data we use simple numeric augmentations (gaussian noise, feature dropout, scaling jitter). Adjust augmentations to taste for astronomy domain knowledge.

In [None]:
# Tabular augmentations
def tabular_augment(x, noise_std=0.01, drop_prob=0.05, scale_jitter=0.02):
    x = x.copy().astype('float32')
    # gaussian noise proportional to value scale
    noise = np.random.normal(0, noise_std, size=x.shape) * (np.abs(x) + 1e-6)
    x = x + noise
    # random feature dropout
    mask = np.random.rand(*x.shape) > drop_prob
    x = x * mask
    # small multiplicative jitter
    jitter = 1.0 + np.random.normal(0, scale_jitter, size=x.shape)
    x = x * jitter
    return x

# Image augmentations (default to common SimCLR)
def make_image_transform(sz=224):
    return T.Compose([
        T.RandomResizedCrop(sz),
        T.RandomHorizontalFlip(p=0.5),
        T.ColorJitter(0.4,0.4,0.4,0.1),
        T.RandomGrayscale(p=0.2),
        T.ToTensor(),
        T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    ])

# Contrastive Dataset that returns (x_i, x_j) pair and optionally a label if available
class ContrastiveDataset(Dataset):
    def __init__(self, df, feature_cols=None, image_col=None, obj_id_col=None, transform=None, tabular_aug=None):
        self.df = df.reset_index(drop=True)
        self.image_col = image_col
        self.obj_id_col = obj_id_col
        self.transform = transform
        self.tabular_aug = tabular_aug
        if feature_cols is None and image_col is None:
            # automatic feature selection: numeric columns
            self.feature_cols = self.df.select_dtypes(include=[np.number]).columns.tolist()
        else:
            self.feature_cols = feature_cols
        # keep labels if present
        self.label_col = None
        for c in ('label','target','class'):
            if c in self.df.columns:
                self.label_col = c
                break
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        if self.image_col is not None and pd.notna(row[self.image_col]):
            path = row[self.image_col]
            img = Image.open(path).convert('RGB')
            x1 = self.transform(img)
            x2 = self.transform(img)
            return x1, x2, (row[self.label_col] if self.label_col is not None else -1)
        else:
            x = row[self.feature_cols].values.astype('float32')
            x1 = tabular_augment(x) if self.tabular_aug is None else self.tabular_aug(x)
            x2 = tabular_augment(x) if self.tabular_aug is None else self.tabular_aug(x)
            return torch.from_numpy(x1), torch.from_numpy(x2), (row[self.label_col] if self.label_col is not None else -1)

## Model: flexible encoder + projection head
- For images use a ResNet50 backbone (no final FC) and a small MLP projection head.
- For tabular data use a 3-layer MLP encoder and a projection head.

In [None]:
# Encoder and SimCLR model definitions
class TabularEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=512, out_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, out_dim)
        )
    def forward(self, x):
        return self.net(x)

class ProjectionHead(nn.Module):
    def __init__(self, in_dim, proj_dim=PROJ_DIM):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, in_dim),
            nn.ReLU(),
            nn.Linear(in_dim, proj_dim)
        )
    def forward(self, x):
        return self.net(x)

class SimCLRModel(nn.Module):
    def __init__(self, encoder, feat_dim, proj_dim=PROJ_DIM):
        super().__init__()
        self.encoder = encoder
        self.proj = ProjectionHead(feat_dim, proj_dim)
    def forward(self, x):
        h = self.encoder(x)
        z = self.proj(h)
        z = F.normalize(z, dim=1)
        return h, z

In [None]:
# NT-Xent loss (normalized temperature-scaled cross entropy)
def nt_xent_loss(z_i, z_j, temperature=TEMPERATURE):
    # z_i, z_j: (N, D) normalized
    N = z_i.size(0)
    z = torch.cat([z_i, z_j], dim=0)  # 2N x D
    sim = torch.matmul(z, z.T)  # 2N x 2N
    sim /= temperature
    # mask to remove similarity with itself
    mask = (~torch.eye(2*N, dtype=torch.bool)).to(device)
    exp_sim = torch.exp(sim) * mask
    # positive pairs: i with i+N and vice versa
    positives = torch.exp((torch.sum(z_i * z_j, dim=-1) / temperature))
    positives = torch.cat([positives, positives], dim=0)
    denom = exp_sim.sum(dim=1)
    loss = -torch.log(positives / denom)
    return loss.mean()

## Prepare dataset and model instance
Create the dataset and choose encoder depending on modality.

In [None]:
# Split dataset (stratify if label exists)
if 'label' in df.columns or 'target' in df.columns or 'class' in df.columns:
    strat_col = [c for c in ('label','target','class') if c in df.columns][0]
    train_df, val_df = train_test_split(df, test_size=0.1, random_state=RANDOM_SEED, stratify=df[strat_col])
else:
    train_df, val_df = train_test_split(df, test_size=0.1, random_state=RANDOM_SEED)

if has_images:
    sz = 224
    transform = make_image_transform(sz=sz)
    train_ds = ContrastiveDataset(train_df, image_col=image_cols[0], obj_id_col=obj_id_col, transform=transform)
    val_ds = ContrastiveDataset(val_df, image_col=image_cols[0], obj_id_col=obj_id_col, transform=transform)
    # encoder: ResNet50 without final layer
    base = models.resnet50(pretrained=False)
    # remove fc
    modules = list(base.children())[:-1]
    class ResNetEncoder(nn.Module):
        def __init__(self, modules, feat_dim=2048):
            super().__init__()
            self.backbone = nn.Sequential(*modules)
            self.feat_dim = feat_dim
        def forward(self, x):
            x = self.backbone(x)
            x = torch.flatten(x, 1)
            return x
    encoder = ResNetEncoder(modules).to(device)
    feat_dim = encoder.feat_dim
else:
    feature_cols = train_df.select_dtypes(include=[np.number]).columns.tolist()
    # drop label column if present from features
    for c in ('label','target','class'):
        if c in feature_cols:
            feature_cols.remove(c)
    print('Tabular feature dim:', len(feature_cols))
    train_ds = ContrastiveDataset(train_df, feature_cols=feature_cols, obj_id_col=obj_id_col)
    val_ds = ContrastiveDataset(val_df, feature_cols=feature_cols, obj_id_col=obj_id_col)
    encoder = TabularEncoder(input_dim=len(feature_cols)).to(device)
    feat_dim = encoder.net[-1].out_features if hasattr(encoder.net[-1],'out_features') else 256

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=0)

model = SimCLRModel(encoder, feat_dim, proj_dim=PROJ_DIM).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
print('Model prepared. Params:', sum(p.numel() for p in model.parameters()))

## Training loop (contrastive pretraining)
The loop runs forward for two views, computes NT-Xent loss and updates weights. Checkpointing is done every few epochs.

In [None]:
# Training loop (lightweight, modify EPOCHS/BATCH_SIZE as needed)
best_loss = float('inf')
for epoch in range(1, EPOCHS+1):
    model.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f'Epoch {epoch}/{EPOCHS}')
    for (x1, x2, _) in pbar:
        # move to device and flatten for tabular case
        if not has_images:
            x1 = x1.float().to(device)
            x2 = x2.float().to(device)
        else:
            x1 = x1.to(device)
            x2 = x2.to(device)
        optimizer.zero_grad()
        _, z1 = model(x1)
        _, z2 = model(x2)
        loss = nt_xent_loss(z1, z2)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    epoch_loss = running_loss / len(train_loader)
    print(f'Epoch {epoch} average loss: {epoch_loss:.4f}')
    # checkpoint
    ckpt_path = os.path.join(CHECKPOINT_DIR, f'simclr_epoch_{epoch}.pt')
    torch.save({'epoch': epoch, 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict()}, ckpt_path)
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, 'simclr_best.pt'))
print('Training complete. Best loss:', best_loss)

## Linear evaluation (optional)
If your dataset contains labels (column `label`/`target`/`class`), train a linear classifier on frozen encoder features to evaluate learned representations.

In [None]:
# Linear evaluation: extract features and train a small linear classifier (if labels exist)
label_col = None
for c in ('label','target','class'):
    if c in df.columns:
        label_col = c
        break
if label_col is None:
    print('No label column found; skip linear evaluation.')
else:
    # prepare feature extractors
    model.eval()
    def extract_features(dataloader):
        feats, labs = [], []
        with torch.no_grad():
            for x1, x2, y in dataloader:
                x = x1.float().to(device) if not has_images else x1.to(device)
                h, _ = model(x)
                feats.append(h.cpu().numpy())
                labs.append(np.array(y))
        feats = np.concatenate(feats, axis=0)
        labs = np.concatenate(labs, axis=0)
        return feats, labs
    tr_feats, tr_labs = extract_features(DataLoader(train_ds, batch_size=256, num_workers=0))
    val_feats, val_labs = extract_features(DataLoader(val_ds, batch_size=256, num_workers=0))
    # train a logistic regression on top
    from sklearn.linear_model import LogisticRegression
    from sklearn.metrics import accuracy_score
    clf = LogisticRegression(max_iter=2000)
    clf.fit(tr_feats, tr_labs)
    preds = clf.predict(val_feats)
    acc = accuracy_score(val_labs, preds)
    print('Linear eval accuracy:', acc)

In [None]:
# Validation metrics & confusion matrix plotting (saves figures to `figures/`)
import os
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, f1_score, accuracy_score, precision_score, recall_score, roc_auc_score, roc_curve
os.makedirs('figures', exist_ok=True)

if label_col is None:
    print('No label column available â€” skipping detailed validation metrics.')
else:
    # Ensure we have extracted validation features and labels
    try:
        val_feats, val_labs
    except NameError:
        val_feats, val_labs = extract_features(DataLoader(val_ds, batch_size=256, num_workers=0))
    try:
        preds
    except NameError:
        preds = clf.predict(val_feats)
    # If classifier supports predict_proba, get probabilities for ROC AUC when binary
    probs = None
    if hasattr(clf, 'predict_proba'):
        try:
            probs = clf.predict_proba(val_feats)
        except Exception:
            probs = None
    # Basic scores
    acc = accuracy_score(val_labs, preds)
    precision = precision_score(val_labs, preds, average='weighted', zero_division=0)
    recall = recall_score(val_labs, preds, average='weighted', zero_division=0)
    f1 = f1_score(val_labs, preds, average='weighted', zero_division=0)
    print(f'Accuracy: {acc:.4f}  Precision: {precision:.4f}  Recall: {recall:.4f}  F1 (weighted): {f1:.4f}')
    # Classification report
    creport = classification_report(val_labs, preds, zero_division=0)
    print('Classification report:\n', creport)
    # Save report to file
    with open('figures/classification_report.txt', 'w') as f:
        f.write('Accuracy: {:.4f}\nPrecision: {:.4f}\nRecall: {:.4f}\nF1 (weighted): {:.4f}\n\n'.format(acc, precision, recall, f1))
        f.write(creport)
    # Confusion matrix
    labels_unique = np.unique(val_labs)
    cm = confusion_matrix(val_labs, preds, labels=labels_unique)
    fig, ax = plt.subplots(figsize=(6,5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax, xticklabels=labels_unique, yticklabels=labels_unique)
    ax.set_xlabel('Predicted')
    ax.set_ylabel('True')
    ax.set_title('Confusion Matrix')
    fig.savefig('figures/confusion_matrix.png', bbox_inches='tight')
    plt.show()
    # ROC AUC for binary tasks if probabilities available
    if probs is not None and len(labels_unique) == 2:
        try:
            pos_prob = probs[:, 1]
            auc = roc_auc_score(val_labs, pos_prob)
            fpr, tpr, _ = roc_curve(val_labs, pos_prob)
            plt.figure(); plt.plot(fpr, tpr, label=f'ROC AUC = {auc:.3f}'); plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate'); plt.title('ROC Curve'); plt.legend(); plt.grid(True)
            plt.savefig('figures/roc_curve.png', bbox_inches='tight')
            plt.show()
        except Exception as e:
            print('Could not compute ROC AUC:', e)
    print('Saved validation figures and report to `figures/`')

In [None]:
# Model comparison: train several classifiers on frozen features and visualize accuracy/F1
import time
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
os.makedirs('figures', exist_ok=True)

# Ensure features and labels exist (extract if necessary)
try:
    tr_feats, tr_labs
except NameError:
    print('Extracting features for train/val for comparison...')
    tr_feats, tr_labs = extract_features(DataLoader(train_ds, batch_size=256, num_workers=0))
    val_feats, val_labs = extract_features(DataLoader(val_ds, batch_size=256, num_workers=0))

# Define candidate models
models = {
    'LogisticRegression': LogisticRegression(max_iter=2000),
    'RandomForest': RandomForestClassifier(n_estimators=200, random_state=RANDOM_SEED, n_jobs=-1),
    'GradientBoosting': GradientBoostingClassifier(n_estimators=100, random_state=RANDOM_SEED),
    'SVM': SVC(kernel='rbf', probability=True, random_state=RANDOM_SEED),
    'KNN': KNeighborsClassifier(n_neighbors=5),
    'MLP': MLPClassifier(hidden_layer_sizes=(256,), max_iter=500, random_state=RANDOM_SEED)
}

results = []
for name, m in models.items():
    print(f'Training {name}...')
    t0 = time.time()
    try:
        m.fit(tr_feats, tr_labs)
    except Exception as e:
        print(f'Error training {name}:', e)
        continue
    dur = time.time() - t0
    preds_tmp = m.predict(val_feats)
    acc_tmp = accuracy_score(val_labs, preds_tmp)
    f1_tmp = f1_score(val_labs, preds_tmp, average='weighted', zero_division=0)
    prec_tmp = precision_score(val_labs, preds_tmp, average='weighted', zero_division=0)
    rec_tmp = recall_score(val_labs, preds_tmp, average='weighted', zero_division=0)
    print(f'{name}: acc={acc_tmp:.4f}, f1={f1_tmp:.4f}, time={dur:.1f}s')
    results.append({'model': name, 'accuracy': acc_tmp, 'f1_weighted': f1_tmp, 'precision': prec_tmp, 'recall': rec_tmp, 'time_s': dur})

# Save results and plot
res_df = pd.DataFrame(results).sort_values('accuracy', ascending=False).reset_index(drop=True)
res_df.to_csv('figures/model_comparison.csv', index=False)
plt.figure(figsize=(8,4))
sns.barplot(data=res_df, x='model', y='accuracy', palette='viridis')
plt.ylim(0,1)
plt.title('Model Accuracy Comparison (validation)')
plt.ylabel('Accuracy')
plt.xlabel('Model')
plt.tight_layout()
plt.savefig('figures/model_comparison_accuracy.png', bbox_inches='tight')
plt.show()

plt.figure(figsize=(8,4))
sns.barplot(data=res_df, x='model', y='f1_weighted', palette='magma')
plt.ylim(0,1)
plt.title('Model F1 (weighted) Comparison (validation)')
plt.ylabel('F1 (weighted)')
plt.xlabel('Model')
plt.tight_layout()
plt.savefig('figures/model_comparison_f1.png', bbox_inches='tight')
plt.show()

print('Saved model comparison CSV and figures to figures/')

## Save final encoder and notes
The pretraining creates `simclr_best.pt` in the `checkpoints/` folder. Use the encoder weights for downstream tasks. Remember: to avoid leakage, always fit scalers or other preprocessors only on training splits when performing downstream supervised training.