In [None]:
#@title Data download
import kagglehub

# Download latest version
path = kagglehub.dataset_download("bvighnesh27/exaggerated-synthetic-corneal-data")

print("Path to dataset files:", path)

In [None]:
# Recommended minimal installs for this notebook
!pip install -q transformers==4.40.0  # pick an LTS-compatible version
!pip install -q datasets
!pip install -q timm      # Optional: alternate backbones if needed

In [None]:
import os, glob, re, math, time, json, random
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

from transformers import AutoImageProcessor, DPTForDepthEstimation
# AutoImageProcessor works for DPT image preprocessing

In [None]:
# ---------- Utility to pair images and csvs ----------
def _numeric_key(path):
    match = re.search(r"\d+", Path(path).stem)
    return match.group(0) if match else None

def find_pairs(image_dir, csv_dir, image_exts=("png","jpg","jpeg")):
    images = []
    for ext in image_exts:
        images += glob.glob(os.path.join(image_dir, f"**/*.{ext}"), recursive=True)
    csvs = glob.glob(os.path.join(csv_dir, "**/*.csv"), recursive=True)

    csv_map = {_numeric_key(p): p for p in csvs if _numeric_key(p)}
    print("CSV numeric keys found:", csv_map.keys()) # Added for debugging
    pairs = []
    for img in images:
        key = _numeric_key(img)
        if key and key in csv_map:
            pairs.append((img, csv_map[key]))
        else:
            print(f"[WARN] No matching csv found for {img}, skipping")
    pairs = sorted(pairs)
    return pairs

# ---------- Dataset ----------
class CorneaDataset(Dataset):
    """
    Returns (image_tensor, target_tensor)
    target_tensor shape: (n_points*3,) flattened [x1,y1,z1, x2,y2,z2, ...]
    """
    def __init__(self, pairs, image_processor=None, image_size=384, augment=False):
        self.pairs = pairs
        self.processor = image_processor
        self.image_size = image_size
        self.augment = augment
        if len(self.pairs)==0:
            raise ValueError("No (image,csv) pairs provided.")
        # determine n_points (expect consistent across CSVs)
        self._n_points = None
        for _, csv_path in pairs:
            df = pd.read_csv(csv_path)
            if df.shape[1] < 3:
                raise ValueError(f"CSV {csv_path} must have at least 3 columns")
            rows = len(df)
            if self._n_points is None:
                self._n_points = rows
            elif self._n_points != rows:
                raise ValueError(f"Inconsistent points: {csv_path} has {rows} but expected {self._n_points}")
        self.target_dim = self._n_points * 3

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img_path, csv_path = self.pairs[idx]
        image = Image.open(img_path).convert("RGB")
        # use transformer processor if provided (keeps same logic as DPT)
        if self.processor is not None:
            proc = self.processor(images=image, return_tensors="pt")
            # proc contains pixel_values: shape (1, 3, H, W)
            img_tensor = proc["pixel_values"].squeeze(0)
        else:
            # fallback torchvision transforms
            img_tensor = F.interpolate(torch.from_numpy(np.array(image).astype(np.float32)/255.0).permute(2,0,1).unsqueeze(0),
                                       size=(self.image_size, self.image_size)).squeeze(0)

        # read CSV target
        df = pd.read_csv(csv_path)
        # find columns by name or fallback to first three
        cols = df.columns.tolist()
        if set(["x","y","z"]).issubset(set(cols)):
            arr = df[["x","y","z"]].values
        else:
            arr = df.iloc[:, :3].values
        target = torch.tensor(arr.reshape(-1), dtype=torch.float32)  # (n_points*3,)
        return img_tensor, target

# ---------- prepare dataloaders ----------
def prepare_dataloaders(image_dir, csv_dir, image_processor, batch_size=8, image_size=384,
                        train_frac=0.8, seed=42, num_workers=0):
    pairs = find_pairs(image_dir, csv_dir)
    if len(pairs)==0:
        raise ValueError("No pairs found. Check directories.")
    random.Random(seed).shuffle(pairs)
    n_train = int(len(pairs) * train_frac)
    train_pairs = pairs[:n_train]
    val_pairs = pairs[n_train:]

    train_ds = CorneaDataset(train_pairs, image_processor=image_processor, image_size=image_size)
    val_ds = CorneaDataset(val_pairs, image_processor=image_processor, image_size=image_size)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
                            num_workers=num_workers, pin_memory=True)
    info = {
        "n_points": train_ds._n_points,
        "target_dim": train_ds.target_dim,
        "n_train": len(train_ds),
        "n_val": len(val_ds),
        "pairs_all": pairs
    }
    return train_loader, val_loader, info

In [None]:
IMAGE_DIR = os.path.join(path, "corneal_side", "corneal_side")
CSV_DIR   = os.path.join(path, "data", "data")
# Use DPT's processor to prepare pixel values
processor = AutoImageProcessor.from_pretrained("Intel/dpt-large")
train_loader, val_loader, info = prepare_dataloaders(IMAGE_DIR, CSV_DIR, processor, batch_size=8, image_size=384)
print(info)

In [None]:
# List the contents of the image and data directories to help identify the correct paths
print("Contents of IMAGE_DIR:")
!ls -R "$IMAGE_DIR"

print("\nContents of CSV_DIR:")
!ls -R "$CSV_DIR"

In [None]:
class DPTRegressor(nn.Module):
    def __init__(self, dpt_model_name="Intel/dpt-hybrid", target_dim=None, pool_mode="avg"):
        """
        dpt_model_name: HF model checkpoint
        target_dim: n_points * 3 (required)
        pool_mode: 'avg' or 'flatten' (avg pooling reduces dim before MLP)
        """
        super().__init__()
        assert target_dim is not None, "target_dim needed"
        self.dpt = DPTForDepthEstimation.from_pretrained(dpt_model_name)
        # DPT's predicted_depth shape is (B, H, W). We'll add channel dim
        self.pool_mode = pool_mode

        # compute flattened feature dimension after pooling:
        # if pool_mode == 'avg' -> use global avg pooled features dim = 1
        # We'll instead use small adaptive pooling to preserve some spatial info
        self.adaptive_pool = nn.AdaptiveAvgPool2d((16,16))  # tuneable
        pooled_dim = 16*16  # per channel (we'll expand to 1 channel)

        mlp_input_dim = pooled_dim
        # MLP: output target_dim
        self.mlp = nn.Sequential(
            nn.Flatten(),
            nn.Linear(mlp_input_dim, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, target_dim)
        )

    def forward(self, pixel_values):
        """
        pixel_values: (B, 3, H, W) preprocessed by the AutoImageProcessor
        returns: (B, target_dim)
        """
        # Use DPT to produce predicted_depth (the usual forward)
        # DPT returns outputs.predicted_depth shape (B, H, W)
        outputs = self.dpt(pixel_values=pixel_values)
        # predicted_depth shape (B, H, W) -> add channel dim
        depth = outputs.predicted_depth.unsqueeze(1)  # (B,1,H,W)

        # Adaptive pool to (B,1,16,16)
        p = self.adaptive_pool(depth)   # (B,1,16,16)
        x = p  # keep single channel
        # Flatten & MLP
        out = self.mlp(x)
        return out

In [None]:
import sklearn.metrics as skm

def train_fit(model, train_loader, val_loader, info, device,
              epochs=10, lr=3e-4, weight_decay=1e-2, ckpt_path="/content/best_vit_decoder.pth"):
    target_dim = info["target_dim"]
    criterion = nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    best_val = float("inf")

    history = {"train_loss":[], "val_loss":[], "train_mae":[], "val_mae":[], "train_r2":[], "val_r2":[]}

    for epoch in range(1, epochs+1):
        model.train()
        train_losses = []
        preds_all = []
        trues_all = []
        for imgs, targets in tqdm(train_loader, desc=f"Train E{epoch}"):
            imgs = imgs.to(device)
            targets = targets.to(device)  # (B, target_dim)
            optimizer.zero_grad()
            out = model(imgs)             # (B, target_dim)
            loss = criterion(out, targets)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

            preds_all.append(out.detach().cpu().numpy())
            trues_all.append(targets.detach().cpu().numpy())

        # aggregate train metrics
        preds_np = np.concatenate(preds_all, axis=0)
        trues_np = np.concatenate(trues_all, axis=0)
        train_mse = np.mean(train_losses)
        train_mae = skm.mean_absolute_error(trues_np, preds_np)
        try:
            train_r2 = skm.r2_score(trues_np, preds_np)
        except:
            train_r2 = float("nan")

        # validation
        model.eval()
        val_losses = []
        preds_all = []
        trues_all = []
        with torch.no_grad():
            for imgs, targets in tqdm(val_loader, desc=f"Val E{epoch}"):
                imgs = imgs.to(device)
                targets = targets.to(device)
                out = model(imgs)
                loss = criterion(out, targets)
                val_losses.append(loss.item())
                preds_all.append(out.cpu().numpy())
                trues_all.append(targets.cpu().numpy())

        preds_np = np.concatenate(preds_all, axis=0)
        trues_np = np.concatenate(trues_all, axis=0)
        val_mse = np.mean(val_losses)
        val_mae = skm.mean_absolute_error(trues_np, preds_np)
        try:
            val_r2 = skm.r2_score(trues_np, preds_np)
        except:
            val_r2 = float("nan")

        print(f"[Epoch {epoch}] Train MSE={train_mse:.5f} MAE={train_mae:.5f} R2={train_r2:.4f} | Val MSE={val_mse:.5f} MAE={val_mae:.5f} R2={val_r2:.4f}")
        history["train_loss"].append(train_mse)
        history["val_loss"].append(val_mse)
        history["train_mae"].append(train_mae)
        history["val_mae"].append(val_mae)
        history["train_r2"].append(train_r2)
        history["val_r2"].append(val_r2)

        scheduler.step()
        if val_mse < best_val:
            best_val = val_mse
            torch.save({
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "val_loss": val_mse,
                "info": info
            }, ckpt_path)
            print(f"[Checkpoint] saved {ckpt_path} val_mse={val_mse:.5f}")

    return history


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
target_dim = info["target_dim"]
model = DPTRegressor(dpt_model_name="Intel/dpt-large", target_dim=target_dim).to(device)

# Optionally freeze DPT backbone initially to train only the head:
for p in model.dpt.parameters():
    p.requires_grad = False
# unfreeze head already trainable

history = train_fit(model, train_loader, val_loader, info, device,
                    epochs=10, lr=3e-4, weight_decay=1e-2, ckpt_path="/content/best_dpt_cornea.pth")

In [None]:
def infer_and_save(model, dataloader, info, device, out_csv="/content/predictions.csv"):
    model.eval()
    rows = []
    pairs = dataloader.dataset.pairs  # (img_path, csv_path)
    with torch.no_grad():
        for i, (imgs, targets) in enumerate(tqdm(dataloader, desc="Infer")):
            imgs = imgs.to(device)
            outs = model(imgs)  # (B, target_dim)
            outs_np = outs.cpu().numpy()
            targets_np = targets.numpy()
            batch_size = outs_np.shape[0]
            for b in range(batch_size):
                img_path, csv_path = pairs[i * dataloader.batch_size + b]
                # reshape predicted values back to Nx3 for convenience
                n_points = info["n_points"]
                pred_flat = outs_np[b]
                pred_points = pred_flat.reshape(n_points, 3)
                # flatten to a string or store columns
                row = {
                    "image": os.path.basename(img_path),
                    "csv": os.path.basename(csv_path),
                    "pred_flat": json.dumps(pred_flat.tolist())
                }
                rows.append(row)
    df = pd.DataFrame(rows)
    df.to_csv(out_csv, index=False)
    print("Saved predictions to", out_csv)

# Example usage:
model_ckpt = torch.load("/content/best_dpt_cornea.pth", map_location=device, weights_only=False)
model.load_state_dict(model_ckpt["model_state_dict"])
infer_and_save(model, val_loader, info, device=device, out_csv="/content/predictions_val.csv")

In [None]:
from mpl_toolkits.mplot3d import Axes3D

def plot_gt_vs_pred(image_path, gt_csv, pred_flat):
    # load image
    img = Image.open(image_path).convert("RGB")
    gt = pd.read_csv(gt_csv)
    # get points Nx3
    if set(["x","y","z"]).issubset(set(gt.columns)):
        gt_points = gt[["x","y","z"]].values
    else:
        gt_points = gt.iloc[:,:3].values
    pred_points = np.array(pred_flat).reshape(-1,3)

    fig = plt.figure(figsize=(12,5))
    ax0 = fig.add_subplot(121)
    ax0.imshow(img); ax0.axis("off"); ax0.set_title("Image")
    ax1 = fig.add_subplot(122, projection='3d')
    ax1.scatter(gt_points[:,0], gt_points[:,1], gt_points[:,2], s=2, c='blue', label='GT')
    ax1.scatter(pred_points[:,0], pred_points[:,1], pred_points[:,2], s=2, c='red', label='Pred')
    ax1.legend()
    plt.show()

# Example: pick first val sample
img_path, csv_path = val_loader.dataset.pairs[0]
# load pred from saved csv or infer on the fly:
with torch.no_grad():
    img_tensor, target = val_loader.dataset[0]
    pred_flat = model(img_tensor.unsqueeze(0).to(device)).cpu().numpy().squeeze()
plot_gt_vs_pred(img_path, csv_path, pred_flat)


In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(history["train_loss"], label="Train Loss", marker='o')
plt.plot(history["val_loss"], label="Validation Loss", marker='o')
plt.title("Training vs Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.legend()
plt.grid(True)
plt.show()
