In [1]:
from video_moment_retrieval.utils.logging import init_logging, logger
from video_moment_retrieval.datasets.qv_highlights import QVDataset, pad_collate
from video_moment_retrieval.moment_detr.model import VideoDetrConfig, MomentDetr
from transformers import TrainingArguments, Trainer, EvalPrediction

init_logging()

In [2]:
from typing import Any
import numpy.typing as npt
from scipy.special import softmax
from video_moment_retrieval.detr_matcher.matcher import center_to_edges
from video_moment_retrieval.eval.eval import compute_mr_ap
import numpy as np


def process_preds_and_labels(eval_preds: EvalPrediction) -> tuple[list[dict[str, Any]], list[list[float]]]:
    # eval.label_ids -> list[list[dict[str, np.array]]]
    # eval.predictions -> list[tuple[np.array, np.array]]
    labels = []
    predictions = []
    for batch_idx in range(len(eval_preds.label_ids)):
        batch_labels = eval_preds.label_ids[batch_idx]
        batch_predictions = eval_preds.predictions[batch_idx]
        # moments and scores are each batch_size x 10 x 2
        moments, scores = batch_predictions
        scores = softmax(scores, -1)[..., 0]  # batch_size x 10
        
        for video_labels, video_predictions, video_scores in zip(batch_labels, moments, scores):
            qid_label, qid_prediction = [], []
            gt_windows = center_to_edges(video_labels["boxes"]) * video_labels["duration"]
            pred_windows = center_to_edges(video_predictions) * video_labels["duration"]
            
            qid_prediction = [(window[0].item(), window[1].item(), score) for window, score in zip(pred_windows, video_scores)]
            qid_label = [(window[0].item(), window[1].item()) for window in gt_windows]
            
        
            labels.append(qid_label)
            predictions.append(qid_prediction)
    
    return labels, predictions  

def compute_metrics(eval_preds: EvalPrediction):
    labels, predictions = process_preds_and_labels(eval_preds)
    metrics_dict = compute_mr_ap(predictions, labels, num_workers=8)
    
    return {
        "mAP@0.5": metrics_dict["0.5"],
        "mAP@0.7": metrics_dict["0.7"],
        "mAP": metrics_dict["average"],
    }

In [3]:

train_dataset = QVDataset("qvhighlights_features\\text_features", "qvhighlights_features\\video_features", "qvhighlights_features\\highlight_train_release.jsonl")
eval_dataset = QVDataset("qvhighlights_features\\text_features", "qvhighlights_features\\video_features", "qvhighlights_features\\highlight_val_release.jsonl")

config = VideoDetrConfig(
    d_model=256,
    encoder_layers=2,
    encoder_ffn_dim=1024,
    decoder_layers=2,
    decoder_ffn_dim=1024,
    num_queries=10,
    dropout=0.1,
    activation_dropout=0.1,
    giou_cost=1,
    bbox_cost=10,
    class_cost=4,
    giou_loss_coefficient=1,
    bbox_loss_coefficient=10,
    ce_loss_coefficient=4,
    num_labels=1,
    saliency_loss_coefficient=2
)
logger.info("Running model using config %s", config)

model = MomentDetr(config)

train_args = TrainingArguments(
    "./train_output",
    per_device_train_batch_size=32,
    gradient_accumulation_steps=2,
    learning_rate=1e-4,
    lr_scheduler_type="constant_with_warmup",
    warmup_steps=500,
    num_train_epochs=200,
    save_steps=1000,
    eval_strategy="steps",
    eval_steps=1000,
    load_best_model_at_end=True,
    greater_is_better=True,
    max_grad_norm=0.1,
    label_names=["labels"],
    weight_decay=1e-4,
    eval_do_concat_batches=False,
    metric_for_best_model="mAP"
    # use_cpu=True
)

trainer = Trainer(
    model=model,
    args=train_args,
    data_collator=pad_collate,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics
)

trainer.train()

2024-06-19 07:42:02,421 - INFO video_moment_retrieval - 2358234223.py:22 - Running model using config VideoDetrConfig {
  "activation_dropout": 0.1,
  "activation_function": "relu",
  "attention_dropout": 0.0,
  "auxiliary_loss": false,
  "backbone": "resnet50",
  "backbone_config": null,
  "backbone_kwargs": {
    "in_chans": 3,
    "out_indices": [
      1,
      2,
      3,
      4
    ]
  },
  "bbox_cost": 10,
  "bbox_loss_coefficient": 10,
  "ce_loss_coefficient": 4,
  "class_cost": 4,
  "d_model": 256,
  "decoder_attention_heads": 8,
  "decoder_ffn_dim": 1024,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 2,
  "dice_loss_coefficient": 1,
  "dilation": false,
  "dropout": 0.1,
  "encoder_attention_heads": 8,
  "encoder_ffn_dim": 1024,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 2,
  "eos_coefficient": 0.1,
  "giou_cost": 1,
  "giou_loss_coefficient": 1,
  "hinge_loss_margin": 0.2,
  "id2label": {
    "0": "LABEL_0"
  },
  "init_std": 0.02,
  "init_xavier_std": 1.0,
  "is_enc

Step,Training Loss,Validation Loss,Map@0.5,Map@0.7,Map
1000,5.3054,6.399514,13.07,5.51,5.34
2000,4.299,4.339568,21.63,9.26,9.48
3000,4.0965,4.341559,22.02,9.8,10.15
4000,3.9978,4.376107,25.36,11.95,11.93
5000,3.8967,4.363026,22.54,9.09,9.74
6000,3.8175,4.394557,25.25,11.04,11.59
7000,3.739,4.393754,28.06,12.01,12.45
8000,3.5714,4.651382,30.79,14.69,14.7
9000,3.4154,4.394294,35.41,18.43,17.27
10000,3.2825,4.613762,37.32,18.62,18.04


TrainOutput(global_step=22600, training_loss=3.2124608909134316, metrics={'train_runtime': 6699.025, 'train_samples_per_second': 215.494, 'train_steps_per_second': 3.374, 'total_flos': 0.0, 'train_loss': 3.2124608909134316, 'epoch': 200.0})

In [31]:
from torch.utils.data import DataLoader
from video_moment_retrieval.datasets.qv_highlights import QVDataset, pad_collate

val_dataset = QVDataset("qvhighlights_features\\text_features", "qvhighlights_features\\video_features", "qvhighlights_features\\highlight_val_release.jsonl")
val_loader = DataLoader(val_dataset, batch_size=1, collate_fn=pad_collate, shuffle=True)
batch = next(iter(val_loader))
model = model.to("cpu")
output = model(**batch)
print(batch["labels"][0]["boxes"])
scores = output.logits.softmax(axis=-1)[0, :, 0]
moments = output.predicted_moments[0, scores > 0.5, :]
print(moments)

tensor([[0.3800, 0.2000]])
tensor([[0.5170, 0.1521],
        [0.2624, 0.1437],
        [0.8812, 0.1701],
        [0.6829, 0.2108]], grad_fn=<IndexBackward0>)
