In [1]:
from datasets import load_dataset
ds = load_dataset('liuhaotian/LLaVA-CC3M-Pretrain-595K', data_files="chat.json", split='train')
print(ds)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

chat.json:   0%|          | 0.00/211M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset({
    features: ['id', 'image', 'conversations'],
    num_rows: 595375
})


In [2]:
from huggingface_hub import hf_hub_download
zip_path = hf_hub_download(repo_id="liuhaotian/LLaVA-CC3M-Pretrain-595K", filename="images.zip", repo_type="dataset")
zip_path

images.zip:   0%|          | 0.00/6.46G [00:00<?, ?B/s]

'/root/.cache/huggingface/hub/datasets--liuhaotian--LLaVA-CC3M-Pretrain-595K/snapshots/814894e93db9e12a1dee78b9669e20e8606fd590/images.zip'

In [3]:
import zipfile, io
from PIL import Image
zf = zipfile.ZipFile(zip_path)
def load_image_from_zip(name):
    with zf.open(name) as f:
        return Image.open(io.BytesIO(f.read())).convert("RGB")

sample = ds[0]
img_name = sample.get("image") or (sample.get("id") + ".jpg")
img = load_image_from_zip(img_name)
print(img_name)
img.size, sample["conversations"][1]

GCC_train_002582585.jpg


((224, 224),
 {'from': 'gpt',
  'value': 'olive oil is a healthy ingredient used liberally .'})

In [4]:
from transformers import CLIPVisionModel, CLIPImageProcessor
vision = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14")
processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
for param in vision.parameters():
  param.requires_grad = False

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/1.71G [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

In [5]:
import torch
def get_vis_tokens(img):
    inputs = processor(images=img, return_tensors="pt").to("cuda")
    with torch.no_grad():
        feats = vision(**inputs).last_hidden_state
    return feats.squeeze(0)

In [6]:
import torch.nn as nn
from transformers import (AutoTokenizer, AutoModelForCausalLM, get_cosine_schedule_with_warmup,)
device = "cuda" if torch.cuda.is_available() else "cpu"
llm = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
batch_size=2
lr=2e-3
epochs=2
WARMUP_STEPS=100
MAX_STEPS=200

In [7]:
from torch.utils.data import Dataset, DataLoader
import json
class ChatZipDataset(Dataset):
    def __init__(self, chat_json_path, images_zip_path, processor, tokenizer):
        with open(chat_json_path, "r", encoding="utf-8") as f:
            self.data = json.load(f)
        self.zf = zipfile.ZipFile(images_zip_path, "r")
        self.processor = processor
        self.tok = tokenizer

    def __len__(self):
      return len(self.data)

    def _open_image(self, name):
        with self.zf.open(name) as f:
            return Image.open(io.BytesIO(f.read())).convert("RGB")

    def __getitem__(self, idx):
        rec = self.data[idx]
        img_name = rec.get("image") or (rec.get("id","") + ".jpg")
        convs = rec.get("conversations", [])
        user_text = "Describe the image."
        asst_text = ""
        for c in convs:
            if c.get("from","").lower() == "human" and c.get("value"):
                user_text = c["value"].replace("<image>", "").strip() or "Describe the image"
                break
        for c in convs:
            if c.get("from","").lower() == "gpt" and c.get("value"):
                asst_text = c["value"].strip()
                break

        img = self._open_image(img_name)
        vis_inputs = self.processor(images=img, return_tensors="pt")

        user_ids = self.tok(user_text, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze(0)
        asst_ids = self.tok(asst_text, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze(0)
        return {"vis_inputs": vis_inputs, "user_ids": user_ids, "asst_ids": asst_ids}

class Collator:
    def __init__(self, pad_id): self.pad_id = pad_id
    def __call__(self, batch):
        max_u = max(x["user_ids"].size(0) for x in batch)
        max_a = max(x["asst_ids"].size(0) for x in batch)
        user_ids, asst_ids, pv_list = [], [], []

        for x in batch:
            u, a = x["user_ids"], x["asst_ids"]
            if u.size(0) < max_u:
                u = torch.cat([u, torch.full((max_u - u.size(0),), self.pad_id, dtype=torch.long)], dim=0)
            if a.size(0) < max_a:
                a = torch.cat([a, torch.full((max_a - a.size(0),), self.pad_id, dtype=torch.long)], dim=0)
            user_ids.append(u)
            asst_ids.append(a)
            pv_list.append(x["vis_inputs"]["pixel_values"])
        user_ids = torch.stack(user_ids, 0)
        asst_ids = torch.stack(asst_ids, 0)
        pixel_values = torch.cat(pv_list, dim=0)

        return {"user_ids": user_ids, "asst_ids": asst_ids, "pixel_values": pixel_values}


In [8]:
class Stage1Model(nn.Module):
    def __init__(self, vision, llm):
        super().__init__()
        self.vision = vision
        for p in self.vision.parameters():
          p.requires_grad = False
        self.vision.eval()

        self.llm = AutoModelForCausalLM.from_pretrained(llm)
        for p in self.llm.parameters():
          p.requires_grad = False
        self.llm.eval()

        v_dim   = self.vision.config.hidden_size
        d_model = self.llm.config.hidden_size
        self.projector = nn.Linear(v_dim, d_model)
        self.tok_emb   = self.llm.get_input_embeddings()

    @torch.no_grad()
    def encode_image(self, vis_inputs):
        return self.vision(**vis_inputs).last_hidden_state

    def forward(self, vis_inputs, user_ids, asst_ids):
        with torch.no_grad():
            v_tokens = self.encode_image(vis_inputs)
        V = self.projector(v_tokens)

        text_ids = torch.cat([user_ids, asst_ids], dim=1)
        text_emb = self.tok_emb(text_ids)

        inputs_embeds = torch.cat([V, text_emb], dim=1)
        attn_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long, device=inputs_embeds.device)

        B = user_ids.size(0); Tv = V.size(1); U = user_ids.size(1); A = asst_ids.size(1)
        labels = torch.cat([
            torch.full((B, Tv), -100, dtype=torch.long, device=inputs_embeds.device),
            torch.full((B, U),  -100, dtype=torch.long, device=inputs_embeds.device),
            asst_ids
        ], dim=1)

        out = self.llm(inputs_embeds=inputs_embeds, attention_mask=attn_mask, labels=labels)
        return out.loss

In [9]:
import os
def train_stage1(chat_json_path, images_zip_path, max_steps=200):
    tokenizer = AutoTokenizer.from_pretrained(llm)
    processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")

    ds = ChatZipDataset(chat_json_path, images_zip_path, processor, tokenizer)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=True, collate_fn=Collator(tokenizer.pad_token_id or 0))

    model = Stage1Model(vision, llm).to(device)
    opt = torch.optim.AdamW(model.projector.parameters(), lr=lr)
    total_steps = min(max_steps, epochs * max(1, len(dl)))
    sched = get_cosine_schedule_with_warmup(opt, WARMUP_STEPS, total_steps)

    step = 0
    model.train()
    for epoch in range(epochs):
        for batch in dl:
            vis_inputs = {"pixel_values": batch["pixel_values"].to(device)}
            user_ids = batch["user_ids"].to(device)
            asst_ids = batch["asst_ids"].to(device)

            loss = model(vis_inputs=vis_inputs, user_ids=user_ids, asst_ids=asst_ids)
            opt.zero_grad(); loss.backward(); opt.step(); sched.step()

            step += 1
            if step % 10 == 0:
                print(f"step {step}/{total_steps}  loss={loss.item():.4f}")
            if step >= total_steps: break
        if step >= total_steps: break

    os.makedirs("checkpoints", exist_ok=True)
    torch.save(model.projector.state_dict(), "checkpoints/projector_stage1.pt")
    print("Saved to checkpoints/projector_stage1.pt")

In [10]:
chat_path = hf_hub_download(repo_id='liuhaotian/LLaVA-CC3M-Pretrain-595K', filename='chat.json', repo_type="dataset")
train_stage1(chat_path, zip_path, max_steps=200)

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

step 10/200  loss=8.0888
step 20/200  loss=6.5507
step 30/200  loss=5.3876
step 40/200  loss=7.3611
step 50/200  loss=5.1082
step 60/200  loss=4.1272
step 70/200  loss=2.3804
step 80/200  loss=2.5535
step 90/200  loss=3.4138
step 100/200  loss=3.9934
step 110/200  loss=3.8943
step 120/200  loss=4.4690
step 130/200  loss=4.2606
step 140/200  loss=3.6347
step 150/200  loss=4.9465
step 160/200  loss=3.6728
step 170/200  loss=3.6709
step 180/200  loss=3.8007
step 190/200  loss=3.0642
step 200/200  loss=4.0941
Saved to checkpoints/projector_stage1.pt


In [11]:
from peft import LoraConfig, get_peft_model
class Stage2Model(nn.Module):
    def __init__(self, vision, llm, projector_path):
        super().__init__()
        self.vision = CLIPVisionModel.from_pretrained(vision)
        for p in self.vision.parameters():
          p.requires_grad = False
        self.vision.eval()
        base_llm = AutoModelForCausalLM.from_pretrained(llm)
        lora_cfg = LoraConfig(
            r=16, lora_alpha=32, lora_dropout=0.05,
            target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
            bias="none", task_type="CAUSAL_LM"
        )
        self.llm = get_peft_model(base_llm, lora_cfg)

        v_dim = self.vision.config.hidden_size
        d_model = self.llm.base_model.config.hidden_size
        self.projector = nn.Linear(v_dim, d_model)

        self.projector.load_state_dict(torch.load(projector_path, map_location="cpu"))

        self.tok_emb = self.llm.get_input_embeddings()

    @torch.no_grad()
    def encode_image(self, vis_inputs):
        return self.vision(**vis_inputs).last_hidden_state

    def forward(self, vis_inputs, user_ids, asst_ids):
        with torch.no_grad():
            v_tokens = self.encode_image(vis_inputs)
        V = self.projector(v_tokens)

        text_ids = torch.cat([user_ids, asst_ids], dim=1)
        text_emb = self.tok_emb(text_ids)

        inputs_embeds = torch.cat([V, text_emb], dim=1)
        attn_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long, device=inputs_embeds.device)

        B = user_ids.size(0); Tv = V.size(1); U = user_ids.size(1); A = asst_ids.size(1)
        labels = torch.cat([
            torch.full((B, Tv), -100, dtype=torch.long, device=inputs_embeds.device),
            torch.full((B, U),  -100, dtype=torch.long, device=inputs_embeds.device),
            asst_ids
        ], dim=1)

        out = self.llm(inputs_embeds=inputs_embeds, attention_mask=attn_mask, labels=labels)
        return out.loss

In [12]:
from torch.utils.data import DataLoader

BATCH_SIZE=2
LR=2e-5
EPOCHS=1
WARMUP_STEPS=50
MAX_STEPS=300

def train_stage2(chat_json_path, images_zip_path, projector_ckpt="checkpoints/projector_stage1.pt", max_steps=300):
    tokenizer = AutoTokenizer.from_pretrained(llm)
    processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")

    ds = ChatZipDataset(chat_json_path, images_zip_path, processor, tokenizer)
    dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=Collator(tokenizer.pad_token_id or 0))

    model = Stage2Model("openai/clip-vit-large-patch14", llm, projector_ckpt).to(device)

    params = list(model.projector.parameters()) + [p for p in model.llm.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(params, lr=LR)

    total_steps = min(max_steps, EPOCHS * max(1, len(dl)))
    sched = get_cosine_schedule_with_warmup(opt, WARMUP_STEPS, total_steps)

    step = 0
    model.train()
    for epoch in range(EPOCHS):
        for batch in dl:
            vis_inputs = {"pixel_values": batch["pixel_values"].to(device)}
            user_ids = batch["user_ids"].to(device)
            asst_ids = batch["asst_ids"].to(device)

            loss = model(vis_inputs=vis_inputs, user_ids=user_ids, asst_ids=asst_ids)

            opt.zero_grad()
            loss.backward()
            opt.step()
            sched.step()

            step += 1
            if step % 10 == 0:
                print(f"[Stage-2] step {step}/{total_steps}  loss={loss.item():.4f}")
            if step >= total_steps: break
        if step >= total_steps: break

    os.makedirs("checkpoints", exist_ok=True)
    torch.save(model.projector.state_dict(), "checkpoints/projector_stage2.pt")
    model.llm.save_pretrained("checkpoints/llm_lora")
    print("Saved projector: checkpoints/projector_stage2.pt")
    print("Saved LoRA adapter: checkpoints/llm_lora")

In [13]:
train_stage2(chat_path, zip_path, projector_ckpt="checkpoints/projector_stage1.pt", max_steps=300)


[Stage-2] step 10/300  loss=4.2345
[Stage-2] step 20/300  loss=4.2142
[Stage-2] step 30/300  loss=3.4410
[Stage-2] step 40/300  loss=3.1668
[Stage-2] step 50/300  loss=2.8178
[Stage-2] step 60/300  loss=5.5109
[Stage-2] step 70/300  loss=3.1644
[Stage-2] step 80/300  loss=3.5968
[Stage-2] step 90/300  loss=2.6456
[Stage-2] step 100/300  loss=4.3146
[Stage-2] step 110/300  loss=2.7245
[Stage-2] step 120/300  loss=1.7584
[Stage-2] step 130/300  loss=3.2866
[Stage-2] step 140/300  loss=3.4748
[Stage-2] step 150/300  loss=3.8017
[Stage-2] step 160/300  loss=2.9234
[Stage-2] step 170/300  loss=2.1767
[Stage-2] step 180/300  loss=2.7793
[Stage-2] step 190/300  loss=3.6375
[Stage-2] step 200/300  loss=2.5224
[Stage-2] step 210/300  loss=3.2630
[Stage-2] step 220/300  loss=2.6615
[Stage-2] step 230/300  loss=2.6836
[Stage-2] step 240/300  loss=2.9017
[Stage-2] step 250/300  loss=2.4858
[Stage-2] step 260/300  loss=2.8795
[Stage-2] step 270/300  loss=2.9193
[Stage-2] step 280/300  loss=2.8218
[

In [30]:
torch.cuda.empty_cache(); gc.collect()

33145

In [None]:
@torch.no_grad()
def generate_once(model, processor, tokenizer, rec):
    # build one sample
    img_name = rec.get("image") or (rec.get("id","") + ".jpg")
    with zipfile.ZipFile(zip_path,"r").open(img_name) as f:
        img = Image.open(io.BytesIO(f.read())).convert("RGB")
    vis_inputs = processor(images=img, return_tensors="pt").to(DEVICE)

    # take first human as prompt
    user_text = "Describe the image."
    for c in rec["conversations"]:
        if c["from"].lower()=="human":
            user_text = c["value"].replace("<image>","").strip() or "Describe the image."
            break

    user_ids = tokenizer(user_text, add_special_tokens=False, return_tensors="pt")["input_ids"].to(DEVICE)

    # vision
    V = model.projector(model.encode_image({k:v for k,v in vis_inputs.items()}))  # (1, Tv, d)
    # prepare prefix embeds [V ; user_emb]
    user_emb = model.tok_emb(user_ids)  # (1, U, d)
    inputs_embeds = torch.cat([V, user_emb], dim=1)
    attn_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long, device=inputs_embeds.device)

    # generate
    gen = model.llm.generate(
        inputs_embeds=inputs_embeds,
        attention_mask=attn_mask,
        max_new_tokens=64,
        do_sample=False
    )
    # strip the visual prefix length (Tv) + user length (U) when decoding new tokens only
    # Easier: decode the full output and print tail
    text = tokenizer.decode(gen[0], skip_special_tokens=True)
    print(text)

# load for inference
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
m = Stage2Model("openai/clip-vit-large-patch14", "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "checkpoints/projector_stage2.pt").to(device)
# load LoRA adapter
from peft import PeftModel
base = m.llm.base_model  # not used directly; we reload via PEFT hub files
m.llm.from_pretrained = None  # no-op, avoid confusion
m.llm.load_adapter("checkpoints/llm_lora", "default")

# try on the first record
with open(chat_path, "r", encoding="utf-8") as f:
    recs = json.load(f)
generate_once(m, processor, tokenizer, recs[0])
