In [1]:
%env CUDA_VISIBLE_DEVICES 0

import json
import os
import random

from datasets import load_from_disk
import evaluate
import numpy as np
import torch
from transformers import VisionEncoderDecoderModel, AutoImageProcessor, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, EarlyStoppingCallback

device = "cuda"
torch.backends.cuda.matmul.allow_tf32 = True

env: CUDA_VISIBLE_DEVICES=0


In [2]:
encoder = "facebook/timesformer-base-finetuned-k600"
decoder = "gpt2"

image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
tokenizer = AutoTokenizer.from_pretrained(decoder)
tokenizer.pad_token = tokenizer.eos_token

model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(encoder, decoder).to(device)
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.max_length = 50
model.config.num_beams = 4
model.config.early_stopping = True

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
Some weights of the model checkpoint at facebook/timesformer-base-finetuned-k600 were not used when initializing TimesformerModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing TimesformerModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TimesformerModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.9.crossattention.bias', 'h.4.crossattention.q_attn.weight', 'h.7.crossa

In [2]:
dataset = load_from_disk("/data1/caelen/dataset/vatex")
dataset.set_format("torch")
dataset

DatasetDict({
    train: Dataset({
        features: ['videoID', 'pixel_values', 'labels'],
        num_rows: 22895
    })
    validation: Dataset({
        features: ['videoID', 'pixel_values', 'labels'],
        num_rows: 2643
    })
})

In [None]:
class VatexDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    
    def __len__(self):
        return 10 * len(self.dataset)
    
    def __getitems__(self, idxs):
        items = []
        for idx in idxs:
            video_idx = idx // 10
            caption_idx = idx % 10
            example = self.dataset[video_idx]
            items.append({
                "videoID": example["videoID"],
                "pixel_values": example["pixel_values"], 
                "labels": example["labels"][caption_idx]
            })
        return items

In [None]:
output_dir = "/data1/caelen/training/vatex"

training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    tf32=True,
    predict_with_generate=True,
    load_best_model_at_end=True,
    logging_strategy="epoch",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    remove_unused_columns=False,
    per_device_train_batch_size=6,
    per_device_eval_batch_size=6,
    dataloader_num_workers=8,
    num_train_epochs=100,
    learning_rate=5e-7,
)

def collator(examples):
    print(examples[0].keys())
    return None
    
    pixel_values, labels = [], []
    for example in examples:
        print(example["video_id"])
        
        # train
        if len(example["pixel_values"]) == 16:
            frame_idxs = []
            for i in range(0, 16, 2):
                frame_idxs.append(i + random.randint(0, 1))
            pixel_values.append(example["pixel_values"][frame_idxs])
            labels.append(example["labels"][0])
        # val
        else:
            pixel_values.append(example["pixel_values"])
            labels.append(example["labels"])

    pixel_values = torch.stack(pixel_values)
    labels = torch.stack(labels)
    return {"pixel_values": pixel_values, "labels": labels}
    
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    data_collator=collator,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=5)],
)

trainer.train()