In [None]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
import sys
import os 
from torchvision import datasets,transforms,models,utils
from torch.utils.data import DataLoader,ConcatDataset,Dataset
import numpy as np
import matplotlib.pyplot as plt
import math
import re 
import timm
import pandas as pd
import numpy as np
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

# Transforms

In [None]:
transform_original = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
transform_flipped = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transforms.RandomHorizontalFlip(p=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
spoof_transforms = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
    transforms.RandomHorizontalFlip(p=1),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
from PIL import Image
class CustomDataset(Dataset):
    def __init__(self,root,special_transform=None,general_transform=None,special_classes=None):
        self.root=root
        self.labels=self.assign_labels()
        if special_classes is not None:
            if not isinstance(special_classes,(set,list,tuple)):
                raise TypeError("special_classes must be a set,list or tuple")
            self.special_classes=set(special_classes)
        else:
            self.special_classes=set()
        self.special_transform=special_transform
        self.general_transform=general_transform
        
    def assign_labels(self):
        labels=[]
        for parent_folder in os.listdir(self.root):
            parent_folder_path=os.path.join(self.root,parent_folder)
            if os.path.isdir(parent_folder_path):
                label=1 if parent_folder=="real" else 0
                image_files=os.listdir(parent_folder_path)
                for image_file in image_files:
                    image_path=os.path.join(parent_folder_path,image_file)
                    labels.append((image_path,label))
        return labels
                
            
    def __getitem__(self,index):
        image_path,label=self.labels[index]
        image = Image.open(image_path).convert("RGB")
        if label in self.special_classes and self.special_transform:
            image=self.special_transform(image)
        elif self.general_transform:
            image=self.general_transform(image)
        label = torch.tensor(label).unsqueeze(0)
        return image,label
    def __len__(self):
        return len(self.labels)

In [None]:
EPOCHS=3
BATCH_SIZE=16
EPOCH_LEN=len(str(EPOCHS))

In [None]:
train_orig=CustomDataset("/kaggle/input/hehedataset/Combine/train",general_transform=transform_original)
train_flip = CustomDataset(root="/kaggle/input/hehedataset/Combine/train",
                           general_transform=transform_flipped,
                           special_transform=spoof_transforms,
                           special_classes=[0])
train_data_combined = ConcatDataset([train_orig, train_flip])
train_loader = DataLoader(train_data_combined, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
val_orig = CustomDataset("/kaggle/input/hehedataset/Combine/val", general_transform=transform_original)
val_flip = CustomDataset(root="/kaggle/input/hehedataset/Combine/val",
                           general_transform=transform_flipped,
                           special_transform=spoof_transforms,
                           special_classes=[0])

val_data_combined = ConcatDataset([val_orig, val_flip])
val_loader = DataLoader(val_data_combined, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
test_orig = CustomDataset("/kaggle/input/hehedataset/Combine/test", general_transform=transform_original)
test_loader = DataLoader(test_orig, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
print(len(train_data_combined),len(val_data_combined),len(test_orig))

In [None]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

# Model

In [None]:
from torchvision.models import mobilenet_v2

In [None]:
class SpoofNet(nn.Module):
    def __init__(self):
        super(SpoofNet,self).__init__()
        self.pretrained_net = mobilenet_v2(pretrained=True)
        self.features = self.pretrained_net.features
        self.conv2d = nn.Conv2d(1280, 32, kernel_size=(3, 3), padding=1)  # Adjust input channels if needed
        self.relu = nn.ReLU()
        self.dropout1 = nn.Dropout(0.2)
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(32, 1)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        x = self.features(x)
        x = self.conv2d(x)
        x = self.relu(x)
        x = self.dropout1(x)
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        x = self.sigmoid(x)
        return x

In [None]:
model=SpoofNet().to(device)

In [None]:
model.eval()
with torch.no_grad():
    correct = 0
    tp = 0
    tn = 0
    fp = 0
    fn = 0

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        pred = (output > 0.5).int()

        tp += (pred.eq(1) & target.eq(1).view_as(pred)).sum().item()
        tn += (pred.eq(0) & target.eq(0).view_as(pred)).sum().item()
        fp += (pred.eq(1) & target.eq(0).view_as(pred)).sum().item()
        fn += (pred.eq(0) & target.eq(1).view_as(pred)).sum().item()

        correct += pred.eq(target.view_as(pred)).sum().item()

    accuracy = correct / len(test_loader.dataset)
    far = fp / (fp + tn)
    frr = fn / (fn + tp)
    recall = tp / (tp + fn)
    hter = (far + frr ) / 2

    print(f"test acc: {accuracy * 100}%")
    print(f"recall: {recall * 100}%")
    print(f"far: {far * 100}%")
    print(f"frr: {frr * 100}%")
    print(f"hter: {hter * 100}%")

In [None]:
checkpoint = torch.load("/kaggle/input/mobilenet/mobilenetv2-best.pt")
model.load_state_dict(checkpoint['state_dict'])

In [None]:
from PIL import Image
from skimage.io import imread
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
from mlxtend.plotting import plot_confusion_matrix

# pytorch
import torch
import torch.nn as nn
from torch.nn.functional import softmax
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter

In [None]:
num_epochs = 3
learning_rate = 5e-5

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.BCELoss()
scheduler = ReduceLROnPlateau(
    optimizer, factor=0.2, patience=3, verbose=True, 
    threshold=0.005, min_lr=5e-7,
)
history = {
    'train_loss': [],
    'val_loss': [],
    'train_accuracy': [],
    'val_accuracy': [],
    'learning_rate': [],
}
best_val_loss = 0.0
save_dir = '/kaggle/working/'

# define checkpoint paths
cont_filepath = os.path.join(save_dir, "mobilenetv2-epoch_{}.pt")
best_filepath = os.path.join(save_dir, "mobilenetv2-best.pt")

def save_checkpoint(state, is_best, filename):
    torch.save(state, filename)
    if is_best:
        torch.save(state, best_filepath)

In [None]:
from tqdm import tqdm
for epoch in range(num_epochs):
    print('epoch: {}/{}'.format(epoch+1, num_epochs))
    print('-----------------------')
    model.train()
    running_loss = 0.0
    train_total = 0
    train_correct = 0
    prog_bar_train = tqdm(train_loader, desc='training')
    for inputs, labels in prog_bar_train:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        predicted = (outputs > 0.5).int()
        loss = criterion(outputs, labels.float())
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()
        
        # update progress
        prog_bar_train.set_postfix({'acc': round(train_correct / train_total, 2)})
    train_acc = 100 * train_correct / train_total
    avg_train_loss = running_loss / len(train_loader)

    history['train_loss'].append(avg_train_loss)
    history['train_accuracy'].append(train_acc)
    
    # validate the model --------------------------------------------------
    model.eval()
    running_val_loss = 0.0
    val_total = 0
    val_correct = 0
    prog_bar_val = tqdm(val_loader, desc='validating')
    with torch.no_grad():
        for inputs, labels in prog_bar_val:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels.float())
            running_val_loss += loss.item()
            predicted = (outputs > 0.5).int()
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

            # update progress
            prog_bar_val.set_postfix({'acc': round(val_correct / val_total, 4)})
    avg_val_loss = running_val_loss / len(val_loader)
    scheduler.step(avg_val_loss)
    
    val_acc = 100 * val_correct / val_total

    history['val_loss'].append(avg_val_loss)
    history['val_accuracy'].append(val_acc)
    
    if epoch == 0:
        best_val_loss = avg_val_loss
        is_best = True
    else:
        is_best = avg_val_loss < best_val_loss
        best_val_loss = min(avg_val_loss, best_val_loss)
    
    checkpoint_filepath = cont_filepath.format(epoch+1)
    print('saving checkpoint: {}'.format(checkpoint_filepath))
    save_checkpoint(
        {'epoch': epoch + 1,
         'state_dict': model.state_dict(),
         'optimizer': optimizer.state_dict(),},
        is_best,
        checkpoint_filepath
    )
    
    current_lr = optimizer.param_groups[0]['lr']
    history['learning_rate'].append(current_lr)
    
    
    # print loss and accuracy
    print(f'Epoch: {epoch+1}/{num_epochs}'),
    print('Loss/train: {}'.format(avg_train_loss))
    print('Loss/val: {}'.format(avg_val_loss))
    print('Acc/train: {}%'.format(train_acc))
    print('Acc/val: {}%'.format(val_acc))
    print('current lr: {}'.format(current_lr))

In [None]:
from IPython.display import FileLink
FileLink('mobilenetv2-best.pt')

# Evaluation

In [None]:
model.eval()
with torch.no_grad():
    correct = 0
    tp = 0
    tn = 0
    fp = 0
    fn = 0

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        pred = (output > 0.5).int()

        tp += (pred.eq(1) & target.eq(1).view_as(pred)).sum().item()
        tn += (pred.eq(0) & target.eq(0).view_as(pred)).sum().item()
        fp += (pred.eq(1) & target.eq(0).view_as(pred)).sum().item()
        fn += (pred.eq(0) & target.eq(1).view_as(pred)).sum().item()

        correct += pred.eq(target.view_as(pred)).sum().item()

    accuracy = correct / len(test_loader.dataset)
    far = fp / (fp + tn)
    frr = fn / (fn + tp)
    recall = tp / (tp + fn)
    hter = (far + frr ) / 2

    print(f"test acc: {accuracy * 100}%")
    print(f"recall: {recall * 100}%")
    print(f"far: {far * 100}%")
    print(f"frr: {frr * 100}%")
    print(f"hter: {hter * 100}%")