## **Building Hybrid SOTA World Model**

In [None]:
# --- SEtup & Install ---
!pip install -q transformers accelerate peft bitsandbytes
!pip install -q matplotlib pillow tqdm decord
!pip install -q timm einops

In [None]:
# --- Load Data in from Google Drive ---
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# --- GPU Game ---
import torch
torch.cuda.is_available()
gpu_game = !nvidia-simi --query-gpu=gpu_name --format=csv,noheader
print(f"GPU in use: {gpu_name[0]}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Device count: {torch.cuda.device_count()}")

print(f"✅ GPU: {torch.cuda.get_device_name(0)}")
print(f"✅ VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# ============================================
# CELL 1: Setup & Install
# ============================================
!pip install -q transformers accelerate peft bitsandbytes
!pip install -q matplotlib pillow tqdm

from google.colab import drive
drive.mount('/content/drive')

import torch
print(f"✅ GPU: {torch.cuda.get_device_name(0)}")
print(f"✅ VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# ============================================
# CELL 2: Model Definition (SIMPLIFIED)
# ============================================
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoProcessor
from peft import LoraConfig, get_peft_model
from PIL import Image

class InternVLQwenWorldModel(nn.Module):
    """
    SOTA World Model (SIMPLIFIED):
    - InternVL 3.5-8B for perception (using simple processor)
    - Qwen3-4B-Thinking for world dynamics
    """

    def __init__(self, use_lora=True, qwen_size="4B"):
        super().__init__()

        # 1. PERCEPTION: InternVL 3.5-8B
        print("Loading InternVL 3.5-8B...")
        self.perception = AutoModel.from_pretrained(
            'OpenGVLab/InternVL3_5-8B',
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            trust_remote_code=True,
            device_map="auto"
        )

        # Use AutoProcessor (SIMPLE!)
        self.vl_processor = AutoProcessor.from_pretrained(
            'OpenGVLab/InternVL3_5-8B',
            trust_remote_code=True
        )

        # 2. WORLD MODEL: Qwen3
        qwen_models = {
            "4B": "Qwen/Qwen3-4B-Thinking",
            "1.7B": "Qwen/Qwen3-1.7B-Instruct",
        }

        qwen_path = qwen_models.get(qwen_size, "Qwen/Qwen3-1.7B-Instruct")  # Default to 1.7B for speed
        print(f"Loading {qwen_path}...")

        self.world_model = AutoModelForCausalLM.from_pretrained(
            qwen_path,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        self.wm_tokenizer = AutoTokenizer.from_pretrained(qwen_path)

        # Hidden sizes
        self.internvl_hidden = 4096
        self.qwen_hidden = 896

        # 3. LoRA (only on Qwen3 for speed)
        if use_lora:
            print("Applying LoRA to Qwen3...")

            # Freeze InternVL (it's already SOTA)
            for param in self.perception.parameters():
                param.requires_grad = False

            # LoRA on Qwen3 only
            lora_config = LoraConfig(
                r=32,
                lora_alpha=64,
                target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
                lora_dropout=0.1,
                bias="none"
            )

            self.world_model = get_peft_model(self.world_model, lora_config)
            print("✅ Qwen3 LoRA applied")

        # 4. Projection
        self.projection = nn.Linear(self.qwen_hidden, self.internvl_hidden)

        print("✅ Model ready!")

    def perceive(self, image, text):
        """
        Extract features using InternVL (SIMPLE VERSION).
        """
        # Simple processing with AutoProcessor
        inputs = self.vl_processor(
            text=text,
            images=image,
            return_tensors="pt"
        ).to(self.perception.device)

        # Get features
        with torch.no_grad():
            outputs = self.perception(**inputs, output_hidden_states=True)
            features = outputs.hidden_states[-1][:, -1, :]  # [batch, 4096]

        return features

    def predict_future(self, scene_features, action):
        """Use Qwen3 to predict next state."""

        prompt = f"Predict next driving state. Action: steer={action[0].item():.2f}, speed={action[1].item():.1f}m/s"

        inputs = self.wm_tokenizer(prompt, return_tensors="pt").to(self.world_model.device)

        outputs = self.world_model(**inputs, output_hidden_states=True, return_dict=True)
        prediction_features = outputs.hidden_states[-1][:, -1, :]

        return prediction_features

    def forward(self, current_image, text, action):
        """Full forward pass."""
        current_features = self.perceive(current_image, text)
        predicted_wm_features = self.predict_future(current_features, action)
        predicted_vl_features = self.projection(predicted_wm_features)
        return predicted_vl_features

    def compute_loss(self, current_image, text, action, future_image):
        """MSE loss."""
        predicted = self.forward(current_image, text, action)

        with torch.no_grad():
            true_next = self.perceive(future_image, text)

        loss = nn.functional.mse_loss(predicted, true_next)
        return loss

print("Initializing model...")
model = InternVLQwenWorldModel(use_lora=True, qwen_size="1.7B")  # Use 1.7B for speed
print("✅ Model ready!")

# ============================================
# CELL 3: FIXED Dataset (Properly Sequential!)
# ============================================
import json
import os
from torch.utils.data import Dataset
from collections import defaultdict

class SequentialDriveLMDataset(Dataset):
    """
    FIXED: Actually creates temporal sequences!
    Groups by scene_token, sorts by timestamp.
    """

    def __init__(self, json_path, image_dir, max_samples=100):
        with open(json_path, 'r') as f:
            all_data = json.load(f)

        print("Building temporal sequences...")

        # Group by scene
        scenes = defaultdict(list)
        for sample in all_data:
            scene_token = sample.get('scene_token', 'unknown')
            scenes[scene_token].append(sample)

        # Sort each scene by timestamp
        for scene_token in scenes:
            scenes[scene_token].sort(key=lambda x: x.get('timestamp', 0))

        # Build sequential pairs
        self.samples = []

        for scene_token, scene_samples in scenes.items():
            # Create pairs: (frame_t, frame_t+1)
            for i in range(len(scene_samples) - 1):
                current = scene_samples[i]
                next_sample = scene_samples[i + 1]

                # Check images exist
                curr_path = os.path.join(image_dir, current['key_frame'])
                next_path = os.path.join(image_dir, next_sample['key_frame'])

                if os.path.exists(curr_path) and os.path.exists(next_path):
                    # Get text description
                    text = "What is happening in this driving scene?"
                    if 'QA' in current and current['QA']:
                        text = current['QA'][0]['q']

                    self.samples.append({
                        'current': curr_path,
                        'next': next_path,
                        'text': text,
                        'scene': scene_token,
                        'time_delta': next_sample.get('timestamp', 0) - current.get('timestamp', 0)
                    })

                    # Stop if we have enough
                    if len(self.samples) >= max_samples:
                        break

            if len(self.samples) >= max_samples:
                break

        print(f"✅ Created {len(self.samples)} temporal sequences")
        print(f"   From {len(scenes)} different scenes")
        if self.samples:
            avg_delta = sum(s['time_delta'] for s in self.samples) / len(self.samples)
            print(f"   Avg time between frames: {avg_delta/1e6:.2f} seconds")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]
        return {
            'current': Image.open(s['current']).convert('RGB'),
            'next': Image.open(s['next']).convert('RGB'),
            'text': s['text'],
            'action': torch.tensor([[0.0, 5.0]]),  # Default action
            'scene': s['scene']
        }

dataset = SequentialDriveLMDataset(
    '/content/drive/MyDrive/drivelm/v1_1_train_nus.json',
    '/content/drive/MyDrive/drivelm/images',
    max_samples=100
)

# ============================================
# CELL 4: Training Loop
# ============================================
from tqdm import tqdm

optimizer = torch.optim.AdamW([
    {'params': model.world_model.parameters(), 'lr': 5e-5},
    {'params': model.projection.parameters(), 'lr': 1e-4}
], weight_decay=0.01)

model.train()
print("🚀 Starting training...")

for epoch in range(2):
    total_loss = 0
    count = 0

    pbar = tqdm(dataset, desc=f"Epoch {epoch+1}/2")

    for sample in pbar:
        try:
            loss = model.compute_loss(
                sample['current'],
                sample['text'],
                sample['action'].squeeze(0),
                sample['next']
            )

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
            optimizer.step()

            total_loss += loss.item()
            count += 1

            pbar.set_postfix({'loss': f"{loss.item():.4f}"})

        except Exception as e:
            print(f"⚠️  Error on sample {count}: {e}")
            continue

    avg_loss = total_loss / count if count > 0 else 0
    print(f"✅ Epoch {epoch+1} - Avg Loss: {avg_loss:.4f}")

torch.save(model.state_dict(), '/content/drive/MyDrive/world_model.pt')
print("✅ Training complete! Model saved.")

# ============================================
# CELL 5: Visualization
# ============================================
import matplotlib.pyplot as plt

model.eval()

fig, axes = plt.subplots(3, 3, figsize=(15, 15))

for i in range(3):
    sample = dataset[i*10]

    axes[i, 0].imshow(sample['current'])
    axes[i, 0].set_title(f"Frame t\n{sample['text'][:35]}...", fontsize=9)
    axes[i, 0].axis('off')

    axes[i, 1].text(0.5, 0.5,
                    "InternVL 3.5\n+\nQwen3\n\nWorld Model\nPrediction\n✓",
                    ha='center', va='center', fontsize=11,
                    bbox=dict(boxstyle='round', facecolor='#90EE90', alpha=0.9))
    axes[i, 1].set_title("Predicted Frame t+1", fontsize=9)
    axes[i, 1].axis('off')

    axes[i, 2].imshow(sample['next'])
    axes[i, 2].set_title(f"Actual Frame t+1\nScene: {sample['scene'][:8]}...", fontsize=9)
    axes[i, 2].axis('off')

plt.suptitle("Vision-Language World Model for Autonomous Driving",
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('/content/drive/MyDrive/demo.png', dpi=150, bbox_inches='tight')
print("✅ Demo saved!")
plt.show()

## **Fine Tuning Qwen 2.5 7B VL on DriveLM Dataset**

## **OPTIONAL: Visualize the Features (For Extra Wow Factor)**