In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install -r /kaggle/input/final-data/requirements_clip.txt

In [2]:
import torch
import clip
import cv2
import numpy as np
from PIL import Image 
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer, get_linear_schedule_with_warmup
from torch.optim import AdamW # Corrected import
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.optim as optim
import json
import os
import time
from tqdm import tqdm

In [3]:
class CricketCommentaryDataset(Dataset):
    def __init__(self, annotations, clip_model, preprocess, num_frames=8):
        self.annotations = annotations
        self.clip_model = clip_model
        self.preprocess = preprocess
        self.num_frames = num_frames
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

    def __len__(self):
        return len(self.annotations)

    def extract_frames(self, video_path, start_time, end_time):
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            print("video not opened")
            return torch.zeros(self.num_frames, 3, 224, 224)

        fps = cap.get(cv2.CAP_PROP_FPS)
        start_frame = int(start_time * fps)
        end_frame = int(end_time * fps)

        if start_frame >= end_frame:
            return torch.zeros(self.num_frames, 3, 224, 224)

        stride = max(1, (end_frame - start_frame) // self.num_frames)
        frames = []

        for i in range(start_frame, end_frame, stride):
            cap.set(cv2.CAP_PROP_POS_FRAMES, i)
            ret, frame = cap.read()
            if ret:
                # Action-focused cropping
                h, w, _ = frame.shape
                crop_size = min(h, w) // 2
                y_start = max(0, (h - crop_size) // 2)
                x_start = max(0, (w - crop_size) // 2)
                cropped = frame[y_start:y_start+crop_size, x_start:x_start+crop_size]

                cropped = cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)
                pil_image = Image.fromarray(cropped)
                frames.append(self.preprocess(pil_image))
            if len(frames) >= self.num_frames:
                break
        
        # Always ensure we return exactly num_frames
        if len(frames) < self.num_frames:
            num_pad = self.num_frames - len(frames)
            frames.extend([torch.zeros(3, 224, 224)] * num_pad)

        cap.release()
        return torch.stack(frames)

    def __getitem__(self, idx):
        ann = self.annotations[idx]
        frames = self.extract_frames(
            ann["video_path"],
            ann["start_time"],
            ann["end_time"]
        )

        # Use the prompt and response directly
        prompt = ann["prompt"]
        response = ann["response"]

        return {
            "frames": frames,
            "prompt": prompt,
            "response": response
        }
    
class TemporalTransformerEncoder(nn.Module):
    def __init__(self, embed_dim, num_heads, num_layers, num_frames, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_frames = num_frames

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        nn.init.trunc_normal_(self.cls_token, std=0.02)

        self.position_embed = nn.Parameter(torch.zeros(1, num_frames + 1, embed_dim))
        nn.init.trunc_normal_(self.position_embed, std=0.02)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=4 * embed_dim,
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, x):
        B = x.size(0)
        cls_token = self.cls_token.expand(B, 1, -1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.position_embed[:, :x.size(1)]
        x = self.transformer(x)
        return {
            "cls": x[:, 0],
            "tokens": x[:, 1:]
        }

class CricketCommentator(nn.Module):
    def __init__(self, train_mode=False, num_frames=8, gpt2_train_layers=2):
        super().__init__()
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.num_frames = num_frames

        import clip
        self.clip, self.preprocess = clip.load("ViT-B/32", device=self.device)
        self.clip = self.clip.float()

        if train_mode:
            for param in self.clip.parameters():
                param.requires_grad = False

        self.temporal_encoder = TemporalTransformerEncoder(
            embed_dim=512,
            num_heads=8,
            num_layers=3,
            num_frames=num_frames,
            dropout=0.1
        ).to(self.device).float()

        self.projection = nn.Sequential(
            nn.Linear(512, 1024),
            nn.GELU(),
            nn.LayerNorm(1024),
            nn.Dropout(0.1),
            nn.Linear(1024, 1024),
            nn.Tanh()
        ).to(self.device).float()

        self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.gpt2 = GPT2LMHeadModel.from_pretrained("gpt2-medium").to(self.device).float()
        self.gpt2.config.pad_token_id = self.tokenizer.eos_token_id

        # 🔒 Freeze all GPT-2 parameters
        for param in self.gpt2.parameters():
            param.requires_grad = False

        # 🔓 Unfreeze last N transformer blocks
        if train_mode and gpt2_train_layers > 0:
            for block in self.gpt2.transformer.h[-gpt2_train_layers:]:
                for param in block.parameters():
                    param.requires_grad = True

            for param in self.gpt2.lm_head.parameters():
                param.requires_grad = True
            for param in self.gpt2.transformer.ln_f.parameters():
                param.requires_grad = True

    def forward(self, frames):
        batch_size = frames.shape[0]
        frames = frames.view(-1, 3, 224, 224)
        with torch.no_grad():
            frame_features = self.clip.encode_image(frames.to(self.device))
        frame_features = frame_features.view(batch_size, self.num_frames, -1).float()
        frame_features = F.normalize(frame_features, p=2, dim=-1)

        temporal_out = self.temporal_encoder(frame_features)
        visual_embeds = self.projection(temporal_out["cls"])
        return F.normalize(visual_embeds, p=2, dim=-1).unsqueeze(1)

    def compute_loss(self, batch):
        frames = batch["frames"].to(self.device)
        prompts = batch["prompt"]
        responses = batch["response"]

        visual_embeds = self.forward(frames)

        full_texts = [f"{p} {r}" for p, r in zip(prompts, responses)]
        inputs = self.tokenizer(
            full_texts,
            return_tensors="pt",
            padding='longest',
            truncation=True,
            max_length=128
        ).to(self.device)

        prompt_inputs = self.tokenizer(
            prompts,
            return_tensors="pt",
            padding='longest',
            truncation=True,
            max_length=128
        ).to(self.device)
        prompt_lengths = prompt_inputs.attention_mask.sum(dim=1)

        text_embeddings = self.gpt2.transformer.wte(inputs.input_ids)
        input_embeddings = torch.cat([visual_embeds, text_embeddings], dim=1)

        visual_mask = torch.ones(visual_embeds.shape[:2]).to(self.device)
        combined_mask = torch.cat([visual_mask, inputs.attention_mask], dim=1)

        labels = inputs.input_ids.clone()
        labels = torch.cat([
            -100 * torch.ones(labels.size(0), 1, dtype=torch.long).to(self.device),
            labels
        ], dim=1)

        for i, plen in enumerate(prompt_lengths):
            labels[i, 1:1 + plen] = -100

        outputs = self.gpt2(
            inputs_embeds=input_embeddings,
            attention_mask=combined_mask,
            labels=labels
        )

        return outputs.loss


def collate_fn(batch):
    """Custom collate function to handle frames"""
    frames = [item["frames"] for item in batch]
    prompts = [item["prompt"] for item in batch]
    responses = [item["response"] for item in batch]
    
    # Stack all frames
    frames_tensor = torch.stack(frames)
    
    return {
        "frames": frames_tensor,
        "prompt": prompts,
        "response": responses
    }

def train_model(model, train_loader, val_loader, epochs, lr):
    """Train visual projection layers with enhanced settings"""
    device = model.device
    # Collect GPT-2 trainable parameters
    gpt2_trainable = [p for p in model.gpt2.parameters() if p.requires_grad]
 
    # Unfreeze temporal encoder and projection
    for param in model.temporal_encoder.parameters():
        param.requires_grad = True
    for param in model.projection.parameters():
        param.requires_grad = True
        
    optimizer = AdamW([
        {'params': model.temporal_encoder.parameters(), 'lr': lr},
        {'params': model.projection.parameters(), 'lr': lr},
        {'params': gpt2_trainable, 'lr': lr * 0.5}
    ], weight_decay=0.01)
    
    # Learning rate schedulers
    total_steps = len(train_loader) * epochs
    warmup_scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps
    )
    
    plateau_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min',
        factor=0.5,
        patience=3,
        verbose=True
    )
    
    # Early stopping
    best_val_loss = float('inf')
    early_stop_patience = 5
    epochs_no_improve = 0
    early_stop = False
    
    for epoch in range(epochs):
        if early_stop:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break
            
        print(f"Epoch {epoch+1}/{epochs}")
        model.train()
        train_loss = 0.0
        
        # Training
        for batch in tqdm(train_loader, desc="Training"):
            optimizer.zero_grad()
            loss = model.compute_loss(batch)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            warmup_scheduler.step()
            train_loss += loss.item()
        
        avg_train_loss = train_loss / len(train_loader)
        
        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validation"):
                loss = model.compute_loss(batch)
                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(val_loader)
        plateau_scheduler.step(avg_val_loss)
        current_lr = optimizer.param_groups[0]['lr']
        
        print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}")
        
        # Early stopping check
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), "best_cricket_commentator.pth")
            print("Saved best model")
        else:
            epochs_no_improve += 1
            print(f"No improvement for {epochs_no_improve}/{early_stop_patience} epochs")
            if epochs_no_improve >= early_stop_patience:
                early_stop = True
    
    # Load best model
    model.load_state_dict(torch.load("best_cricket_commentator.pth"))
    return model


In [4]:
# Load annotations
with open("/kaggle/input/final-data/Json_360B.json", "r") as f:
    annotations = json.load(f)

# Split into train and validation (85/15)
split_idx = int(0.85 * len(annotations))
train_annotations = annotations[:split_idx]
val_annotations = annotations[split_idx:]

# Initialize CLIP for dataset
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device)

# Create datasets
train_dataset = CricketCommentaryDataset(
    train_annotations,
    clip_model,
    preprocess,
    num_frames=8
)
val_dataset = CricketCommentaryDataset(
    val_annotations,
    clip_model,
    preprocess,
    num_frames=8
)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=8,  # Small batch size due to memory constraints
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=2
)
val_loader = DataLoader(
    val_dataset,
    batch_size=4,
    collate_fn=collate_fn,
    num_workers=2
)

# Initialize model in training mode
model = CricketCommentator(train_mode=True).to(device)

100%|███████████████████████████████████████| 338M/338M [02:48<00:00, 2.10MiB/s]


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/718 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.52G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [5]:
# Train the model
trained_model = train_model(model,train_loader,val_loader,epochs=30,lr=1e-4)
  # Save final model
torch.save(trained_model.state_dict(), "cricket_commentator_final.pth")
print("Model saved successfully!")

Epoch 1/30


Training:   0%|          | 0/39 [00:00<?, ?it/s]`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
Training: 100%|██████████| 39/39 [06:38<00:00, 10.22s/it]
Validation: 100%|██████████| 14/14 [01:14<00:00,  5.33s/it]


Epoch 1 | Train Loss: 6.2700 | Val Loss: 4.9263 | LR: 3.33e-05
Saved best model
Epoch 2/30


Training: 100%|██████████| 39/39 [06:36<00:00, 10.16s/it]
Validation: 100%|██████████| 14/14 [01:13<00:00,  5.24s/it]


Epoch 2 | Train Loss: 4.3325 | Val Loss: 3.4906 | LR: 6.67e-05
Saved best model
Epoch 3/30


Training: 100%|██████████| 39/39 [06:34<00:00, 10.11s/it]
Validation: 100%|██████████| 14/14 [01:13<00:00,  5.26s/it]


Epoch 3 | Train Loss: 2.8148 | Val Loss: 3.1422 | LR: 1.00e-04
Saved best model
Epoch 4/30


Training: 100%|██████████| 39/39 [06:31<00:00, 10.05s/it]
Validation: 100%|██████████| 14/14 [01:13<00:00,  5.25s/it]


Epoch 4 | Train Loss: 2.4087 | Val Loss: 3.0739 | LR: 9.63e-05
Saved best model
Epoch 5/30


Training: 100%|██████████| 39/39 [06:38<00:00, 10.21s/it]
Validation: 100%|██████████| 14/14 [01:15<00:00,  5.41s/it]


Epoch 5 | Train Loss: 2.1992 | Val Loss: 3.0188 | LR: 9.26e-05
Saved best model
Epoch 6/30


Training: 100%|██████████| 39/39 [06:40<00:00, 10.26s/it]
Validation: 100%|██████████| 14/14 [01:15<00:00,  5.36s/it]


Epoch 6 | Train Loss: 2.1436 | Val Loss: 3.0054 | LR: 8.89e-05
Saved best model
Epoch 7/30


Training: 100%|██████████| 39/39 [06:41<00:00, 10.30s/it]
Validation: 100%|██████████| 14/14 [01:15<00:00,  5.43s/it]


Epoch 7 | Train Loss: 2.0236 | Val Loss: 2.9600 | LR: 8.52e-05
Saved best model
Epoch 8/30


Training: 100%|██████████| 39/39 [06:42<00:00, 10.31s/it]
Validation: 100%|██████████| 14/14 [01:14<00:00,  5.33s/it]


Epoch 8 | Train Loss: 1.9391 | Val Loss: 2.9476 | LR: 8.15e-05
Saved best model
Epoch 9/30


Training: 100%|██████████| 39/39 [06:41<00:00, 10.30s/it]
Validation: 100%|██████████| 14/14 [01:15<00:00,  5.38s/it]


Epoch 9 | Train Loss: 1.8908 | Val Loss: 2.9669 | LR: 7.78e-05
No improvement for 1/5 epochs
Epoch 10/30


Training: 100%|██████████| 39/39 [06:44<00:00, 10.37s/it]
Validation: 100%|██████████| 14/14 [01:15<00:00,  5.39s/it]


Epoch 10 | Train Loss: 1.8204 | Val Loss: 2.9698 | LR: 7.41e-05
No improvement for 2/5 epochs
Epoch 11/30


Training: 100%|██████████| 39/39 [06:45<00:00, 10.39s/it]
Validation: 100%|██████████| 14/14 [01:14<00:00,  5.35s/it]


Epoch 11 | Train Loss: 1.8060 | Val Loss: 2.9781 | LR: 7.04e-05
No improvement for 3/5 epochs
Epoch 12/30


Training: 100%|██████████| 39/39 [06:43<00:00, 10.35s/it]
Validation: 100%|██████████| 14/14 [01:14<00:00,  5.34s/it]


Epoch 12 | Train Loss: 1.7672 | Val Loss: 2.9871 | LR: 3.33e-05
No improvement for 4/5 epochs
Epoch 13/30


Training: 100%|██████████| 39/39 [06:36<00:00, 10.16s/it]
Validation: 100%|██████████| 14/14 [01:13<00:00,  5.23s/it]


Epoch 13 | Train Loss: 1.6623 | Val Loss: 2.9942 | LR: 6.30e-05
No improvement for 5/5 epochs
Early stopping triggered at epoch 14
Model saved successfully!


In [6]:
trained_model.eval()

CricketCommentator(
  (clip): CLIP(
    (visual): VisionTransformer(
      (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
      (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (transformer): Transformer(
        (resblocks): Sequential(
          (0): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (1): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuan