### PACKAGES

In [None]:
!pip uninstall timm transformers accelerate peft -y

In [None]:
import os
import sys
sys.path.append("/kaggle/input/pip-vlmfs")

import json  , math , timm , einops
import torch
import torch.nn as nn 
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np 
import tqdm
from tqdm import tqdm

from transformers import AutoModelForCausalLM, AutoTokenizer 

import lightning as LIGHTNING
from lightning.pytorch.callbacks import Callback

from typing import List, Union , Dict

# LANGUAGE MODEL

In [None]:
model_path = "/kaggle/input/vlm-merger"
language_model = AutoModelForCausalLM.from_pretrained(
    # MODEL_PATH,
    model_path , 
    trust_remote_code=True,
    device_map="cuda:0",
    torch_dtype = torch.bfloat16,
    low_cpu_mem_usage=True,
)

tokenizer = AutoTokenizer.from_pretrained(model_path)

# VISION MODEL

In [None]:
class TimmCNNModel(nn.Module):
    def __init__(self, num_classes: int = 8, model_name: str = "mobilenetv4_conv_medium.e500_r256_in1k"):
        super().__init__()
        
        self.backbone = timm.create_model(
             'local-dir:/kaggle/input/codefiles/efficientnet_b0/efficientnet_b0',
            pretrained=True,
            num_classes=0,
        )
        
        self.feature_dim = self.backbone.num_features
        

        self.classifier = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(self.feature_dim, 512),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(512),
        
            nn.Dropout(0.1),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, num_classes)
        )

    
    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        return self.backbone(x)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.forward_features(x)
        logits = self.classifier(features)
        return logits

image_model = TimmCNNModel(num_classes=8)

weights = torch.load("/kaggle/input/d/aneeshmukkamala/timmweights/finalcheckpoint.pth")
image_model.load_state_dict(weights['model_state_dict'])

for param in image_model.parameters():
    param.requires_grad = False



# PROJECTOR MODEL

In [None]:
class Projector_4to3d(nn.Module):
        
    def __init__(self, cnn_dim: int = 1280, llm_dim: int = 2048, num_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        self.cnn_dim = cnn_dim
        self.llm_dim = llm_dim
        
        # Spatial positional embeddings for 8x8 grid
        self.spatial_pos_embed = nn.Parameter(torch.randn(64, cnn_dim))
        
        # Multi-scale feature processing
        self.spatial_conv = nn.Conv2d(cnn_dim, cnn_dim // 2, 1)  # Reduce channels while preserving spatial
        self.global_pool = nn.AdaptiveAvgPool2d(1)  # Global context
        
        # Enhanced projection layers
        self.input_proj = nn.Sequential(
            nn.Linear(cnn_dim, llm_dim),
            nn.LayerNorm(llm_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        # Multi-head self-attention for spatial reasoning
        self.spatial_attention = nn.MultiheadAttention(
            embed_dim=llm_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        # Cross-attention for text-image alignment
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=llm_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        self.norm1 = nn.LayerNorm(llm_dim)
        self.norm2 = nn.LayerNorm(llm_dim)
        
        # Enhanced FFN
        self.ffn = nn.Sequential(
            nn.Linear(llm_dim, llm_dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(llm_dim * 4, llm_dim),
            nn.Dropout(dropout)
        )
        
        self.norm3 = nn.LayerNorm(llm_dim)
        
        # Token compression layer (optional - reduces from 64 to fewer tokens)
        self.compress_tokens = nn.Parameter(torch.randn(32, llm_dim))  # Learnable query tokens
        self.token_compression = nn.MultiheadAttention(
            embed_dim=llm_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        self._init_weights()
    
    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.LayerNorm):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight)
    
    def forward(self, cnn_features: torch.Tensor, text_embeddings: torch.Tensor = None) -> torch.Tensor:
        batch_size = cnn_features.shape[0]
        
        # Multi-scale processing
        spatial_features = self.spatial_conv(cnn_features)  # (B, 640, 8, 8)
        global_context = self.global_pool(cnn_features).flatten(1)  # (B, 1280)
        
        # Flatten spatial features and add positional encoding
        x = einops.rearrange(cnn_features, "b c h w -> b (h w) c")  # (B, 64, 1280)
        pos_embeddings = self.spatial_pos_embed.unsqueeze(0).expand(batch_size, -1, -1)
        x = x + pos_embeddings
        
        # Project to LLM dimension - keep in float32 for attention operations
        x = self.input_proj(x)  # (B, 64, 2048)
        
        # Self-attention for spatial reasoning
        attended_x, spatial_attn_weights = self.spatial_attention(x, x, x)
        x = self.norm1(x + attended_x)
        
        # Cross-attention with text (if available during training)
        if text_embeddings is not None:
            # Convert text embeddings to float32 for attention computation
            text_embeddings_float = text_embeddings.float()
            cross_attended, cross_attn_weights = self.cross_attention(x, text_embeddings_float, text_embeddings_float)
            x = self.norm2(x + cross_attended)
        
        # FFN
        ffn_out = self.ffn(x)
        x = self.norm3(x + ffn_out)
        
        # Optional token compression
        compress_queries = self.compress_tokens.unsqueeze(0).expand(batch_size, -1, -1)
        compressed_x, _ = self.token_compression(compress_queries, x, x)
        
        return compressed_x  # (B, 32, 2048) - compressed representation


projector = Projector_4to3d(cnn_dim=1280, llm_dim=2048, num_heads=8)
checkpoint = "/kaggle/input/lmweights/vmweights/projector.pth"
weights = torch.load(checkpoint)
projector.load_state_dict(weights)

for param in projector.parameters():
    param.requires_grad = False

# weights.keys()

# LIGHTNING MODEL

In [None]:
class Model(nn.Module):
    def __init__(self, image_model, language_model, projector, tokenizer, prompt="Describe the medical image:"):
        super().__init__()
        self.image_model = image_model 
        self.language_model = language_model
        self.projector = projector
        self.tokenizer = tokenizer
        self.eos_token = tokenizer.eos_token
        self.prompt = prompt
        
        device = next(self.language_model.parameters()).device
        
        self.image_model.to(device)
        self.projector.to(device)
        
        # Create prompt embeddings
        prompt_tokens = tokenizer(text=prompt, return_tensors="pt").input_ids.to(device)
        prompt_embeddings = language_model.get_input_embeddings()(prompt_tokens).detach()
        self.register_buffer('prompt_embeddings', prompt_embeddings)
        
        # Contrastive learning components
        self.temperature = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.image_projection_head = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Linear(512, 256)
        ).to(device)
        self.text_projection_head = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Linear(512, 256)
        ).to(device)
    
    @property
    def device(self):
        return next(self.parameters()).device
    
    def forward(self, patches: torch.Tensor, texts: List[str], compute_contrastive: bool = True):
        device = self.device
        patches = patches.to(device)
        
        image_features = self.image_model.backbone.forward_features(patches)
        
        tokenized = self.tokenizer(
            text=[text + self.tokenizer.eos_token for text in texts],
            padding=True,
            truncation=True,
            max_length=CFG.MAX_LENGTH,
            return_tensors="pt",
        )
        tokenized = {k: v.to(device) for k, v in tokenized.items()}
        text_embeddings = self.language_model.get_input_embeddings()(tokenized["input_ids"])
        
        patch_embeddings = self.projector(image_features, text_embeddings)
        
        patch_embeddings = patch_embeddings.to(torch.bfloat16)
        text_embeddings = text_embeddings.to(torch.bfloat16)
        
        # Concatenate embeddings
        embeddings = torch.cat([
            self.prompt_embeddings.expand(patches.size(0), -1, -1),
            patch_embeddings,
            text_embeddings,
        ], dim=1)
        
        # Create attention mask
        prompt_mask = torch.ones(patches.size(0), self.prompt_embeddings.size(1), device=device)
        patch_mask = torch.ones(patches.size(0), patch_embeddings.size(1), device=device)
        attention_mask = torch.cat([prompt_mask, patch_mask, tokenized["attention_mask"]], dim=1)
        
        # Create labels
        prompt_labels = torch.full((patches.size(0), self.prompt_embeddings.size(1)), CFG.LABEL_MASK, device=device)
        patch_labels = torch.full((patches.size(0), patch_embeddings.size(1)), CFG.LABEL_MASK, device=device)
        text_labels = tokenized["input_ids"].clone()
        labels = torch.cat([prompt_labels, patch_labels, text_labels], dim=1)
        labels[attention_mask == 0] = CFG.LABEL_MASK
        
        llm_output = self.language_model(
            inputs_embeds=embeddings,
            attention_mask=attention_mask,
            labels=labels
        )
        
        total_loss = llm_output.loss
        loss_dict = {"language_loss": llm_output.loss}
        
        if compute_contrastive:
            # Contrastive loss between image and text
            image_global = patch_embeddings.mean(dim=1)  # Global image representation
            text_global = text_embeddings.mean(dim=1)    # Global text representation
            
            # Project to contrastive space
            image_proj = self.image_projection_head(image_global.float())
            text_proj = self.text_projection_head(text_global.float())
            
            # Normalize
            image_proj = F.normalize(image_proj, dim=-1)
            text_proj = F.normalize(text_proj, dim=-1)
            
            # Compute contrastive loss
            logits = torch.matmul(image_proj, text_proj.t()) * self.temperature.exp()
            labels_contrastive = torch.arange(len(logits), device=device)
            
            contrastive_loss = (F.cross_entropy(logits, labels_contrastive) + 
                              F.cross_entropy(logits.t(), labels_contrastive)) / 2
            
            total_loss = total_loss + 0.1 * contrastive_loss  
            loss_dict["contrastive_loss"] = contrastive_loss
        
        # Attention regularization loss
        if hasattr(self.projector, 'spatial_attention'):
            # Encourage attention diversity (prevent attention collapse)
            attn_entropy_loss = 0.0
            # This would be computed from attention weights if we save them
            loss_dict["attention_entropy_loss"] = attn_entropy_loss
        
        return {
            "loss": total_loss,
            "logits": llm_output.logits,
            "loss_breakdown": loss_dict
        }
    
    def generate(self, patches: torch.Tensor, generator_kwargs: dict[str, Union[int, float]]):
        device = self.device
        patches = patches.to(device)
        
        image_features = self.image_model.backbone.forward_features(patches)
        patch_embeddings = self.projector(image_features)
        patch_embeddings = patch_embeddings.to(torch.bfloat16)
        
        embeddings = torch.cat([
            self.prompt_embeddings.expand(patches.size(0), -1, -1),
            patch_embeddings,
        ], dim=1)
        
        prompt_mask = torch.ones(patches.size(0), self.prompt_embeddings.size(1), device=device)
        patch_mask = torch.ones(patches.size(0), patch_embeddings.size(1), device=device)
        attention_mask = torch.cat([prompt_mask, patch_mask], dim=1)
        
        return self.language_model.generate(
            inputs_embeds=embeddings,
            attention_mask=attention_mask,
            **generator_kwargs
        )

In [None]:
vlm_model = Model(image_model, language_model, projector , tokenizer, prompt="Describe this image:")
vlm_model = vlm_model.to(torch.device("cuda:0"))

In [None]:
import numpy as np

patches = np.load("/kaggle/input/miccaireg/images.npy"  , mmap_mode = "r")

start = 5100
end = 5109
patches_batch = np.array(patches[start:end])
patches_batch = torch.tensor(patches_batch)
patches_batch.shape

In [None]:
generator_kwargs = {
    
    "max_new_tokens": 50,
    "do_sample": True,
    "temperature": 0.4,
    "top_p": 0.9,
    "pad_token_id": tokenizer.eos_token_id
}
    
batch_size = 3
generated_ids = []

for i in tqdm(range(0, len(patches_batch), batch_size)):
    batch_chunk = patches_batch[i:i+batch_size]
    chunk_ids = vlm_model.generate(batch_chunk, generator_kwargs)
    generated_ids.extend(chunk_ids)

In [None]:
generated_ids[0]

In [None]:
generated_texts = []
for new_tokens in generated_ids:
    text = tokenizer.decode(new_tokens, skip_special_tokens=True)
    generated_texts.append(text.strip())

with open("inference.json" ,"w") as f:
    json.dump(generated_texts , f , indent=2)

In [None]:
# with open("/kaggle/input/codefiles/train.json" , "r") as f:
#     d = json.load(f)

# d[start:end]

In [None]:
generated_texts

In [None]:
!nvidia-smi