In [3]:
import datetime as dt

In [4]:
video_post_times = {
    "7BCojznmtRE.csv": dt.datetime(2020, 9, 5, 6, 45, 11, tzinfo=dt.timezone.utc),
    "cIY95KCnnNk.csv": dt.datetime(2022, 2, 26, 22, 0, 14, tzinfo=dt.timezone.utc),
    "CrtuA5HWFoU.csv": dt.datetime(2020, 10, 20, 10, 30, 11, tzinfo=dt.timezone.utc),
    "hDkuUZ3F1GU.csv": dt.datetime(2020, 11, 26, 23, 45, 10, tzinfo=dt.timezone.utc),
    "hgdSJdeGF_0.csv": dt.datetime(2020, 6, 19, 5, 35, 49, tzinfo=dt.timezone.utc),
    "qqOxkuO3ip0.csv": dt.datetime(2021, 7, 8, 22, 45, 11, tzinfo=dt.timezone.utc),
    "RJ0jdO5ZfU4.csv": dt.datetime(2021, 4, 23, 19, 00, 12, tzinfo=dt.timezone.utc),
    "RTXS4MMngnA.csv": dt.datetime(2020, 5, 29, 4, 38, 30, tzinfo=dt.timezone.utc),
    "tylNqtyj0gs.csv": dt.datetime(2020, 8, 7, 7, 45, 12, tzinfo=dt.timezone.utc),
    "Y7t5B69G0Dw.csv": dt.datetime(2020, 7, 20, 11, 40, 34, tzinfo=dt.timezone.utc)
}

In [None]:
from pathlib import Path
import pandas as pd
import glob
import datetime as dt

for filepath in glob.glob("*.csv"):
  if 'annotated' in filepath:
    continue

  base_dt = video_post_times[filepath]

  df = pd.read_csv(filepath)
  if 'timestamp' not in df.columns:
      raise ValueError("CSV must have a 'timestamp' column (seconds).")

  df['timestamp'] = pd.to_numeric(df['timestamp'], errors='coerce')
  if df['timestamp'].isna().any():
      raise ValueError("Found non-numeric timestamps; clean or convert them first.")

  df['event_dt'] = pd.to_datetime(base_dt) + pd.to_timedelta(df['timestamp'], unit='s')

  # Get weekday and hour (0–23) in UTC
  df['day_of_week_utc'] = df['event_dt'].dt.weekday
  df['hour_utc'] = df['event_dt'].dt.hour

  # Optionally drop intermediate column
  df = df.drop(columns=['event_dt'])

  output_path = Path(filepath).with_name(Path(filepath).stem + "_annotated.csv")
  df.to_csv(output_path, index=False)

In [None]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
import pickle

class Config:
    raw_csv = "dataset.csv"
    processed_csv = "dataset_processed.csv"
    scaler_path = "scaler.pkl"

    text_col = "comment"
    timestamp_col = "timestamp"
    day_col = "day"
    hour_col = "hour"
    label_col = "viral"

cfg = Config()

def preprocess_delta(series: pd.Series) -> np.ndarray:
    return np.sign(series.astype(float).values) * np.log1p(np.abs(series.astype(float).values))

def encode_day_of_week(series: pd.Series) -> np.ndarray:
    d = series.astype(float).values
    sin = np.sin(2 * np.pi * d / 7)
    cos = np.cos(2 * np.pi * d / 7)
    return np.stack([sin, cos], axis=1)

def encode_hour_of_day(series: pd.Series) -> np.ndarray:
    h = series.astype(float).values
    sin = np.sin(2 * np.pi * h / 24)
    cos = np.cos(2 * np.pi * h / 24)
    return np.stack([sin, cos], axis=1)

# Load raw data
df = pd.read_csv(cfg.raw_csv)
df = df.dropna(subset=[cfg.text_col, cfg.timestamp_col, cfg.day_col, cfg.hour_col, cfg.label_col])

# Preprocess timestamp
delta_raw = preprocess_delta(df[cfg.timestamp_col])
scaler = StandardScaler().fit(delta_raw.reshape(-1, 1))
df['timestamp_scaled'] = scaler.transform(delta_raw.reshape(-1, 1)).reshape(-1)

# Cyclic encodings
day_encoded = encode_day_of_week(df[cfg.day_col])
df['day_sin'] = day_encoded[:, 0]
df['day_cos'] = day_encoded[:, 1]

hour_encoded = encode_hour_of_day(df[cfg.hour_col])
df['hour_sin'] = hour_encoded[:, 0]
df['hour_cos'] = hour_encoded[:, 1]

df = df[['comment',
         'timestamp_scaled', 'day_sin', 'day_cos', 'hour_sin', 'hour_cos',
         'viral']]

# Save processed dataset
df.to_csv(cfg.processed_csv, index=False)
print(f"Saved preprocessed data to {cfg.processed_csv}")
print(f"Saved scaler to {cfg.scaler_path}")

Saved preprocessed data to dataset_processed.csv
Saved scaler to scaler.pkl


In [6]:
pip install --upgrade transformers

Collecting transformers
  Downloading transformers-4.57.3-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
Downloading transformers-4.57.3-py3-none-any.whl (12.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.0/12.0 MB[0m [31m54.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.57.2
    Uninstalling transformers-4.57.2:
      Successfully uninstalled transformers-4.57.2
Successfully installed transformers-4.57.3


In [3]:
# BASE TRANSFORMER

import os
import math
import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from transformers import (
    DistilBertModel,
    DistilBertTokenizerFast,
    DataCollatorWithPadding,
)
from transformers.optimization import get_linear_schedule_with_warmup
from torch.optim import AdamW
import json

class Config:
    data_csv = "dataset_processed.csv"  # preprocessed CSV with all features
    text_col = "comment"

    # Preprocessed numeric feature columns
    numeric_features = ['timestamp_scaled', 'day_sin', 'day_cos', 'hour_sin', 'hour_cos']
    label_col = "viral"

    model_name = "distilbert-base-uncased"
    max_length = 256
    batch_size = 32
    epochs = 4
    lr = 2e-5
    weight_decay = 0.01
    warmup_ratio = 0.1
    dropout = 0.2
    delta_proj_dim = 64
    seed = 42
    freeze_transformer = False
    save_best_path = "best_viral_classifier.pt"


cfg = Config()

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(cfg.seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

TIMESTAMP_MEAN = 11.800730451380801
TIMESTAMP_STD = 4.351129933194433

df = pd.read_csv(cfg.data_csv)

# Basic cleaning/checks
required_cols = [cfg.text_col] + cfg.numeric_features + [cfg.label_col]
missing = [c for c in required_cols if c not in df.columns]
if missing:
    raise ValueError(f"Missing columns in CSV: {missing}")

df = df.dropna(subset=required_cols)
df[cfg.label_col] = df[cfg.label_col].astype(int)

train_df, test_df = train_test_split(df, test_size=0.15, random_state=cfg.seed, stratify=df[cfg.label_col])
train_df, val_df = train_test_split(train_df, test_size=0.15 * (1/0.85), random_state=cfg.seed, stratify=train_df[cfg.label_col])

train_numeric = train_df[cfg.numeric_features].values
val_numeric = val_df[cfg.numeric_features].values
test_numeric = test_df[cfg.numeric_features].values

print(f"Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")

tokenizer = DistilBertTokenizerFast.from_pretrained(cfg.model_name)

class CommentDataset(Dataset):
    def __init__(self, texts, numeric_feats, labels, tokenizer, max_length):
        self.texts = list(texts)
        self.numeric = torch.tensor(numeric_feats, dtype=torch.float32)
        self.labels = torch.tensor(np.asarray(labels, dtype=np.float32))
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        enc = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            padding=False,
            return_tensors=None,
        )
        return {
            "input_ids": torch.tensor(enc["input_ids"], dtype=torch.long).squeeze(0),
            "attention_mask": torch.tensor(enc["attention_mask"], dtype=torch.long).squeeze(0),
            "numeric": self.numeric[idx],  # shape (5,)
            "labels": self.labels[idx],
        }

train_ds = CommentDataset(train_df[cfg.text_col], train_numeric, train_df[cfg.label_col], tokenizer, cfg.max_length)
val_ds   = CommentDataset(val_df[cfg.text_col],   val_numeric,   val_df[cfg.label_col],   tokenizer, cfg.max_length)
test_ds  = CommentDataset(test_df[cfg.text_col],  test_numeric,  test_df[cfg.label_col],  tokenizer, cfg.max_length)

base_collator = DataCollatorWithPadding(tokenizer=tokenizer)

def collate_fn(batch):
    token_batch = [{"input_ids": b["input_ids"], "attention_mask": b["attention_mask"]} for b in batch]
    collated = base_collator(token_batch)
    numerics = torch.stack([b["numeric"] for b in batch])  # (B,5)
    labels = torch.stack([b["labels"] for b in batch])
    collated["numeric"] = numerics
    collated["labels"] = labels
    return collated

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds,   batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)

class ViralCommentClassifier(nn.Module):
    def __init__(self, model_name, numeric_in_dim=5, delta_proj_dim=64, dropout=0.2, freeze_transformer=False):
        super().__init__()
        self.transformer = DistilBertModel.from_pretrained(model_name)
        if freeze_transformer:
            for p in self.transformer.parameters():
                p.requires_grad = False

        hidden_size = self.transformer.config.hidden_size  # 768

        # Project numeric vector (delta + cyc encodings) to feature vector
        self.numeric_proj = nn.Sequential(
            nn.Linear(numeric_in_dim, delta_proj_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(delta_proj_dim, delta_proj_dim),
            nn.GELU(),
        )

        combined_dim = hidden_size + delta_proj_dim
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(combined_dim, 256),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(256, 1),
        )

    def forward(self, input_ids, attention_mask, numeric):
        out = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
        cls_emb = out.last_hidden_state[:, 0, :]
        num_emb = self.numeric_proj(numeric)
        x = torch.cat([cls_emb, num_emb], dim=-1)
        logits = self.classifier(x).squeeze(-1)
        return logits

model = ViralCommentClassifier(
    cfg.model_name,
    numeric_in_dim=len(cfg.numeric_features),  # 5
    delta_proj_dim=cfg.delta_proj_dim,
    dropout=cfg.dropout,
    freeze_transformer=cfg.freeze_transformer
).to(device)

train_labels_np = train_df[cfg.label_col].values
pos_count = (train_labels_np == 1).sum()
neg_count = (train_labels_np == 0).sum()
pos_weight = torch.tensor(neg_count / max(pos_count, 1), dtype=torch.float32).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

optimizer = AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
num_training_steps = len(train_loader) * cfg.epochs
num_warmup_steps = int(cfg.warmup_ratio * num_training_steps)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)

def evaluate(loader, model):
    model.eval()
    all_labels = []
    all_probs = []
    with torch.no_grad():
        for batch in loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            numeric = batch["numeric"].to(device)
            labels = batch["labels"].to(device)

            logits = model(input_ids, attention_mask, numeric)
            probs = torch.sigmoid(logits)
            all_labels.append(labels.cpu().numpy())
            all_probs.append(probs.cpu().numpy())

    all_labels = np.concatenate(all_labels)
    all_probs = np.concatenate(all_probs)
    preds = (all_probs >= 0.5).astype(np.int32)

    acc = accuracy_score(all_labels, preds)
    f1 = f1_score(all_labels, preds)
    try:
        roc = roc_auc_score(all_labels, all_probs)
    except ValueError:
        roc = float("nan")
    return acc, f1, roc

#Training
best_val_f1 = -1.0
for epoch in range(1, cfg.epochs + 1):
    model.train()
    total_loss = 0.0
    for batch in train_loader:
        optimizer.zero_grad()
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        numeric = batch["numeric"].to(device)
        labels = batch["labels"].to(device)

        logits = model(input_ids, attention_mask, numeric)
        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    val_acc, val_f1, val_roc = evaluate(val_loader, model)
    print(f"Epoch {epoch}/{cfg.epochs} | Train Loss: {avg_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f} | Val ROC-AUC: {val_roc:.4f}")

    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save({
            "model_state_dict": model.state_dict(),
            "timestamp_mean": TIMESTAMP_MEAN,
            "timestamp_std": TIMESTAMP_STD,
            "tokenizer_name": cfg.model_name,
            "config": vars(cfg),
        }, cfg.save_best_path)
        print(f"Saved best model to {cfg.save_best_path}")

test_acc, test_f1, test_roc = evaluate(test_loader, model)
print(f"Test Acc: {test_acc:.4f} | Test F1: {test_f1:.4f} | Test ROC-AUC: {test_roc:.4f}")

def preprocess_numeric_for_inference(timestamp_seconds, day, hour, timestamp_mean, timestamp_std):
    # delta: log1p then standardize
    delta = math.log1p(float(timestamp_seconds))
    delta = (delta - timestamp_mean) / timestamp_std
    # cyc encodings
    day = float(day)
    hour = float(hour)
    day_sin = math.sin(2 * math.pi * day / 7)
    day_cos = math.cos(2 * math.pi * day / 7)
    hour_sin = math.sin(2 * math.pi * hour / 24)
    hour_cos = math.cos(2 * math.pi * hour / 24)
    return np.array([delta, day_sin, day_cos, hour_sin, hour_cos], dtype=np.float32)


def predict_viral(texts, timestamps, days, hours, model, tokenizer, timestamp_mean, timestamp_std, max_length=cfg.max_length, batch_size=32):
    model.eval()
    batches = []
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        batch_ts = timestamps[i:i+batch_size]
        batch_days = days[i:i+batch_size]
        batch_hours = hours[i:i+batch_size]
        enc = tokenizer(batch_texts, truncation=True, max_length=max_length, padding=True, return_tensors="pt")
        numerics = [preprocess_numeric_for_inference(t, d, h, timestamp_mean, timestamp_std) for t, d, h in zip(batch_ts, batch_days, batch_hours)]
        numerics = torch.tensor(numerics, dtype=torch.float32)
        batches.append((enc, numerics))

    all_probs = []
    with torch.no_grad():
        for enc, numerics in batches:
            input_ids = enc["input_ids"].to(device)
            attention_mask = enc["attention_mask"].to(device)
            numerics = numerics.to(device)
            logits = model(input_ids, attention_mask, numerics)
            probs = torch.sigmoid(logits).cpu().numpy()
            all_probs.extend(probs.tolist())

    preds = [1 if p >= 0.5 else 0 for p in all_probs]
    return preds, all_probs

Using device: cuda
Train: 90598, Val: 19415, Test: 19415
Epoch 1/4 | Train Loss: 0.5688 | Val Acc: 0.8970 | Val F1: 0.7373 | Val ROC-AUC: 0.9539
Saved best model to best_viral_classifier.pt
Epoch 2/4 | Train Loss: 0.4155 | Val Acc: 0.9001 | Val F1: 0.7453 | Val ROC-AUC: 0.9566
Saved best model to best_viral_classifier.pt
Epoch 3/4 | Train Loss: 0.3369 | Val Acc: 0.9094 | Val F1: 0.7565 | Val ROC-AUC: 0.9558
Saved best model to best_viral_classifier.pt
Epoch 4/4 | Train Loss: 0.2742 | Val Acc: 0.9047 | Val F1: 0.7475 | Val ROC-AUC: 0.9539
Test Acc: 0.9059 | Test F1: 0.7485 | Test ROC-AUC: 0.9555


In [None]:
# HIERARCHICAL MODEL

import os
import math
import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, average_precision_score
from transformers import (
    DistilBertModel,
    DistilBertTokenizerFast,
    DataCollatorWithPadding,
)
from transformers.optimization import get_linear_schedule_with_warmup
from torch.optim import AdamW
import re

class Config:
    data_csv = "dataset_processed.csv"  # path to your preprocessed CSV
    text_col = "comment"
    timestamp_col = "timestamp_scaled"
    day_sin_col = "day_sin"
    day_cos_col = "day_cos"
    hour_sin_col = "hour_sin"
    hour_cos_col = "hour_cos"
    label_col = "viral"

    model_name = "distilbert-base-uncased"
    max_length = 128
    batch_size = 16
    epochs = 4
    lr = 2e-5
    weight_decay = 0.01
    warmup_ratio = 0.1
    dropout = 0.2
    delta_proj_dim = 64
    seed = 42
    freeze_transformer = False
    save_best_path = "best_hierarchical_viral_classifier.pt"


cfg = Config()


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(cfg.seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def split_into_sentences(text):
    """Split comment into sentences. Handle YouTube comment quirks."""
    text = str(text)

    # Split on sentence-ending punctuation
    sentences = re.split(r'[.!?]+', text)
    sentences = [s.strip() for s in sentences if s.strip()]

    # If no punctuation, try splitting on newlines
    if len(sentences) <= 1:
        sentences = [s.strip() for s in text.split('\n') if s.strip()]

    # If still nothing, just use the whole text
    if not sentences:
        sentences = [text.strip()]

    # Cap at 10 sentences to avoid memory issues
    return sentences[:10]

df = pd.read_csv(cfg.data_csv)

# Basic cleaning/checks
required_cols = [
    cfg.text_col,
    cfg.timestamp_col,
    cfg.day_sin_col,
    cfg.day_cos_col,
    cfg.hour_sin_col,
    cfg.hour_cos_col,
    cfg.label_col
]
missing = [c for c in required_cols if c not in df.columns]
if missing:
    raise ValueError(f"Missing columns in CSV: {missing}")

df = df.dropna(subset=required_cols)
df[cfg.label_col] = df[cfg.label_col].astype(int)

print(f"Dataset size: {len(df)}")
print(f"Viral comments: {df[cfg.label_col].sum()} ({df[cfg.label_col].mean()*100:.2f}%)")

train_df, test_df = train_test_split(
    df, test_size=0.15, random_state=cfg.seed, stratify=df[cfg.label_col]
)
train_df, val_df = train_test_split(
    train_df, test_size=0.1765, random_state=cfg.seed, stratify=train_df[cfg.label_col]
)

print(f"Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")

# Split comments into sentences
print("Splitting comments into sentences")
train_df['sentences'] = train_df[cfg.text_col].apply(split_into_sentences)
val_df['sentences'] = val_df[cfg.text_col].apply(split_into_sentences)
test_df['sentences'] = test_df[cfg.text_col].apply(split_into_sentences)

# Check sentence statistics
all_sentence_counts = train_df['sentences'].apply(len)

def extract_numeric_features(df):
    return np.column_stack([
        df[cfg.timestamp_col].values,
        df[cfg.day_sin_col].values,
        df[cfg.day_cos_col].values,
        df[cfg.hour_sin_col].values,
        df[cfg.hour_cos_col].values,
    ])

train_numeric = extract_numeric_features(train_df)
val_numeric = extract_numeric_features(val_df)
test_numeric = extract_numeric_features(test_df)

print(f"Numeric features shape: {train_numeric.shape}")
print(f"Numeric features (first row): {train_numeric[0]}")

tokenizer = DistilBertTokenizerFast.from_pretrained(cfg.model_name)

class HierarchicalCommentDataset(Dataset):
    def __init__(self, sentence_lists, numeric_feats, labels, tokenizer, max_length):
        self.sentence_lists = list(sentence_lists)
        self.numeric = torch.tensor(numeric_feats, dtype=torch.float32)
        self.labels = torch.tensor(np.asarray(labels, dtype=np.float32))
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.sentence_lists)

    def __getitem__(self, idx):
        sentences = self.sentence_lists[idx]

        # Tokenize each sentence separately
        sentence_encodings = []
        for sent in sentences:
            enc = self.tokenizer(
                sent,
                truncation=True,
                max_length=self.max_length,
                padding=False,
                return_tensors=None,
            )
            sentence_encodings.append({
                "input_ids": torch.tensor(enc["input_ids"], dtype=torch.long),
                "attention_mask": torch.tensor(enc["attention_mask"], dtype=torch.long),
            })

        return {
            "sentences": sentence_encodings,
            "num_sentences": len(sentences),
            "numeric": self.numeric[idx],
            "labels": self.labels[idx],
        }

train_ds = HierarchicalCommentDataset(
    train_df['sentences'].values, train_numeric, train_df[cfg.label_col].values,
    tokenizer, cfg.max_length
)
val_ds = HierarchicalCommentDataset(
    val_df['sentences'].values, val_numeric, val_df[cfg.label_col].values,
    tokenizer, cfg.max_length
)
test_ds = HierarchicalCommentDataset(
    test_df['sentences'].values, test_numeric, test_df[cfg.label_col].values,
    tokenizer, cfg.max_length
)

# Hierarchical Collator
base_collator = DataCollatorWithPadding(tokenizer=tokenizer)

def hierarchical_collate_fn(batch):
    """Collate sentences from multiple comments into a batch."""
    max_sentences = max(b["num_sentences"] for b in batch)
    batch_size = len(batch)

    # Flatten all sentences for batch processing
    all_sentences = []
    sentence_masks = []  # Track which sentences are real vs padding

    for b in batch:
        sents = b["sentences"]
        num_real = len(sents)

        # Add real sentences
        all_sentences.extend(sents)
        sentence_masks.extend([1] * num_real)

        # Pad with dummy sentences
        for _ in range(max_sentences - num_real):
            all_sentences.append({
                "input_ids": torch.tensor([tokenizer.cls_token_id, tokenizer.sep_token_id]),
                "attention_mask": torch.tensor([1, 1]),
            })
            sentence_masks.append(0)

    # Collate all sentences
    padded_sentences = base_collator(all_sentences)

    # Other data
    numerics = torch.stack([b["numeric"] for b in batch])
    labels = torch.stack([b["labels"] for b in batch])
    sentence_masks = torch.tensor(sentence_masks, dtype=torch.bool)

    return {
        "input_ids": padded_sentences["input_ids"],
        "attention_mask": padded_sentences["attention_mask"],
        "sentence_masks": sentence_masks.view(batch_size, max_sentences),
        "max_sentences": max_sentences,
        "batch_size": batch_size,
        "numeric": numerics,
        "labels": labels,
    }

train_loader = DataLoader(
    train_ds, batch_size=cfg.batch_size, shuffle=True, collate_fn=hierarchical_collate_fn
)
val_loader = DataLoader(
    val_ds, batch_size=cfg.batch_size, shuffle=False, collate_fn=hierarchical_collate_fn
)
test_loader = DataLoader(
    test_ds, batch_size=cfg.batch_size, shuffle=False, collate_fn=hierarchical_collate_fn
)

# Hierarchical Model
class HierarchicalViralClassifier(nn.Module):
    def __init__(self, model_name, numeric_in_dim=5, delta_proj_dim=64,
                 dropout=0.2, freeze_transformer=False):
        super().__init__()
        self.transformer = DistilBertModel.from_pretrained(model_name)

        if freeze_transformer:
            for p in self.transformer.parameters():
                p.requires_grad = False
            print("Transformer weights frozen")

        hidden_size = self.transformer.config.hidden_size  # 768

        # Sentence-level attention
        self.sentence_attention = nn.Sequential(
            nn.Linear(hidden_size, 128),
            nn.Tanh(),
            nn.Linear(128, 1)
        )

        # Numeric projection
        self.numeric_proj = nn.Sequential(
            nn.Linear(numeric_in_dim, delta_proj_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(delta_proj_dim, delta_proj_dim),
            nn.GELU(),
        )

        combined_dim = hidden_size + delta_proj_dim
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(combined_dim, 256),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(256, 1),
        )

        # Store attention for visualization
        self.last_sentence_attention = None
        self.last_context_attention = None

    def forward(self, input_ids, attention_mask, sentence_masks, max_sentences, batch_size, numeric):
        """
        input_ids: (batch_size * max_sentences, seq_len)
        attention_mask: (batch_size * max_sentences, seq_len)
        sentence_masks: (batch_size, max_sentences) - True for real sentences
        """
        bert_out = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
        sentence_embeddings = bert_out.last_hidden_state[:, 0, :]  # CLS tokens

        sentence_embeddings = sentence_embeddings.view(batch_size, max_sentences, -1)

        attention_scores = self.sentence_attention(sentence_embeddings).squeeze(-1)

        # Mask out padding sentences
        attention_scores = attention_scores.masked_fill(~sentence_masks, -1e9)

        # Softmax to get attention weights
        attention_weights = torch.softmax(attention_scores, dim=1)

        # Weighted sum of sentence embeddings
        comment_embedding = torch.sum(
            attention_weights.unsqueeze(-1) * sentence_embeddings,
            dim=1
        )

        num_emb = self.numeric_proj(numeric)

        final_embedding = torch.cat([comment_embedding, num_emb], dim=-1)

        # Classification
        logits = self.classifier(final_embedding).squeeze(-1)

        return logits

model = HierarchicalViralClassifier(
    cfg.model_name,
    numeric_in_dim=train_numeric.shape[1],  # Should be 5
    delta_proj_dim=cfg.delta_proj_dim,
    dropout=cfg.dropout,
    freeze_transformer=cfg.freeze_transformer
).to(device)
total_params, trainable_params = count_parameters(model)

train_labels_np = train_df[cfg.label_col].values
pos_count = (train_labels_np == 1).sum()
neg_count = (train_labels_np == 0).sum()
pos_weight = torch.tensor(neg_count / max(pos_count, 1), dtype=torch.float32).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

optimizer = AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
num_training_steps = len(train_loader) * cfg.epochs
num_warmup_steps = int(cfg.warmup_ratio * num_training_steps)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)

def evaluate(loader, model, criterion=None):
    model.eval()
    all_labels = []
    all_probs = []
    total_loss = 0.0

    with torch.no_grad():
        for batch in loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            sentence_masks = batch["sentence_masks"].to(device)
            max_sentences = batch["max_sentences"]
            batch_size = batch["batch_size"]
            numeric = batch["numeric"].to(device)
            labels = batch["labels"].to(device)

            logits = model(
                input_ids,
                attention_mask,
                sentence_masks,
                max_sentences,
                batch_size,
                numeric
            )
            probs = torch.sigmoid(logits)

            if criterion is not None:
                loss = criterion(logits, labels)
                total_loss += loss.item()

            all_labels.append(labels.cpu().numpy())
            all_probs.append(probs.cpu().numpy())

    all_labels = np.concatenate(all_labels)
    all_probs = np.concatenate(all_probs)
    preds = (all_probs >= 0.5).astype(np.int32)

    acc = accuracy_score(all_labels, preds)
    f1 = f1_score(all_labels, preds)
    pr_auc = average_precision_score(all_labels, all_probs)
    try:
        roc = roc_auc_score(all_labels, all_probs)
    except ValueError:
        roc = float("nan")

    avg_loss = total_loss / len(loader) if criterion is not None else None
    return acc, f1, roc, pr_auc, avg_loss

best_val_f1 = -1.0
for epoch in range(1, cfg.epochs + 1):
    model.train()
    total_loss = 0.0
    for batch in train_loader:
        optimizer.zero_grad()
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        sentence_masks = batch["sentence_masks"].to(device)
        max_sentences = batch["max_sentences"]
        batch_size = batch["batch_size"]
        numeric = batch["numeric"].to(device)
        labels = batch["labels"].to(device)

        logits = model(
            input_ids,
            attention_mask,
            sentence_masks,
            max_sentences,
            batch_size,
            numeric
        )
        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()

    train_avg_loss = total_loss / len(train_loader)

    # Evaluate on train set (for F1)
    train_acc, train_f1, train_roc, train_pr_auc, _ = evaluate(train_loader, model)

    # Evaluate on validation set
    val_acc, val_f1, val_roc, val_pr_auc, val_avg_loss = evaluate(val_loader, model, criterion)

    print(f"Epoch {epoch}/{cfg.epochs}")
    print(f"  Train Loss: {train_avg_loss:.4f} | Train F1: {train_f1:.4f}")
    print(f"  Val Loss: {val_avg_loss:.4f} | Val F1: {val_f1:.4f} | Val Acc: {val_acc:.4f} | "
          f"Val ROC-AUC: {val_roc:.4f} | Val PR-AUC: {val_pr_auc:.4f}")

    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save({
            "model_state_dict": model.state_dict(),
            "tokenizer_name": cfg.model_name,
            "config": vars(cfg),
        }, cfg.save_best_path)
        print(f"  ✓ Saved best model to {cfg.save_best_path}")
    print()

test_acc, test_f1, test_roc, test_pr_auc, test_loss = evaluate(test_loader, model, criterion)
print(f"\n{'='*60}")
print(f"FINAL TEST RESULTS:")
print(f"{'='*60}")
print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f} | Test F1: {test_f1:.4f} | "
      f"Test ROC-AUC: {test_roc:.4f} | Test PR-AUC: {test_pr_auc:.4f}")


Using device: cuda
Dataset size: 129428
Viral comments: 21572 (16.67%)
Train: 90595, Val: 19418, Test: 19415
Splitting comments into sentences...
Sentences per comment - Mean: 1.44, Median: 1, Max: 10
Numeric features shape: (90595, 5)
Numeric features (first row): [-0.3422516  0.         1.        -0.8660254  0.5      ]

MODEL PARAMETER COUNT
Total parameters:      66,679,490
Trainable parameters:  66,679,490
Frozen parameters:     0
Trainable percentage:  100.00%

Parameter breakdown by component:
--------------------------------------------------------------------------------
transformer                    Total:   66,362,880  Trainable:   66,362,880
sentence_attention             Total:       98,561  Trainable:       98,561
numeric_proj                   Total:        4,544  Trainable:        4,544
classifier                     Total:      213,505  Trainable:      213,505

Epoch 1/4
  Train Loss: 0.6495 | Train F1: 0.7655
  Val Loss: 0.6632 | Val F1: 0.7233 | Val Acc: 0.9010 | Val

TypeError: 'NoneType' object is not subscriptable

In [None]:
from sklearn.metrics import (
    accuracy_score,
    balanced_accuracy_score,
    f1_score,
    precision_recall_fscore_support,
    confusion_matrix,
    roc_auc_score,
    average_precision_score,
)

df = pd.read_csv(cfg.data_csv)
df = df.dropna(subset=[cfg.text_col] + ['timestamp_scaled', 'day_sin', 'day_cos', 'hour_sin', 'hour_cos'] + [cfg.label_col])
df[cfg.label_col] = df[cfg.label_col].astype(int)

train_df, test_df = train_test_split(
    df, test_size=0.15, random_state=cfg.seed, stratify=df[cfg.label_col]
)

test_df['sentences'] = test_df[cfg.text_col].apply(split_into_sentences)

test_numeric = test_df[['timestamp_scaled', 'day_sin', 'day_cos', 'hour_sin', 'hour_cos']].values

# Create hierarchical dataset
test_ds_hierarchical = HierarchicalCommentDataset(
    test_df['sentences'].values,
    test_numeric,
    test_df[cfg.label_col].values,
    tokenizer,
    cfg.max_length
)

# Create hierarchical loader
test_loader_hierarchical = DataLoader(
    test_ds_hierarchical,
    batch_size=cfg.batch_size,
    shuffle=False,
    collate_fn=hierarchical_collate_fn
)

def evaluate_detailed_hierarchical(loader, model, threshold=0.5):
    """Evaluation function specifically for hierarchical model."""
    model.eval()
    y_true = []
    y_prob = []

    with torch.no_grad():
        for batch in loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            sentence_masks = batch["sentence_masks"].to(device)
            max_sentences = batch["max_sentences"]
            batch_size = batch["batch_size"]
            numeric = batch["numeric"].to(device)
            labels = batch["labels"].to(device)

            # Hierarchical model forward pass
            logits = model(
                input_ids,
                attention_mask,
                sentence_masks,
                max_sentences,
                batch_size,
                numeric
            )
            probs = torch.sigmoid(logits)

            y_true.append(labels.cpu().numpy())
            y_prob.append(probs.cpu().numpy())

    y_true = np.concatenate(y_true).astype(int)
    y_prob = np.concatenate(y_prob)
    y_pred = (y_prob >= threshold).astype(int)

    # Overall metrics
    acc = accuracy_score(y_true, y_pred)
    bal_acc = balanced_accuracy_score(y_true, y_pred)
    f1_binary = f1_score(y_true, y_pred)
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        y_true, y_pred, average="macro", zero_division=0
    )
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
        y_true, y_pred, average="weighted", zero_division=0
    )

    # Per-class metrics
    precision_cls, recall_cls, f1_cls, support_cls = precision_recall_fscore_support(
        y_true, y_pred, labels=[0, 1], average=None, zero_division=0
    )

    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    tn, fp, fn, tp = cm.ravel()

    # ROC-AUC and PR-AUC
    roc = None
    pr_auc = None
    if len(set(y_true)) == 2:
        try:
            roc = roc_auc_score(y_true, y_prob)
        except Exception:
            roc = None
        try:
            pr_auc = average_precision_score(y_true, y_prob)
        except Exception:
            pr_auc = None

    metrics = {
        "threshold": threshold,
        "n_samples": int(len(y_true)),
        "positives_true": int(y_true.sum()),
        "negatives_true": int((y_true == 0).sum()),
        "positives_pred": int(y_pred.sum()),
        "negatives_pred": int((y_pred == 0).sum()),

        "accuracy": acc,
        "balanced_accuracy": bal_acc,
        "f1_binary_pos": f1_binary,
        "precision_macro": precision_macro,
        "recall_macro": recall_macro,
        "f1_macro": f1_macro,
        "precision_weighted": precision_weighted,
        "recall_weighted": recall_weighted,
        "f1_weighted": f1_weighted,

        "class_0": {
            "precision": precision_cls[0],
            "recall": recall_cls[0],
            "f1": f1_cls[0],
            "support": int(support_cls[0]),
        },
        "class_1": {
            "precision": precision_cls[1],
            "recall": recall_cls[1],
            "f1": f1_cls[1],
            "support": int(support_cls[1]),
        },

        "confusion_matrix": cm.tolist(),
        "tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp),

        "roc_auc": roc,
        "pr_auc": pr_auc,
    }

    return metrics


# -----------------------
# Load and evaluate
# -----------------------
ckpt = torch.load(cfg.save_best_path, map_location=device)
conf = ckpt.get("config", {})

model_loaded = HierarchicalViralClassifier(
    model_name=conf.get("model_name", cfg.model_name),
    numeric_in_dim=5,
    delta_proj_dim=conf.get("delta_proj_dim", cfg.delta_proj_dim),
    dropout=conf.get("dropout", cfg.dropout),
    freeze_transformer=conf.get("freeze_transformer", cfg.freeze_transformer),
).to(device)

model_loaded.load_state_dict(ckpt["model_state_dict"])
model_loaded.eval()

print("="*80)
print("DETAILED EVALUATION ON TEST SET (HIERARCHICAL MODEL)")
print("="*80)

# Evaluate with hierarchical loader
metrics = evaluate_detailed_hierarchical(test_loader_hierarchical, model_loaded, threshold=0.5)

# Print comprehensive summary
print(f"\n{'OVERALL METRICS':-^80}")
print(f"Samples: {metrics['n_samples']} (Viral: {metrics['positives_true']}, "
      f"Non-viral: {metrics['negatives_true']})")
print(f"Threshold: {metrics['threshold']}")
print(f"\nAccuracy:          {metrics['accuracy']:.4f}")
print(f"Balanced Accuracy: {metrics['balanced_accuracy']:.4f}")
print(f"F1 (binary, pos):  {metrics['f1_binary_pos']:.4f}")
print(f"F1 (macro):        {metrics['f1_macro']:.4f}")
print(f"F1 (weighted):     {metrics['f1_weighted']:.4f}")
print(f"ROC-AUC:           {metrics['roc_auc']:.4f}")
print(f"PR-AUC:            {metrics['pr_auc']:.4f}")

print(f"\n{'PER-CLASS METRICS':-^80}")
print(f"{'Class':<15} {'Precision':<12} {'Recall':<12} {'F1-Score':<12} {'Support':<10}")
print("-" * 80)
print(f"{'Non-viral (0)':<15} {metrics['class_0']['precision']:<12.4f} "
      f"{metrics['class_0']['recall']:<12.4f} {metrics['class_0']['f1']:<12.4f} "
      f"{metrics['class_0']['support']:<10}")
print(f"{'Viral (1)':<15} {metrics['class_1']['precision']:<12.4f} "
      f"{metrics['class_1']['recall']:<12.4f} {metrics['class_1']['f1']:<12.4f} "
      f"{metrics['class_1']['support']:<10}")

print(f"\n{'CONFUSION MATRIX':-^80}")
print(f"{'':>20} {'Predicted Non-viral':<20} {'Predicted Viral':<20}")
print(f"{'Actual Non-viral':<20} {metrics['tn']:<20} {metrics['fp']:<20}")
print(f"{'Actual Viral':<20} {metrics['fn']:<20} {metrics['tp']:<20}")

print(f"\n{'DETAILED COUNTS':-^80}")
print(f"True Negatives (TN):  {metrics['tn']:<10} (correctly predicted non-viral)")
print(f"False Positives (FP): {metrics['fp']:<10} (predicted viral, actually non-viral)")
print(f"False Negatives (FN): {metrics['fn']:<10} (predicted non-viral, actually viral)")
print(f"True Positives (TP):  {metrics['tp']:<10} (correctly predicted viral)")

print(f"\n{'PREDICTIONS SUMMARY':-^80}")
print(f"Predicted as viral:     {metrics['positives_pred']} comments")
print(f"Predicted as non-viral: {metrics['negatives_pred']} comments")

print("="*80)

# Threshold sensitivity analysis
print("\n" + "="*80)
print("THRESHOLD SENSITIVITY ANALYSIS")
print("="*80)
print(f"{'Threshold':<12} {'Precision':<12} {'Recall':<12} {'F1':<12} {'Predicted Viral':<15}")
print("-" * 80)

for thresh in [0.3, 0.4, 0.5, 0.6, 0.7]:
    m = evaluate_detailed_hierarchical(test_loader_hierarchical, model_loaded, threshold=thresh)
    print(f"{thresh:<12.1f} {m['class_1']['precision']:<12.4f} "
          f"{m['class_1']['recall']:<12.4f} {m['class_1']['f1']:<12.4f} "
          f"{m['positives_pred']:<15}")

print("="*80)

Creating hierarchical test dataset...
DETAILED EVALUATION ON TEST SET (HIERARCHICAL MODEL)

--------------------------------OVERALL METRICS---------------------------------
Samples: 19415 (Viral: 3236, Non-viral: 16179)
Threshold: 0.5

Accuracy:          0.9075
Balanced Accuracy: 0.8664
F1 (binary, pos):  0.7437
F1 (macro):        0.8436
F1 (weighted):     0.9103
ROC-AUC:           0.9496
PR-AUC:            0.7866

-------------------------------PER-CLASS METRICS--------------------------------
Class           Precision    Recall       F1-Score     Support   
--------------------------------------------------------------------------------
Non-viral (0)   0.9596       0.9281       0.9436       16179     
Viral (1)       0.6913       0.8047       0.7437       3236      

--------------------------------CONFUSION MATRIX--------------------------------
                     Predicted Non-viral  Predicted Viral     
Actual Non-viral     15016                1163                
Actual Viral 