In [1]:
import autorootcwd
from src.model.model import FrozenInTime, compute_similarity
from src.data.data_loader import TextVideoDataLoader
import torch.optim as optim
import torch
from transformers import AutoTokenizer
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_loader = TextVideoDataLoader(
    dataset_name= "NTU",
    text_params= {"input": "text",},
    video_params= {"input_res": 224, "num_frames": 4,},
    data_dir= "data/nturgbd",
    metadata_dir= "data/nturgbd",
    split= 'train',
    tsfm_params= None,
    tsfm_split= None,
    cut= None,
    subsample= 1,
    sliding_window_stride= -1,
    reader= 'decord',
    batch_size=32,
    num_workers=1,
    shuffle=True,
)

model = FrozenInTime(
    video_params={"model": "SpaceTimeTransformer", "num_frames": 4, "arch_config":"base_patch16_224", "vit_init": "imagenet-21k", "attention_style":"frozen-in-time", "pretrained":True},
    text_params={"model": "distilbert-base-uncased", "pretrained": True},
    projection_dim=256,
).to(device)

for p in model.video_model.parameters(): p.requires_grad=False
for p in model.text_model.parameters():  p.requires_grad=False

  from .autonotebook import tqdm as notebook_tqdm


######USING ATTENTION STYLE:  frozen-in-time


In [2]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

batch = next(iter(data_loader))
video_data = batch['video'].to(device)
text_data = tokenizer(
    batch['text'],
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=tokenizer.model_max_length
).to(device)

In [8]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)

model.train()

for i in range(5):
    # 1. forward
    video_embeddings = model.compute_video(video_data)
    text_embeddings  = model.compute_text(text_data)

    # 2. positive 유사도만 뽑아서 loss 정의
    _, pos_sims = compute_similarity(video_embeddings, text_embeddings)
    loss = - pos_sims.mean()

    # 3. backward & step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"Step {i+1:02d} — loss: {loss.item():.4f}")
    
torch.save(model.state_dict(), "finetuned_head.pth")

Step 01 — loss: 0.0243
Step 02 — loss: -0.5016
Step 03 — loss: -0.6397
Step 04 — loss: -0.7773
Step 05 — loss: -0.8865


In [None]:
import pickle
import torch
import torch.nn.functional as F

for p in model.video_model.parameters(): p.requires_grad = False
for p in model.text_model.parameters():  p.requires_grad = False

optimizer = optim.Adam(model.parameters(), lr=1e-4)
temperature = 0.07

num_epochs = 10
log_interval = 100

model.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    epoch_loss   = 0.0
    num_batches  = 0

    for batch_idx, batch in enumerate(data_loader, start=1):
        video_data = batch["video"].to(device)
        text_data = tokenizer(
            batch['text'],
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=tokenizer.model_max_length
        ).to(device)

        video_embeddings = model.compute_video(video_data)
        text_embeddings = model.compute_text(text_data)

        sim, _ = compute_similarity(video_embeddings, text_embeddings)
        labels = torch.arange(sim.size(0), device=device)

        loss_v2t = F.cross_entropy(sim / temperature, labels)
        loss_t2v = F.cross_entropy(sim.t() / temperature, labels)
        loss = (loss_v2t + loss_t2v) / 2

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        epoch_loss   += loss.item()
        num_batches  += 1

        if batch_idx % log_interval == 0:
            avg = running_loss / log_interval
            print(f"[Epoch {epoch+1}/{num_epochs}  Batch {batch_idx}]  loss = {avg:.4f}")
            running_loss = 0.0

    epoch_avg = epoch_loss / num_batches
    print(f"--- Epoch {epoch}/{num_epochs} finished, average loss = {epoch_avg:.4f} ---")

    ckpt = {"state_dict": model.state_dict()}
    ckpt_path = f"finetuned_epoch{epoch+1}.pkl"
    with open(ckpt_path, "wb") as f:
        pickle.dump(ckpt, f)

    print(f"Saved checkpoint to {ckpt_path}")

--- Epoch 0/10 finished, average loss = 3.4658 ---
Saved checkpoint to finetuned_epoch1.pkl
--- Epoch 1/10 finished, average loss = 3.4155 ---
Saved checkpoint to finetuned_epoch2.pkl
--- Epoch 2/10 finished, average loss = 3.3378 ---
Saved checkpoint to finetuned_epoch3.pkl
--- Epoch 3/10 finished, average loss = 3.1951 ---
Saved checkpoint to finetuned_epoch4.pkl
--- Epoch 4/10 finished, average loss = 3.0604 ---
Saved checkpoint to finetuned_epoch5.pkl
--- Epoch 5/10 finished, average loss = 2.9387 ---
Saved checkpoint to finetuned_epoch6.pkl
--- Epoch 6/10 finished, average loss = 2.8747 ---
Saved checkpoint to finetuned_epoch7.pkl
--- Epoch 7/10 finished, average loss = 2.8019 ---
Saved checkpoint to finetuned_epoch8.pkl
--- Epoch 8/10 finished, average loss = 2.7584 ---
Saved checkpoint to finetuned_epoch9.pkl
--- Epoch 9/10 finished, average loss = 2.7094 ---
Saved checkpoint to finetuned_epoch10.pkl
