# Importing necessary libraries

In [None]:
!pip install -r requirements.txt

In [None]:
import os
import zipfile
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import timm
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
from itertools import combinations
import numpy as np
import cv2
from skimage.feature import local_binary_pattern
from scipy.signal import convolve2d
import torch.nn.functional as F
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.signal import convolve2d
import numpy as np
from torchvision.datasets.folder import default_loader
import shutil

batch_size = 32
learning_rate = 1e-4
num_epochs = 1
num_classes = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Importing required data appropriately

In [None]:
# Unzipping custom_dataset.zip
#The unzipping results in 2 datasets - custom_dataset and test_dataset - being created in the working directory
with zipfile.ZipFile("custom_dataset.zip", 'r') as zip_ref:
    zip_ref.extractall('custom_dataset')

# Unzipping tes-final.zip
with zipfile.ZipFile("test_dataset.zip", 'r') as zip_ref:
    zip_ref.extractall('test_dataset')

# Configuration
train_dir = "/kaggle/input/custom_dataset/train"
test_dir = "/kaggle/input/custom_dataset/test"

In [None]:
# Function to remove .ipynb_checkpoints folder from the dataset
def remove_ipynb_checkpoints(dataset_dir):
    for root, dirs, files in os.walk(dataset_dir, topdown=False):
        # Check if .ipynb_checkpoints is in the directories list
        if '.ipynb_checkpoints' in dirs:
            # Construct the full path to the .ipynb_checkpoints directory
            checkpoint_dir = os.path.join(root, '.ipynb_checkpoints')
            # Remove the .ipynb_checkpoints directory and all its contents
            shutil.rmtree(checkpoint_dir)

# Paths to the train and test directories
train_dir = 'custom_dataset/train'
test_dir = 'custom_dataset/test'

# Remove .ipynb_checkpoints from both train and test directories
remove_ipynb_checkpoints(train_dir)
remove_ipynb_checkpoints(test_dir)
remove_ipynb_checkpoints('test_dataset/test-interiit/perturbed_images_32')


In [None]:
#Fuctions to compute the FFT and LBP of an input image
def compute_fft_magnitude(image):
    """
    Computes the FFT magnitude spectrum of an input image.
    
    Args:
        image (torch.Tensor or np.ndarray): Input image (RGB).
        
    Returns:
        np.ndarray: Normalized magnitude spectrum of the FFT.
    """
    # Check if input is a PyTorch tensor
    if isinstance(image, torch.Tensor):
        image = image.permute(1, 2, 0).cpu().numpy().astype(np.uint8)  # Convert to NumPy
    # Perform FFT computation
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    f = np.fft.fft2(gray)
    fshift = np.fft.fftshift(f)
    magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1e-10)  # Avoid log(0)
    return cv2.normalize(magnitude_spectrum, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

def compute_lbp(image):
    """
    Computes the Local Binary Pattern (LBP) of an input image.
    
    Args:
        image (torch.Tensor or np.ndarray): Input image (RGB).
        
    Returns:
        np.ndarray: LBP of the image, resized to 32x32.
    """
    # Check if input is a PyTorch tensor
    if isinstance(image, torch.Tensor):
        image = image.permute(1, 2, 0).cpu().numpy().astype(np.uint8)  # Convert to NumPy
    # Perform LBP computation
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    lbp = local_binary_pattern(gray, P=8, R=1, method="uniform")
    return cv2.resize(lbp, (32, 32))

#Dataset Preparation
transform = transforms.Compose([
    transforms.Resize((32, 32)),  # Match DEFL input
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 
])

# DEFL architecture development

In [None]:
#Define the 8 base filters (example 3x3 filters, customize as needed)
base_filters = [
    np.array([[0, 0, 0], [0, -1, 1], [0, 0, 0]]),  
    np.array([[0, 0, 0], [1, -1, 0], [0, 0, 0]]),  
    np.array([[0, 1, 0], [0, -1, 0], [0, 0, 0]]), 
    np.array([[0, 0, 0], [0, -1, 0], [0, 1, 0]]),  
    np.array([[0, 0, 1], [0, -1, 0], [0, 0, 0]]), 
    np.array([[0, 0, 0], [0, -1, 0], [1, 0, 0]]), 
    np.array([[1, 0, 0], [0, -1, 0], [0, 0, 0]]), 
    np.array([[0, 0, 0], [0, -1, 0], [0, 0, 1]]) 
]

# List to store the composite filters
composite_filters = []
# Generate composite filters for n = 1 to 7
for n in range(1, 8):
    # Get all combinations of n base filters
    for combo in combinations(base_filters, n):
        # Start with the first filter in the combination
        composite_filter = combo[0]
        # Convolve with the remaining filters in the combination
        for next_filter in combo[1:]:
            composite_filter = convolve2d(composite_filter, next_filter, mode='same')
        # Add the resulting composite filter to the list
        composite_filters.append(composite_filter)
# Duplicating and adding two random filters to the composite filters to make it 256
composite_filters.append(composite_filters[0])
composite_filters.append(composite_filters[1])
# Shuffling the contents of my list
composite_filters = random.sample(composite_filters, len(composite_filters))


class DirectionalConvolutionalBlock(nn.Module):
    """
    A block that applies a set of directional convolutions to the input image.

    This block uses a set of composite directional filters, applied through 
    multiple convolutional layers, to extract directional features from the input image.
    
    Args:
        composite_filters (list of numpy.ndarray): List of filter weights for the directional convolutions.
        in_channels (int): Number of input channels of the image (e.g., 3 for RGB images).

    Layers:
        - `conv_layers`: A list of 64 convolutional layers, each applying a directional filter to the input image.
    
    The input is passed through each of the 64 convolutional layers, and the results are concatenated 
    along the channel dimension to produce the output feature map.
    """
    def __init__(self, composite_filters, in_channels):
        super(DirectionalConvolutionalBlock, self).__init__()
        self.filters = composite_filters
        self.conv_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=1, kernel_size=3, padding=1, bias=False),
                #nn.BatchNorm2d(3)  # Add BatchNorm
            )
            for _ in range(64)
        ])
        for i, filter_weights in enumerate(self.filters):
            weight = torch.tensor(filter_weights, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
            self.conv_layers[i][0].weight.data = weight.repeat(1, in_channels, 1, 1)

    def forward(self, x):
        """
        Applies the directional convolution filters to the input image.

        Args:
            x (torch.Tensor): The input tensor with shape (batch_size, in_channels, height, width).
        
        Returns:
            torch.Tensor: Concatenated output tensor after applying all directional convolutions.
        """
        outputs = []
        spatial_size = x.size()[2:]  # Get the spatial dimensions (H, W)
        for conv_layer in self.conv_layers:
            out = F.relu(conv_layer(x))  # ReLU after batch norm
            # if out.size()[2:] != spatial_size:
            #     out = F.interpolate(out, size=spatial_size, mode='bilinear', align_corners=False)
            outputs.append(out)
        return torch.cat(outputs, dim=1)  # Concatenate along the channel dimension


class StandardConvolutionalBlock(nn.Module):
    """
    A block that applies a standard set of convolutions to the input image.

    This block uses 64 convolutional layers to extract features from the input image using standard 
    convolutions, each with a kernel size of 3 and padding of 1. The outputs of these convolutions 
    are concatenated along the channel dimension to form the output.

    Args:
        in_channels (int): Number of input channels of the image (e.g., 3 for RGB images).

    Layers:
        - `conv_layers`: A list of 64 convolutional layers, each applying standard convolution to the input image.

    The input is passed through each of the 64 convolutional layers, and the results are concatenated 
    along the channel dimension to produce the output feature map.
    """
    def __init__(self, in_channels):
        super(StandardConvolutionalBlock, self).__init__()
        self.conv_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=1, kernel_size=3, padding=1, bias=False),
                #nn.BatchNorm2d(1)  # Add BatchNorm
            )
            for _ in range(64)
        ])

    def forward(self, x):
        """
        Applies standard convolution filters to the input image.

        Args:
            x (torch.Tensor): The input tensor with shape (batch_size, in_channels, height, width).
        
        Returns:
            torch.Tensor: Concatenated output tensor after applying all convolutions.
        """
        outputs = []
        for conv_layer in self.conv_layers:
            out = F.relu(conv_layer(x))  # ReLU after batch norm
            outputs.append(out)
        return torch.cat(outputs, dim=1)  # Concatenate along the channel dimension


class DEFL(nn.Module):
    """
    DEFL (Directional and Standard Feature Learning) model for feature extraction.

    The DEFL model combines two types of feature extraction blocks:
    1. **Directional Convolutional Blocks (DCBs)**: Apply directional filters to the input image to capture directional features.
    2. **Standard Convolutional Blocks (SCBs)**: Apply standard convolutions to extract general features.

    The model is structured into multiple levels, where each level contains both a Directional and Standard Convolutional Block.

    Args:
        composite_filters (list of list of numpy.ndarray): List of directional filters used in the DCBs. 
        The list is divided into 4 sets of 64 filters each for each level.

    Layers:
        - `levels`: A list of 4 levels, where each level contains both a DCB and a SCB.
        
    The input is passed through each level, where the output of each block (DCB and SCB) is concatenated 
    along the channel dimension and passed to the next level. The concatenation of outputs from the last level is returned.
    """
    def __init__(self, composite_filters):
        super(DEFL, self).__init__()
        self.levels = nn.ModuleList()
        in_channels = 3  # Initial input has 3 channels
        for i in range(4):
            dcb_filters = composite_filters[i * 64:(i + 1) * 64]  # Each level gets 64 filters
            self.levels.append(nn.ModuleDict({
                "dcb": DirectionalConvolutionalBlock(dcb_filters, in_channels),
                "scb": StandardConvolutionalBlock(in_channels)
            }))
            in_channels = 128  # After concatenating DCB and SCB outputs

    def forward(self, x):
        """
        Forward pass through the DEFL model.

        Args:
            x (torch.Tensor): The input tensor with shape (batch_size, in_channels, height, width).
        
        Returns:
            torch.Tensor: The output tensor after passing through all levels, containing concatenated 
                          feature maps from DCBs and SCBs.
        """
        for level in self.levels:
            dcb_out = level["dcb"](x)  # Pass through Directional Convolutional Block
            scb_out = level["scb"](x)  # Pass through Standard Convolutional Block
            x = torch.cat([dcb_out, scb_out], dim=1)  # Combine outputs
        return x


# Classification Model instantiation and loss function

In [None]:
#Load Pretrained ViT and Access CLS Token
class CombinedViTModel(nn.Module):
    """
    Combined Vision Transformer (ViT) model for binary classification with custom classifier head.

    This model loads a pre-trained Vision Transformer (ViT) model and removes its 
    classification head. A new linear classifier is then applied to the output of 
    the CLS token (embedding) for binary classification tasks.

    Args:
        base_model_name (str): Name of the pre-trained ViT model (default: 'vit_tiny_patch16_224').
        embedding_dim (int): Dimensionality of the ViT embedding (default: 192).
        num_classes (int): Number of output classes (default: 2, for binary classification).
        in_channels (int): Number of input channels (default: 130).

    Layers:
        - `base_model`: Pre-trained Vision Transformer model without the classification head.
        - `classifier`: A custom linear classifier for binary classification.

    Forward Pass:
        The input image is passed through the pre-trained ViT model to extract the CLS token 
        (embedding), which is then passed through the new classifier to obtain the logits.
    """
    def __init__(self, base_model_name='vit_tiny_patch16_224', embedding_dim=192, num_classes=2, in_channels = 130):
        super(CombinedViTModel, self).__init__()
        # Load pre-trained ViT model
        self.base_model = timm.create_model(base_model_name, pretrained=True, img_size = 32, in_chans = 130)
        self.base_model.head = nn.Identity()  # Nullify the classification head
        
        # New classification head
        self.classifier = nn.Linear(embedding_dim, 1)

    def forward(self, x):
        """
        Forward pass through the model.

        Args:
            x (torch.Tensor): Input tensor with shape (batch_size, in_channels, height, width).

        Returns:
            tuple: 
                - embeddings (torch.Tensor): Extracted CLS token (embedding) from the model.
                - logits (torch.Tensor): Output logits for binary classification.
        """
        # Extract embeddings (CLS token) from the base model
        embeddings = self.base_model.forward_features(x)[:, 0]  # CLS token is the first token
        logits = self.classifier(embeddings)  # Binary classification logits
        return embeddings, logits

#Initialize a variable to observe the value of the 2 losses in the first run of the model, and not again
count1 = 0

# Define the Combined Loss Function
class CombinedLoss(nn.Module):
    """
    Combined loss function with Binary Cross-Entropy (BCE) Loss and Contrastive Loss.

    This loss function computes two components:
    1. Binary Cross-Entropy loss (BCE) for binary classification.
    2. Contrastive loss to learn feature similarities between positive and negative pairs.
    The total loss is the weighted sum of BCE loss and contrastive loss.

    Args:
        margin (float): Margin for contrastive loss (default: 1.0).
        temperature (float): Temperature for similarity scaling in contrastive loss (default: 0.5).
        lambda_contrastive (float): Weight for the contrastive loss component (default: 0.01).

    Layers:
        - `bce_loss`: Binary Cross-Entropy loss function for binary classification.

    Forward Pass:
        The BCE loss is computed from the logits and labels. The contrastive loss is computed 
        from the embeddings and labels, and the two losses are combined with the weighting factor.
    """
    def __init__(self, margin=1.0, temperature=0.5, lambda_contrastive=0.01):
        super(CombinedLoss, self).__init__()
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.margin = margin
        self.temperature = temperature
        self.lambda_contrastive = lambda_contrastive

    def forward(self, embeddings, logits, labels):
        """
        Computes the combined loss.

        Args:
            embeddings (torch.Tensor): Embeddings (CLS token) produced by the model.
            logits (torch.Tensor): Logits for binary classification.
            labels (torch.Tensor): Ground truth labels for classification.

        Returns:
            torch.Tensor: Combined loss (BCE loss + contrastive loss).
        """
        global count1
        # Compute BCE Loss
        bce_loss = self.bce_loss(logits, labels.float().unsqueeze(1))
        
        # Compute Contrastive Loss
        contrastive_loss = self.contrastive_loss(embeddings, labels)
        
        # Combine the Losses
        combined_loss = bce_loss + self.lambda_contrastive * contrastive_loss
        if count1 == 0:
            print(bce_loss)
            print(contrastive_loss)
            count1 += 1
        return combined_loss

    def contrastive_loss(self, features, labels):
        """
        Computes the contrastive loss between feature embeddings.

        Args:
            features (torch.Tensor): Feature embeddings (output of ViT).
            labels (torch.Tensor): Ground truth labels for contrastive loss computation.

        Returns:
            torch.Tensor: Contrastive loss based on similarity between feature embeddings.
        """
        # Normalize features
        features = F.normalize(features, dim=1)
        
        # Compute similarity matrix
        sim_matrix = torch.matmul(features, features.T) / self.temperature
        
        # Create label masks
        labels = labels.view(-1, 1)
        mask_pos = torch.eq(labels, labels.T).float()
        
        # Remove diagonal (self-similarities)
        mask_pos.fill_diagonal_(0)
        
        # Compute positive and negative losses
        exp_sim = torch.exp(sim_matrix)
        eps = 1e-8  # Numerical stability
        log_prob = sim_matrix - torch.log(exp_sim.sum(dim=1, keepdim=True) + eps)
        
        # Mean over positive pairs
        pos_mask_sum = mask_pos.sum(1).clamp(min=eps)
        mean_log_prob_pos = (mask_pos * log_prob).sum(1) / pos_mask_sum
        
        return -mean_log_prob_pos.mean()


# Initialize Loss and Optimizer - these values were chosen after several trials using Optuna
temperature = 0.5
lambda_contrastive = 0.01

# Training, Evaluation and Testing

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])

test_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])

perturbed_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])

train_dir = 'custom_dataset/train'
test_dir = 'custom_dataset/test'
perturbed_test_dir = 'test_dataset/test-interiit/perturbed_images_32'

train_dataset = datasets.ImageFolder(root=train_dir, transform=train_transform)
test_dataset = datasets.ImageFolder(root=test_dir, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)





In [None]:
def train_vit_with_combined_loss(defl_model, vit_model, train_dir, device):
    """
    Trains the DEFL and ViT models together with a combined loss function.

    This function trains the DEFL and ViT models on the provided training data 
    using a combined loss that includes Binary Cross-Entropy (BCE) and Contrastive Loss. 
    It computes feature representations using the DEFL model and combines them with 
    additional FFT and LBP features before passing them through the ViT model.

    Args:
        defl_model (nn.Module): The DEFL model used for feature extraction.
        vit_model (nn.Module): The ViT model used for classification.
        train_dir (str): Path to the training dataset.
        device (torch.device): Device to run the models on (CPU or GPU).
    
    Saves:
        The weights of both DEFL and ViT models to 'defl_weights.pth' and 'vit_weights.pth'.
    """
    defl_model.train()  # Set DEFL model to train
    vit_model.train()   # Set ViT model to train

    # Load training data from the dataset directory
    train_dataset = datasets.ImageFolder(root=train_dir, transform=train_transform)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
    
    optimizer = optim.Adam(list(defl_model.parameters()) + list(vit_model.parameters()), lr=0.001)  # Optimizer for both models

    for epoch in range(num_epochs):
        total_loss = 0
        all_preds, all_labels = [], []
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Compute FFT and LBP features for the batch
            fft_features = []
            lbp_features = []
            for img in images:  # Process each image in the batch
                fft_features.append(compute_fft_magnitude(img))  # Now works with PyTorch tensors
                lbp_features.append(compute_lbp(img))  # Now works with PyTorch tensors

            # Convert FFT and LBP features to tensors
            fft_features = torch.tensor(fft_features).unsqueeze(1).float().to(device)  # (B, 1, H, W)
            lbp_features = torch.tensor(lbp_features).unsqueeze(1).float().to(device)  # (B, 1, H, W)

            # Pass through DEFL
            feature_maps = defl_model(images)  # Output: (B, 128, H, W)

            # Concatenate DEFL, FFT, and LBP features
            combined_features = torch.cat((feature_maps, fft_features, lbp_features), dim=1)  # (B, 130, H, W)

            # Pass concatenated features through ViT
            embeddings, logits = vit_model(combined_features)

            # Compute combined loss (BCE + Contrastive Loss)
            loss = combined_loss(embeddings, logits, labels)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update total loss
            total_loss += loss.item()

            # Calculate train accuracy
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

        train_accuracy = accuracy_score(all_labels, all_preds)
        print(f"Epoch {epoch+1}/{num_epochs}, Combined Loss: {total_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")

    # Save weights for both models
    torch.save(defl_model.state_dict(), 'defl_weights.pth')
    torch.save(vit_model.state_dict(), 'vit_weights.pth')

def evaluate_vit_with_combined_loss(defl_model, vit_model, test_dir, device):
    """
    Evaluates the performance of the trained DEFL and ViT models on the test dataset.

    This function computes the accuracy, recall, precision, F1 score, and confusion matrix 
    by passing the test data through the DEFL and ViT models. It calculates the predictions 
    using the combined embeddings (from DEFL and ViT) and compares them to the true labels.

    Args:
        defl_model (nn.Module): The DEFL model used for feature extraction.
        vit_model (nn.Module): The ViT model used for classification.
        test_dir (str): Path to the test dataset.
        device (torch.device): Device to run the models on (CPU or GPU).
    
    Prints:
        Evaluation metrics including accuracy, recall, precision, F1 score, and confusion matrix.
    """
    defl_model.eval()  # Set DEFL model to eval
    vit_model.eval()   # Set ViT model to eval

    test_dataset = datasets.ImageFolder(root=test_dir, transform=test_transform)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
    
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            # Compute FFT and LBP features for the batch
            fft_features = []
            lbp_features = []
            for img in images:  # Process each image in the batch
                fft_features.append(compute_fft_magnitude(img))  # Now works with PyTorch tensors
                lbp_features.append(compute_lbp(img))  # Now works with PyTorch tensors

            # Convert FFT and LBP features to tensors
            fft_features = torch.tensor(fft_features).unsqueeze(1).float().to(device)  # (B, 1, H, W)
            lbp_features = torch.tensor(lbp_features).unsqueeze(1).float().to(device)  # (B, 1, H, W)

            # Pass through DEFL
            feature_maps = defl_model(images)  # Output: (B, 128, H, W)

            # Concatenate DEFL, FFT, and LBP features
            combined_features = torch.cat((feature_maps, fft_features, lbp_features), dim=1)  # (B, 130, H, W)

            # Pass concatenated features through ViT
            embeddings, logits = vit_model(combined_features)

            # Apply sigmoid activation and threshold
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float()

            # Store predictions and labels
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculate evaluation metrics
    accuracy = accuracy_score(all_labels, all_preds)
    recall = recall_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    conf_matrix = confusion_matrix(all_labels, all_preds)

    # Log evaluation results
    print(f"Accuracy: {accuracy:.4f}, Recall: {recall:.4f}, Precision: {precision:.4f}, F1 Score: {f1:.4f}")
    print("Confusion Matrix:")
    print(conf_matrix)

def predict_and_save(vit_model, defl_model, test_dir, output_csv='predictions.csv', device='cuda'):
    """
    Performs predictions on the test dataset and saves the results to a CSV file.

    This function passes the test images through the DEFL and ViT models to generate predictions. 
    The results are then stored in a CSV file for further analysis.

    Args:
        vit_model (nn.Module): The ViT model used for classification.
        defl_model (nn.Module): The DEFL model used for feature extraction.
        test_dir (str): Path to the test dataset.
        output_csv (str): Path to save the predictions (default: 'predictions.csv').
        device (torch.device): Device to run the models on (CPU or GPU).
    
    Saves:
        The predictions to the specified CSV file.
    """
    vit_model.eval()
    defl_model.eval()

    test_dataset = datasets.ImageFolder(root=test_dir, transform=perturbed_transform)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
    
    all_preds = []
    with torch.no_grad():
        for images in test_loader:
            images = images.to(device)

            # Compute FFT and LBP features for the batch
            fft_features = []
            lbp_features = []
            for img in images:  # Process each image in the batch
                fft_features.append(compute_fft_magnitude(img))  # Now works with PyTorch tensors
                lbp_features.append(compute_lbp(img))  # Now works with PyTorch tensors

            # Convert FFT and LBP features to tensors
            fft_features = torch.tensor(fft_features).unsqueeze(1).float().to(device)  # (B, 1, H, W)
            lbp_features = torch.tensor(lbp_features).unsqueeze(1).float().to(device)  # (B, 1, H, W)

            # Pass through DEFL
            feature_maps = defl_model(images)  # Output: (B, 128, H, W)

            # Concatenate DEFL, FFT, and LBP features
            combined_features = torch.cat((feature_maps, fft_features, lbp_features), dim=1)  # (B, 130, H, W)

            # Pass concatenated features through ViT
            _, logits = vit_model(combined_features)

            # Apply sigmoid activation and threshold
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float()

            # Store predictions
            all_preds.extend(preds.cpu().numpy())

    # Save predictions to CSV
    df = pd.DataFrame(all_preds, columns=['Predictions'])
    df.to_csv(output_csv, index=False)
    print(f"Predictions saved to {output_csv}")


In [None]:
# Load models
# Initialize Loss and Optimizer - these values were chosen after several trials using Optuna
vit_model = CombinedViTModel().to(device)
defl_model = DEFL(composite_filters).to(device)
temperature = 0.5
lambda_contrastive = 0.01
combined_loss = CombinedLoss(lambda_contrastive=lambda_contrastive).to(device)
optimizer = optim.Adam(vit_model.parameters(), lr=learning_rate)

# Train and save the model weights
train_vit_with_combined_loss(defl_model, vit_model, train_dir='custom_dataset/train', device=device)

# Load the saved weights for both models
defl_model.load_state_dict(torch.load('defl_weights.pth'))
vit_model.load_state_dict(torch.load('vit_weights.pth'))

# Perform evaluation on test data
evaluate_vit_with_combined_loss(defl_model, vit_model, test_dir='custom_dataset/test', device=device)

# Perform predictions on new dataset and save results
predict_and_save(vit_model, defl_model, test_dir='test_dataset/test-interiit/perturbed_images_32', output_csv='predictions.csv', device=device)



In [None]:
# Count Trainable Parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Trainable Parameters in ViT: {count_parameters(vit_model)}")