In [1]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv‚Ä¶

In [2]:
!wandb login --relogin

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


In [1]:
import os, math, random, time, torch
import cv2
import numpy as np
import re
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.amp import autocast, GradScaler
import matplotlib.pyplot as plt

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoProcessor,
    SiglipVisionModel,
)

# CONFIG

LLM_NAME    = "meta-llama/Llama-3.2-1B"
VISION_NAME = "google/siglip-so400m-patch14-384"

# Logic Flags
CALCULATE_SEMANTIC_ACCURACY = True

# Synthetic Data Config
NUM_FRAMES    = 4
TRAIN_SAMPLES = 2000
VAL_SAMPLES   = 50
IMG_SIZE      = 224

# Training Config
cache_dir        = "./video_checkpoints"
batch_size       = 4
max_txt_len      = 64
total_steps      = 5000
warmup_steps     = 500
grad_accum       = 8
val_interval     = 200

use_wandb     = True
wandb_project = "blipren-video-synthetic"

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = True
print("Device:", device)

os.makedirs(cache_dir, exist_ok=True)
os.makedirs("logs", exist_ok=True)


# LOADING BACKBONES

print(f"Loading LLM: {LLM_NAME}")
tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

llm = AutoModelForCausalLM.from_pretrained(LLM_NAME, torch_dtype=torch.float16, device_map="auto")
llm.eval()
d_model = llm.config.hidden_size

print(f"Loading vision encoder: {VISION_NAME}")
vision_model = SiglipVisionModel.from_pretrained(VISION_NAME, torch_dtype=torch.float16).to(device)
processor = AutoProcessor.from_pretrained(VISION_NAME)
vision_model.eval()
d_vision = vision_model.config.hidden_size


# DATASET, HAD TO ADD EOS EXPLICITY

def get_sinusoidal_embeddings(n_pos, d_model):
    """
    Standard Transformer Positional Embeddings.
    Returns: (1, n_pos, 1, d_model) to match our broadcasting needs.
    """
    pe = torch.zeros(n_pos, d_model)
    position = torch.arange(0, n_pos, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)

    # Reshape for broadcasting: (1, T, 1, D)
    return pe.unsqueeze(0).unsqueeze(2)


class SyntheticVideoDataset(Dataset):
    def __init__(self, num_samples, num_frames=4, img_size=224):
        self.num_samples = num_samples
        self.num_frames = num_frames
        self.img_size = img_size
        self.prompt_text = "Numbers in video:"

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        sequence = [random.randint(0, 9) for _ in range(self.num_frames)]
        frames_pil = []
        frames_raw = []

        for num in sequence:
            # High Contrast
            img_np = np.zeros((self.img_size, self.img_size, 3), dtype=np.uint8)
            text = str(num)
            font = cv2.FONT_HERSHEY_SIMPLEX
            font_scale = 6
            thickness = 10

            text_size = cv2.getTextSize(text, font, font_scale, thickness)[0]
            text_x = (self.img_size - text_size[0]) // 2
            text_y = (self.img_size + text_size[1]) // 2

            cv2.putText(img_np, text, (text_x, text_y), font, font_scale, (255, 255, 255), thickness)
            frames_raw.append(img_np)
            frames_pil.append(Image.fromarray(img_np))

        seq_str = ", ".join(map(str, sequence))
        # Explicit EOS
        full_caption = f"{self.prompt_text} {seq_str}{tokenizer.eos_token}"

        return {
            "frames": frames_pil,
            "raw_frames": frames_raw,
            "caption": full_caption,
            "prompt_only": self.prompt_text,
            "gt_seq": seq_str
        }

# DATA COLLATOR

def collate_fn(batch):
    all_frames = []
    captions = []
    raw_frames_batch = []
    prompts = []
    gts = []

    for ex in batch:
        all_frames.extend(ex["frames"])
        captions.append(ex["caption"])
        raw_frames_batch.append(ex["raw_frames"])
        prompts.append(ex["prompt_only"])
        gts.append(ex["gt_seq"])

    pixel_values = processor(images=all_frames, return_tensors="pt")["pixel_values"]

    enc = tokenizer(
        captions,
        padding="max_length",
        truncation=True,
        max_length=max_txt_len,
        return_tensors="pt",
    )
    input_ids = enc["input_ids"]
    labels = input_ids.clone()

    # Mask prompt
    prompt_ids_len = [len(tokenizer.encode(p, add_special_tokens=False)) for p in prompts]
    for i, p_len in enumerate(prompt_ids_len):
        labels[i, :p_len+1] = -100
    labels[enc["attention_mask"] == 0] = -100

    b_sz = len(batch)
    frames_per_vid = len(batch[0]["frames"])
    _, c, h, w = pixel_values.shape
    pixel_values = pixel_values.view(b_sz, frames_per_vid, c, h, w)

    return {
        "pixel_values": pixel_values,
        "input_ids": input_ids,
        "attention_mask": enc["attention_mask"],
        "labels": labels,
        "raw_frames_batch": raw_frames_batch,
        "prompts": prompts,
        "gts": gts
    }

print("Initializing Synthetic Datasets...")
train_ds = SyntheticVideoDataset(TRAIN_SAMPLES, NUM_FRAMES, IMG_SIZE)
val_ds   = SyntheticVideoDataset(VAL_SAMPLES, NUM_FRAMES, IMG_SIZE)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, collate_fn=collate_fn)
print("Dataloaders ready.")


# ARCHITECTURE

class QFormer(nn.Module):
    def __init__(self, d_vis, d_model, n_queries=64, n_layers=6):
        super().__init__()
        self.query = nn.Parameter(torch.randn(1, n_queries, d_model))
        self.vis_proj = nn.Linear(d_vis, d_model)
        layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=8, dim_feedforward=d_model*4, batch_first=True)
        self.transformer = nn.TransformerDecoder(layer, num_layers=n_layers)
        self.ln_out = nn.LayerNorm(d_model)

    def forward(self, vis_features):
        B = vis_features.shape[0]
        v = self.vis_proj(vis_features)
        q = self.query.expand(B, -1, -1)
        return self.ln_out(self.transformer(q, v))

class Projector(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.linear1 = nn.Linear(d_model, d_model)
        self.gelu = nn.GELU()
        self.linear2 = nn.Linear(d_model, d_model)
        self.ln2 = nn.LayerNorm(d_model)

    def forward(self, x):
        res = x
        x = self.ln1(x)
        x = self.linear2(self.gelu(self.linear1(x)))
        return self.ln2(x + res)

class VideoBLIP(nn.Module):
    def __init__(self, llm, vision, qformer, projector, num_frames):
        super().__init__()
        self.llm = llm
        self.vision = vision
        self.qformer = qformer
        self.projector = projector
        self.num_frames = num_frames

        # Instead of learnable parameters (which start at 0 and take forever to learn),
        # we create fixed, high-frequency mathematical embeddings. With longer learning learnable params would also work
        d_vis = vision.config.hidden_size

        # Generate them
        sinusoidal_embed = get_sinusoidal_embeddings(num_frames, d_vis)

        # Register as a buffer (not a learnable parameter, just a constant)
        self.register_buffer("time_embed", sinusoidal_embed)

    def encode_video(self, pixel_values):
        B, T, C, H, W = pixel_values.shape
        pixel_values_flat = pixel_values.view(B * T, C, H, W)

        with torch.no_grad():
            vout = self.vision(pixel_values=pixel_values_flat)
            vtoks = vout.last_hidden_state # (B*T, Patches, D)

        d_vis = vtoks.shape[-1]

        # Un-flatten: (B, T, Patches, D)
        vtoks = vtoks.view(B, T, -1, d_vis)

        # NOT LEARNED TIME EMBEDDINGS
        # PERFORMING A SAFE SLICE IN CASE OF GETTING LESS FRAMES IN INFERENCE TIME
        t_embed = self.time_embed[:, :T, :, :]

        # Add to visual tokens.
        # Note: We don't need to scale vtoks usually, but ensure types match
        vtoks = vtoks + t_embed.to(vtoks.dtype)

        # Flatten for Q-Former: (B, T*Patches, D)
        vtoks = vtoks.view(B, -1, d_vis)

        q = self.qformer(vtoks.to(torch.float32))
        return self.projector(q).to(self.llm.dtype)

    def forward(self, input_ids, pixel_values, attention_mask, labels=None):
        q = self.encode_video(pixel_values)
        K = q.size(1)

        txt_emb = self.llm.get_input_embeddings()(input_ids)
        all_emb = torch.cat([q, txt_emb], dim=1)

        prefix_mask = torch.ones(input_ids.size(0), K, device=input_ids.device, dtype=attention_mask.dtype)
        full_mask = torch.cat([prefix_mask, attention_mask], dim=1)

        if labels is not None:
            prefix_labels = torch.full((input_ids.size(0), K), -100, device=input_ids.device, dtype=torch.long)
            full_labels = torch.cat([prefix_labels, labels], dim=1)
        else:
            full_labels = None

        out = self.llm(inputs_embeds=all_emb, attention_mask=full_mask, labels=full_labels)
        return out.logits, out.loss

    @torch.no_grad()
    def generate(self, pixel_values, prompts, max_new_tokens=30, repetition_penalty=1.2):
        q = self.encode_video(pixel_values)
        K = q.size(1)

        enc = tokenizer(prompts, return_tensors="pt", padding=True).to(pixel_values.device)
        input_ids = enc.input_ids
        attn_mask = enc.attention_mask
        txt_emb = self.llm.get_input_embeddings()(input_ids)

        all_emb = torch.cat([q, txt_emb], dim=1)
        prefix_mask = torch.ones(input_ids.size(0), K, device=input_ids.device, dtype=attn_mask.dtype)
        full_mask = torch.cat([prefix_mask, attn_mask], dim=1)

        out_ids = self.llm.generate(
            inputs_embeds=all_emb,
            attention_mask=full_mask,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.eos_token_id,
            repetition_penalty=repetition_penalty,
            do_sample=False
        )
        return tokenizer.batch_decode(out_ids, skip_special_tokens=True)

# MODULES
qformer = QFormer(d_vis=d_vision, d_model=d_model, n_queries=64, n_layers=6)
projector = Projector(d_model)

# Pass NUM_FRAMES to model for embedding initialization
model = VideoBLIP(llm, vision_model, qformer, projector, num_frames=NUM_FRAMES).to(device)

# Freeze LLM and Vision
for p in llm.parameters(): p.requires_grad = False
for p in vision_model.parameters(): p.requires_grad = False

# OPTIMIZER AND WANDB SETUP

train_params = list(qformer.parameters()) + list(projector.parameters())
optimizer = torch.optim.AdamW(train_params, lr=1e-4, weight_decay=0.01)

def lr_lambda(step):
    if step < warmup_steps: return step / max(1, warmup_steps)
    return 0.5 * (1 + math.cos(math.pi * (step - warmup_steps) / (total_steps - warmup_steps)))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
scaler = GradScaler(device="cuda")

if use_wandb:
    import wandb
    wandb.init(project=wandb_project, config={
        "lr": 1e-4, "frames": NUM_FRAMES, "semantic_acc": CALCULATE_SEMANTIC_ACCURACY, "time_embed": "learnable"
    })

def calculate_accuracy(pred_text, gt_text):
    def extract_nums(s): return [int(d) for d in re.findall(r'\d+', s)]
    gt_nums = extract_nums(gt_text)
    pred_nums = extract_nums(pred_text)
    length = min(len(gt_nums), len(pred_nums))
    if length == 0: return 0.0
    matches = sum(1 for i in range(length) if gt_nums[i] == pred_nums[i])
    return matches / max(len(gt_nums), 1)

# VALIDATION AND TRAINING

def run_validation(global_step):
    model.eval()
    total_loss, total_acc, count = 0.0, 0.0, 0
    log_vids, log_prompts, log_preds, log_gts = [], [], [], []

    with torch.no_grad():
        for i, batch in enumerate(val_loader):
            pixel_values = batch["pixel_values"].to(device)
            input_ids    = batch["input_ids"].to(device)
            attn_mask    = batch["attention_mask"].to(device)
            labels       = batch["labels"].to(device)

            with autocast("cuda", dtype=torch.float16):
                _, loss = model(input_ids, pixel_values, attn_mask, labels)

            bs = input_ids.size(0)
            total_loss += loss.item() * bs
            count += bs

            if i == 0:
                prompts = batch["prompts"]
                gts = batch["gts"]
                preds = model.generate(pixel_values, prompts, max_new_tokens=30, repetition_penalty=1.2)

                if CALCULATE_SEMANTIC_ACCURACY:
                    batch_acc = 0
                    for j in range(len(preds)):
                        acc = calculate_accuracy(preds[j], gts[j])
                        batch_acc += acc
                        if j < 4:
                            log_vids.append(batch["raw_frames_batch"][j])
                            log_prompts.append(prompts[j])
                            log_preds.append(preds[j])
                            log_gts.append(gts[j])
                    total_acc = batch_acc / len(preds)

    avg_loss = total_loss / max(1, count)
    print(f"\n[VAL @ {global_step}] Loss={avg_loss:.4f} | Acc={total_acc*100:.1f}%")
    if log_preds:
        print(f"GT: {log_gts[0]}")
        print(f"PR: {log_preds[0]}\n")

    if use_wandb:
        log_dict = {"val_loss": avg_loss, "val_accuracy": total_acc}
        if log_preds:
            columns = ["step", "video", "prompt", "pred", "gt"]
            table = wandb.Table(columns=columns)
            for k in range(len(log_preds)):
                vid_np = np.array(log_vids[k])
                vid_np = np.transpose(vid_np, (0, 3, 1, 2))
                table.add_data(global_step, wandb.Video(vid_np, fps=2, format="gif"), log_prompts[k], log_preds[k], log_gts[k])
            log_dict["val_samples"] = table
        wandb.log(log_dict, step=global_step)

    model.train()
    return avg_loss

print("üöÄ Training starting...")
best_val_loss = float("inf")
global_step = 0
train_iter = iter(train_loader)

while global_step < total_steps:
    try:
        batch = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        batch = next(train_iter)

    pixel_values = batch["pixel_values"].to(device)
    input_ids    = batch["input_ids"].to(device)
    attn_mask    = batch["attention_mask"].to(device)
    labels       = batch["labels"].to(device)

    with autocast("cuda", dtype=torch.float16):
        _, loss = model(input_ids, pixel_values, attn_mask, labels)
        loss = loss / grad_accum

    scaler.scale(loss).backward()

    if (global_step + 1) % grad_accum == 0:
        scaler.unscale_(optimizer)
        grad_norm = torch.nn.utils.clip_grad_norm_(train_params, 1.0).item()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        optimizer.zero_grad()
        if use_wandb: wandb.log({"grad_norm": grad_norm}, step=global_step)

    if global_step % 10 == 0:
        curr_loss = loss.item() * grad_accum
        lr = scheduler.get_last_lr()[0]
        if global_step % 100 == 0:
            print(f"[{global_step:05d}] loss={curr_loss:.4f} lr={lr:.6e}")
        if use_wandb: wandb.log({"train_loss": curr_loss, "lr": lr}, step=global_step)

    if (global_step + 1) % val_interval == 0:
        val_loss = run_validation(global_step + 1)
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(qformer.state_dict(), os.path.join(cache_dir, "qformer_best.pt"))
            torch.save(projector.state_dict(), os.path.join(cache_dir, "projector_best.pt"))

    global_step += 1

print("üéâ Training Finished!")
wandb.finish()

Device: cuda
Loading LLM: meta-llama/Llama-3.2-1B


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!


Loading vision encoder: google/siglip-so400m-patch14-384


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Initializing Synthetic Datasets...
Dataloaders ready.


[34m[1mwandb[0m: Currently logged in as: [33meren23[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


üöÄ Training starting...
[00000] loss=3.9508 lr=0.000000e+00
[00100] loss=1.0310 lr=2.400000e-06


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



[VAL @ 200] Loss=1.0490 | Acc=0.0%
GT: 5, 1, 0, 9
PR:  1, 0



[00200] loss=0.9542 lr=5.000000e-06
[00300] loss=0.9222 lr=7.400000e-06

[VAL @ 400] Loss=0.7327 | Acc=6.2%
GT: 8, 3, 9, 6
PR:  9, 8, 7, 6



[00400] loss=0.7627 lr=1.000000e-05
[00500] loss=0.8662 lr=1.240000e-05

[VAL @ 600] Loss=0.5745 | Acc=12.5%
GT: 3, 0, 2, 4
PR:  2, 3, 4, 0



[00600] loss=0.6287 lr=1.500000e-05
[00700] loss=0.4339 lr=1.740000e-05

[VAL @ 800] Loss=0.4539 | Acc=37.5%
GT: 1, 8, 4, 8
PR:  1, 8, 4



[00800] loss=0.4111 lr=2.000000e-05
[00900] loss=0.3208 lr=2.240000e-05

[VAL @ 1000] Loss=0.3703 | Acc=75.0%
GT: 9, 7, 8, 6
PR:  9, 7, 8, 6



[01000] loss=0.3672 lr=2.500000e-05
[01100] loss=0.3474 lr=2.740000e-05

[VAL @ 1200] Loss=0.3076 | Acc=37.5%
GT: 4, 5, 5, 1
PR:  4, 5, 1, 2



[01200] loss=0.2691 lr=3.000000e-05
[01300] loss=0.1929 lr=3.240000e-05

[VAL @ 1400] Loss=0.2411 | Acc=43.8%
GT: 1, 1, 6, 6
PR:  1, 6, 1



[01400] loss=0.1891 lr=3.500000e-05
[01500] loss=0.1522 lr=3.740000e-05

[VAL @ 1600] Loss=0.1542 | Acc=56.2%
GT: 1, 0, 2, 3
PR:  1, 0, 2



[01600] loss=0.2068 lr=4.000000e-05
[01700] loss=0.1618 lr=4.240000e-05

[VAL @ 1800] Loss=0.1580 | Acc=75.0%
GT: 8, 7, 4, 3
PR:  8, 7, 3, 4



[01800] loss=0.1416 lr=4.500000e-05
[01900] loss=0.0531 lr=4.740000e-05

[VAL @ 2000] Loss=0.1792 | Acc=75.0%
GT: 3, 9, 6, 7
PR:  3, 9, 7, 6



[02000] loss=0.1559 lr=5.000000e-05
[02100] loss=0.0780 lr=5.240000e-05

[VAL @ 2200] Loss=0.1284 | Acc=68.8%
GT: 9, 8, 4, 5
PR:  9, 8, 4, 5



[02200] loss=0.0698 lr=5.500000e-05
[02300] loss=0.1822 lr=5.740000e-05

[VAL @ 2400] Loss=0.1456 | Acc=56.2%
GT: 6, 3, 7, 3
PR:  6, 3, 7, 9



[02400] loss=0.1264 lr=6.000000e-05
[02500] loss=0.0840 lr=6.240000e-05

[VAL @ 2600] Loss=0.1388 | Acc=75.0%
GT: 9, 2, 7, 8
PR:  9, 2, 7, 8



[02600] loss=0.1014 lr=6.500000e-05
[02700] loss=0.2347 lr=6.740000e-05

[VAL @ 2800] Loss=0.2148 | Acc=75.0%
GT: 5, 5, 5, 9
PR:  5, 5, 5, 9



[02800] loss=0.1516 lr=7.000000e-05
[02900] loss=0.2102 lr=7.240000e-05

[VAL @ 3000] Loss=0.1249 | Acc=56.2%
GT: 7, 7, 9, 7
PR:  7, 9, 7, 7



[03000] loss=0.1346 lr=7.500000e-05
[03100] loss=0.0966 lr=7.740000e-05

[VAL @ 3200] Loss=0.1199 | Acc=37.5%
GT: 2, 5, 8, 2
PR:  2, 5, 8, 1



[03200] loss=0.1155 lr=8.000000e-05
[03300] loss=0.0624 lr=8.240000e-05

[VAL @ 3400] Loss=0.1151 | Acc=87.5%
GT: 6, 2, 5, 6
PR:  6, 2, 5, 1



[03400] loss=0.0391 lr=8.500000e-05
[03500] loss=0.0755 lr=8.740000e-05

[VAL @ 3600] Loss=0.1401 | Acc=75.0%
GT: 3, 6, 9, 0
PR:  3, 6, 0, 9



[03600] loss=0.1810 lr=9.000000e-05
[03700] loss=0.0403 lr=9.240000e-05

[VAL @ 3800] Loss=0.0985 | Acc=68.8%
GT: 3, 4, 5, 8
PR:  3, 4, 8, 5



[03800] loss=0.1225 lr=9.500000e-05
[03900] loss=0.0702 lr=9.740000e-05

[VAL @ 4000] Loss=0.1150 | Acc=87.5%
GT: 3, 5, 0, 5
PR:  3, 5, 0, 5



[04000] loss=0.0407 lr=1.000000e-04
[04100] loss=0.0447 lr=9.999825e-05

[VAL @ 4200] Loss=0.0606 | Acc=81.2%
GT: 0, 3, 4, 7
PR:  0, 3, 4, 7



[04200] loss=0.0426 lr=9.999238e-05
[04300] loss=0.0455 lr=9.998332e-05

[VAL @ 4400] Loss=0.1290 | Acc=81.2%
GT: 0, 3, 5, 9
PR:  0, 3, 5, 9



[04400] loss=0.1281 lr=9.996954e-05
[04500] loss=0.1424 lr=9.995317e-05

[VAL @ 4600] Loss=0.1034 | Acc=87.5%
GT: 1, 5, 4, 3
PR:  1, 5, 4, 3



[04600] loss=0.0958 lr=9.993148e-05
[04700] loss=0.1300 lr=9.990780e-05

[VAL @ 4800] Loss=0.0920 | Acc=100.0%
GT: 4, 1, 9, 1
PR:  4, 1, 9, 1



[04800] loss=0.0682 lr=9.987820e-05
[04900] loss=0.1337 lr=9.984723e-05

[VAL @ 5000] Loss=0.0629 | Acc=75.0%
GT: 7, 6, 2, 1
PR:  7, 6, 2, 1



üéâ Training Finished!


0,1
grad_norm,‚ñà‚ñÑ‚ñÉ‚ñà‚ñÉ‚ñÉ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÇ‚ñÅ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
lr,‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
train_loss,‚ñà‚ñÇ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
val_accuracy,‚ñÅ‚ñÅ‚ñÇ‚ñÑ‚ñÜ‚ñÑ‚ñÑ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÖ‚ñÜ‚ñÜ‚ñÖ‚ñÑ‚ñá‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñà‚ñÜ
val_loss,‚ñà‚ñÜ‚ñÖ‚ñÑ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ

0,1
grad_norm,0.81477
lr,0.0001
train_loss,0.02503
val_accuracy,0.75
val_loss,0.06286


In [None]:
import wandb
import os

# ================= CONFIGURATION =================
WANDB_ENTITY  = "eren23"
WANDB_PROJECT = "blipren-video-synthetic"
RUN_NAME      = "whole-glitter-6"


CHECKPOINT_DIR = "./video_checkpoints"

FILES_TO_UPLOAD = [
    "qformer_best.pt",
    "projector_best.pt",
]

# ================= UPLOAD SCRIPT =================
print(f"üîç Searching for run '{RUN_NAME}' in {WANDB_ENTITY}/{WANDB_PROJECT}...")

api = wandb.Api()
runs = api.runs(f"{WANDB_ENTITY}/{WANDB_PROJECT}", filters={"display_name": RUN_NAME})

if len(runs) == 0:
    print(f"Could not find run with name: {RUN_NAME}")
    print("Check the name or manually provide the Run ID (e.g., 'a1b2c3d4').")
else:
    run = runs[0]
    print(f"Found Run ID: {run.id} ({run.state})")

    for filename in FILES_TO_UPLOAD:
        file_path = os.path.join(CHECKPOINT_DIR, filename)

        if os.path.exists(file_path):
            print(f"Uploading {filename}...")
            # root='.' preserves the folder structure in wandb
            run.upload_file(file_path, root=".")
            print(f"   Success!")
        else:
            print(f"File not found: {file_path}")

    print("\nAll uploads finished. Check the 'Files' tab in WandB.")