In [1]:
import os
import json
from dataclasses import dataclass
from typing import Dict, List, Any, Tuple

import numpy as np
import torch
from torch.utils.data import Dataset
from sklearn.utils.class_weight import compute_class_weight

import av  # pip install av
from transformers import (
    AutoImageProcessor,
    VideoMAEForVideoClassification,
    TrainingArguments,
    Trainer,
)

import evaluate  # pip install evaluate

# load environment variables with dotenv
from dotenv import load_dotenv
load_dotenv()


  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
from train import *
import os


In [3]:
# Point this at the Olympic Boxing dataset directory

# Pretrained VideoMAE base (self-supervised on K400)
model_name = "MCG-NJU/videomae-base"

image_processor = AutoImageProcessor.from_pretrained(model_name)
model = VideoMAEForVideoClassification.from_pretrained(
    model_name,
    num_labels=len(LABEL2ID),
    label2id=LABEL2ID,
    id2label=ID2LABEL,
    torch_dtype=torch.bfloat16
)


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
`torch_dtype` is deprecated! Use `dtype` instead!
Some weights of VideoMAEForVideoClassification were not initialized from the model checkpoint at MCG-NJU/videomae-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
# check for cuda
if torch.cuda.is_available():
    model.to("cuda")
    print("Using CUDA")
else:
    print("Using CPU")

Using CUDA


In [5]:
from sklearn.model_selection import train_test_split
import re
from collections import Counter
import random

def split_data():
    train_paths = []
    val_paths = []
    test_paths = []

    # Collect all paths with their labels and sources
    all_paths = []
    all_labels = []
    all_sources = []

    for label in os.listdir("preprocessed_clips_3"):
        paths = [f"preprocessed_clips_3/{label}/{p}" for p in os.listdir(f"preprocessed_clips_3/{label}")]
        
        for path in paths:
            # Extract source from filename: pattern is clip_task_[kamx_nums]_index_c.npy
            # Source is the part matching task_[kamx_nums]_index_c (everything from task_ to .npy)
            filename = os.path.basename(path)
            # Match task_ followed by any characters until .npy
            source_match = re.search(r'task_kam\d+_[^_]+', filename)
            if source_match:
                source = source_match.group(0)
            else:
                # Fallback: use filename without extension as source
                source = os.path.splitext(filename)[0]
            
            all_paths.append(path)
            all_labels.append(label)
            all_sources.append(source)

    # Create combined stratification key: label_source
    # This ensures both label and source distributions are maintained
    stratify_key = [f"{label}_{source}" for label, source in zip(all_labels, all_sources)]


    # First split: 80% train, 20% temp (which will become val+test)
    train_paths, temp_paths, train_labels, temp_labels, train_sources, temp_sources = train_test_split(
        all_paths, all_labels, all_sources,
        test_size=0.2,
        stratify=stratify_key,
        random_state=632
    )

    # Second split: split temp into 50% val, 50% test (which gives 10% val, 10% test overall)
    # Create new stratification key for temp split
    temp_stratify_key = [f"{label}_{source}" for label, source in zip(temp_labels, temp_sources)]

    temp_counts = Counter(temp_stratify_key)
    min_count = min(temp_counts.values())

    if min_count >= 2:
        val_paths, test_paths, val_labels, test_labels, val_sources, test_sources = train_test_split(
            temp_paths, temp_labels, temp_sources,
            test_size=0.5,
            stratify=temp_stratify_key,
            random_state=632
        )
    else:
        # Fall back to non-stratified split if some classes have < 2 members
        val_paths, test_paths, val_labels, test_labels, val_sources, test_sources = train_test_split(
            temp_paths, temp_labels, temp_sources,
            test_size=0.5,
            stratify=None,
            random_state=632
        )

    # Convert to class variables
    train_paths = train_paths
    val_paths = val_paths
    test_paths = test_paths
    
    return train_paths, val_paths, test_paths


In [6]:
from train import *
import random
class BoxingDataset(Dataset):
    mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
    std  = np.array([0.229, 0.224, 0.225], dtype=np.float32)
    
    all_splits = split_data()
    train_paths = all_splits[0]
    val_paths = all_splits[1]
    test_paths = all_splits[2]

        
        
    def __init__(self, split: str):
        self.split = split
        
        
    def __len__(self):
        if self.split == "train":
            return len(self.train_paths)
        elif self.split == "val":
            return len(self.val_paths)
        elif self.split == "test":
            return len(self.test_paths)
        else:
            raise ValueError(f"Unknown split: {self.split}")

    def __getitem__(self, idx):
        if self.split == "train":
            path = self.train_paths[idx]
        elif self.split == "val":
            path = self.val_paths[idx]
        elif self.split == "test":
            path = self.test_paths[idx]
        else:
            raise ValueError(f"Unknown split: {self.split}")
        
        clip = np.load(path)
        
        # convert to float and scale to 0-1
        clip = clip.astype(np.float32) / 255.0
        
        # image net mean/std
        clip = (clip - self.mean) / self.std
        
        #reorder to (T,C,H,W)
        clip = clip.transpose(0,3,1,2)
        
        #convert to tensor
        clip = torch.from_numpy(clip)
        
        return {
            "pixel_values": clip,
            "labels": torch.tensor(LABEL2ID[path.split("/")[-2]], dtype=torch.long) 
        }



In [7]:
train_dataset = BoxingDataset(
    split="train",
)
val_dataset = BoxingDataset(
    split="val",
)
test_dataset = BoxingDataset(
    split="test",
)

In [8]:
len(train_dataset)

19640

In [9]:

# FACTS used batch_size=4, grad_accum=2, warmup_ratio=0.1, epochs=10
# Learning rate is not rendered in the HTML; start with 1e-4 and tune around it.
training_args = TrainingArguments(
    output_dir="./facts-boxing-videomae",
    eval_strategy="steps",
    eval_steps=500,
    save_steps=500,
    logging_steps=100,
    save_total_limit=2,
    num_train_epochs=15,
    per_device_train_batch_size=24, 
    per_device_eval_batch_size=24, 
    gradient_accumulation_steps=1,  
    warmup_ratio=0.1,
    learning_rate=5e-5,
    weight_decay=0.05,
    bf16=True,
    fp16=False,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,
    report_to="wandb",  # or "wandb"/"tensorboard"
    dataloader_num_workers=4,        # ADD THIS - use multiple workers
    dataloader_pin_memory=True,      # ADD THIS - faster CPU->GPU transfer
    dataloader_prefetch_factor=2, 
    torch_compile=True,
    optim="adamw_torch_fused",


)

data_collator = VideoDataCollator()

train_labels = [LABEL2ID[path.split("/")[-2]] for path in BoxingDataset.train_paths]

class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.arange(len(LABEL2ID)),
    y=np.array(train_labels)  # Ensure it's a numpy array
)
class_weights = torch.tensor(class_weights, dtype=torch.float32)


In [10]:
sample_weights = [class_weights[label] for label in train_labels]
sampler = torch.utils.data.WeightedRandomSampler(
    sample_weights,                                              
    len(sample_weights), 
    replacement=True
)
sampler


<torch.utils.data.sampler.WeightedRandomSampler at 0x7043c47df9a0>

In [11]:
import random
import matplotlib.pyplot as plt
from collections import defaultdict

# sample_counts = defaultdict(int)
# for i in range(500):
#     sampled_idx = random.choices(train_labels, weights=sample_weights, k=1)[0]
#     sample_counts[sampled_idx] += 1

# # use matplotlib to plot the distribution

# plt.bar(sample_counts.keys(), sample_counts.values())
# plt.show()

In [12]:
class BalancedSamplerTrainer(Trainer):
    def __init__(self, *args, train_sampler=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.train_sampler = train_sampler
        
    def _get_train_sampler(self, train_dataset: Dataset | None = None):
        if train_dataset is None:
            train_dataset = self.train_dataset
            
        if train_dataset is None or not has_length(train_dataset):
            return None
        
        if self.train_sampler is not None:
            return self.train_sampler
        
        return super()._get_train_sampler(train_dataset)

trainer = BalancedSamplerTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    train_sampler=sampler,
)


In [13]:
os.environ["WANDB_NAME"] = "20% Jitter Preprocessing - 15 Epochs - 5e-5 - Balanced Sampler - Stratified Split"

# Train
trainer.train()


[34m[1mwandb[0m: Currently logged in as: [33mnkosik11[0m ([33mnkosik11-hobby[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,Accuracy,Macro F1,F1 Lhhp,Precision Lhhp,Recall Lhhp,F1 Rhhp,Precision Rhhp,Recall Rhhp,F1 Lhmp,Precision Lhmp,Recall Lhmp,F1 Rhmp,Precision Rhmp,Recall Rhmp,F1 Lhblp,Precision Lhblp,Recall Lhblp,F1 Rhblp,Precision Rhblp,Recall Rhblp,F1 Lhbp,Precision Lhbp,Recall Lhbp,F1 Rhbp,Precision Rhbp,Recall Rhbp
500,2.076,2.083955,0.12831,0.11218,0.138249,0.334821,0.087108,0.109929,0.221429,0.073113,0.115512,0.203488,0.080645,0.187919,0.125964,0.369811,0.115132,0.0875,0.168269,0.084257,0.051491,0.231707,0.092391,0.062271,0.178947,0.054054,0.050505,0.05814
1000,2.0532,2.086711,0.129124,0.105999,0.00232,1.0,0.001161,0.256767,0.191024,0.391509,0.160982,0.197324,0.135945,0.135283,0.131206,0.139623,0.087977,0.112782,0.072115,0.072368,0.04955,0.134146,0.078091,0.04918,0.189474,0.054201,0.035336,0.116279
1500,2.0253,2.10288,0.114868,0.106715,0.026287,0.230769,0.013937,0.110899,0.292929,0.068396,0.15,0.23301,0.110599,0.106061,0.160305,0.079245,0.195462,0.119403,0.538462,0.098246,0.068966,0.170732,0.056338,0.038462,0.105263,0.110429,0.063604,0.418605
2000,1.6045,1.724771,0.269654,0.248461,0.371681,0.509091,0.292683,0.194079,0.320652,0.139151,0.13211,0.324324,0.082949,0.290258,0.306723,0.275472,0.29681,0.208577,0.514423,0.215743,0.141762,0.45122,0.184486,0.115183,0.463158,0.302521,0.199262,0.627907
2500,1.0502,1.500687,0.356415,0.392248,0.272401,0.596078,0.176539,0.385799,0.387173,0.384434,0.354669,0.325626,0.389401,0.400657,0.354651,0.460377,0.377104,0.290155,0.538462,0.459016,0.415842,0.512195,0.307692,0.194444,0.736842,0.580645,0.652174,0.523256
3000,0.7529,1.38887,0.426477,0.46023,0.361545,0.617978,0.255517,0.429967,0.39839,0.466981,0.412183,0.368421,0.467742,0.443662,0.415842,0.475472,0.470309,0.464789,0.475962,0.564103,0.486726,0.670732,0.42623,0.287823,0.821053,0.57384,0.450331,0.790698
3500,0.6432,1.306968,0.461507,0.500738,0.446441,0.551195,0.375145,0.416999,0.477204,0.370283,0.400468,0.407143,0.394009,0.52454,0.44186,0.645283,0.478261,0.366667,0.6875,0.619355,0.657534,0.585366,0.465753,0.345178,0.715789,0.654088,0.712329,0.604651
4000,0.5171,1.245615,0.506314,0.542429,0.490358,0.602369,0.413473,0.479508,0.423913,0.551887,0.465957,0.432806,0.504608,0.541254,0.480938,0.618868,0.534759,0.60241,0.480769,0.62963,0.6375,0.621951,0.557078,0.491935,0.642105,0.640884,0.610526,0.674419
4500,0.4274,1.180005,0.542159,0.572853,0.540973,0.634375,0.471545,0.530777,0.463845,0.620283,0.479657,0.448,0.516129,0.552876,0.543796,0.562264,0.601852,0.580357,0.625,0.627737,0.781818,0.52439,0.550725,0.508929,0.6,0.698225,0.710843,0.686047
5000,0.3675,1.1804,0.553564,0.577278,0.5497,0.645768,0.478513,0.550598,0.511111,0.596698,0.499494,0.445045,0.569124,0.567669,0.565543,0.569811,0.6125,0.540441,0.706731,0.658065,0.69863,0.621951,0.546392,0.535354,0.557895,0.633803,0.803571,0.523256


TrainOutput(global_step=12285, training_loss=0.6296214087784752, metrics={'train_runtime': 3514.7393, 'train_samples_per_second': 83.818, 'train_steps_per_second': 3.495, 'total_flos': 3.6711027065314345e+20, 'train_loss': 0.6296214087784752, 'epoch': 15.0})

In [14]:

# Evaluate on test split
test_metrics = trainer.evaluate(test_dataset)
print("Test metrics:", test_metrics)


Test metrics: {'eval_loss': 1.081895351409912, 'eval_accuracy': 0.617671009771987, 'eval_macro_f1': 0.6118197703066269, 'eval_f1_LHHP': 0.653556969346443, 'eval_precision_LHHP': 0.6758373205741627, 'eval_recall_LHHP': 0.6326987681970885, 'eval_f1_RHHP': 0.6266009852216748, 'eval_precision_RHHP': 0.5608465608465608, 'eval_recall_RHHP': 0.7098214285714286, 'eval_f1_LHMP': 0.5467800729040098, 'eval_precision_LHMP': 0.5447941888619855, 'eval_recall_LHMP': 0.5487804878048781, 'eval_f1_RHMP': 0.5935483870967742, 'eval_precision_RHMP': 0.5948275862068966, 'eval_recall_RHMP': 0.592274678111588, 'eval_f1_LHBlP': 0.6131805157593123, 'eval_precision_LHBlP': 0.6331360946745562, 'eval_recall_LHBlP': 0.5944444444444444, 'eval_f1_RHBlP': 0.686046511627907, 'eval_precision_RHBlP': 0.8082191780821918, 'eval_recall_RHBlP': 0.5959595959595959, 'eval_f1_LHBP': 0.4891304347826087, 'eval_precision_LHBP': 0.5056179775280899, 'eval_recall_LHBP': 0.47368421052631576, 'eval_f1_RHBP': 0.6857142857142857, 'eval_p