In [1]:
import os
gpu_ids = [4]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_ids))
import random
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import cv2
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import VideoMAEFeatureExtractor, VideoMAEModel
from sklearn.metrics import f1_score, recall_score, accuracy_score
from tqdm import tqdm

# ---- SETTINGS ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
negative_samples = 1500
neutral_samples = 1500
positive_samples = 1500
batch_size = 16
clip_len = 16
num_epochs = 30

clip_dir = "/data/home/huixian/Documents/Homeworks/535_project/MOSEI/Clip/Clips_16frames"
mapping_csv = "/data/home/huixian/Documents/Homeworks/535_project/MOSEI/Clip/clip_sentiment_mapping.csv"

# ---- DATASET ----
class VideoClipDataset(Dataset):
    def __init__(self, clip_dir, csv_path, feature_extractor):
        self.clip_dir = clip_dir
        self.df = pd.read_csv(csv_path)
        self.feature_extractor = feature_extractor

        # Map filename to (sentiment_score, sentiment_label)
        self.filename2score = dict()
        for _, row in self.df.iterrows():
            clip_name = row["clip_filename"]
            score = row["sentiment_score"]
            label = row["sentiment_label"]
            self.filename2score[clip_name] = (score, label)

        # Group by label
        self.grouped = {"Negative": [], "Neutral": [], "Positive": []}
        for clip_name, (_, label) in self.filename2score.items():
            self.grouped[label].append(clip_name)

        # Sample
        sampled_neg = random.sample(self.grouped["Negative"], min(len(self.grouped["Negative"]), negative_samples))
        sampled_neu = random.sample(self.grouped["Neutral"], min(len(self.grouped["Neutral"]), neutral_samples))
        sampled_pos = random.sample(self.grouped["Positive"], min(len(self.grouped["Positive"]), positive_samples))
        self.samples = sampled_neg + sampled_neu + sampled_pos

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        clip_name = self.samples[idx]
        clip_path = os.path.join(self.clip_dir, clip_name)

        cap = cv2.VideoCapture(clip_path)
        frames = []
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frames.append(frame[:, :, ::-1])  # Convert BGR to RGB
        cap.release()

        if len(frames) < clip_len:
            # Pad frames if too short
            frames += [frames[-1]] * (clip_len - len(frames))
        frames = frames[:clip_len]

        inputs = self.feature_extractor(images=frames, return_tensors="pt")["pixel_values"].squeeze(0)  # (3, T, H, W)

        sentiment_score, label = self.filename2score[clip_name]
        return inputs, torch.tensor(sentiment_score, dtype=torch.float32)

# ---- LOSS FUNCTION ----
class CenteredWeightedMSELoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, preds, targets):
        ideal = torch.zeros_like(targets)
        ideal[targets < -0.3] = -3.0
        ideal[targets > 0.3] = 3.0
        ideal[(-0.3 <= targets) & (targets <= 0.3)] = 0.0

        weights = torch.ones_like(targets)
        weights[targets < -0.3] = 2.0
        weights[targets > 0.3] = 2.0
        weights[(-0.3 <= targets) & (targets <= 0.3)] = 1.0

        mse = (preds - ideal) ** 2
        return (weights * mse).mean()

# ---- MODEL ----
class SentimentRegressor(nn.Module):
    def __init__(self, feature_dim):
        super().__init__()
        self.regressor = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        return self.regressor(x).squeeze(1)

# ---- TRAINING LOOP ----
def run_epoch(model, loader, optimizer, is_train=True):
    if is_train:
        model.train()
    else:
        model.eval()

    total_preds = []
    total_labels = []
    total_loss = 0

    for clips, targets in tqdm(loader, leave=False):
        clips = clips.to(device)
        targets = targets.to(device)

        with torch.set_grad_enabled(is_train):
            features = video_mae(clips).last_hidden_state.mean(dim=1)  # Global average pooling
            preds = model(features)
            loss = loss_fn(preds, targets)

            if is_train:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        total_loss += loss.item()
        total_preds.extend(preds.detach().cpu().numpy())
        total_labels.extend(targets.detach().cpu().numpy())

    avg_loss = total_loss / len(loader)
    return avg_loss, np.array(total_preds), np.array(total_labels)

def evaluate(preds, labels):
    preds_label = []
    labels_label = []

    for p in preds:
        if p < -0.3:
            preds_label.append("Negative")
        elif p > 0.3:
            preds_label.append("Positive")
        else:
            preds_label.append("Neutral")

    for l in labels:
        if l < -0.3:
            labels_label.append("Negative")
        elif l > 0.3:
            labels_label.append("Positive")
        else:
            labels_label.append("Neutral")

    macro_f1 = f1_score(labels_label, preds_label, average="macro")
    micro_f1 = f1_score(labels_label, preds_label, average="micro")
    recall = recall_score(labels_label, preds_label, average=None, labels=["Negative", "Neutral", "Positive"])
    acc = accuracy_score(labels_label, preds_label)

    return macro_f1, micro_f1, recall, acc

# ---- LOAD EVERYTHING ----
feature_extractor = VideoMAEFeatureExtractor.from_pretrained("MCG-NJU/videomae-base")
video_mae = VideoMAEModel.from_pretrained("MCG-NJU/videomae-base").to(device)
video_mae.eval()  # Freeze VideoMAE
for param in video_mae.parameters():
    param.requires_grad = False

dataset = VideoClipDataset(
    clip_dir=clip_dir,
    csv_path=mapping_csv,
    feature_extractor=feature_extractor
)

train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

regressor = SentimentRegressor(feature_dim=768).to(device)
loss_fn = CenteredWeightedMSELoss()
optimizer = optim.Adam(regressor.parameters(), lr=2e-4)

# ---- TRAIN ----
best_macro_f1 = -np.inf
start_epoch = 0
# model_path = "/data/home/huixian/Documents/Homeworks/535_project/mosei_code/best_regressor.pth"
# if os.path.exists(model_path):
#     regressor.load_state_dict(torch.load(model_path))
#     start_epoch = 15  # Replace X with last completed epoch + 1
for epoch in range(start_epoch, num_epochs):
    print(f"\nEpoch {epoch}")

    train_loss, _, _ = run_epoch(regressor, train_loader, optimizer, is_train=True)
    val_loss, val_preds, val_labels = run_epoch(regressor, val_loader, optimizer, is_train=False)

    macro_f1, micro_f1, recall, acc = evaluate(val_preds, val_labels)

    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss: {val_loss:.4f}")
    print(f"Macro-F1: {macro_f1:.4f} | Micro-F1: {micro_f1:.4f} | Acc: {acc:.4f} | Recall: {recall}")

    if macro_f1 > best_macro_f1:
        best_macro_f1 = macro_f1
        torch.save(regressor.state_dict(), "best_regressor_epo_30_bs_16_lr_2e-4.pth")
        print(f"✅ Best model saved at epoch {epoch} with Macro-F1={macro_f1:.4f}")

# ---- EVALUATE ON TEST SET ----
test_loss, test_preds, test_labels = run_epoch(regressor, test_loader, optimizer, is_train=False)
macro_f1, micro_f1, recall, acc = evaluate(test_preds, test_labels)
print("\n----- TEST RESULTS -----")
print(f"Macro-F1: {macro_f1:.4f} | Micro-F1: {micro_f1:.4f} | Acc: {acc:.4f} | Recall: {recall}")





Epoch 0


                                                 

Train Loss: 11.7665
Val Loss: 12.0621
Macro-F1: 0.2997 | Micro-F1: 0.3689 | Acc: 0.3689 | Recall: [0.2125     0.89855072 0.05263158]
✅ Best model saved at epoch 0 with Macro-F1=0.2997

Epoch 1


                                                 

Train Loss: 11.4028
Val Loss: 11.6766
Macro-F1: 0.3971 | Micro-F1: 0.4133 | Acc: 0.4133 | Recall: [0.45       0.61594203 0.19078947]
✅ Best model saved at epoch 1 with Macro-F1=0.3971

Epoch 2


                                                 

Train Loss: 11.1734
Val Loss: 11.4754
Macro-F1: 0.4296 | Micro-F1: 0.4311 | Acc: 0.4311 | Recall: [0.30625 0.5     0.5    ]
✅ Best model saved at epoch 2 with Macro-F1=0.4296

Epoch 3


                                                 

Train Loss: 10.9261
Val Loss: 11.2209
Macro-F1: 0.4289 | Micro-F1: 0.4267 | Acc: 0.4267 | Recall: [0.43125    0.43478261 0.41447368]

Epoch 4


                                                 

Train Loss: 10.7432
Val Loss: 11.1692
Macro-F1: 0.4206 | Micro-F1: 0.4200 | Acc: 0.4200 | Recall: [0.40625    0.39130435 0.46052632]

Epoch 5


                                                 

Train Loss: 10.5980
Val Loss: 11.0889
Macro-F1: 0.4313 | Micro-F1: 0.4333 | Acc: 0.4333 | Recall: [0.43125    0.35507246 0.50657895]
✅ Best model saved at epoch 5 with Macro-F1=0.4313

Epoch 6


                                                 

Train Loss: 10.4883
Val Loss: 11.1947
Macro-F1: 0.4284 | Micro-F1: 0.4400 | Acc: 0.4400 | Recall: [0.35625    0.3115942  0.64473684]

Epoch 7


                                                 

Train Loss: 10.3804
Val Loss: 11.1523
Macro-F1: 0.4172 | Micro-F1: 0.4289 | Acc: 0.4289 | Recall: [0.60625    0.29710145 0.36184211]

Epoch 8


                                                 

Train Loss: 10.2244
Val Loss: 11.0572
Macro-F1: 0.4284 | Micro-F1: 0.4356 | Acc: 0.4356 | Recall: [0.54375    0.30434783 0.44078947]

Epoch 9


                                                 

Train Loss: 10.1068
Val Loss: 10.9891
Macro-F1: 0.4440 | Micro-F1: 0.4511 | Acc: 0.4511 | Recall: [0.4875     0.3115942  0.53947368]
✅ Best model saved at epoch 9 with Macro-F1=0.4440

Epoch 10


                                                 

Train Loss: 10.0732
Val Loss: 10.8857
Macro-F1: 0.4345 | Micro-F1: 0.4422 | Acc: 0.4422 | Recall: [0.425      0.3115942  0.57894737]

Epoch 11


                                                 

Train Loss: 9.9403
Val Loss: 11.3333
Macro-F1: 0.4368 | Micro-F1: 0.4600 | Acc: 0.4600 | Recall: [0.7125     0.26086957 0.375     ]

Epoch 12


                                                 

Train Loss: 9.8569
Val Loss: 11.0528
Macro-F1: 0.4312 | Micro-F1: 0.4444 | Acc: 0.4444 | Recall: [0.5875     0.26086957 0.46052632]

Epoch 13


                                                 

Train Loss: 9.8074
Val Loss: 10.8989
Macro-F1: 0.4116 | Micro-F1: 0.4289 | Acc: 0.4289 | Recall: [0.425     0.2173913 0.625    ]

Epoch 14


                                                 

Train Loss: 9.6556
Val Loss: 10.7950
Macro-F1: 0.4327 | Micro-F1: 0.4422 | Acc: 0.4422 | Recall: [0.475      0.2826087  0.55263158]

Epoch 15


                                                 

Train Loss: 9.5826
Val Loss: 10.8635
Macro-F1: 0.4148 | Micro-F1: 0.4311 | Acc: 0.4311 | Recall: [0.44375    0.22463768 0.60526316]

Epoch 16


                                                 

Train Loss: 9.4668
Val Loss: 10.7020
Macro-F1: 0.4215 | Micro-F1: 0.4378 | Acc: 0.4378 | Recall: [0.45625    0.22463768 0.61184211]

Epoch 17


                                                 

Train Loss: 9.4391
Val Loss: 10.7567
Macro-F1: 0.4309 | Micro-F1: 0.4444 | Acc: 0.4444 | Recall: [0.575      0.24637681 0.48684211]

Epoch 18


                                                 

Train Loss: 9.3177
Val Loss: 10.7192
Macro-F1: 0.4328 | Micro-F1: 0.4444 | Acc: 0.4444 | Recall: [0.5375     0.26086957 0.51315789]

Epoch 19


                                                 

Train Loss: 9.2719
Val Loss: 10.7338
Macro-F1: 0.4511 | Micro-F1: 0.4622 | Acc: 0.4622 | Recall: [0.5125     0.2826087  0.57236842]
✅ Best model saved at epoch 19 with Macro-F1=0.4511

Epoch 20


                                                 

Train Loss: 9.1568
Val Loss: 10.6500
Macro-F1: 0.4354 | Micro-F1: 0.4556 | Acc: 0.4556 | Recall: [0.4625     0.22463768 0.65789474]

Epoch 21


                                                 

Train Loss: 9.1017
Val Loss: 10.5852
Macro-F1: 0.4368 | Micro-F1: 0.4556 | Acc: 0.4556 | Recall: [0.46875    0.23188406 0.64473684]

Epoch 22


                                                 

Train Loss: 8.9937
Val Loss: 10.7506
Macro-F1: 0.4546 | Micro-F1: 0.4733 | Acc: 0.4733 | Recall: [0.65       0.25362319 0.48684211]
✅ Best model saved at epoch 22 with Macro-F1=0.4546

Epoch 23


                                                 

Train Loss: 8.9146
Val Loss: 10.8559
Macro-F1: 0.4648 | Micro-F1: 0.4822 | Acc: 0.4822 | Recall: [0.66875    0.2826087  0.46710526]
✅ Best model saved at epoch 23 with Macro-F1=0.4648

Epoch 24


                                                 

Train Loss: 8.8437
Val Loss: 10.8135
Macro-F1: 0.4470 | Micro-F1: 0.4667 | Acc: 0.4667 | Recall: [0.6625     0.24637681 0.46052632]

Epoch 25


                                                

KeyboardInterrupt: 