In [1]:
import gc
import os
import time
from datetime import datetime
import argparse

import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torchvision.transforms import Compose
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR
import torch.nn.functional as F

from src.opts.opts import parser
from src.utils.reproducibility import make_reproducible
from src.models.model import VideoModel
from src.dataset.video_dataset import VideoDatasetTest, prepare_clips_data_test
from src.dataset.video_transforms import GroupMultiScaleCrop, Stack, ToTorchFormatTensor, GroupNormalize
from src.utils.meters import AverageMeter
from src.utils.metrics import calc_accuracy

In [2]:
from collections import defaultdict
import json
from typing import Dict, Any, Callable, Literal
import numpy as np
from torch.utils.data import Dataset

from src.dataset.utils import make_video_path
from src.dataset.video_dataset import prepare_clips_data, prepare_clips_data_test
from src.dataset.video_dataset import VideoDatasetTest

%load_ext autoreload
%autoreload 2

In [3]:
!ls /Users/artemmerinov/data/

[1m[36mDiving[m[m                             data-annotation-trainval-v1_1.json
[1m[36mMNIST[m[m                              [1m[36mholoassist[m[m
[1m[36mSAINT[m[m                              [1m[36moberalp[m[m
[1m[36mbackbones[m[m                          [1m[36mucf101[m[m
[1m[36mceleba_hq_256[m[m                      video_pitch_shifted.tar


In [4]:
key_list, video_name_arr, start_arr, end_arr = prepare_clips_data_test(
    holoassist_dir="/Users/artemmerinov/data/holoassist/HoloAssist",
    test_action_clips_file="/Users/artemmerinov/data//holoassist/test_action_clips.txt"
)

Number of successful fine-grained clips: 52 
Number of failed fine-grained clips: 40497 
Number of successful videos: 1 
Number of failed videos: 462 
Failed video names are: {'z089-july-07-22-espresso', 'z064-june-29-22-marius_assemble', 'z124-aug-10-22-switch', 'R065-15July-Belt', 'z009-june-15-22-gladom_assemble', 'z145-aug-18-22-marius_disassemble', 'z076-july-01-22-printer_small', 'R162-3Oct-RAM', 'z027-june-22-22-switch', 'z141-aug-16-22-dslr', 'R145-02Sep-RAM', 'z183-sep-08-22-gopro', 'z088-july-07-22-printer_big', 'z203-sep-24-22-nespresso', 'z168-sep-01-22-gopro', 'z078-july-01-22-gladom_assemble', 'z053-june-27-22-marius_disassemble', 'z131-aug-11-22-knarrevik_assemble', 'z197-sep-18-22-knarrevik_disassemble', 'z045-june-24-22-marius_disassemble', 'R051-13July-SmallPrinter', 'z173-sep-04-22-printer_big', 'z200-sep-22-22-knarrevik_disassemble', 'z151-aug-25-22-gladom_disassemble', 'z160-aug-29-22-dslr', 'z055-june-27-22-rashult_disassemble', 'z083-july-06-22-gladom_disassemble

In [8]:
key_list

['z209-sep-28-22-gladom_disassemble_0.963_2.698',
 'z209-sep-28-22-gladom_disassemble_2.705_2.969',
 'z209-sep-28-22-gladom_disassemble_2.979_3.607',
 'z209-sep-28-22-gladom_disassemble_3.614_6.296',
 'z209-sep-28-22-gladom_disassemble_6.449_9.233',
 'z209-sep-28-22-gladom_disassemble_9.247_11.127',
 'z209-sep-28-22-gladom_disassemble_11.137_12.354',
 'z209-sep-28-22-gladom_disassemble_12.369_13.447',
 'z209-sep-28-22-gladom_disassemble_13.466_15.271',
 'z209-sep-28-22-gladom_disassemble_15.298_16.855',
 'z209-sep-28-22-gladom_disassemble_16.858_17.221',
 'z209-sep-28-22-gladom_disassemble_17.250_20.887',
 'z209-sep-28-22-gladom_disassemble_20.904_21.206',
 'z209-sep-28-22-gladom_disassemble_21.218_22.577',
 'z209-sep-28-22-gladom_disassemble_22.588_24.061',
 'z209-sep-28-22-gladom_disassemble_24.072_26.434',
 'z209-sep-28-22-gladom_disassemble_26.443_28.365',
 'z209-sep-28-22-gladom_disassemble_28.375_29.777',
 'z209-sep-28-22-gladom_disassemble_29.795_30.883',
 'z209-sep-28-22-gladom

In [6]:
video_name_arr[12].decode(), start_arr[12].decode(), end_arr[12].decode()

('z209-sep-28-22-gladom_disassemble', '20.904', '21.206')

In [7]:
f"{video_name_arr[12].decode()}_{start_arr[12].decode()}_{end_arr[12].decode()}"

'z209-sep-28-22-gladom_disassemble_20.904_21.206'

In [7]:
parser = argparse.ArgumentParser()
parser.add_argument("--holoassist_dir", type=str, default="/Users/artemmerinov/data/holoassist/HoloAssist")
parser.add_argument("--test_action_clips_file", type=str, default="/Users/artemmerinov/data/holoassist/test_action_clips.txt")
parser.add_argument("--fga_map_file", type=str, default="/Users/artemmerinov/data/holoassist/fine_grained_actions_map.txt")
parser.add_argument("--base_model", type=str, default="InceptionV3")
parser.add_argument("--fusion_mode", type=str, default="GSF")
parser.add_argument("--num_segments", type=int, default=8)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--num_workers", type=int, default=12)
parser.add_argument("--prefetch_factor", type=int, default=4)
parser.add_argument("--repetitions", type=int, default=3, help="Number of spatial and temporal sampling to achieve better precision in evaluation.")
parser.add_argument("--num_classes", type=int, default=1887)
parser.add_argument("--checkpoint", type=str, default="/Users/artemmerinov/PycharmProjects/holoassist-challenge/checkpoints/holoassist_InceptionV3_GSF_action_10.pth", help="Best model weigths.")
args = parser.parse_args([])
print(args)

Namespace(holoassist_dir='/Users/artemmerinov/data/holoassist/HoloAssist', test_action_clips_file='/Users/artemmerinov/data/holoassist/test_action_clips.txt', fga_map_file='/Users/artemmerinov/data/holoassist/fine_grained_actions_map.txt', base_model='InceptionV3', fusion_mode='GSF', num_segments=8, batch_size=32, num_workers=12, prefetch_factor=4, repetitions=3, num_classes=1887, checkpoint='/Users/artemmerinov/PycharmProjects/holoassist-challenge/checkpoints/holoassist_InceptionV3_GSF_action_10.pth')


In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = VideoModel(
    num_classes=args.num_classes, 
    num_segments=args.num_segments, 
    base_model=args.base_model,
    fusion_mode=args.fusion_mode,
    verbose=False,
).to(device)

input_size = model.input_size
crop_size = model.crop_size
input_mean = model.input_mean
input_std = model.input_std
div = model.div

# Parallel!
model = torch.nn.DataParallel(model).to(device)

#  ========================= LOAD MODEL STATE =========================
# 

checkpoint = torch.load(f=args.checkpoint, map_location=device)
model.load_state_dict(state_dict=checkpoint["model_state_dict"], strict=False)

#  ========================= PREPARE CLIPS DATA =========================
#
video_name_arr, start_arr, end_arr = prepare_clips_data_test(
    holoassist_dir="/Users/artemmerinov/data/holoassist/HoloAssist",
    test_action_clips_file="/Users/artemmerinov/data//holoassist/test_action_clips.txt"
)

=> Using GSF fusion
No. of GSF modules = 11
Number of successful fine-grained clips: 52 
Number of failed fine-grained clips: 40497 
Number of successful videos: 1 
Number of failed videos: 462 
Failed video names are: {'z028-june-22-22-rashult_disassemble', 'R125-16Aug-ATV', 'z165-aug-31-22-rashult_disassemble', 'z045-june-24-22-printer_small', 'R070-19July-GoPro', 'R080-21July-Switch', 'z201-sep-22-22-printer_small', 'z182-sep-08-22-rashult_disassemble', 'z028-june-22-22-gladom_disassemble', 'R088-29July-Coffee', 'z202-sep-23-22-rashult_assemble', 'z032-june-22-22-gladom_assemble', 'z038-june-23-22-gopro', 'z194-sep-16-22-gopro', 'z124-aug-10-22-printer_big', 'z023-june-21-22-knarrevik_assemble', 'z081-july-06-22-knarrevik_disassemble', 'z090-july-08-22-gladom_disassemble', 'z190-sep-10-22-printer_big', 'z096-july-12-22-dslr', 'z034-june-23-22-dslr', 'z054-june-27-22-switch', 'z116-aug-05-22-marius_assemble', 'R207-11Nov-ATV-part2', 'z023-june-21-22-rashult_assemble', 'z153-aug-25-22

In [9]:
probs = torch.empty((args.repetitions, len(video_name_arr), args.num_classes), dtype=torch.float32)

for repeat_id in range(args.repetitions):
    print(
        f"\nrepeat_id={repeat_id}",
        f"time={datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S')}",
        flush=True
    )

    #  ========================= DATALOADER =========================
    # 

    # Make dataloader for each repeat.

    transform = Compose([
        GroupMultiScaleCrop(input_size=input_size, scales=[1, .875]),
        Stack(),
        ToTorchFormatTensor(div=div),
        GroupNormalize(mean=input_mean, std=input_std),
    ])
    dataset = VideoDatasetTest(
        holoassist_dir=args.holoassist_dir,
        video_name_arr=video_name_arr,
        start_arr=start_arr,
        end_arr=end_arr,
        num_segments=args.num_segments,
        transform=transform,
    )
    dataloader = DataLoader(
        dataset=dataset, 
        batch_size=args.batch_size, 
        shuffle=False,
        num_workers=args.num_workers, 
        drop_last=False, 
        pin_memory=True,
        prefetch_factor=args.prefetch_factor,
    )


repeat_id=0 time=2024-05-21 15:30:40

repeat_id=1 time=2024-05-21 15:30:40

repeat_id=2 time=2024-05-21 15:30:40


In [10]:
dataset[0]

('z209-sep-28-22-gladom_disassemble_0.963_2.698',
 tensor([[[ 0.0667,  0.0745,  0.0667,  ..., -0.5294, -0.5294, -0.5294],
          [ 0.0667,  0.0745,  0.0667,  ..., -0.5294, -0.5294, -0.5294],
          [ 0.0667,  0.0745,  0.0667,  ..., -0.5294, -0.5294, -0.5294],
          ...,
          [-0.3804, -0.3412, -0.3490,  ..., -0.4745, -0.4667, -0.4588],
          [-0.4431, -0.3882, -0.3569,  ..., -0.4824, -0.4667, -0.4588],
          [-0.2627, -0.3333, -0.3255,  ..., -0.4902, -0.4667, -0.4588]],
 
         [[-0.0196, -0.0118, -0.0196,  ..., -0.5608, -0.5608, -0.5608],
          [-0.0196, -0.0118, -0.0196,  ..., -0.5608, -0.5608, -0.5608],
          [-0.0196, -0.0118, -0.0196,  ..., -0.5608, -0.5608, -0.5608],
          ...,
          [-0.4431, -0.4275, -0.4431,  ..., -0.5216, -0.5137, -0.5059],
          [-0.5059, -0.4588, -0.4353,  ..., -0.5294, -0.5137, -0.5059],
          [-0.3255, -0.3961, -0.3961,  ..., -0.5373, -0.5137, -0.5059]],
 
         [[-0.0588, -0.0588, -0.0745,  ..., -0.623

In [18]:
keys = ["one","two","3","4","5","6","7"]
logits = torch.rand((7,1222))

In [22]:
_, ids = torch.topk(logits, k=5, dim=1)

In [23]:
ids.size()

torch.Size([7, 5])

In [24]:
for k,v in zip(keys, ids):
    pass

In [27]:
v.tolist()

[84, 1118, 61, 43, 1006]