In [2]:
import os, random, numpy as np, pandas as pd
from tqdm import tqdm
from sklearn.metrics import f1_score, classification_report, confusion_matrix
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.models import swin_t, Swin_T_Weights
from torchvision.transforms import Resize
from torch.amp import autocast, GradScaler

# === REPRODUCIBILITY ===
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# === PATHS (COLAB) ===
train_csv_path = "/content/drive/MyDrive/Colab Notebooks/Projects/CSVs/train.csv"
test_csv_path  = "/content/drive/MyDrive/Colab Notebooks/Projects/CSVs/test.csv"
NPY_DIR        = "/content/drive/MyDrive/Colab Notebooks/Projects/npy_segments_unimodal"
save_path      = "/content/drive/MyDrive/Colab Notebooks/Results/Unfrozen_randomseed/Swinonly"
os.makedirs(save_path, exist_ok=True)

# === CONFIG ===
BATCH_SIZE = 2
MAX_FRAMES = 80
EPOCHS = 20
USE_WEIGHTED_LOSS = True
PATIENCE = 4

# === MODEL ===
class SwinVideoClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.swin = swin_t(weights=Swin_T_Weights.DEFAULT)
        self.swin.head = nn.Identity()   # Fully unfrozen
        self.fc = nn.Linear(768, 1)
    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.view(B*T, C, H, W)
        features = self.swin(x).view(B, T, -1)
        pooled = features.mean(dim=1)
        return self.fc(pooled).squeeze(1)

# === DATASET ===
class ViolenceDataset(Dataset):
    def __init__(self, csv_path, npy_dir):
        self.df = pd.read_csv(csv_path)
        self.npy_dir = npy_dir
        self.resize = Resize((224, 224))
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        frames = np.load(os.path.join(self.npy_dir, f"{row['Segment ID']}.npy"))[:MAX_FRAMES]
        frames = torch.stack([self.resize(torch.from_numpy(f).permute(2,0,1).float()/255.0) for f in frames])
        return frames, torch.tensor(row['Violence label(video)'], dtype=torch.float32)

# === INIT ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dataset = ViolenceDataset(train_csv_path, NPY_DIR)
test_dataset  = ViolenceDataset(test_csv_path,  NPY_DIR)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False)

pos = train_dataset.df['Violence label(video)'].sum()
neg = len(train_dataset) - pos
ratio = neg / max(pos, 1)
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([ratio]).to(device)) if USE_WEIGHTED_LOSS else nn.BCEWithLogitsLoss()

model = SwinVideoClassifier().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=1, factor=0.5)
scaler = GradScaler()

best_f1, early_stop_counter = 0, 0

# === TRAIN ===
for epoch in range(EPOCHS):
    model.train()
    y_true, y_pred, total_loss = [], [], 0.0
    for frames, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        frames, labels = frames.to(device), labels.to(device)
        with autocast(device_type='cuda'):
            outputs = model(frames)
            loss = criterion(outputs, labels)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
        preds = (torch.sigmoid(outputs) > 0.5).int()
        y_true.extend(labels.cpu().numpy()); y_pred.extend(preds.cpu().numpy())
    macro_f1 = f1_score(y_true, y_pred, average='macro')
    micro_f1 = f1_score(y_true, y_pred, average='micro')
    print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | Macro F1: {macro_f1:.4f} | Micro F1: {micro_f1:.4f}")
    scheduler.step(macro_f1)
    if macro_f1 > best_f1:
        best_f1 = macro_f1
        torch.save(model.state_dict(), os.path.join(save_path, "swin_best_20.pt"))
        early_stop_counter = 0
    else:
        early_stop_counter += 1
        if early_stop_counter >= PATIENCE: break

# === TEST ===
model.load_state_dict(torch.load(os.path.join(save_path, "swin_best_20.pt")))
model.eval()
y_true, y_pred, test_losses = [], [], []
segment_ids = test_dataset.df['Segment ID'].tolist()
with torch.no_grad():
    for frames, labels in test_loader:
        frames, labels = frames.to(device), labels.to(device)
        outputs = model(frames)
        loss = criterion(outputs, labels)
        test_losses.append(loss.item())
        preds = (torch.sigmoid(outputs) > 0.5).int()
        y_true.extend(labels.cpu().numpy()); y_pred.extend(preds.cpu().numpy())

avg_test_loss = np.mean(test_losses)
report = classification_report(y_true, y_pred, target_names=["Non-violent","Violent"], output_dict=True, zero_division=0)
conf_matrix = confusion_matrix(y_true, y_pred)
print(f"\n[TEST] BCE Loss: {avg_test_loss:.4f}")
print(f"[TEST] Macro F1: {report['macro avg']['f1-score']:.4f}")
print(f"[TEST] Micro F1: {f1_score(y_true,y_pred,average='micro'):.4f}")
print("[TEST] Per-Class F1 Scores:")
print(f" - Non-violent F1: {report['Non-violent']['f1-score']:.4f}")
print(f" - Violent F1: {report['Violent']['f1-score']:.4f}")
print("Confusion Matrix:\n", conf_matrix)

pd.DataFrame({"Segment ID": segment_ids, "True": y_true, "Pred": y_pred}).to_csv(
    os.path.join(save_path, "swin_predictions_20.csv"), index=False)
pd.DataFrame(report).to_csv(os.path.join(save_path, "swin_test_metrics_20.csv"))


Epoch 1/20: 100%|██████████| 335/335 [29:37<00:00,  5.31s/it]


Epoch 1 | Loss: 0.7548 | Macro F1: 0.6077 | Micro F1: 0.6084


Epoch 2/20: 100%|██████████| 335/335 [05:46<00:00,  1.04s/it]


Epoch 2 | Loss: 0.6752 | Macro F1: 0.6213 | Micro F1: 0.6218


Epoch 3/20: 100%|██████████| 335/335 [05:36<00:00,  1.01s/it]


Epoch 3 | Loss: 0.6698 | Macro F1: 0.6673 | Micro F1: 0.6682


Epoch 4/20: 100%|██████████| 335/335 [05:32<00:00,  1.01it/s]


Epoch 4 | Loss: 0.6016 | Macro F1: 0.7020 | Micro F1: 0.7025


Epoch 5/20: 100%|██████████| 335/335 [05:29<00:00,  1.02it/s]


Epoch 5 | Loss: 0.5669 | Macro F1: 0.7292 | Micro F1: 0.7294


Epoch 6/20: 100%|██████████| 335/335 [05:31<00:00,  1.01it/s]


Epoch 6 | Loss: 0.4821 | Macro F1: 0.7919 | Micro F1: 0.7922


Epoch 7/20: 100%|██████████| 335/335 [05:29<00:00,  1.02it/s]


Epoch 7 | Loss: 0.4907 | Macro F1: 0.7947 | Micro F1: 0.7952


Epoch 8/20: 100%|██████████| 335/335 [05:25<00:00,  1.03it/s]


Epoch 8 | Loss: 0.3980 | Macro F1: 0.8347 | Micro F1: 0.8356


Epoch 9/20: 100%|██████████| 335/335 [05:23<00:00,  1.04it/s]


Epoch 9 | Loss: 0.3629 | Macro F1: 0.8565 | Micro F1: 0.8580


Epoch 10/20: 100%|██████████| 335/335 [05:24<00:00,  1.03it/s]


Epoch 10 | Loss: 0.3209 | Macro F1: 0.8791 | Micro F1: 0.8804


Epoch 11/20: 100%|██████████| 335/335 [05:21<00:00,  1.04it/s]


Epoch 11 | Loss: 0.2787 | Macro F1: 0.9030 | Micro F1: 0.9043


Epoch 12/20: 100%|██████████| 335/335 [05:23<00:00,  1.04it/s]


Epoch 12 | Loss: 0.2878 | Macro F1: 0.8957 | Micro F1: 0.8969


Epoch 13/20: 100%|██████████| 335/335 [05:19<00:00,  1.05it/s]


Epoch 13 | Loss: 0.2405 | Macro F1: 0.9075 | Micro F1: 0.9088


Epoch 14/20: 100%|██████████| 335/335 [05:19<00:00,  1.05it/s]


Epoch 14 | Loss: 0.1847 | Macro F1: 0.9392 | Micro F1: 0.9402


Epoch 15/20: 100%|██████████| 335/335 [05:22<00:00,  1.04it/s]


Epoch 15 | Loss: 0.1943 | Macro F1: 0.9227 | Micro F1: 0.9238


Epoch 16/20: 100%|██████████| 335/335 [05:19<00:00,  1.05it/s]


Epoch 16 | Loss: 0.1992 | Macro F1: 0.9167 | Micro F1: 0.9178


Epoch 17/20: 100%|██████████| 335/335 [05:16<00:00,  1.06it/s]


Epoch 17 | Loss: 0.0957 | Macro F1: 0.9619 | Micro F1: 0.9626


Epoch 18/20: 100%|██████████| 335/335 [05:16<00:00,  1.06it/s]


Epoch 18 | Loss: 0.0549 | Macro F1: 0.9787 | Micro F1: 0.9791


Epoch 19/20: 100%|██████████| 335/335 [05:17<00:00,  1.06it/s]


Epoch 19 | Loss: 0.0473 | Macro F1: 0.9802 | Micro F1: 0.9806


Epoch 20/20: 100%|██████████| 335/335 [05:16<00:00,  1.06it/s]


Epoch 20 | Loss: 0.0449 | Macro F1: 0.9848 | Micro F1: 0.9851

[TEST] BCE Loss: 3.5976
[TEST] Macro F1: 0.6551
[TEST] Micro F1: 0.6564
[TEST] Per-Class F1 Scores:
 - Non-violent F1: 0.6763
 - Violent F1: 0.6340
Confusion Matrix:
 [[117  66]
 [ 46  97]]
