In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import time
import os
import sys
from torch.utils.data import DataLoader, Dataset
import math
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import cv2
from transformers import pipeline
from PIL import Image
from torchvision import models
import torchvision.transforms.functional as TF
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import time

In [None]:
class CustomDataset(Dataset): 
    
    def __init__(self, path, device, transform=None, img_size=(128, 128)):
        super(CustomDataset, self).__init__()
        self.device = device
        self.images = []
        self.labels = []
        self.img_size = img_size
        self.transform = transform
        self.path = path
        self.num_channels = 1
        
        for folder in os.listdir(self.path):
            label = 1 if 'client' in folder else 0
            for image in os.listdir(os.path.join(self.path, folder)):
                if image.endswith('.jpg') or image.endswith('.png'):
                    img_path = os.path.join(self.path, folder, image)
                    self.images.append(img_path)
                    self.labels.append(label)
        
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = Image.open(self.images[idx]).convert("RGB")

        if self.transform:
            img= self.transform(img)
            
        return img, self.labels[idx]

In [None]:
class SimulateDistanceTransform:
    def __init__(self, min_scale=0.5, max_scale=1.0):
        self.min_scale = min_scale
        self.max_scale = max_scale

    def __call__(self, img):
        # Randomly choose a scale factor
        scale_factor = random.uniform(self.min_scale, self.max_scale)
        
        # Get original dimensions
        original_width, original_height = img.size
        
        # Calculate new dimensions
        new_width = int(original_width * scale_factor)
        new_height = int(original_height * scale_factor)
        
        # Resize the image
        img = transforms.Resize((new_height, new_width))(img)
        
        # Pad the image to the original size
        padding = (
            (original_width - new_width) // 2,
            (original_height - new_height) // 2,
            (original_width - new_width + 1) // 2,
            (original_height - new_height + 1) // 2
        )
        img = transforms.Pad(padding,padding_mode='edge')(img)
        
        # Optional: Apply a slight blur
        if scale_factor < 0.65:
            img = transforms.GaussianBlur(kernel_size=3)(img)
        
        return img

In [None]:
img_size = (252, 252)
batch_size = 80

transf = transforms.Compose([
    SimulateDistanceTransform(min_scale=0.4, max_scale=1.0),  
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1),
    transforms.Resize(img_size),
    transforms.ToTensor()
])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = CustomDataset("/kaggle/input/increased-liveliness-detection/train",device,transf,img_size=img_size)
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)

val_dataset = CustomDataset("/kaggle/input/increased-liveliness-detection/val",device,transf,img_size=img_size)
val_loader = DataLoader(val_dataset,batch_size=batch_size,shuffle=True)

# test_dataset = CustomDataset("/kaggle/input/increased-liveliness-detection/test",device,transf,img_size=img_size)
# test_loader = prepare_dataloader(test_dataset,batch_size=batch_size//world_size,shuffle=True)

In [None]:
class CDC(nn.Module):
    '''
    This class performs central difference convolution (CDC) operation. First the normal convolution is performed and then the difference convolution is performed. The output is the difference between the two is taken.
    '''
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
                 padding=1, dilation=1, groups=1, bias=False, theta=0.7):

        super(CDC, self).__init__()
        self.bias= bias
        self.stride = stride
        self.groups = groups
        self.dilation = dilation
        self.theta = theta
        self.padding = padding
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.bias = None
            
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding if kernel_size==3 else 0, dilation=dilation, groups=groups, bias=bias)

    def forward(self, x):
        out_normal = self.conv(x)
        # if conv.weight is (out_channels, in_channels, kernel_size, kernel_size),
        # then the  self.conv.weight.sum(2) will return (out_channels, in_channels,kernel_size)
        # and self.conv.weight.sum(2).sum(2) will return (out_channels,n_channels)
        kernel_diff = self.conv.weight.sum(2).sum(2)
        # Here we are adding extra dimensions such that the kernel_diff is of shape (out_channels, in_channels, 1, 1) so that convolution can be performed.
        kernel_diff = kernel_diff[:, :, None, None]
        out_diff = F.conv2d(input=x, weight=kernel_diff, bias=self.bias, stride=self.stride, padding=0, groups=self.groups)
        return out_normal - self.theta * out_diff

In [None]:
class FineTuneDepthAnything(nn.Module):
    def __init__(self, device):
        super(FineTuneDepthAnything, self).__init__()
        self.depth_anything = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf")
        for name,param in self.depth_anything.named_parameters():
            if 'head' in name or 'neck.fusion_stage.layers.2.residual_layer' in name or 'neck.fusion_stage.layers.3' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
        
        self.depth_anything = self.depth_anything.to(device)
                
    def forward(self, inp):
        return self.depth_anything(inp).predicted_depth.unsqueeze(1)

In [None]:
def contrast_depth_conv(input_tensor, device):
    """
    Compute contrast depth using depthwise convolution.
    
    Parameters:
    - input_tensor: A tensor of shape (N, C, H, W), expected to be (N, 1, 32, 32)
    - device: The device (CPU/GPU) the tensors should be processed on
    
    Returns:
    - A tensor of shape (N, 8, H, W) representing the contrast depth
    """
    # Ensure the input tensor is of the correct shape by removing any extra channel dimension
    input_tensor = input_tensor.squeeze(1)

    # Define the 8 different 3x3 kernel filters for contrast depth computation
    kernel_filter_list = [
        [[1, 0, 0], [0, -1, 0], [0, 0, 0]], [[0, 1, 0], [0, -1, 0], [0, 0, 0]], [[0, 0, 1], [0, -1, 0], [0, 0, 0]],
        [[0, 0, 0], [1, -1, 0], [0, 0, 0]], [[0, 0, 0], [0, -1, 1], [0, 0, 0]],
        [[0, 0, 0], [0, -1, 0], [1, 0, 0]], [[0, 0, 0], [0, -1, 0], [0, 1, 0]], [[0, 0, 0], [0, -1, 0], [0, 0, 1]]
    ]
    
    # Convert the list of kernel filters into a PyTorch tensor and send it to the specified device
    kernel_filter = torch.tensor(kernel_filter_list, dtype=torch.float32).to(device)
    
    # Add an extra dimension to the kernel filters to match the expected shape for conv2d (out_channels, in_channels, H, W)
    kernel_filter = kernel_filter.unsqueeze(dim=1)
    
    # Expand the input tensor to have 8 channels to match the number of kernel filters
    input_expanded = input_tensor.unsqueeze(dim=1).expand(-1, 8, -1, -1)
    
    # Perform depthwise convolution using the defined kernel filters
    contrast_depth = F.conv2d(input_expanded, weight=kernel_filter, groups=8)
    
    return contrast_depth

In [None]:
def customLoss(criterion,mse_criterion,predictions,labels,device):

    smooth_loss = criterion(predictions, labels)
    contrast_pred = contrast_depth_conv(predictions,device)
    contrast_label = contrast_depth_conv(labels,device)
    contrast_loss = mse_criterion(contrast_pred, contrast_label)

    return 0.3*smooth_loss + contrast_loss 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = FineTuneDepthAnything(device).to(device)
model.load_state_dict(torch.load('/kaggle/input/finetune_depth_anything/pytorch/54_epochs_trained/1/fine_tuning_depth_anything.pth',map_location=device,weights_only=False))
model = torch.nn.DataParallel(model, device_ids = [0,1]).to(device)

In [None]:
large_depth_map_model = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-Large-hf").to(device) # LiheYoung/depth-anything-base-hf
large_depth_map_model = torch.nn.DataParallel(large_depth_map_model, device_ids=[0,1]).to(device)
large_depth_map_model.eval()

In [None]:
criterion = nn.SmoothL1Loss()
mse_criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001,weight_decay=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)
num_epochs = 51

In [None]:
def plot_depth_maps(outputs, labels, binary):
    # Initialize the index dictionary with None values
    index = {0: None, 1: None}
    
    # Try to find the first occurrence of 0 and 1, if they exist
    for label in [0, 1]:
        matches = (binary == label).nonzero(as_tuple=True)[0]
        if len(matches) > 0:
            index[label] = matches[0].item()
    
    fig, ax = plt.subplots(1, 4, figsize=(12, 5))
    
    # Plot for label 0 if it exists
    if index[0] is not None:
        ax[0].imshow(outputs[index[0]].cpu().detach().squeeze(0).squeeze(0).numpy())
        ax[0].set_title('Predictions for label 0')
        ax[1].imshow(labels[index[0]].cpu().detach().squeeze(0).numpy())
        ax[1].set_title('Ground truth for label 0')
    else:
        ax[0].set_visible(False)
        ax[1].set_visible(False)
    
    # Plot for label 1 if it exists
    if index[1] is not None:
        ax[2].imshow(outputs[index[1]].cpu().detach().squeeze(0).squeeze(0).numpy())
        ax[2].set_title('Predictions for label 1')
        ax[3].imshow(labels[index[1]].cpu().detach().squeeze(0).numpy())
        ax[3].set_title('Ground truth for label 1')
    else:
        ax[2].set_visible(False)
        ax[3].set_visible(False)
    
    plt.show()

In [None]:
def get_labels(inputs, label):
    mask = label.view(-1, 1, 1, 1) == 1
    return torch.where(mask, inputs, torch.zeros_like(inputs))

In [None]:
train_loss = []
val_loss = []
best_epoch_loss = float('inf')
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for i, data in enumerate(train_loader, 0):
        inputs, binary_labels = data
        inputs, binary_labels = inputs.to(device), binary_labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)

        with torch.no_grad():
            out = large_depth_map_model(inputs)
            depth_maps = out.predicted_depth.unsqueeze(1)

        labels = get_labels(depth_maps, binary_labels)
        loss = customLoss(criterion, mse_criterion, outputs, labels, device)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    train_loss.append(running_loss)
    print(f"Epoch {epoch+1}, Training Loss: {running_loss} and lr is {optimizer.param_groups[0]['lr']}")
    plot_depth_maps(outputs,labels,binary_labels)
    
    if (epoch + 1) % 5 == 0:
        model.eval()  
        with torch.no_grad():
            running_loss_test = 0.0
            for i, data in enumerate(val_loader, 0):
                inputs_test,  binary_labels_test = data
                inputs_test, binary_labels_test = inputs_test.to(device), binary_labels_test.to(device)
                outputs_test= model(inputs_test)
                
                with torch.no_grad():
                    out = large_depth_map_model(inputs_test)
                    depth_maps_test = out.predicted_depth.unsqueeze(1)

                labels_test = get_labels(depth_maps_test, binary_labels_test)
                loss_test = customLoss(criterion,mse_criterion, outputs_test, labels_test,device)

                running_loss_test += loss_test.item()
                
            val_loss.append(running_loss_test)
            print(f"Validation Loss: {running_loss_test}")
            plot_depth_maps(outputs_test,labels_test,binary_labels_test)
            
            if running_loss_test < best_epoch_loss:
                best_epoch_loss = running_loss_test
                torch.save(model.module.state_dict(), "best_fine_tuning_depth_anything.pth")
                
    torch.save(model.module.state_dict(), "fine_tuning_depth_anything.pth")
    scheduler.step(running_loss)

train_loss = []
val_loss = []
best_epoch_loss = float('inf')
for epoch in range(num_epochs):
    # train_loader.sampler.set_epoch(epoch)
    model.train()
    running_loss = 0.0
    
    for i, data in enumerate(train_loader):
        inputs, binary_labels = data
        inputs = inputs.to(device)
        binary_labels = binary_labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        
        with torch.no_grad():
            out = large_depth_map_model(inputs)
            depth_maps = out.predicted_depth.unsqueeze(1)
        
        labels = get_labels(depth_maps, binary_labels)
        loss = customLoss(criterion, mse_criterion, outputs, labels, device)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    # Synchronize loss across all processes
    dist.all_reduce(torch.tensor(running_loss).to(device))
    running_loss /= world_size
    
    if local_rank == 0:
        train_loss.append(running_loss)
        print(f"Epoch {epoch+1}, Training Loss: {running_loss} and lr is {optimizer.param_groups[0]['lr']}")
        
    # Validation every 5 epochs
    if (epoch + 1) % 5 == 0:
        model.eval()
        with torch.no_grad():
            running_loss_test = 0.0
            
            for i, data in enumerate(val_loader):
                inputs_test, binary_labels_test = data
                inputs_test = inputs_test.to(device)
                binary_labels_test = binary_labels_test.to(device)
                
                outputs_test = model(inputs_test)
                
                with torch.no_grad():
                    out = large_depth_map_model(inputs_test)
                    depth_maps_test = out.predicted_depth.unsqueeze(1)
                
                labels_test = get_labels(depth_maps_test, binary_labels_test)
                loss_test = customLoss(criterion, mse_criterion, outputs_test, labels_test, device)
                running_loss_test += loss_test.item()
            
            # Synchronize validation loss
            dist.all_reduce(torch.tensor(running_loss_test).to(device))
            running_loss_test /= world_size
            
            if local_rank == 0:
                val_loss.append(running_loss_test)
                print(f"Validation Loss: {running_loss_test}")
                
                if running_loss_test < best_epoch_loss:
                    best_epoch_loss = running_loss_test
                    torch.save(model.module.state_dict(), "best_fine_tuning_depth_anything.pth")
    
    if local_rank == 0:
        torch.save(model.module.state_dict(), "fine_tuning_depth_anything.pth")
    
    scheduler.step(running_loss)