## Extract video features

In [38]:
import ffmpeg
import os

output_path = "test-frames"
os.makedirs(output_path, exist_ok=True)
output_format = "test-frames\\{video_id}-%03d.jpg"

test_video = ".\\qvhilights_videos\\videos\\ZzGSP0ySLD0_510.0_660.0.mp4"
video_id = os.path.basename(test_video).rsplit(".", 1)[0]
print(video_id)
video_stream = ffmpeg.input(test_video)
output = ffmpeg.output(video_stream, output_format.format(video_id=video_id), r=0.5)

ZzGSP0ySLD0_510.0_660.0


In [43]:
ffmpeg.probe(test_video)["streams"][0]["width"]

534

In [94]:
from transformers import CLIPImageProcessor, CLIPVisionModel, CLIPTextModel, CLIPTokenizerFast

model_checkpoint = "openai/clip-vit-base-patch32"
model = CLIPVisionModel.from_pretrained(model_checkpoint).to("cuda")
model.eval()
text_model = CLIPTextModel.from_pretrained(model_checkpoint)
text_model.eval()
processor = CLIPImageProcessor.from_pretrained(model_checkpoint)
text_processor = CLIPTokenizerFast.from_pretrained(model_checkpoint)



In [112]:
import numpy as np
import numpy.typing as npt
import torch

@torch.no_grad()
def encode_video(video_path: str, fps=0.5) -> npt.NDArray[np.float32]:
    video_info = ffmpeg.probe(video_path)
    frame_width = video_info["streams"][0]["width"]
    frame_height = video_info["streams"][0]["height"]
    
    process = ffmpeg.input(
        video_path
    ).output("pipe:", r=fps, format='rawvideo', pix_fmt="rgb24"
    ).run_async(pipe_stdout=True, pipe_stderr=True)
    video_array = []
    
    while True:
        in_bytes = process.stdout.read(frame_width * frame_height * 3)
        if not in_bytes:
            break
        in_frame = np.frombuffer(in_bytes, np.uint8).reshape((frame_height, frame_width, 3))
        video_array.append(in_frame)
    video_array = np.array(video_array, dtype=np.uint8)
    # process.kill()
    model_input = processor(video_array, return_tensors="pt").to("cuda")
    model_output = model(**model_input)
    return model_output.pooler_output.cpu().numpy()

@torch.no_grad()
def encode_text(text: str) -> npt.NDArray[np.float32]:
    model_input = text_processor(text, return_tensors="pt")
    model_output = text_model(**model_input)
    return model_output.last_hidden_state.squeeze(0).cpu().numpy()

In [85]:
video_path = ".\\qvhilights_videos\\videos\\ZzGSP0ySLD0_510.0_660.0.mp4"
encoded_video = encode_video(video_path, 0.49)
print(encoded_video.shape)

<class 'subprocess.Popen'>
(75, 768)


In [102]:
encoded_text = encode_text('A video covering hill and water from a boat')
print(encoded_text.shape)

(11, 512)


In [93]:
import jsonlines

with jsonlines.open(".\\qvhilights_videos\\highlight_test_release.jsonl") as reader:
    for el in reader:
        print(el)
        break

{'qid': 3158, 'query': 'A video covering hill and water from a boat', 'duration': 150, 'vid': '_6hnl_BrFvs_360.0_510.0'}


In [118]:
import jsonlines
import tqdm

def encode_dataset(dataset_path: str, dataset_info_file: str, output_path: str) -> None:
    os.makedirs(output_path, exist_ok=True)
    video_features_output_path = os.path.join(output_path, "video_features")
    os.makedirs(video_features_output_path, exist_ok=True)
    
    text_features_output_path = os.path.join(output_path, "text_features")
    os.makedirs(text_features_output_path, exist_ok=True)
    
    with jsonlines.open(dataset_info_file) as reader:
        for el in tqdm.tqdm(reader):
            video_in_path = os.path.join(dataset_path, el["vid"] + ".mp4")
            video_out_path = os.path.join(video_features_output_path, el["vid"])
            text_out_path = os.path.join(text_features_output_path, str(el["qid"]))
            
            if not os.path.exists(video_out_path):
                video_features = encode_video(video_in_path, fps=0.49)
                np.save(video_out_path, video_features)
            
            if not os.path.exists(text_out_path):
                text_features = encode_text(el["query"])
                np.save(text_out_path, text_features)

In [None]:
encode_dataset(".\\qvhilights_videos\\videos", ".\\qvhilights_videos\\highlight_val_release.jsonl", "dis_features")

## MomentDETR Own Implementation

### TODO

- [ ] Add saliency loss
- [ ] Add evaluation metrics
- [ ] Match paper performance (implies comparing with it and fixing or adding anything missing)
- [ ] Refactor code (cloesly follow HuggingFace model?) (also keep in mind what can be reused between different SOTAs)
- [ ] WandB/Tensorboard interface? (easy iteration and rersults)

In [3]:
from torch.utils.data import Dataset, DataLoader
import numpy as np
import numpy.typing as npt
import os
import jsonlines
from typing import Any

QVDataPointTarget = dict[str, Any]
QVDataPoint = tuple[npt.NDArray[np.float32], npt.NDArray[np.float32], None | QVDataPointTarget]

class QVDataset(Dataset):
    def __init__(self, text_features_path: str, video_features_path: str, data_file_path: str):
        super().__init__()
        self.text_features_path = text_features_path
        self.video_features_path = video_features_path
        self.data = []
        with jsonlines.open(data_file_path) as reader:
            for el in reader:
                self.data.append(el)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index: int) -> QVDataPoint:
        item_info = self.data[index]
        qid = item_info["qid"]
        vid = item_info["vid"]
        query_features = np.load(os.path.join(self.text_features_path, str(qid) + ".npy"))
        video_features = np.load(os.path.join(self.video_features_path, vid + ".npy"))
        label = None
        if "relevant_clip_ids" in item_info:
            label = {
                "relevant_clip_ids": item_info["relevant_clip_ids"], 
                "saliency_scores": item_info["saliency_scores"], 
                "relevant_windows": item_info["relevant_windows"],
                "duration": item_info["duration"]
            }
        return (query_features, video_features, item_info["query"], label)
        

In [4]:
import torch

# define Data collator function
def pad_collate(samples: list[QVDataPoint]):
    # Need to: pad video and text features -> create attention video/text attention
    text_lens = [el[0].shape[0] for el in samples]
    video_lens = [el[1].shape[0] for el in samples]
    
    batch_size = len(samples)
    text_len = max(text_lens)
    text_hidden = samples[0][0].shape[-1]
    video_len = max(video_lens)
    video_hidden = samples[0][1].shape[-1]
    
    text_features = torch.zeros((batch_size, text_len, text_hidden))
    text_attn_mask = torch.ones((batch_size, text_len))
    video_features = torch.zeros((batch_size, video_len, video_hidden))
    video_attn_mask = torch.ones((batch_size, video_len))
    
    for idx, sample in enumerate(samples):
        sample_text_len = sample[0].shape[0]
        sample_video_len = sample[1].shape[0]
        text_features[idx, :sample_text_len, :] = torch.tensor(sample[0])
        video_features[idx, :sample_video_len, :] = torch.tensor(sample[1])
        text_attn_mask[idx, sample_text_len:] = 0
        video_attn_mask[idx, sample_video_len:] = 0
    
    labels = None
    # We have labels
    if samples[0][2] is not None:
        # build classes array
        labels = []
        for (_, _, sample) in samples:
            duration = sample["duration"]
            boxes = sample["relevant_windows"]
            class_labels = torch.zeros((len(boxes), ), dtype=torch.int64) 
            labels.append({
                "boxes": torch.tensor([[box[0] / duration, box[1] / duration] for box in boxes]),
                "class_labels": class_labels
            })
    
    return  {
        "text_features": text_features,
        "text_attn_mask": text_attn_mask,
        "video_features": video_features,
        "video_attn_mask": video_attn_mask,
        "labels": labels
    }
        

In [5]:
train_dataset = QVDataset("qvhighlights_features\\text_features", "qvhighlights_features\\video_features", "qvhighlights_features\\highlight_train_release.jsonl")
train_dataset[0]

(array([[ 0.33928594,  0.11646017,  0.10195109, ...,  0.24677397,
          0.5906364 ,  0.10129976],
        [ 0.5391787 ,  0.6050957 , -0.29105273, ...,  0.11356771,
          0.07524773, -1.3233696 ],
        [ 0.34825635, -0.5755961 , -0.18056145, ...,  1.5290366 ,
          0.17571366, -0.41787058],
        ...,
        [ 0.78474975,  0.6386728 , -0.09147779, ..., -0.27714124,
         -1.1072998 ,  0.3630575 ],
        [ 0.50689864, -0.25536713, -0.9592153 , ...,  0.5298033 ,
          1.6603259 , -0.7617836 ],
        [ 0.72716683, -0.6981834 ,  0.03235934, ...,  1.5555809 ,
          0.70241404, -0.5428549 ]], dtype=float32),
 array([[ 0.00542442,  1.2323831 ,  0.596713  , ..., -0.30323648,
          1.965467  , -0.3933972 ],
        [ 0.00542442,  1.2323831 ,  0.596713  , ..., -0.30323648,
          1.965467  , -0.3933972 ],
        [-0.09869489,  0.5870662 ,  0.19045241, ..., -1.0922167 ,
          1.7422361 , -0.04100407],
        ...,
        [-0.47089767, -0.5834327 , -1.2

In [30]:

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=pad_collate)

In [31]:
batch = next(iter(train_loader))
batch

{'text_features': tensor([[[ 0.3393,  0.1165,  0.1020,  ...,  0.2468,  0.5906,  0.1013],
          [ 1.9753, -0.5844,  0.3685,  ...,  1.1658,  0.8050, -0.9801],
          [ 1.0568, -0.3584, -0.1190,  ...,  0.4568, -0.9863,  0.2836],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
 
         [[ 0.3393,  0.1165,  0.1020,  ...,  0.2468,  0.5906,  0.1013],
          [ 1.9753, -0.5844,  0.3685,  ...,  1.1658,  0.8050, -0.9801],
          [ 0.0405, -0.4019, -0.2012,  ...,  0.5851,  0.2978, -2.3808],
          ...,
          [-0.1568, -0.2916, -0.0831,  ..., -0.3611,  0.4270, -0.1416],
          [-0.0660, -0.6583, -0.6306,  ...,  0.3031, -1.0054,  0.6075],
          [ 1.0487, -0.2953, -0.0168,  ...,  0.9910, -0.9835, -0.6988]]]),
 'text_attn_mask': tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.,

In [10]:
text_attn_mask = batch["text_attn_mask"]
print(text_attn_mask)
torch.argmax((text_attn_mask < 0.5).to(torch.int8), dim=1)[0]

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1.]])


tensor(12)

In [39]:
print(batch["video_features"].shape)

torch.Size([2, 75, 768])


In [6]:
from transformers import DetrConfig, DetrForObjectDetection, DetrModel
from transformers.models.detr.modeling_detr import DetrEncoder, DetrDecoder, DetrSinePositionEmbedding
import torch.nn as nn
import torch
from typing import Optional
from scipy.optimize import linear_sum_assignment
from transformers.modeling_utils import PreTrainedModel


# taken from https://github.com/facebookresearch/detr/blob/master/models/matcher.py
class MomentDetrHungarianMatcher(nn.Module):
    """
    This class computes an assignment between the targets and the predictions of the network.

    For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
    predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
    un-matched (and thus treated as non-objects).

    Args:
        class_cost:
            The relative weight of the classification error in the matching cost.
        bbox_cost:
            The relative weight of the L1 error of the bounding box coordinates in the matching cost.
        giou_cost:
            The relative weight of the giou loss of the bounding box in the matching cost.
    """

    def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):
        super().__init__()

        self.class_cost = class_cost
        self.bbox_cost = bbox_cost
        self.giou_cost = giou_cost
        if class_cost == 0 and bbox_cost == 0 and giou_cost == 0:
            raise ValueError("All costs of the Matcher can't be 0")

    @torch.no_grad()
    def forward(self, outputs: torch.FloatTensor, targets: list[dict]):
        """
        Args:
            outputs (`dict`):
                A dictionary that contains at least these entries:
                * "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
                * "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
            targets (`List[dict]`):
                A list of targets (len(targets) = batch_size), where each target is a dict containing:
                * "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
                  ground-truth
                 objects in the target) containing the class labels
                * "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.

        Returns:
            `List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
            - index_i is the indices of the selected predictions (in order)
            - index_j is the indices of the corresponding selected targets (in order)
            For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
        """
        batch_size, num_queries = outputs["logits"].shape[:2]

        # We flatten to compute the cost matrices in a batch
        out_prob = outputs["logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
        out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 2]

        # Also concat the target labels and boxes
        target_ids = torch.cat([v["class_labels"] for v in targets])
        target_bbox = torch.cat([v["boxes"] for v in targets])

        # Compute the classification cost. Contrary to the loss, we don't use the NLL,
        # but approximate it in 1 - proba[target class].
        # The 1 is a constant that doesn't change the matching, it can be ommitted.
        class_cost = -out_prob[:, target_ids]

        # Compute the L1 cost between boxes
        bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)

        # Compute the giou cost between boxes
        giou_cost = -generalized_moment_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))

        # Final cost matrix
        cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
        cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()

        sizes = [len(v["boxes"]) for v in targets]
        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]

def center_to_corners_format(box: torch.FloatTensor) -> torch.FloatTensor:
    center, width = box.unbind(-1)
    return torch.stack(
        [center - 0.5 * width, center + 0.5 * width],
        dim=-1
    )

def generalized_moment_iou(box1: torch.FloatTensor, box2: torch.FloatTensor) -> torch.FloatTensor:
    def moment_iou(box1: torch.FloatTensor, box2: torch.FloatTensor) -> tuple[torch.FloatTensor, torch.FloatTensor]:
        # box1 [batch * queries, 2]
        # box2 [num_target, 2]
        length1 = (box1[:, 1] - box1[:, 0])
        length2 = (box2[:, 1] - box2[:, 0])
        left = torch.max(box1[:, None, 0], box2[:, 0])
        right = torch.min(box1[:, None, 1], box2[:, 1])
        intersection = (right - left).clamp(min = 0)
        union = length1[:, None] + length2 - intersection
        
        return intersection / union, union 
    
    iou, union = moment_iou(box1, box2)
    left = torch.min(box1[:, None, 0], box2[:, 0])
    right = torch.max(box1[:, None, 1], box2[:, 1])
    area = right - left
    return iou - (area - union) / area
        
class MomentDetrLoss(nn.Module):
    def __init__(self, matcher, num_classes, eos_coef, losses):
        super().__init__()
        self.matcher = matcher
        self.num_classes = num_classes
        self.eos_coef = eos_coef
        self.losses = losses
        empty_weight = torch.ones(self.num_classes + 1)
        empty_weight[-1] = self.eos_coef
        self.register_buffer("empty_weight", empty_weight)
    
    def loss_labels(self, outputs, targets, indices, num_boxes):
        """
        Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim
        [nb_target_boxes]
        """
        if "logits" not in outputs:
            raise KeyError("No logits were found in the outputs")
        source_logits = outputs["logits"]

        idx = self._get_source_permutation_idx(indices)
        target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(
            source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
        )
        target_classes[idx] = target_classes_o

        loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight)
        losses = {"loss_ce": loss_ce}

        return losses

    
    def loss_boxes(self, outputs, targets, indices, num_boxes):
        """
        Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.

        Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes
        are expected in format (center_x, center_y, w, h), normalized by the image size.
        """
        if "pred_boxes" not in outputs:
            raise KeyError("No predicted boxes found in outputs")
        idx = self._get_source_permutation_idx(indices)
        source_boxes = outputs["pred_boxes"][idx]
        target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)

        loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")

        losses = {}
        losses["loss_bbox"] = loss_bbox.sum() / num_boxes

        loss_giou = 1 - torch.diag(
            generalized_moment_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
        )
        losses["loss_giou"] = loss_giou.sum() / num_boxes
        return losses

    def _get_source_permutation_idx(self, indices):
        # permute predictions following indices
        batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
        source_idx = torch.cat([source for (source, _) in indices])
        return batch_idx, source_idx

    def get_loss(self, loss, outputs, targets, indices, num_boxes):
        loss_map = {
            "labels": self.loss_labels,
            # "cardinality": self.loss_cardinality,
            "boxes": self.loss_boxes,
            # "masks": self.loss_masks,
        }
        if loss not in loss_map:
            raise ValueError(f"Loss {loss} not supported")
        return loss_map[loss](outputs, targets, indices, num_boxes)

    def forward(self, outputs, targets):
        """
        This performs the loss computation.

        Args:
             outputs (`dict`, *optional*):
                Dictionary of tensors, see the output specification of the model for the format.
             targets (`List[dict]`, *optional*):
                List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
                losses applied, see each loss' doc.
        """
        outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}

        # Retrieve the matching between the outputs of the last layer and the targets
        indices = self.matcher(outputs_without_aux, targets)

        # Compute the average number of target boxes across all nodes, for normalization purposes
        num_boxes = sum(len(t["class_labels"]) for t in targets)
        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
        world_size = 1
        num_boxes = torch.clamp(num_boxes / world_size, min=1).item()

        # Compute all the requested losses
        losses = {}
        for loss in self.losses:
            losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
        
        return losses

def positional_encodings(input_embs: torch.FloatTensor, n = 10_000):
    batch_size, seq_len, d = input_embs.shape
    position = torch.arange(0, seq_len).unsqueeze_(1)
    denominator = torch.pow(n, 2 * torch.arange(0, d//2) / d)
    
    pos_enc = position / denominator
    encodings = torch.zeros((seq_len, d))
    encodings[:, 0::2] = pos_enc.sin()
    encodings[:, 1::2] = pos_enc.cos()
    return encodings.unsqueeze(0).repeat(batch_size, 1, 1)

class ProjectionMLP(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, n_layers: int = 2):
        super().__init__()
        self.layers = nn.ModuleList()
        h = [hidden_dim] * (n_layers - 1)
        for in_dim, out_dim in zip([input_dim] + h, h + [output_dim]):
            self.layers.append(nn.LayerNorm(normalized_shape=in_dim))
            self.layers.append(nn.Dropout())
            self.layers.append(nn.Linear(in_dim, out_dim))
            self.layers.append(nn.ReLU())

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        for layer in self.layers:
            x = layer(x)
        return x

class MomentDetr(PreTrainedModel):
    def __init__(self, config: DetrConfig):
        super().__init__(config)
        
        self.config = config
        
        self.encoder = DetrEncoder(config)
        self.decoder = DetrDecoder(config)
        self.positions = DetrSinePositionEmbedding(normalize=True)
        
        self.text_projection = ProjectionMLP(512, config.hidden_size, config.hidden_size)
        self.video_projection = ProjectionMLP(768, config.hidden_size, config.hidden_size)
        
        self.object_queries = nn.Embedding(config.num_queries, config.d_model)
        
        # Define saliency score predictor head
        self.sal_predictor = nn.Linear(config.hidden_size, 1)
        
        # Define window bounds predictor head
        self.moment_predictor = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.ReLU(),
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.ReLU(),
            nn.Linear(config.hidden_size, 2),
            nn.ReLU()
        )
        
        # Define window class predictor head
        self.classifier = nn.Linear(config.hidden_size, 2)
        self.post_init()
        
        
    def forward(
        self,
        video_features: torch.FloatTensor,
        video_attn_mask: Optional[torch.FloatTensor],
        text_features: torch.FloatTensor,
        text_attn_mask: Optional[torch.FloatTensor],
        labels: Optional[list[dict]] = None
    ) -> torch.FloatTensor:
        batch_size, video_seq_len, _ = video_features.shape
        _, text_seq_len, _ = text_features.shape
        
        if video_attn_mask is None:
            video_attn_mask = torch.ones((batch_size, video_seq_len), device = video_features.device)
        if text_attn_mask is None:
            text_attn_mask = torch.ones((batch_size, text_seq_len), device = text_features.device)
        
        text_projected = self.text_projection(text_features)
        video_projected = self.video_projection(video_features)
        
        concatenated_features = torch.cat([video_projected, text_projected], dim=1)
        attn_mask = torch.cat([video_attn_mask, text_attn_mask], dim=1)
            
        # Compute positional encodings
        positions = positional_encodings(concatenated_features).to(self.device)
        
        # Pass through the encoder using positions and concatenated_features
        encoder_output = self.encoder(
            inputs_embeds=concatenated_features,
            attention_mask=attn_mask,
            object_queries=positions,
        )
        
        # Pass through the decoder using positions, object_queries, encoder_output
        object_queries = self.object_queries.weight.unsqueeze(0).repeat(batch_size, 1, 1)
        decoder_inputs = torch.zeros_like(object_queries)
        
        decoder_output = self.decoder(
            inputs_embeds=decoder_inputs,
            attention_mask=None,
            encoder_hidden_states=encoder_output.last_hidden_state,
            encoder_attention_mask=attn_mask,
            object_queries=positions,
            query_position_embeddings=object_queries
        )
        
        pred_moments = self.moment_predictor(decoder_output.last_hidden_state)
        logits = self.classifier(decoder_output.last_hidden_state)
        
        loss = None
        if labels is not None:
            outputs_loss = {
                "logits": logits,
                "pred_boxes": pred_moments
            }
            matcher = MomentDetrHungarianMatcher(self.config.class_cost, self.config.bbox_cost, self.config.giou_cost)
            # match_ = matcher(outputs, labels)
            criterion = MomentDetrLoss(matcher, self.config.num_labels, self.config.eos_coefficient, ["labels", "boxes"])
            criterion.to(self.device)
            
            loss_dict = criterion(outputs_loss, labels)
            weight_dict = {
                "loss_ce": 4,  # TODO: Pass this in config
                "loss_bbox": self.config.bbox_loss_coefficient,
                "loss_giou": self.config.giou_loss_coefficient
            }
            loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

        return pred_moments, logits, loss
            

In [7]:
def put_batch_on_device(batch, device):
    on_device_batch = {}
    for k, v in batch.items():
        if isinstance(v, torch.Tensor):
            on_device_batch.update({k: v.to(device)})
        else:
            on_device_l = []
            for el in v:
                on_device_dict = {}
                for el_k, el_v in el.items():
                    on_device_dict.update({el_k: el_v.to(device)})
                on_device_l.append(on_device_dict)
            on_device_batch.update({k: on_device_l})
    return on_device_batch

In [36]:
config = DetrConfig(
    d_model=256,
    encoder_layers=2,
    decoder_layers=2,
    num_queries=10,
    # TODO: figure out dropout details
    dropout=0.1,
    activation_dropout=0.1,
    giou_loss_coefficient=1,
    bbox_loss_coefficient=10,
    num_labels=1,
    device="cuda"
)
momentDETR_model = MomentDetr(config).to("cuda")
batch = put_batch_on_device(next(iter(train_loader)), "cuda")

momentDETR_model(**batch)

(tensor([[[0.0000e+00, 1.0283e-01],
          [0.0000e+00, 1.2472e-01],
          [0.0000e+00, 1.0637e-01],
          ...,
          [0.0000e+00, 1.4899e-01],
          [0.0000e+00, 1.3989e-01],
          [0.0000e+00, 1.0796e-01]],
 
         [[0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00]],
 
         [[0.0000e+00, 1.2454e-01],
          [0.0000e+00, 9.5373e-02],
          [0.0000e+00, 8.2522e-02],
          ...,
          [0.0000e+00, 6.0441e-02],
          [0.0000e+00, 1.0840e-01],
          [0.0000e+00, 1.0471e-01]],
 
         ...,
 
         [[0.0000e+00, 0.0000e+00],
          [0.0000e+00, 4.0778e-03],
          [0.0000e+00, 2.6076e-02],
          ...,
          [0.0000e+00, 5.1384e-02],
          [0.0000e+00, 1.0886e-02],
          [0.0000e+00, 7.7266e-02]],
 
         [[0.0000e+00, 1.5711e-02],
          [0

In [13]:
train_dataset =  QVDataset("qvhighlights_features\\text_features", "qvhighlights_features\\video_features", "qvhighlights_features\\highlight_train_release.jsonl")
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=pad_collate)
config = DetrConfig(
    d_model=256,
    encoder_layers=2,
    decoder_layers=2,
    num_queries=10,
    # TODO: figure out dropout details
    dropout=0.1,
    activation_dropout=0.1,
    bbox_cost=10,
    giou_cost=1,
    class_cost=4,
    giou_loss_coefficient=1,
    bbox_loss_coefficient=10,
    num_labels=1
)

momentDETR_model = MomentDetr(config).to("cuda")
epochs = 100
optimizer = torch.optim.AdamW(momentDETR_model.parameters(), lr=2e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 400)
print_every = 100
for epoch in range(epochs):
    momentDETR_model.train()
    batch_losses = []
    for idx, data in enumerate(train_loader):
        data = put_batch_on_device(data, "cuda")
        optimizer.zero_grad()
        _, _, loss = momentDETR_model(**data)
        loss.backward()
        optimizer.step()
        batch_losses.append(loss.detach().cpu().item())
        if idx % print_every == 0:
            print(f"Epoch #{epoch} - loss {batch_losses[-1]:.2f}")
    print(f"Epoch #{epoch} - avg loss {np.mean(batch_losses):.2f}")
    scheduler.step()
        

Epoch #0 - loss 0.50
Epoch #0 - loss 0.74
Epoch #0 - avg loss 0.61
Epoch #1 - loss 0.77
Epoch #1 - loss 0.55
Epoch #1 - avg loss 0.63
Epoch #2 - loss 0.52
Epoch #2 - loss 0.61
Epoch #2 - avg loss 0.59
Epoch #3 - loss 0.53
Epoch #3 - loss 0.55
Epoch #3 - avg loss 0.59
Epoch #4 - loss 0.62
Epoch #4 - loss 0.61
Epoch #4 - avg loss 0.60
Epoch #5 - loss 0.61
Epoch #5 - loss 0.61
Epoch #5 - avg loss 0.59
Epoch #6 - loss 0.62
Epoch #6 - loss 0.57
Epoch #6 - avg loss 0.57
Epoch #7 - loss 0.47
Epoch #7 - loss 0.60
Epoch #7 - avg loss 0.60
Epoch #8 - loss 0.49
Epoch #8 - loss 0.66
Epoch #8 - avg loss 0.60
Epoch #9 - loss 0.58
Epoch #9 - loss 0.56
Epoch #9 - avg loss 0.59
Epoch #10 - loss 0.65
Epoch #10 - loss 0.62
Epoch #10 - avg loss 0.58
Epoch #11 - loss 0.58
Epoch #11 - loss 0.63
Epoch #11 - avg loss 0.57
Epoch #12 - loss 0.59
Epoch #12 - loss 0.61
Epoch #12 - avg loss 0.57
Epoch #13 - loss 0.57
Epoch #13 - loss 0.54
Epoch #13 - avg loss 0.57
Epoch #14 - loss 0.65
Epoch #14 - loss 0.51
Epoch 

In [28]:
momentDETR_model.eval()

val_dataset = QVDataset("qvhighlights_features/text_features",
                        "qvhighlights_features/video_features",
                        "qvhighlights_features/highlight_val_release.jsonl")
val_loader = DataLoader(train_dataset, collate_fn=pad_collate, shuffle=True)
batch = put_batch_on_device(next(iter(val_loader)), "cuda")
a, b, c = momentDETR_model(**batch)

In [27]:
momentDETR_model.save_pretrained("moment-detr-1")

In [29]:
batch["labels"][0]["boxes"] * 150

tensor([[64., 78.]], device='cuda:0')

In [30]:
probs = b.softmax(-1)
indices = torch.where(probs[:, :, 0] > 0.5)
a[indices] * 150

tensor([[63.8942, 81.5064]], device='cuda:0', grad_fn=<MulBackward0>)

In [39]:
batch

{'text_features': tensor([[[ 0.3393,  0.1165,  0.1020,  ...,  0.2468,  0.5906,  0.1013],
          [ 1.9753, -0.5844,  0.3685,  ...,  1.1658,  0.8050, -0.9801],
          [ 1.1319, -0.2811, -0.4251,  ...,  1.1118, -0.7140, -1.5363],
          ...,
          [ 1.6926, -1.6952, -1.8157,  ...,  0.5784, -0.1449,  1.3623],
          [ 1.3815, -2.8677, -2.3773,  ...,  0.3917,  0.0104,  1.5142],
          [ 1.5898, -1.9487, -1.7853,  ...,  0.5196,  0.5867,  1.0427]]],
        device='cuda:0'),
 'text_attn_mask': tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1.]], device='cuda:0'),
 'video_features': tensor([[[ 0.6747,  0.5734,  0.3003,  ..., -0.6585,  0.0275,  0.6033],
          [ 0.6113,  0.5741,  0.1799,  ..., -0.4763, -0.0072,  0.5828],
          [ 0.6008,  0.4625,  0.3425,  ...,  0.2027,  0.2525,  0.9696],
          ...,
          [ 1.0498,  0.3782,  0.3071,  ..., -0.5147,  0.5142,  1.1265],
          [ 0.9453,  0.3807, -0.1985,  ..., -0.4939, 

In [31]:
probs

tensor([[[1.9702e-06, 1.0000e+00],
         [3.2925e-06, 1.0000e+00],
         [1.0000e+00, 1.7850e-06],
         [4.4461e-05, 9.9996e-01],
         [3.1897e-06, 1.0000e+00],
         [1.3979e-04, 9.9986e-01],
         [6.6539e-06, 9.9999e-01],
         [5.5409e-06, 9.9999e-01],
         [4.6470e-06, 1.0000e+00],
         [5.1168e-06, 9.9999e-01]]], device='cuda:0',
       grad_fn=<SoftmaxBackward0>)

In [32]:
momentDETR_model.parameters()

<generator object Module.parameters at 0x0000021B76786CE0>

In [38]:
sum(param.numel() for param in momentDETR_model.parameters())

6386949