In [2]:
import os
gpu_ids = [4]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_ids))
import random
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import cv2
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import VideoMAEFeatureExtractor, VideoMAEModel
from sklearn.metrics import f1_score, recall_score, accuracy_score
from tqdm import tqdm

# ---- CONFIGURATION ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_dir = "/data/home/huixian/Documents/Homeworks/535_project/MOSEI/Clip/Clips_16frames"
mapping_csv = "/data/home/huixian/Documents/Homeworks/535_project/MOSEI/Clip/clip_sentiment_mapping.csv"
clip_len = 16
batch_size = 8
num_epochs = 10
negative_samples = 1500
neutral_samples = 1500
positive_samples = 1500

# ---- DATASET ----
class VideoClipClassificationDataset(Dataset):
    def __init__(self, clip_dir, csv_path, feature_extractor):
        self.clip_dir = clip_dir
        self.df = pd.read_csv(csv_path)
        self.feature_extractor = feature_extractor
        self.label_map = {"Negative": 0, "Neutral": 1, "Positive": 2}
        self.samples_by_class = {k: [] for k in self.label_map}

        for _, row in self.df.iterrows():
            self.samples_by_class[row["sentiment_label"]].append(row["clip_filename"])

        self.samples = (
            random.sample(self.samples_by_class["Negative"], min(negative_samples, len(self.samples_by_class["Negative"]))) +
            random.sample(self.samples_by_class["Neutral"], min(neutral_samples, len(self.samples_by_class["Neutral"]))) +
            random.sample(self.samples_by_class["Positive"], min(positive_samples, len(self.samples_by_class["Positive"])))
        )

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        clip_name = self.samples[idx]
        label_str = self.df[self.df["clip_filename"] == clip_name]["sentiment_label"].values[0]
        label = self.label_map[label_str]
        clip_path = os.path.join(self.clip_dir, clip_name)

        cap = cv2.VideoCapture(clip_path)
        frames = []
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frames.append(frame[:, :, ::-1])
        cap.release()

        if len(frames) < clip_len:
            frames += [frames[-1]] * (clip_len - len(frames))
        frames = frames[:clip_len]

        inputs = self.feature_extractor(images=frames, return_tensors="pt")["pixel_values"].squeeze(0)
        return inputs, label

# ---- MODEL ----
class VideoClassifier(nn.Module):
    def __init__(self, input_dim=768, num_classes=3):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        return self.model(x)

# ---- TRAINING LOOP ----
def run_epoch(model, loader, optimizer=None):
    model.train() if optimizer else model.eval()
    total_loss = 0
    all_preds, all_labels = [], []

    for clips, labels in tqdm(loader, leave=False):
        clips, labels = clips.to(device), labels.to(device)

        with torch.set_grad_enabled(optimizer is not None):
            features = video_mae(clips).last_hidden_state.mean(dim=1)
            logits = model(features)
            loss = loss_fn(logits, labels)

            if optimizer:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        total_loss += loss.item()
        all_preds.extend(logits.argmax(dim=1).cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(loader)
    return avg_loss, np.array(all_preds), np.array(all_labels)

def evaluate(preds, labels):
    return (
        f1_score(labels, preds, average="macro"),
        f1_score(labels, preds, average="micro"),
        recall_score(labels, preds, average=None, labels=[0, 1, 2]),
        accuracy_score(labels, preds)
    )

# ---- LOAD MODELS ----
feature_extractor = VideoMAEFeatureExtractor.from_pretrained("MCG-NJU/videomae-base")
video_mae = VideoMAEModel.from_pretrained("MCG-NJU/videomae-base").to(device)
video_mae.eval()
for p in video_mae.parameters(): p.requires_grad = False

dataset = VideoClipClassificationDataset(clip_dir, mapping_csv, feature_extractor)
train_set, val_set, test_set = random_split(dataset, [int(0.8 * len(dataset)), int(0.1 * len(dataset)), len(dataset) - int(0.9 * len(dataset))])
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

classifier = VideoClassifier().to(device)
optimizer = optim.Adam(classifier.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()

# ---- TRAINING ----
best_macro = -1
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch}")
    train_loss, _, _ = run_epoch(classifier, train_loader, optimizer)
    val_loss, val_preds, val_labels = run_epoch(classifier, val_loader)

    macro_f1, micro_f1, recall, acc = evaluate(val_preds, val_labels)
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    print(f"Macro-F1: {macro_f1:.4f} | Micro-F1: {micro_f1:.4f} | Acc: {acc:.4f} | Recall: {recall}")

    if macro_f1 > best_macro:
        best_macro = macro_f1
        torch.save(classifier.state_dict(), "best_video_classifier.pth")
        print(f"✅ Best model saved at epoch {epoch}")

# ---- FINAL TEST ----
classifier.load_state_dict(torch.load("best_video_classifier.pth"))
test_loss, test_preds, test_labels = run_epoch(classifier, test_loader)
macro_f1, micro_f1, recall, acc = evaluate(test_preds, test_labels)
print("\n--- TEST SET ---")
print(f"Macro-F1: {macro_f1:.4f} | Micro-F1: {micro_f1:.4f} | Acc: {acc:.4f} | Recall: {recall}")





Epoch 0


                                                 

Train Loss: 1.0920 | Val Loss: 1.0818
Macro-F1: 0.2584 | Micro-F1: 0.3933 | Acc: 0.3933 | Recall: [0.08387097 0.04580153 0.96341463]
✅ Best model saved at epoch 0

Epoch 1


                                                 

Train Loss: 1.0768 | Val Loss: 1.0690
Macro-F1: 0.4471 | Micro-F1: 0.4511 | Acc: 0.4511 | Recall: [0.55483871 0.46564885 0.34146341]
✅ Best model saved at epoch 1

Epoch 2


                                                 

Train Loss: 1.0633 | Val Loss: 1.0572
Macro-F1: 0.4628 | Micro-F1: 0.4667 | Acc: 0.4667 | Recall: [0.56129032 0.42748092 0.40853659]
✅ Best model saved at epoch 2

Epoch 3


                                                 

Train Loss: 1.0531 | Val Loss: 1.0497
Macro-F1: 0.4702 | Micro-F1: 0.4733 | Acc: 0.4733 | Recall: [0.5483871  0.4351145  0.43292683]
✅ Best model saved at epoch 3

Epoch 4


                                                 

Train Loss: 1.0418 | Val Loss: 1.0520
Macro-F1: 0.4460 | Micro-F1: 0.4600 | Acc: 0.4600 | Recall: [0.6516129  0.48854962 0.25609756]

Epoch 5


                                                 

Train Loss: 1.0360 | Val Loss: 1.0583
Macro-F1: 0.4104 | Micro-F1: 0.4200 | Acc: 0.4200 | Recall: [0.43225806 0.66412214 0.21341463]

Epoch 6


                                                 

Train Loss: 1.0287 | Val Loss: 1.0409
Macro-F1: 0.4725 | Micro-F1: 0.4733 | Acc: 0.4733 | Recall: [0.46451613 0.46564885 0.48780488]
✅ Best model saved at epoch 6

Epoch 7


                                                 

Train Loss: 1.0230 | Val Loss: 1.0359
Macro-F1: 0.4662 | Micro-F1: 0.4689 | Acc: 0.4689 | Recall: [0.50322581 0.41984733 0.47560976]

Epoch 8


                                                 

Train Loss: 1.0181 | Val Loss: 1.0343
Macro-F1: 0.4712 | Micro-F1: 0.4733 | Acc: 0.4733 | Recall: [0.48387097 0.4351145  0.49390244]

Epoch 9


  classifier.load_state_dict(torch.load("best_video_classifier.pth"))


Train Loss: 1.0106 | Val Loss: 1.0332
Macro-F1: 0.4476 | Micro-F1: 0.4511 | Acc: 0.4511 | Recall: [0.54193548 0.41221374 0.39634146]


                                               


--- TEST SET ---
Macro-F1: 0.4018 | Micro-F1: 0.4022 | Acc: 0.4022 | Recall: [0.41104294 0.41025641 0.38167939]


