In [3]:
from pathlib import Path
import random, numpy as np, pandas as pd, torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset, Subset
from torchvision import transforms
from PIL import Image
from sklearn.model_selection import train_test_split
import timm, pytorch_lightning as pl
from torchmetrics.classification import BinaryAUROC

# ────────────────────────────────────────────────────────────
#  Paths  (edit if yours differ)
# ────────────────────────────────────────────────────────────
CKPT     = Path(r"C:\Users\offic\medself\shenzhen_ckpts\epochepoch=11-aucval_auc=0.896.ckpt")
SHEN_CSV = Path(r"C:\Users\offic\OneDrive\Masaüstü\datasets\SelfSupervised\Shenzhen\shenzhen_metadata.csv")
SHEN_IMG = Path(r"C:\Users\offic\OneDrive\Masaüstü\datasets\SelfSupervised\Shenzhen\images\images")
MONT_CSV = Path(r"C:\Users\offic\OneDrive\Masaüstü\datasets\SelfSupervised\Montgomery\montgomery_metadata.csv")
MONT_IMG = Path(r"C:\Users\offic\OneDrive\Masaüstü\datasets\SelfSupervised\Montgomery\images\images")
for p in (CKPT, SHEN_CSV, SHEN_IMG, MONT_CSV, MONT_IMG): assert p.exists(), p

# ────────────────────────────────────────────────────────────
#  Hyper-params & constants
# ────────────────────────────────────────────────────────────
DEVICE        = "cuda" if torch.cuda.is_available() else "cpu"
MEM_SIZE      = 200
EWC_LAMBDA    = 0.05
SHEN_EPOCHS   = 5
MONT_EPOCHS   = 15
BATCH_SIZE    = 32

# ────────────────────────────────────────────────────────────
#  Transforms & Dataset
# ────────────────────────────────────────────────────────────
def med_tf(train=True, aug=0.5):
    ops = [transforms.Resize((224, 224))]
    if train:
        ops.append(transforms.RandomHorizontalFlip(p=aug))
    ops += [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
    return transforms.Compose(ops)

class CXRCSV(Dataset):
    def __init__(self, df, img_dir, train=True, aug=0.5):
        self.df = df.reset_index(drop=True)
        self.dir = img_dir
        self.tf  = med_tf(train, aug)
        self.targets = [0 if f.lower().strip()=="normal" else 1
                        for f in self.df["findings"]]

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        x = self.tf(Image.open(self.dir / self.df.iloc[idx]["study_id"]).convert("L")).repeat(3,1,1)
        y = torch.tensor([self.targets[idx]], dtype=torch.float32)
        return x, y

def make_split(csv_path, img_path, seed=42):
    df = pd.read_csv(csv_path)
    y  = df["findings"].apply(lambda s: 0 if s.lower().strip()=="normal" else 1)
    tr, val = train_test_split(np.arange(len(df)), stratify=y,
                               test_size=0.2, random_state=seed)
    return (CXRCSV(df.iloc[tr],  img_path, True),
            CXRCSV(df.iloc[val], img_path, False))

shen_train, shen_val = make_split(SHEN_CSV,  SHEN_IMG)
mont_train, mont_val = make_split(MONT_CSV,  MONT_IMG)

# ────────────────────────────────────────────────────────────
#  Continual-learning LightningModule
# ────────────────────────────────────────────────────────────
class CLModel(pl.LightningModule):
    def __init__(self, ckpt_path, lr=1e-4, wd=1e-4):
        super().__init__()
        # SSL ViT encoder
        self.encoder = timm.create_model("vit_tiny_patch16_224",
                                         num_classes=0, global_pool="token")
        sd = torch.load(ckpt_path, map_location="cpu", weights_only=False)["state_dict"]
        self.encoder.load_state_dict({k.replace("student.",""):v for k,v in sd.items()
                                      if k.startswith("student.")}, strict=False)
        # Detect CLS dim robustly
        with torch.no_grad():
            dummy = torch.zeros(1, 3, 224, 224)
            feat  = self.encoder(dummy)
            if feat.ndim == 3: feat = feat[:,0]
            self.feat_dim = feat.shape[-1]

        self.head = nn.Linear(self.feat_dim, 1)
        self.lr, self.wd = lr, wd
        self.crit = nn.BCEWithLogitsLoss()

        # Replay / EWC storage
        self.mem_x, self.mem_y = [], []
        self.ewc_mu = None
        self.ewc_fisher = None   #  <<< fixed: no unpacking error

    # Forward returns (B,1)
    def forward(self, x):
        z = self.encoder(x)
        if z.ndim == 3: z = z[:,0]
        return self.head(z)

    def training_step(self, batch, _):
        x, y = batch
        logits = self(x).squeeze(1)
        loss = self.crit(logits, y.squeeze(1))
        if self.ewc_mu is not None:
            loss += EWC_LAMBDA * sum((f * (p - m).pow(2)).sum()
                                     for p, m, f in zip(self.parameters(),
                                                        self.ewc_mu,
                                                        self.ewc_fisher))
        self.log("loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.wd)

    # ---------- Replay helpers ----------
    @torch.no_grad()
    def add_to_memory(self, dataset):
        sel = random.sample(range(len(dataset)),
                            k=min(MEM_SIZE, len(dataset)))
        for xs, ys in DataLoader(Subset(dataset, sel), batch_size=64):
            # store individual samples   <<< fixed
            for xi, yi in zip(xs, ys):
                self.mem_x.append(xi.cpu())
                self.mem_y.append(yi.cpu())
        self.mem_x, self.mem_y = self.mem_x[-MEM_SIZE:], self.mem_y[-MEM_SIZE:]

    # ---------- Fisher computation ----------
    def compute_fisher(self, dataset, samples=600):
        fisher = {n: torch.zeros_like(p) for n, p in self.named_parameters()}
        loader = DataLoader(dataset, batch_size=32, shuffle=True)
        cnt = 0
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            loss = self.crit(self(x).squeeze(1), y.squeeze(1))
            self.zero_grad(); loss.backward()
            for (n, p) in self.named_parameters():
                fisher[n] += p.grad.detach() ** 2
            cnt += 1
            if cnt * 32 >= samples: break
        for n in fisher: fisher[n] /= cnt
        self.ewc_mu     = [p.detach().clone() for p in self.parameters()]
        self.ewc_fisher = [fisher[n] for n, _ in self.named_parameters()]

# ────────────────────────────────────────────────────────────
#  Evaluation helper
# ────────────────────────────────────────────────────────────
def auc_on(model, dataset):
    model.to(DEVICE).eval()
    auc = BinaryAUROC().to(DEVICE)
    with torch.no_grad():
        for x, y in DataLoader(dataset, batch_size=64):
            auc.update(torch.sigmoid(model(x.to(DEVICE)).squeeze(1)),
                       y.to(DEVICE).int().squeeze(1))
    return auc.compute().item()

# ────────────────────────────────────────────────────────────
#  Phase 0 – Train on Shenzhen
# ────────────────────────────────────────────────────────────
pl.seed_everything(42)
model = CLModel(CKPT).to(DEVICE)

pl.Trainer(max_epochs=SHEN_EPOCHS, accelerator=DEVICE,
           devices=1, log_every_n_steps=20).fit(
    model, DataLoader(shen_train, BATCH_SIZE, shuffle=True, num_workers=0))

shen_auc_before = auc_on(model, shen_val)
print(f"\nShenzhen AUC after Task 0: {shen_auc_before:.3f}")

model.add_to_memory(shen_train)
model.compute_fisher(shen_train)

# ────────────────────────────────────────────────────────────
#  Phase 1 – Train on Montgomery + replay
# ────────────────────────────────────────────────────────────
class ReplayDS(Dataset):
    def __init__(self, xs, ys): self.xs, self.ys = xs, ys
    def __len__(self): return len(self.xs)
    def __getitem__(self, i):  return self.xs[i], self.ys[i]

combo_train = ConcatDataset([mont_train,
                             ReplayDS(model.mem_x, model.mem_y)])

pl.Trainer(max_epochs=MONT_EPOCHS, accelerator=DEVICE,
           devices=1, log_every_n_steps=20).fit(
    model, DataLoader(combo_train, BATCH_SIZE, shuffle=True, num_workers=0))

# ────────────────────────────────────────────────────────────
#  Final metrics
# ────────────────────────────────────────────────────────────
shen_auc_after = auc_on(model, shen_val)
mont_auc       = auc_on(model, mont_val)

print(f"\nFinal Shenzhen AUC:   {shen_auc_after:.3f}")
print(f"Montgomery AUC:       {mont_auc:.3f}")
print(f"Backward Transfer Δ:  {shen_auc_after - shen_auc_before:+.3f}")


Seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params | Mode 
------------------------------------------------------
0 | encoder | VisionTransformer | 5.5 M  | train
1 | head    | Linear            | 193    | train
2 | crit    | BCEWithLogitsLoss | 0      | train
------------------------------------------------------
5.5 M     Trainable params
0         Non-trainable params
5.5 M     Total params
22.098    Total estimated model params size (MB)
C:\Users\offic\anaconda3\envs\medssl\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
C:\Users\offic\anaconda3\envs\medssl\Lib\site-packages\pytorch_lightning\

Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.



Shenzhen AUC after Task 0: 0.797


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params | Mode
-----------------------------------------------------
0 | encoder | VisionTransformer | 5.5 M  | eval
1 | head    | Linear            | 193    | eval
2 | crit    | BCEWithLogitsLoss | 0      | eval
-----------------------------------------------------
5.5 M     Trainable params
0         Non-trainable params
5.5 M     Total params
22.098    Total estimated model params size (MB)
C:\Users\offic\anaconda3\envs\medssl\Lib\site-packages\pytorch_lightning\loops\fit_loop.py:298: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=20). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=15` reached.



Final Shenzhen AUC:   0.852
Montgomery AUC:       0.797
Backward Transfer Δ:  +0.055
