In [None]:
%pip install -r requirements.txt

In [1]:
from transformers import CLIPModel, CLIPTokenizer, CLIPImageProcessor, AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import torch
import torch.nn as nn


device = "cuda" if torch.cuda.is_available() else "cpu"
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")


clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
vision_encoder = clip_model.vision_model
text_encoder = clip_model.text_model
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

qwen_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B-Base")
qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B-Base")


  from .autonotebook import tqdm as notebook_tqdm


Using device: mps


In [None]:
items = ["cat", "dog", "horse", "bird", "car", "person", "tree", "house", "book", "phone"]
classifiers = [f"a photo of a {item}" for item in items]

text_inputs = tokenizer(classifiers, return_tensors="pt")
text_features = text_encoder(**text_inputs).pooler_output
text_features_proj = clip_model.text_projection(text_features)

for item, classifier in zip(items, classifiers):
    image = Image.open(f"images/{item}.jpg")
    image_inputs = image_processor(image, return_tensors="pt")
    image_features = vision_encoder(**image_inputs).pooler_output
    image_features_proj = clip_model.visual_projection(image_features)
    
    # Compute similarity to all classifier texts
    similarities = (image_features_proj @ text_features_proj.T).squeeze(0)
    best_idx = similarities.argmax().item()
    best_label = classifiers[best_idx]
    best_score = similarities[best_idx].item()
    
    print(f"{item}: best match is '{best_label}' (score: {best_score:.3f})")

In [4]:
clip_dim = clip_model.visual_projection.out_features  # 512
qwen_dim = qwen_model.model.embed_tokens.embedding_dim  # 4096

adapter = nn.Sequential(
    nn.Linear(clip_dim, qwen_dim),
    nn.LayerNorm(qwen_dim),
    nn.GELU(),
    nn.Linear(qwen_dim, qwen_dim),
)

for item in items:
    image = Image.open(f"images/{item}.jpg")
    image_inputs = image_processor(image, return_tensors="pt")
    image_features = vision_encoder(**image_inputs).pooler_output
    image_features_proj = clip_model.visual_projection(image_features)
    image_latent = adapter(image_features_proj)
    attention_mask = torch.ones(1, 1)
    position_ids = torch.zeros(1, 1)

    generated_ids = qwen_model.generate(inputs_embeds=image_latent.unsqueeze(1), attention_mask=attention_mask, position_ids=position_ids, max_new_tokens=30, do_sample=True, temperature=0.8, pad_token_id=qwen_tokenizer.pad_token_id, eos_token_id=qwen_tokenizer.eos_token_id)

    print(f"{item}: {qwen_tokenizer.decode(generated_ids[0])}")

NameError: name 'items' is not defined

In [2]:
from datasets import load_dataset
import os

# Load Flickr30k dataset from Hugging Face
print("Loading Flickr30k dataset...")
dataset = load_dataset("nlphuji/flickr30k")


print(f"Dataset loaded: {dataset}")
print(f"Available splits: {list(dataset.keys())}")
print(f"Test set: {len(dataset['test'])} samples")

# Convert to the format your training code expects
captions = {}
for item in dataset['test']:  # Changed from 'train' to 'test'
    img_filename = item['filename']  # Changed from 'image_file_name' to 'filename'
    caption = item['caption']  # Changed from 'sentence' to 'caption'
    captions[img_filename] = caption  # Don't use setdefault with append
    
print(f"Loaded {len(captions)} images with captions")



Loading Flickr30k dataset...
Dataset loaded: DatasetDict({
    test: Dataset({
        features: ['image', 'caption', 'sentids', 'split', 'img_id', 'filename'],
        num_rows: 31014
    })
})
Available splits: ['test']
Test set: 31014 samples
Loaded 31014 images with captions


In [None]:
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim


device = "cuda" if torch.cuda.is_available() else "cpu"
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

clip_model = clip_model.to(device)
clip_model = clip_model.to(device)
vision_encoder = vision_encoder.to(device)
qwen_model = qwen_model.to(device)
adapter = adapter.to(device)


# dont change qwen weights
for param in qwen_model.parameters():
    param.requires_grad = False

optimizer = optim.AdamW(adapter.parameters(), lr=1e-5)


sample_size = 10  # Instead of 31,014 samples
dataset = dataset['test'].select(range(sample_size))

image_files = list(captions.keys())  # Use all images for demo
for epoch in range(3):
    for item in tqdm(dataset):
        # Get image and caption
        image = item['image']
        caption = item['caption']
        
        # Process image
        image_inputs = image_processor(image, return_tensors="pt").to(device)
        image_features = vision_encoder(**image_inputs).pooler_output
        image_features_proj = clip_model.visual_projection(image_features)
        image_latent = adapter(image_features_proj)

        # Tokenize the caption
        input_ids = qwen_tokenizer(caption, return_tensors="pt", truncation=True, max_length=32).input_ids.to(device)
        
        # Create the full sequence: [image_embedding] + [caption_tokens]
        image_latent_seq = image_latent.unsqueeze(1)  # [1, 1, hidden_dim]
        
        # Get text embeddings for the caption
        text_embeds = qwen_model.model.embed_tokens(input_ids)
        
        # Concatenate: image embedding + text embeddings
        full_embeddings = torch.cat([image_latent_seq, text_embeds], dim=1)
        
        # Create labels: -100 for image position, actual tokens for text
        batch_size, seq_len = full_embeddings.size(0), full_embeddings.size(1)
        labels = torch.full((batch_size, seq_len), -100, dtype=torch.long, device=device)
        labels[:, 1:1+input_ids.size(1)] = input_ids  # Fill text positions with actual tokens
        
        # Forward pass
        outputs = qwen_model(inputs_embeds=full_embeddings, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item():.4f}")

print("Training loop complete (demo).")

# Save the trained models
print("Saving trained models...")
torch.save(adapter.state_dict(), "adapter.pth")
print("Models saved successfully!")

In [None]:
# Train the adapter on your local images with simple captions
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image

device = "cpu"  # Use CPU for training

# Move models to device
clip_model = clip_model.to(device)
vision_encoder = vision_encoder.to(device)
qwen_model = qwen_model.to(device)
adapter = adapter.to(device)

# Freeze Qwen weights
for param in qwen_model.parameters():
    param.requires_grad = False

optimizer = optim.AdamW(adapter.parameters(), lr=1e-4)  # Higher learning rate

# Simple training data - your local images
training_data = [
    ("cat.jpg", "a cat"),
    ("dog.jpg", "a dog"), 
    ("horse.jpg", "a horse"),
    ("bird.jpg", "a bird"),
    ("car.jpg", "a car"),
    ("person.jpg", "a person"),
    ("tree.jpg", "a tree"),
    ("house.jpg", "a house"),
    ("book.jpg", "a book"),
    ("phone.jpg", "a phone")
]

print("Starting training...")
# Train for a few epochs
for epoch in range(20):  # More epochs
    total_loss = 0
    for image_name, caption in training_data:
        # Load image
        image = Image.open(f"images/{image_name}")
        image_inputs = image_processor(image, return_tensors="pt").to(device)
        
        # Get CLIP features
        image_features = vision_encoder(**image_inputs).pooler_output
        image_features_proj = clip_model.visual_projection(image_features)
        image_latent = adapter(image_features_proj)

        # Tokenize caption
        input_ids = qwen_tokenizer(caption, return_tensors="pt", truncation=True, max_length=10).input_ids.to(device)
        
        # Create sequence: [image_embedding] + [caption_tokens]
        image_latent_seq = image_latent.unsqueeze(1)
        text_embeds = qwen_model.model.embed_tokens(input_ids)
        full_embeddings = torch.cat([image_latent_seq, text_embeds], dim=1)
        
        # Create labels
        batch_size, seq_len = full_embeddings.size(0), full_embeddings.size(1)
        labels = torch.full((batch_size, seq_len), -100, dtype=torch.long, device=device)
        labels[:, 1:1+input_ids.size(1)] = input_ids
        
        # Forward pass
        outputs = qwen_model(inputs_embeds=full_embeddings, labels=labels)
        loss = outputs.loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Average Loss: {total_loss/len(training_data):.4f}")

# Save the trained adapter
torch.save(adapter.state_dict(), "adapter_trained.pth")
print("Training complete!")

In [None]:
import os

device = "mps" if torch.backends.mps.is_available() else "cpu"


adapter.load_state_dict(torch.load("adapter.pth"))
adapter = adapter.float()
adapter = adapter.to(device)
adapter.eval()  # Set to evaluation mode

clip_model = clip_model.to(device)
vision_encoder = vision_encoder.to(device)
qwen_model = qwen_model.to(device)
qwen_model = qwen_model.float()
qwen_model.eval()  # Set to evaluation mode

# Test generation
for item in items:
    image = Image.open(f"images/{item}.jpg")
    image_inputs = image_processor(image, return_tensors="pt").to(device)
    image_features = vision_encoder(**image_inputs).pooler_output
    image_features_proj = clip_model.visual_projection(image_features)
    image_latent = adapter(image_features_proj)
    
    image_latent = image_latent.float()
    attention_mask = torch.ones(1, 1, device=device, dtype=torch.long)
    position_ids = torch.zeros(1, 1, device=device, dtype=torch.long)

    with torch.no_grad():
        # FAST generation - no sampling, fewer tokens
        generated_ids = qwen_model.generate(
            inputs_embeds=image_latent.unsqueeze(1), 
            attention_mask=attention_mask, 
            position_ids=position_ids, 
            max_new_tokens=10,  # Reduced from 30
            do_sample=False,    # Greedy decoding (much faster)
            # temperature=0.8,  # Remove this
            pad_token_id=qwen_tokenizer.pad_token_id, 
            eos_token_id=qwen_tokenizer.eos_token_id
        )

    print(f"{item}: {qwen_tokenizer.decode(generated_ids[0])}")

In [None]:
# Test generation WITHOUT training first to see if it's fast
import os

device = "cpu"  # Use CPU for now - it's often faster than MPS for small models


adapter.load_state_dict(torch.load("adapter.pth"))
adapter = adapter.float()
adapter = adapter.to(device)
adapter.eval()  # Set to evaluation mode

clip_model = clip_model.to(device)
vision_encoder = vision_encoder.to(device)
qwen_model = qwen_model.to(device)
qwen_model = qwen_model.float()
qwen_model.eval()  # Set to evaluation mode




# Test generation with the small model
for item in items:
    image = Image.open(f"images/{item}.jpg")
    image_inputs = image_processor(image, return_tensors="pt").to(device)
    image_features = vision_encoder(**image_inputs).pooler_output
    image_features_proj = clip_model.visual_projection(image_features)
    image_latent = adapter(image_features_proj)
    
    attention_mask = torch.ones(1, 1, device=device, dtype=torch.long)
    position_ids = torch.zeros(1, 1, device=device, dtype=torch.long)

    with torch.no_grad():
        generated_ids = qwen_model.generate(
            inputs_embeds=image_latent.unsqueeze(1), 
            attention_mask=attention_mask, 
            position_ids=position_ids, 
            max_new_tokens=5,  # Very short
            do_sample=False,   # Greedy
            pad_token_id=qwen_tokenizer.pad_token_id, 
            eos_token_id=qwen_tokenizer.eos_token_id
        )

    print(f"{item}: {qwen_tokenizer.decode(generated_ids[0])}")

In [11]:
from PIL import Image
import torch

# --- Device setup ---
device = "mps" if torch.backends.mps.is_available() else "cpu"

# --- Move models to device ---
clip_model = clip_model.to(device)
vision_encoder = vision_encoder.to(device)
adapter = adapter.to(device)
qwen_model = qwen_model.to(device)

clip_model.eval()
vision_encoder.eval()
adapter.eval()
qwen_model.eval()

adapter.load_state_dict(torch.load("adapter_flickr_batched.pth"))


# --- Inputs ---
image = Image.open("images/cat.jpg")
prompt = "This is a photo of:"  # You can change this prompt

# --- 1. Image to CLIP features ---
image_inputs = image_processor(image, return_tensors="pt")
image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
with torch.no_grad():
    image_features = vision_encoder(**image_inputs).pooler_output  # [1, 1, vision_dim]
    image_features_proj = clip_model.visual_projection(image_features)  # [1, 1, clip_dim]
    image_latent = adapter(image_features_proj)  # [1, 1, qwen_dim]
    image_latent = image_latent.unsqueeze(1)  # [1, 1, qwen_dim]

# --- 2. Prompt to Qwen embeddings ---
input_ids = qwen_tokenizer(prompt, return_tensors="pt").input_ids.to(device)  # [1, prompt_len]
with torch.no_grad():
    text_embeds = qwen_model.model.embed_tokens(input_ids)  # [1, prompt_len, qwen_dim]

# --- 3. Concatenate image and text embeddings ---
full_embeddings = torch.cat([image_latent, text_embeds], dim=1)  # [1, 1+prompt_len, qwen_dim]

# --- 4. Attention mask ---
batch_size, seq_len, _ = full_embeddings.shape
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long, device=device)

# --- 5. Generate text ---
with torch.no_grad():
    generated_ids = qwen_model.generate(
        inputs_embeds=full_embeddings,
        attention_mask=attention_mask,
        max_new_tokens=30,
        pad_token_id=qwen_tokenizer.pad_token_id,
        eos_token_id=qwen_tokenizer.eos_token_id,
        do_sample=True,
        temperature=0.8,
    )
    output = qwen_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    print("Generated:", output)

Generated:  Edoardo, the only son of the last Italian-born governor of the island of Tuscany, Italy, with his wife, Lucrezia


In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

device = "mps" if torch.backends.mps.is_available() else "cpu"

# Move models to device
clip_model = clip_model.to(device)
vision_encoder = vision_encoder.to(device)
qwen_model = qwen_model.to(device)
adapter = adapter.to(device)

clip_model.eval()
vision_encoder.eval()
qwen_model.eval()
adapter.train()

# Freeze CLIP and Qwen
for param in clip_model.parameters():
    param.requires_grad = False
for param in vision_encoder.parameters():
    param.requires_grad = False
for param in qwen_model.parameters():
    param.requires_grad = False

optimizer = optim.AdamW(adapter.parameters(), lr=1e-4)

# Prepare a PyTorch dataset and dataloader
class FlickrDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset, image_processor):
        self.data = hf_dataset
        self.image_processor = image_processor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image = item['image']
        caption = item['caption']
        if isinstance(caption, list):
            caption = caption[0]
        return image, caption

def collate_fn(batch):
    images, captions = zip(*batch)
    # Process images as a batch (no padding argument!)
    image_inputs = image_processor(list(images), return_tensors="pt")
    # Tokenize captions as a batch (padding is correct here)
    tokenized = qwen_tokenizer(
        list(captions),
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=32
    )
    return image_inputs, tokenized

batch_size = 8
sample_size = 100  # Increase for more data
flickr_data = dataset['test'].select(range(sample_size))
flickr_dataset = FlickrDataset(flickr_data, image_processor)
dataloader = DataLoader(flickr_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

print("Starting batched adapter training on Flickr30k...")
for epoch in range(3):
    print(f"Batch {epoch+1}/{len(dataloader)}")
    total_loss = 0
    for image_inputs, tokenized in tqdm(dataloader):
        # Move to device
        image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
        input_ids = tokenized['input_ids'].to(device)
        attention_mask = tokenized['attention_mask'].to(device)

        batch_size, seq_len = input_ids.shape

        with torch.no_grad():
            image_features = vision_encoder(**image_inputs).pooler_output  # [B, vision_dim]
            image_features_proj = clip_model.visual_projection(image_features)  # [B, clip_dim]
        image_latent = adapter(image_features_proj)  # [B, qwen_dim]
        image_latent = image_latent.unsqueeze(1)  # [B, 1, qwen_dim]

        with torch.no_grad():
            text_embeds = qwen_model.model.embed_tokens(input_ids)  # [B, seq_len, qwen_dim]

        # Concatenate image and text embeddings
        full_embeddings = torch.cat([image_latent, text_embeds], dim=1)  # [B, 1+seq_len, qwen_dim]

        # Labels: -100 for image, actual tokens for text
        labels = torch.full((batch_size, seq_len + 1), -100, dtype=torch.long, device=device)
        labels[:, 1:] = input_ids

        # Forward and optimize
        outputs = qwen_model(inputs_embeds=full_embeddings, labels=labels, attention_mask=torch.cat([torch.ones((batch_size, 1), device=device, dtype=torch.long), attention_mask], dim=1))
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch_size

    print(f"Epoch {epoch+1}, Average Loss: {total_loss/len(flickr_dataset):.4f}")

torch.save(adapter.state_dict(), "adapter_flickr_batched.pth")
print("Batched adapter training complete and saved as adapter_flickr_batched.pth!")

Starting batched adapter training on Flickr30k...
Batch 1/13


100%|██████████| 13/13 [05:53<00:00, 27.23s/it]


Epoch 1, Average Loss: 6.1941
Batch 2/13


100%|██████████| 13/13 [08:02<00:00, 37.10s/it]


Epoch 2, Average Loss: 6.1678
Batch 3/13


100%|██████████| 13/13 [08:33<00:00, 39.50s/it]

Epoch 3, Average Loss: 5.8785
Batched adapter training complete and saved as adapter_flickr_batched.pth!



