In [1]:
from video_moment_retrieval.utils.logging import init_logging, logger
from video_moment_retrieval.datasets.qv_highlights import QVDataset, pad_collate
from video_moment_retrieval.moment_detr.model import VideoDetrConfig, MomentDetr
from transformers import TrainingArguments, Trainer, EvalPrediction

init_logging()

In [2]:
from typing import Any
import numpy.typing as npt
from scipy.special import softmax
from video_moment_retrieval.detr_matcher.matcher import center_to_edges
from video_moment_retrieval.eval.eval import compute_mr_ap
import numpy as np



def process_preds_and_labels(eval_preds: EvalPrediction) -> tuple[list[dict[str, Any]], list[list[float]]]:
    # eval.label_ids -> list[list[dict[str, np.array]]]
    # eval.predictions -> list[tuple[np.array, np.array]]
    labels = []
    predictions = []
    for batch_idx in range(len(eval_preds.label_ids)):
        batch_labels = eval_preds.label_ids[batch_idx]
        batch_predictions = eval_preds.predictions[batch_idx]
        # moments and scores are each batch_size x 10 x 2
        moments, scores = batch_predictions
        scores = softmax(scores, -1)[..., 0]  # batch_size x 10
        
        for video_labels, video_predictions, video_scores in zip(batch_labels, moments, scores):
            qid_label, qid_prediction = [], []
            gt_windows = center_to_edges(video_labels["boxes"]) * video_labels["duration"]
            pred_windows = center_to_edges(video_predictions) * video_labels["duration"]
            pred_windows = np.round(pred_windows / 2, 0) * 2
            pred_windows = np.clip(pred_windows, a_min=0, a_max=150)
            
            qid_prediction = [(window[0].item(), window[1].item(), score) for window, score in zip(pred_windows, video_scores)]
            qid_label = [(window[0].item(), window[1].item()) for window in gt_windows]
            
        
            labels.append(qid_label)
            predictions.append(qid_prediction)
    
    return labels, predictions  

def compute_metrics(eval_preds: EvalPrediction):
    labels, predictions = process_preds_and_labels(eval_preds)
    metrics_dict = compute_mr_ap(predictions, labels, num_workers=8)
    
    return {
        "mAP@0.5": metrics_dict["0.5"],
        "mAP@0.7": metrics_dict["0.7"],
        "mAP": metrics_dict["average"],
    }

In [None]:

train_dataset = QVDataset("qvhighlights_features\\text_features", "qvhighlights_features\\video_features", "qvhighlights_features\\highlight_train_release.jsonl")
eval_dataset = QVDataset("qvhighlights_features\\text_features", "qvhighlights_features\\video_features", "qvhighlights_features\\highlight_val_release.jsonl")

config = VideoDetrConfig(
    auxiliary_loss=True,
    video_embedding_dim=514,
    d_model=256,
    encoder_layers=2,
    encoder_ffn_dim=1024,
    decoder_layers=2,
    decoder_ffn_dim=1024,
    num_queries=10,
    dropout=0.1,
    activation_dropout=0.1,
    giou_cost=1,
    bbox_cost=10,
    class_cost=4,
    giou_loss_coefficient=1,
    bbox_loss_coefficient=10,
    ce_loss_coefficient=4,
    saliency_loss_coefficient=1,
    num_labels=1,
)
logger.info("Running model using config %s", config)

model = MomentDetr(config)

train_args = TrainingArguments(
    "./train_output",
    per_device_train_batch_size=32,
    gradient_accumulation_steps=1,
    learning_rate=1e-4,
    lr_scheduler_type="constant_with_warmup",
    warmup_steps=500,
    num_train_epochs=200,
    save_steps=1000,
    eval_strategy="steps",
    eval_steps=1000,
    load_best_model_at_end=True,
    greater_is_better=True,
    max_grad_norm=0.1,
    label_names=["labels"],
    weight_decay=1e-4,
    eval_do_concat_batches=False,
    metric_for_best_model="mAP"
    # use_cpu=True
)

trainer = Trainer(
    model=model,
    args=train_args,
    data_collator=pad_collate,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics
)

trainer.train()

2024-07-22 08:13:14,751 - INFO video_moment_retrieval - 1329379645.py:24 - Running model using config VideoDetrConfig {
  "activation_dropout": 0.1,
  "activation_function": "relu",
  "attention_dropout": 0.0,
  "auxiliary_loss": true,
  "backbone": "resnet50",
  "backbone_config": null,
  "backbone_kwargs": {
    "in_chans": 3,
    "out_indices": [
      1,
      2,
      3,
      4
    ]
  },
  "bbox_cost": 10,
  "bbox_loss_coefficient": 10,
  "ce_loss_coefficient": 4,
  "class_cost": 4,
  "d_model": 256,
  "decoder_attention_heads": 8,
  "decoder_ffn_dim": 1024,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 2,
  "dice_loss_coefficient": 1,
  "dilation": false,
  "dropout": 0.1,
  "encoder_attention_heads": 8,
  "encoder_ffn_dim": 1024,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 2,
  "eos_coefficient": 0.1,
  "giou_cost": 1,
  "giou_loss_coefficient": 1,
  "hinge_loss_margin": 0.2,
  "id2label": {
    "0": "LABEL_0"
  },
  "init_std": 0.02,
  "init_xavier_std": 1.0,
  "is_enco

Step,Training Loss,Validation Loss,Map@0.5,Map@0.7,Map
1000,10.0908,9.196931,18.17,5.26,6.42
2000,8.7266,8.57294,21.0,7.17,7.97
3000,8.197,8.266402,23.24,9.6,10.01
4000,8.0366,8.220374,21.19,8.08,8.95
5000,7.8847,8.268519,26.0,11.63,11.69
6000,7.6917,8.106082,29.71,13.05,13.27
7000,7.4822,7.940238,34.18,16.42,16.21
8000,7.165,7.851285,39.9,20.05,19.51
9000,6.9372,7.8535,44.17,24.85,23.09
10000,6.6397,7.95366,46.16,27.75,24.14
