In [4]:
# ===========================================================
#  ♦♦  Continual SSL→CL pipeline – v8.1
#      (bug-fixed: replay helpers restored)
# ===========================================================
from pathlib import Path
import random, numpy as np, pandas as pd, torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from PIL import Image
from sklearn.model_selection import train_test_split
import timm, pytorch_lightning as pl
from torchmetrics.classification import BinaryAUROC
from torch.optim.lr_scheduler import CosineAnnealingLR

# ───────────────────────── paths ────────────────────────────
CKPT     = Path(r"C:\Users\offic\medself\shenzhen_ckpts\epochepoch=11-aucval_auc=0.896.ckpt")
SHEN_CSV = Path(r"C:\Users\offic\OneDrive\Masaüstü\datasets\SelfSupervised\Shenzhen\shenzhen_metadata.csv")
SHEN_IMG = Path(r"C:\Users\offic\OneDrive\Masaüstü\datasets\SelfSupervised\Shenzhen\images\images")
MONT_CSV = Path(r"C:\Users\offic\OneDrive\Masaüstü\datasets\SelfSupervised\Montgomery\montgomery_metadata.csv")
MONT_IMG = Path(r"C:\Users\offic\OneDrive\Masaüstü\datasets\SelfSupervised\Montgomery\images\images")
MIAS_CSV = Path(r"C:\Users\offic\OneDrive\Masaüstü\datasets\SelfSupervised\MIAS\mias_info.csv")
MIAS_IMG = Path(r"C:\Users\offic\OneDrive\Masaüstü\datasets\SelfSupervised\MIAS\images")
for p in (CKPT, SHEN_CSV, SHEN_IMG, MONT_CSV, MONT_IMG, MIAS_CSV, MIAS_IMG):
    assert p.exists(), p

# ────────────────────── hyper-params ───────────────────────
DEVICE   = "cuda" if torch.cuda.is_available() else "cpu"
BATCH    = 32
GRAD_ACC = 2                  # → effective batch = 64
MEM_SIZE = 200
EWC_LMB  = 0.05
EPOCHS   = dict(shen=5, mont=25, mias=[12, 25])
LRS      = dict(shen=1e-4, mont=3e-4, mias_head=3e-4, mias_mid=1e-4)

# ───────────────────── transforms / datasets ───────────────
def med_tf(train=True):
    if train:
        ops = [
            transforms.Resize(256),
            transforms.RandomResizedCrop(224, scale=(0.85, 1.0), ratio=(0.9, 1.1)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(5),
        ]
    else:
        ops = [transforms.Resize(256), transforms.CenterCrop(224)]
    ops += [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
    return transforms.Compose(ops)

class CXRCSV(Dataset):
    def __init__(self, df, img_dir, train=True):
        self.df, self.dir = df.reset_index(drop=True), img_dir
        self.tf   = med_tf(train)
        self.tgts = [0 if f.lower().strip() == "normal" else 1
                     for f in self.df["findings"]]
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        img = self.dir / self.df.iloc[idx]["study_id"]
        x   = self.tf(Image.open(img).convert("L")).repeat(3, 1, 1)
        y   = torch.tensor([self.tgts[idx]], dtype=torch.float32)
        return x, y

def split_cxr(csv, img):
    df = pd.read_csv(csv)
    y  = df["findings"].apply(lambda s: 0 if s.lower().strip() == "normal" else 1)
    tr, val = train_test_split(np.arange(len(df)), test_size=0.2,
                               stratify=y, random_state=42)
    return CXRCSV(df.iloc[tr], img, True), CXRCSV(df.iloc[val], img, False)

shen_train, shen_val = split_cxr(SHEN_CSV, SHEN_IMG)
mont_train, mont_val = split_cxr(MONT_CSV, MONT_IMG)

class MIASCSV(Dataset):
    def __init__(self, df, img_dir, train=True):
        self.df, self.dir = df.reset_index(drop=True), img_dir
        self.tf = med_tf(train)
        sev = self.df["SEVERITY"].fillna("B").astype(str).str.upper()
        self.tgts = [1 if s.startswith("M") else 0 for s in sev]
        self.pos  = [i for i, t in enumerate(self.tgts) if t == 1]
        self.neg  = [i for i, t in enumerate(self.tgts) if t == 0]
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        stem = self.df.iloc[idx]["REFNUM"]
        f = next(self.dir / f"{stem}{e}"
                 for e in (".png", ".pgm") if (self.dir / f"{stem}{e}").exists())
        x = self.tf(Image.open(f).convert("L")).repeat(3, 1, 1)
        y = torch.tensor([self.tgts[idx]], dtype=torch.float32)
        return x, y

def split_mias():
    df = pd.read_csv(MIAS_CSV)
    sev = df["SEVERITY"].fillna("B").astype(str).str.upper()
    y   = sev.map(lambda s: 1 if s.startswith("M") else 0)
    tr, val = train_test_split(np.arange(len(df)), test_size=0.2,
                               stratify=y, random_state=42)
    return MIASCSV(df.iloc[tr], MIAS_IMG, True), MIASCSV(df.iloc[val], MIAS_IMG, False)

mias_train, mias_val = split_mias()

# ───────── balanced & malignant-oversample dataset ─────────
class BalancedDataset(Dataset):
    def __init__(self, cur_ds, rep_x, rep_y, oversample=False):
        self.cur, self.rep, self.repy = cur_ds, rep_x, rep_y
        self.oversample = oversample and hasattr(cur_ds, "pos")
        self.pos = cur_ds.pos if self.oversample else None
        self.neg = cur_ds.neg if self.oversample else None
        self.n   = max(len(self.cur), len(self.rep)) * 2 if self.rep else len(self.cur)
    def __len__(self): return self.n
    def __getitem__(self, idx):
        if self.rep and idx % 2 == 1:        # replay slot
            r = (idx // 2) % len(self.rep)
            return self.rep[r], self.repy[r]
        if self.oversample:
            idx_cur = random.choice(self.pos) if random.random() < 0.67 else random.choice(self.neg)
        else:
            idx_cur = (idx // 2) % len(self.cur)
        return self.cur[idx_cur]

# ───────────── lightning continual model ───────────────────
class CLModel(pl.LightningModule):
    def __init__(self, ckpt_path, init_lr, wd=1e-4):
        super().__init__()
        self.save_hyperparameters(ignore=["ckpt_path"])

        # backbone
        self.encoder = timm.create_model(
            "vit_tiny_patch16_224", num_classes=0, global_pool="token"
        )
        ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)["state_dict"]
        self.encoder.load_state_dict(
            {k.replace("student.", ""): v for k, v in ckpt.items() if k.startswith("student.")},
            strict=False,
        )

        # head
        with torch.no_grad():
            z = self.encoder(torch.zeros(1, 3, 224, 224))
            z = z[:, 0] if z.ndim == 3 else z
        self.head = nn.Linear(z.shape[-1], 1)

        # misc
        self.base_lr = init_lr
        self.wd      = wd
        self.loss_fn = nn.BCEWithLogitsLoss()              # overridden per task
        self.example_input_array = torch.zeros(1, 3, 224, 224)

        # continual-learning state
        self.mem_x, self.mem_y = [], []                    # replay buffer
        self.ewc_mu, self.ewc_f = None, None

    # ─────────── utilities ─────────────────────────────────
    def set_pos_weight(self, ds):
        """Set task-specific BCE pos_weight from dataset."""
        d = ds.cur if isinstance(ds, BalancedDataset) else ds
        pos = sum(d.tgts)
        neg = len(d) - pos
        pw  = torch.tensor([neg / pos]) if pos > 0 else torch.tensor([1.0])
        self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=pw.to(self.device))

    @torch.no_grad()
    def add_to_memory(self, dataset):
        sel = random.sample(range(len(dataset)), k=min(MEM_SIZE, len(dataset)))
        for xs, ys in DataLoader(Subset(dataset, sel), batch_size=64, num_workers=0):
            for xi, yi in zip(xs, ys):
                self.mem_x.append(xi.cpu())
                self.mem_y.append(yi.cpu())
        self.mem_x, self.mem_y = self.mem_x[-MEM_SIZE:], self.mem_y[-MEM_SIZE:]

    def compute_fisher(self, dataset, samples=600):
        fisher = {n: torch.zeros_like(p) for n, p in self.named_parameters()}
        loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)
        cnt = 0
        for x, y in loader:
            x, y = x.to(self.device), y.to(self.device)
            loss = self.loss_fn(self(x).squeeze(1), y.squeeze(1))
            self.zero_grad(); loss.backward()
            for n, p in self.named_parameters():
                fisher[n] += p.grad.detach() ** 2
            cnt += 1
            if cnt * 32 >= samples:
                break
        for n in fisher:
            fisher[n] /= cnt
        self.ewc_mu = [p.detach().clone() for p in self.parameters()]
        self.ewc_f  = [fisher[n] for n, _ in self.named_parameters()]

    # ─────────── forward / training / optim ────────────────
    def forward(self, x):
        z = self.encoder(x)
        z = z[:, 0] if z.ndim == 3 else z
        return self.head(z)

    def _total_loss(self, logits, y):
        loss = self.loss_fn(logits, y)
        if self.ewc_mu is not None:
            loss += EWC_LMB * sum(
                (f * (p - m).pow(2)).sum()
                for p, m, f in zip(self.parameters(), self.ewc_mu, self.ewc_f)
            )
        return loss

    def training_step(self, batch, _):
        x, y  = batch
        loss  = self._total_loss(self(x).squeeze(1), y.squeeze(1))
        self.log("loss", loss, prog_bar=True)
        return loss                                # Lightning handles grad-acc

    def configure_optimizers(self):
        opt   = torch.optim.AdamW(self.parameters(), lr=self.base_lr, weight_decay=self.wd)
        sched = CosineAnnealingLR(opt, T_max=20)   # will be patched per phase
        return [opt], [sched]

# ──────────── metrics helper ───────────────────────────────
def auc_on(model, ds):
    model.eval().to(DEVICE)
    auc = BinaryAUROC().to(DEVICE)
    with torch.no_grad():
        for x, y in DataLoader(ds, batch_size=64, num_workers=0):
            preds = torch.sigmoid(model(x.to(DEVICE)).squeeze(1))
            auc.update(preds, y.to(DEVICE).int().squeeze(1))
    return auc.compute().item()

# ───────── helper to run one phase with its own opt/sched ──
def run_phase(model, train_ds, epochs, lr, t_max):
    model.set_pos_weight(train_ds)

    def _cfg(self):
        opt   = torch.optim.AdamW(self.parameters(), lr=lr, weight_decay=self.wd)
        sched = CosineAnnealingLR(opt, T_max=t_max)
        return [opt], [sched]

    original = model.configure_optimizers
    model.configure_optimizers = _cfg.__get__(model, model.__class__)

    pl.Trainer(
        max_epochs           = epochs,
        accelerator          = DEVICE,
        devices              = 1,
        log_every_n_steps    = 10,
        gradient_clip_val    = 1.0,
        accumulate_grad_batches = GRAD_ACC,
    ).fit(model, DataLoader(train_ds, BATCH, shuffle=True, num_workers=0))

    model.configure_optimizers = original  # restore

# ─────────────────── training pipeline ─────────────────────
pl.seed_everything(42)
model = CLModel(CKPT, init_lr=LRS["shen"]).to(DEVICE)

# Task 0 – Shenzhen ----------------------------------------
run_phase(model, shen_train, EPOCHS["shen"],
          lr=LRS["shen"], t_max=EPOCHS["shen"] * 2)
shen_auc0 = auc_on(model, shen_val)
print(f"\nShenzhen AUC after Task 0: {shen_auc0:.3f}")

model.add_to_memory(shen_train); model.compute_fisher(shen_train)

# Task 1 – Montgomery --------------------------------------
model.base_lr = LRS["mont"]
train1 = BalancedDataset(mont_train, model.mem_x, model.mem_y)
run_phase(model, train1, EPOCHS["mont"],
          lr=LRS["mont"], t_max=EPOCHS["mont"] * 2)
mont_auc = auc_on(model, mont_val)
print(f"Montgomery AUC: {mont_auc:.3f}")

model.add_to_memory(mont_train); model.compute_fisher(mont_train)

# Task 2 – MIAS (Phase-1 & Phase-2) -------------------------
# freeze all encoder blocks first
for p in model.encoder.parameters(): p.requires_grad = False
for p in model.head.parameters(): p.requires_grad = True

# Phase-1: unfreeze last 3 blocks
for blk in model.encoder.blocks[9:]:
    for p in blk.parameters(): p.requires_grad = True
train2a = BalancedDataset(mias_train, model.mem_x, model.mem_y, oversample=True)
run_phase(model, train2a, EPOCHS["mias"][0],
          lr=LRS["mias_head"], t_max=EPOCHS["mias"][0] * 2)

# Phase-2: unfreeze blocks 6-8 as well
for blk in model.encoder.blocks[6:9]:
    for p in blk.parameters(): p.requires_grad = True
train2b = BalancedDataset(mias_train, model.mem_x, model.mem_y, oversample=True)
run_phase(model, train2b, EPOCHS["mias"][1],
          lr=LRS["mias_mid"], t_max=EPOCHS["mias"][1] * 2)

# ─────────────── final metrics ─────────────────────────────
shen_auc_f = auc_on(model, shen_val)
mont_auc_f = auc_on(model, mont_val)
mias_auc_f = auc_on(model, mias_val)

print("\n──────── Final AUCs ────────")
print(f"Shenzhen:   {shen_auc_f:.3f}")
print(f"Montgomery: {mont_auc_f:.3f}")
print(f"MIAS:       {mias_auc_f:.3f}")
print(f"Backward Transfer Δ (Shenzhen): {shen_auc_f - shen_auc0:+.3f}")


Seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params | Mode  | In sizes         | Out sizes
-------------------------------------------------------------------------------------
0 | encoder | VisionTransformer | 5.5 M  | train | [1, 3, 224, 224] | [1, 192] 
1 | head    | Linear            | 193    | train | [1, 192]         | [1, 1]   
2 | loss_fn | BCEWithLogitsLoss | 0      | train | ?                | ?        
-------------------------------------------------------------------------------------
5.5 M     Trainable params
0         Non-trainable params
5.5 M     Total params
22.098    Total estimated model params size (MB)
C:\Users\offic\anaconda3\envs\medssl\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider inc

Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.



Shenzhen AUC after Task 0: 0.758


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params | Mode  | In sizes         | Out sizes
-------------------------------------------------------------------------------------
0 | encoder | VisionTransformer | 5.5 M  | eval  | [1, 3, 224, 224] | [1, 192] 
1 | head    | Linear            | 193    | eval  | [1, 192]         | [1, 1]   
2 | loss_fn | BCEWithLogitsLoss | 0      | train | ?                | ?        
-------------------------------------------------------------------------------------
5.5 M     Trainable params
0         Non-trainable params
5.5 M     Total params
22.098    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=25` reached.


Montgomery AUC: 0.724


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params | Mode  | In sizes         | Out sizes
-------------------------------------------------------------------------------------
0 | encoder | VisionTransformer | 5.5 M  | eval  | [1, 3, 224, 224] | [1, 192] 
1 | head    | Linear            | 193    | eval  | [1, 192]         | [1, 1]   
2 | loss_fn | BCEWithLogitsLoss | 0      | train | ?                | ?        
-------------------------------------------------------------------------------------
1.3 M     Trainable params
4.2 M     Non-trainable params
5.5 M     Total params
22.098    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=12` reached.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params | Mode  | In sizes         | Out sizes
-------------------------------------------------------------------------------------
0 | encoder | VisionTransformer | 5.5 M  | eval  | [1, 3, 224, 224] | [1, 192] 
1 | head    | Linear            | 193    | eval  | [1, 192]         | [1, 1]   
2 | loss_fn | BCEWithLogitsLoss | 0      | train | ?                | ?        
-------------------------------------------------------------------------------------
2.7 M     Trainable params
2.9 M     Non-trainable params
5.5 M     Total params
22.098    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=25` reached.



──────── Final AUCs ────────
Shenzhen:   0.760
Montgomery: 0.724
MIAS:       0.696
Backward Transfer Δ (Shenzhen): +0.002
