In [None]:
!pip uninstall transformers timm accelerate peft unsloth bitsandbytes xformers -y

In [None]:
%%capture
import os

os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "disabled"

In [None]:
%%capture

import sys

sys.path.append("/kaggle/input/pip-unsloth")
sys.path.append("/kaggle/input/pip-vlmfs")
# sys.path.append("/kaggle/input/pip-quantization")

import unsloth
from unsloth import FastModel
import json  , math , timm , einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np 

from transformers import AutoModelForCausalLM, AutoTokenizer 

import lightning as LIGHTNING
from lightning.pytorch.callbacks import Callback

from typing import List, Union , Dict

In [None]:
model, tokenizer = FastModel.from_pretrained(
    model_name = "/kaggle/input/gemma3",
    dtype = torch.bfloat16, 
    # dtype = torch.float32, 

    load_in_4bit = False , 
    full_finetuning = False,
)

In [None]:
language_model = FastModel.get_peft_model(
    model,
    finetune_vision_layers     = False, 
    finetune_audio_layers      = False,
    
    finetune_language_layers   = True, 
    finetune_attention_modules = True,  
    finetune_mlp_modules       = True,  

    r = 256,
    lora_alpha = 256*4,
    lora_dropout = 0.04,
    bias = "none",
)

In [None]:
language_model = language_model.to(dtype=torch.bfloat16)
language_model.print_trainable_parameters()

# IMAGE COMPONENTS

### IMAGE MODEL

In [None]:
class TimmCNNModel(nn.Module):
    # timm/mobilenetv4_conv_medium.e500_r256_in1k
    def __init__(self, num_classes: int = 8, model_name: str = "efficientnet_b0"):
        super().__init__()
        
        self.backbone = timm.create_model(
             'local-dir:/kaggle/input/codefiles/efficientnet_b0/efficientnet_b0',
            pretrained=True,
            num_classes=0,
            )
        
        self.feature_dim = self.backbone.num_features
        

        self.classifier = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(self.feature_dim, 512),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(512),
        
            nn.Dropout(0.1),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, num_classes)
        )

        
    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        return self.backbone(x)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.forward_features(x)
        logits = self.classifier(features)
        return logits


image_model = TimmCNNModel(num_classes=8)

weights = torch.load("/kaggle/input/timmweights/finalcheckpoint.pth")
image_model.load_state_dict(weights['model_state_dict'])

for param in image_model.parameters():
    param.requires_grad = False

In [None]:
image_model.feature_dim

### PROJECTOR

In [None]:
class Projector_4to3d(nn.Module):
        
    def __init__(self, cnn_dim: int = 1280, llm_dim: int = 2048, num_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        self.cnn_dim = cnn_dim
        self.llm_dim = llm_dim
        
        # Spatial positional embeddings for 8x8 grid
        self.spatial_pos_embed = nn.Parameter(torch.randn(64, cnn_dim))
        
        # Multi-scale feature processing
        self.spatial_conv = nn.Conv2d(cnn_dim, cnn_dim // 2, 1)  # Reduce channels while preserving spatial
        self.global_pool = nn.AdaptiveAvgPool2d(1)  # Global context
        
        # Enhanced projection layers
        self.input_proj = nn.Sequential(
            nn.Linear(cnn_dim, llm_dim),
            nn.LayerNorm(llm_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        # Multi-head self-attention for spatial reasoning
        self.spatial_attention = nn.MultiheadAttention(
            embed_dim=llm_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        # Cross-attention for text-image alignment
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=llm_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        self.norm1 = nn.LayerNorm(llm_dim)
        self.norm2 = nn.LayerNorm(llm_dim)
        
        # Enhanced FFN
        self.ffn = nn.Sequential(
            nn.Linear(llm_dim, llm_dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(llm_dim * 4, llm_dim),
            nn.Dropout(dropout)
        )
        
        self.norm3 = nn.LayerNorm(llm_dim)
        
        # Token compression layer (optional - reduces from 64 to fewer tokens)
        self.compress_tokens = nn.Parameter(torch.randn(32, llm_dim))  # Learnable query tokens
        self.token_compression = nn.MultiheadAttention(
            embed_dim=llm_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        self._init_weights()
    
    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.LayerNorm):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight)
    
    def forward(self, cnn_features: torch.Tensor, text_embeddings: torch.Tensor = None) -> torch.Tensor:
        batch_size = cnn_features.shape[0]
        
        # Multi-scale processing
        spatial_features = self.spatial_conv(cnn_features)  # (B, 640, 8, 8)
        global_context = self.global_pool(cnn_features).flatten(1)  # (B, 1280)
        
        # Flatten spatial features and add positional encoding
        x = einops.rearrange(cnn_features, "b c h w -> b (h w) c")  # (B, 64, 1280)
        pos_embeddings = self.spatial_pos_embed.unsqueeze(0).expand(batch_size, -1, -1)
        x = x + pos_embeddings
        
        # Project to LLM dimension - keep in float32 for attention operations
        x = self.input_proj(x)  # (B, 64, 2048)
        
        # Self-attention for spatial reasoning
        attended_x, spatial_attn_weights = self.spatial_attention(x, x, x)
        x = self.norm1(x + attended_x)
        
        # Cross-attention with text (if available during training)
        if text_embeddings is not None:
            # Convert text embeddings to float32 for attention computation
            text_embeddings_float = text_embeddings.float()
            cross_attended, cross_attn_weights = self.cross_attention(x, text_embeddings_float, text_embeddings_float)
            x = self.norm2(x + cross_attended)
        
        # FFN
        ffn_out = self.ffn(x)
        x = self.norm3(x + ffn_out)
        
        # Optional token compression
        compress_queries = self.compress_tokens.unsqueeze(0).expand(batch_size, -1, -1)
        compressed_x, _ = self.token_compression(compress_queries, x, x)
        
        return compressed_x  # (B, 32, 2048) - compressed representation


projector = Projector_4to3d(cnn_dim=1280, llm_dim=2048, num_heads=8)

lora weights are in float32 always

# DATALOADER

In [None]:
class CFG:
    
    EPOCHS = 18
    GRAD_ACC = 1
    
    TRAIN_BATCH_SIZE = 10
    
    TRAIN_START = 0
    TRAIN_END = 5000
    
    LABEL_MASK = -100
    MAX_LENGTH = 50
    
    VM_LR = 2e-4
    LLM_LR = 2e-5


In [None]:
with open("/kaggle/input/miccaireg/labels.json", "r") as f:
    texts = json.load(f)

patches = np.load("/kaggle/input/miccaireg/images.npy" , mmap_mode="r")

print(f"Patches shape: {np.shape(patches)}, Texts length: {len(texts)}")

In [None]:
l = []

for x in texts:
    l.append(len(tokenizer(text=x)["input_ids"][0]))

import matplotlib.pyplot as plt 
plt.plot(l)

In [None]:
c = 0

for x in l:
    if x<50:
        c+=1

c

In [None]:
class REGDataset(Dataset):
    def __init__(self, patches_mmap, texts: list[str], start_idx: int, end_idx: int):
        self.patches_mmap = patches_mmap
        self.texts = texts
        self.start_idx = start_idx
        self.end_idx = end_idx
    
    def __len__(self):
        return self.end_idx - self.start_idx
    
    def __getitem__(self, idx):
        actual_idx = self.start_idx + idx
        patches = (self.patches_mmap[actual_idx])
        texts = self.texts[actual_idx]
        return torch.tensor(patches), texts


train_data = REGDataset(patches, texts, CFG.TRAIN_START, CFG.TRAIN_END)
train_dl = DataLoader(train_data, CFG.TRAIN_BATCH_SIZE, pin_memory=True, shuffle=True, drop_last=True)

# MODEL

In [None]:
class Model(nn.Module):
    def __init__(self, image_model, language_model, projector, tokenizer, prompt="Describe the medical image:"):
        super().__init__()
        self.image_model = image_model 
        self.language_model = language_model
        self.projector = projector
        self.tokenizer = tokenizer
        self.eos_token = tokenizer.eos_token
        self.prompt = prompt
        
        device = next(self.language_model.parameters()).device
        
        self.image_model.to(device)
        self.projector.to(device)
        
        # Create prompt embeddings
        prompt_tokens = tokenizer(text=prompt, return_tensors="pt").input_ids.to(device)
        prompt_embeddings = language_model.get_input_embeddings()(prompt_tokens).detach()
        self.register_buffer('prompt_embeddings', prompt_embeddings)
        
        # Contrastive learning components
        self.temperature = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.image_projection_head = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Linear(512, 256)
        ).to(device)
        self.text_projection_head = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Linear(512, 256)
        ).to(device)
    
    @property
    def device(self):
        return next(self.parameters()).device
    
    def forward(self, patches: torch.Tensor, texts: List[str], compute_contrastive: bool = True):
        device = self.device
        patches = patches.to(device)
        
        image_features = self.image_model.backbone.forward_features(patches)
        
        tokenized = self.tokenizer(
            text=[text + self.tokenizer.eos_token for text in texts],
            padding=True,
            truncation=True,
            max_length=CFG.MAX_LENGTH,
            return_tensors="pt",
        )
        tokenized = {k: v.to(device) for k, v in tokenized.items()}
        text_embeddings = self.language_model.get_input_embeddings()(tokenized["input_ids"])
        
        patch_embeddings = self.projector(image_features, text_embeddings)
        
        patch_embeddings = patch_embeddings.to(torch.bfloat16)
        text_embeddings = text_embeddings.to(torch.bfloat16)
        
        # Concatenate embeddings
        embeddings = torch.cat([
            self.prompt_embeddings.expand(patches.size(0), -1, -1),
            patch_embeddings,
            text_embeddings,
        ], dim=1)
        
        # Create attention mask
        prompt_mask = torch.ones(patches.size(0), self.prompt_embeddings.size(1), device=device)
        patch_mask = torch.ones(patches.size(0), patch_embeddings.size(1), device=device)
        attention_mask = torch.cat([prompt_mask, patch_mask, tokenized["attention_mask"]], dim=1)
        
        # Create labels
        prompt_labels = torch.full((patches.size(0), self.prompt_embeddings.size(1)), CFG.LABEL_MASK, device=device)
        patch_labels = torch.full((patches.size(0), patch_embeddings.size(1)), CFG.LABEL_MASK, device=device)
        text_labels = tokenized["input_ids"].clone()
        labels = torch.cat([prompt_labels, patch_labels, text_labels], dim=1)
        labels[attention_mask == 0] = CFG.LABEL_MASK
        
        llm_output = self.language_model(
            inputs_embeds=embeddings,
            attention_mask=attention_mask,
            labels=labels
        )
        
        total_loss = llm_output.loss
        loss_dict = {"language_loss": llm_output.loss}
        
        if compute_contrastive:
            # Contrastive loss between image and text
            image_global = patch_embeddings.mean(dim=1)  # Global image representation
            text_global = text_embeddings.mean(dim=1)    # Global text representation
            
            # Project to contrastive space
            image_proj = self.image_projection_head(image_global.float())
            text_proj = self.text_projection_head(text_global.float())
            
            # Normalize
            image_proj = F.normalize(image_proj, dim=-1)
            text_proj = F.normalize(text_proj, dim=-1)
            
            # Compute contrastive loss
            logits = torch.matmul(image_proj, text_proj.t()) * self.temperature.exp()
            labels_contrastive = torch.arange(len(logits), device=device)
            
            contrastive_loss = (F.cross_entropy(logits, labels_contrastive) + 
                              F.cross_entropy(logits.t(), labels_contrastive)) / 2
            
            total_loss = total_loss + 0.1 * contrastive_loss  
            loss_dict["contrastive_loss"] = contrastive_loss
        
        # Attention regularization loss
        if hasattr(self.projector, 'spatial_attention'):
            # Encourage attention diversity (prevent attention collapse)
            attn_entropy_loss = 0.0
            # This would be computed from attention weights if we save them
            loss_dict["attention_entropy_loss"] = attn_entropy_loss
        
        return {
            "loss": total_loss,
            "logits": llm_output.logits,
            "loss_breakdown": loss_dict
        }
    
    def generate(self, patches: torch.Tensor, generator_kwargs: dict[str, Union[int, float]]):
        device = self.device
        patches = patches.to(device)
        
        image_features = self.image_model.backbone.forward_features(patches)
        patch_embeddings = self.projector(image_features)
        patch_embeddings = patch_embeddings.to(torch.bfloat16)
        
        embeddings = torch.cat([
            self.prompt_embeddings.expand(patches.size(0), -1, -1),
            patch_embeddings,
        ], dim=1)
        
        prompt_mask = torch.ones(patches.size(0), self.prompt_embeddings.size(1), device=device)
        patch_mask = torch.ones(patches.size(0), patch_embeddings.size(1), device=device)
        attention_mask = torch.cat([prompt_mask, patch_mask], dim=1)
        
        return self.language_model.generate(
            inputs_embeds=embeddings,
            attention_mask=attention_mask,
            **generator_kwargs
        )

In [None]:
class LightningModule(LIGHTNING.LightningModule):
    def __init__(self, model: Model):
        super().__init__()
        self.model = model
        self.automatic_optimization = False  
    
    def training_step(self, batch, batch_idx):
        opt = self.optimizers()
        sch = self.lr_schedulers()
        
        patches, texts = batch
        
        output = self.model(patches, texts, compute_contrastive=True)
        
        total_loss = output["loss"]
        loss_breakdown = output["loss_breakdown"]
        
        self.manual_backward(total_loss)
        
        self.clip_gradients(opt, gradient_clip_val=1.0, gradient_clip_algorithm="norm")
        
        opt.step()
        opt.zero_grad()
        sch.step()
        
        self.log("train_loss", total_loss, prog_bar=True)
        for loss_name, loss_value in loss_breakdown.items():
            self.log(f"{loss_name}", loss_value)
        
        
        return total_loss        
    def configure_optimizers(self):
        params = [
            {"params": self.model.projector.parameters(), "lr": CFG.VM_LR, "weight_decay": 1e-4},
            {"params": self.model.image_projection_head.parameters(), "lr": CFG.VM_LR, "weight_decay": 1e-4},
            {"params": self.model.text_projection_head.parameters(), "lr": CFG.VM_LR, "weight_decay": 1e-4},
            {"params": [self.model.temperature], "lr": CFG.VM_LR, "weight_decay": 0.0},
            {"params": [p for p in self.model.language_model.parameters() if p.requires_grad], "lr": CFG.LLM_LR, "weight_decay": 1e-5}
        ]

        optimizer = torch.optim.AdamW(params, eps=1e-8)
        
        # Cosine annealing with warmup
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=[param_group["lr"] for param_group in optimizer.param_groups],
            total_steps=self.trainer.estimated_stepping_batches,
            pct_start=0.1,  # 10% warmup
            anneal_strategy='cos',
            div_factor=25.0,
            final_div_factor=1000.0,
        )
        
        return [optimizer], [{"scheduler": scheduler, "interval": "step"}]


In [None]:
model = Model(image_model, language_model, projector , tokenizer)
lightning_module = LightningModule(model)

In [None]:
step_train_losses = []
step_language_losses = []
step_contrastive_losses = []
step_attention_entropy_losses = []

class PrintLossCallback(Callback):
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        metrics = trainer.callback_metrics
        
        if 'train_loss' in metrics:
            step_train_losses.append(float(metrics['train_loss']))
        if 'train_language_loss' in metrics:
            step_language_losses.append(float(metrics['train_language_loss']))
        if 'train_contrastive_loss' in metrics:
            step_contrastive_losses.append(float(metrics['train_contrastive_loss']))
        if 'train_attention_entropy_loss' in metrics:
            step_attention_entropy_losses.append(float(metrics['train_attention_entropy_loss']))

    def on_train_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics
        text = f"\nEpoch {trainer.current_epoch} Summary:"
        for key, value in metrics.items():
            text += f"  {key}: {value:.4f}"
                
        print(text)

In [None]:
trainer = LIGHTNING.Trainer(
    max_epochs=  CFG.EPOCHS,
    # precision='bf16',
    # precision = "32",
    accumulate_grad_batches= CFG.GRAD_ACC,
    # gradient_clip_val=1.0,
    enable_progress_bar=True,
    
    log_every_n_steps=100,
    enable_checkpointing=False,
    callbacks=[PrintLossCallback()]
)

In [None]:
l = []
for c in language_model.parameters():
    l.append(c.dtype)

set(l)

In [None]:
trainer.fit(
    model = lightning_module,
    train_dataloaders = train_dl,
    datamodule = None,
)

In [None]:
os.makedirs("vmweights", exist_ok=True)
os.makedirs("lmweights_lora", exist_ok=True)

lightning_module.model.language_model.save_pretrained("/kaggle/working/lmweights_lora/")
torch.save(lightning_module.model.projector.state_dict(), "/kaggle/working/vmweights/projector.pth")

In [None]:
try:
    np.save('step_train_losses.npy', np.array(step_train_losses))
    np.save('step_language_losses.npy', np.array(step_language_losses))
    np.save('step_contrastive_losses.npy', np.array(step_contrastive_losses))
    np.save('step_attention_entropy_losses.npy', np.array(step_attention_entropy_losses))
except Exception as e:
    print(e)