In [None]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
import sys
import os 
from torchvision import datasets,transforms,models,utils
from torch.utils.data import DataLoader,ConcatDataset,Dataset
import numpy as np
import matplotlib.pyplot as plt
import re 
import timm
import pandas as pd
import numpy as np
# Ignore warnings
from PIL import Image
import warnings
warnings.filterwarnings("ignore")

# CDCN model

In [None]:
class FASDataset(Dataset):
    def __init__(self,root_dir,depth_map_size,transform,smoothing):
        super().__init__()
        self.root_dir=root_dir
        self.labels=self.assign_labels()
        self.depth_map_size=depth_map_size
        self.transform=transform
        if smoothing:
            self.label_weight=1.0
        else:
            self.label_weight=0.99
    def assign_labels(self):
        labels=[]
        for parent_folder in os.listdir(self.root_dir):
            parent_folder_path=os.path.join(self.root_dir,parent_folder)
            if os.path.isdir(parent_folder_path):
                label=1 if parent_folder=='real' else 0
                image_files=os.listdir(parent_folder_path)
                for image_file in image_files:
                    image_path=os.path.join(parent_folder_path,image_file)
                    labels.append((image_path,label))
        return labels
    def __getitem__(self,index):
        image_path,label=self.labels[index]
        img=Image.open(image_path)
        if self.transform:
            img=self.transform(img)
        label=float(label)
        label=np.expand_dims(label,axis=0)
        if label==1:
            depth_map=np.ones((self.depth_map_size[0],self.depth_map_size[1]),dtype=np.float32)*self.label_weight
        else:
            depth_map=np.ones((self.depth_map_size[0],self.depth_map_size[1]),dtype=np.float32)*(1.0-self.label_weight)
        return img, depth_map,label
    def __len__(self):
        return len(self.labels)

In [None]:
  _mean_=[0.5,0.5,0.5]
  _sigma_= [0.5,0.5,0.5]

In [None]:
import os
import torch
from torchvision import transforms
from PIL import ImageDraw

def add_visualization(epoch,img_batch,preds,targets,score,writer):
    mean=[-(_mean_[i])/(_sigma_[i]) for i in range(len(_mean_))]
    sigma=[1/(_sigma_[i]) for i in range(len(_sigma_))]
    img_transform=transforms.Compose([
        transforms.Normalize(mean,sigma),
        transforms.ToPILImage()
    ])
    ts_transform=transforms.ToTensor()
    for idx in range(img_batch.shape[0]):
        vis_img = img_transform(img_batch[idx].cpu())
        ImageDraw.Draw(vis_img).text((0,0), 'pred: {} vs gt: {}'.format(int(preds[idx]), int(targets[idx])), (255,0,255))
        ImageDraw.Draw(vis_img).text((20,20), 'score {}'.format(score[idx]), (255,0,255))
        tb_img = ts_transform(vis_img)
        writer.add_image('Prediction visualization/{}'.format(idx), tb_img, epoch)

def predict(depth_map,threshold=0.5):
    """ Convert depth map estimation to true/fake prediction
    Args:
    -depth_map:32*32
    -threshold: between 0 and 1
    Return 
    -Predicted score"""
    with torch.no_grad():
        score=torch.mean(depth_map,axis=(1,2))
        preds=(score>=threshold).type(torch.FloatTensor)
        return preds,score
def calc_accuracy(preds,targets):
    with torch.no_grad():
        equals=torch.mean(preds.eq(targets).type(torch.FloatTensor))
        return equals.item()
    

In [None]:
class AvgMeter():
    def __init__(self,writer,name,num_iter_per_epoch,per_iter_vis=False):
        self.writer=writer
        self.name=name
        self.num_iter_per_epoch=num_iter_per_epoch
        self.per_iter_vis=per_iter_vis
    def reset(self,epoch):
        self.val=0
        self.avg=0
        self.sum=0
        self.count=0
        self.epoch=epoch
    def update(self,val,n=1):
        self.val=val
        self.sum+=val*n
        self.count += n
        self.avg = self.sum / self.count if self.count !=0 else 0
        if self.per_iter_vis:
            self.writer.add_scalar(self.name, self.avg, self.epoch * self.num_iter_per_epoch + self.count - 1)
        else:
            if self.count == self.num_iter_per_epoch - 1:
                self.writer.add_scalar(self.name, self.avg, self.epoch)

In [None]:
import torchvision.transforms.functional as TF
from random import random
class RandomGammaCorrection:
    def __init__(self,max_gamma,min_gamma):
        self.max_gamma=max_gamma
        self.min_gamma=min_gamma
    def __call__(self,x):
        gamma=self.min_gamma+random()*(self.max_gamma-self.min_gamma)
        return TF.adjust_gamma(x,gamma=gamma)
    

In [None]:
########################   Centeral-difference (second order, with 9 parameters and a const theta for 3x3 kernel) 2D Convolution   ##############################
## | a1 a2 a3 |   | w1 w2 w3 |
## | a4 a5 a6 | * | w4 w5 w6 | --> output = \sum_{i=1}^{9}(ai * wi) - \sum_{i=1}^{9}wi * a5 --> Conv2d (k=3) - Conv2d (k=1)
## | a7 a8 a9 |   | w7 w8 w9 |
##
##   --> output = 
## | a1 a2 a3 |   |  w1  w2  w3 |     
## | a4 a5 a6 | * |  w4  w5  w6 |  -  | a | * | w\_sum |     (kernel_size=1x1, padding=0)
## | a7 a8 a9 |   |  w7  w8  w9 |     

class Conv2d_cd(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=3,stride=1,padding=1,dilation=1,groups=1,bias=False,theta=0.7):
        super(Conv2d_cd,self).__init__()
        self.conv=nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias)
        self.theta=theta
    def forward(self,x):
        out_normal=self.conv(x)
        if math.fabs(self.theta-0.0)<1e-8:
            return out_normal
        else:
            #pdb.set_trace
            [C_out,C_in,kernel_size,kernel_size]=self.conv.weight.shape
            kernel_diff=self.conv.weight.sum(2).sum(2)
            kernel_diff = kernel_diff[:, :, None, None]
            out_diff = F.conv2d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride, padding=0, groups=self.conv.groups)
            return out_normal-self.theta*out_diff
class SpatialAttention(nn.Module):
    def __init__(self,kernel=3):
        super(SpatialAttention,self).__init__()
        self.conv1=nn.Conv2d(2,1,kernel_size=kernel,padding=kernel//2,bias=False)
        self.sigmoid=nn.Sigmoid()
    def forward(self,x):
        avg_out=torch.mean(x,dim=1,keepdim=True)
        max_out,_=torch.max(x,dim=1,keepdim=True)
        x=torch.cat([avg_out,max_out],dim=1)
        x=self.conv1(x)
        return self.sigmoid(x)
class CDCN(nn.Module):
    def __init__(self,basic_conv=Conv2d_cd,theta=0.7):
        super(CDCN,self).__init__()
        self.conv1=nn.Sequential(
            basic_conv(3,64,kernel_size=3,stride=1,padding=1,bias=False,theta=theta),
            nn.BatchNorm2d(64),
            nn.ReLU())
        self.Block1=nn.Sequential(
            basic_conv(64,128,kernel_size=3,stride=1,padding=1,bias=False,theta=theta),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            basic_conv(128, 196, kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.BatchNorm2d(196),
            nn.ReLU(),  
            basic_conv(196, 128, kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.BatchNorm2d(128),
            nn.ReLU(),   
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )
        self.Block2 = nn.Sequential(
            basic_conv(128, 128, kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.BatchNorm2d(128),
            nn.ReLU(),   
            basic_conv(128, 196, kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.BatchNorm2d(196),
            nn.ReLU(),  
            basic_conv(196, 128, kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.BatchNorm2d(128),
            nn.ReLU(),  
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )
        self.Block3 = nn.Sequential(
            basic_conv(128, 128, kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.BatchNorm2d(128),
            nn.ReLU(),   
            basic_conv(128, 196, kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.BatchNorm2d(196),
            nn.ReLU(),  
            basic_conv(196, 128, kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.BatchNorm2d(128),
            nn.ReLU(),   
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )
        self.lastconv1 = nn.Sequential(
            basic_conv(128*3, 128, kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.BatchNorm2d(128),
            nn.ReLU(),    
        )
        self.lastconv2 = nn.Sequential(
            basic_conv(128, 64, kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.BatchNorm2d(64),
            nn.ReLU(),    
        )
        
        self.lastconv3 = nn.Sequential(
            basic_conv(64, 1, kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.ReLU(),    
        )
        self.downsample32x32=nn.Upsample(size=(32,32),mode='bilinear')
    def forward(self,x): # x=[3,256,256]
        x_input=x
        x=self.conv1(x)
        x_Block1=self.Block1(x)
        x_Block1_32x32=self.downsample32x32(x_Block1)
        x_Block2 = self.Block2(x_Block1)	    # x [128, 64, 64]	  
        x_Block2_32x32 = self.downsample32x32(x_Block2)   # x [128, 32, 32]  
        
        x_Block3 = self.Block3(x_Block2)	    # x [128, 32, 32]  	
        x_Block3_32x32 = self.downsample32x32(x_Block3)   # x [128, 32, 32]  
        x_concat = torch.cat((x_Block1_32x32,x_Block2_32x32,x_Block3_32x32), dim=1)    # x [128*3, 32, 32] 
        x=self.lastconv1(x_concat) #[128,32,32]
        x=self.lastconv2(x)  #[64,32,32]
        x=self.lastconv3(x)  #[1,32,32]
        map_x=x.squeeze(1)
        return map_x,x_concat,x_Block1,x_Block2,x_Block3,x_input
        

In [None]:
class CDCNpp(nn.Module):

    def __init__(self, basic_conv=Conv2d_cd, theta=0.0):   
        super(CDCNpp, self).__init__()
        
        
        self.conv1 = nn.Sequential(
            basic_conv(3, 64, kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.BatchNorm2d(64),
            nn.ReLU(),    
            
        )
        
        self.Block1 = nn.Sequential(
            basic_conv(64, 128, kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.BatchNorm2d(128),
            nn.ReLU(),  
            
            basic_conv(128, int(128*1.6), kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.BatchNorm2d(int(128*1.6)),
            nn.ReLU(),  
            basic_conv(int(128*1.6), 128, kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.BatchNorm2d(128),
            nn.ReLU(), 
            
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            
        )
        
        self.Block2 = nn.Sequential(
            basic_conv(128, int(128*1.2), kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.BatchNorm2d(int(128*1.2)),
            nn.ReLU(),  
            basic_conv(int(128*1.2), 128, kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.BatchNorm2d(128),
            nn.ReLU(),  
            basic_conv(128, int(128*1.4), kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.BatchNorm2d(int(128*1.4)),
            nn.ReLU(),  
            basic_conv(int(128*1.4), 128, kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.BatchNorm2d(128),
            nn.ReLU(),  
            
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )
        
        self.Block3 = nn.Sequential(
            basic_conv(128, 128, kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.BatchNorm2d(128),
            nn.ReLU(), 
            basic_conv(128, int(128*1.2), kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.BatchNorm2d(int(128*1.2)),
            nn.ReLU(),  
            basic_conv(int(128*1.2), 128, kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.BatchNorm2d(128),
            nn.ReLU(), 
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )
        
        # Original
        
        self.lastconv1 = nn.Sequential(
            basic_conv(128*3, 128, kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            basic_conv(128, 1, kernel_size=3, stride=1, padding=1, bias=False, theta= theta),
            nn.ReLU(),    
        )
        
      
        self.sa1 = SpatialAttention(kernel = 7)
        self.sa2 = SpatialAttention(kernel = 5)
        self.sa3 = SpatialAttention(kernel = 3)
        self.downsample32x32 = nn.Upsample(size=(32, 32), mode='bilinear')

 
    def forward(self, x):	    	# x [3, 256, 256]
        
        x_input = x
        x = self.conv1(x)		   
        
        x_Block1 = self.Block1(x)	    	    	
        attention1 = self.sa1(x_Block1)
        x_Block1_SA = attention1 * x_Block1
        x_Block1_32x32 = self.downsample32x32(x_Block1_SA)   
        
        x_Block2 = self.Block2(x_Block1)	    
        attention2 = self.sa2(x_Block2)  
        x_Block2_SA = attention2 * x_Block2
        x_Block2_32x32 = self.downsample32x32(x_Block2_SA)  
        
        x_Block3 = self.Block3(x_Block2)	    
        attention3 = self.sa3(x_Block3)  
        x_Block3_SA = attention3 * x_Block3	
        x_Block3_32x32 = self.downsample32x32(x_Block3_SA)   
        
        x_concat = torch.cat((x_Block1_32x32,x_Block2_32x32,x_Block3_32x32), dim=1)    
        
        #pdb.set_trace()
        
        map_x = self.lastconv1(x_concat)
        
        map_x = map_x.squeeze(1)
        
        return map_x, x_concat, attention1, attention2, attention3, x_input

In [None]:
def contrast_depth_conv(input,device):
    """
    compute contrast depth in both of (out,label)
    input 32*32
    output 8*32*32
    """
    kernel_filter_list =[
                        [[1,0,0],[0,-1,0],[0,0,0]], [[0,1,0],[0,-1,0],[0,0,0]], [[0,0,1],[0,-1,0],[0,0,0]],
                        [[0,0,0],[1,-1,0],[0,0,0]], [[0,0,0],[0,-1,1],[0,0,0]],
                        [[0,0,0],[0,-1,0],[1,0,0]], [[0,0,0],[0,-1,0],[0,1,0]], [[0,0,0],[0,-1,0],[0,0,1]]
                        ]
    kernel_filter=np.array(kernel_filter_list,np.float32)
    kernel_filter=torch.from_numpy(kernel_filter.astype(np.float64)).float().to(device)
    #weights (in_channel,out_channel,kernel,kernel)
    kernel_filter=kernel_filter.unsqueeze(dim=1)
    input = input.unsqueeze(dim=1).expand(input.shape[0], 8, input.shape[1],input.shape[2]) 
    contrast_depth = F.conv2d(input, weight=kernel_filter, groups=8)  # depthwise conv
    return contrast_depth

class ContrastDepthLoss(nn.Module): # Pearson range [-1, 1] so if < 0, abs|loss| ; if >0, 1- loss
    def __init__(self,device):
        super(ContrastDepthLoss,self).__init__()
        self.device=device
    def forward(self,out,label):
        """
        compute contrast depth in both of (out,label),
        """
        contrast_out=contrast_depth_conv(out,device=self.device)
        contrast_label=contrast_depth_conv(label,device=self.device)
        criterion_MSE=nn.MSELoss()
        loss=criterion_MSE(contrast_out,contrast_label)
        return loss
class DepthLoss(nn.Module):
    def __init__(self,device):
        super(DepthLoss,self).__init__()
        self.criterion_absolute_loss=nn.MSELoss()
        self.criterion_contrastive_loss=ContrastDepthLoss(device=device)
    def forward(self,predicted_depth_map,gt_depth_map):
        absolute_loss=self.criterion_absolute_loss(predicted_depth_map,gt_depth_map)
        contrastive_loss=self.criterion_contrastive_loss(predicted_depth_map,gt_depth_map)
        return absolute_loss+contrastive_loss
    
    
    

In [None]:
import math
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
from random import randint
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=CDCNpp().to(device)
optimizer=torch.optim.Adam(model.parameters(),lr=0.0003)
lr_scheduler=StepLR(optimizer=optimizer,step_size=30,gamma=0.1)
criterion=DepthLoss(device=device)
writer=SummaryWriter()
dump_input=torch.randn((1,3,256,256)).to(device)
writer.add_graph(model,dump_input)

In [None]:
# import math
dump_input=torch.randn((1,3,256,256)).to(device)
output=model(dump_input)
output[0].shape

# DataLoader


In [None]:
transform_flipped = transforms.Compose([
    RandomGammaCorrection(max_gamma=1.5,min_gamma=0.67),
    transforms.Resize([256,256]),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transforms.RandomHorizontalFlip(p=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
spoof_transforms = transforms.Compose([
    RandomGammaCorrection(max_gamma=1.5,min_gamma=0.67),
    transforms.Resize([256,256]),
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
    transforms.RandomHorizontalFlip(p=1),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
normal_transforms=transforms.Compose([
    transforms.Resize([256,256]),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
    
])



In [None]:
train_orig=FASDataset(root_dir="/kaggle/input/hehedataset/Combine/train",
                    depth_map_size=[32,32],
                    transform=transform_flipped,
                    smoothing=True)
train_flip = FASDataset(root_dir="/kaggle/input/hehedataset/Combine/train",
                           depth_map_size=[32,32],
                           transform=spoof_transforms,
                           smoothing=True)
train_data_combined = ConcatDataset([train_orig, train_flip])
train_loader=torch.utils.data.DataLoader(
                dataset=train_data_combined,
                batch_size=8,
                shuffle=True,
                num_workers=2)


In [None]:
val_orig=FASDataset(root_dir="/kaggle/input/hehedataset/Combine/val",
                   
                    depth_map_size=[32,32],
                    transform=transform_flipped,
                    smoothing=True)
val_flip = FASDataset(root_dir="/kaggle/input/hehedataset/Combine/val",
                           depth_map_size=[32,32],
                           transform=spoof_transforms,
                           smoothing=True)
val_data_combined = ConcatDataset([val_orig, val_flip])
val_loader=torch.utils.data.DataLoader(
                dataset=val_data_combined,
                batch_size=8,
                shuffle=True,
                num_workers=2)
testset = FASDataset(root_dir="/kaggle/input/hehedataset/Combine/test",
                depth_map_size=[32, 32],
                transform=normal_transforms,
                smoothing=False)
test_loader=torch.utils.data.DataLoader(
                dataset=testset,
                batch_size=8,
                shuffle=True,
                num_workers=2)


In [None]:
print(len(train_data_combined),len(val_data_combined),len(testset))

# Training

In [None]:
from tqdm import tqdm
train_loss_metric=AvgMeter(writer=writer, name="Loss/train", num_iter_per_epoch=len(train_loader), per_iter_vis=True)
train_acc_metric = AvgMeter(writer=writer, name='Accuracy/train', num_iter_per_epoch=len(train_loader), per_iter_vis=True)
val_loss_metric = AvgMeter(writer=writer, name='Loss/val', num_iter_per_epoch=len(val_loader))
val_acc_metric = AvgMeter(writer=writer, name='Accuracy/val', num_iter_per_epoch=len(val_loader))


def train_one_epoch(model,epoch, train_loader, device, optimizer, criterion, scheduler_lr=None):
    model.train()
    train_loss_metric.reset(epoch)
    train_acc_metric.reset(epoch)
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch}", position=0, leave=True)
    for i, (img, depth_map, label) in pbar:
        img, depth_map, label = img.to(device), depth_map.to(device), label.to(device)
        net_depth_map, _, _, _, _, _ = model(img)
        optimizer.zero_grad()
        loss = criterion(net_depth_map, depth_map)
        loss.backward()
        optimizer.step()
        preds, _ = predict(net_depth_map)
        targets, _ = predict(depth_map)
        accuracy = calc_accuracy(preds, targets)

        # Update metrics
        train_loss_metric.update(loss.item())
        train_acc_metric.update(accuracy)

        pbar.set_postfix({'Loss': train_loss_metric.avg, 'Accuracy': train_acc_metric.avg})


        

def validate_one_epoch(model,epoch, val_loader, device, criterion, scheduler_lr=None):
    model.eval()
    val_loss_metric.reset(epoch)
    val_acc_metric.reset(epoch)
    seed = randint(0, len(val_loader) - 1)
    with torch.no_grad():
        pbar = tqdm(enumerate(val_loader), total=len(val_loader), desc=f"Validation Epoch {epoch}", position=0, leave=True)
        for i, (img, depth_map, label) in pbar:
            img, depth_map, label = img.to(device), depth_map.to(device), label.to(device)
            net_depth_map, _, _, _, _, _ = model(img)
            loss = criterion(net_depth_map, depth_map)
            preds, score = predict(net_depth_map)
            targets, _ = predict(depth_map)
            accuracy = calc_accuracy(preds, targets)

            # Update metrics
            val_loss_metric.update(loss.item())
            val_acc_metric.update(accuracy)

            pbar.set_postfix({'Loss': val_loss_metric.avg, 'Accuracy': val_acc_metric.avg})

            if i == seed:
                add_visualization(epoch, img, preds, targets, score, writer)
    return val_acc_metric.avg


In [None]:
def evaluating(model, test_loader, device, criterion):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for img, depth_map, label in test_loader:
            img, depth_map, label = img.to(device), depth_map.to(device), label.to(device)
            net_depth_map, _, _, _, _, _ = model(img)
            loss = criterion(net_depth_map, depth_map)
            test_loss += loss.item()
            
            preds, _ = predict(net_depth_map)
            targets, _ = predict(depth_map)
            correct += preds.eq(targets).sum().item()
            total += targets.size(0)
    
    test_loss /= len(test_loader)
    test_acc = correct / total
    
    return test_acc, test_loss

In [None]:
best_acc=0.0
for epoch in range(0, 35):
    train_one_epoch(model, epoch, train_loader, device, optimizer, criterion)
    val_acc = validate_one_epoch(model, epoch, val_loader, device, criterion)
    
    # Evaluate on test set
    test_acc, test_loss = evaluating(model, test_loader, device, criterion)
    
    # Print test loss and accuracy
    print(f"Epoch {epoch} - Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")
    
    # Check if this is the best accuracy so far
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_acc': best_acc,
        }, 'best_model.pth')
        print(f"New best model saved with accuracy: {best_acc:.4f}")

    torch.save(model,'model.pth')



# Evaluation

In [None]:

model.eval()
with torch.no_grad():
    correct = 0
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    total = 0  # Total number of samples

    for i, (img, depth_map, label) in enumerate(test_loader):
        img, depth_map, label = img.to(device), depth_map.to(device), label.to(device)
        net_depth_map, _, _, _, _, _ = model(img)
        preds, score = predict(net_depth_map)
        targets, target_score = predict(depth_map)

        tp += (preds.eq(1) & targets.eq(1)).sum().item()
        tn += (preds.eq(0) & targets.eq(0)).sum().item()
        fp += (preds.eq(1) & targets.eq(0)).sum().item()
        fn += (preds.eq(0) & targets.eq(1)).sum().item()

        correct += preds.eq(targets).sum().item()
        total += targets.size(0)

    accuracy = correct / total
    far = fp / (fp + tn)
    frr = fn / (fn + tp)
    recall = tp / (tp + fn)
    hter = (far + frr) / 2

    print(f"Test Accuracy: {accuracy * 100:.2f}%")
    print(f"Recall: {recall * 100:.2f}%")
    print(f"FAR: {far * 100:.2f}%")
    print(f"FRR: {frr * 100:.2f}%")
    print(f"HTER: {hter * 100:.2f}%")
