In [None]:
import pathlib
from transformers import AutoImageProcessor, VideoMAEForVideoClassification
import os
import torch
import torch.nn as nn
from collections import OrderedDict
import gc
import json
import time
from functools import partial
import itertools

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

print(device)

In [None]:
# FIX FOR IMPORT ISSUE FROM HERE: https://github.com/xinntao/Real-ESRGAN/issues/768
import sys
import types
from torchvision.transforms.functional import rgb_to_grayscale

# Create a module for `torchvision.transforms.functional_tensor`
functional_tensor = types.ModuleType("torchvision.transforms.functional_tensor")
functional_tensor.rgb_to_grayscale = rgb_to_grayscale

# Add this module to sys.modules so other imports can access it
sys.modules["torchvision.transforms.functional_tensor"] = functional_tensor

In [None]:
dataset_root_path = "/home/k/kyparkypar/ondemand/data/sys/myjobs/projects/default/dataset/CMU-MOSI/Raw_reorganised/"
dataset_root_path = pathlib.Path(dataset_root_path)

In [None]:
dataset_root_path

In [None]:
all_video_file_paths = (
    list(dataset_root_path.glob("train/*/*.mp4"))
    + list(dataset_root_path.glob("valid/*/*.mp4"))
    + list(dataset_root_path.glob("test/*/*.mp4"))
 )

In [None]:
all_video_file_paths

In [None]:
class_labels = sorted({str(path).split("/")[-2] for path in all_video_file_paths})

label2id = {label: i for i, label in enumerate(class_labels)}
id2label = {i: label for label, i in label2id.items()}

print(f"Unique classes: {list(label2id.keys())}.")

In [None]:
model_ckpt = "MCG-NJU/videomae-base"
image_processor = AutoImageProcessor.from_pretrained(model_ckpt)
model = VideoMAEForVideoClassification.from_pretrained(
    model_ckpt,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True,  # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)

In [None]:
model

In [None]:
model.init_weights()

In [None]:
batch_size = 4
gradient_accumulation_steps = 8

learning_rate = 1e-5
weight_decay = 1e-4

max_epochs = 5

hidden_size = 32
dropout_p = 0.1
activation_fn = "tanh"

In [None]:
if activation_fn == 'relu':
    activation_function = nn.ReLU()
elif activation_fn == 'tanh':
    activation_function = nn.Tanh()
elif activation_fn == 'No':
    activation_function = nn.Identity()

In [None]:
model.classifier = nn.Sequential(
    OrderedDict([
        ('dense', nn.Linear(768, hidden_size)),
        ('act_func', activation_function),
        ('dropout', nn.Dropout(dropout_p)),
        ('dense_outp', nn.Linear(hidden_size, model.config.num_labels)),
    ])
)

In [None]:
import pytorchvideo.data

from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)

from torchvision.transforms import (
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
    Resize,
)

In [None]:
mean = image_processor.image_mean
std = image_processor.image_std
if "shortest_edge" in image_processor.size:
    height = width = image_processor.size["shortest_edge"]
else:
    height = image_processor.size["height"]
    width = image_processor.size["width"]
resize_to = (height, width)

num_frames_to_sample = model.config.num_frames

In [None]:
num_frames_to_sample

In [None]:
sample_rate = 4
fps = 30
clip_duration = num_frames_to_sample * sample_rate / fps

In [None]:
train_transform = Compose(
    [
        ApplyTransformToKey(
            key="video",
            transform=Compose(
                [
                    UniformTemporalSubsample(num_frames_to_sample),
                    Lambda(lambda x: x / 255.0),
                    Normalize(mean, std),
                    Resize(resize_to),
                ]
            ),
        ),
    ]
)

# train_transform_augm = Compose(
#     [
#         ApplyTransformToKey(
#             key="video",
#             transform=Compose(
#                 [
#                     UniformTemporalSubsample(num_frames_to_sample),
#                     Lambda(lambda x: x / 255.0),
#                     Normalize(mean, std),
#                     RandomShortSideScale(min_size=256, max_size=320),
#                     RandomCrop(resize_to),
#                     RandomHorizontalFlip(p=0.5),
#                 ]
#             ),
#         ),
#     ]
# )

train_dataset = pytorchvideo.data.Ucf101(
    data_path=os.path.join(dataset_root_path, "train"),
    clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration),
    decode_audio=False,
    transform=train_transform,
)

In [None]:
type(train_dataset)

In [None]:
type(pytorchvideo.data.Ucf101)

In [None]:
# from pytorchvideo.data import Ucf101

# class LenEnabledUcf101(Ucf101):
#     def __init__(self, *args, **kwargs):
#         super().__init__(*args, **kwargs)
#         # Cache the dataset length
#         self._length = len(self._video_paths)

#     def __len__(self):
#         return self._length

# # Use the custom subclass
# ex_train_dataset = LenEnabledUcf101(
#     data_path=os.path.join(dataset_root_path, "train"),
#     clip_sampler=pytorchvideo.data.make_clip_sampler("random", clip_duration),
#     decode_audio=False,
#     transform=train_transform,
# )

# # Example usage
# print(len(ex_train_dataset))  # Now len() works

In [None]:
# from pytorchvideo.data.clip_sampling import ClipSampler, ClipInfo
# from typing import Iterator, Dict
# import random

# class FractionalClipSampler(ClipSampler):
#     def __init__(self, base_sampler, fraction=0.2, seed=42):
#         """
#         A fractional clip sampler that keeps only a fraction of the clips
#         selected by the base sampler.

#         Args:
#             base_sampler (ClipSampler): The base sampler to use for generating clips.
#             fraction (float): Fraction of clips to keep (e.g., 0.2 for 20%).
#             seed (int): Random seed for reproducibility.
#         """
#         super().__init__()
#         self.base_sampler = base_sampler
#         self.fraction = fraction
#         self.random_state = random.Random(seed)

#     def __call__(self, last_clip_time: float, video_duration: float, info_dict: Dict) -> Iterator[ClipInfo]:
#         """
#         Called to generate clips.

#         Args:
#             last_clip_time (float): Start time of the last clip.
#             video_duration (float): Duration of the video.
#             info_dict (dict): Additional metadata about the video.

#         Returns:
#             Iterator[ClipInfo]: A generator of ClipInfo objects for the selected clips.
#         """
#         # Generate clips using the base sampler
#         for clip_info in self.base_sampler(last_clip_time, video_duration, info_dict):
#             # Randomly keep only a fraction of the clips
#             if self.random_state.random() <= self.fraction:
#                 yield clip_info

#     def reset(self):
#         """Resets the state of the sampler."""
#         self.base_sampler.reset()


# # Use the fractional sampler
# fractional_sampler = FractionalClipSampler(
#     pytorchvideo.data.make_clip_sampler("random", clip_duration),
#     fraction=0.2,
# )


# train_dataset_subset = pytorchvideo.data.Ucf101(
#     data_path=os.path.join(dataset_root_path, "train"),
#     clip_sampler=fractional_sampler,
#     decode_audio=False,
#     transform=train_transform,
# )


In [None]:
val_transform = Compose(
    [
        ApplyTransformToKey(
            key="video",
            transform=Compose(
                [
                    UniformTemporalSubsample(num_frames_to_sample),
                    Lambda(lambda x: x / 255.0),
                    Normalize(mean, std),
                    Resize(resize_to),
                ]
            ),
        ),
    ]
)

val_dataset = pytorchvideo.data.Ucf101(
    data_path=os.path.join(dataset_root_path, "valid"),
    clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration),
    decode_audio=False,
    transform=val_transform,
)

test_dataset = pytorchvideo.data.Ucf101(
    data_path=os.path.join(dataset_root_path, "test"),
    clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration),
    decode_audio=False,
    transform=val_transform,
)

# from pytorchvideo.data.clip_sampling import UniformClipSampler

# # Set stride to a large value (greater than any video length)
# clip_sampler = UniformClipSampler(clip_duration=clip_duration, stride=clip_duration)

# test_dataset = pytorchvideo.data.Ucf101(
#     data_path=os.path.join(dataset_root_path, "test"),
#     clip_sampler=clip_sampler,
#     decode_audio=False,
#     transform=val_transform,
# )

In [None]:
print(train_dataset.num_videos, val_dataset.num_videos, test_dataset.num_videos)

In [None]:
import imageio
import numpy as np
from IPython.display import Image

def unnormalize_img(img):
    """Un-normalizes the image pixels."""
    img = (img * std) + mean
    img = (img * 255).astype("uint8")
    return img.clip(0, 255)

def create_gif(video_tensor, filename="sample.gif"):
    """Prepares a GIF from a video tensor.
    The video tensor is expected to have the following shape:
    (num_frames, num_channels, height, width).
    """
    frames = []
    for video_frame in video_tensor:
        frame_unnormalized = unnormalize_img(video_frame.permute(1, 2, 0).numpy())
        frames.append(frame_unnormalized)
    kargs = {"duration": 0.25}
    imageio.mimsave(filename, frames, "GIF", **kargs)
    return filename

def display_gif(video_tensor, gif_name="sample.gif"):
    """Prepares and displays a GIF from a video tensor."""
    video_tensor = video_tensor.permute(1, 0, 2, 3)
    gif_filename = create_gif(video_tensor, gif_name)
    return Image(filename=gif_filename)

In [None]:
# sample_video = next(iter(train_dataset))
# video_tensor = sample_video["video"]

In [None]:
# display_gif(video_tensor)

In [None]:
from transformers import TrainingArguments, Trainer, TrainerCallback

model_name = model_ckpt.split("/")[-1]
new_model_name = f"{model_name}-finetuned-cmu-mosi"
num_epochs = 4

In [None]:
model_name

In [None]:
# Configure training run with TrainingArguments class      
metric_for_best_model = "loss"   # Save the model and the metrics of the current model for the best epochs
training_args = TrainingArguments(
    output_dir="./runs/videomae",
    # logging_dir="./logs/videomae",
    # report_to="tensorboard",
    learning_rate=learning_rate,                   
    push_to_hub=False,
    num_train_epochs=max_epochs,                   
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    per_device_eval_batch_size=batch_size,
    eval_strategy="epoch",                       
    save_strategy="epoch",
    # save_total_limit=0,  # Ensure no checkpoints are saved
    # eval_steps=1,
    # save_steps=1,
    weight_decay=weight_decay,
    load_best_model_at_end=True,
    metric_for_best_model=metric_for_best_model,
    remove_unused_columns=False,
    # eval_accumulation_steps=eval_accumulation_steps,
    # logging_strategy="epoch",
    logging_steps=10,
    lr_scheduler_type="constant",  # Ensures no decay in learning rate
    fp16=True,
    max_steps=(train_dataset.num_videos // (batch_size * gradient_accumulation_steps)) * max_epochs,
)

In [None]:
class_wts = np.array([1.16304348, 0.87704918])

In [None]:
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.get("labels")
        
        outputs = model(**inputs)

        logits = outputs.get("logits")

        # Compute custom loss with class weights
        weights = torch.tensor(class_wts, dtype=torch.float).to(device)
        loss_fct = nn.CrossEntropyLoss(weight=weights)

        loss = loss_fct(logits.view(-1, model.config.num_labels), labels.view(-1))
        
        return (loss, outputs) if return_outputs else loss

In [None]:
class BestModelEpochCallback(TrainerCallback):
    def __init__(self):
        self.best_loss = float("inf")
        self.best_acc = 0.0
        self.best_epoch = None
        self.training_metrics = []  # Track training loss at the end of each epoch
        self.eval_metrics = []      # Track evaluation loss at the end of each epoch
                        
    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if metrics is not None and metric_for_best_model == "loss":
            if "eval_loss" in metrics and round(state.epoch) > 1:
                # print(state.epoch)
                # print(round(state.epoch))
                self.eval_metrics.append((round(state.epoch), metrics["eval_loss"]))
                current_loss = metrics["eval_loss"]
                # print(f"Epoch #{int(state.epoch)} | Validation Loss: {current_loss:.5f} | Validation Accuracy: {metrics['eval_accuracy']:.5f}")
                if current_loss < self.best_loss:
                    self.best_loss = current_loss
                    self.best_epoch = round(state.epoch)
                    self.best_acc = metrics["eval_accuracy"]
                    
        elif metrics is not None and metric_for_best_model == "accuracy":
            if "eval_loss" in metrics and round(state.epoch) > 1:
                self.eval_metrics.append((round(state.epoch), metrics["eval_loss"]))
                current_acc = metrics["eval_accuracy"]
                # print(f"Epoch #{int(state.epoch)} | Validation Accuracy: {metrics['eval_accuracy']:.5f} | Validation Loss: {current_loss:.5f}")
                if current_acc > self.best_acc:
                    self.best_acc = current_acc
                    self.best_epoch = round(state.epoch)
                    self.best_loss = metrics["eval_loss"]

    def on_epoch_end(self, args, state, control, **kwargs):
        # Log training loss at the end of the epoch
        if state.log_history:
            # Extract the last logged loss
            for log in reversed(state.log_history):
                if "loss" in log:
                    self.training_metrics.append((state.epoch, log["loss"]))
                    break

    # def on_log(self, args, state, control, logs=None, **kwargs):
    #     if logs and "loss" in logs:
    #         self.training_metrics.append((state.epoch, logs["loss"]))

best_model_callback = BestModelEpochCallback()

In [None]:
import evaluate

metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

In [None]:
def collate_fn(examples):
    resize = Resize(resize_to)  # Ensure consistent size
    videos = []
    labels = []

    for example in examples:
        resized_video = resize(example["video"])  # Resize each video
        videos.append(resized_video.permute(1, 0, 2, 3))  # Permute dimensions
        labels.append(example["label"])

    pixel_values = torch.stack(videos)
    labels = torch.tensor(labels)

    return {"pixel_values": pixel_values, 
            "labels": labels}

In [None]:
trainer = CustomTrainer(
    model,
    training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    processing_class=image_processor,
    callbacks=[best_model_callback],  # Not used during CV, only here to find optimal epochs
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)

In [None]:
trainer.train()

In [None]:
# import matplotlib.pyplot as plt

# # Extract metrics
# epochs, training_losses = zip(*best_model_callback.training_metrics)
# eval_epochs, eval_losses = zip(*best_model_callback.eval_metrics)

# # Plot learning curves
# plt.figure(figsize=(10, 6))
# plt.plot(epochs, training_losses, 
#          label="Training Loss")
# plt.plot(eval_epochs, eval_losses, 
#          label="Validation Loss")
# plt.xlabel("Epochs")
# plt.xticks(range(1, max_epochs+1))
# plt.ylabel("Loss")
# plt.title("Learning Curve")
# plt.legend()

# # Show the plot
# plt.show()

In [None]:
# print("Optimal Epochs: ", (int(best_model_callback.best_epoch) + 1))

In [None]:
# metrics = trainer.evaluate(test_dataset)

In [None]:
# metrics

In [None]:
# print("Avg. test loss: ", preds.metrics['test_loss'])

In [None]:
# predicted_labels = np.argmax(metrics.predictions, axis=1)

In [None]:
# predicted_labels.shape

In [None]:
# preds.label_ids

In [None]:
# preds.label_ids.shape

In [None]:
# from sklearn.metrics import classification_report

# print(classification_report(metrics.label_ids, predicted_labels, digits=4))

In [None]:
# import pandas as pd

# confusion_matrix = pd.crosstab(preds.label_ids, predicted_labels)

In [None]:
# import seaborn as sns

# # Set the size of the figure
# plt.figure(figsize=(10, 7))

# # Create a heatmap from the confusion matrix
# sns.heatmap(confusion_matrix,
#             annot=True,
#             fmt='d',
#             cmap='Blues',
#             cbar=True)

# # Set titles and labels
# plt.title('Fine-Tuned AST (Optimal Parameters) Confusion Matrix')
# plt.xlabel('Predicted Labels')
# plt.ylabel('True Labels')

# # Show the plot
# plt.show()

In [None]:
def run_inference(model, video_or_dataset):
    """
    Run inference on either a single video or a dataset of videos.
    
    Args:
        model (torch.nn.Module): The model to use for inference.
        video_or_dataset (Union[torch.Tensor, LabeledVideoDataset]): 
            A single video tensor or a dataset of videos.
    
    Returns:
        Union[torch.Tensor, Tuple[torch.Tensor, List[int]]]: 
            If a single video tensor is provided, returns logits (torch.Tensor).
            If a dataset is provided, returns a tuple containing:
            - logits for all videos (torch.Tensor)
            - a list of corresponding labels (List[int]).
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # Case 1: Single video input
    if isinstance(video_or_dataset, torch.Tensor):
        # (num_frames, num_channels, height, width) -> (num_channels, num_frames, height, width)
        permuted_video = video_or_dataset.permute(1, 0, 2, 3)
        inputs = {"pixel_values": permuted_video.unsqueeze(0).to(device)}
        
        with torch.no_grad():
            outputs = model(**inputs)
        return outputs.logits

    # Case 2: Dataset input
    logits_list = []
    labels_list = []
    dataset_iterator = iter(video_or_dataset)  # Create an iterator for the dataset
    i = 0
    for i in range(video_or_dataset.num_videos):
        sample = next(dataset_iterator)  # Get the next sample
        video = sample["video"]
        label = sample["label"]  # Extract label
        permuted_video = video.permute(1, 0, 2, 3)
        inputs = {"pixel_values": permuted_video.unsqueeze(0).to(device)}
        
        with torch.no_grad():
            outputs = model(**inputs)
        
        logits_list.append(outputs.logits)
        labels_list.append(label)  # Append label to labels_list
        print(i, "->", outputs.logits, "-", label)
        i += 1
    
    return torch.cat(logits_list, dim=0), labels_list

In [None]:
logits = run_inference(model, test_dataset)

In [None]:
# logits

In [None]:
real_labels = logits[1]

In [None]:
predicted_logits = logits[0]
predicted_labels = predicted_logits.argmax(-1).cpu().numpy()

In [None]:
# predicted_labels

In [None]:
from sklearn.metrics import classification_report

print(classification_report(real_labels, predicted_labels, digits=4))

In [None]:
# Convert lists to tensors
predicted_logits_tensor = torch.tensor(predicted_logits, dtype=torch.float).to(device)
real_labels_tensor = torch.tensor(real_labels, dtype=torch.long).to(device)

# Compute class weights
weights = torch.tensor(class_wts, dtype=torch.float).to(device)
loss_fct = nn.CrossEntropyLoss(weight=weights)

# Compute loss
loss = loss_fct(predicted_logits_tensor.view(-1, predicted_logits_tensor.size(-1)), real_labels_tensor.view(-1))

In [None]:
print("Avg. test loss: ", loss.item())

In [None]:
import pandas as pd

confusion_matrix = pd.crosstab(real_labels, predicted_labels)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Set the size of the figure
plt.figure(figsize=(10, 7))

# Create a heatmap from the confusion matrix
sns.heatmap(confusion_matrix,
            annot=True,
            fmt='d',
            cmap='Blues',
            cbar=True)

# Set titles and labels
plt.title('Fine-Tuned VideoMAE (Optimal Parameters) Confusion Matrix')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')

# Show the plot
plt.show()

In [None]:
# sample_test_video = next(iter(test_dataset))

In [None]:
# test_dataset

In [None]:
# predicted_class_idx = logits.argmax(-1).item()
# print("Predicted class:", model.config.id2label[predicted_class_idx])