In [None]:
import torch
# import clip
# from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
# import pytorch_lightning as pl
# from torch.utils.data import DataLoader, TensorDataset
# import json
# import numpy as np
# from torch.optim.lr_scheduler import ReduceLROnPlateau

# with open("./dataset_cache/answer_to_idx.json", 'r') as f:
#     s = f.read()
#     answer_to_idx = json.loads(s)

# print(f"Total unique answers: {len(answer_to_idx)}")

class CrossModalFusion(nn.Module):
    """Cross-modal fusion module to better combine image and text features"""
    def __init__(self, llm_dim, image_dim):
        super().__init__()
        self.image_proj = nn.Linear(image_dim, llm_dim)
        # self.text_proj = nn.Linear(embed_dim, embed_dim)
        
        self.cross_attn = nn.MultiheadAttention(llm_dim, num_heads=4, dropout=0.4, bias=False)
        
        self.ln1 = nn.LayerNorm(llm_dim)
        self.ln2 = nn.LayerNorm(llm_dim)
        
        self.mlp = nn.Sequential(
            nn.Linear(llm_dim, llm_dim*2),
            nn.GELU(),
            nn.Linear(llm_dim*2, llm_dim)
        )
        
    def forward(self, hidden_state, image_embs):
        # Project features
        img_proj = self.image_proj(image_embs)
        
        out = self.cross_attn(hidden_state, img_proj, img_proj)[0]

        out = self.ln1(out + hidden_state)
        
        out = self.ln2(self.mlp(out)+out)
        
        return out

class AdapterTransformerLayer(nn.Module):
    def __init__(self, transformer_layer, image_dim):
        """
        A wrapper around an existing transformer layer that adds adapters after
        attention and after the feed-forward sublayers.

        Args:
            transformer_layer: One layer from a pre-trained transformer (e.g., BERTLayer)
            adapter_size: Bottleneck size for the adapters
        """
        super().__init__()
        self.layer = transformer_layer
        self.hidden_size = transformer_layer.attn.c_attn.nx  # Model-specific

        # Freeze all transformer weights (we don’t train them)
        for param in self.layer.parameters():
            param.requires_grad = False

        # Add a CrossAttention Adapter
        
        self.adapter = CrossModalFusion(self.hidden_size, image_dim)
        

    def forward(self, hidden_states, image_embeddings, encoder_attention_mask=None):
        # Standard attention (output of frozen pre-trained layer)
        
        hidden_states = self.layer(hidden_states)

        # Inject adapter after attention
        fused_hidden_state = self.adapter(hidden_states, image_embeddings)
        return fused_hidden_state



In [None]:
2304

1536

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from typing import Optional, Tuple
import math

class CrossModalFusion(nn.Module):
    """Improved cross-modal fusion module"""
    def __init__(self, llm_dim, image_dim, num_heads=8, dropout=0.1):
        super().__init__()
        self.llm_dim = llm_dim
        self.num_heads = num_heads
        
        # Project image features to match LLM dimension
        self.image_proj = nn.Sequential(
            nn.Linear(image_dim, llm_dim),
            nn.LayerNorm(llm_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        
        # Cross-attention: text queries, image keys/values
        self.cross_attn = nn.MultiheadAttention(
            llm_dim, 
            num_heads=num_heads, 
            dropout=dropout, 
            bias=False,
            batch_first=True  # Important for proper batching
        )
        
        # Layer norms
        self.ln1 = nn.LayerNorm(llm_dim)
        self.ln2 = nn.LayerNorm(llm_dim)
        
        # Feed-forward network
        self.mlp = nn.Sequential(
            nn.Linear(llm_dim, llm_dim * 4),  # Standard 4x expansion
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(llm_dim * 4, llm_dim),
            nn.Dropout(dropout)
        )
        
        # Gating mechanism to control fusion strength
        self.gate = nn.Parameter(torch.zeros(1))
        
    def forward(self, text_hidden, image_embs, attention_mask=None):
        batch_size, seq_len, hidden_dim = text_hidden[0].shape
        
        # Project image embeddings
        img_proj = self.image_proj(image_embs)  # [B, num_patches, llm_dim]
        # Cross-attention: text attends to image
        attn_out, attn_weights = self.cross_attn(
            query=text_hidden[0],
            key=img_proj,
            value=img_proj,
            key_padding_mask=None,  # Could add image padding mask here
            need_weights=True
        )
        # Apply gating and residual connection
        text_hidden = self.ln1(text_hidden[0] + self.gate * attn_out)
        # Feed-forward with residual
        ff_out = self.mlp(text_hidden)
        text_hidden = self.ln2(text_hidden + ff_out)
        return text_hidden, attn_weights

class AdapterTransformerLayer(nn.Module):
    def __init__(self, transformer_layer, image_dim, adapter_position="after"):
        """
        Args:
            transformer_layer: Original GPT-2 transformer block
            image_dim: Dimension of image embeddings
            adapter_position: "before", "after", or "parallel" to the transformer block
        """
        super().__init__()
        self.layer = transformer_layer
        self.adapter_position = adapter_position
        
        # Get hidden size from the transformer layer
        self.hidden_size = transformer_layer.attn.c_attn.weight.shape[1] // 3  # Divide by 3 because it's concatenated Q,K,V
        
        # Freeze transformer weights
        for param in self.layer.parameters():
            param.requires_grad = False
        
        # Add cross-modal adapter
        self.adapter = CrossModalFusion(
            llm_dim=self.hidden_size, 
            image_dim=image_dim,
            num_heads=8,
            dropout=0.1
        )
        
    def forward(self, hidden_states, image_embeddings=None, **kwargs):
        if self.adapter_position == "before" and image_embeddings is not None:
            # Apply adapter before transformer block
            hidden_states, _ = self.adapter(hidden_states, image_embeddings)
            hidden_states = self.layer(hidden_states, **kwargs)
            
        elif self.adapter_position == "after" and image_embeddings is not None:
            # Apply transformer block first, then adapter
            layer_outputs = self.layer(hidden_states, **kwargs)
            # Extract hidden states from layer output (could be tuple or tensor)
            if isinstance(layer_outputs, tuple):
                hidden_states = layer_outputs[0]
            else:
                hidden_states = layer_outputs
            
            # Apply adapter
            hidden_states, _ = self.adapter((hidden_states,), image_embeddings)
            
        elif self.adapter_position == "parallel" and image_embeddings is not None:
            # Parallel processing with weighted combination
            transformer_out = self.layer(hidden_states, **kwargs)
            if isinstance(transformer_out, tuple):
                transformer_out = transformer_out[0]
            
            adapter_out, _ = self.adapter((hidden_states,), image_embeddings)
            
            # Learnable combination weights
            if not hasattr(self, 'combination_weight'):
                self.combination_weight = nn.Parameter(torch.tensor(0.5))
            hidden_states = (1 - self.combination_weight) * transformer_out + \
                           self.combination_weight * adapter_out
        else:
            # No image embeddings, just use original transformer
            layer_outputs = self.layer(hidden_states, **kwargs)
            if isinstance(layer_outputs, tuple):
                hidden_states = layer_outputs[0]
            else:
                hidden_states = layer_outputs
            
        return (hidden_states,)  # Return tuple to match GPT-2 output format

class MultimodalGPT2(nn.Module):
    """Complete multimodal GPT-2 model"""
    def __init__(self, gpt2_model_name="openai-community/gpt2", image_dim=512, 
                 adapter_position="after", num_adapter_layers=6):
        super().__init__()
        
        # Load base GPT-2 model
        self.gpt2 = GPT2LMHeadModel.from_pretrained(gpt2_model_name)
        self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
        
        self.gpt2.requires_grad_(False)
        # Add padding token
        self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Store config
        self.config = self.gpt2.config
        self.image_dim = image_dim
        
        # Wrap selected layers with adapters
        # Usually better to add adapters to later layers
        num_layers = len(self.gpt2.transformer.h)
        start_layer = max(0, num_layers - num_adapter_layers)
        
        for i in range(start_layer, num_layers):
            original_layer = self.gpt2.transformer.h[i]
            self.gpt2.transformer.h[i] = AdapterTransformerLayer(
                original_layer, image_dim, adapter_position
            )
        
        print(f"Added adapters to layers {start_layer} to {num_layers-1}")
        
        # Image feature processor (optional, if you want to process raw images)
        print("self.config.n_embed", self.config.n_embd)
        self.image_processor = nn.Sequential(
            nn.Linear(512, self.config.n_embd),
            nn.LayerNorm(self.config.n_embd),
            nn.GELU(),
            nn.Dropout(0.1)
        )
        
        # Special tokens for multimodal processing
        self.image_start_token = "<image>"
        self.image_end_token = "</image>"
        
    def add_special_tokens(self):
        """Add special tokens for image boundaries"""
        special_tokens = {
            "additional_special_tokens": [self.image_start_token, self.image_end_token]
        }
        self.tokenizer.add_special_tokens(special_tokens)
        self.gpt2.resize_token_embeddings(len(self.tokenizer))
        
    def forward(self, input_ids, attention_mask=None, image_embeddings=None, 
                labels=None, **kwargs):
        """
        Args:
            input_ids: [batch_size, seq_len]
            attention_mask: [batch_size, seq_len]  
            image_embeddings: [batch_size, num_patches, image_dim]
            labels: [batch_size, seq_len] for training
        """
        # Process image embeddings if provided
        if image_embeddings is not None:
            image_embeddings = self.image_processor(image_embeddings)
        
        # Store image embeddings in a way that adapter layers can access them
        # We'll modify the forward pass to pass this through
        return self._forward_with_adapters(
            input_ids=input_ids,
            attention_mask=attention_mask,
            image_embeddings=image_embeddings,
            labels=labels,
            **kwargs
        )
    
    def _forward_with_adapters(self, input_ids, attention_mask=None, 
                              image_embeddings=None, labels=None, **kwargs):
        """Modified forward pass that handles image embeddings"""
        
        # Get token embeddings and position embeddings
        inputs_embeds = self.gpt2.transformer.wte(input_ids)
        position_ids = torch.arange(0, input_ids.size(-1), device=input_ids.device)
        position_embeds = self.gpt2.transformer.wpe(position_ids)
        
        hidden_states = inputs_embeds + position_embeds
        hidden_states = self.gpt2.transformer.drop(hidden_states)
        
        # Process attention mask to the format expected by GPT-2
        # GPT-2 expects a causal mask, we need to handle the attention mask properly
        if attention_mask is not None:
            # Convert attention mask to the format expected by GPT-2
            batch_size, seq_len = input_ids.shape
            # Create causal mask
            causal_mask = torch.tril(torch.ones((seq_len, seq_len), device=input_ids.device))
            # Expand attention mask to match causal mask dimensions
            expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2)  # [B, 1, 1, S]
            # Combine with causal mask
            attention_mask = expanded_mask * causal_mask.unsqueeze(0).unsqueeze(0)  # [B, 1, S, S]
            # Convert to the format expected by GPT-2 (large negative values for masked positions)
            attention_mask = attention_mask.to(dtype=hidden_states.dtype)
            attention_mask = (1.0 - attention_mask) * torch.finfo(hidden_states.dtype).min
        
        # Pass through transformer blocks (some have adapters)
        for i, block in enumerate(self.gpt2.transformer.h):
            if isinstance(block, AdapterTransformerLayer):
                # Pass image embeddings to adapter layers
                layer_outputs = block(hidden_states, 
                                    image_embeddings=image_embeddings, 
                                    attention_mask=attention_mask)
                hidden_states = layer_outputs[0]
                
            else:
                # Regular transformer block
                layer_outputs = block(hidden_states, attention_mask=attention_mask)
                if isinstance(layer_outputs, tuple):
                    hidden_states = layer_outputs[0]
                else:
                    hidden_states = layer_outputs

        # Final layer norm
        hidden_states = self.gpt2.transformer.ln_f(hidden_states)
        # Language modeling head
        logits = self.gpt2.lm_head(hidden_states)
        
        # Calculate loss if labels provided
        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
                           shift_labels.view(-1))
            
        return {
            'logits': logits,
            'loss': loss,
            'hidden_states': hidden_states
        }
    
    def generate_with_images(self, text, image_embeddings, max_length=100, **kwargs):
        """Generate text conditioned on both text and images"""
        # Tokenize input text
        inputs = self.tokenizer(text, return_tensors="pt", padding=True)
        
        # Generate with image conditioning
        with torch.no_grad():
            outputs = self.gpt2.generate(
                input_ids=inputs['input_ids'],
                attention_mask=inputs.get('attention_mask'),
                max_length=max_length,
                do_sample=True,
                temperature=0.7,
                pad_token_id=self.tokenizer.pad_token_id,
                **kwargs
            )
            outputs = self.gpt2.generate(
                input_ids=inputs['input_ids'],
                attention_mask=inputs.get('attention_mask'),
                max_length=max_length,
                do_sample=True,
                temperature=0.7,
                pad_token_id=self.tokenizer.pad_token_id,
                **kwargs
            )
        
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

# Example usage and training setup
def create_model_and_count_params():
    # Create model
    model = MultimodalGPT2(
        gpt2_model_name="openai-community/gpt2",
        image_dim=768,  # Vision Transformer patch embeddings
        adapter_position="after",
        num_adapter_layers=6
    )
    
    # Add special tokens
    model.add_special_tokens()
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Percentage trainable: {100 * trainable_params / total_params:.2f}%")
    
    return model

# Example of how to prepare data
def prepare_multimodal_batch(texts, image_features, tokenizer, max_length=768):
    """
    Prepare a batch of multimodal data
    
    Args:
        texts: List of text strings
        image_features: Tensor of shape [batch_size, num_patches, feature_dim]
        tokenizer: GPT-2 tokenizer
        max_length: Maximum sequence length
    """
    # Tokenize texts
    encodings = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )
    
    return {
        'input_ids': encodings['input_ids'],
        'attention_mask': encodings['attention_mask'],
        'image_embeddings': image_features,
        'labels': encodings['input_ids'].clone()  # For language modeling
    }

if __name__ == "__main__":
    model = create_model_and_count_params()
    
    # Example forward pass
    batch_size, seq_len = 2, 20
    num_patches, image_dim = 196, 512  # Typical ViT-Base patch embeddings
    
    # Dummy data
    input_ids = torch.randint(0, 1000, (batch_size, seq_len))
    image_embeddings = torch.randn(batch_size, num_patches, image_dim)
    
    # Forward pass
    outputs = model(input_ids=input_ids, image_embeddings=image_embeddings)
    print(f"Output logits shape: {outputs['logits'].shape}")

Added adapters to layers 6 to 11
self.config.n_embed 768
Total parameters: 170,898,438
Trainable parameters: 46,457,094
Percentage trainable: 27.18%
Output logits shape: torch.Size([2, 20, 50259])


In [16]:
model.load_state_dict(torch.load(r"C:\Users\Rohit Francis\Downloads\FinalModel.pth"))

<All keys matched successfully>

In [17]:
import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model2, preprocess = clip.load("ViT-B/32", device=device)

image = preprocess(Image.open('Clip.jpg')).unsqueeze(0).to(device)
text = clip.tokenize(["a clip", "a dog", "a cat"]).to(device)

with torch.no_grad():
    image_features = model2.encode_image(image)
    text_features = model2.encode_text(text)
    
    logits_per_image, logits_per_text = model2(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

Label probs: [[0.9883   0.00925  0.002338]]


In [18]:
image_features.shape

torch.Size([1, 512])

In [20]:
text = "what is in this image?"
inputs = model.tokenizer(text, return_tensors="pt", padding=True)

In [24]:
out = model(**inputs, image_features = image_features.detach())

In [31]:
print(out['logits'].shape)
pred_tokens = torch.argmax(out["logits"], dim=-1)
print(pred_tokens)

torch.Size([1, 6, 50259])
tensor([[  11,  262,  262, 1492,   30,  198]])


In [32]:
model.tokenizer.decode(pred_tokens[0], skip_special_tokens=True)

', the the book?\n'

In [35]:
text = "what do you see here?"
pred_text = ""
for i in range(10):
    text += pred_text
    inputs = model.tokenizer(text, return_tensors="pt", padding=True)
    out = model(**inputs, image_features = image_features.detach())
    # print(out['logits'].shape)
    pred_tokens = torch.argmax(out["logits"], dim=-1)
    # print(pred_tokens)
    pred_text = model.tokenizer.decode(pred_tokens[0], skip_special_tokens=True)
    print(pred_text)
    

, you think in?"

, you think in?" I're? the


, you think in?" I're? the

 you think in?" "'m not
 one

, you think in?" I're? the

 you think in?" "'m not
 one
" you're in? I'm? the I
're in?
Cause',
 of

, you think in?" I're? the

 you think in?" "'m not
 one
" you're in? I'm? the I
're in?
Cause',
 of
 you think in?" "'m? the

 think in?" "'m not one

,I think not?" the'm not the

'm
??" the
 I you
 course"
, you think in?" I're? the

 you think in?" "'m not
 one
" you're in? I'm? the I
're in?
Cause',
 of
 you think in?" "'m? the

 think in?" "'m not one

,I think not?" the'm not the

'm
??" the
 I you
 course" of think in?"
'm? the


 think in?" "'m not


" you're in? I'm? the I'm
 in?

', of

" think in?" Im not


 in?" "'m not


" you'm in
 I





" not
 the the


'm think



, you think in?" I're? the

 you think in?" "'m not
 one
" you're in? I'm? the I
're in?
Cause',
 of
 you think in?" "'m? the

 think in?" "'m not one

,I think not?" the'm not the

'm
??" the
 I you
 cou

IndexError: index out of range in self

In [2]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")



In [3]:
model.named_modules

<bound method Module.named_modules of GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)>

In [31]:
# Define adapter size
adapter_size = 768
num_blocks = len(model.transformer.h)
print(f"Num GPT2 Blocks", num_blocks)

# Wrap all encoder layers with adapter-enabled versions
for i in range(num_blocks):
    original_layer = model.transformer.h[i]
    # print(original_layer)
    
    model.transformer.h[i] = AdapterTransformerLayer(original_layer, adapter_size)
# Check that only adapters will be trained
# trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# total_params = sum(p.numel() for p in model.parameters())
# print(f"Trainable parameters: {trainable_params} / {total_params}")

# # Now you can tokenize input and train like usual.
# inputs = tokenizer("Adapters are lightweight and powerful.", return_tensors="pt")
# outputs = model(**inputs)

Num GPT2 Blocks 12


In [32]:
model.transformer.h

ModuleList(
  (0-11): 12 x AdapterTransformerLayer(
    (layer): GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D(nf=2304, nx=768)
        (c_proj): Conv1D(nf=768, nx=768)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D(nf=3072, nx=768)
        (c_proj): Conv1D(nf=768, nx=3072)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (adapter): CrossModalFusion(
      (image_proj): Linear(in_features=768, out_features=768, bias=True)
      (cross_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=False)
      )
      (ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((768,), eps=1e

In [33]:
inputs = tokenizer("Hey yesterday I had a croissant from Walmart.", return_tensors='pt')

output = model(**inputs)
print(output)

TypeError: AdapterTransformerLayer.forward() got multiple values for argument 'encoder_attention_mask'

In [None]:
def create_data():
    # Load data
    X1 = torch.load("./dataset_cache/images_embs.pt", weights_only=True)
    X2 = torch.load("./dataset_cache/text_embs.pt", weights_only=True)
    print('Shape of data:', X1.shape, X2.shape, X1.dtype, X2.dtype)
    
    # Convert to float32 if needed
    X1 = X1.to(torch.float32)
    X2 = X2.to(torch.float32)
    
    # Normalize embeddings (important for consistent scale)
    X1 = F.normalize(X1, p=2, dim=1)
    X2 = F.normalize(X2, p=2, dim=1)
    
    # Load targets
    Y = torch.load("./dataset_cache/targets.pt", weights_only=True)
    
    # Convert to one-hot encoding
    eye = torch.eye(len(answer_to_idx))[Y.squeeze()]
    
    # Print some stats
    print(f"Total samples: {len(Y)}")
    print(f"Feature dimensions: Image={X1.shape[1]}, Text={X2.shape[1]}")
    print(f"Target classes: {len(answer_to_idx)}")
    
    return TensorDataset(X1, X2, eye)

def get_dataloaders(batch_size=32):
    dataset = create_data()
    
    # Create train/val/test split (80/10/10)
    train_size = int(0.8 * len(dataset))
    # val_size = int(0.2 * len(dataset))
    val_size = len(dataset)-train_size
    # test_size = len(dataset) - train_size - val_size
    
    # Create splits with fixed seed for reproducibility
    generator = torch.Generator().manual_seed(42)
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size], generator=generator
    )
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    # test_loader = DataLoader(
    #     test_dataset, 
    #     batch_size=batch_size, 
    #     shuffle=False,
    #     num_workers=4,
    #     pin_memory=True
    # )
    
    print(f"Train samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    # print(f"Test samples: {len(test_dataset)}")
    
    return train_loader, val_loader

# Train model
def train_model(max_epochs=30, batch_size=64):
    # Create model and dataloaders
    model = EnhancedFusor(
        embedding_size=512, 
        num_heads=8,
        dropout=0.2,
        lr=5e-4,
        weight_decay=0.01
    )
    
    train_loader, val_loader = get_dataloaders(batch_size)

    # Create trainer with additional callbacks
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        enable_progress_bar=True,
        num_nodes=1,
        enable_checkpointing=True,
        callbacks=[
            pl.callbacks.ModelCheckpoint(
                monitor='val_loss',
                filename='{epoch}-{val_loss:.2f}-{val_acc:.2f}',
                save_top_k=3,
                mode='min'
            ),
            # pl.callbacks.EarlyStopping(
            #     monitor='val_loss',
            #     patience=5,
            #     mode='min'
            # ),
        ],
        gradient_clip_val=1.0,  # Prevent exploding gradients
    )
    
    # Train the model
    trainer.fit(model, train_loader, val_loader)
    
    # Test the model
    # test_result = trainer.test(model, test_loader)
    # print(f"Test results: {test_result}")
    
    return model, trainer

def inference_example(checkpoint_path=None):
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Create reverse mapping from index to answer
    idx2answers = {v: k for k, v in answer_to_idx.items()}
    
    # Load model from checkpoint or create new
    if checkpoint_path:
        model = EnhancedFusor.load_from_checkpoint(
            checkpoint_path=checkpoint_path
        ).to(device)
        print(f"Loaded model from {checkpoint_path}")
    else:
        model = EnhancedFusor().to(device)
        print("Created new model (not trained)")
    
    # Set model to evaluation mode
    model.eval()
    
    # Load and preprocess image
    _, preprocess = clip.load("ViT-B/32", device=device)
    image_path = r"C:\Users\Rohit Francis\Desktop\Codes\Datasets\VQA\dataset\images\image1.png"  # Adjust path as needed
    image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
    
    # Create text query
    # query = "what is on the left side of white oven ?"
    query = "how many garbage_bin is here in the image ?"
    text = clip.tokenize([query]).to(device)
    
    print(f"Query: {query}")
    
    # Get embeddings
    with torch.no_grad():
        image_embs = model.encode_image(image).to(torch.float32)
        text_embs = model.encode_text(text).to(torch.float32)
    
    # Get prediction
    with torch.no_grad():
        out = model(image_embs, text_embs)
        
    # Get top-5 predictions
    top_probs, top_indices = torch.topk(F.softmax(out, dim=-1), k=5)
    
    # Display results
    print("\nTop 5 predictions:")
    for i, (prob, idx) in enumerate(zip(top_probs[0], top_indices[0])):
        answer = idx2answers.get(idx.item(), "unknown")
        print(f"{i+1}. {answer} ({prob.item()*100:.2f}%)")


if __name__ == "__main__":
    # Uncomment to train the model
    # train_model(max_epochs=30, batch_size=64)
    inference_example("./lightning_logs/version_4/checkpoints/epoch=7-val_loss=3.66-val_acc=0.32.ckpt")
    # For inference with trained model (replace with your checkpoint path)
    # inference_example(checkpoint_path="./lightning_logs/version_0/checkpoints/epoch=1-step=686.ckpt")