In [1]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix, roc_auc_score
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF

# ----------------------------
# 1. Data Preprocessing
# ----------------------------

def transform_image(pil_image: Image.Image) -> torch.Tensor:
    transform_ops = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    return transform_ops(pil_image.convert("RGB"))

In [2]:
class Camelyon16Dataset(Dataset):
    def __init__(self, root_dir, split='train'):
        self.image_paths = []
        self.labels = []
        self.names = []
        self.split = split

        if split == 'train':
            subdir = os.path.join(root_dir, 'train', 'good')
            self.image_paths = [os.path.join(subdir, f) for f in os.listdir(subdir)]
            self.labels = [0] * len(self.image_paths)
        else:
            good_dir = os.path.join(root_dir, 'test', 'good')
            bad_dir = os.path.join(root_dir, 'test', 'Ungood')
            self.image_paths += [os.path.join(good_dir, f) for f in os.listdir(good_dir)]
            self.labels += [0] * len(os.listdir(good_dir))
            self.image_paths += [os.path.join(bad_dir, f) for f in os.listdir(bad_dir)]
            self.labels += [1] * len(os.listdir(bad_dir))

        self.names = [os.path.basename(p) for p in self.image_paths]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        name = self.names[idx]
        image = Image.open(img_path)
        return transform_image(image), label, name

In [3]:
# ----------------------------
# 2. Feature Extractor
# ----------------------------
class ResNetFeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3

    def forward(self, x):
        x = self.layer0(x)
        out1 = self.layer1(x)
        out2 = self.layer2(out1)
        out3 = self.layer3(out2)
        return [out1, out2, out3]

In [4]:
# ----------------------------
# 3. ActNorm
# ----------------------------
class ActNorm2d(nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        self.bias = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
        self.logs = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
        self.initialized = False

    def initialize(self, x):
        with torch.no_grad():
            mean = x.mean([0, 2, 3], keepdim=True)
            std = x.std([0, 2, 3], keepdim=True)
            self.bias.data.copy_(-mean)
            self.logs.data.copy_(torch.log(1 / (std + 1e-6)))
            self.initialized = True

    def forward(self, x):
        if not self.initialized:
            self.initialize(x)
        return (x + self.bias) * torch.exp(self.logs)

In [5]:
# ----------------------------
# 4. RealNVP Block
# ----------------------------
class RealNVPBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.norm = ActNorm2d(in_channels)
        self.net = nn.Sequential(
            nn.Conv2d(in_channels // 2, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(128, in_channels, kernel_size=3, padding=1),
        )

    def forward(self, x):
        x_a, x_b = x.chunk(2, dim=1)
        h = self.net(x_a)
        s, t = h.chunk(2, dim=1)
        s = torch.sigmoid(s + 2.0)
        y_b = s * x_b + t
        log_det_jacobian = torch.sum(torch.log(s).view(x.size(0), -1), dim=1)
        out = torch.cat([x_a, y_b], dim=1)
        out = self.norm(out)
        return out, log_det_jacobian


# ----------------------------
# 5. MSFlow
# ----------------------------
class MSFlow(nn.Module):
    def __init__(self, in_channels_list):
        super().__init__()
        self.subflows = nn.ModuleList([
            nn.Sequential(*[RealNVPBlock(c) for _ in range(3)])
            for c in in_channels_list
        ])

    def forward(self, features):
        total_log_likelihood = []
        for i, flow in enumerate(self.subflows):
            x = features[i]
            log_det = 0
            for block in flow:
                x, det = block(x)
                log_det += det
            log_prob = -0.5 * torch.sum(x ** 2, dim=[1, 2, 3]) + log_det
            total_log_likelihood.append(log_prob)

        fused_score = sum(total_log_likelihood) / len(total_log_likelihood)
        return -fused_score

In [6]:
# ----------------------------
# 6. Training – Image-level AUROC Only
# ----------------------------
def train_msflow():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Using device:", device)

    root_dir = r"C:\afeca academy\סימסטר ב\advanced deep learning\project Normalizing Flow\MedlAnaomaly-Data\Camelyon16\Camelyon16"

    train_set = Camelyon16Dataset(root_dir, split='train')
    test_set = Camelyon16Dataset(root_dir, split='test')
    train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=0)

    resnet = ResNetFeatureExtractor().to(device)
    flow_model = MSFlow(in_channels_list=[256, 512, 1024]).to(device)
    optimizer = torch.optim.Adam(flow_model.parameters(), lr=1e-4)

    for epoch in range(10):
        print(f"\n🔁 Epoch {epoch+1}/10", flush=True)
        flow_model.train()
        total_loss = 0

        for batch_idx, (imgs, _, _) in enumerate(train_loader):
            imgs = imgs.to(device)
            features = resnet(imgs)
            loss = flow_model(features).mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            print(f"  📦 Batch {batch_idx+1}/{len(train_loader)} | Loss: {loss.item():.4f}", flush=True)

        print(f"✅ Epoch {epoch+1} Avg Loss: {total_loss/len(train_loader):.4f}", flush=True)

        # Evaluation
        flow_model.eval()
        all_scores, all_labels = [], []

        with torch.no_grad():
            for imgs, labels, _ in test_loader:
                imgs = imgs.to(device)
                features = resnet(imgs)
                scores = flow_model(features)
                all_scores.extend(scores.cpu().numpy())
                all_labels.extend(labels.numpy())

        preds = [1 if s > np.median(all_scores) else 0 for s in all_scores]
        cm = confusion_matrix(all_labels, preds)
        auc = roc_auc_score(all_labels, all_scores)
        print(f"📊 Confusion Matrix:\n{cm}")
        print(f"📈 Image-level ROC AUC: {auc:.4f}")

In [7]:
train_msflow()

Using device: cuda

🔁 Epoch 1/10
  📦 Batch 1/1272 | Loss: 323792.3125
  📦 Batch 2/1272 | Loss: 359444.3438
  📦 Batch 3/1272 | Loss: 351001.7812
  📦 Batch 4/1272 | Loss: 350395.4062
  📦 Batch 5/1272 | Loss: 347852.7500
  📦 Batch 6/1272 | Loss: 332094.7500
  📦 Batch 7/1272 | Loss: 345539.5312
  📦 Batch 8/1272 | Loss: 348751.6250
  📦 Batch 9/1272 | Loss: 345478.7500
  📦 Batch 10/1272 | Loss: 336762.3125
  📦 Batch 11/1272 | Loss: 340683.4375
  📦 Batch 12/1272 | Loss: 346451.6250
  📦 Batch 13/1272 | Loss: 339819.3125
  📦 Batch 14/1272 | Loss: 346239.1250
  📦 Batch 15/1272 | Loss: 334955.5000
  📦 Batch 16/1272 | Loss: 340436.4375
  📦 Batch 17/1272 | Loss: 329445.4375
  📦 Batch 18/1272 | Loss: 331054.3750
  📦 Batch 19/1272 | Loss: 339281.8750
  📦 Batch 20/1272 | Loss: 329189.2812
  📦 Batch 21/1272 | Loss: 313356.7500
  📦 Batch 22/1272 | Loss: 317285.8125
  📦 Batch 23/1272 | Loss: 320206.9375
  📦 Batch 24/1272 | Loss: 327705.5625
  📦 Batch 25/1272 | Loss: 330179.9688
  📦 Batch 26/1272 | Loss: 