## Environment

In [49]:
import os, json, random, csv, platform
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from datasets import load_dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix
from tqdm import tqdm

SEED = 114514
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
DEVICE = ("mps" if torch.backends.mps.is_available()
          else "cuda" if torch.cuda.is_available()
          else "cpu")
print(f"[INFO] Device = {DEVICE}")

OUT = Path("outputs/baseline"); OUT.mkdir(parents=True, exist_ok=True)
OUT_EDA = Path("outputs/eda"); OUT_EDA.mkdir(parents=True, exist_ok=True)

with open(OUT/"run_info.json","w") as f:
    json.dump({"seed": SEED, "device": DEVICE,
               "python": f"{os.sys.version_info.major}.{os.sys.version_info.minor}.{os.sys.version_info.micro}",
               "torch": torch.__version__}, f, indent=2)


[INFO] Device = mps


## Load Data

In [50]:
ds = load_dataset("mmenendezg/pneumonia_x_ray")  # splits: train/validation/test
ds


DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 4187
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 1045
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 624
    })
})

## EDA

In [51]:
def split_stats(split, sample_k=200, seed=SEED):
    n = len(split)
    labels = np.array(split["label"], dtype=int)
    cnt = np.bincount(labels, minlength=2)
    k = min(sample_k, n)
    rng = random.Random(seed)
    idx = rng.sample(range(n), k) if n > 0 else []
    w_list, h_list = [], []
    for i in idx:
        w, h = split[i]["image"].size
        w_list.append(w); h_list.append(h)
    if k == 0:
        size_mean = [0.0, 0.0]; size_min = [0,0]; size_max=[0,0]
    else:
        size_mean = [float(np.mean(w_list)), float(np.mean(h_list))]
        size_min  = [int(np.min(w_list)),   int(np.min(h_list))]
        size_max  = [int(np.max(w_list)),   int(np.max(h_list))]
    return {"n": int(n),
            "class_count": {0:int(cnt[0]), 1:int(cnt[1])},
            "size_sample_mean": size_mean,
            "size_sample_min":  size_min,
            "size_sample_max":  size_max}

stats = {"train": split_stats(ds["train"]),
         "validation": split_stats(ds["validation"]),
         "test": split_stats(ds["test"])}

with open(OUT_EDA/"data_stats.json","w") as f:
    json.dump(stats, f, indent=2)
stats


{'train': {'n': 4187,
  'class_count': {0: 1080, 1: 3107},
  'size_sample_mean': [500.0, 500.0],
  'size_sample_min': [500, 500],
  'size_sample_max': [500, 500]},
 'validation': {'n': 1045,
  'class_count': {0: 269, 1: 776},
  'size_sample_mean': [500.0, 500.0],
  'size_sample_min': [500, 500],
  'size_sample_max': [500, 500]},
 'test': {'n': 624,
  'class_count': {0: 234, 1: 390},
  'size_sample_mean': [500.0, 500.0],
  'size_sample_min': [500, 500],
  'size_sample_max': [500, 500]}}

## EDA Show

In [52]:
def plot_class_bar(stats, title="Class Distribution (0=normal, 1=pneumonia)",
                   savepath=OUT_EDA/"class_dist.png", log_y=False):
    savepath = Path(savepath); savepath.parent.mkdir(parents=True, exist_ok=True)
    labels = ["normal(0)", "pneumonia(1)"]
    def counts_of(s): d=s["class_count"]; return [int(d.get(0,0)), int(d.get(1,0))]
    train_cnt = counts_of(stats["train"]); val_cnt = counts_of(stats["validation"]); test_cnt = counts_of(stats["test"])
    x=np.arange(2); w=0.25
    fig, ax = plt.subplots(figsize=(7,4.2))
    b1=ax.bar(x-w, train_cnt, width=w, label="train")
    b2=ax.bar(x,   val_cnt,   width=w, label="val")
    b3=ax.bar(x+w, test_cnt,  width=w, label="test")
    ax.set_xticks(x); ax.set_xticklabels(labels); ax.set_ylabel("count"); ax.set_title(title)
    if log_y: ax.set_yscale("log")
    ax.legend()
    for bars in (b1,b2,b3):
        for rect in bars:
            h=rect.get_height()
            ax.annotate(f"{int(h)}", (rect.get_x()+rect.get_width()/2, h),
                        xytext=(0,3), textcoords="offset points",
                        ha="center", va="bottom", fontsize=8)
    fig.tight_layout(); fig.savefig(savepath, dpi=220); plt.close(fig)
    return savepath

_ = plot_class_bar(stats)


## Trun in Dataswt

In [53]:
IMAGENET_MEAN=[0.485,0.456,0.406]; IMAGENET_STD=[0.229,0.224,0.225]
tx_train = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])
tx_eval = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

class HFDataset(Dataset):
    def __init__(self, split, t): self.ds, self.t = split, t
    def __len__(self): return len(self.ds)
    def __getitem__(self, i):
        ex = self.ds[i]
        img = ex["image"].convert("RGB")
        x = self.t(img)
        y = int(ex["label"])
        return x, y

train_set = HFDataset(ds["train"], tx_train)
val_set   = HFDataset(ds["validation"], tx_eval)
test_set  = HFDataset(ds["test"], tx_eval)


## Load Data

In [None]:
# 6. DataLoader
from torch.utils.data import DataLoader
BATCH = 64 

train_loader = DataLoader(train_set, batch_size=BATCH, shuffle=True,
                          num_workers=0, pin_memory=False)
val_loader   = DataLoader(val_set,   batch_size=BATCH, shuffle=False,
                          num_workers=0, pin_memory=False)
test_loader  = DataLoader(test_set,  batch_size=BATCH, shuffle=False,
                          num_workers=0, pin_memory=False)

# class weights
import numpy as np
train_labels = np.array(ds["train"]["label"], dtype=int)
cnt = np.bincount(train_labels, minlength=2)
class_weights = (cnt.sum()/(2.0*np.maximum(cnt,1))).astype(np.float32)
print("class_counts:", cnt.tolist(), " class_weights:", class_weights.tolist())



class_counts: [1080, 3107]  class_weights: [1.938425898551941, 0.6738011240959167]


## Training Tools

In [55]:
def build_model():
    m = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    m.fc = nn.Linear(m.fc.in_features, 2)
    return m

def eval_epoch(model, loader):
    model.eval()
    ys, p1s = [], []
    with torch.inference_mode():
        for x,y in loader:
            x,y = x.to(DEVICE), y.to(DEVICE)
            prob1 = torch.softmax(model(x), dim=1)[:,1]
            ys.append(y.cpu().numpy()); p1s.append(prob1.cpu().numpy())
    y_true = np.concatenate(ys); p1 = np.concatenate(p1s)
    y_pred = (p1>=0.5).astype(int)
    acc = accuracy_score(y_true, y_pred)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
    auc = roc_auc_score(y_true, p1)
    return {"acc":acc,"prec":prec,"rec":rec,"f1":f1,"auc":auc}, y_true, p1

def plot_confmat_roc(y_true, p1, outdir:Path):
    outdir = Path(outdir); outdir.mkdir(parents=True, exist_ok=True)
    y_pred=(p1>=0.5).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0,1])

    fig = plt.figure(); plt.imshow(cm, interpolation='nearest')
    plt.title("Confusion Matrix"); plt.xlabel("Pred"); plt.ylabel("True")
    plt.xticks([0,1],["normal","pneumonia"]); plt.yticks([0,1],["normal","pneumonia"])
    for i in range(2):
        for j in range(2): plt.text(j, i, cm[i,j], ha="center", va="center")
    plt.tight_layout(); fig.savefig(outdir/"confmat.png", dpi=220); plt.close(fig)

    thr = np.linspace(0,1,200); tpr=[]; fpr=[]; P=(y_true==1).sum(); N=(y_true==0).sum()
    for t in thr:
        yp=(p1>=t).astype(int)
        TP=((yp==1)&(y_true==1)).sum(); FP=((yp==1)&(y_true==0)).sum()
        TPR=TP/max(P,1); FPR=FP/max(N,1)
        tpr.append(TPR); fpr.append(FPR)
    fig=plt.figure(); plt.plot(fpr,tpr); plt.plot([0,1],[0,1],"--")
    plt.xlabel("FPR"); plt.ylabel("TPR"); plt.title("ROC Curve"); plt.tight_layout()
    fig.savefig(outdir/"roc.png", dpi=220); plt.close(fig)

def best_threshold_youden(y_true, p1):
    thr = np.linspace(0,1,501)
    best_t, best_j = 0.5, -1
    for t in thr:
        yp = (p1>=t).astype(int)
        TP=((yp==1)&(y_true==1)).sum(); FP=((yp==1)&(y_true==0)).sum()
        TN=((yp==0)&(y_true==0)).sum(); FN=((yp==0)&(y_true==1)).sum()
        TPR = TP/max(TP+FN,1); FPR = FP/max(FP+TN,1)
        J = TPR - FPR
        if J > best_j: best_j, best_t = J, t
    return float(best_t), float(best_j)


## Training

In [56]:
model = build_model().to(DEVICE)
criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(DEVICE))
optim = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode="max", factor=0.5, patience=2)

best_f1, best_epoch, patience = -1.0, -1, 4
history = []
for epoch in range(1, 11):
    model.train(); running = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/10")
    for x,y in pbar:
        x,y = x.to(DEVICE), y.to(DEVICE)
        optim.zero_grad()
        loss = criterion(model(x), y)
        loss.backward(); optim.step()
        running += loss.item()*x.size(0)
        pbar.set_postfix(loss=f"{loss.item():.4f}", lr=f"{optim.param_groups[0]['lr']:.2e}")
    tr_loss = running/len(train_loader.dataset)
    val_metrics, _, _ = eval_epoch(model, val_loader)
    val_f1 = float(val_metrics["f1"]) if np.isfinite(val_metrics["f1"]) else 0.0
    sched.step(val_f1)
    history.append({"epoch":epoch,"train_loss":tr_loss, **{k:float(v) for k,v in val_metrics.items()}})
    if val_metrics["f1"] > best_f1:
        best_f1 = val_metrics["f1"]; best_epoch = epoch; patience = 4
        torch.save({"model":model.state_dict()}, OUT/"best.pt")
    else:
        patience -= 1
        if patience == 0: print("Early stop."); break

with open(OUT/"history.csv","w",newline="") as f:
    w = csv.DictWriter(f, fieldnames=["epoch","train_loss","acc","prec","rec","f1","auc"])
    w.writeheader(); w.writerows(history)


Epoch 1/10: 100%|██████████| 66/66 [00:43<00:00,  1.53it/s, loss=0.0530, lr=3.00e-04]
Epoch 2/10: 100%|██████████| 66/66 [00:43<00:00,  1.52it/s, loss=0.0011, lr=3.00e-04]
Epoch 3/10: 100%|██████████| 66/66 [00:47<00:00,  1.38it/s, loss=0.0004, lr=3.00e-04]
Epoch 4/10: 100%|██████████| 66/66 [00:51<00:00,  1.27it/s, loss=0.0039, lr=3.00e-04]
Epoch 5/10: 100%|██████████| 66/66 [00:49<00:00,  1.34it/s, loss=0.0008, lr=3.00e-04]
Epoch 6/10: 100%|██████████| 66/66 [00:48<00:00,  1.36it/s, loss=0.0003, lr=1.50e-04]
Epoch 7/10: 100%|██████████| 66/66 [00:47<00:00,  1.40it/s, loss=0.0004, lr=1.50e-04]
Epoch 8/10: 100%|██████████| 66/66 [00:47<00:00,  1.40it/s, loss=0.0403, lr=1.50e-04]
Epoch 9/10: 100%|██████████| 66/66 [00:47<00:00,  1.40it/s, loss=0.0020, lr=1.50e-04]
Epoch 10/10: 100%|██████████| 66/66 [00:47<00:00,  1.39it/s, loss=0.0001, lr=1.50e-04]


## Test Model

In [57]:
ckpt = torch.load(OUT/"best.pt", map_location=DEVICE)
model.load_state_dict(ckpt["model"])
test_metrics, y_true, p1 = eval_epoch(model, test_loader)
best_t, best_j = best_threshold_youden(y_true, p1)

with open(OUT/"metrics.json","w") as f:
    json.dump({"best_epoch": best_epoch,
               "history_last": history[-1] if history else {},
               "test": {k:float(v) for k,v in test_metrics.items()},
               "best_threshold_youden": best_t,
               "best_J": best_j}, f, indent=2)

plot_confmat_roc(y_true, p1, OUT)
print("Saved to:", OUT.resolve())
print("TEST:", test_metrics)
print(f"Best epoch: {best_epoch}, Youden best threshold: {best_t:.3f} (J={best_j:.3f})")


  ckpt = torch.load(OUT/"best.pt", map_location=DEVICE)


Saved to: /Users/nolan/Desktop/507/outputs/baseline
TEST: {'acc': 0.8589743589743589, 'prec': 0.8158995815899581, 'rec': 1.0, 'f1': 0.8986175115207373, 'auc': 0.9762382204689897}
Best epoch: 10, Youden best threshold: 0.998 (J=0.844)


## 增强 tool

In [None]:
# augmentation tool
from torchvision import transforms

IMAGENET_MEAN=[0.485,0.456,0.406]; IMAGENET_STD=[0.229,0.224,0.225]

def get_transform(level):
    if level=="light":
        return transforms.Compose([
            transforms.Resize((224,224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ])
    if level=="medium":
        return transforms.Compose([
            transforms.Resize((224,224)),
            transforms.RandomResizedCrop(224, scale=(0.9,1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=10),
            transforms.ColorJitter(brightness=0.1, contrast=0.1),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ])
    if level=="strong":
        return transforms.Compose([
            transforms.Resize((224,224)),
            transforms.RandomResizedCrop(224, scale=(0.8,1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),                                  
            transforms.RandomErasing(p=0.5, scale=(0.02,0.1), ratio=(0.3,3.3)),  
            transforms.GaussianBlur(kernel_size=3, sigma=(0.1,1.0)),
            transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ])
    raise ValueError(level)


## 增强结果

In [None]:
import json, csv, numpy as np, torch, torch.nn as nn
from torch.utils.data import DataLoader
from pathlib import Path
from tqdm import tqdm

OUT_ROOT = Path("outputs")
TX_EVAL = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

train_labels = np.array(ds["train"]["label"], dtype=int)
cnt = np.bincount(train_labels, minlength=2)
class_weights = (cnt.sum()/(2.0*np.maximum(cnt,1))).astype(np.float32)

def run_one_augmentation(level):
    out = OUT_ROOT/f"aug_{level}"
    out.mkdir(parents=True, exist_ok=True)
    print(f"\n=== Augmentation: {level} -> {out} ===")

    train_set = HFDataset(ds["train"], get_transform(level))
    val_set   = HFDataset(ds["validation"], TX_EVAL)
    test_set  = HFDataset(ds["test"], TX_EVAL)

    train_loader = DataLoader(train_set, batch_size=64, shuffle=True,  num_workers=0, pin_memory=False)
    val_loader   = DataLoader(val_set,   batch_size=64, shuffle=False, num_workers=0, pin_memory=False)
    test_loader  = DataLoader(test_set,  batch_size=64, shuffle=False, num_workers=0, pin_memory=False)

    model = build_model().to(DEVICE)
    criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(DEVICE))
    optim = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode="max", factor=0.5, patience=2)

    best_f1, patience, history = -1.0, 4, []
    for epoch in range(1, 11):
        model.train(); running=0.0
        pbar = tqdm(train_loader, desc=f"[{level}] Epoch {epoch}/10")
        for x,y in pbar:
            x,y = x.to(DEVICE), y.to(DEVICE)
            optim.zero_grad()
            loss = criterion(model(x), y)
            loss.backward(); optim.step()
            running += loss.item()*x.size(0)
            pbar.set_postfix(loss=f"{loss.item():.4f}")
        tr_loss = running/len(train_loader.dataset)
        val_metrics, _, _ = eval_epoch(model, val_loader)
        sched.step(float(val_metrics["f1"]))
        history.append({"epoch":epoch,"train_loss":tr_loss, **{k:float(v) for k,v in val_metrics.items()}})
        if val_metrics["f1"] > best_f1:
            best_f1 = val_metrics["f1"]; patience = 4
            torch.save({"model":model.state_dict()}, out/"best.pt")
        else:
            patience -= 1
            if patience == 0: break

    ckpt = torch.load(out/"best.pt", map_location=DEVICE)
    model.load_state_dict(ckpt["model"])
    test_metrics, y_true, p1 = eval_epoch(model, test_loader)
    with open(out/"metrics.json","w") as f:
        json.dump({"history":history, "test":{k:float(v) for k,v in test_metrics.items()}}, f, indent=2)
    plot_confmat_roc(y_true, p1, out)
    print(f"[{level}] TEST:", test_metrics)
    return level, test_metrics

results = []
for lv in ["light","medium","strong"]:
    name, m = run_one_augmentation(lv)
    results.append([name, m["acc"], m["prec"], m["rec"], m["f1"], m["auc"]])

OUT_ROOT.mkdir(parents=True, exist_ok=True)
with open(OUT_ROOT/"aug_summary.csv","w",newline="") as f:
    w = csv.writer(f)
    w.writerow(["augmentation","accuracy","precision","recall","f1","auc"])
    w.writerows(results)

print("Saved summary ->", (OUT_ROOT/"aug_summary.csv").resolve())



=== Augmentation: light -> outputs/aug_light ===


[light] Epoch 1/10: 100%|██████████| 66/66 [00:42<00:00,  1.56it/s, loss=0.1374]
[light] Epoch 2/10: 100%|██████████| 66/66 [00:43<00:00,  1.50it/s, loss=0.0090]
[light] Epoch 3/10: 100%|██████████| 66/66 [00:42<00:00,  1.54it/s, loss=0.0735]
[light] Epoch 4/10: 100%|██████████| 66/66 [00:43<00:00,  1.52it/s, loss=0.0370]
[light] Epoch 5/10: 100%|██████████| 66/66 [00:44<00:00,  1.48it/s, loss=0.0010]
[light] Epoch 6/10: 100%|██████████| 66/66 [00:47<00:00,  1.38it/s, loss=0.0026]
[light] Epoch 7/10: 100%|██████████| 66/66 [00:47<00:00,  1.40it/s, loss=0.0001]
[light] Epoch 8/10: 100%|██████████| 66/66 [00:48<00:00,  1.37it/s, loss=0.0002]
[light] Epoch 9/10: 100%|██████████| 66/66 [00:47<00:00,  1.39it/s, loss=0.0009]
[light] Epoch 10/10: 100%|██████████| 66/66 [00:47<00:00,  1.38it/s, loss=0.0004]
  ckpt = torch.load(out/"best.pt", map_location=DEVICE)


[light] TEST: {'acc': 0.8685897435897436, 'prec': 0.826271186440678, 'rec': 1.0, 'f1': 0.9048723897911833, 'auc': 0.9750438308130617}

=== Augmentation: medium -> outputs/aug_medium ===


[medium] Epoch 1/10: 100%|██████████| 66/66 [00:51<00:00,  1.29it/s, loss=0.5562]
[medium] Epoch 2/10: 100%|██████████| 66/66 [00:51<00:00,  1.28it/s, loss=0.0059]
[medium] Epoch 3/10: 100%|██████████| 66/66 [00:51<00:00,  1.28it/s, loss=0.0096]
[medium] Epoch 4/10: 100%|██████████| 66/66 [00:54<00:00,  1.21it/s, loss=0.0129]
[medium] Epoch 5/10: 100%|██████████| 66/66 [00:54<00:00,  1.21it/s, loss=0.0422]
[medium] Epoch 6/10: 100%|██████████| 66/66 [00:53<00:00,  1.22it/s, loss=0.0055]
[medium] Epoch 7/10: 100%|██████████| 66/66 [00:54<00:00,  1.20it/s, loss=0.0434]
[medium] Epoch 8/10: 100%|██████████| 66/66 [00:54<00:00,  1.21it/s, loss=0.0192]
[medium] Epoch 9/10: 100%|██████████| 66/66 [00:52<00:00,  1.25it/s, loss=0.0715]
[medium] Epoch 10/10: 100%|██████████| 66/66 [00:55<00:00,  1.19it/s, loss=0.1127]
  ckpt = torch.load(out/"best.pt", map_location=DEVICE)


[medium] TEST: {'acc': 0.9310897435897436, 'prec': 0.9044289044289044, 'rec': 0.9948717948717949, 'f1': 0.9474969474969475, 'auc': 0.9928884505807583}

=== Augmentation: strong -> outputs/aug_strong ===


[strong] Epoch 1/10: 100%|██████████| 66/66 [00:58<00:00,  1.13it/s, loss=0.0569]
[strong] Epoch 2/10: 100%|██████████| 66/66 [00:57<00:00,  1.15it/s, loss=0.0156]
[strong] Epoch 3/10: 100%|██████████| 66/66 [00:58<00:00,  1.13it/s, loss=0.1349]
[strong] Epoch 4/10: 100%|██████████| 66/66 [00:57<00:00,  1.16it/s, loss=0.1235]
[strong] Epoch 5/10: 100%|██████████| 66/66 [00:57<00:00,  1.16it/s, loss=0.0924]
  ckpt = torch.load(out/"best.pt", map_location=DEVICE)


[strong] TEST: {'acc': 0.9326923076923077, 'prec': 0.910377358490566, 'rec': 0.9897435897435898, 'f1': 0.9484029484029484, 'auc': 0.9894915625684857}
Saved summary -> /Users/nolan/Desktop/507/outputs/aug_summary.csv


## VIT

In [None]:
import os
import json
import random
import numpy as np
import torch
import torch.nn as nn
from pathlib import Path
from tqdm import tqdm
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import vit_b_16, ViT_B_16_Weights
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix
import matplotlib.pyplot as plt

SEED = 114514
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(SEED)

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

DEVICE = ("mps" if torch.backends.mps.is_available() 
          else "cuda" if torch.cuda.is_available() 
          else "cpu")
print(f"Running on Device: {DEVICE}")

OUT = Path("outputs/vit_fix")
OUT.mkdir(parents=True, exist_ok=True)

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

def to_rgb(img):
    return img.convert("RGB")

tx_train = transforms.Compose([
    transforms.Lambda(to_rgb),
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

tx_eval = transforms.Compose([
    transforms.Lambda(to_rgb),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

class HFDataset(Dataset):
    def __init__(self, split, t):
        self.ds = split
        self.t = t
    def __len__(self): 
        return len(self.ds)
    def __getitem__(self, i):
        item = self.ds[i]     
        x = self.t(item["image"])
        y = int(item["label"])
        return x, y

def build_model():
    m = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
    in_f = m.heads.head.in_features
    m.heads.head = nn.Linear(in_f, 2)
    return m

@torch.no_grad()
def eval_epoch(model, loader):
    model.eval()
    ys, p1s = [], []
    for x, y in loader:
        x = x.to(DEVICE)
        if not x.is_contiguous(): x = x.contiguous()
        y = y.to(DEVICE)
        
        logits = model(x)
        prob1 = torch.softmax(logits, dim=1)[:, 1]
        
        ys.append(y.cpu().numpy())
        p1s.append(prob1.cpu().numpy())

    if not ys: return {}, np.array([]), np.array([]) 

    y_true = np.concatenate(ys)
    p1 = np.concatenate(p1s)
    y_pred = (p1 >= 0.5).astype(int)
    
    acc = accuracy_score(y_true, y_pred)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
    try:
        auc = roc_auc_score(y_true, p1)
    except:
        auc = 0.5

    return {"acc": acc, "prec": prec, "rec": rec, "f1": f1, "auc": auc}, y_true, p1

def plot_results(y_true, p1, outdir):
    if len(y_true) == 0: return
    y_pred = (p1 >= 0.5).astype(int)
    
    cm = confusion_matrix(y_true, y_pred, labels=[0,1])
    fig = plt.figure()
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title("Confusion Matrix")
    plt.xticks([0,1], ["Normal", "Pneumonia"])
    plt.yticks([0,1], ["Normal", "Pneumonia"])
    for i in range(2):
        for j in range(2):
            plt.text(j, i, cm[i,j], ha="center", va="center", color="red")
    plt.tight_layout()
    fig.savefig(outdir/"confmat.png")
    plt.close()

if __name__ == "__main__":
    print(">>> 1. Loading Dataset...")

    try:
        raw_ds = load_dataset("mmenendezg/pneumonia_x_ray")
    except Exception as e:
        print(f"Error loading dataset: {e}")
        exit(1)

    if "validation" not in raw_ds:
        print("   Warning: No validation set found. Splitting from train...")
        split = raw_ds["train"].train_test_split(test_size=0.1, seed=SEED)
        ds = {"train": split["train"], "validation": split["test"], "test": raw_ds.get("test", split["test"])}
    else:
        ds = raw_ds

    train_set = HFDataset(ds["train"], tx_train)
    val_set   = HFDataset(ds["validation"], tx_eval)
    test_set  = HFDataset(ds["test"], tx_eval)

    BATCH = 32
    train_loader = DataLoader(train_set, batch_size=BATCH, shuffle=True, num_workers=0)
    val_loader   = DataLoader(val_set, batch_size=BATCH, shuffle=False, num_workers=0)
    test_loader  = DataLoader(test_set, batch_size=BATCH, shuffle=False, num_workers=0)

    try:
        labels_list = ds["train"]["label"]
    except:
        labels_list = [x['label'] for x in ds["train"]]
        
    cnt = np.bincount(labels_list, minlength=2)
    cw = (cnt.sum() / (2.0 * np.maximum(cnt, 1))).astype(np.float32)
    cw_t = torch.tensor(cw, dtype=torch.float32).to(DEVICE)
    print(f">>> Class Weights: {cw} (Counts: {cnt})")

    model = build_model().to(DEVICE)
    model = model.to(memory_format=torch.contiguous_format)

    criterion = nn.CrossEntropyLoss(weight=cw_t)
    optim = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode="max", factor=0.5, patience=1)

    EPOCHS = 5
    best_f1 = -1.0
    
    print(f">>> Start Training ({EPOCHS} Epochs)...")
    
    if DEVICE == "mps":
        dummy = torch.randn(2, 3, 224, 224).to(DEVICE)
        _ = model(dummy)

    for epoch in range(1, EPOCHS+1):
        model.train()
        run_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}")
        
        for x, y in pbar:
            x = x.to(DEVICE, non_blocking=True)
            if not x.is_contiguous():
                x = x.contiguous()
            
            y = y.to(DEVICE).long()
            
            optim.zero_grad(set_to_none=True)
            
            logits = model(x)
            loss = criterion(logits, y)
            
            loss.backward()
            optim.step()
            
            run_loss += loss.item() * x.size(0)
            pbar.set_postfix(loss=f"{loss.item():.4f}")

        avg_loss = run_loss / len(train_set)
        val_m, _, _ = eval_epoch(model, val_loader)
        
        print(f"   [Val] Loss: {avg_loss:.4f} | F1: {val_m['f1']:.4f} | Acc: {val_m['acc']:.4f}")
        
        sched.step(val_m["f1"])
        
        if val_m["f1"] > best_f1:
            best_f1 = val_m["f1"]
            torch.save({"model": model.state_dict()}, OUT/"best.pt")
            print("   >>> Best Model Saved!")
            
    # test
    print("\n>>> Testing Best Model...")
    ckpt = torch.load(OUT/"best.pt", map_location=DEVICE)
    model.load_state_dict(ckpt["model"])
    test_m, y_true, p1 = eval_epoch(model, test_loader)
    
    plot_results(y_true, p1, OUT)
    
    with open(OUT/"metrics.json", "w") as f:
        json.dump(test_m, f, indent=4)
        
    print(f"Final Test Metrics: {test_m}")
    print(f"Results saved to: {OUT.resolve()}")

Running on Device: mps
>>> 1. Loading Dataset...
>>> Class Weights: [1.9384259 0.6738011] (Counts: [1080 3107])
>>> Start Training (5 Epochs)...


Epoch 1/5:   0%|          | 0/131 [00:58<?, ?it/s]


RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.