# Approach 2: LLM + Adapters (Frozen Models + Trainable Connector)

This notebook demonstrates a **multimodal adapter** approach. Unlike the pipeline approach (where models are separate), here we **fuse** vision and language by training a small "connector" layer.

**Key Concept:**
We take a powerful pre-trained Vision Encoder (CLIP) and a powerful pre-trained LLM (Phi-3). We keep both of them **FROZEN** (unchanged). We only train a tiny **Adapter** (MLP) that learns to translate "visual features" into "word embeddings" that the LLM can understand.

**Architecture:**
1.  **Image** $\rightarrow$ Frozen CLIP $\rightarrow$ Visual Features
2.  **Visual Features** $\rightarrow$ **Trainable Adapter** $\rightarrow$ Visual Tokens
3.  **Visual Tokens** + **Text Tokens** $\rightarrow$ Frozen LLM $\rightarrow$ Text Output
4.  **LLM** $\rightarrow$ Image Prompt $\rightarrow$ Frozen Stable Diffusion $\rightarrow$ Output Image

**Why this matters:**
This allows the LLM to "see" the image directly in its embedding space, enabling joint reasoning (e.g., "What is unusual about this image?") without retraining the massive LLM.

In [None]:
# Install necessary libraries
# - peft: Parameter-Efficient Fine-Tuning (for LoRA/Adapters)
# - bitsandbytes: For 4-bit quantization
# - datasets: To easily download COCO subset
!pip install -q transformers diffusers accelerate bitsandbytes peft torch torchvision datasets

In [None]:
import torch
import torch.nn as nn
from PIL import Image
import requests
from io import BytesIO
import matplotlib.pyplot as plt

# Setup device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

def show_image(img, title=None):
    plt.figure(figsize=(5, 5))
    plt.imshow(img)
    if title: plt.title(title)
    plt.axis('off')
    plt.show()

## 1. Load Frozen Vision Encoder (CLIP)
We use `openai/clip-vit-base-patch32`. We set `requires_grad=False` to ensure it is **frozen**.

In [None]:
from transformers import CLIPVisionModel, CLIPImageProcessor

print("Loading CLIP Vision Encoder...")
clip_model_name = "openai/clip-vit-base-patch32"
clip_processor = CLIPImageProcessor.from_pretrained(clip_model_name)
clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)

# FREEZE the vision encoder
for param in clip_model.parameters():
    param.requires_grad = False

print("CLIP loaded and FROZEN.")
print(f"Vision Output Dimension: {clip_model.config.hidden_size}")

## 2. Define Trainable Adapter (The "Connector")
This is the **only** part we will train from scratch. It projects the 768-dim CLIP features to the 3072-dim Phi-3 embedding space.

In [None]:
class VisualAdapter(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.GELU(),
            nn.Linear(output_dim, output_dim)
        )
    
    def forward(self, x):
        return self.model(x)

# Dimensions
CLIP_DIM = 768 # ViT-B/32 hidden size
PHI_DIM = 3072 # Phi-3-mini hidden size

# Initialize Adapter
adapter = VisualAdapter(CLIP_DIM, PHI_DIM).to(device)
print("Adapter initialized (Trainable).")

## 3. Load Frozen LLM (Phi-3-mini)
We load Phi-3 in 4-bit to save memory. We also freeze it.

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

print("Loading Phi-3 LLM...")
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

llm_tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
llm_model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Phi-3-mini-4k-instruct",
    quantization_config=quantization_config,
    device_map="auto",
    trust_remote_code=False # Use internal implementation
)

# FREEZE the LLM
for param in llm_model.parameters():
    param.requires_grad = False

print("Phi-3 loaded and FROZEN.")

## 4. Tiny Training Loop (Demonstration)
We will train the adapter on a small example dataset(random subset of COCO captions). Goal to teachc the model that an image corresponds to its text caption

**Process:**
1.  Get Image Embeddings (CLIP).
2.  Project them (Adapter).
3.  Concatenate with Text Embeddings (LLM).
4.  Calculate Loss.
5.  Update Adapter.

In [None]:
from datasets import load_dataset

# 1. Load Dataset (COCO Subset)
print("Loading COCO dataset subset...")
try:
    # We use 'phiyodr/coco2017' which contains metadata/captions.
    dataset = load_dataset("phiyodr/coco2017", split="train[:50]")
    print(f"Loaded {len(dataset)} examples from COCO.")
except Exception as e:
    print(f"COCO load failed: {e}. Falling back to Pokemon dataset.")
    dataset = load_dataset("lambdalabs/pokemon-blip-captions", split="train[:50]")

# 2. Setup Optimizer
optimizer = torch.optim.AdamW(adapter.parameters(), lr=1e-4)
adapter.train()

print("Starting Training Loop...")

# Headers to mimic a browser to avoid 403 Forbidden on some servers
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'}

for idx, item in enumerate(dataset):
    optimizer.zero_grad()
    
    # --- Prepare Data ---
    raw_image = None
    caption = ""
    
    # Case 1: Dataset has 'image' column (PIL Image) - e.g. Pokemon or full COCO
    if 'image' in item and item['image'] is not None:
        raw_image = item['image'].convert('RGB')
    
    # Case 2: Dataset has 'file_name' (COCO style), fetch from URL
    elif 'file_name' in item:
        try:
            # Try train2017 first
            file_name = item['file_name']
            img_url = f"http://images.cocodataset.org/train2017/{file_name}"
            
            # Use BytesIO(response.content) which is more robust than response.raw
            response = requests.get(img_url, headers=headers, timeout=5)
            
            if response.status_code != 200:
                # Try val2017 fallback
                img_url = f"http://images.cocodataset.org/val2017/{file_name}"
                response = requests.get(img_url, headers=headers, timeout=5)

            if response.status_code == 200:
                raw_image = Image.open(BytesIO(response.content)).convert('RGB')
            else:
                # print(f"Failed to fetch {file_name}: {response.status_code}")
                pass
                
        except Exception as e:
            # print(f"Error downloading {item.get('file_name')}: {e}")
            pass
            
    if raw_image is None:
        continue

    # Get Caption
    if 'captions' in item:
        caption = item['captions'][0] # COCO
    elif 'text' in item:
        caption = item['text'] # Pokemon/BLIP datasets
    else:
        caption = "An image."
        
    instruction = "Describe this image: "
    target_text = caption
    
    # --- A. Vision Forward ---
    with torch.no_grad():
        pixel_values = clip_processor(images=raw_image, return_tensors="pt").pixel_values.to(device)
        vision_outputs = clip_model(pixel_values)
        image_embeds_raw = vision_outputs.pooler_output
        
    # Pass through Adapter
    image_embeds_projected = adapter(image_embeds_raw.unsqueeze(1))
    
    # --- B. Text Forward ---
    text_input = instruction + target_text
    tokens = llm_tokenizer(text_input, return_tensors="pt").to(device)
    input_ids = tokens.input_ids
    
    instruction_len = len(llm_tokenizer(instruction).input_ids)
    
    with torch.no_grad():
        text_embeds = llm_model.model.embed_tokens(input_ids)
        
    # --- C. Combine & Loss ---
    # Cast image embeddings to match text embeddings dtype (float16)
    image_embeds_projected = image_embeds_projected.to(text_embeds.dtype)
    
    inputs_embeds = torch.cat([image_embeds_projected, text_embeds], dim=1)
    
    # Create Labels
    labels = torch.full(inputs_embeds.shape[:2], -100, dtype=torch.long).to(device)
    start_idx = 1 + instruction_len
    if input_ids.shape[1] > instruction_len:
        labels[0, start_idx:] = input_ids[0, instruction_len:]
    
    # LLM Forward
    outputs = llm_model(inputs_embeds=inputs_embeds, labels=labels)
    loss = outputs.loss
    
    # Backprop
    loss.backward()
    optimizer.step()
    
    if (idx + 1) % 10 == 0:
        print(f"Step {idx+1}/{len(dataset)}: Loss = {loss.item():.4f}")

print("Training complete. Adapter updated.")

## 5. Inference Demo
Now we use the trained adapter to perform inference.
1. **Image** $\rightarrow$ Adapter $\rightarrow$ LLM
2. **Text** $\rightarrow$ LLM
3. **LLM** generates response.

In [None]:
def generate_multimodal(image, prompt, max_new_tokens=50):
    adapter.eval()
    
    # 1. Vision Path
    with torch.no_grad():
        pixel_values = clip_processor(images=image, return_tensors="pt").pixel_values.to(device)
        vision_outputs = clip_model(pixel_values)
        image_embeds = adapter(vision_outputs.pooler_output.unsqueeze(1))
        
    # 2. Text Path
    tokens = llm_tokenizer(prompt, return_tensors="pt").to(device)
    input_ids = tokens.input_ids
    with torch.no_grad():
        text_embeds = llm_model.model.embed_tokens(input_ids)
        
    # 3. Combine
    # CRITICAL FIX: Cast image embeddings to match text embeddings dtype (float16)
    image_embeds = image_embeds.to(text_embeds.dtype)
    inputs_embeds = torch.cat([image_embeds, text_embeds], dim=1)
    
    # 4. Generate
    attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.long).to(device)
    
    with torch.no_grad():
        outputs = llm_model.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            pad_token_id=llm_tokenizer.eos_token_id
        )
        
    return llm_tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test it
# We load a fresh image to ensure the variable is valid and avoid "dtype object" errors
print("Loading test image for inference...")
img_url = "http://images.cocodataset.org/val2017/000000039769.jpg" # Two cats
headers = {'User-Agent': 'Mozilla/5.0'}

try:
    response = requests.get(img_url, headers=headers, timeout=5)
    test_image = Image.open(BytesIO(response.content)).convert('RGB')
except Exception as e:
    print(f"Failed to load test image ({e}). Using dummy image.")
    test_image = Image.new('RGB', (224, 224), color='gray')

test_prompt = "Describe this image and suggest a style for a painting of it."
print(f"Input Prompt: {test_prompt}")
show_image(test_image, "Input")

response = generate_multimodal(test_image, test_prompt)
print(f"LLM Response:\n{response}")

## 6. Generate Image (Frozen Stable Diffusion)
Finally, we take the LLM's suggestion and generate an image.
(Note: Since we only trained on a small example dataset, the LLM's output might be random, but the pipeline is functional).

In [None]:
from diffusers import StableDiffusionPipeline

# Load Stable Diffusion (Frozen)
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe.to(device)

# Extract a prompt from the response (or use a fallback if response is gibberish due to tiny training)
# For this demo, we'll assume the response contains a description.
gen_prompt = response if len(response) > 5 else "A cat in cyberpunk style"

print(f"Generating image for: '{gen_prompt}'")
image = pipe(gen_prompt).images[0]
show_image(image, "Generated Output")

# Findings
Performs very poorly, and does not match the original image, this is mainly due to a very small and not carefully chosend dataset. Production grade ones like LLaVA train on thousands of images, this was not done here due to computational constraints