In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/siw-dataset-images/test_data-20230322T161835Z-001/test_data/live_109-2-1-1-2_720.jpg
/kaggle/input/siw-dataset-images/test_data-20230322T161835Z-001/test_data/live_108-1-1-1-2_180.jpg
/kaggle/input/siw-dataset-images/test_data-20230322T161835Z-001/test_data/live_110-2-1-1-1_420.jpg
/kaggle/input/siw-dataset-images/test_data-20230322T161835Z-001/test_data/spoof_015-1-3-1-2_60.jpg
/kaggle/input/siw-dataset-images/test_data-20230322T161835Z-001/test_data/spoof_015-2-3-2-1_360.jpg
/kaggle/input/siw-dataset-images/test_data-20230322T161835Z-001/test_data/spoof_018-2-3-4-2_300.jpg
/kaggle/input/siw-dataset-images/test_data-20230322T161835Z-001/test_data/spoof_018-1-3-4-1_300.jpg
/kaggle/input/siw-dataset-images/test_data-20230322T161835Z-001/test_data/live_109-2-1-1-1_1140.jpg
/kaggle/input/siw-dataset-images/test_data-20230322T161835Z-001/test_data/spoof_015-2-3-4-1_240.jpg
/kaggle/input/siw-dataset-images/test_data-20230322T161835Z-001/test_data/spoof_015-2-3-2-1_120.jpg
/kag

# # MobileNetV3 edge detection controll stiching Sobel XY Filter

In [None]:
# Full training + evaluation script for 14x14 pixel-wise BCE loss
# using MobileNetV3 + Sobel Edge Detection + Patch Stitching (Controlled & Random)

import os
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.metrics import confusion_matrix, roc_curve, auc
import matplotlib.pyplot as plt
import numpy as np
import random

# ---------------------- Sobel Filter ----------------------
def sobel_filter(image):
    sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=torch.float32).expand(3, 1, 3, 3)
    sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32).expand(3, 1, 3, 3)
    sobel_x, sobel_y = sobel_x.to(image.device), sobel_y.to(image.device)

    image = image.unsqueeze(0)
    grad_x = F.conv2d(image, sobel_x, padding=1, groups=3)
    grad_y = F.conv2d(image, sobel_y, padding=1, groups=3)
    edge = torch.sqrt(grad_x ** 2 + grad_y ** 2)
    return edge.squeeze(0)

# ---------------------- Dataset ----------------------
class PatchGridDataset(Dataset):
    def __init__(self, root_dir, transform=None, mode="controlled"):
        self.image_paths = sorted(glob.glob(os.path.join(root_dir, '*.jpg')))
        self.transform = transform
        self.mode = mode
        self.patch_size = 16
        self.grid_size = 14
        self.label_map = [1 if 'live' in os.path.basename(p).lower() else 0 for p in self.image_paths]

        self.all_images = [Image.open(p).convert('RGB') for p in self.image_paths]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        label = self.label_map[idx]
        base_img = self.all_images[idx]
        if self.transform:
            base_img = self.transform(base_img)

        patches = []
        gt_map = []

        for i in range(self.grid_size * self.grid_size):
            if self.mode == "controlled":
                if random.random() > 0.5:
                    img = base_img
                    lab = label
                else:
                    ridx = random.randint(0, len(self.all_images) - 1)
                    img = self.transform(self.all_images[ridx])
                    lab = self.label_map[ridx]
            else:
                ridx = random.randint(0, len(self.all_images) - 1)
                img = self.transform(self.all_images[ridx])
                lab = self.label_map[ridx]

            top = random.randint(0, img.shape[1] - self.patch_size)
            left = random.randint(0, img.shape[2] - self.patch_size)
            patch = img[:, top:top + self.patch_size, left:left + self.patch_size]
            edge = sobel_filter(patch)
            combined = torch.cat([patch, edge], dim=0)  # [6, 16, 16]
            patches.append(combined)
            gt_map.append(lab)

        grid = torch.stack(patches).view(self.grid_size, self.grid_size, 6, self.patch_size, self.patch_size)
        grid = grid.permute(2, 0, 3, 1, 4).reshape(6, self.grid_size * self.patch_size, self.grid_size * self.patch_size)
        label_map = torch.tensor(gt_map, dtype=torch.float32).view(self.grid_size, self.grid_size)
        return grid, label_map

# ---------------------- Model ----------------------
class MobileNetGrid(nn.Module):
    def __init__(self):
        super().__init__()
        base = models.mobilenet_v3_small(pretrained=True)
        base.features[0][0] = nn.Conv2d(6, 16, 3, 2, 1, bias=False)
        self.features = base.features
        self.conv1x1 = nn.Conv2d(576, 1, kernel_size=1)
        self.upsample = nn.Upsample(size=(14, 14), mode='bilinear', align_corners=False)

    def forward(self, x):
        x = self.features(x)
        x = self.conv1x1(x)
        x = self.upsample(x)
        return x.squeeze(1)  # [B, 14, 14]

# ---------------------- Evaluation ----------------------
def evaluate(model, dataloader, device):
    model.eval()
    y_true, y_score = [], []

    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            out = model(x)  # raw logits
            out = torch.sigmoid(out)  # apply sigmoid for evaluation
            y_true.extend((y.mean(dim=(1, 2)) > 0.5).int().cpu().numpy())
            y_score.extend(out.mean(dim=(1, 2)).cpu().numpy())

    preds = [1 if p > 0.5 else 0 for p in y_score]
    cm = confusion_matrix(y_true, preds)
    tn, fp, fn, tp = cm.ravel()
    apcer = fp / (fp + tn + 1e-6)
    bpcer = fn / (fn + tp + 1e-6)
    acer = (apcer + bpcer) / 2
    acc = 100 * (tp + tn) / (tp + tn + fp + fn)

    fpr, tpr, _ = roc_curve(y_true, y_score)
    roc_auc = auc(fpr, tpr)

    return acc, apcer, bpcer, acer, fpr, tpr, roc_auc

# ---------------------- Train Function ----------------------
def train_model(model, loaders, device, save_prefix):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
    criterion = nn.BCEWithLogitsLoss()
    train_loader, val_loader = loaders

    train_accs, val_accs, losses = [], [], []

    for epoch in range(1, 26):
        model.train()
        total_loss = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            y = y.clamp(0, 1).float()
            out = model(x)
            loss = criterion(out, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        scheduler.step()
        avg_loss = total_loss / len(train_loader)
        acc, apcer, bpcer, acer, fpr, tpr, roc_auc = evaluate(model, val_loader, device)

        print(f"Epoch {epoch}: Loss={avg_loss:.4f}, ACC={acc:.2f}%, ACER={acer:.4f}")
        losses.append(avg_loss)
        val_accs.append(acc)

    # --- Save Metrics Plots ---
    plt.figure()
    plt.plot(val_accs, label='Validation Accuracy')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.legend()
    plt.savefig(f'{save_prefix}_accuracy.png')

    plt.figure()
    plt.plot(losses, label='Training Loss')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend()
    plt.savefig(f'{save_prefix}_loss.png')

    plt.figure()
    plt.plot(fpr, tpr, label=f'ROC (AUC={roc_auc:.2f})')
    plt.xlabel('FPR'); plt.ylabel('TPR'); plt.legend()
    plt.savefig(f'{save_prefix}_roc.png')

    return acc, apcer, bpcer, acer, model

# ---------------------- Run Controlled + Random ----------------------
def run_all():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])

    results = {}
    for mode in ["controlled", "random"]:
        print(f"\n===== Training Mode: {mode.upper()} Stitching =====")
        model = MobileNetGrid().to(device)

        train_set = PatchGridDataset('/kaggle/input/siw-dataset-images/train_data-20230322T161839Z-001/train_data', transform, mode)
        val_set = PatchGridDataset('/kaggle/input/siw-dataset-images/val_data-20230322T161837Z-001/val_data', transform, mode)
        test_set = PatchGridDataset('/kaggle/input/siw-dataset-images/test_data-20230322T161835Z-001/test_data', transform, mode)

        train_loader = DataLoader(train_set, batch_size=8, shuffle=True, num_workers=2)
        val_loader = DataLoader(val_set, batch_size=8, shuffle=False, num_workers=2)
        test_loader = DataLoader(test_set, batch_size=8, shuffle=False, num_workers=2)

        acc, apcer, bpcer, acer, model = train_model(model, (train_loader, val_loader), device, save_prefix=mode)
        print(f"\n📊 Final Test Evaluation for {mode.upper()} Stitching")
        test_acc, test_apcer, test_bpcer, test_acer, *_ = evaluate(model, test_loader, device)
        print(f"Test Accuracy: {test_acc:.2f}%, APCER: {test_apcer:.4f}, BPCER: {test_bpcer:.4f}, ACER: {test_acer:.4f}")

        results[mode] = dict(ACC=acc, APCER=apcer, BPCER=bpcer, ACER=acer, TestACC=test_acc, TestACER=test_acer)

    # --- Bar plot ---
    labels = list(results.keys())
    metrics = ["ACC", "APCER", "BPCER", "ACER", "TestACC", "TestACER"]

    x = np.arange(len(labels))
    width = 0.12

    plt.figure(figsize=(12, 6))
    for i, metric in enumerate(metrics):
        values = [results[m][metric] for m in labels]
        plt.bar(x + i * width, values, width=width, label=metric)

    plt.xticks(x + (len(metrics) / 2) * width, labels)
    plt.ylabel('Score')
    plt.title('Comparison of Controlled vs Random Stitching')
    plt.legend()
    plt.savefig('comparison_bar.png')

# ---------------------- Entry ----------------------
if __name__ == '__main__':
    run_all()



===== Training Mode: CONTROLLED Stitching =====
Epoch 1: Loss=0.3512, ACC=64.16%, ACER=0.3890
Epoch 2: Loss=0.3248, ACC=91.26%, ACER=0.0848
Epoch 3: Loss=0.3117, ACC=92.73%, ACER=0.0718
Epoch 4: Loss=0.3041, ACC=91.26%, ACER=0.0941
Epoch 5: Loss=0.2992, ACC=87.93%, ACER=0.1154
Epoch 6: Loss=0.2976, ACC=54.68%, ACER=0.4201
Epoch 7: Loss=0.2924, ACC=91.26%, ACER=0.0881
Epoch 8: Loss=0.2880, ACC=91.75%, ACER=0.0882
Epoch 9: Loss=0.2888, ACC=87.93%, ACER=0.1261
Epoch 10: Loss=0.2837, ACC=84.24%, ACER=0.1481
Epoch 11: Loss=0.2785, ACC=96.67%, ACER=0.0351
Epoch 12: Loss=0.2763, ACC=79.68%, ACER=0.1886
Epoch 13: Loss=0.2764, ACC=92.24%, ACER=0.0725
Epoch 14: Loss=0.2755, ACC=92.73%, ACER=0.0683
Epoch 15: Loss=0.2721, ACC=93.23%, ACER=0.0638
Epoch 16: Loss=0.2740, ACC=85.71%, ACER=0.1324
Epoch 17: Loss=0.2734, ACC=93.10%, ACER=0.0641
Epoch 18: Loss=0.2723, ACC=90.39%, ACER=0.0892
Epoch 19: Loss=0.2713, ACC=97.04%, ACER=0.0278
Epoch 20: Loss=0.2681, ACC=95.81%, ACER=0.0388
Epoch 21: Loss=0.267



Epoch 1: Loss=0.3869, ACC=85.10%, ACER=0.5000
Epoch 2: Loss=0.3639, ACC=84.73%, ACER=0.5000
Epoch 3: Loss=0.3562, ACC=87.19%, ACER=0.5000
Epoch 4: Loss=0.3524, ACC=84.85%, ACER=0.5000
Epoch 5: Loss=0.3500, ACC=83.99%, ACER=0.5000
Epoch 6: Loss=0.3462, ACC=85.10%, ACER=0.5000
Epoch 7: Loss=0.3436, ACC=85.71%, ACER=0.5000
Epoch 8: Loss=0.3424, ACC=86.08%, ACER=0.5000
Epoch 9: Loss=0.3397, ACC=85.96%, ACER=0.5000
Epoch 10: Loss=0.3386, ACC=85.10%, ACER=0.5000
Epoch 11: Loss=0.3354, ACC=84.48%, ACER=0.5000
