In [9]:
import autorootcwd
from src.model.model import FrozenInTime, compute_similarity
from src.data.data_loader import TextVideoDataLoader
import torch.optim as optim
import torch
from transformers import AutoTokenizer
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_loader = TextVideoDataLoader(
    dataset_name= "NTU",
    text_params= {"input": "text",},
    video_params= {"input_res": 224, "num_frames": 4,},
    data_dir= "data/nturgbd",
    metadata_dir= "data/nturgbd",
    split= 'train',
    tsfm_params= None,
    tsfm_split= None,
    cut= None,
    subsample= 1,
    sliding_window_stride= -1,
    reader= 'decord',
    batch_size=32,
    num_workers=1,
    shuffle=True,
)

model = FrozenInTime(
    video_params={"model": "SpaceTimeTransformer", "num_frames": 4, "arch_config":"base_patch16_224", "vit_init": "imagenet-21k", "attention_style":"frozen-in-time", "pretrained":True},
    text_params={"model": "distilbert-base-uncased", "pretrained": True},
    projection_dim=256,
).to(device)

for p in model.video_model.parameters(): p.requires_grad=False
for p in model.text_model.parameters():  p.requires_grad=False

######USING ATTENTION STYLE:  frozen-in-time


In [7]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

batch = next(iter(data_loader))
video_data = batch['video'].to(device)
text_data = tokenizer(
    batch['text'],
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=tokenizer.model_max_length
).to(device)



In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)

model.train()

for i in range(5):
    # 1. forward
    video_embeddings = model.compute_video(video_data)
    text_embeddings  = model.compute_text(text_data)

    # 2. positive 유사도만 뽑아서 loss 정의
    _, pos_sims = compute_similarity(video_embeddings, text_embeddings)
    loss = - pos_sims.mean()

    # 3. backward & step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"Step {i+1:02d} — loss: {loss.item():.4f}")
    
torch.save(model.state_dict(), "finetuned_head.pth")

Step 01 — loss: 0.0243
Step 02 — loss: -0.5016
Step 03 — loss: -0.6397
Step 04 — loss: -0.7773
Step 05 — loss: -0.8865


In [None]:
num_epochs = 5
log_interval = 100

model.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    for batch_idx, batch in enumerate(data_loader, start=1):
        video_data = batch["video"].to(device)
        text_data = tokenizer(
            batch['text'],
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=tokenizer.model_max_length
        ).to(device)

        video_embeddings = model.compute_video(video_data)
        text_embeddings = model.compute_text(text_data)

        _, pos_sims = compute_similarity(video_embeddings, text_embeddings)
        loss = -pos_sims.mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % log_interval == 0:
            avg = running_loss / log_interval
            print(f"[Epoch {epoch+1}/{num_epochs}  Batch {batch_idx}]  loss = {avg:.4f}")
            running_loss = 0.0
    torch.save(model.state_dict(), f"finetuned_epoch{epoch+1}.pth")
    print(f"Finished epoch {epoch+1}\n")