In [1]:
# ✅ ViT-B בלבד עם Fine-Tuning הדרגתי + Validation + MSFlow

import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from torchvision.models import vit_b_16, ViT_B_16_Weights
from PIL import Image
from sklearn.metrics import confusion_matrix, roc_auc_score
import numpy as np

# ----------------------------
# Dataset Camelyon16
# ----------------------------
def get_transforms():
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

class Camelyon16Dataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.image_paths, self.labels = [], []
        self.transform = transform

        if split == 'train':
            folder = os.path.join(root_dir, 'train', 'good')
            self.image_paths = [os.path.join(folder, f) for f in os.listdir(folder)]
            self.labels = [0] * len(self.image_paths)
        else:
            g = os.path.join(root_dir, 'test', 'good')
            b = os.path.join(root_dir, 'test', 'Ungood')
            self.image_paths = [os.path.join(g, f) for f in os.listdir(g)] + \
                                [os.path.join(b, f) for f in os.listdir(b)]
            self.labels = [0] * len(os.listdir(g)) + [1] * len(os.listdir(b))

    def __len__(self): return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, self.labels[idx]

# ----------------------------
# Feature Extractor: ViT-B
# ----------------------------
class ViTFeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        model.heads = nn.Identity()
        self.model = model

    def forward(self, x):
        with torch.no_grad():
            f = self.model(x)
        return [f.view(f.size(0), f.size(1), 1, 1)]

# ----------------------------
# ActNorm2d + RealNVP Block + MSFlow
# ----------------------------
class ActNorm2d(nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        self.bias = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
        self.logs = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
        self.initialized = False

    def initialize(self, x):
        with torch.no_grad():
            mean = x.mean([0, 2, 3], keepdim=True)
            std = x.std([0, 2, 3], keepdim=True)
            self.bias.data.copy_(-mean)
            self.logs.data.copy_(torch.log(1 / (std + 1e-6)))
            self.initialized = True

    def forward(self, x):
        if not self.initialized:
            self.initialize(x)
        return (x + self.bias) * torch.exp(self.logs)

class RealNVPBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.norm = ActNorm2d(in_channels)
        self.net = nn.Sequential(
            nn.Conv2d(in_channels // 2, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 1),
            nn.ReLU(),
            nn.Conv2d(128, in_channels, 3, padding=1)
        )

    def forward(self, x):
        x_a, x_b = x.chunk(2, dim=1)
        h = self.net(x_a)
        s, t = h.chunk(2, dim=1)
        s = torch.sigmoid(s + 2.0)
        y_b = s * x_b + t
        log_det = torch.sum(torch.log(s).view(x.size(0), -1), dim=1)
        out = torch.cat([x_a, y_b], dim=1)
        out = self.norm(out)
        return out, log_det

class MSFlow(nn.Module):
    def __init__(self, in_channels_list):
        super().__init__()
        self.subflows = nn.ModuleList([
            nn.Sequential(*[RealNVPBlock(c) for _ in range(3)])
            for c in in_channels_list
        ])

    def forward(self, features):
        scores = []
        for i, flow in enumerate(self.subflows):
            x = features[i]
            log_det = 0
            for block in flow:
                x, det = block(x)
                log_det += det
            log_prob = -0.5 * torch.sum(x ** 2, dim=[1, 2, 3]) + log_det
            scores.append(log_prob)
        return -sum(scores) / len(scores)

# ----------------------------
# Training + Validation
# ----------------------------
def train_msflow():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    root = r"C:\\afeca academy\\סימסטר ב\\advanced deep learning\\project Normalizing Flow\\MedlAnaomaly-Data\\Camelyon16\\Camelyon16"

    full_ds = Camelyon16Dataset(root, 'test', get_transforms())
    val_size = int(0.2 * len(full_ds))
    test_size = len(full_ds) - val_size
    test_ds, val_ds = random_split(full_ds, [test_size, val_size])

    train_ds = Camelyon16Dataset(root, 'train', get_transforms())
    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False)
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False)

    extractor = ViTFeatureExtractor().to(device)
    flow = MSFlow([768]).to(device)
    optimizer = torch.optim.Adam(flow.parameters(), lr=1e-4)

    best_auc = 0.0
    for epoch in range(10):
        flow.train()
        total_loss = 0
        for x, _ in train_loader:
            x = x.to(device)
            features = extractor(x)
            loss = flow(features).mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"\n✅ Epoch {epoch+1} Avg Loss: {total_loss/len(train_loader):.4f}")

        flow.eval()
        all_scores, all_labels = [], []
        with torch.no_grad():
            for x, y in val_loader:
                x = x.to(device)
                scores = flow(extractor(x))
                all_scores.extend(scores.cpu().numpy())
                all_labels.extend(y.numpy())

        scores = np.array(all_scores)
        scores = np.nan_to_num(scores, nan=0.0, posinf=1e6, neginf=-1e6)
        scores_z = (scores - scores.mean()) / (scores.std() + 1e-8)
        preds = (scores_z > 0).astype(int)
        cm = confusion_matrix(all_labels, preds)
        auc = roc_auc_score(all_labels, scores_z)

        if auc > best_auc:
            best_auc = auc
            torch.save(flow.state_dict(), "best_vit_model.pth")
            print("💾 Best model so far – saving...")

        print(f"📊 Confusion Matrix:\n{cm}")
        print(f"📈 ROC AUC: {auc:.4f}")
        print(f"📉 Score Mean: {scores.mean():.2f} | Std: {scores.std():.2f}")

    print("\n📦 Loading best model for final evaluation...")
    flow.load_state_dict(torch.load("best_vit_model.pth"))
    flow.eval()
    all_scores, all_labels = [], []
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            scores = flow(extractor(x))
            all_scores.extend(scores.cpu().numpy())
            all_labels.extend(y.numpy())

    scores = np.array(all_scores)
    scores_z = (scores - scores.mean()) / (scores.std() + 1e-8)
    preds = (scores_z > 0).astype(int)
    cm = confusion_matrix(all_labels, preds)
    auc = roc_auc_score(all_labels, scores_z)

    print("\n✅ Final Evaluation with ViT-B + Fine-Tuning")
    print(f"📊 Confusion Matrix:\n{cm}")
    print(f"📈 Final ROC AUC: {auc:.4f}")
    print(f"📉 Final Score Mean: {scores.mean():.2f} | Std: {scores.std():.2f}")




In [2]:
train_msflow()


✅ Epoch 1 Avg Loss: 741.5255
💾 Best model so far – saving...
📊 Confusion Matrix:
[[89 32]
 [30 51]]
📈 ROC AUC: 0.7295
📉 Score Mean: 594.03 | Std: 272.24

✅ Epoch 2 Avg Loss: 423.5538
💾 Best model so far – saving...
📊 Confusion Matrix:
[[88 33]
 [33 48]]
📈 ROC AUC: 0.7348
📉 Score Mean: 380.45 | Std: 190.70

✅ Epoch 3 Avg Loss: 277.7954
💾 Best model so far – saving...
📊 Confusion Matrix:
[[89 32]
 [34 47]]
📈 ROC AUC: 0.7391
📉 Score Mean: 261.31 | Std: 138.37

✅ Epoch 4 Avg Loss: 187.3375
💾 Best model so far – saving...
📊 Confusion Matrix:
[[91 30]
 [31 50]]
📈 ROC AUC: 0.7400
📉 Score Mean: 177.00 | Std: 94.09

✅ Epoch 5 Avg Loss: 127.4357
📊 Confusion Matrix:
[[92 29]
 [33 48]]
📈 ROC AUC: 0.7317
📉 Score Mean: 122.27 | Std: 68.43

✅ Epoch 6 Avg Loss: 87.6960
📊 Confusion Matrix:
[[90 31]
 [29 52]]
📈 ROC AUC: 0.7365
📉 Score Mean: 84.70 | Std: 49.52

✅ Epoch 7 Avg Loss: 59.5701
📊 Confusion Matrix:
[[89 32]
 [32 49]]
📈 ROC AUC: 0.7242
📉 Score Mean: 57.98 | Std: 33.80

✅ Epoch 8 Avg Loss: 40.87

  flow.load_state_dict(torch.load("best_vit_model.pth"))



✅ Final Evaluation with ViT-B + Fine-Tuning
📊 Confusion Matrix:
[[333 106]
 [150 220]]
📈 Final ROC AUC: 0.7471
📉 Final Score Mean: 176.54 | Std: 92.45
