## **MOSAIC PS-2(MATHEMATICAL PROBLEM SOLVER)**
The PS centres around a multimodal problem, taking into account the integration and processing of multiple modalities including text and images combining computer vision, natural language processing, mathematical reasoning and deep learning.

## **Pipeline Followed to achieve this task -**


1.   ***Using datasets***- A hugging face library to decode the arrow file and import the dataset first


2.   ***Computer Vision and image feature and information extraction-*** This was done taking into account the two major components of any image pertaining to a geometrical problem- a geometrical figure or area and the labellings corresponding to these figures(like a parallel line symbol, angle markings, named vertices, labelled side lengths, highlighted points,arcs etc)

    For this purpose, one major task was ocr based text detection and other major task was the implementation of foundational CNN to detect and learn geometrical patterns in the images and augmenting both to a foundational attention layer along with the problem text to gain a complete context of the problem.
    
    Models used for the preproccessing problem analysing step:


    1.   EfficientNetB4(employed as a CNN model- finetuned on the dataset using **Lora**)
    2.   easyocr - For ocr based detection of text in the image(directly
         employed as a lightweight ocr model to preprocess the
         context to be sent to the reasoning model)




3.  ***Sequence to Sequence Natural Language Processing( To learn about the problem description alongside the image):-***
A Natural Language processing model was needed to be designed in order to process a sentence long text per image and problem- the problem statement or description. For this purpose the following model has been finetuned and trained on the given dataset:-


       1.   BERT(bert-base-uncased)-finetuned as a text encoder to convert
            the given problem description into semantic contextual
            embeddings in a word embedding space ultimately
            to be augmented in a fusion space along with
            image feature vector to interpret the combined
            context(finetuned using lora)
4. ***Fusion Layer:-***
A custom multihead Attention layer to combine the extracted visual features from the image, the ocr text, and the problem description embeddings(encoded) in an augmented space to combine all individual contexts into one wholesome interpretable context of the problem. This layer is trained on the given dataset to learn the overall context better(apart from the finetuning and training of individual context handling models).

5. ***Reasoning model:-***
After first preprocessing the dataset and the problems to be in a stage to be fed to a reasoning model, taking into account the combined as well as individual contexts, structed prompts are generated to be given to the reasoning model based on which it reasons and predicts the correct answer(by learning the supervised dataset with correct answer labels(ground truth) being given and utilized to minimize the error.
For this purpose the following models have been finetuned, trained and utiized to predict the answer:-


    1. t5-small decoder:- The combined context and learnt embeddings generated by the
    Fusion layer is decoded and fed as text-based prompts
    to the model(the prompts so generated are also trained to
    be as accurate as possible via force teaching)

    2. phi-1.5:- open-source large language model (LLM) developed by
    Microsoft Research as part of the Phi series of small,
    efficient language models.Easily finetunable and lightweight to
    be trained with the limited capacity of a T4-GPU
    (as against other large models like deepseek-math7b-instruct etc.)

The models used above have been finetuned using Lora and mixed precision


In [None]:
   #just a flowchart representation-kindly dont run this cell

                   +-----------------+
                   |  Geometry Image |
                   +-----------------+
                            |
                            v
        +-------------------+--------------------+
        |                                        |
        v                                        v
+------------------+                   +-------------------+
|   EfficientNet-B4|                   |      EasyOCR      |
| (Visual Features) |                 | (Text from Image)  |
+------------------+                   +-------------------+
        |                                        |
        +-------------------+--------------------+
                            |
                            v
                +-------------------------+
                |  Visual Context Encoder |
                +-------------------------+

                            ↑
                            |
                            ↓
           +-----------------------------+
           |  Problem Description (Text) |
           +-----------------------------+
                            |
                            v
                   +------------------+
                   |  BERT Text Encoder|
                   +------------------+

                            ↓
                            ↓
       +---------------------------------------------+
       |         Fusion Layer (Multihead Attention)  |
       | Combines Visual, OCR, and Textual Contexts  |
       +---------------------------------------------+
                            |
                            v
              +------------------------------+
              |  Unified Multimodal Embedding|
              +------------------------------+
                            |
                            v
         +------------------------------------------+
         |    Reasoning Module (LLM Decoders)       |
         |    - T5-small                             |
         |    - Phi-1.5 (Lightweight Finetuned LLM) |
         +------------------------------------------+
                            |
                            v
                +------------------------+
                |  Final Answer Output   |
                +------------------------+


In [3]:
!pip install torch
!pip install torchvision
!pip install peft
!pip install tqdm
!pip install PIL
!pip install easyocr
!pip install numpy

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [4]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.

In [5]:
import torch #importing the necessary modules

import torch.nn as nn

from torch.utils.data import DataLoader, random_split

from torchvision import transforms

from transformers import(
BertModel, BertTokenizer,
T5ForConditionalGeneration, T5Tokenizer,
AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig
)

from peft import get_peft_model, LoraConfig, TaskType

from datasets import load_from_disk

from tqdm import tqdm

from PIL import Image

import easyocr

import numpy as np

import random

from torchvision import models

torch.cuda.empty_cache()

In [None]:
#initializing the necessary values to be used throughout the program
DEVICE='cuda'

BATCH_SIZE=1

EPOCHS=1

VAL_SPLIT=0.3

SEED=42

random.seed(SEED)

torch.manual_seed(SEED)

<torch._C.Generator at 0x7ddfdc6fe750>

In [None]:
import torch
import torch.nn as nn
from torchvision import models
from peft import get_peft_model, LoraConfig

class EfficientNetAdapter(nn.Module):
    def __init__(self):
        super().__init__()
        base=models.efficientnet_b4(pretrained=True) #freezing the other pretrained layers
        self.features=base.features #adaptive features
        self.pool=nn.AdaptiveAvgPool2d(1) #adpative pooling

    def forward(self,x):
        x=self.features(x)
        return self.pool(x)

model=EfficientNetAdapter()
lora_config=LoraConfig(# definign lora config
        r=16,
        lora_alpha=32,
        target_modules=["fc1","fc2"], #fully connected classifier head layers to be trained
        lora_dropout=0.05,
        bias="none",
        task_type="FEATURE_EXTRACTION"
        )
model.features=get_peft_model(model.features,lora_config).to(DEVICE)



In [None]:
text_encoder=BertModel.from_pretrained("bert-base-uncased")

text_encoder=get_peft_model(text_encoder, LoraConfig(#finetuning bert-base encased using lora
    r=16,
    lora_alpha=32,
    target_modules=["query","value"], # Specifies the layers to apply LoRA fine-tuning.
    lora_dropout=0.05,
    task_type=TaskType.FEATURE_EXTRACTION
)).to(DEVICE)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
# This layer combines the extracted visual features from the image, the ocr text, and the problem description embeddings(encoded) in an augmented space to combine all individual contexts into
# one wholesome interpretable context of the problem using the technique of multihead attention(in a transformer)
 class FusionLayer(nn.Module):
    def __init__(self, img_dim=1792, text_dim=768):
        super().__init__()
        self.img_proj = nn.Linear(img_dim, 512)
        self.text_proj = nn.Linear(text_dim, 512)
        self.attention = nn.MultiheadAttention(512, 8, batch_first=True) #defining the multihead attention layer
        self.output_proj = nn.Linear(512, 2048)  # New projection layer

    def forward(self, img_feats, text_feats):
        img_proj = self.img_proj(img_feats) #image and text projected to embeddings of dimensions 512
        text_proj = self.text_proj(text_feats)
        img_proj = img_proj.unsqueeze(1)
        text_proj = text_proj.unsqueeze(1)
        attn_out, _ = self.attention(img_proj, text_proj, text_proj)
        out = attn_out.squeeze(1)               # [B, 512]
        out = self.output_proj(out)             # [B, 2048]
        return out

In [None]:
t5_decoder= T5ForConditionalGeneration.from_pretrained(
    "t5-small",
).to(DEVICE)
"""
This decoder is defined to decode the embeddings finally generated by the fusion layer after combining the visual and text based context into text, which
is to be fed to the phi-1.5 reasoning model as prompt to predict the answer.
"""

In [None]:
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from peft import prepare_model_for_kbit_training, get_peft_model, LoraConfig

# Step 1: Load config and tokenizer
config = AutoConfig.from_pretrained("microsoft/phi-1.5")
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1.5")

# Step 2: Define model
phi_model = AutoModelForCausalLM.from_pretrained(
    "microsoft/phi-1.5",
    config=config,
    trust_remote_code=True,
    use_safetensors=True,
    revision="main",
    device_map='auto'
)

# Step 3: Prepare for LoRA / PEFT
phi_model = prepare_model_for_kbit_training(phi_model)

# Step 4: Add Lora Config
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "k_proj"], #layers to be finetuned
    lora_dropout=0.05,
    task_type="CAUSAL_LM",
)
phi_model = get_peft_model(phi_model, lora_config).to('cuda')

# Step 4: Gradient checkpointing to reduce memory (still on CUDA)
phi_model.gradient_checkpointing_enable()

# Confirm model is on CUDA
print(f"Model on device: {next(phi_model.parameters()).device}")

Model on device: cuda:0


In [None]:
ocr_reader = easyocr.Reader(['en'], gpu=DEVICE == 'cuda')

# Preprocessing function for images
def preprocess_image(image):
    if image.mode == 'RGBA':
        image = image.convert('RGB')
    return transforms.Compose([
        transforms.Resize((380, 380)),
        transforms.ToTensor()
    ])(image)

# Dataset class
class GeometryDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):  # initalizing dataset and bert-tokenizer for every instance
        self.dataset = dataset
        self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    def __len__(self):  # calculating the length of the dataset for evry instance
        return len(self.dataset)

    def __getitem__(self, idx):  #
        item = self.dataset[idx]

        # Preprocess image
        image = preprocess_image(item['images'][0])

        # OCR from original image
        ocr_text = " ".join([res[1] for res in ocr_reader.readtext(np.array(item['images'][0]))])
        cleaned_problem = item['problem'].replace("<image>", "").strip()

        # Combine all text
        full_text = f"{cleaned_problem}\nOCR: {ocr_text}\nOptions: " + ", ".join(
            [f"{chr(65 + i)}: {choice}" for i, choice in enumerate(item['choices'])]
        )

        # BERT tokenization
        text_inputs = self.bert_tokenizer(
            full_text,
            return_tensors="pt",
            max_length=512,
            padding="max_length",
            truncation=True
        )

        # Ground truth label
        answer_idx = ord(item['ground_truth'].upper()) - ord('A')

        # Handle option_values only if numeric(for mse loss calculation in case of numeric values otherwise cosine similarity for string options(delat by raw_choices))
        try:
            option_values = torch.tensor([float(c) for c in item['choices']], dtype=torch.float32,requires_grad=False)
        except ValueError:
            option_values = None  # Handle later in loss function
        if image is None or text_inputs['input_ids'] is None or text_inputs['attention_mask'] is None or option_values is None or item['choices'] is None:
            return None

        return {
            'input_ids': text_inputs['input_ids'].squeeze(0),
            'attention_mask': text_inputs['attention_mask'].squeeze(0),
            'image': image,
            'labels': torch.tensor(answer_idx),
            'option_values': option_values,
            'raw_choices': item['choices']  # Save raw choices for hybrid loss decision
        }

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
#testing the tokenizer using custom input
text = "What is 2 + 2? A. 2 B. 3 C. 4 D. 5"
tokens = tokenizer(
    text,
    return_tensors="pt",
    padding="max_length",
    max_length=512,
    truncation=True
)

print(type(tokens))
print(tokens)

<class 'transformers.tokenization_utils_base.BatchEncoding'>
{'input_ids': tensor([[ 101, 2054, 2003, 1016, 1009, 1016, 1029, 1037, 1012, 1016, 1038, 1012,
         1017, 1039, 1012, 1018, 1040, 1012, 1019,  102,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,  

In [None]:
from google.colab import drive #mount google drive to access the training dataset
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
from torch.utils.data.dataloader import default_collate
def collate_fn_skip_none(batch): #custom collate function to only allow tensors, numpy arrays and alllowed data structures to pass for each batch while training
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:
        return None  # let the training loop handle this
    return default_collate(batch)
full_dataset=load_from_disk("/content/drive/My Drive/MOSAIC PS-2")['train']

train_size=int(len(full_dataset)*(1-VAL_SPLIT)) #training data size

train_dataset, val_dataset=random_split(full_dataset,[train_size,len(full_dataset)-train_size]) #training and cross validation dataset split

train_loader=DataLoader(GeometryDataset(train_dataset),batch_size=BATCH_SIZE,shuffle=True,collate_fn=collate_fn_skip_none) #loading training data

val_loader=DataLoader(GeometryDataset(val_dataset),batch_size=BATCH_SIZE,collate_fn=collate_fn_skip_none) #loading cross validation data

image_encoder=EfficientNetAdapter().to(DEVICE) #intializing image enocder-storing everything to 'cuda' to prevent backpropagation chain from breaking at any point

fusion_layer=FusionLayer().to(DEVICE)# initializing fusion_layer-storing everything to 'cuda' to prevent backpropagation chain from breaking at any point
optimizer=torch.optim.AdamW([
    {'params': phi_model.parameters(), "lr": 1e-5},
    {'params':image_encoder.parameters(),'lr':1e-5}, #learning rate for params determined via hyperparameter tuning
    {'params':text_encoder.parameters(),'lr':2e-5},
    {'params':fusion_layer.parameters(),'lr':3e-5}

])



In [None]:
from torch.nn.functional import cosine_embedding_loss, gumbel_softmax
import torch.nn as nn
from torch import autocast

ce_loss = nn.CrossEntropyLoss()
mse_loss = nn.MSELoss()

# Convert parameters to float16 if needed(to prevent error in case any torch.uint8 param is passed)
for name, param in phi_model.named_parameters():
    if isinstance(param, torch.nn.Parameter) and param.dtype == torch.uint8:
        param.data = param.data.to(torch.float16)

def run_epoch(model, loader, optimizer=None, is_train=True):
    scaler = torch.cuda.amp.GradScaler()# Initializes a gradient scaler for mixed precision training to improve performance and reduce memory usage on GPU
    total_loss = 0.0
    correct = 0
    total = 0

    model.train() if is_train else model.eval()
   # Initialize a progress bar for training or validation
progress = tqdm(loader, desc="Training" if is_train else "Validation")

# Enable or disable gradient computation based on whether it's training or validation
with torch.set_grad_enabled(is_train):
    for batch in progress:
        # Skip invalid batches
        if batch is None:
            continue

        # Zero out gradients if in training mode
        if is_train:
            optimizer.zero_grad()

        # Move images and other inputs to the specified device (GPU)
        images = batch['image'].to(DEVICE)

        # Convert grayscale images to 3-channel (RGB) format if necessary
        if images.shape[1] == 1:
            images = images.repeat(1, 3, 1, 1)

        input_ids = batch['input_ids'].to(DEVICE)  # Tokenized input IDs for text
        attention_mask = batch['attention_mask'].to(DEVICE)  # Attention mask for text
        answer_idx = batch['labels'].to(DEVICE)  # Ground truth answer indices

        options = batch['option_values']  # Numeric values for multiple-choice options (if available)
        raw_choices = batch['raw_choices']  # Raw text of multiple-choice options

        # Use mixed precision (autocast) for faster computation and reduced memory usage on GPUs
        with autocast("cuda"):
            # === Feature Encoding ===
            # Extract image features using the image encoder and flatten them
            img_feats = image_encoder(images).view(images.size(0), -1)  # [Batch_Size, Image_Feature_Dim]

            # Extract text features using the text encoder and compute mean pooling over token embeddings
            text_feats = text_encoder(
                input_ids=input_ids,
                attention_mask=attention_mask
            ).last_hidden_state.mean(1)  # [Batch_Size, Text_Feature_Dim]

            # === Feature Fusion ===
            # Concatenate image and text features and project them into a joint feature space
            fused_feats = fusion_layer(img_feats, text_feats)
            fused_feats = fused_feats.unsqueeze(1)  # Add a sequence dimension for decoder input

            # Prepare decoder input IDs (start token followed by ground truth answer index)
            decoder_start_token_id = tokenizer.pad_token_id  # Use padding token as start token
            decoder_input_ids = torch.cat(
                [
                    torch.full((answer_idx.size(0), 1), decoder_start_token_id, device=answer_idx.device),
                    answer_idx.unsqueeze(1)  # Add a dimension for sequence tokens
                ],
                dim=1
            )
            """
            This step is done in order to decode the final visual+textual context embeddings generated by the fusion layer to be fed to the reasoning model
             as prompt to predict the answer.
            """

            # === Model Forward Pass ===
            outputs = model(
                inputs_embeds=fused_feats,  # Fused image-text embeddings as input
                attention_mask=attention_mask,  # Attention mask for the decoder
                decoder_input_ids=decoder_input_ids,  # Decoder input IDs (start token + ground truth)
                use_cache=False  # Disable caching during training/validation
            )

            logits = outputs.logits[:, -1, :4]  # Extract logits for the last token (4 options)

            # === Loss Calculation ===
            loss_ce = ce_loss(logits, answer_idx)  # Cross-entropy loss for classification

            if options is not None:
                # If numeric option values are provided:
                options_tensor = options.to(DEVICE)  # Move option values to the device
                probs = torch.softmax(logits, dim=-1)  # Compute probabilities from logits

                predicted_values = torch.matmul(probs, options_tensor.unsqueeze(-1)).squeeze()
                correct_values = options_tensor.gather(1, answer_idx.unsqueeze(1)).squeeze()

                loss_hybrid = mse_loss(predicted_values, correct_values)  # Mean squared error loss for numeric predictions

            else:
                # If numeric option values are not provided:
                choice_embeddings = []
                """
                The cosine loss is calculated for predicted embeddings and correct embeddings to measure their similarity by first encoding the text via the same text_encoder(bert model)
                and then calculating the cosine similarity loss between the predicted and correct embeddings.
                """

                for choices in raw_choices:
                    tokens = text_encoder.tokenizer(
                        choices,
                        return_tensors="pt",
                        padding="max_length",
                        truncation=True,
                        max_length=32
                    ).to(DEVICE)

                    emb = text_encoder(**tokens).last_hidden_state.mean(1)
                    choice_embeddings.append(emb)

                choice_embeddings = torch.stack(choice_embeddings)
                probs = torch.softmax(logits, dim=-1).unsqueeze(-1)

                pred_emb = (choice_embeddings * probs).sum(dim=1)

                indices = answer_idx.unsqueeze(1).unsqueeze(2).expand(-1, 1, pred_emb.size(-1))

                correct_emb = torch.gather(choice_embeddings, 1, indices).squeeze(1)

                loss_hybrid = cosine_embedding_loss(
                    pred_emb,
                    correct_emb,
                    torch.ones(pred_emb.size(0)).to(DEVICE)
                )

            loss = 0.4 * loss_ce + 0.6 * loss_hybrid  # Weighted combination of losses(weights determined by hyperparamter tuning)

            if is_train:
                # Backward pass using GradScaler for mixed precision training
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()

            total_loss += loss.item()  # Accumulate total loss

            _, predicted = torch.max(logits, 1)

            correct += (predicted == answer_idx).sum().item()

            total += answer_idx.size(0)

            progress.set_postfix({
                'loss': total_loss / (progress.n + 1),
                'acc': f"{correct / total:.2%}"
            })

    return total_loss / len(loader), correct / total


# === Training Loop ===

best_val_acc = 0.0
EPOCHS = 1

# Check if gradients are enabled for each component of the model pipeline
print("Grad check - image_encoder:", any(p.requires_grad for p in image_encoder.parameters()))
print("Grad check - text_encoder:", any(p.requires_grad for p in text_encoder.parameters()))
print("Grad check - fusion_layer:", any(p.requires_grad for p in fusion_layer.parameters()))
print("Grad check - phi_model:", any(p.requires_grad for p in phi_model.parameters()))

for epoch in range(EPOCHS):
    train_loss, train_acc = run_epoch(phi_model, train_loader, optimizer)  # Train the model on training data

    val_loss, val_acc = run_epoch(phi_model, val_loader, is_train=False)   # Evaluate the model on validation data

    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.2%}")
    print(f"Val Loss: {val_loss:.4f} | Acc: {val_acc:.2%}")
    """
    The performance metric used here is accuracy:
     Accuracy measures the proportion of correctly predicted instances (both positive and negative) out of the total predictions made by the model

    """

    if val_acc > best_val_acc:
        torch.save({
            'phi_model': phi_model.state_dict(),
            'image_encoder': image_encoder.state_dict(),
            'fusion_layer': fusion_layer.state_dict()
        }, "best_model.pth")   # Save the best model checkpoint

        best_val_acc = val_acc

print("Training completed")



  scaler = torch.cuda.amp.GradScaler()


Grad check - image_encoder: True
Grad check - text_encoder: True
Grad check - fusion_layer: True
Grad check - phi_model: True


Training: 100%|██████████| 2031/2031 [12:00<00:00,  2.82it/s, loss=inf, acc=29.83%]
Validation: 100%|██████████| 871/871 [02:25<00:00,  5.97it/s, loss=3.78e+3, acc=33.42%]



Epoch 1/1
Train Loss: inf | Acc: 29.83%
Val Loss: 3781.8049 | Acc: 33.42%
Training completed
