# nnU-Net v2 on Google Colab (Kaggle dataset)

This notebook:
1. Downloads the dataset from Kaggle
2. Converts `.tif` images/masks to nnU-Net v2 format (`.nii.gz`)
3. Plans + preprocesses
4. Trains fold 0 (GPU)
5. Runs a quick prediction + visualization

> **Before you run**: Runtime → Change runtime type → **GPU**.

In [None]:
!nvidia-smi

## 1) Install dependencies

In [None]:
!pip -q install -U nnunetv2 nibabel tifffile simpleitk blosc2 kaggle matplotlib

## 2) Kaggle authentication

Upload your `kaggle.json` (Kaggle API token) to Colab:
- Kaggle → Account → **Create New API Token** → downloads `kaggle.json`
- In Colab left sidebar: **Files → Upload** `kaggle.json`

Then run the cell below.

In [None]:
import os, pathlib, json, shutil

# Put kaggle.json in the right place
kaggle_json_src = "/content/kaggle.json"
kaggle_dir = pathlib.Path.home()/".kaggle"
kaggle_dir.mkdir(parents=True, exist_ok=True)

if os.path.exists(kaggle_json_src):
    shutil.copy(kaggle_json_src, kaggle_dir/"kaggle.json")
    os.chmod(kaggle_dir/"kaggle.json", 0o600)
    print("kaggle.json installed.")
else:
    print("Upload kaggle.json to /content first (Files sidebar → Upload).")


## 3) Download the dataset from Kaggle

Set the Kaggle dataset slug below (format: `username/dataset-name`).

In [None]:
KAGGLE_DATASET = "PUT_USERNAME/DATASET_SLUG_HERE"  # <-- change this

DATA_ROOT = "/content/data_kaggle"
os.makedirs(DATA_ROOT, exist_ok=True)

# Download + unzip
!kaggle datasets download -d {KAGGLE_DATASET} -p {DATA_ROOT} --force
!unzip -q -o {DATA_ROOT}/*.zip -d {DATA_ROOT}

print("Downloaded/unzipped into:", DATA_ROOT)
!find {DATA_ROOT} -maxdepth 3 -type d


## 4) Point to your dataset folders

We expect:
- images: `.../dataset/image/*.tif`
- labels: `.../dataset/label/*.tif`

If your extracted path is different, edit `IMG_DIR` and `LBL_DIR`.

In [None]:
import glob, os

# EDIT THESE if your folder names differ after unzip
IMG_DIR = f"{DATA_ROOT}/dataset/image"
LBL_DIR = f"{DATA_ROOT}/dataset/label"

print("IMG_DIR exists:", os.path.exists(IMG_DIR))
print("LBL_DIR exists:", os.path.exists(LBL_DIR))
print("Sample images:", sorted(glob.glob(IMG_DIR + "/*.tif"))[:5])
print("Sample labels:", sorted(glob.glob(LBL_DIR + "/*.tif"))[:5])


## 5) Set nnU-Net paths (Colab)

We store everything under `/content/nnUNet` (fast local disk).

In [None]:
import os

BASE = "/content/nnUNet"
os.environ["nnUNet_raw"] = f"{BASE}/nnUNet_raw"
os.environ["nnUNet_preprocessed"] = f"{BASE}/nnUNet_preprocessed"
os.environ["nnUNet_results"] = f"{BASE}/nnUNet_results"

for p in [os.environ["nnUNet_raw"], os.environ["nnUNet_preprocessed"], os.environ["nnUNet_results"]]:
    os.makedirs(p, exist_ok=True)

for k in ["nnUNet_raw","nnUNet_preprocessed","nnUNet_results"]:
    print(k, "=", os.environ[k])


## 6) Convert TIF → nnU-Net v2 format (NIfTI)

Important:
- images: `case_XXXX_0000.nii.gz`
- labels: `case_XXXX.nii.gz`

Label remap used (binary stone mask): **`label >= 251 → 1` else 0**.

In [None]:
import os, glob, json
import numpy as np
import tifffile as tiff
import nibabel as nib

DATASET_ID = 501
DATASET_NAME = "KSSD"
nn_folder = f"{os.environ['nnUNet_raw']}/Dataset{DATASET_ID:03d}_{DATASET_NAME}"
imagesTr = f"{nn_folder}/imagesTr"
labelsTr = f"{nn_folder}/labelsTr"
os.makedirs(imagesTr, exist_ok=True)
os.makedirs(labelsTr, exist_ok=True)

img_files = sorted(glob.glob(IMG_DIR + "/*.tif"))
lbl_files = sorted(glob.glob(LBL_DIR + "/*.tif"))

print("Images:", len(img_files))
print("Labels:", len(lbl_files))
assert len(img_files) == len(lbl_files), "Image/label count mismatch. Fix IMG_DIR/LBL_DIR."

# quick label value check
sample_lbl = tiff.imread(lbl_files[0]).astype(np.uint8)
print("Unique label values in first mask:", np.unique(sample_lbl))

for i, (img_p, lbl_p) in enumerate(zip(img_files, lbl_files)):
    img = tiff.imread(img_p).astype(np.float32)
    lbl = tiff.imread(lbl_p).astype(np.uint8)

    # Binary remap: stone pixels are high values (>=251)
    lbl = (lbl >= 251).astype(np.uint8)

    # Ensure (H,W,1) for 2D
    if img.ndim == 2:
        img = img[..., None]
    if lbl.ndim == 2:
        lbl = lbl[..., None]

    affine = np.eye(4)
    case_id = f"case_{i:04d}"
    nib.save(nib.Nifti1Image(img, affine), f"{imagesTr}/{case_id}_0000.nii.gz")
    nib.save(nib.Nifti1Image(lbl, affine), f"{labelsTr}/{case_id}.nii.gz")

print("Conversion done:", nn_folder)


## 7) Create dataset.json

In [None]:
import glob, json, os

dataset_json = {
    "channel_names": {"0": "CT"},
    "labels": {"background": 0, "stone": 1},
    "numTraining": len(glob.glob(imagesTr + "/*.nii.gz")),
    "file_ending": ".nii.gz"
}

with open(f"{nn_folder}/dataset.json", "w") as f:
    json.dump(dataset_json, f, indent=2)

print("dataset.json written:", f"{nn_folder}/dataset.json")
print(dataset_json)


## 8) Plan + preprocess

In [None]:
!nnUNetv2_plan_and_preprocess -d 501 --verify_dataset_integrity

## 9) Train fold 0 (GPU)

If you get a memory error, tell me your GPU type from `nvidia-smi` and we’ll adjust the plan/batch size.

In [None]:
!nnUNetv2_train 501 2d 0

## 10) Quick prediction on 20 samples + visualize

This is just a sanity check that training produces reasonable masks.

In [None]:
import os, glob, shutil

imagesTs = "/content/imagesTs"
preds = "/content/preds"
os.makedirs(imagesTs, exist_ok=True)
os.makedirs(preds, exist_ok=True)

all_imgs = sorted(glob.glob(imagesTr + "/*.nii.gz"))
for p in all_imgs[:20]:
    shutil.copy(p, imagesTs)

print("Prepared imagesTs:", len(os.listdir(imagesTs)))


In [None]:
!nnUNetv2_predict -d 501 -i /content/imagesTs -o /content/preds -c 2d -f 0

In [None]:
import os, glob
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt

pred_files = sorted(glob.glob("/content/preds/*.nii.gz"))
print("Pred files:", len(pred_files))
pred_p = pred_files[0]
case = os.path.basename(pred_p).replace(".nii.gz","")

img_p = f"{imagesTr}/{case}_0000.nii.gz"
gt_p  = f"{labelsTr}/{case}.nii.gz"

img = nib.load(img_p).get_fdata()
gt  = nib.load(gt_p).get_fdata()
pr  = nib.load(pred_p).get_fdata()

img2 = img[...,0] if img.ndim == 3 else img
gt2  = gt[...,0] if gt.ndim == 3 else gt
pr2  = pr[...,0] if pr.ndim == 3 else pr

plt.figure()
plt.imshow(img2, cmap="gray")
plt.title("Image")
plt.axis("off")
plt.show()

plt.figure()
plt.imshow(gt2)
plt.title("Ground truth mask")
plt.axis("off")
plt.show()

plt.figure()
plt.imshow(pr2)
plt.title("Predicted mask (fold 0)")
plt.axis("off")
plt.show()
