Importing Libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import math
import collections.abc
from typing import Dict, List, Optional, Set, Tuple, Union
import torch.utils.checkpoint
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt




Model Pipeline

In [None]:
class PatchConfig:
    
    def __init__(
        self,
        hidden_size=128,
        num_hidden_layers=6,
        num_attention_heads=8,
        intermediate_size=1024,
        hidden_dropout_prob=0.02,
        image_size=32,
        patch_size=4,
        num_channels=3,
        num_blocks=6
    ):

        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_blocks=num_blocks

In [None]:
x=PatchConfig()
x.hidden_size

In [None]:
class PatchEmbeddings(nn.Module):
    '''
    Construct the position and patch embeddings.
    Args:
        config: Configuration object containing model parameters.
        use_mask_token (bool, optional): Whether to use a mask token. Defaults to False.
    Attributes:
        config: Configuration object containing model parameters.
        cls_token (nn.Parameter): Learnable [CLS] token.
        patch_embeddings (PatchPatchEmbeddings): Patch embeddings module.
        position_embeddings (nn.Parameter): Learnable position embeddings.
        dropout (nn.Dropout): Dropout layer.
        patch_size (int): Size of each patch.
        hidden_size (int): Size of the hidden layer.
    Methods:
        pos_encoding(embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
            Computes the positional encoding for the given embeddings.
        forward(pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None, pos_encoding: bool = True) -> torch.Tensor:
            Forward pass to compute the embeddings for the input pixel values.
    
    '''

    def __init__(self, config, use_mask_token: bool = False) -> None:
        super().__init__()
        self.config = config
        self.cls_token = nn.Parameter(torch.randn(1, 1, self.config.hidden_size))
        self.patch_embeddings = PatchPatchEmbeddings(self.config)
        num_patches = self.patch_embeddings.num_patches
        self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, self.config.hidden_size))
        self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
        self.patch_size = self.config.patch_size
        self.hidden_size = self.config.hidden_size

    def pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
   
        num_patches = embeddings.shape[1] - 1
        num_positions = self.position_embeddings.shape[1] - 1

        # always interpolate when tracing to ensure the exported model works for dynamic input shapes
        if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
            return self.position_embeddings

        class_pos_embed = self.position_embeddings[:, :1]
        patch_pos_embed = self.position_embeddings[:, 1:]

        dim = embeddings.shape[-1]

        new_height = height // self.patch_size
        new_width = width // self.patch_size

        sqrt_num_positions =  int(math.sqrt(num_positions)) # if num_positions is a tensor

        patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            size=(new_height, new_width),
            mode="bicubic",
            align_corners=False,
        )

        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_pos_embed, patch_pos_embed), dim=1)

    def forward(
        self,
        pixel_values: torch.Tensor,
        bool_masked_pos: Optional[torch.BoolTensor] = None,
        pos_encoding: bool = True,
    ) -> torch.Tensor:
        batch_size, num_channels, height, width = pixel_values.shape
        embeddings = self.patch_embeddings(pixel_values, pos_encoding=pos_encoding)


        # add the [CLS] token to the embedded patch tokens
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        embeddings = torch.cat((cls_tokens, embeddings),
                               dim=1)

        # add positional encoding to each token
        if pos_encoding:
            embeddings = embeddings + self.pos_encoding(embeddings, height, width)
        else:
            embeddings = embeddings + self.position_embeddings

        embeddings = self.dropout(embeddings)

        return embeddings  #(batch_size,num_patches+1,hidden_size)

class PatchPatchEmbeddings(nn.Module):


    def __init__(self, config):
        super().__init__()

        self.config = config
        image_size, patch_size = self.config.image_size, self.config.patch_size
        num_channels, hidden_size = self.config.num_channels, self.config.hidden_size

        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_patches = num_patches

        self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, pixel_values: torch.Tensor, pos_encoding: bool = True) -> torch.Tensor:
        batch_size, num_channels, height, width = pixel_values.shape
        if num_channels != self.num_channels:
            raise ValueError(
                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
                f" Expected {self.num_channels} but got {num_channels}."
            )
        if not pos_encoding:
            if height != self.image_size[0] or width != self.image_size[1]:
                raise ValueError(
                    f"Input image size ({height}*{width}) doesn't match model"
                    f" ({self.image_size[0]}*{self.image_size[1]})."
                )
        embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) #(batch_size,num_patches,hidden_size)
        return embeddings

In [None]:
x = PatchEmbeddings(PatchConfig())
x.hidden_size

In [None]:
class HolloPatchBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, expansion_dim, in_channels=3):
        """
        Initializes the HolloPatchBlock module.
        Args:
            embed_dim (int): The dimension of the embedding.
            num_heads (int): The number of attention heads.
            expansion_dim (int): The dimension to which the feed-forward network expands.
            in_channels (int, optional): The number of input channels. Default is 3.
        Attributes:
            norm1 (nn.LayerNorm): Layer normalization for the input embeddings.
            mha (nn.MultiheadAttention): Multi-head attention mechanism.
            norm2 (nn.LayerNorm): Layer normalization after the attention mechanism.
            ffn (nn.Sequential): Feed-forward network consisting of two linear layers with a GELU activation in between.
            embed_dim (int): The dimension of the embedding.
            batch_norm1 (nn.BatchNorm2d): Batch normalization for the input channels.
            t2i_conv1 (nn.Conv2d): Convolutional layer transforming input channels to 9 times the input channels.
            batch_norm2 (nn.BatchNorm2d): Batch normalization for the transformed channels.
            t2i_conv2 (nn.Conv2d): Convolutional layer transforming back to the original input channels.
            conv_trans1 (nn.ConvTranspose2d): Transposed convolutional layer transforming embeddings back to input channels.
        """
        super(HolloPatchBlock, self).__init__()

        self.norm1 = nn.LayerNorm(embed_dim)
        self.mha = nn.MultiheadAttention(embed_dim, num_heads)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, expansion_dim),
            nn.GELU(),
            nn.Linear(expansion_dim, embed_dim),
        )
        self.embed_dim = embed_dim
        self.batch_norm1=nn.BatchNorm2d(in_channels)
        self.t2i_conv1 = nn.Conv2d(in_channels, in_channels*9, kernel_size=3, padding=1)
        self.batch_norm2=nn.BatchNorm2d(in_channels*9)
        self.t2i_conv2 = nn.Conv2d(in_channels*9, in_channels, kernel_size=3, padding=1)
        self.conv_trans1 = nn.ConvTranspose2d(embed_dim, in_channels, kernel_size=4, stride=4)

    def forward(self, patch_embeddings, image_embeddings):
        # Multi-head Attention
        patch_embeddings = patch_embeddings + self.mha(self.norm1(patch_embeddings), patch_embeddings, patch_embeddings)[0]
        # Feed Forward Network
        patch_embeddings = patch_embeddings + self.ffn(self.norm2(patch_embeddings))
        patch_1 = patch_embeddings[:, 1:, :].permute(0,2,1).reshape(-1, self.embed_dim, 8, 8)
        patch_image = self.conv_trans1(patch_1)
        
        # T2I Block for combining patch and image embeddings
        t2i_out = F.relu(self.t2i_conv1(self.batch_norm1(image_embeddings)))
        t2i_out = F.relu(self.t2i_conv2(self.batch_norm2(t2i_out)))

        combined_embeddings = patch_image + t2i_out  # Combining along the channel dimension
        return patch_embeddings,combined_embeddings

In [None]:
HolloPatch=HolloPatchBlock(128, 8, 512)
HolloPatch

In [None]:
class SignatureExtractor(nn.Module):
    def __init__(self, config, in_channels=3):
        """
        Initializes the SignatureExtractor class.
        Args:
            config (object): Configuration object containing model parameters.
            in_channels (int, optional): Number of input channels. Default is 3.
        Attributes:
            patch_embed (PatchEmbeddings): Embedding layer for patches.
            embed_dim (int): Dimension of the embeddings.
            num_attention_heads (int): Number of attention heads.
            expansion_dim (int): Dimension of the intermediate expansion layer.
            HolloPatch_blocks (nn.ModuleList): List of HolloPatchBlock modules.
            conv (nn.Conv2d): Convolutional layer for rectified image.
        """
        super(SignatureExtractor, self).__init__()

        

        self.patch_embed = PatchEmbeddings(config) #(batch,num_patch+1,hidden_size(embed_dim))
        self.embed_dim=config.hidden_size
        self.num_attention_heads= config.num_attention_heads
        self.expansion_dim =  config.intermediate_size
        self.HolloPatch_blocks = nn.ModuleList([
            HolloPatchBlock(self.embed_dim, self.num_attention_heads, self.expansion_dim) for _ in range(config.num_blocks)
        ])
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)  # Rectified Image

    def forward(self, x):
        
        patch_embeddings = self.patch_embed(x)  #(batch,num_patch+1,hidden_size(embed_dim))
        image_embeddings = x
        
        for HolloPatch_block in self.HolloPatch_blocks:
            patch_embeddings, image_embeddings = HolloPatch_block(patch_embeddings, image_embeddings)
            
        rectified_image = self.conv(image_embeddings)
        
        return rectified_image, x-rectified_image

In [None]:
x = SignatureExtractor(PatchConfig())
x.patch_embed.config.hidden_size

In [None]:
class ModelPipeline(nn.Module):
    """
    A neural network model pipeline for image processing and classification.
    Args:
        config (dict): Configuration dictionary for the SignatureExtractor.
        in_channels (int, optional): Number of input channels for the convolutional layer. Default is 3.
        num_classes (int, optional): Number of output classes for the classifier. Default is 2.
    Attributes:
        signature_extractor (SignatureExtractor): Module to extract signatures from input images.
        conv (nn.Conv2d): Convolutional layer to process input images.
        batch_norm (nn.BatchNorm2d): Batch normalization layer for the convolutional output.
        fc1 (nn.Linear): Fully connected layer for classification.
        attack_classifier (nn.Softmax): Softmax layer to produce class probabilities.
        flatten (nn.Flatten): Flatten layer to reshape the tensor for the fully connected layer.
    Methods:
        forward(x):
            Forward pass of the model.
            Args:
                x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width).
            Returns:
                tuple: A tuple containing:
                    - class_probs (torch.Tensor): Class probabilities of shape (batch_size, num_classes).
                    - signature (torch.Tensor): Extracted signature from the input image.
                    - rectified_image (torch.Tensor): Rectified image from the SignatureExtractor.
    """


    def __init__(self, config,in_channels=3, num_classes=2):
        super(ModelPipeline, self).__init__()
        self.signature_extractor = SignatureExtractor(config)
        self.conv = nn.Conv2d(in_channels, in_channels*5, kernel_size=3, padding=1)
        self.batch_norm=nn.BatchNorm2d(in_channels*5)

        self.fc1 = nn.Linear(in_channels*5*2*32*32, num_classes)
        self.attack_classifier = nn.Softmax(dim=1)
        self.flatten = nn.Flatten()
    def forward(self, x):
        rectified_image,signature = self.signature_extractor(x)
        image = self.batch_norm(self.conv(x))
        sign = self.batch_norm(self.conv(signature))
        concatenated_image = torch.cat([image, sign], dim=1)
        fc_output = self.fc1(self.flatten(concatenated_image))
        class_probs = self.attack_classifier(fc_output)

        return class_probs,signature,rectified_image

In [None]:
x = ModelPipeline(PatchConfig())
x.signature_extractor.patch_embed

Uploading Train Dataset

In [None]:


# Define transformations
transform = transforms.Compose([
    transforms.Resize((32, 32)),  # Resize images to a fixed size
    transforms.ToTensor(),          # Convert images to PyTorch tensors
])

# Load the dataset
aid_dataset = datasets.ImageFolder(root=".corrected_dataset/train", transform=transform)

# Define train-val split ratio
train_ratio = 0.8
val_ratio = 1 - train_ratio

# Calculate dataset sizes
total_size = len(aid_dataset)
train_size = int(train_ratio * total_size)
val_size = total_size - train_size

# Split the dataset into train and val subsets
train_dataset, val_dataset = random_split(aid_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

# Create DataLoaders for train and val subsets
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)  # Batch size can be adjusted
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Check dataset sizes
print(f"Train dataset size: {len(train_dataset)}")
print(f"Val dataset size: {len(val_dataset)}")


In [None]:
let=PatchConfig()
model = ModelPipeline(let)
model

In [None]:
total_params = sum(p.numel() for p in model.parameters())
print(f"Total Parameters: {total_params}")

Training loop

In [None]:


in_channels = 3
num_classes = 2  # Assuming binary classification for adversarial attacks

# Initialize your model
model = ModelPipeline(PatchConfig())

# Wrap the model with DataParallel
model = nn.DataParallel(model)

# Move the model to GPU
model = model.to('cuda')

classification_loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-5)

# Initialize the scheduler to reduce learning rate on plateau
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True, min_lr=1e-15)

num_epochs = 50
patience = 10  # Number of epochs to wait for improvement
min_delta = 0.001  # Minimum change to consider as improvement
best_val_loss = float('inf')
patience_counter = 0

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0

    for image, labels in train_loader:
        labels = F.one_hot(labels, num_classes=num_classes).float()  # Shape: (batch_size, num_classes)
        image, labels = image.to('cuda'), labels.to('cuda')

        optimizer.zero_grad()
        classification_output, signature, rectified_image = model(image)
        classification_loss = classification_loss_fn(classification_output, labels)
        classification_loss.backward()
        optimizer.step()

        train_loss += classification_loss.item()

        # Accuracy calculation
        predicted_labels = torch.argmax(classification_output, dim=1)  # Predicted class
        true_labels = torch.argmax(labels, dim=1)  # True class
        train_correct += (predicted_labels == true_labels).sum().item()
        train_total += labels.size(0)

    train_accuracy = train_correct / train_total

    # Validation phase
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    model.eval()
    with torch.no_grad():
        for embeddings, labels in val_loader:
            labels = F.one_hot(labels, num_classes=num_classes).float()
            embeddings, labels = embeddings.to('cuda'), labels.to('cuda')
            outputs, signature, rectified_image = model(embeddings)
            loss = classification_loss_fn(outputs, labels)
            val_loss += loss.item()

            # Accuracy calculation
            predicted_labels = torch.argmax(outputs, dim=1)  # Predicted class
            true_labels = torch.argmax(labels, dim=1)  # True class
            val_correct += (predicted_labels == true_labels).sum().item()
            val_total += labels.size(0)

    val_accuracy = val_correct / val_total

    print(f"Epoch {epoch+1}/{num_epochs}, "
          f"Training Loss: {train_loss/len(train_loader):.4f}, "
          f"Validation Loss: {val_loss/len(val_loader):.4f}, "
          f"Training Accuracy: {train_accuracy:.4f}, "
          f"Validation Accuracy: {val_accuracy:.4f}")

    # Scheduler step based on validation loss
    scheduler.step(val_loss)

    if (epoch + 1) % 10 == 0:
        torch.save(model.state_dict(), f'model{epoch+1}.pth')


    # Early stopping logic
    if val_loss < best_val_loss - min_delta:
        best_val_loss = val_loss
        patience_counter = 0  # Reset counter if validation loss improves
        torch.save(model.state_dict(), 'final_best_model.pth')  # Save the best model
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print("Early stopping triggered.")
        break

# Load the best model after training if needed
model.load_state_dict(torch.load('final_best_model.pth'))


In [None]:
torch.save(model, "entire_model.pth")
print("Entire model saved as entire_model.pth")

In [None]:
# Assuming your model instance is called `model`
torch.save(model.state_dict(), "model_weights.pth")
print("Model weights saved as model_weights.pth")


In [None]:
# Assuming the same architecture is defined or imported
model.load_state_dict(torch.load("model_weights.pth"))
model.eval()  # Set to evaluation mode for inference


In [None]:
model = torch.load("entire_model.pth")
model.eval()


Uploading Test Dataset

In [None]:

# Define transformations
transform = transforms.Compose([
    transforms.Resize((32, 32)),  # Resize images to a fixed size
    transforms.ToTensor(),          # Convert images to PyTorch tensors
])

# Load the dataset
test_dataset = datasets.ImageFolder(root=".corrected_dataset/test", transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Check dataset sizes
print(f"test dataset size: {len(test_dataset)}")


Inference

In [None]:

model.load_state_dict(torch.load('final_best_model.pth'))

# Test phase
test_loss = 0.0
test_correct = 0
test_total = 0

# Store predictions and true labels for evaluation
all_predicted_labels = []
all_true_labels = []

model.eval()  # Set model to evaluation mode
with torch.no_grad():
    for embeddings, labels in test_loader:
        labels = F.one_hot(labels, num_classes=num_classes).float()
        embeddings, labels = embeddings.to('cuda'), labels.to('cuda')
        outputs, signature, rectified_image = model(embeddings)
        loss = classification_loss_fn(outputs, labels)
        test_loss += loss.item()

        # Accuracy calculation
        predicted_labels = torch.argmax(outputs, dim=1)  # Predicted class
        true_labels = torch.argmax(labels, dim=1)  # True class
        test_correct += (predicted_labels == true_labels).sum().item()
        test_total += labels.size(0)

        # Store predictions and true labels for metrics
        all_predicted_labels.extend(predicted_labels.cpu().numpy())
        all_true_labels.extend(true_labels.cpu().numpy())

test_accuracy = test_correct / test_total
print(f"Test Accuracy: {test_accuracy:.4f}")

# Calculate precision, recall, F1 score, and confusion matrix
precision = precision_score(all_true_labels, all_predicted_labels, average='weighted')
recall = recall_score(all_true_labels, all_predicted_labels, average='weighted')
f1 = f1_score(all_true_labels, all_predicted_labels, average='weighted')
conf_matrix = confusion_matrix(all_true_labels, all_predicted_labels)

print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print("Confusion Matrix:")
print(conf_matrix)
