In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from sklearn.metrics import confusion_matrix, roc_auc_score, f1_score
from PIL import Image
import numpy as np
import os
import cv2
import random
import matplotlib.pyplot as plt

# =============================
# Dataset Loader
# =============================
class Camelyon16Dataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.image_paths, self.labels = [], []
        self.transform = transform
        if split == 'train':
            folder = os.path.join(root_dir, 'train', 'good')
            self.image_paths = [os.path.join(folder, f) for f in os.listdir(folder)]
            self.labels = [0] * len(self.image_paths)
        else:
            g = os.path.join(root_dir, 'test', 'good')
            b = os.path.join(root_dir, 'test', 'Ungood')
            self.image_paths = [os.path.join(g, f) for f in os.listdir(g)] + [os.path.join(b, f) for f in os.listdir(b)]
            self.labels = [0] * len(os.listdir(g)) + [1] * len(os.listdir(b))

    def __len__(self): return len(self.image_paths)
    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform: img = self.transform(img)
        return img, self.labels[idx]

# =============================
# Augmentations
# =============================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# =============================
# ResNet18 Frozen Extractor
# =============================
class ResNetFeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.stage1 = nn.Sequential(*list(model.children())[:5])  # עד כולל layer1
        self.stage2 = model.layer2
        self.stage3 = model.layer3
        self.stage4 = model.layer4
        for param in self.parameters():
            param.requires_grad = False
        

    def forward(self, x):
        x = self.stage1(x)
        f1 = self.stage2(x)
        f2 = self.stage3(f1)
        f3 = self.stage4(f2)
        return [f1, f2, f3]

# =============================
# MSFlow (RealNVP + Fusion)
# =============================
class ActNorm2d(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.bias = nn.Parameter(torch.zeros(1, channels, 1, 1))
        self.logs = nn.Parameter(torch.zeros(1, channels, 1, 1))
        self.initialized = False

    def initialize(self, x):
        with torch.no_grad():
            mean = x.mean([0, 2, 3], keepdim=True)
            std = x.std([0, 2, 3], keepdim=True)
            self.bias.data.copy_(-mean)
            self.logs.data.copy_(torch.log(1 / (std + 1e-6)))
            self.initialized = True

    def forward(self, x):
        if not self.initialized: self.initialize(x)
        return (x + self.bias) * torch.exp(self.logs)

class RealNVPBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.norm = ActNorm2d(in_channels)
        self.net = nn.Sequential(
            nn.Conv2d(in_channels // 2, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 1),
            nn.ReLU(),
            nn.Conv2d(128, in_channels, 3, padding=1)
        )

    def forward(self, x):
        x_a, x_b = x.chunk(2, dim=1)
        h = self.net(x_a)
        s, t = h.chunk(2, dim=1)
        s = torch.sigmoid(s + 2.0)
        y_b = s * x_b + t
        log_det = torch.sum(torch.log(s).reshape(x.size(0), -1), dim=1)
        out = torch.cat([x_a, y_b], dim=1)
        out = self.norm(out)
        return out, log_det

class MSFlow(nn.Module):
    def __init__(self, channels_list):
        super().__init__()
        self.flows = nn.ModuleList([
            nn.Sequential(*[RealNVPBlock(c) for _ in range(3)]) for c in channels_list
        ])

    def forward(self, features):
        scores = []
        for i, flow in enumerate(self.flows):
            x = features[i]
            log_det = 0
            for block in flow:
                x, det = block(x)
                log_det += det
            log_prob = -0.5 * torch.sum(x ** 2, dim=[1, 2, 3]) + log_det
            scores.append(log_prob)
        return -sum(scores) / len(scores)

# =============================
# Train Loop
# =============================
def train_msflow():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    root_dir = r"C:\afeca academy\סימסטר ב\advanced deep learning\project Normalizing Flow\MedlAnaomaly-Data\Camelyon16\Camelyon16"
    train_ds = Camelyon16Dataset(root_dir, 'train', transform)
    val_ds = Camelyon16Dataset(root_dir, 'test', transform)
    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False)

    extractor = ResNetFeatureExtractor().to(device)
    flow = MSFlow([128, 256, 512]).to(device)  # בהתאמה ל־ResNet18
    optimizer = torch.optim.Adam(flow.parameters(), lr=1e-4)

    best_auc = 0
    for epoch in range(10):
        flow.train()
        total_loss = 0
        for x, _ in train_loader:
            x = x.to(device)
            feats = extractor(x)
            loss = flow(feats).mean()
            optimizer.zero_grad(); loss.backward(); optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(train_loader)
        print(f"✅ Epoch {epoch+1} Avg Loss: {avg_loss:.4f}")
        evaluate(val_loader, extractor, flow, device)

# =============================
# Evaluation
# =============================
def evaluate(val_loader, extractor, flow, device):
    flow.eval()
    all_scores, all_labels = [], []
    with torch.no_grad():
        for x, y in val_loader:
            x = x.to(device)
            feats = extractor(x)
            scores = flow(feats)
            all_scores.extend(scores.cpu().numpy())
            all_labels.extend([y.item()])
    scores = np.array(all_scores)
    scores_z = (scores - scores.mean()) / (scores.std() + 1e-8)
    preds = (scores_z > 0).astype(int)
    cm = confusion_matrix(all_labels, preds)
    auc = roc_auc_score(all_labels, scores_z)
    f1 = f1_score(all_labels, preds)
    print(f"📊 Confusion Matrix:\n{cm}")
    print(f"📈 ROC AUC: {auc:.4f} | F1: {f1:.4f}")




In [5]:

train_msflow()

✅ Epoch 1 Avg Loss: 23052.1996
📊 Confusion Matrix:
[[753 367]
 [336 777]]
📈 ROC AUC: 0.7266 | F1: 0.6885
✅ Epoch 2 Avg Loss: 11987.0105
📊 Confusion Matrix:
[[767 353]
 [335 778]]
📈 ROC AUC: 0.7299 | F1: 0.6934
✅ Epoch 3 Avg Loss: 7149.5825
📊 Confusion Matrix:
[[762 358]
 [324 789]]
📈 ROC AUC: 0.7329 | F1: 0.6982
✅ Epoch 4 Avg Loss: 4449.9575
📊 Confusion Matrix:
[[773 347]
 [340 773]]
📈 ROC AUC: 0.7312 | F1: 0.6923
✅ Epoch 5 Avg Loss: 2829.3960
📊 Confusion Matrix:
[[778 342]
 [344 769]]
📈 ROC AUC: 0.7306 | F1: 0.6915
✅ Epoch 6 Avg Loss: 1816.0876
📊 Confusion Matrix:
[[768 352]
 [351 762]]
📈 ROC AUC: 0.7239 | F1: 0.6843
✅ Epoch 7 Avg Loss: 1172.5379
📊 Confusion Matrix:
[[775 345]
 [345 768]]
📈 ROC AUC: 0.7249 | F1: 0.6900
✅ Epoch 8 Avg Loss: 759.0340
📊 Confusion Matrix:
[[773 347]
 [347 766]]
📈 ROC AUC: 0.7302 | F1: 0.6882
✅ Epoch 9 Avg Loss: 492.1548
📊 Confusion Matrix:
[[763 357]
 [361 752]]
📈 ROC AUC: 0.7233 | F1: 0.6769
✅ Epoch 10 Avg Loss: 319.6370
📊 Confusion Matrix:
[[772 348]
 [3