In [1]:
from pathlib import Path
import pandas as pd, numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, RichProgressBar
import timm, torchmetrics
from sklearn.model_selection import train_test_split
import optuna
from optuna.integration.pytorch_lightning import PyTorchLightningPruningCallback

# 🔒 paths — adjust only if your files live elsewhere
CKPT_PATH = Path(r"C:\Users\offic\medself\checkpoints\epoch=39-step=12280.ckpt")
CSV_PATH  = Path(r"C:\Users\offic\OneDrive\Masaüstü\datasets\SelfSupervised\Shenzhen\shenzhen_metadata.csv")
IMG_DIR   = Path(r"C:\Users\offic\OneDrive\Masaüstü\datasets\SelfSupervised\Shenzhen\images\images")

assert CKPT_PATH.exists(), "Checkpoint not found"
assert CSV_PATH.exists(),  "CSV not found"
assert IMG_DIR.exists(),   "Image folder not found"




In [2]:
class IdentityTransform:
    def __call__(self, x):
        return x

def med_transform(img_size=224, train=True):
    mean = std = [0.5]
    ops = [transforms.Resize((img_size, img_size))]
    if train:
        ops.append(transforms.RandomHorizontalFlip())
    ops += [transforms.ToTensor(), transforms.Normalize(mean, std)]
    return transforms.Compose(ops)


In [3]:
class ShenzhenCSV(Dataset):
    """Load grayscale PNG + binary label from metadata CSV."""
    def __init__(self, df: pd.DataFrame, img_dir: Path, train=True):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.tf = med_transform(train=train)

    @staticmethod
    def _label(findings: str) -> int:
        return 0 if findings.lower().strip() == "normal" else 1

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(self.img_dir / row["study_id"]).convert("L")
        x = self.tf(img).repeat(3, 1, 1)      # (3,224,224)
        y = torch.tensor(self._label(row["findings"]), dtype=torch.float32)
        return x, y


In [4]:
class ShenzhenDM(pl.LightningDataModule):
    def __init__(self, csv_path: Path, img_dir: Path,
                 batch=32, workers=0, seed=42):
        super().__init__()
        self.csv_path, self.img_dir = csv_path, img_dir
        self.batch, self.workers, self.seed = batch, workers, seed

    def setup(self, stage=None):
        df = pd.read_csv(self.csv_path)
        labels = df["findings"].apply(lambda s: 0 if s.lower().strip()=="normal" else 1)
        train_idx, temp_idx = train_test_split(
            np.arange(len(df)), test_size=0.25,
            stratify=labels, random_state=self.seed
        )
        val_idx, test_idx = train_test_split(
            temp_idx, test_size=0.40,
            stratify=labels.iloc[temp_idx], random_state=self.seed
        )
        self.train_ds = ShenzhenCSV(df.iloc[train_idx], self.img_dir, train=True)
        self.val_ds   = ShenzhenCSV(df.iloc[val_idx],   self.img_dir, train=False)
        self.test_ds  = ShenzhenCSV(df.iloc[test_idx],  self.img_dir, train=False)

    def _dl(self, ds, shuffle=False):
        return DataLoader(ds, self.batch, shuffle=shuffle,
                          num_workers=self.workers, pin_memory=True)

    def train_dataloader(self): return self._dl(self.train_ds, True)
    def val_dataloader(self):   return self._dl(self.val_ds)
    def test_dataloader(self):  return self._dl(self.test_ds)


In [5]:
class LitTBFinetune(pl.LightningModule):
    def __init__(self, ckpt_path: Path, freeze_epochs=3, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()

        # ── 1. Build ViT-Tiny encoder and load SSL weights ───────────
        self.encoder = timm.create_model(
            "vit_tiny_patch16_224", num_classes=0, global_pool="token"
        )
        ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
        self.encoder.load_state_dict(
            {k.replace("student.", ""): v for k, v in ckpt.items()
             if k.startswith("student.")},
            strict=False,
        )

        # ── 2. Auto-detect output feature length ─────────────────────
        with torch.no_grad():
            dummy = torch.zeros(1, 3, 224, 224)
            feat = self.encoder(dummy)
            if feat.ndim == 3:       # (B, tokens, dim) → CLS token
                feat = feat[:, 0]
            self.feat_dim = feat.shape[-1]

        # ── 3. Classification head that matches discovered dim ───────
        self.head = nn.Linear(self.feat_dim, 1)

        # ── other hyper-params & metrics ─────────────────────────────
        self.freeze_epochs, self.lr = freeze_epochs, lr
        self.auc = torchmetrics.AUROC(task="binary")

    # Forward returns logits
    def forward(self, x):
        z = self.encoder(x)
        if z.ndim == 3:          # CLS token
            z = z[:, 0]
        return self.head(z).squeeze(1)

    def training_step(self, batch, _):
        x, y = batch
        loss = F.binary_cross_entropy_with_logits(self(x), y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, _):
        x, y = batch
        preds = torch.sigmoid(self(x))
        self.auc.update(preds, y.int())

    def on_validation_epoch_end(self):
        val_auc = self.auc.compute(); self.auc.reset()
        self.log("val_auc", val_auc, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr)

    def on_train_epoch_start(self):
        freeze = self.current_epoch < self.freeze_epochs
        for p in self.encoder.parameters():
            p.requires_grad = not freeze


In [6]:
dm_shen = ShenzhenDM(CSV_PATH, IMG_DIR, batch=32, workers=0)
model_tb = LitTBFinetune(CKPT_PATH, freeze_epochs=3, lr=1e-3)


In [7]:
callbacks = [
    ModelCheckpoint(dirpath="shenzhen_ckpts",
                    filename="epoch{epoch}-auc{val_auc:.3f}",
                    monitor="val_auc", mode="max", save_top_k=1),
    EarlyStopping(monitor="val_auc", mode="max", patience=4),
    RichProgressBar(),
]

trainer = Trainer(
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1,
    max_epochs=15,
    precision="bf16-mixed" if torch.cuda.is_available() else 32,
    callbacks=callbacks,
    log_every_n_steps=10,
)


Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [9]:
trainer.fit(model_tb, dm_shen)


C:\Users\offic\anaconda3\envs\medssl\Lib\site-packages\pytorch_lightning\callbacks\model_checkpoint.py:652: Checkpoint directory C:\Users\offic\medself\shenzhen_ckpts exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()