<a href="https://colab.research.google.com/github/FadouaBOUAFIF/ApprentissageAuto/blob/main/version_25_avril.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Databricks notebook source
import zipfile
import os

zip_path = "/Volumes/dp-datalake-dev-default/test/test/4NSigComp2010.zip"  # change to your .zip file path
extract_dir = "/Volumes/dp-datalake-dev-default/test/test"

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_dir)

print("Unzipped to:", extract_dir)


# COMMAND ----------

# MAGIC
# MAGIC %pip install torch torchvision timm

# COMMAND ----------

# MAGIC %pip install opencv-python

# COMMAND ----------

# MAGIC %md
# MAGIC # SWIN-SIAMESE MODEL

# COMMAND ----------

import os
import random
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import timm
from torchvision import transforms
import cv2
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc, f1_score
import matplotlib.pyplot as plt

# --------------------------
# 1. Dataset Preparation
# --------------------------
genuine_dir = "/Volumes/dp-datalake-dev-default/test/test/genuines"
forged_dir = "/Volumes/dp-datalake-dev-default/test/test/forgeries"

genuine_paths = [os.path.join(genuine_dir, f) for f in os.listdir(genuine_dir)]
forged_paths = [os.path.join(forged_dir, f) for f in os.listdir(forged_dir)]

genuine_train, genuine_test = train_test_split(genuine_paths, test_size=0.2, random_state=42)
forged_train, forged_test = train_test_split(forged_paths, test_size=0.2, random_state=42)

# --------------------------
# 2. Data Augmentation
# --------------------------
class SignatureTransform:
    def __init__(self, train=True):
        self.train = train
        self.base_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.ToTensor(),
        ])

    def __call__(self, img):
        img = np.array(img)
        gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        edges = cv2.Laplacian(gray, cv2.CV_8U, ksize=3)
        edges = np.stack([edges] * 3, axis=-1)
        img = np.clip(img + 0.5 * edges, 0, 255).astype(np.uint8)
        img = self.base_transform(img)

        if self.train:
            img = transforms.RandomPerspective(distortion_scale=0.3, p=0.5)(img)
            img = transforms.ColorJitter(brightness=0.2, contrast=0.2)(img)
            img = transforms.RandomRotation(10)(img)
            img = transforms.RandomResizedCrop(224, scale=(0.9, 1.1))(img)
        else:
            img = transforms.Resize(256)(img)
            img = transforms.CenterCrop(224)(img)

        return img

# --------------------------
# 3. Dataset Implementation
# --------------------------
class SwinSiameseDataset(Dataset):
    def __init__(self, genuine_paths, forged_paths, transform=None):
        self.genuine_paths = genuine_paths
        self.forged_paths = forged_paths
        self.transform = transform
        self.pairs = self._generate_pairs()

    def _generate_pairs(self):
        pairs = []
        for i in range(len(self.genuine_paths)):
            for j in range(i + 1, min(i + 10, len(self.genuine_paths))):
                pairs.append((self.genuine_paths[i], self.genuine_paths[j], 1))

        for genuine in self.genuine_paths:
            sampled_forged = random.sample(self.forged_paths, min(5, len(self.forged_paths)))
            for forged in sampled_forged:
                pairs.append((genuine, forged, 0))

        random.shuffle(pairs)
        return pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img1_path, img2_path, label = self.pairs[idx]
        img1 = Image.open(img1_path).convert("RGB")
        img2 = Image.open(img2_path).convert("RGB")

        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)

        return img1, img2, torch.tensor(label, dtype=torch.float32)

# --------------------------
# 4. Model Definition
# --------------------------
class SwinSignatureNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = timm.create_model(
            'swin_tiny_patch4_window7_224', pretrained=True, num_classes=0, features_only=True
        )
        feature_dim = self.feature_extractor.feature_info.channels()[-1]

        self.local_attention = nn.Sequential(
            nn.Conv2d(feature_dim, feature_dim, kernel_size=1),
            nn.Sigmoid()
        )

        self.head = nn.Sequential(
            nn.Linear(feature_dim * 4, 1024),
            nn.BatchNorm1d(1024),
            nn.Mish(inplace=True),
            nn.Dropout(0.5),

            nn.Linear(1024, 512),
            nn.InstanceNorm1d(512),
            nn.Mish(),
            nn.Dropout(0.3),

            nn.Linear(512, 1)
        )

    def forward(self, x1, x2):
        f1 = self._process_single(x1)
        f2 = self._process_single(x2)
        diff = torch.abs(f1 - f2)
        prod = f1 * f2
        features = torch.cat([f1, f2, diff, prod], dim=1)
        return self.head(features)

    def _process_single(self, x):
        features = self.feature_extractor(x)[-1]
        features = features.permute(0, 3, 1, 2)
        attention = self.local_attention(features)
        features = features * attention
        return F.adaptive_avg_pool2d(features, (1, 1)).view(features.size(0), -1)

# --------------------------
# 5. Loss Function
# --------------------------
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
        return focal_loss.mean()

# --------------------------
# 6. Optimizer Setup
# --------------------------
def get_optimizer(model, loader):
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=3e-4, steps_per_epoch=len(loader), epochs=50
    )
    return optimizer, scheduler

# --------------------------
# 7. Train & Eval
# --------------------------
def train_epoch(model, loader, optimizer, criterion, device, scaler):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for img1, img2, label in loader:
        img1, img2, label = img1.to(device), img2.to(device), label.to(device)
        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            outputs = model(img1, img2).squeeze()
            loss = criterion(outputs, label)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        preds = (torch.sigmoid(outputs) > 0.5).float()
        correct += (preds == label).sum().item()
        total += label.size(0)

    return total_loss / len(loader), correct / total

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    all_logits = []
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for img1, img2, label in loader:
            img1, img2, label = img1.to(device), img2.to(device), label.to(device)
            logits = model(img1, img2).squeeze()
            loss = criterion(logits, label)
            total_loss += loss.item()
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float()
            correct += (preds == label).sum().item()
            total += label.size(0)

            all_logits.append(probs.cpu().numpy())
            all_preds.append(preds.cpu().numpy())
            all_labels.append(label.cpu().numpy())

    avg_loss = total_loss / len(loader)
    acc = correct / total
    all_logits = np.concatenate(all_logits)
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    return avg_loss, acc, all_logits, all_preds, all_labels

# --------------------------
# 8. Distribution Check
# --------------------------
def check_pair_distribution(dataset):
    pos = sum(1 for _, _, label in dataset.pairs if label == 1)
    neg = sum(1 for _, _, label in dataset.pairs if label == 0)
    total = len(dataset.pairs)
    print(f"Total Pairs: {total}\nPositive: {pos} ({pos/total:.2%})\nNegative: {neg} ({neg/total:.2%})")

# --------------------------
# 9. Main Execution
# --------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    train_dataset = SwinSiameseDataset(genuine_train, forged_train, transform=SignatureTransform(train=True))
    test_dataset = SwinSiameseDataset(genuine_test, forged_test, transform=SignatureTransform(train=False))

    print("\n Train Dataset Distribution:")
    check_pair_distribution(train_dataset)
    print("\n Test Dataset Distribution:")
    check_pair_distribution(test_dataset)

    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=0, pin_memory=True)

    model = SwinSignatureNetwork().to(device)
    criterion = FocalLoss()
    optimizer, scheduler = get_optimizer(model, train_loader)
    scaler = torch.cuda.amp.GradScaler()

    best_accuracy = 0
    for epoch in range(20):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device, scaler)
        test_loss, test_acc, logits, preds, labels = evaluate(model, test_loader, criterion, device)
        scheduler.step()

        # Compute F1
        f1 = f1_score(labels, preds)

        # Compute ROC curve and AUC
        fpr, tpr, _ = roc_curve(labels, logits)
        roc_auc = auc(fpr, tpr)

        # Save best model
        if test_acc > best_accuracy:
            best_accuracy = test_acc
            torch.save(model.state_dict(), "best_model.pth")
            print(f"New best model saved at epoch {epoch+1} with accuracy: {best_accuracy:.4f}")

        # Print metrics
        print(f"Epoch {epoch+1:02d}: "
              f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} || "
              f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f} | "
              f"F1: {f1:.4f} | AUC: {roc_auc:.4f}")

        # Plot ROC curve
        plt.figure()
        plt.plot(fpr, tpr, label=f'AUC = {roc_auc:.4f}')
        plt.plot([0, 1], [0, 1], linestyle='--')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title(f'ROC Curve Epoch {epoch+1}')
        plt.legend(loc='lower right')
        plt.savefig(f'roc_epoch_{epoch+1}.png')
        plt.close()

    print(f"\nTraining complete. Best Test Accuracy: {best_accuracy:.4f}")


# COMMAND ----------

# MAGIC %md
# MAGIC # SIAMESE MODEL

# COMMAND ----------

import os
import random
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import cv2
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, f1_score, roc_curve, auc, accuracy_score
import matplotlib.pyplot as plt

# --------------------------
# 1. Dataset Preparation
# --------------------------
genuine_paths = [os.path.join("/Volumes/dp-datalake-dev-default/test/test/genuines", f) for f in os.listdir("/Volumes/dp-datalake-dev-default/test/test/genuines")]
forged_paths = [os.path.join("/Volumes/dp-datalake-dev-default/test/test/forgeries", f) for f in os.listdir("/Volumes/dp-datalake-dev-default/test/test/forgeries")]

genuine_train, genuine_test = train_test_split(genuine_paths, test_size=0.2, random_state=42)
forged_train, forged_test = train_test_split(forged_paths, test_size=0.2, random_state=42)

# --------------------------
# 2. Lightweight Data Augmentation
# --------------------------
class SimpleTransform:
    def __init__(self, train=True):
        self.train = train
        self.base = transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
        ])

    def __call__(self, img):
        if not isinstance(img, Image.Image):
            img = Image.fromarray(img)
        gray = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2GRAY)
        edges = cv2.Laplacian(gray, cv2.CV_8U, ksize=3)
        img = np.clip(np.array(img) + 0.3 * np.stack([edges]*3, axis=-1), 0, 255)
        img = Image.fromarray(img.astype(np.uint8))
        img = self.base(img)

        if self.train:
            img = transforms.RandomPerspective(0.2, p=0.5)(img)
        return img

# --------------------------
# 3. Memory-Efficient Siamese Model
# --------------------------
class LightSiamese(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 128x128 -> 64x64

            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 64x64 -> 32x32
        )
        self.feature_dim = 64 * 32 * 32
        self.head = nn.Sequential(
            nn.Linear(self.feature_dim * 2, 128),
            nn.ReLU(),
            nn.Linear(128, 1) )

    def forward_once(self, x):
        return self.cnn(x).view(x.size(0), -1)

    def forward(self, x1, x2):
        feat1 = self.forward_once(x1)
        feat2 = self.forward_once(x2)
        combined = torch.cat([feat1, feat2], dim=1)
        return self.head(combined).squeeze()

# --------------------------
# 4. Training Utilities with Metrics
# --------------------------
def evaluate(model, loader, device):
    model.eval()
    all_preds, all_labels, all_probs = [], [], []

    with torch.no_grad():
        for img1, img2, labels in loader:
            img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)
            outputs = model(img1, img2)

            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).float()

            all_probs.extend(probs.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    f1 = f1_score(all_labels, all_preds)
    auc_score = roc_auc_score(all_labels, all_probs)
    acc = accuracy_score(all_labels, all_preds)

    # Plot ROC curve
    fpr, tpr, _ = roc_curve(all_labels, all_probs)
    plt.figure()
    plt.plot(fpr, tpr, label=f'AUC={auc_score:.2f}')
    plt.plot([0,1],[0,1],'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve')
    plt.legend()
    plt.savefig('roc_curve.png')
    plt.close()

    return f1, auc_score, acc

def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for img1, img2, labels in loader:
        img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(img1, img2)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = (torch.sigmoid(outputs) > 0.5).float()
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return total_loss / len(loader), correct / total

# --------------------------
# 5. Dataset Class
# --------------------------
class SiameseDataset(Dataset):
    def __init__(self, genuine, forged, transform=None):
        self.img1_paths, self.img2_paths, self.labels = self._create_pairs(genuine, forged)
        self.transform = transform or SimpleTransform()

    def _create_pairs(self, genuine, forged):
        pairs = []
        # Genuine-genuine pairs
        for i in range(len(genuine)-1):
            pairs.append((genuine[i], genuine[i+1], 1))
        # Genuine-forged pairs
        for i in range(min(len(genuine), len(forged))):
            pairs.append((genuine[i], forged[i], 0))
        return list(zip(*pairs)) if pairs else ([],[],[])

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img1 = Image.open(self.img1_paths[idx]).convert('RGB')
        img2 = Image.open(self.img2_paths[idx]).convert('RGB')
        return self.transform(img1), self.transform(img2), torch.tensor(self.labels[idx], dtype=torch.float32)

# --------------------------
# 6. Main Execution (20 epochs)
# --------------------------
if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Data
    train_data = SiameseDataset(genuine_train, forged_train, SimpleTransform(train=True))
    test_data = SiameseDataset(genuine_test, forged_test, SimpleTransform(train=False))

    train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=16)

    # Model
    model = LightSiamese().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    criterion = nn.BCEWithLogitsLoss()

    best_f1 = 0
    for epoch in range(20):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        f1, auc, test_acc = evaluate(model, test_loader, device)

        if f1 > best_f1:
            best_f1 = f1
            torch.save(model.state_dict(), 'best_siamese.pth')

        print(f"Epoch {epoch+1}/20: "
              f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
              f"Test F1: {f1:.4f} | Test AUC: {auc:.4f} | Test Acc: {test_acc:.4f} | "
              f"Best F1: {best_f1:.4f}")

# COMMAND ----------

# MAGIC %md
# MAGIC # SWIN

# COMMAND ----------

import os
import random
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import cv2
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, roc_curve, auc, roc_auc_score, confusion_matrix, accuracy_score
from tqdm import tqdm
import timm
import matplotlib.pyplot as plt

# Configuration
class Config:
    IMG_SIZE = 224
    CHANNELS = 3
    BATCH_SIZE = 32
    EPOCHS = 20
    LR = 1e-4
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    SEED = 42

# Set random seed for reproducibility
torch.manual_seed(Config.SEED)
random.seed(Config.SEED)
np.random.seed(Config.SEED)

# --------------------------
# 1. Signature Dataset (Single Image)
# --------------------------
class SignatureDataset(Dataset):
    def __init__(self, genuine_paths, forged_paths, transform=None):
        self.paths = genuine_paths + forged_paths
        self.labels = [1] * len(genuine_paths) + [0] * len(forged_paths)
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            img = self.transform(img)

        return img, torch.tensor(label, dtype=torch.float32)

# --------------------------
# 2. Data Augmentation
# --------------------------
class SignatureTransform:
    def __init__(self, train=True):
        self.train = train
        self.base_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, img):
        img = np.array(img)

        # Edge enhancement
        gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        edges = cv2.Laplacian(gray, cv2.CV_8U, ksize=3)
        edges = np.stack([edges]*3, axis=-1)
        img = np.clip(img + 0.3*edges, 0, 255).astype(np.uint8)

        if self.train:
            augmentations = transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),
                transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
                transforms.ColorJitter(brightness=0.2, contrast=0.2),
                transforms.RandomResizedCrop(Config.IMG_SIZE, scale=(0.8, 1.0)),
                transforms.RandomHorizontalFlip(p=0.5),
            ])(img)
        else:
            augmentations = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize(Config.IMG_SIZE),
                transforms.CenterCrop(Config.IMG_SIZE),
            ])(img)

        return self.base_transform(augmentations)

# --------------------------
# 3. Swin Transformer Model
# --------------------------
class SwinSignatureClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.swin = timm.create_model(
            'swin_tiny_patch4_window7_224',
            pretrained=True,
            num_classes=1  # Let timm handle the final classification layer
        )

    def forward(self, x):
        return self.swin(x).squeeze()

# --------------------------
# 4. Training Utilities
# --------------------------
def create_data_loaders(genuine_dir, forged_dir, test_size=0.2):
    genuine_paths = [os.path.join(genuine_dir, f) for f in os.listdir(genuine_dir)]
    forged_paths = [os.path.join(forged_dir, f) for f in os.listdir(forged_dir)]

    # Split datasets
    genuine_train, genuine_test = train_test_split(genuine_paths, test_size=test_size, random_state=Config.SEED)
    forged_train, forged_test = train_test_split(forged_paths, test_size=test_size, random_state=Config.SEED)

    train_dataset = SignatureDataset(
        genuine_train, forged_train,
        transform=SignatureTransform(train=True)
    )
    test_dataset = SignatureDataset(
        genuine_test, forged_test,
        transform=SignatureTransform(train=False)
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=Config.BATCH_SIZE,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=Config.BATCH_SIZE,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

    return train_loader, test_loader

def plot_roc_curve(true_labels, pred_probs, epoch=None):
    fpr, tpr, thresholds = roc_curve(true_labels, pred_probs)
    roc_auc = auc(fpr, tpr)

    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='darkorange', lw=2,
             label=f'ROC curve (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'ROC Curve{" - Epoch "+str(epoch) if epoch else ""}')
    plt.legend(loc="lower right")
    plt.savefig(f'roc_curve{"_epoch"+str(epoch) if epoch else ""}.png')
    plt.close()

# --------------------------
# 5. Training Loop
# --------------------------
def train():
    torch.cuda.empty_cache()

    # Initialize
    train_loader, test_loader = create_data_loaders(
        genuine_dir="/Volumes/dp-datalake-dev-default/test/test/genuines",
        forged_dir="/Volumes/dp-datalake-dev-default/test/test/forgeries"
    )

    model = SwinSignatureClassifier().to(Config.DEVICE)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=Config.LR, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=3)
    scaler = torch.cuda.amp.GradScaler()

    best_acc = 0
    best_metrics = {'f1': 0, 'auc': 0}

    for epoch in range(Config.EPOCHS):
        # Training
        model.train()
        train_loss, train_correct = 0, 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{Config.EPOCHS} [Train]")
        for images, labels in pbar:
            images, labels = images.to(Config.DEVICE), labels.to(Config.DEVICE)

            optimizer.zero_grad()

            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item()
            preds = (torch.sigmoid(outputs) > 0.5).float()
            train_correct += (preds == labels).sum().item()

            pbar.set_postfix({
                'loss': train_loss/len(pbar),
                'acc': train_correct/((pbar.n+1)*Config.BATCH_SIZE)
            })

        # Validation
        model.eval()
        val_loss, val_correct = 0, 0
        all_preds = []
        all_labels = []
        all_probs = []

        with torch.no_grad():
            for images, labels in tqdm(test_loader, desc=f"Epoch {epoch+1}/{Config.EPOCHS} [Val]"):
                images, labels = images.to(Config.DEVICE), labels.to(Config.DEVICE)

                outputs = model(images)
                val_loss += criterion(outputs, labels).item()

                probs = torch.sigmoid(outputs)
                preds = (probs > 0.5).float()
                val_correct += (preds == labels).sum().item()

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())

        # Calculate metrics
        train_loss /= len(train_loader)
        train_acc = train_correct / len(train_loader.dataset)
        val_loss /= len(test_loader)
        val_acc = val_correct / len(test_loader.dataset)
        val_f1 = f1_score(all_labels, all_preds)
        val_auc = roc_auc_score(all_labels, all_probs)

        scheduler.step(val_acc)

        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            best_metrics = {'f1': val_f1, 'auc': val_auc}
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'val_acc': val_acc,
                'val_f1': val_f1,
                'val_auc': val_auc,
            }, "best_swin_classifier.pth")

            # Plot ROC curve for best model
            plot_roc_curve(all_labels, all_probs, epoch+1)

        print(f"\nEpoch {epoch+1}: "
              f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | "
              f"F1: {val_f1:.4f} | AUC: {val_auc:.4f} | "
              f"Best Acc: {best_acc:.4f}")

    # Final evaluation
    print(f"\nTraining complete. Best Val Accuracy: {best_acc:.4f}")
    print(f"Best F1 Score: {best_metrics['f1']:.4f}")
    print(f"Best ROC AUC: {best_metrics['auc']:.4f}")

    # Load best model and evaluate on test set
    best_model = SwinSignatureClassifier().to(Config.DEVICE)
    checkpoint = torch.load("best_swin_classifier.pth", weights_only=False)
    best_model.load_state_dict(checkpoint['model_state_dict'])

    # Final evaluation function
    best_model.eval()
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Final Evaluation"):
            images, labels = images.to(Config.DEVICE), labels.to(Config.DEVICE)
            outputs = best_model(images)
            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).float()

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    # Calculate final metrics
    final_accuracy = accuracy_score(all_labels, all_preds)
    final_f1 = f1_score(all_labels, all_preds)
    final_auc = roc_auc_score(all_labels, all_probs)
    cm = confusion_matrix(all_labels, all_preds)

    print("\nFinal Test Set Performance:")
    print(f"Accuracy: {final_accuracy:.4f}")
    print(f"F1 Score: {final_f1:.4f}")
    print(f"ROC AUC: {final_auc:.4f}")
    print("Confusion Matrix:")
    print(cm)

    # Plot final ROC curve
    plot_roc_curve(all_labels, all_probs, "final")

if __name__ == "__main__":
    train()

FileNotFoundError: [Errno 2] No such file or directory: '/Volumes/dp-datalake-dev-default/test/test/4NSigComp2010.zip'