In [None]:
# === BASELINE (BALANCED) ===

from common.dataset import load_balanced_chestxray
from common.model import get_model
from common.utils import get_device
from train import train_model
from eval import eval_model
from torch.utils.data import DataLoader

root = "/content/chestxray"  # שנה בהתאם

device = get_device()

# ====== LOAD BALANCED DATA ======
train_ds = load_balanced_chestxray(
    root=root,
    split="train",
    class0_name="NORMAL",
    class1_name="PNEUMONIA",
    per_class=2000,
    image_size=224,
    in_channels=3,
    augment=True,
)

test_ds = load_balanced_chestxray(
    root=root,
    split="test",
    class0_name="NORMAL",
    class1_name="PNEUMONIA",
    per_class=None,
    image_size=224,
    in_channels=3,
    augment=False,
)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

# ===== MODEL =====
model = get_model(
    arch="resnet18",
    num_classes=2,
    in_channels=3,
    pretrained=True,
).to(device)

# ===== TRAIN =====
model, hist = train_model(
    model,
    train_loader,
    device,
    lr=1e-4,
    epochs=8,
    class_weights=None,
)

# ===== EVAL =====
overall, c0, c1, metrics, cm = eval_model(model, test_loader, device)

print("Overall Accuracy:", overall)
print("Class0 Accuracy :", c0)
print("Class1 Accuracy :", c1)
print("Fairness metrics:", metrics)
print("Confusion Matrix:\n", cm)