In [1]:
# %% [markdown]
# # Fine-tuning 2D U-Net for Tooth Segmentation
#
# This notebook fine-tunes the 2D U-Net that was pretrained with a
# Masked Image Modeling (MIM) reconstruction objective on all 2D
# X-ray images (see `pretrain_MIM.ipynb`).
#
# Pipeline:
# 1. Load the index (`sts2d_index.csv`) and reconstruct `df_seg`:
#      - 900 image–mask pairs for segmentation.
# 2. Build train/val/test splits for segmentation.
# 3. Define segmentation transforms and `DentalSegmentationDataset`.
# 4. Define 2D U-Net (same as in pretraining).
# 5. Load MIM-pretrained weights (except the final output layer).
# 6. Define loss (BCE + Dice) and metrics (Dice, IoU).
# 7. Fine-tune the model on segmentation data.
# 8. Optionally compare with training from scratch by turning off
#    the `USE_PRETRAINED` flag.

# %% 
import os
from pathlib import Path
import random

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import cv2
import matplotlib.pyplot as plt

# %matplotlib inline  # uncomment in Jupyter if needed

# -----------------------------
# Reproducibility
# -----------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# -----------------------------
# Paths (aligned with previous notebooks)
# -----------------------------
DATA_ROOT = Path("./sts_tooth_data").resolve()
PROCESSED_2D_DIR = DATA_ROOT / "processed_2d"
INDEX_CSV = DATA_ROOT / "sts2d_index.csv"
CHECKPOINT_DIR = DATA_ROOT / "checkpoints"

print("DATA_ROOT      :", DATA_ROOT)
print("PROCESSED_2D   :", PROCESSED_2D_DIR)
print("INDEX_CSV path :", INDEX_CSV)
print("CHECKPOINT_DIR :", CHECKPOINT_DIR)

# -----------------------------
# Device
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


  from .autonotebook import tqdm as notebook_tqdm


DATA_ROOT      : E:\Data\ToothSeg\sts_tooth_data
PROCESSED_2D   : E:\Data\ToothSeg\sts_tooth_data\processed_2d
INDEX_CSV path : E:\Data\ToothSeg\sts_tooth_data\sts2d_index.csv
CHECKPOINT_DIR : E:\Data\ToothSeg\sts_tooth_data\checkpoints
Using device: cuda


In [2]:
# %% [markdown]
# ## 1. Build `df_seg` (image–mask pairs)
#
# The index `sts2d_index.csv` has columns:
# - rel_path
# - age_group
# - label_status
# - is_mask
# - pair_id
#
# We will:
# - separate image rows (is_mask == False) and mask rows (is_mask == True),
# - inner-join them on `pair_id` to obtain only pairs that have both image and mask:
#     -> `df_seg` with ~900 rows.
# %%
assert INDEX_CSV.exists(), f"Index CSV not found: {INDEX_CSV}"

df = pd.read_csv(INDEX_CSV)
print("Full index shape:", df.shape)
print(df.head())

df_img = df[df["is_mask"] == False].copy()
df_mask = df[df["is_mask"] == True].copy()

print("\nNumber of image rows:", len(df_img))
print("Number of mask rows :", len(df_mask))

# Keep only the columns we need from the mask df for joining
df_mask_simple = df_mask[["pair_id", "rel_path"]].rename(columns={"rel_path": "mask_rel"})

# Inner join on pair_id to obtain only (image, mask) pairs
df_seg = pd.merge(df_img, df_mask_simple, on="pair_id", how="inner")

print("\nSegmentation dataframe shape:", df_seg.shape)
print(df_seg.head())

print("\nValue counts — age_group in df_seg:")
print(df_seg["age_group"].value_counts())

print("\nValue counts — label_status in df_seg:")
print(df_seg["label_status"].value_counts())

Full index shape: (4900, 5)
                          rel_path age_group label_status  is_mask  pair_id
0  A-PXI/Labeled/Image/A_L_001.png     adult      labeled    False  a_l_001
1  A-PXI/Labeled/Image/A_L_002.png     adult      labeled    False  a_l_002
2  A-PXI/Labeled/Image/A_L_003.png     adult      labeled    False  a_l_003
3  A-PXI/Labeled/Image/A_L_004.png     adult      labeled    False  a_l_004
4  A-PXI/Labeled/Image/A_L_005.png     adult      labeled    False  a_l_005

Number of image rows: 4000
Number of mask rows : 900

Segmentation dataframe shape: (900, 6)
                          rel_path age_group label_status  is_mask  pair_id  \
0  A-PXI/Labeled/Image/A_L_001.png     adult      labeled    False  a_l_001   
1  A-PXI/Labeled/Image/A_L_002.png     adult      labeled    False  a_l_002   
2  A-PXI/Labeled/Image/A_L_003.png     adult      labeled    False  a_l_003   
3  A-PXI/Labeled/Image/A_L_004.png     adult      labeled    False  a_l_004   
4  A-PXI/Labeled/Image/A_L_

In [3]:
# ## 2. Train/val/test split
#
# We will:
# - Shuffle `df_seg`,
# - Split into:
#     - ~70% train
#     - ~15% val
#     - ~15% test
#
# For simplicity, we perform a random split without explicit stratification.
# (If needed, we could later add stratified splitting based on age_group and label_status.)
# %%
df_seg_shuffled = df_seg.sample(frac=1.0, random_state=SEED).reset_index(drop=True)

n_total = len(df_seg_shuffled)
n_train = int(0.7 * n_total)
n_val = int(0.15 * n_total)
n_test = n_total - n_train - n_val

df_train = df_seg_shuffled.iloc[:n_train].reset_index(drop=True)
df_val   = df_seg_shuffled.iloc[n_train:n_train + n_val].reset_index(drop=True)
df_test  = df_seg_shuffled.iloc[n_train + n_val:].reset_index(drop=True)

print(f"Total samples: {n_total}")
print(f"Train: {len(df_train)}, Val: {len(df_val)}, Test: {len(df_test)}")

Total samples: 900
Train: 630, Val: 135, Test: 135
