In [21]:
import autorootcwd
from src.model.model import FrozenInTime, compute_similarity
from src.data.data_loader import TextVideoDataLoader
import torch.optim as optim
import torch
from transformers import AutoTokenizer
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_loader = TextVideoDataLoader(
    dataset_name= "NTU",
    text_params= {"input": "text",},
    video_params= {"input_res": 224, "num_frames": 4,},
    data_dir= "data/nturgbd",
    metadata_dir= "data/nturgbd",
    split= 'train',
    tsfm_params= None,
    tsfm_split= None,
    cut= None,
    subsample= 1,
    sliding_window_stride= -1,
    reader= 'decord',
    batch_size=32,
    num_workers=1,
    shuffle=True,
)

model = FrozenInTime(
    video_params={"model": "SpaceTimeTransformer", "num_frames": 4, "arch_config":"base_patch16_224", "vit_init": "imagenet-21k", "attention_style":"frozen-in-time", "pretrained":True},
    text_params={"model": "distilbert-base-uncased", "pretrained": True},
    projection_dim=256,
).to(device)

for p in model.video_model.parameters(): p.requires_grad=False
for p in model.text_model.parameters():  p.requires_grad=False

######USING ATTENTION STYLE:  frozen-in-time


In [22]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

batch = next(iter(data_loader))
video_data = batch['video'].to(device)
text_data = tokenizer(
    batch['text'],
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=tokenizer.model_max_length
).to(device)



In [8]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)

model.train()

for i in range(5):
    # 1. forward
    video_embeddings = model.compute_video(video_data)
    text_embeddings  = model.compute_text(text_data)

    # 2. positive 유사도만 뽑아서 loss 정의
    _, pos_sims = compute_similarity(video_embeddings, text_embeddings)
    loss = - pos_sims.mean()

    # 3. backward & step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"Step {i+1:02d} — loss: {loss.item():.4f}")
    
torch.save(model.state_dict(), "finetuned_head.pth")

Step 01 — loss: 0.0243
Step 02 — loss: -0.5016
Step 03 — loss: -0.6397
Step 04 — loss: -0.7773
Step 05 — loss: -0.8865


In [23]:
import pickle
num_epochs = 10
log_interval = 100

model.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    epoch_loss   = 0.0
    num_batches  = 0
    for batch_idx, batch in enumerate(data_loader, start=1):
        video_data = batch["video"].to(device)
        text_data = tokenizer(
            batch['text'],
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=tokenizer.model_max_length
        ).to(device)

        video_embeddings = model.compute_video(video_data)
        text_embeddings = model.compute_text(text_data)

        _, pos_sims = compute_similarity(video_embeddings, text_embeddings)
        loss = -pos_sims.mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        epoch_loss   += loss.item()
        num_batches  += 1

        if batch_idx % log_interval == 0:
            avg = running_loss / log_interval
            print(f"[Epoch {epoch+1}/{num_epochs}  Batch {batch_idx}]  loss = {avg:.4f}")
            running_loss = 0.0

    epoch_avg = epoch_loss / num_batches
    print(f"--- Epoch {epoch}/{num_epochs} finished, average loss = {epoch_avg:.4f} ---")

    ckpt = {"state_dict": model.state_dict()}
    ckpt_path = f"finetuned_epoch{epoch+1}.pkl"
    with open(ckpt_path, "wb") as f:
        pickle.dump(ckpt, f)

    print(f"Saved checkpoint to {ckpt_path}")
    print(f"Finished epoch {epoch+1}\n")

--- Epoch 0/10 finished, average loss = 0.0041 ---
Saved checkpoint to finetuned_epoch1.pkl
Finished epoch 1

--- Epoch 1/10 finished, average loss = 0.0031 ---
Saved checkpoint to finetuned_epoch2.pkl
Finished epoch 2

--- Epoch 2/10 finished, average loss = 0.0038 ---
Saved checkpoint to finetuned_epoch3.pkl
Finished epoch 3

--- Epoch 3/10 finished, average loss = 0.0035 ---
Saved checkpoint to finetuned_epoch4.pkl
Finished epoch 4

--- Epoch 4/10 finished, average loss = 0.0047 ---
Saved checkpoint to finetuned_epoch5.pkl
Finished epoch 5

--- Epoch 5/10 finished, average loss = 0.0031 ---
Saved checkpoint to finetuned_epoch6.pkl
Finished epoch 6

--- Epoch 6/10 finished, average loss = 0.0039 ---
Saved checkpoint to finetuned_epoch7.pkl
Finished epoch 7

--- Epoch 7/10 finished, average loss = 0.0038 ---
Saved checkpoint to finetuned_epoch8.pkl
Finished epoch 8

--- Epoch 8/10 finished, average loss = 0.0039 ---
Saved checkpoint to finetuned_epoch9.pkl
Finished epoch 9

--- Epoch 

In [25]:
import pickle
import torch
import torch.nn.functional as F
from torch.optim import Adam
from src.model.text_augmentation import augment_text_labels, average_augmented_embeddings

batch = next(iter(data_loader))
video_data = batch['video'].to(device)
text_data = tokenizer(
    batch['text'],
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=tokenizer.model_max_length
).to(device)

# 1) backbone freeze
for p in model.video_model.parameters(): p.requires_grad = False
for p in model.text_model.parameters():  p.requires_grad = False

# 2) 사전 계산: 모든 라벨에 대해 증강 텍스트 → 임베딩 → 평균
# dataset.metadata['raw_captions'] 에 모든 캡션(라벨) 저장돼 있다고 가정
unique_labels = sorted(set(data_loader.dataset.metadata['caption']))
aug_data       = augment_text_labels(unique_labels)
augmented_texts, label_groups = aug_data['augmented_texts'], aug_data['label_groups']

all_aug_embeds = []
model.eval()
with torch.no_grad():
    # 한번에 너무 많이 넣으면 OOM 날 수 있으니 쪼개서
    chunk = 32
    for i in range(0, len(augmented_texts), chunk):
        batch_txts = augmented_texts[i:i+chunk]
        toks = tokenizer(batch_txts, return_tensors="pt",
                         padding=True, truncation=True,
                         max_length=tokenizer.model_max_length).to(device)
        emb = model.compute_text(toks)           # [chunk, dim]
        all_aug_embeds.append(emb.cpu())
all_aug_embeds = torch.cat(all_aug_embeds, dim=0).to(device)  # [num_templates, dim]

averaged_embeds = average_augmented_embeddings(all_aug_embeds, label_groups)
# averaged_embeds: [num_labels, dim]

# 라벨→인덱스 매핑
label_to_idx = {lbl:i for i,lbl in enumerate(unique_labels)}

# 3) optimizer: head만
optimizer = Adam(
    list(model.vid_proj.parameters()) +
    list(model.txt_proj.parameters()),
    lr=1e-5
)

# 4) fine-tuning loop
num_epochs   = 10
log_interval = 100
temperature  = 0.07

model.train()
for epoch in range(1, num_epochs+1):
    running_loss = 0.0
    epoch_loss   = 0.0
    n_batches    = 0

    for batch_idx, batch in enumerate(data_loader, start=1):
        video_data = batch["video"].to(device)
        # 원래 batch['text'] 가 raw caption 문자열
        batch_labels = batch['text']
        idxs = [label_to_idx[lbl] for lbl in batch_labels]
        t_emb = averaged_embeds[idxs]           # [B, dim]

        v_emb = model.compute_video(video_data) # [B, dim]

        # contrastive InfoNCE
        logits, _ = compute_similarity(v_emb, t_emb)
        logits = logits / temperature
        labels = torch.arange(len(v_emb), device=device)
        loss = F.cross_entropy(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        epoch_loss   += loss.item()
        n_batches    += 1

        if batch_idx % log_interval == 0:
            print(f"[Epoch {epoch}/{num_epochs}  Batch {batch_idx}]"
                  f"  batch-loss = {running_loss/log_interval:.4f}")
            running_loss = 0.0

    avg_epoch_loss = epoch_loss / n_batches
    print(f"--- Epoch {epoch}/{num_epochs} avg loss = {avg_epoch_loss:.4f} ---")

    # 5) 체크포인트 저장
    ckpt = {"state_dict": model.state_dict()}
    with open(f"finetuned_epoch{epoch}.pkl", "wb") as f:
        pickle.dump(ckpt, f)
    print(f"Saved checkpoint to finetuned_epoch{epoch}.pkl\n")


--- Epoch 1/10 avg loss = 3.4503 ---
Saved checkpoint to finetuned_epoch1.pkl

--- Epoch 2/10 avg loss = 3.4462 ---
Saved checkpoint to finetuned_epoch2.pkl

--- Epoch 3/10 avg loss = 3.4431 ---
Saved checkpoint to finetuned_epoch3.pkl

--- Epoch 4/10 avg loss = 3.4375 ---
Saved checkpoint to finetuned_epoch4.pkl

--- Epoch 5/10 avg loss = 3.4338 ---
Saved checkpoint to finetuned_epoch5.pkl

--- Epoch 6/10 avg loss = 3.4311 ---
Saved checkpoint to finetuned_epoch6.pkl

--- Epoch 7/10 avg loss = 3.4242 ---
Saved checkpoint to finetuned_epoch7.pkl

--- Epoch 8/10 avg loss = 3.4207 ---
Saved checkpoint to finetuned_epoch8.pkl

--- Epoch 9/10 avg loss = 3.4125 ---
Saved checkpoint to finetuned_epoch9.pkl

--- Epoch 10/10 avg loss = 3.4100 ---
Saved checkpoint to finetuned_epoch10.pkl

