In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from tqdm import tqdm
import random

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
class TVSumDataset(Dataset):
    def __init__(self, video_ids, feature_dir, ratings_df):
        self.video_ids = video_ids
        self.feature_dir = feature_dir
        self.ratings_df = ratings_df

    def __len__(self):
        return len(self.video_ids)

    def __getitem__(self, idx):
      vid = self.video_ids[idx]
      feat_path = os.path.join(self.feature_dir, f"{vid}.npy")
      feat = np.load(feat_path)  # shape: [T, 2048]

      # Grab the row matching the video
      row = self.ratings_df[self.ratings_df['video_name'] == vid].iloc[0]

      # Parse the comma-separated frame scores string into a float array
      scores_str = row.iloc[2]  # third column
      scores = np.array([float(s) for s in scores_str.split(',')], dtype=np.float32)

      # Pad to equal length
      feat_len = feat.shape[0]
      score_len = len(scores)

      if feat_len > score_len:
          scores = np.pad(scores, (0, feat_len - score_len), mode='constant')
      elif score_len > feat_len:
          pad_feat = np.zeros((score_len - feat_len, feat.shape[1]), dtype=np.float32)
          feat = np.vstack((feat, pad_feat))

      return torch.tensor(feat, dtype=torch.float32), torch.tensor(scores, dtype=torch.float32), vid


        # return torch.tensor(feat, dtype=torch.float32), torch.tensor(scores, dtype=torch.float32), vid


In [None]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    feats, scores, vids = zip(*batch)
    lengths = [len(f) for f in feats]

    feats_padded = pad_sequence(feats, batch_first=True)    # shape: [B, max_len, 2048]
    scores_padded = pad_sequence(scores, batch_first=True)  # shape: [B, max_len]

    return feats_padded, scores_padded, torch.tensor(lengths), vids


In [None]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class BiLSTMModel(nn.Module):
    def __init__(self, input_size=2048, hidden_size=256):
        super(BiLSTMModel, self).__init__()
        self.bilstm = nn.LSTM(input_size, hidden_size, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size * 2, 1)

    def forward(self, x, lengths):
        # Pack the padded batch
        packed = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        out_packed, _ = self.bilstm(packed)
        out, _ = pad_packed_sequence(out_packed, batch_first=True)

        out = self.fc(out).squeeze(-1)  # [B, T]
        return out


In [None]:
def masked_mse_loss(pred, target, lengths):
    loss = 0.0
    for i in range(len(lengths)):
        valid_len = lengths[i]
        loss += nn.functional.mse_loss(pred[i, :valid_len], target[i, :valid_len])
    return loss / len(lengths)


In [None]:
def get_fold_split(fold_csv_path, test_fold):
    df = pd.read_csv(fold_csv_path)
    train_ids = df[df['fold'] != test_fold]['video_id'].tolist()
    test_ids = df[df['fold'] == test_fold]['video_id'].tolist()
    return train_ids, test_ids


In [None]:

def knapsack_importance_selection(scores, budget):
    sorted_idx = np.argsort(scores)[::-1]  # highest scores first
    selected = []
    for i in sorted_idx:
        if len(selected) < budget:
            selected.append(i)
        else:
            break
    return sorted(selected)

def get_binary_summary(selected_idxs, total_len):
    binary = np.zeros(total_len)
    binary[selected_idxs] = 1
    return binary


In [None]:

def compute_precision_recall_f1(pred, gt):
    tp = np.sum((pred == 1) & (gt == 1))
    fp = np.sum((pred == 1) & (gt == 0))
    fn = np.sum((pred == 0) & (gt == 1))

    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)
    f1 = 2 * precision * recall / (precision + recall + 1e-8)

    return precision, recall, f1


In [None]:
from google.colab import drive

# 1. Mount Google Drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:

from torch.utils.data import DataLoader
import pandas as pd

# Load CSVs
ratings_df = pd.read_csv("/content/drive/MyDrive/tvsum_average_ratings.csv")
fold_df = pd.read_csv("/content/drive/MyDrive/tvsum_stratified_5fold_split.csv")

# Choose a test fold
test_fold = 1
train_ids = fold_df[fold_df['fold'] != test_fold]['video_id'].tolist()
test_ids = fold_df[fold_df['fold'] == test_fold]['video_id'].tolist()

# Paths
features_dir = "/content/drive/MyDrive/GN_V3_features"

# Datasets
train_dataset = TVSumDataset(train_ids, features_dir, ratings_df)
test_dataset = TVSumDataset(test_ids, features_dir, ratings_df)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)


In [None]:

def evaluate_model(model, dataloader, ratings_df, device, set_name="Test", summary_ratio=0.15):
    model.eval()
    total_f1 = 0
    total_videos = 0

    with torch.no_grad():
        for feats, scores, lengths, vids in dataloader:
            feats, scores = feats.to(device), scores.to(device)
            lengths = lengths.to(device)
            preds = model(feats, lengths).cpu().numpy()[0]  # [T]
            vid = vids[0]
            length = lengths[0].item()

            # Predict knapsack summary
            k = int(length * summary_ratio)
            pred_summary_idx = knapsack_importance_selection(preds[:length], k)
            pred_bin = get_binary_summary(pred_summary_idx, length)

            # ✅ Updated Ground Truth Parsing
            gt_str = ratings_df[ratings_df['video_name'] == vid].iloc[0, 2]
            gt_scores = np.array([float(s) for s in gt_str.split(',')], dtype=np.float32)

            if len(gt_scores) > length:
                gt_scores = gt_scores[:length]
            elif len(gt_scores) < length:
                gt_scores = np.pad(gt_scores, (0, length - len(gt_scores)), mode='constant')

            gt_summary_idx = knapsack_importance_selection(gt_scores, k)
            gt_bin = get_binary_summary(gt_summary_idx, length)

            # Compare
            precision, recall, f1 = compute_precision_recall_f1(pred_bin, gt_bin)

            print(f"[{set_name}] {vid}: F1={f1:.4f}, Precision={precision:.4f}, Recall={recall:.4f}")
            print(f"[{set_name}] Selected frame indices: {pred_summary_idx}")
            print("-" * 40)

            total_f1 += f1
            total_videos += 1

    avg_f1 = total_f1 / total_videos
    print(f"[{set_name}] Avg F1 Score across {total_videos} videos: {avg_f1:.4f}")


In [None]:
model = BiLSTMModel().to(device)
optimizer = Adam(model.parameters(), lr=1e-4)
num_epochs = 10
model_save_path = "/content/drive/MyDrive/bilstm_knapsack_tvsum.pth"

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")

    if os.path.exists(model_save_path):
        model.load_state_dict(torch.load(model_save_path))

    model.train()
    total_loss = 0

    for feats, scores, lengths, _ in tqdm(train_loader):
        feats, scores, lengths = feats.to(device), scores.to(device), lengths.to(device)

        preds = model(feats, lengths)
        loss = masked_mse_loss(preds, scores, lengths)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Train Loss: {avg_loss:.4f}")

    torch.save(model.state_dict(), model_save_path)

    # Evaluate both sets
    print("Evaluating on Train Set:")
    evaluate_model(model, train_loader, ratings_df, device, set_name="Train")

    print("Evaluating on Test Set:")
    evaluate_model(model, test_loader, ratings_df, device, set_name="Test")



Epoch 1/10


100%|██████████| 10/10 [02:06<00:00, 12.64s/it]


Train Loss: 0.2037
Evaluating on Train Set:
[Train] uGu_10sucQo: F1=0.4609, Precision=0.4609, Recall=0.4609
[Train] Selected frame indices: [np.int64(354), np.int64(355), np.int64(356), np.int64(357), np.int64(358), np.int64(715), np.int64(716), np.int64(719), np.int64(720), np.int64(721), np.int64(722), np.int64(723), np.int64(724), np.int64(725), np.int64(726), np.int64(731), np.int64(732), np.int64(733), np.int64(734), np.int64(735), np.int64(736), np.int64(737), np.int64(738), np.int64(739), np.int64(740), np.int64(741), np.int64(742), np.int64(743), np.int64(760), np.int64(854), np.int64(856), np.int64(857), np.int64(858), np.int64(859), np.int64(860), np.int64(861), np.int64(863), np.int64(866), np.int64(867), np.int64(875), np.int64(880), np.int64(881), np.int64(882), np.int64(883), np.int64(884), np.int64(885), np.int64(886), np.int64(887), np.int64(1078), np.int64(1079), np.int64(1080), np.int64(1081), np.int64(1082), np.int64(1083), np.int64(1084), np.int64(1085), np.int64(10

100%|██████████| 10/10 [00:15<00:00,  1.57s/it]


Train Loss: 0.1785
Evaluating on Train Set:
[Train] J0nA4VgnoCo: F1=0.5428, Precision=0.5428, Recall=0.5428
[Train] Selected frame indices: [np.int64(598), np.int64(599), np.int64(600), np.int64(601), np.int64(602), np.int64(603), np.int64(604), np.int64(605), np.int64(606), np.int64(607), np.int64(608), np.int64(609), np.int64(610), np.int64(611), np.int64(612), np.int64(613), np.int64(614), np.int64(615), np.int64(616), np.int64(617), np.int64(618), np.int64(619), np.int64(620), np.int64(621), np.int64(622), np.int64(623), np.int64(624), np.int64(625), np.int64(626), np.int64(627), np.int64(646), np.int64(647), np.int64(648), np.int64(649), np.int64(650), np.int64(651), np.int64(652), np.int64(653), np.int64(654), np.int64(655), np.int64(656), np.int64(657), np.int64(658), np.int64(659), np.int64(660), np.int64(724), np.int64(725), np.int64(726), np.int64(727), np.int64(728), np.int64(729), np.int64(730), np.int64(731), np.int64(732), np.int64(733), np.int64(734), np.int64(735), np.i

100%|██████████| 10/10 [00:14<00:00,  1.48s/it]


Train Loss: 0.1709
Evaluating on Train Set:
[Train] WG0MBPpPC6I: F1=0.5070, Precision=0.5070, Recall=0.5070
[Train] Selected frame indices: [np.int64(109), np.int64(110), np.int64(111), np.int64(112), np.int64(113), np.int64(114), np.int64(115), np.int64(116), np.int64(117), np.int64(118), np.int64(119), np.int64(120), np.int64(121), np.int64(122), np.int64(123), np.int64(124), np.int64(125), np.int64(126), np.int64(127), np.int64(128), np.int64(129), np.int64(130), np.int64(131), np.int64(132), np.int64(133), np.int64(134), np.int64(135), np.int64(136), np.int64(137), np.int64(138), np.int64(139), np.int64(140), np.int64(141), np.int64(142), np.int64(143), np.int64(144), np.int64(145), np.int64(146), np.int64(147), np.int64(148), np.int64(149), np.int64(150), np.int64(155), np.int64(156), np.int64(157), np.int64(158), np.int64(159), np.int64(160), np.int64(161), np.int64(162), np.int64(163), np.int64(164), np.int64(165), np.int64(166), np.int64(167), np.int64(168), np.int64(169), np.i

100%|██████████| 10/10 [00:16<00:00,  1.65s/it]


Train Loss: 0.1611
Evaluating on Train Set:
[Train] sTEELN-vY30: F1=0.6448, Precision=0.6448, Recall=0.6448
[Train] Selected frame indices: [np.int64(14), np.int64(15), np.int64(16), np.int64(17), np.int64(18), np.int64(19), np.int64(21), np.int64(22), np.int64(23), np.int64(24), np.int64(25), np.int64(26), np.int64(27), np.int64(28), np.int64(29), np.int64(30), np.int64(31), np.int64(32), np.int64(33), np.int64(34), np.int64(35), np.int64(36), np.int64(37), np.int64(38), np.int64(39), np.int64(40), np.int64(41), np.int64(42), np.int64(43), np.int64(44), np.int64(45), np.int64(46), np.int64(47), np.int64(48), np.int64(49), np.int64(50), np.int64(51), np.int64(52), np.int64(53), np.int64(54), np.int64(55), np.int64(56), np.int64(57), np.int64(58), np.int64(59), np.int64(60), np.int64(61), np.int64(62), np.int64(63), np.int64(64), np.int64(65), np.int64(66), np.int64(67), np.int64(68), np.int64(69), np.int64(70), np.int64(71), np.int64(72), np.int64(73), np.int64(74), np.int64(75), np.in

100%|██████████| 10/10 [00:15<00:00,  1.52s/it]


Train Loss: 0.1549
Evaluating on Train Set:
[Train] kLxoNp-UchI: F1=0.3356, Precision=0.3356, Recall=0.3356
[Train] Selected frame indices: [np.int64(252), np.int64(253), np.int64(254), np.int64(255), np.int64(256), np.int64(257), np.int64(258), np.int64(259), np.int64(260), np.int64(261), np.int64(262), np.int64(263), np.int64(264), np.int64(265), np.int64(266), np.int64(267), np.int64(268), np.int64(269), np.int64(270), np.int64(271), np.int64(272), np.int64(273), np.int64(274), np.int64(275), np.int64(276), np.int64(277), np.int64(278), np.int64(279), np.int64(280), np.int64(281), np.int64(282), np.int64(283), np.int64(284), np.int64(285), np.int64(286), np.int64(287), np.int64(288), np.int64(289), np.int64(290), np.int64(291), np.int64(292), np.int64(293), np.int64(294), np.int64(295), np.int64(296), np.int64(297), np.int64(298), np.int64(299), np.int64(300), np.int64(301), np.int64(302), np.int64(303), np.int64(304), np.int64(305), np.int64(306), np.int64(307), np.int64(308), np.i

100%|██████████| 10/10 [00:15<00:00,  1.55s/it]


Train Loss: 0.1469
Evaluating on Train Set:
[Train] XzYM3PfTM4w: F1=0.4168, Precision=0.4168, Recall=0.4168
[Train] Selected frame indices: [np.int64(593), np.int64(594), np.int64(595), np.int64(596), np.int64(597), np.int64(598), np.int64(599), np.int64(600), np.int64(601), np.int64(602), np.int64(603), np.int64(604), np.int64(605), np.int64(606), np.int64(1138), np.int64(1139), np.int64(1140), np.int64(1142), np.int64(1143), np.int64(1155), np.int64(1156), np.int64(1157), np.int64(1158), np.int64(1159), np.int64(1160), np.int64(1161), np.int64(1162), np.int64(1163), np.int64(1164), np.int64(1165), np.int64(1166), np.int64(1167), np.int64(1168), np.int64(1169), np.int64(1170), np.int64(1171), np.int64(1172), np.int64(1173), np.int64(1174), np.int64(1175), np.int64(1176), np.int64(1177), np.int64(1178), np.int64(1179), np.int64(1180), np.int64(1181), np.int64(1182), np.int64(1183), np.int64(1184), np.int64(1185), np.int64(1186), np.int64(1187), np.int64(1188), np.int64(1189), np.int64(

100%|██████████| 10/10 [00:15<00:00,  1.52s/it]


Train Loss: 0.1405
Evaluating on Train Set:
[Train] 98MoyGZKHXc: F1=0.6102, Precision=0.6102, Recall=0.6102
[Train] Selected frame indices: [np.int64(593), np.int64(594), np.int64(1605), np.int64(1606), np.int64(1607), np.int64(1608), np.int64(1609), np.int64(1610), np.int64(1611), np.int64(1612), np.int64(1613), np.int64(1614), np.int64(1615), np.int64(1616), np.int64(1617), np.int64(1618), np.int64(1619), np.int64(1620), np.int64(1621), np.int64(1622), np.int64(1623), np.int64(1624), np.int64(1625), np.int64(1626), np.int64(1627), np.int64(1628), np.int64(1629), np.int64(1630), np.int64(1631), np.int64(1632), np.int64(1633), np.int64(1634), np.int64(1635), np.int64(1636), np.int64(1637), np.int64(1638), np.int64(1639), np.int64(1640), np.int64(1641), np.int64(1642), np.int64(1643), np.int64(1644), np.int64(1645), np.int64(1646), np.int64(1647), np.int64(1648), np.int64(1649), np.int64(1650), np.int64(1651), np.int64(1652), np.int64(1653), np.int64(1654), np.int64(1655), np.int64(1656

100%|██████████| 10/10 [00:15<00:00,  1.53s/it]


Train Loss: 0.1335
Evaluating on Train Set:
[Train] 91IHQYk1IQM: F1=0.3367, Precision=0.3367, Recall=0.3367
[Train] Selected frame indices: [np.int64(225), np.int64(226), np.int64(268), np.int64(270), np.int64(271), np.int64(272), np.int64(273), np.int64(274), np.int64(275), np.int64(276), np.int64(277), np.int64(278), np.int64(279), np.int64(280), np.int64(281), np.int64(282), np.int64(283), np.int64(284), np.int64(285), np.int64(286), np.int64(287), np.int64(288), np.int64(289), np.int64(290), np.int64(291), np.int64(292), np.int64(483), np.int64(484), np.int64(485), np.int64(486), np.int64(488), np.int64(489), np.int64(491), np.int64(498), np.int64(499), np.int64(500), np.int64(507), np.int64(513), np.int64(514), np.int64(515), np.int64(907), np.int64(911), np.int64(912), np.int64(915), np.int64(916), np.int64(917), np.int64(918), np.int64(919), np.int64(920), np.int64(921), np.int64(922), np.int64(923), np.int64(924), np.int64(925), np.int64(926), np.int64(927), np.int64(928), np.i

100%|██████████| 10/10 [00:15<00:00,  1.52s/it]


Train Loss: 0.1283
Evaluating on Train Set:
[Train] AwmHb44_ouw: F1=0.3115, Precision=0.3115, Recall=0.3115
[Train] Selected frame indices: [np.int64(2845), np.int64(2846), np.int64(2847), np.int64(2930), np.int64(2931), np.int64(2932), np.int64(2933), np.int64(2934), np.int64(2935), np.int64(2936), np.int64(2937), np.int64(2938), np.int64(2939), np.int64(2940), np.int64(2941), np.int64(2942), np.int64(2972), np.int64(2973), np.int64(2974), np.int64(2975), np.int64(2976), np.int64(2977), np.int64(2978), np.int64(2979), np.int64(2980), np.int64(2981), np.int64(2982), np.int64(2983), np.int64(2984), np.int64(2985), np.int64(2986), np.int64(2987), np.int64(2988), np.int64(2989), np.int64(2990), np.int64(2991), np.int64(2992), np.int64(2993), np.int64(2994), np.int64(2995), np.int64(2996), np.int64(2997), np.int64(2998), np.int64(2999), np.int64(3000), np.int64(3001), np.int64(3095), np.int64(3098), np.int64(3102), np.int64(3103), np.int64(3104), np.int64(3105), np.int64(3106), np.int64(31

100%|██████████| 10/10 [00:15<00:00,  1.52s/it]


Train Loss: 0.1192
Evaluating on Train Set:
[Train] cjibtmSLxQ4: F1=0.4845, Precision=0.4845, Recall=0.4845
[Train] Selected frame indices: [np.int64(373), np.int64(374), np.int64(375), np.int64(376), np.int64(377), np.int64(378), np.int64(379), np.int64(393), np.int64(472), np.int64(480), np.int64(500), np.int64(699), np.int64(700), np.int64(701), np.int64(702), np.int64(703), np.int64(721), np.int64(725), np.int64(726), np.int64(727), np.int64(728), np.int64(729), np.int64(730), np.int64(731), np.int64(732), np.int64(733), np.int64(734), np.int64(735), np.int64(736), np.int64(737), np.int64(738), np.int64(739), np.int64(740), np.int64(741), np.int64(742), np.int64(743), np.int64(888), np.int64(889), np.int64(890), np.int64(992), np.int64(994), np.int64(995), np.int64(996), np.int64(997), np.int64(998), np.int64(999), np.int64(1000), np.int64(1007), np.int64(1008), np.int64(1009), np.int64(1010), np.int64(1047), np.int64(1050), np.int64(1051), np.int64(1066), np.int64(1168), np.int64(

In [None]:
print(ratings_df.columns)


Index(['video_name', 'video_type', 'average_rating'], dtype='object')
