In [18]:
%pip install -r requirements.txt

Collecting datasets (from -r requirements.txt (line 4))
  Using cached datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting pyarrow>=15.0.0 (from datasets->-r requirements.txt (line 4))
  Using cached pyarrow-20.0.0-cp312-cp312-macosx_12_0_arm64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets->-r requirements.txt (line 4))
  Using cached dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting pandas (from datasets->-r requirements.txt (line 4))
  Using cached pandas-2.3.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (91 kB)
Collecting xxhash (from datasets->-r requirements.txt (line 4))
  Using cached xxhash-3.5.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets->-r requirements.txt (line 4))
  Using cached multiprocess-0.70.16-py312-none-any.whl.metadata (7.2 kB)
Collecting fsspec>=2023.5.0 (from huggingface-hub<1.0,>=0.30.0->transformers->-r requirements.txt (line 1))
  Using cached fsspec-2025.3.0-py3-none-any.

In [1]:
from transformers import CLIPModel, CLIPTokenizer, CLIPImageProcessor, AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import torch
import torch.nn as nn

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
vision_encoder = clip_model.vision_model
text_encoder = clip_model.text_model
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

qwen_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B-Base")
qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B-Base")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
items = ["cat", "dog", "horse", "bird", "car", "person", "tree", "house", "book", "phone"]
classifiers = [f"a photo of a {item}" for item in items]

text_inputs = tokenizer(classifiers, return_tensors="pt")
text_features = text_encoder(**text_inputs).pooler_output
text_features_proj = clip_model.text_projection(text_features)

for item, classifier in zip(items, classifiers):
    image = Image.open(f"images/{item}.jpg")
    image_inputs = image_processor(image, return_tensors="pt")
    image_features = vision_encoder(**image_inputs).pooler_output
    image_features_proj = clip_model.visual_projection(image_features)
    
    # Compute similarity to all classifier texts
    similarities = (image_features_proj @ text_features_proj.T).squeeze(0)
    best_idx = similarities.argmax().item()
    best_label = classifiers[best_idx]
    best_score = similarities[best_idx].item()
    
    print(f"{item}: best match is '{best_label}' (score: {best_score:.3f})")

cat: best match is 'a photo of a cat' (score: 30.380)
dog: best match is 'a photo of a dog' (score: 35.998)
horse: best match is 'a photo of a horse' (score: 36.964)
bird: best match is 'a photo of a bird' (score: 30.397)
car: best match is 'a photo of a car' (score: 23.982)
person: best match is 'a photo of a person' (score: 27.901)
tree: best match is 'a photo of a tree' (score: 30.656)
house: best match is 'a photo of a house' (score: 29.412)
book: best match is 'a photo of a book' (score: 40.067)
phone: best match is 'a photo of a phone' (score: 32.127)


In [3]:
clip_dim = clip_model.visual_projection.out_features  # 512
qwen_dim = qwen_model.model.embed_tokens.embedding_dim  # 4096

adapter = nn.Sequential(
    nn.Linear(clip_dim, qwen_dim),
    nn.LayerNorm(qwen_dim),
    nn.GELU(),
    nn.Linear(qwen_dim, qwen_dim),
)

for item in items:
    image = Image.open(f"images/{item}.jpg")
    image_inputs = image_processor(image, return_tensors="pt")
    image_features = vision_encoder(**image_inputs).pooler_output
    image_features_proj = clip_model.visual_projection(image_features)
    image_latent = adapter(image_features_proj)
    attention_mask = torch.ones(1, 1)
    position_ids = torch.zeros(1, 1)

    generated_ids = qwen_model.generate(inputs_embeds=image_latent.unsqueeze(1), attention_mask=attention_mask, position_ids=position_ids, max_new_tokens=30, do_sample=True, temperature=0.8, pad_token_id=qwen_tokenizer.pad_token_id, eos_token_id=qwen_tokenizer.eos_token_id)

    print(f"{item}: {qwen_tokenizer.decode(generated_ids[0])}")

cat: zzzoj CONTzอนzojz here IMPeezz CONTzzózez zzzzezzz
dog: hanhan ln xyz Sudokuhanhan lnhanhan xyz ^ _ aa Sudokuhan Sudoku bb ^ aa ij vnhan bvhanhan ee llna oo


KeyboardInterrupt: 

In [27]:
from datasets import load_dataset
import os

# Load Flickr30k dataset from Hugging Face
print("Loading Flickr30k dataset...")
dataset = load_dataset("nlphuji/flickr30k")


print(f"Dataset loaded: {dataset}")
print(f"Available splits: {list(dataset.keys())}")
print(f"Test set: {len(dataset['test'])} samples")

# Convert to the format your training code expects
captions = {}
for item in dataset['test']:  # Changed from 'train' to 'test'
    img_filename = item['filename']  # Changed from 'image_file_name' to 'filename'
    caption = item['caption']  # Changed from 'sentence' to 'caption'
    captions[img_filename] = caption  # Don't use setdefault with append
    
print(f"Loaded {len(captions)} images with captions")



Loading Flickr30k dataset...
Dataset loaded: DatasetDict({
    test: Dataset({
        features: ['image', 'caption', 'sentids', 'split', 'img_id', 'filename'],
        num_rows: 31014
    })
})
Available splits: ['test']
Test set: 31014 samples
Loaded 31014 images with captions


In [None]:
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim


device = "cuda" if torch.cuda.is_available() else "cpu"
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

clip_model = clip_model.to(device)
clip_model = clip_model.to(device)
vision_encoder = vision_encoder.to(device)
qwen_model = qwen_model.to(device)
adapter = adapter.to(device)


# dont change qwen weights
for param in qwen_model.parameters():
    param.requires_grad = False

optimizer = optim.AdamW(adapter.parameters(), lr=1e-5)


sample_size = 10  # Instead of 31,014 samples
dataset = dataset['test'].select(range(sample_size))

image_files = list(captions.keys())  # Use all images for demo
for epoch in range(3):
    for item in tqdm(dataset):
        # Get image and caption
        image = item['image']
        caption = item['caption'][0]
        
        # Process image
        image_inputs = image_processor(image, return_tensors="pt").to(device)
        image_features = vision_encoder(**image_inputs).pooler_output
        image_features_proj = clip_model.visual_projection(image_features)
        image_latent = adapter(image_features_proj)

        # Tokenize the caption
        input_ids = qwen_tokenizer(caption, return_tensors="pt", truncation=True, max_length=32).input_ids.to(device)
        
        # Create the full sequence: [image_embedding] + [caption_tokens]
        image_latent_seq = image_latent.unsqueeze(1)  # [1, 1, hidden_dim]
        
        # Get text embeddings for the caption
        text_embeds = qwen_model.model.embed_tokens(input_ids)
        
        # Concatenate: image embedding + text embeddings
        full_embeddings = torch.cat([image_latent_seq, text_embeds], dim=1)
        
        # Create labels: -100 for image position, actual tokens for text
        batch_size, seq_len = full_embeddings.size(0), full_embeddings.size(1)
        labels = torch.full((batch_size, seq_len), -100, dtype=torch.long, device=device)
        labels[:, 1:1+input_ids.size(1)] = input_ids  # Fill text positions with actual tokens
        
        # Forward pass
        outputs = qwen_model(inputs_embeds=full_embeddings, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item():.4f}")

print("Training loop complete (demo).")

# Save the trained models
print("Saving trained models...")
torch.save(adapter.state_dict(), "adapter.pth")
print("Models saved successfully!")

Using device: mps


  1%|          | 1/100 [00:15<24:57, 15.13s/it]

Loss: 3.9998


  2%|▏         | 2/100 [00:21<16:22, 10.03s/it]

Loss: 4.5520


  3%|▎         | 3/100 [00:27<12:53,  7.97s/it]

Loss: 4.8648


  4%|▍         | 4/100 [00:28<08:34,  5.36s/it]

Loss: 4.5511


  5%|▌         | 5/100 [00:29<06:00,  3.80s/it]

Loss: 3.3203


  6%|▌         | 6/100 [00:30<04:20,  2.77s/it]

Loss: 4.5416


  7%|▋         | 7/100 [00:32<04:07,  2.66s/it]

Loss: 3.5948


  8%|▊         | 8/100 [00:47<09:45,  6.36s/it]

Loss: 4.4727


  9%|▉         | 9/100 [00:50<08:21,  5.51s/it]

Loss: 3.9131


 10%|█         | 10/100 [01:13<16:14, 10.83s/it]

Loss: 5.6113


 11%|█         | 11/100 [01:31<19:34, 13.20s/it]

Loss: 5.7889


 12%|█▏        | 12/100 [01:43<18:43, 12.77s/it]

Loss: 3.7383


 13%|█▎        | 13/100 [01:46<14:15,  9.84s/it]

Loss: 3.1418


 14%|█▍        | 14/100 [01:47<10:07,  7.06s/it]

Loss: 4.7679


 15%|█▌        | 15/100 [01:52<09:08,  6.45s/it]

Loss: 4.4486


 16%|█▌        | 16/100 [01:56<07:55,  5.66s/it]

Loss: 4.0978


 17%|█▋        | 17/100 [01:57<06:02,  4.36s/it]

Loss: 4.1440


 18%|█▊        | 18/100 [01:58<04:37,  3.39s/it]

Loss: 3.8194


 19%|█▉        | 19/100 [02:00<04:01,  2.98s/it]

Loss: 3.9962


 20%|██        | 20/100 [02:04<04:06,  3.08s/it]

Loss: 3.7006


 21%|██        | 21/100 [02:05<03:24,  2.59s/it]

Loss: 5.4342


 22%|██▏       | 22/100 [02:16<06:34,  5.05s/it]

Loss: 4.3910


 23%|██▎       | 23/100 [02:42<14:32, 11.33s/it]

Loss: 3.7832


 24%|██▍       | 24/100 [02:48<12:19,  9.73s/it]

Loss: 5.0906


 25%|██▌       | 25/100 [02:56<11:32,  9.24s/it]

Loss: 4.2484


 26%|██▌       | 26/100 [03:07<12:08,  9.84s/it]

Loss: 4.8969


 27%|██▋       | 27/100 [03:42<20:58, 17.24s/it]

Loss: 5.2828


 28%|██▊       | 28/100 [05:19<49:26, 41.20s/it]

Loss: 4.6474


 29%|██▉       | 29/100 [05:43<42:44, 36.12s/it]

Loss: 6.5313


 30%|███       | 30/100 [05:46<30:21, 26.02s/it]

Loss: 5.4243


 31%|███       | 31/100 [06:15<30:57, 26.92s/it]

Loss: 4.2975


 32%|███▏      | 32/100 [06:21<23:27, 20.71s/it]

Loss: 4.6432


 33%|███▎      | 33/100 [06:28<18:38, 16.70s/it]

Loss: 5.1250


In [14]:
import os

device = "cuda" if torch.cuda.is_available() else "cpu"



adapter.load_state_dict(torch.load("adapter.pth"))
adapter = adapter.float()  # Ensure adapter is in full precision
qwen_model = qwen_model.float() 


image_folder = "images"
image_files = [f for f in os.listdir(image_folder) if f.endswith(".jpg")]

for img_file in image_files:
    image = Image.open(os.path.join(image_folder, img_file)).convert("RGB")
    image_inputs = image_processor(image, return_tensors="pt").to(device)
    image_features = vision_encoder(**image_inputs).pooler_output
    image_features_proj = clip_model.visual_projection(image_features)
    image_latent = adapter(image_features_proj)

    # Generate text from image latent
    attention_mask = torch.ones(1, 1, dtype=torch.long).to(device)
    position_ids = torch.zeros(1, 1, dtype=torch.long).to(device)
    generated_ids = qwen_model.generate(
        inputs_embeds=image_latent.unsqueeze(1),
        attention_mask=attention_mask,
        position_ids=position_ids,
        max_new_tokens=30,
        do_sample=True,
        temperature=0.8,
        pad_token_id=qwen_tokenizer.pad_token_id,
        eos_token_id=qwen_tokenizer.eos_token_id,
    )
    description = qwen_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    print(f"{img_file}: {description}")

dog.jpg:  Seg_Pos中着中的_PosEscPos中Esc*/着》着著乘著着否定中中的著書着着ÆFLSeg SegPOS
horse.jpg: )ibilitiesusherDrivenusherushingollywood Kempusherusherushingstanding Kempstandingereusher Possusherusherusherusher tậnusher acceleratedusher existsushing Possereeusher
book.jpg: 1  ( ',",


1的 
9  0a x T9

 N",  T 3 2

  
bird.jpg: Asked voting voting inter pageNumber hoe veryShareAskedAskedette veryboard votingAskedAskedAsked very very pageNumberPubMed voting votingAsked very very pager hoeAskedboard
person.jpg: ловлов tears?








ловлов



 значительноловлов значительноLOT значительнолов tearsловLOTлов.



лов значительно.


ловловлов tearsловлов.



house.jpg: ??^^^^%%uu^''-
eeoooooooollaaee,,ee,,oooooooo!ffffcccc^^^^aaaaaaaa!!','aaaaooooooorrffffff                                                                                                                                
cat.jpg: .Invoke揮.Invoke MonoBehaviour.Invoke.Invoke_VOID.Invokeobre#pragma.Invoke#pragma.Invoke.Invokeobreobre天空.Invoke slagDetr