In [29]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

from models.mimic_cxr.models import CXREncoder, LatentSplitter, CXRDecoder, CXRModel
from data.chexpert_subset import get_chexpert_weighted_sampler

import torchvision.transforms as T
from sklearn.metrics import accuracy_score, roc_auc_score
from torch.utils.data import DataLoader, WeightedRandomSampler
import torch.optim as optim
import torch.nn as nn
from data.chexpert_subset import CheXpertSubsetDataset

In [30]:
device = torch.device("cpu")

In [31]:
def make_loaders(root, batch_size=8, image_size=224):
    # For end-to-end training with reconstruction, it's simpler to keep images in [0,1]
    transform = T.Compose([
        T.Resize((image_size, image_size)),
        T.ToTensor(),          # [0,1]
        # no normalization here; decoder outputs [0,1] via Sigmoid
    ])

    train_ds = CheXpertSubsetDataset(
        root=root,
        split="val",                    # we use "val" as the training subset
        pathology="Pleural Effusion",
        transform=transform,
    )

    test_ds = CheXpertSubsetDataset(
        root=root,
        split="test",
        pathology="Pleural Effusion",
        transform=transform,
    )

    # Compute class-balanced sample weights for Pleural Effusion
    labels = torch.tensor(
        [train_ds.df.iloc[i][train_ds.label_idx] for i in range(len(train_ds))],
        dtype=torch.float32,
    )
    N = len(labels)
    N_pos = labels.sum()
    N_neg = N - N_pos

    N_pos = max(N_pos, 1.0)
    N_neg = max(N_neg, 1.0)

    weight_pos = N / (2.0 * N_pos)
    weight_neg = N / (2.0 * N_neg)

    sample_weights = torch.where(
        labels == 1.0,
        torch.full_like(labels, weight_pos),
        torch.full_like(labels, weight_neg),
    )

    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=4,
    )

    test_loader = DataLoader(
        test_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
    )

    return train_loader, test_loader

In [32]:
def train_one_epoch(model, loader, optimizer, device,
                    lambda_sig=1.0, lambda_rec=20.0, lambda_cons=10.0, lambda_orth=1.0):
    model.train()

    bce = nn.BCEWithLogitsLoss()
    l1 = nn.L1Loss()
    mse = nn.MSELoss()

    total_loss = 0.0

    for batch in loader:
        imgs = batch["image"].to(device)             # (B, 3, 224, 224)
        labels = batch["label"].to(device).unsqueeze(1)  # (B, 1)

        # Forward
        z, z_sig, z_nui, x_recon, logits = model(imgs)

        # Reconstruction loss (image space, [0,1])
        L_rec = l1(x_recon, imgs)

        # Signal/probe loss (Pleural Effusion)
        L_sig = bce(logits, labels)

        # Consistency: encode reconstructed image and compare latents
        z_re = model.encoder(x_recon)           # (B, latent_dim)
        L_cons = mse(z, z_re)

        # ------------------------------------------------------------
        # ✔ Orthogonality loss (new)
        # ------------------------------------------------------------
        # Minimize correlation between z_sig and z_nui:
        # If vectors align, elementwise product is large.
        # Encourages the two subspaces to encode different factors.
        L_orth = torch.mean(z_sig * z_nui)

        # Combine all losses
        loss = (
            lambda_sig * L_sig +
            lambda_rec * L_rec +
            lambda_cons * L_cons +
            lambda_orth * L_orth
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)


In [33]:
def eval_probe_auc(model, loader, device):
    model.eval()
    preds, targets = [], []

    with torch.no_grad():
        for batch in loader:
            imgs = batch["image"].to(device)
            labels = batch["label"].cpu().numpy().tolist()

            _, z_sig, _, _, logits = model(imgs)
            probs = torch.sigmoid(logits).cpu().numpy().flatten().tolist()

            preds.extend(probs)
            targets.extend(labels)

    # Guard against degenerate case (all labels the same)
    if len(set(targets)) < 2:
        return float("nan")

    auc = roc_auc_score(targets, preds)
    return auc

In [34]:
def train_probe_warmup(model, loader, optimizer, device):
    model.train()
    bce = nn.BCEWithLogitsLoss()
    total_loss = 0.0

    for batch in loader:
        imgs = batch["image"].to(device)
        labels = batch["label"].to(device).unsqueeze(1)

        with torch.no_grad():
            z = model.encoder(imgs)
            z_sig, z_nui = model.splitter(z)
        logits = model.probe(z_sig)

        L = bce(logits, labels)

        optimizer.zero_grad()
        L.backward()
        optimizer.step()

        total_loss += L.item()

    return total_loss / len(loader)


In [35]:
def main():
    root = "./chexlocalize_download"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    batch_size = 8
    epochs = 20

    lambda_sig = 1.0
    lambda_rec = 20.0
    lambda_cons = 10.0

    warmup_epochs = 5

    train_loader, test_loader = get_chexpert_weighted_sampler(root, batch_size=8)

    model = CXRModel(latent_dim=1024, split_dim=512, pretrained=True).to(device)

    ###########################################################
    # Stage 1 — Warm-up: freeze encoder, train splitter + probe
    ###########################################################

    # Freeze encoder
    for p in model.encoder.parameters():
        p.requires_grad = False

    # Warm-up optimizer (only splitter + probe)
    warmup_optimizer = optim.Adam(
        list(model.probe.parameters()),
        lr=1e-3
    )

    # Warm-up loop
    for epoch in range(1, warmup_epochs + 1):
        warmup_loss = train_probe_warmup(model, train_loader, warmup_optimizer, device)
        warmup_auc = eval_probe_auc(model, test_loader, device)
        print(f"[Warm-up] Epoch {epoch:02d} | Loss: {warmup_loss:.4f} | AUC: {warmup_auc:.4f}")

    ###########################################################
    # Stage 2 — End-to-end training: unfreeze encoder + full loss
    ###########################################################

    for p in model.encoder.parameters():
        p.requires_grad = True

    optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)


    os.makedirs("checkpoints", exist_ok=True)

    for epoch in range(1, epochs + 1):
        train_loss = train_one_epoch(
            model, train_loader, optimizer, device,
            lambda_sig=lambda_sig,
            lambda_rec=lambda_rec,
            lambda_cons=lambda_cons,
        )
        auc = eval_probe_auc(model, test_loader, device)

        print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} | Test AUC: {auc:.4f}")

        # Save a checkpoint every few epochs
        if epoch % 5 == 0:
            ckpt_path = f"artifacts/chexlocalize/cxr_fullmodel_epoch{epoch}.pt"
            torch.save(model.state_dict(), ckpt_path)
            print(f"Saved checkpoint: {ckpt_path}")


if __name__ == "__main__":
    main()

  [val_ds.df.iloc[i][val_ds.label_idx] for i in range(len(val_ds))],


[Warm-up] Epoch 01 | Loss: 0.7041 | AUC: 0.5271
[Warm-up] Epoch 02 | Loss: 0.6612 | AUC: 0.5837
[Warm-up] Epoch 03 | Loss: 0.6548 | AUC: 0.6442
[Warm-up] Epoch 04 | Loss: 0.5993 | AUC: 0.6393
[Warm-up] Epoch 05 | Loss: 0.5607 | AUC: 0.6620
Epoch 01 | Train Loss: 19.2372 | Test AUC: 0.5596
Epoch 02 | Train Loss: 18.5931 | Test AUC: 0.5383
Epoch 03 | Train Loss: 18.1977 | Test AUC: 0.5856
Epoch 04 | Train Loss: 17.5197 | Test AUC: 0.6027
Epoch 05 | Train Loss: 17.1262 | Test AUC: 0.5679
Saved checkpoint: artifacts/chexlocalize/cxr_fullmodel_epoch5.pt
Epoch 06 | Train Loss: 16.4056 | Test AUC: 0.5968
Epoch 07 | Train Loss: 15.8921 | Test AUC: 0.5878
Epoch 08 | Train Loss: 15.4719 | Test AUC: 0.5922
Epoch 09 | Train Loss: 15.0300 | Test AUC: 0.6054
Epoch 10 | Train Loss: 14.8441 | Test AUC: 0.6104
Saved checkpoint: artifacts/chexlocalize/cxr_fullmodel_epoch10.pt
Epoch 11 | Train Loss: 14.3131 | Test AUC: 0.6396
Epoch 12 | Train Loss: 14.1529 | Test AUC: 0.6516
Epoch 13 | Train Loss: 13.658

In [36]:
def collect_latents(model, root, max_samples=1000, image_size=224, device="cpu"):
    transform = T.Compose([
        T.Resize((image_size, image_size)),
        T.ToTensor(),   # [0,1]
    ])

    ds = CheXpertSubsetDataset(
        root=root,
        split="val",
        pathology="Pleural Effusion",
        transform=transform
    )
    loader = DataLoader(ds, batch_size=16, shuffle=True, num_workers=4)

    all_z_sig = []
    all_z_nui = []
    all_labels = []
    all_views = []

    with torch.no_grad():
        for batch in loader:
            imgs = batch["image"].to(device)
            labels = batch["label"].cpu().numpy()
            views = batch["view"].cpu().numpy()

            # full model forward to get z_sig, z_nui
            _, z_sig, z_nui, _, _ = model(imgs)

            all_z_sig.append(z_sig.cpu().numpy())
            all_z_nui.append(z_nui.cpu().numpy())
            all_labels.append(labels)
            all_views.append(views)

            if sum(len(x) for x in all_labels) >= max_samples:
                break

    Z_sig = np.concatenate(all_z_sig, axis=0)[:max_samples]
    Z_nui = np.concatenate(all_z_nui, axis=0)[:max_samples]
    y = np.concatenate(all_labels, axis=0)[:max_samples]
    v = np.concatenate(all_views, axis=0)[:max_samples]

    return Z_sig, Z_nui, y, v

In [37]:
def run_tsne(Z, color, legend, title, out_path):
    print(f"Running TSNE for {title} ...")
    tsne = TSNE(n_components=2, perplexity=30, learning_rate="auto", init="pca")
    Z2 = tsne.fit_transform(Z)

    plt.figure(figsize=(6,5))
    for key in np.unique(color):
        mask = color == key
        plt.scatter(Z2[mask, 0], Z2[mask, 1], s=8, alpha=0.7, label=legend[key])
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()
    print(f"Saved: {out_path}")

In [38]:
def load_model(checkpoint_path, device):
    model = CXRModel(
        latent_dim=1024,
        split_dim=512,
        pretrained=True  # still load ImageNet weights for encoder init
    )
    state = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(state)
    model.to(device)
    model.eval()
    return model

In [39]:
root = "./chexlocalize_download"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Choose which checkpoint you want to visualize
checkpoint = "artifacts/chexlocalize/cxr_fullmodel_epoch10.pt"

model = load_model(checkpoint, device)

print("Collecting latents...")
Z_sig, Z_nui, labels, views = collect_latents(
    model, root, max_samples=800, device=device
)

# Pleural Effusion (0 or 1)
effusion_legend = {0.0: "No Effusion", 1.0: "Effusion"}

# View labels: 0=AP, 1=PA, 2=Lateral
view_legend = {0: "AP", 1: "PA", 2: "Lateral"}

Collecting latents...


In [40]:
run_tsne(
    Z_sig,
    labels,
    effusion_legend,
    "t-SNE: z_sig (colored by Pleural Effusion)",
    "tsne_z_sig_effusion_w_ortho.png"
)

run_tsne(
    Z_nui,
    views,
    view_legend,
    "t-SNE: z_nui (colored by View AP/PA/Lateral)",
    "tsne_z_nui_view_w_ortho.png"
)

print("t-SNE visualization complete.")

Running TSNE for t-SNE: z_sig (colored by Pleural Effusion) ...
Saved: tsne_z_sig_effusion_w_ortho.png
Running TSNE for t-SNE: z_nui (colored by View AP/PA/Lateral) ...
Saved: tsne_z_nui_view_w_ortho.png
t-SNE visualization complete.
