In [62]:
import transformers

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm

from torchvision.transforms import Compose, Resize, CenterCrop

from datasets import load_from_disk


from PIL import Image

import matplotlib.pyplot as plt


from peft import get_peft_model, LoraConfig, TaskType

In [39]:
flickr = load_from_disk("flickr30k_dataset/")

split = flickr['test'].train_test_split(train_size=0.9, seed=42)

train_set = split['train']
val_set = split['test']

In [66]:
tokenizer = transformers.AutoTokenizer.from_pretrained('meta-llama/Llama-3.2-1B')
llama = transformers.AutoModelForCausalLM.from_pretrained(
    'meta-llama/Llama-3.2-1B',
    torch_dtype=torch.float16
).cuda()

In [45]:
tokenizer.pad_token = tokenizer.eos_token 

In [None]:
CLIP_model_id = "openai/clip-vit-large-patch14-336"
CLIP = transformers.CLIPModel.from_pretrained(CLIP_model_id, torch_dtype=torch.float16).cuda()
processor = transformers.CLIPProcessor.from_pretrained(CLIP_model_id)#

CLIP.eval()
for param in CLIP.parameters(): param.requires_grad = False

In [14]:
image = Image.open("example.jpg")


In [34]:
inputs = processor(images=image, return_tensors="pt").to("cuda")
with torch.no_grad():
    image_embeds = CLIP.get_image_features(**inputs)

print(image_embeds.shape)  

torch.Size([1, 768])


In [None]:
transform = Compose([
    Resize(336, interpolation=Image.BICUBIC),
    CenterCrop(336),
])

resized_image = transform(image)
resized_image.show()

Opening in existing browser session.


In [52]:
class flickr_dataset(Dataset):
    def __init__(self, data, tokenizer=tokenizer, processor=processor):
        self.images = data['image']
        self.caption = data['caption']
        self.tokenizer = tokenizer
        self.processor = processor
        


    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        img = self.images[index]
        caption =self.caption[index]
        encoded = self.tokenizer(caption, return_tensors="pt", padding="max_length", truncation=True)

        return self.processor(images=img, return_tensors='pt'), (encoded)


In [53]:
data = flickr_dataset(val_set)

In [57]:
data[0][1]

{'input_ids': tensor([[128000,     32,    893,  ..., 128001, 128001, 128001],
        [128000,  11874,   3026,  ..., 128001, 128001, 128001],
        [128000,   1692,  18186,  ..., 128001, 128001, 128001],
        [128000,     32,    893,  ..., 128001, 128001, 128001],
        [128000,  11874,   3026,  ..., 128001, 128001, 128001]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}

In [58]:
print (CLIP)

CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e

In [59]:
print(llama)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

In [63]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
    target_modules=["q_proj", "v_proj"]  # Typical for LLaMA
)

llama_with_lora = get_peft_model(llama, lora_config)

In [64]:
llama_with_lora.print_trainable_parameters()

trainable params: 851,968 || all params: 1,236,666,368 || trainable%: 0.0689


In [65]:
class captioning(nn.Module):
    def __init__(self, CLIP, llama):
        super().__init__()
        self.CLIP = CLIP
        self.llama = llama
        self.mlp = nn.Sequential(
            nn.Linear(768,1024),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(1024,2048),
            nn.LayerNorm(2048)
        )

    def forward(self, image, input_ids, attention_mask):
        # Encode image with CLIP
        with torch.no_grad():
            image_embed = self.CLIP.get_image_features(pixel_values=image)

        image_token = self.mlp(image_embed).unsqueeze(1)  # [B, 1, D]

        # Embed input_ids via LLaMA embedding layer
        input_embeds = self.llama.model.embed_tokens(input_ids)

        # Concatenate image token and text tokens
        inputs_embeds = torch.cat([image_token, input_embeds], dim=1)

        # Adjust attention mask
        extended_mask = torch.cat([
            torch.ones(image_token.shape[0], 1, device=image_token.device),  # [B, 1]
            attention_mask
        ], dim=1)

        return self.llama(
            inputs_embeds=inputs_embeds,
            attention_mask=extended_mask
        )

In [None]:
def train(model, dataloader, optimizer, device, epochs=3):
    model.train()
    model.to(device)

    loss_fn = nn.CrossEntropyLoss(ignore_index=-100)  # ignore padding

    for epoch in range(epochs):
        total_loss = 0

        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}")
        for batch in pbar:
            pixel_values = batch["pixel_values"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            # Shift labels for causal loss
            labels = input_ids.clone()
            labels[:, :-1] = input_ids[:, 1:]
            labels[:, -1] = -100  # ignore the last token

            outputs = model(
                image=pixel_values,
                input_ids=input_ids,
                attention_mask=attention_mask
            )

            logits = outputs.logits  # shape [B, T, V]
            loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            pbar.set_postfix(loss=loss.item())

        print(f"Epoch {epoch+1} completed. Avg Loss: {total_loss/len(dataloader):.4f}")
