In [None]:
!pip install --no-index --find-links="/kaggle/input/surface-package-scraper" -q pytorch_lightning monai albumentations imagecodecs --no-deps # "numpy==1.26.4" "scipy==1.15.3"
!pip uninstall -q -y tensorflow  # preventing AttributeError

In [None]:
# Cell 1
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

1 - 384


40 - 256

750 - 320

In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
import gc
import warnings
from typing import Tuple, Optional, Dict, List, Callable
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader
import os
warnings.filterwarnings("ignore")

images_path = "/kaggle/input/vesuvius-surface-npz/train_images"
mask_path = "/kaggle/input/vesuvius-surface-npz/train_labels"
root_dir = "/kaggle/input/vesuvius-surface-npz"
test_images_path = "/kaggle/input/vesuvius-challenge-surface-detection/test_images"
OUTPUT_DIR = "/kaggle/working/checkpoints"
os.makedirs(OUTPUT_DIR, exist_ok=True)
gc.collect()

In [None]:
class Config:
    def __init__(self):
        self.lr = 1e-4
        self.num_workers = 2
        self.batches = 1
        self.val_split = 0.2
        self.target_shape = (128,128,128)
        self.weight_decay = 2e-4
        self.max_epochs = 50
        

config = Config()

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
import tifffile
import pytorch_lightning as pl
import random
from monai import transforms as MT

# --- 1. Modified Dataset ---
class VesuviusDataset(Dataset):
    def __init__(self, 
                 images_id: list, 
                 mode: str = "train",
                 transform=None): # Added transform argument
        
        self.images_id = images_id
        self.mode = mode
        self.images_dir = Path(images_path)
        self.masks_dir = Path(masks_path) if masks_dir else None
        self.test_dir = Path(test_images_path) if test_dir else None
        self.transform = transform # Store the transform

    def __len__(self):
        return len(self.images_id)

    def __getitem__(self, idx):
        file_name = self.images_id[idx]
        
        # 1. Determine Source Directory
        if self.mode == "test":
            if self.test_dir is None:
                raise ValueError("Mode is 'test' but no 'test_dir' was provided!")
            source_dir = self.test_images_path
        else:
            source_dir = self.test_images_path

        # 2. Load Image (Returns Tensor)
        img_path = source_dir / file_name
        img = self.load_volume(img_path)

        # 3. Load Mask
        mask = None
        if self.mode in ["train", "val"]:
            if self.masks_dir is None:
                raise ValueError(f"Mode is '{self.mode}' but no 'masks_dir' was provided!")
            mask_path = self.masks_dir / file_name
            mask = self.load_volume(mask_path, is_mask=True)

        # --- 4. Apply Transforms (CPU Side) ---
        # Wrap in dictionary for MONAI dictionary-based transforms
        data = {"image": img}
        if mask is not None:
            data["label"] = mask

        if self.transform:
            # Apply transform
            data = self.transform(data)

        # Unpack back to tuple
        img_out = data["image"]
        mask_out = data["label"] if "label" in data else None

        return img_out, mask_out, file_name

    def load_volume(self, file_path, is_mask=False):
        path_obj = Path(file_path)
        if not path_obj.exists():
            raise FileNotFoundError(f"Could not find file: {path_obj}")
        
        if path_obj.suffix == ".npz":
            archive = np.load(str(path_obj))
            data = archive[list(archive.files)[0]]
        elif path_obj.suffix in [".tif", ".tiff"]:
            data = tifffile.imread(str(path_obj))
        else:
            data = np.load(str(path_obj))
        
        tensor = torch.from_numpy(data)

        if not is_mask:
            # Normalize and add channel dim
            tensor = tensor.half().div_(255.0).unsqueeze(0)
        else:
            # Add channel dim
            tensor = tensor.long().unsqueeze(0)

        return tensor

In [None]:
import os
images_id = None
for _,_,c in os.walk(images_path):
    images_id = c


In [None]:
def custom_collate(batch):
    return batch
#we want colation on happen on gpu as cpu cant do heavy interpolation
#skipping the collation in dataloader

In [None]:
class DesuviusDataModule(pl.LightningDataModule):
    def __init__(self,
                 num_workers=4,
                 batches=8,
                 target_shape=(96, 96, 96), # Ensure this matches your VRAM limits
                 val_split=0.2):
        super().__init__()
        self.train_img_dir = images_path
        self.test_img_dir= test_images_path
        self.mask_img_dir = mask_path
        self.num_workers = num_workers
        self.batches = batches
        self.target_shape = config.target_shape
        self.val_split = val_split

        # Define Transforms HERE (CPU)
        # Note: EnsureTyped ensures we output Tensors, track_meta=False saves memory
        self.train_transforms = MT.Compose([
            MT.Resized(keys=["image", "label"], spatial_size=self.target_shape, mode=["trilinear", "nearest"]),
            MT.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
            MT.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
            MT.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
            MT.RandRotated(keys=["image", "label"], range_x=0.1, range_y=0.1, range_z=0.1, prob=0.3, keep_size=True, mode=["bilinear", "nearest"]),
            MT.RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5),
            MT.RandGaussianNoised(keys=["image"], prob=0.3, mean=0.0, std=0.01),
            MT.EnsureTyped(keys=["image", "label"], track_meta=False) 
        ])

        self.val_transforms = MT.Compose([
            MT.Resized(keys=["image", "label"], spatial_size=self.target_shape, mode=["trilinear", "nearest"]),
            MT.EnsureTyped(keys=["image", "label"], track_meta=False)
        ])

        self.test_transforms = MT.Compose([
            MT.Resized(keys=["image"], spatial_size=self.target_shape, mode=["trilinear"]),
            MT.EnsureTyped(keys=["image"], track_meta=False)
        ])

    def setup(self, stage=None):
        random.seed(32)
        # Assuming 'images_id' is passed in or defined globally. 
        # Ideally, pass this into __init__
        all_files = images_id 
        random.shuffle(all_files)
        
        split_idx = int(len(all_files) * (1 - self.val_split))
        train_files = all_files[:split_idx]
        val_files = all_files[split_idx:]
        
        print(f"Total files: {len(all_files)}")
        print(f"Train files: {len(train_files)}")
        print(f"Val files: {len(val_files)}")

        # Initialize Datasets WITH transforms
        self.train_dataset = VesuviusDataset(
            images_id=train_files, 
            images_dir=self.train_img_dir,
            masks_dir=self.mask_img_dir,
            mode="train", 
            transform=self.train_transforms
        )
        
        self.val_dataset = VesuviusDataset(
            images_id=val_files, 
            images_dir=self.train_img_dir,
            masks_dir=self.mask_img_dir,
            mode="val", 
            transform=self.val_transforms
        )
        
        # If you have a test set logic, initialize it here with self.test_transforms

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            batch_size=self.batches,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True, # Critical for speed when using CPU transforms
            persistent_workers=bool(self.num_workers > 0),
            # collate_fn=custom_collate # Use if you have a specific need, otherwise default works
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            batch_size=self.batches,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=bool(self.num_workers > 0),
            # collate_fn=custom_collate 
        )

In [None]:
dataModule = DesuviusDataModule()
dataModule.setup()

In [None]:
print(dataModule.train_dataloader)

In [None]:
import matplotlib.pyplot as plt

# Get a batch
train_loader = dataModule.train_dataloader()
batch = next(iter(train_loader))

# 2. Unpack the batch
# With default collation, PyTorch stacks them into tensors: (Batch, Channel, D, H, W)
images_batch, masks_batch, frag_ids = batch

# Select the first sample in the batch
# Shape is (C, D, H, W), usually (1, 96, 96, 96)
raw_img_tensor = images_batch[0] 
raw_mask_tensor = masks_batch[0]
frag_id = frag_ids[0]

print(f"ID: {frag_id}")
print(f"Tensor Shape (with channel): {raw_img_tensor.shape}")

# 3. Prepare for Plotting
# We don't need to apply transforms manually anymore; they are already applied!
# Just extract Channel 0 and convert to Numpy.
img = raw_img_tensor[0].detach().cpu().numpy()  # (D, H, W)
msk = raw_mask_tensor[0].detach().cpu().numpy() # (D, H, W)

print(f"Plotting Shape: {img.shape}")
print(f"Image Range: {img.min():.2f} - {img.max():.2f}")

# 4. Visualization Logic
# Calculate middle indices
d_mid = img.shape[0] // 2
h_mid = img.shape[1] // 2
w_mid = img.shape[2] // 2

# Setup plot: 3 Rows (Axes), 2 Cols (Image, Mask)
fig, axes = plt.subplots(3, 2, figsize=(10, 15))

# Row 1: Z-axis (Depth/Axial)
axes[0, 0].imshow(img[d_mid, :, :], cmap='gray')
axes[0, 0].set_title(f'Axial (Depth={d_mid}) - Image')
axes[0, 1].imshow(msk[d_mid, :, :], cmap='gray')
axes[0, 1].set_title(f'Axial (Depth={d_mid}) - Mask')

# Row 2: Y-axis (Height/Coronal)
axes[1, 0].imshow(img[:, h_mid, :], cmap='gray')
axes[1, 0].set_title(f'Coronal (Height={h_mid}) - Image')
axes[1, 1].imshow(msk[:, h_mid, :], cmap='gray')
axes[1, 1].set_title(f'Coronal (Height={h_mid}) - Mask')

# Row 3: X-axis (Width/Sagittal)
axes[2, 0].imshow(img[:, :, w_mid], cmap='gray')
axes[2, 0].set_title(f'Sagittal (Width={w_mid}) - Image')
axes[2, 1].imshow(msk[:, :, w_mid], cmap='gray')
axes[2, 1].set_title(f'Sagittal (Width={w_mid}) - Mask')

plt.tight_layout()
plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
from monai.losses import TverskyLoss

class surfaceSegementation(pl.LightningModule):
    def __init__(self, net, learning_rate=1e-3, weight_decay=1e-4):
        super().__init__()
        # Fixed typo: save_hyperparameters
        self.save_hyperparameters(ignore=["net"])
        # Fixed: assigned 'net' instead of undefined 'model'
        self.net_module = net
        # Fixed: removed undefined 'config' object and used init args
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.out_channels = 2
        self.spatial_dims  = 3
        self.ignore_index_val = 2
        
        self.crossEntropy = nn.CrossEntropyLoss(
            ignore_index = self.ignore_index_val
        )
        self.Tversky = TverskyLoss(
            softmax = True,
            to_onehot_y = False,
            include_background = True,
            alpha = 0.75,
            beta = 0.25  
        )

    # Added missing forward method required for self(img)
    def forward(self, x):
        return self.net_module(x)

    def _compute_loss(self, logits, targets):
        # pytorch cross entropy can ignore classes inherintly
        targets = targets.squeeze(1)
        ce_loss = self.crossEntropy(logits, targets.long())
        
        # mask so tversky doesnt get confused by class 2 as it cant ignore it natively
        # Fixed: Changed 'target' to 'targets'
        mask = (targets != self.ignore_index_val)
        
        # setting index 2 to 0
        # Fixed: Changed 'target' to 'targets'
        target_clean = torch.where(mask, targets, torch.tensor(0, device=targets.device))
        
        targets_onehot = torch.nn.functional.one_hot(
            target_clean.long(),
            num_classes=2
        ).float()

        # Fixed: Changed self.hparams.spatial_dims to self.spatial_dims
        if self.spatial_dims == 3:
            targets_onehot = targets_onehot.permute(0,4,1,2,3)
        else:
            targets_onehot = targets_onehot.permute(0,3,1,2) # Fixed permute for 2D
        
        # CHANGED: .half() -> .float() to prevent runtime errors on standard FP32 training
        targets_masked_ohe = targets_onehot * mask.unsqueeze(1).float() 

        tversky_loss = self.Tversky(logits, targets_masked_ohe)

        return 0.5*tversky_loss + 0.5*ce_loss

    def _compute_metrics(self, preds_logits: torch.Tensor, targets_class_indices: torch.Tensor) -> dict:
        """
        Computes Dice and IoU fully vectorized (No for-loops).
        """
        num_classes = preds_logits.shape[1]
        preds_hard = torch.argmax(preds_logits, dim=1, keepdim=True)
        valid_mask = (targets_class_indices != self.ignore_index_val)
        
        # 3. One-Hot Encode Predictions
        preds_ohe = torch.nn.functional.one_hot(
            preds_hard.squeeze(1), 
            num_classes=num_classes
        ).float()
        
        # 4. One-Hot Encode Targets
        targets_clean = torch.where(valid_mask, targets_class_indices, torch.tensor(0, device=targets_class_indices.device))
        
        targets_ohe = torch.nn.functional.one_hot(
            targets_clean.squeeze(1).long(), 
            num_classes=num_classes
        ).float()

        # Fixed: Changed self.hparams.spatial_dims to self.spatial_dims
        if self.spatial_dims == 3:
            preds_ohe = preds_ohe.permute(0, 4, 1, 2, 3)
            targets_ohe = targets_ohe.permute(0, 4, 1, 2, 3)
        else:
            preds_ohe = preds_ohe.permute(0, 3, 1, 2)
            targets_ohe = targets_ohe.permute(0, 3, 1, 2)

        valid_mask_float = valid_mask.float()
        preds_ohe = preds_ohe * valid_mask_float
        targets_ohe = targets_ohe * valid_mask_float

        # Fixed: Changed self.hparams.spatial_dims to self.spatial_dims
        reduce_dims = (0, 2, 3, 4) if self.spatial_dims == 3 else (0, 2, 3)

        intersection = (preds_ohe * targets_ohe).sum(dim=reduce_dims)
        cardinality_pred = preds_ohe.sum(dim=reduce_dims)
        cardinality_target = targets_ohe.sum(dim=reduce_dims)
        
        union_sum = cardinality_pred + cardinality_target

        dice_scores = (2. * intersection + 1e-8) / (union_sum + 1e-8)
        iou_scores = (intersection + 1e-8) / (union_sum - intersection + 1e-8)

        return {
            "dice": dice_scores.mean(),
            "iou": iou_scores.mean()
        }

    # Added 'self' argument
    def training_step(self, batch, batch_idx):
        img, mask, _ = batch
        logits = self(img)
        loss = self._compute_loss(logits, mask)
        # Fixed function name _compute_metric -> _compute_metrics
        # Fixed variable name metric -> metrics (to match logging below)
        metrics = self._compute_metrics(logits, mask)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_dice", metrics["dice"], on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_iou", metrics["iou"], on_step=True, on_epoch=True, prog_bar=True)
        return loss

    # Added 'self' argument
    def validation_step(self, batch, batch_idx):
        img, mask, _ = batch
        logits = self(img)
        loss = self._compute_loss(logits, mask)
        # Fixed function name _compute_metric -> _compute_metrics
        # Fixed variable name metric -> metrics
        metrics = self._compute_metrics(logits, mask)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_dice", metrics["dice"], on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_iou", metrics["iou"], on_step=False, on_epoch=True, prog_bar=True)
        return loss

    # Added 'self' argument and renamed to predict_step
    def predict_step(self, batch, batch_idx):
        img, _, frag_id = batch
        logits = self(img)
        prob = torch.softmax(logits, dim=1)
        classes = torch.argmax(prob, dim=1)
        return {"prediction": classes, "frag_id": frag_id}
        
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            # Fixed MAX_EPOCHS to use self.trainer.max_epochs
            T_max=self.trainer.max_epochs if self.trainer else 100,
            eta_min=1e-6
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch"
            }
        }

In [None]:
from monai.networks.nets import SegResNet, SwinUNETR
net = SwinUNETR(
    in_channels=1,
    out_channels=2,
    feature_size=48,
    use_v2=True,
    drop_rate=0.2,
    attn_drop_rate=0.2,
    dropout_path_rate=0.2,
    use_checkpoint=True
)

net_name = net.__class__.__name__
model = surfaceSegementation(net=net)

In [None]:
ckpt_path = "/kaggle/input/vesuvius-swinuneter-model1/checkpoints/SwinUNETR-epoch=01-val_dice=0.6019.ckpt"

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.utilities.exceptions import MisconfigurationException

# Callbacks
checkpoint_callback = ModelCheckpoint(
    dirpath=OUTPUT_DIR,
    filename=net_name + "-{epoch:02d}-{val_dice:.4f}",
    monitor="val_dice",
    mode="max",
    save_top_k=3,
    verbose=True
)

early_stop_callback = EarlyStopping(
    monitor="val_dice",
    patience=10,
    mode="max",
    verbose=True
)

lr_monitor = LearningRateMonitor(logging_interval="epoch")
csv_logger = CSVLogger(save_dir=OUTPUT_DIR)

# Trainer
trainer = pl.Trainer(
    max_epochs=config.max_epochs,
    accelerator="auto",
    devices="auto",
    logger=csv_logger,
    callbacks=[checkpoint_callback, early_stop_callback, lr_monitor],
    precision="16-mixed",
    log_every_n_steps=20,
    enable_progress_bar=True,
    accumulate_grad_batches=18,
    gradient_clip_val=1.0, # Clips gradient norm to 1.0 to prevent exploding gradients
)

# Train
try:
    trainer.fit(model, datamodule=dataModule, ckpt_path=ckpt_path)
except MisconfigurationException as ex:
    print(ex)