In [150]:
# 1. Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
# !ls /content/drive/MyDrive


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [151]:
# 2. Navigate to your project directory
import os

# Replace with your actual path inside Google Drive
# curated_path = '/content/drive/MyDrive/curated_enface'
curated_path = "/content/drive/MyDrive/curated_enface"  # correct spelling!
!ls /content/drive/MyDrive/curated_enface

# save_path = '/content/drive/MyDrive/curated_enface/supreme_vit_mae.pth'
save_path = "/content/drive/MyDrive/curated_enface/xsupreme_vit_mae.pth"

os.chdir(curated_path )

# 3. Confirm you're in the right place
print("Current directory:", os.getcwd())

1002_left   1007_right	1011_left   1020_left	1025_right
1003_right  1008_left	1012_right  1020_right	1026_right
1004_right  1008_right	1015_left   1021_left	1028_right
1005_left   1009_left	1017_left   1022_right	1032_left
1006_left   1010_right	1019_left   1024_left	1032_right
Current directory: /content/drive/MyDrive/curated_enface


In [153]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import math

# =============== Utility: 2D Sin-Cos Positional Embedding ===============
def get_2d_sincos_pos_embed(embed_dim, grid_size):
    """Generate 2D positional embeddings with correct dimensions"""
    assert embed_dim % 2 == 0, "Embedding dimension must be even"
    # Generate grid coordinates for both axes
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)

    # Get embeddings for each axis
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_h)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_w)

    # Create 2D embeddings through broadcasting
    emb = np.concatenate([
        np.repeat(emb_h[:, None, :], grid_size, axis=1),
        np.repeat(emb_w[None, :, :], grid_size, axis=0)
    ], axis=-1)

    return emb.reshape(-1, embed_dim)


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """Generate 1D positional embeddings from grid positions"""
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float32)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega  # Frequency calculation

    # Ensure pos is treated as array even with single value
    pos = np.asarray(pos).reshape(-1)
    out = np.outer(pos, omega)  # Vectorized calculation

    return np.concatenate([np.sin(out), np.cos(out)], axis=1)

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# =============== Dataset Loader ===============
class EnfaceMaskedDataset(Dataset):
    def __init__(self, root_dir, image_size=512):
        self.root_dir = root_dir
        self.image_size = image_size
        if not os.path.isdir(self.root_dir):
            raise ValueError(f"Provided root_dir does not exist: {self.root_dir}")
        self.patients = [d for d in os.listdir(self.root_dir) if os.path.isdir(os.path.join(self.root_dir, d))]
        self.augment = transforms.Compose([
            transforms.RandomHorizontalFlip(),
        ])
        self.to_tensor = transforms.ToTensor()
        self.resize = transforms.Resize((image_size, image_size))

    def __len__(self):
        return len(self.patients)

    def __getitem__(self, idx):
        patient = self.patients[idx]
        enface_path = os.path.join(self.root_dir, patient, 'raw enface')
        if not os.path.exists(enface_path):
            raise FileNotFoundError(f"Missing 'raw enface' folder for patient: {patient}")
        enface_file = next((f for f in os.listdir(enface_path) if f.lower().endswith('.jpg')), None)
        if enface_file is None:
            raise FileNotFoundError(f"No JPG image found in: {enface_path}")
        enface_img = Image.open(os.path.join(enface_path, enface_file)).convert('RGB')

        enface_img = self.augment(enface_img)
        enface_tensor = self.to_tensor(self.resize(enface_img)) * 2 - 1
        return enface_tensor

# =============== Relative Positional Embeddings ===============
def get_rel_pos_index(h, w):
    coords = torch.stack(torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij"))
    coords_flat = coords.flatten(1)
    rel_coords = coords_flat[:, :, None] - coords_flat[:, None, :]
    rel_coords = rel_coords.permute(1, 2, 0).contiguous()
    rel_coords[:, :, 0] += h - 1
    rel_coords[:, :, 1] += w - 1
    rel_coords[:, :, 0] *= 2 * w - 1
    return rel_coords.sum(-1)

class RelativeAttention(nn.Module):
    def __init__(self, dim, num_heads, grid_size):
        super().__init__()
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)

        # Generate relative index for expanded grid
        self.rel_index = get_rel_pos_index(grid_size, grid_size)
        self.rel_bias = nn.Parameter(
            torch.zeros((2 * grid_size - 1) ** 2, num_heads)
        )
        nn.init.trunc_normal_(self.rel_bias, std=0.02)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Initialize attention scores
        attn = (q @ k.transpose(-2, -1)) * self.scale  # This line was missing!

        # Handle relative bias
        rel_bias = self.rel_bias[self.rel_index.to(x.device)]
        if N > self.rel_index.shape[0]:
            pad_size = N - self.rel_index.shape[0]
            rel_bias = F.pad(rel_bias, (0,0,0,pad_size,0,pad_size), value=0)
        rel = rel_bias.unsqueeze(0).permute(0, 3, 1, 2)

        attn = attn + rel[:, :, :N, :N]  # Now attn is properly initialized
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj(out)

class Block(nn.Module):
    def __init__(self, dim, num_heads, grid_size, mlp_ratio=4., norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = RelativeAttention(dim, num_heads, grid_size)
        self.norm2 = norm_layer(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(dim * mlp_ratio), dim),
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

# =============== MAE Model ===============
class MAEModel(nn.Module):
    def __init__(
        self,
        img_size=512,
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        depth=12,
        num_heads=12,
        decoder_embed_dim=512,
        decoder_depth=8,
        decoder_num_heads=16,
        mlp_ratio=4.,
        norm_layer=nn.LayerNorm,
        mask_ratio=0.75
    ):
        super().__init__()
        # Encoder parameters
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.mask_ratio = mask_ratio

        # Save decoder embed dimension as instance variable
        self.decoder_embed_dim = decoder_embed_dim
        self.decoder_num_heads = decoder_num_heads

        # Size info
        self.img_size = img_size
        self.patch_embed = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
        )
        self.grid_size = img_size // patch_size
        self.num_patches = self.grid_size ** 2
        print(f"Grid size: {self.grid_size}x{self.grid_size}, Num patches: {self.num_patches}")

        # Define patch dimension for reconstruction
        self.patch_dim = patch_size * patch_size * in_chans
        print(f"Patch dim: {self.patch_dim}")

        # Encoder embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, 1 + self.num_patches, embed_dim),
            requires_grad=False
        )  # fixed sin-cos embedding

        # Encoder blocks
        self.blocks = nn.ModuleList([
            Block(
                embed_dim,
                num_heads,
                self.grid_size,  # Use original grid_size without +1
                mlp_ratio=mlp_ratio,
                norm_layer=norm_layer
            )
            for _ in range(depth)
        ])
        self.norm = norm_layer(embed_dim)

        # Decoder embeddings
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        self.decoder_pos_embed = nn.Parameter(
            torch.zeros(1, 1 + self.num_patches, decoder_embed_dim),
            requires_grad=False
        )

        # Decoder blocks
        self.decoder_blocks = nn.ModuleList([
            Block(
                decoder_embed_dim,
                decoder_num_heads,
                self.grid_size,
                mlp_ratio=mlp_ratio,
                norm_layer=norm_layer
            )
            for _ in range(decoder_depth)
        ])
        self.decoder_norm = norm_layer(decoder_embed_dim)

        # Decoder prediction head
        self.decoder_pred = nn.Linear(decoder_embed_dim, self.patch_dim, bias=True)

        # Init weights
        self.initialize_weights()

    def initialize_weights(self):
        # Encoder positional embedding
        pos_embed = get_2d_sincos_pos_embed(self.embed_dim, self.grid_size)
        pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
        pe_token = torch.zeros(1, 1, self.embed_dim)
        self.pos_embed.data.copy_(torch.cat([pe_token, pos_embed], dim=1))

        # Decoder positional embedding (same pattern)
        decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_embed_dim, self.grid_size)
        decoder_pos_embed = torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
        decoder_pe_token = torch.zeros(1, 1, self.decoder_embed_dim)
        self.decoder_pos_embed.data.copy_(torch.cat([decoder_pe_token, decoder_pos_embed], dim=1))

        # Initialize cls_token and mask_token
        nn.init.normal_(self.cls_token, std=0.02)
        nn.init.normal_(self.mask_token, std=0.02)

        # Initialize linear layers
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def patchify(self, imgs):
        """Convert images into patches"""
        # imgs: [B, 3, H, W]
        p = self.patch_size
        h = w = self.img_size // p
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        x = x.permute(0, 2, 4, 3, 5, 1)  # [B, h, w, p, p, 3]
        patches = x.reshape(shape=(x.shape[0], h * w, p * p * 3))  # [B, h*w, p*p*3]
        return patches

    def unpatchify(self, patches):
        """Restore patches back to images"""
        # patches: [B, L, p*p*3]
        p = self.patch_size
        h = w = self.img_size // p
        B = patches.shape[0]

        x = patches.reshape(shape=(B, h, w, p, p, 3))
        x = x.permute(0, 5, 1, 3, 2, 4)  # [B, 3, h, p, w, p]
        imgs = x.reshape(shape=(B, 3, h * p, w * p))  # [B, 3, H, W]
        return imgs

    def random_masking(self, x, mask_ratio):
        """Random masking for MAE"""
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))

        # Generate random noise and sort to identify indices to keep
        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]

        # Sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)  # indices to restore original order

        # Keep the first len_keep indices
        ids_keep = ids_shuffle[:, :len_keep]

        # Generate a binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask.scatter_(dim=1, index=ids_keep, value=0)

        # Keep tokens indicated by the mask
        x_masked = torch.gather(
            x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)
        )

        return x_masked, mask, ids_restore

    def forward_encoder(self, x):
        # Convert image to patches: [B, 3, H, W] -> [B, num_patches, embed_dim]
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)

        # Add positional embedding (without cls token)
        x = x + self.pos_embed[:, 1:, :]

        # Apply masking
        x, mask, ids_restore = self.random_masking(x, self.mask_ratio)

        # Append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # Apply transformer blocks
        for blk in self.blocks:
            x = blk(x)

        # Apply final normalization
        x = self.norm(x)

        return x, mask, ids_restore

    def forward_decoder(self, x, ids_restore):
        # Embed from encoder to decoder dimensions
        x = self.decoder_embed(x)

        # Add position embedding for decoder
        x = x + self.decoder_pos_embed[:, :x.shape[1], :]

        # Project encoder tokens to decoder tokens
        # Note that the cls token is at position 0
        x_vis = x[:, 1:, :]  # Remove cls token
        B, N_vis, D = x_vis.shape

        # Prepare mask tokens
        mask_tokens = self.mask_token.repeat(B, self.num_patches - N_vis, 1)

        # Concatenate visible tokens with mask tokens
        x_ = torch.cat([x_vis, mask_tokens], dim=1)

        # Unshuffle to put tokens back in original order
        x_ = torch.gather(
            x_, dim=1,
            index=ids_restore.unsqueeze(-1).repeat(1, 1, D)
        )

        # Append cls token
        x = torch.cat([x[:, :1, :], x_], dim=1)

        # Apply decoder blocks
        for blk in self.decoder_blocks:
            x = blk(x)

        # Apply final normalization and prediction head
        x = self.decoder_norm(x)

        # Predict only on patches (not cls token)
        x = self.decoder_pred(x[:, 1:, :])

        return x

    def forward(self, imgs):
        # Move images to the correct device
        imgs = imgs.to(self.cls_token.device)

        # Encode
        latent, mask, ids_restore = self.forward_encoder(imgs)

        # Decode
        pred = self.forward_decoder(latent, ids_restore)

        # Calculate target patches
        target = self.patchify(imgs)

        # Calculate loss (L1 loss on masked patches)
        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # [B, L], mean loss per patch

        # Apply mask: only compute loss on masked patches
        mask = mask.to(loss.device)  # ensure on same device
        loss = (loss * mask).sum() / mask.sum()  # mean loss on masked patches

        return loss, pred, mask

def train_mae(curated_path, batch_size=2, num_epochs=200,
              patch_size=32, embed_dim=768, decoder_embed_dim=512,
              mask_ratio=0.75):
    """Train MAE model"""
    print(f"Starting MAE training on {curated_path}")
    print(f"Parameters: batch_size={batch_size}, patch_size={patch_size}, embed_dim={embed_dim}")

    # Create dataset and dataloader
    dataset = EnfaceMaskedDataset(curated_path)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0  # To avoid CUDA issues
    )

    # Create model
    model = MAEModel(
        img_size=512,
        patch_size=patch_size,
        in_chans=3,
        embed_dim=embed_dim,
        depth=8,
        num_heads=12,
        decoder_embed_dim=decoder_embed_dim,
        decoder_depth=4,
        decoder_num_heads=16,
        mlp_ratio=4,
        mask_ratio=mask_ratio
    ).to(device)

    # Check model device
    print(f"Model is on device: {next(model.parameters()).device}")

    # Optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=1.5e-6,
        betas=(0.9, 0.95),
        weight_decay=0.05
    )

    # Learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=num_epochs,
        eta_min=1e-7
    )

    # Main training loop
    model.train()

    for epoch in range(num_epochs):
        print(f"Epoch {epoch}/{num_epochs-1}")
        epoch_loss = 0.0
        num_batches = 0

        for batch_idx, imgs in enumerate(dataloader):
            # Forward pass
            loss, _, _ = model(imgs)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()

            # Clip gradients for stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            # Optimizer step
            optimizer.step()

            # Track metrics
            batch_loss = loss.item()
            epoch_loss += batch_loss
            num_batches += 1

            # Print progress
            if batch_idx % 10 == 0:
                print(f"  Batch {batch_idx}: Loss = {batch_loss:.4f}")

        # Update learning rate
        lr_scheduler.step()

        # End of epoch tracking
        avg_epoch_loss = epoch_loss / num_batches if num_batches > 0 else 0
        print(f"Epoch {epoch}: Loss = {avg_epoch_loss:.4f}, LR = {lr_scheduler.get_last_lr()[0]:.6f}")

        # Save model checkpoint
        if epoch % 20 == 0 or epoch == num_epochs - 1:
            save_path = f"mae_model_epoch_{epoch}.pt"
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_epoch_loss,
            }, save_path)
            print(f"Model saved to {save_path}")

    print("Training complete!")
    return model

# Run training if script is executed directly
if __name__ == "__main__":
    # try:
    curated_path = curated_path
    print(f"Using curated_path: {curated_path}")
    model = train_mae(
        curated_path,
        batch_size=2,
        num_epochs=200,
        patch_size=32,
        embed_dim=768,
        decoder_embed_dim=512,
        mask_ratio=0.75
    )
    # except NameError:
    #     print("curated_path not found. Please define curated_path before running.")

Using device: cuda:0
Using curated_path: /content/drive/MyDrive/curated_enface
Starting MAE training on /content/drive/MyDrive/curated_enface
Parameters: batch_size=2, patch_size=32, embed_dim=768
Grid size: 16x16, Num patches: 256
Patch dim: 3072
Model is on device: cuda:0
Epoch 0/199
  Batch 0: Loss = 0.3120
  Batch 10: Loss = 0.2732
Epoch 0: Loss = 0.3300, LR = 0.000001
Model saved to mae_model_epoch_0.pt
Epoch 1/199
  Batch 0: Loss = 0.2612
  Batch 10: Loss = 0.2372
Epoch 1: Loss = 0.2923, LR = 0.000001
Epoch 2/199
  Batch 0: Loss = 0.2427
  Batch 10: Loss = 0.3047
Epoch 2: Loss = 0.2685, LR = 0.000001
Epoch 3/199
  Batch 0: Loss = 0.2175
  Batch 10: Loss = 0.2054
Epoch 3: Loss = 0.2492, LR = 0.000001
Epoch 4/199
  Batch 0: Loss = 0.2796
  Batch 10: Loss = 0.1879
Epoch 4: Loss = 0.2333, LR = 0.000001
Epoch 5/199
  Batch 0: Loss = 0.4436
  Batch 10: Loss = 0.1936
Epoch 5: Loss = 0.2183, LR = 0.000001
Epoch 6/199
  Batch 0: Loss = 0.4363
  Batch 10: Loss = 0.1611
Epoch 6: Loss = 0.20