In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet50, ResNet50_Weights
from sklearn.metrics import f1_score, classification_report, confusion_matrix
from tqdm import tqdm
from torch.amp import autocast, GradScaler
from torchvision import transforms

# === COLAB PATHS ===
train_csv_path = '/content/drive/MyDrive/Colab Notebooks/Projects/CSVs/train.csv'
test_csv_path = '/content/drive/MyDrive/Colab Notebooks/Projects/CSVs/test.csv'
NPY_DIR = '/content/drive/MyDrive/Colab Notebooks/Projects/npy_segments_unimodal'
save_path = '/content/drive/MyDrive/Colab Notebooks/Results/Again(RESNET50+GRU)'
os.makedirs(save_path, exist_ok=True)

# === CONFIG ===
BATCH_SIZE = 4
GRAD_ACCUM_STEPS = 4
EPOCHS = 20
MAX_FRAMES = 80
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === MODEL ===
class ResNet50GRU(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
        for p in self.resnet.parameters():
            p.requires_grad = False
        for p in self.resnet.layer4.parameters():
            p.requires_grad = True
        self.resnet.fc = nn.Identity()
        self.gru = nn.GRU(2048, 256, batch_first=True, bidirectional=True)
        self.attn = nn.Linear(512, 1)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(512, 1)

    def forward(self, x):
        B, T, C, H, W = x.size()
        x = x.view(B * T, C, H, W)
        feats = self.resnet(x)
        feats = feats.view(B, T, -1)
        out, _ = self.gru(feats)
        weights = torch.softmax(self.attn(out), dim=1)
        out = torch.sum(weights * out, dim=1)
        out = self.dropout(out)
        return self.fc(out).squeeze(1)

# === DATASET ===
class ViolenceDataset(Dataset):
    def __init__(self, csv_path, npy_dir):
        self.df = pd.read_csv(csv_path)
        self.npy_dir = npy_dir
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225))
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        segment_id = row['Segment ID']
        label = row['Violence label(video)']
        frames = np.load(os.path.join(self.npy_dir, f"{segment_id}.npy"))
        frames = frames[:MAX_FRAMES]
        if len(frames) < MAX_FRAMES:
            pad_len = MAX_FRAMES - len(frames)
            frames = np.concatenate([frames, np.repeat(frames[-1][np.newaxis], pad_len, axis=0)], axis=0)
        frames = torch.stack([
            self.transform(torch.from_numpy(f).permute(2, 0, 1).float() / 255.0)
            for f in frames
        ])
        return frames, torch.tensor(label, dtype=torch.float32)

# === INIT ===
train_dataset = ViolenceDataset(train_csv_path, NPY_DIR)
test_dataset = ViolenceDataset(test_csv_path, NPY_DIR)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

model = ResNet50GRU().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1)
scaler = GradScaler()
best_loss = float('inf')
early_stop_counter = 0
PATIENCE = 4

# === TRAINING ===
for epoch in range(EPOCHS):
    model.train()
    y_true, y_pred = [], []
    total_loss = 0.0
    optimizer.zero_grad()

    for i, (frames, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")):
        frames, labels = frames.to(device), labels.to(device)
        with autocast(device_type='cuda'):
            outputs = model(frames)
            loss = criterion(outputs, labels) / GRAD_ACCUM_STEPS
        scaler.scale(loss).backward()

        if (i + 1) % GRAD_ACCUM_STEPS == 0 or (i + 1) == len(train_loader):
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        total_loss += loss.item() * GRAD_ACCUM_STEPS
        preds = (torch.sigmoid(outputs) > 0.5).int()
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())

    avg_loss = total_loss / len(train_loader)
    macro_f1 = f1_score(y_true, y_pred, average='macro')
    micro_f1 = f1_score(y_true, y_pred, average='micro')
    print(f"Epoch {epoch+1} | Avg BCE Loss: {avg_loss:.4f} | Macro F1: {macro_f1:.4f} | Micro F1: {micro_f1:.4f}")
    print(classification_report(y_true, y_pred, target_names=["Non-violent", "Violent"], zero_division=0))

    scheduler.step(avg_loss)

    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(model.state_dict(), os.path.join(save_path, "resnet50_gru_best_bce.pt"))
        print("[SAVED] Best BCE model")
        early_stop_counter = 0
    else:
        early_stop_counter += 1
        if early_stop_counter >= PATIENCE:
            print("Early stopping.")
            break

# === TESTING ===
model.load_state_dict(torch.load(os.path.join(save_path, "resnet50_gru_best_bce.pt")))
model.eval()
y_true, y_pred, segment_ids = [], [], test_dataset.df['Segment ID'].tolist()
test_loss = 0

with torch.no_grad():
    for frames, labels in test_loader:
        frames, labels = frames.to(device), labels.to(device)
        outputs = model(frames)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        preds = (torch.sigmoid(outputs) > 0.5).int()
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())

avg_test_loss = test_loss / len(test_loader)
macro_f1_test = f1_score(y_true, y_pred, average='macro')
micro_f1_test = f1_score(y_true, y_pred, average='micro')
report = classification_report(y_true, y_pred, target_names=["Non-violent", "Violent"], output_dict=True, zero_division=0)
conf_matrix = confusion_matrix(y_true, y_pred)

print("\n[TEST] BCE Loss:", round(avg_test_loss, 4))
print("[TEST] Macro F1:", round(macro_f1_test, 4))
print("[TEST] Micro F1:", round(micro_f1_test, 4))
print("[TEST] Per-Class F1 Scores:")
print(" - Non-violent F1:", round(report['Non-violent']['f1-score'], 4))
print(" - Violent F1:", round(report['Violent']['f1-score'], 4))
print("Confusion Matrix:\n", conf_matrix)

results = pd.DataFrame({"Segment ID": segment_ids, "True": y_true, "Pred": y_pred})
results.to_csv(os.path.join(save_path, "resnet50_gru_predictions_final.csv"), index=False)
pd.DataFrame(report).to_csv(os.path.join(save_path, "resnet50_gru_test_metrics_final.csv"))


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 169MB/s]
Epoch 1/20: 100%|██████████| 168/168 [24:53<00:00,  8.89s/it]


Epoch 1 | Avg BCE Loss: 0.6523 | Macro F1: 0.4794 | Micro F1: 0.5904
              precision    recall  f1-score   support

 Non-violent       0.59      0.92      0.72       381
     Violent       0.60      0.15      0.24       288

    accuracy                           0.59       669
   macro avg       0.59      0.54      0.48       669
weighted avg       0.59      0.59      0.51       669

[SAVED] Best BCE model


Epoch 2/20: 100%|██████████| 168/168 [04:56<00:00,  1.77s/it]


Epoch 2 | Avg BCE Loss: 0.5773 | Macro F1: 0.6800 | Micro F1: 0.6891
              precision    recall  f1-score   support

 Non-violent       0.72      0.75      0.73       381
     Violent       0.65      0.60      0.63       288

    accuracy                           0.69       669
   macro avg       0.68      0.68      0.68       669
weighted avg       0.69      0.69      0.69       669

[SAVED] Best BCE model


Epoch 3/20: 100%|██████████| 168/168 [04:40<00:00,  1.67s/it]


Epoch 3 | Avg BCE Loss: 0.5103 | Macro F1: 0.7411 | Micro F1: 0.7474
              precision    recall  f1-score   support

 Non-violent       0.77      0.79      0.78       381
     Violent       0.71      0.69      0.70       288

    accuracy                           0.75       669
   macro avg       0.74      0.74      0.74       669
weighted avg       0.75      0.75      0.75       669

[SAVED] Best BCE model


Epoch 4/20: 100%|██████████| 168/168 [04:31<00:00,  1.61s/it]


Epoch 4 | Avg BCE Loss: 0.4897 | Macro F1: 0.7499 | Micro F1: 0.7519
              precision    recall  f1-score   support

 Non-violent       0.81      0.74      0.77       381
     Violent       0.69      0.77      0.73       288

    accuracy                           0.75       669
   macro avg       0.75      0.75      0.75       669
weighted avg       0.76      0.75      0.75       669

[SAVED] Best BCE model


Epoch 5/20: 100%|██████████| 168/168 [04:24<00:00,  1.57s/it]


Epoch 5 | Avg BCE Loss: 0.4879 | Macro F1: 0.7419 | Micro F1: 0.7444
              precision    recall  f1-score   support

 Non-violent       0.80      0.74      0.77       381
     Violent       0.69      0.75      0.72       288

    accuracy                           0.74       669
   macro avg       0.74      0.75      0.74       669
weighted avg       0.75      0.74      0.75       669

[SAVED] Best BCE model


Epoch 6/20: 100%|██████████| 168/168 [04:19<00:00,  1.54s/it]


Epoch 6 | Avg BCE Loss: 0.4622 | Macro F1: 0.7688 | Micro F1: 0.7713
              precision    recall  f1-score   support

 Non-violent       0.82      0.77      0.79       381
     Violent       0.72      0.77      0.74       288

    accuracy                           0.77       669
   macro avg       0.77      0.77      0.77       669
weighted avg       0.77      0.77      0.77       669

[SAVED] Best BCE model


Epoch 7/20: 100%|██████████| 168/168 [04:16<00:00,  1.53s/it]


Epoch 7 | Avg BCE Loss: 0.4372 | Macro F1: 0.7734 | Micro F1: 0.7758
              precision    recall  f1-score   support

 Non-violent       0.82      0.77      0.80       381
     Violent       0.72      0.78      0.75       288

    accuracy                           0.78       669
   macro avg       0.77      0.78      0.77       669
weighted avg       0.78      0.78      0.78       669

[SAVED] Best BCE model


Epoch 8/20: 100%|██████████| 168/168 [04:14<00:00,  1.52s/it]


Epoch 8 | Avg BCE Loss: 0.3991 | Macro F1: 0.7953 | Micro F1: 0.7982
              precision    recall  f1-score   support

 Non-violent       0.83      0.81      0.82       381
     Violent       0.75      0.79      0.77       288

    accuracy                           0.80       669
   macro avg       0.79      0.80      0.80       669
weighted avg       0.80      0.80      0.80       669

[SAVED] Best BCE model


Epoch 9/20: 100%|██████████| 168/168 [04:14<00:00,  1.52s/it]


Epoch 9 | Avg BCE Loss: 0.3730 | Macro F1: 0.8251 | Micro F1: 0.8281
              precision    recall  f1-score   support

 Non-violent       0.85      0.84      0.85       381
     Violent       0.80      0.81      0.80       288

    accuracy                           0.83       669
   macro avg       0.82      0.83      0.83       669
weighted avg       0.83      0.83      0.83       669

[SAVED] Best BCE model


Epoch 10/20: 100%|██████████| 168/168 [04:17<00:00,  1.53s/it]


Epoch 10 | Avg BCE Loss: 0.3249 | Macro F1: 0.8640 | Micro F1: 0.8655
              precision    recall  f1-score   support

 Non-violent       0.91      0.85      0.88       381
     Violent       0.82      0.89      0.85       288

    accuracy                           0.87       669
   macro avg       0.86      0.87      0.86       669
weighted avg       0.87      0.87      0.87       669

[SAVED] Best BCE model


Epoch 11/20: 100%|██████████| 168/168 [04:20<00:00,  1.55s/it]


Epoch 11 | Avg BCE Loss: 0.3055 | Macro F1: 0.8576 | Micro F1: 0.8595
              precision    recall  f1-score   support

 Non-violent       0.89      0.86      0.87       381
     Violent       0.82      0.86      0.84       288

    accuracy                           0.86       669
   macro avg       0.86      0.86      0.86       669
weighted avg       0.86      0.86      0.86       669

[SAVED] Best BCE model


Epoch 12/20: 100%|██████████| 168/168 [04:16<00:00,  1.53s/it]


Epoch 12 | Avg BCE Loss: 0.3022 | Macro F1: 0.8589 | Micro F1: 0.8610
              precision    recall  f1-score   support

 Non-violent       0.89      0.86      0.88       381
     Violent       0.83      0.86      0.84       288

    accuracy                           0.86       669
   macro avg       0.86      0.86      0.86       669
weighted avg       0.86      0.86      0.86       669

[SAVED] Best BCE model


Epoch 13/20: 100%|██████████| 168/168 [04:17<00:00,  1.53s/it]


Epoch 13 | Avg BCE Loss: 0.2306 | Macro F1: 0.8968 | Micro F1: 0.8984
              precision    recall  f1-score   support

 Non-violent       0.92      0.90      0.91       381
     Violent       0.87      0.90      0.88       288

    accuracy                           0.90       669
   macro avg       0.90      0.90      0.90       669
weighted avg       0.90      0.90      0.90       669

[SAVED] Best BCE model


Epoch 14/20: 100%|██████████| 168/168 [04:13<00:00,  1.51s/it]


Epoch 14 | Avg BCE Loss: 0.2506 | Macro F1: 0.8928 | Micro F1: 0.8939
              precision    recall  f1-score   support

 Non-violent       0.94      0.87      0.90       381
     Violent       0.84      0.92      0.88       288

    accuracy                           0.89       669
   macro avg       0.89      0.90      0.89       669
weighted avg       0.90      0.89      0.89       669



Epoch 15/20: 100%|██████████| 168/168 [04:13<00:00,  1.51s/it]


Epoch 15 | Avg BCE Loss: 0.2323 | Macro F1: 0.8878 | Micro F1: 0.8894
              precision    recall  f1-score   support

 Non-violent       0.92      0.88      0.90       381
     Violent       0.85      0.90      0.87       288

    accuracy                           0.89       669
   macro avg       0.89      0.89      0.89       669
weighted avg       0.89      0.89      0.89       669



Epoch 16/20: 100%|██████████| 168/168 [04:11<00:00,  1.50s/it]


Epoch 16 | Avg BCE Loss: 0.1916 | Macro F1: 0.9121 | Micro F1: 0.9133
              precision    recall  f1-score   support

 Non-violent       0.94      0.91      0.92       381
     Violent       0.88      0.92      0.90       288

    accuracy                           0.91       669
   macro avg       0.91      0.91      0.91       669
weighted avg       0.91      0.91      0.91       669

[SAVED] Best BCE model


Epoch 17/20: 100%|██████████| 168/168 [04:14<00:00,  1.51s/it]


Epoch 17 | Avg BCE Loss: 0.1892 | Macro F1: 0.9227 | Micro F1: 0.9238
              precision    recall  f1-score   support

 Non-violent       0.95      0.91      0.93       381
     Violent       0.89      0.94      0.91       288

    accuracy                           0.92       669
   macro avg       0.92      0.93      0.92       669
weighted avg       0.93      0.92      0.92       669

[SAVED] Best BCE model


Epoch 18/20: 100%|██████████| 168/168 [04:16<00:00,  1.53s/it]


Epoch 18 | Avg BCE Loss: 0.1837 | Macro F1: 0.9139 | Micro F1: 0.9148
              precision    recall  f1-score   support

 Non-violent       0.96      0.89      0.92       381
     Violent       0.87      0.94      0.91       288

    accuracy                           0.91       669
   macro avg       0.91      0.92      0.91       669
weighted avg       0.92      0.91      0.92       669

[SAVED] Best BCE model


Epoch 19/20: 100%|██████████| 168/168 [04:09<00:00,  1.49s/it]


Epoch 19 | Avg BCE Loss: 0.1666 | Macro F1: 0.9253 | Micro F1: 0.9268
              precision    recall  f1-score   support

 Non-violent       0.93      0.94      0.94       381
     Violent       0.92      0.91      0.91       288

    accuracy                           0.93       669
   macro avg       0.93      0.93      0.93       669
weighted avg       0.93      0.93      0.93       669

[SAVED] Best BCE model


Epoch 20/20: 100%|██████████| 168/168 [04:11<00:00,  1.50s/it]


Epoch 20 | Avg BCE Loss: 0.1544 | Macro F1: 0.9256 | Micro F1: 0.9268
              precision    recall  f1-score   support

 Non-violent       0.95      0.92      0.93       381
     Violent       0.90      0.93      0.92       288

    accuracy                           0.93       669
   macro avg       0.92      0.93      0.93       669
weighted avg       0.93      0.93      0.93       669

[SAVED] Best BCE model

[TEST] BCE Loss: 0.9424
[TEST] Macro F1: 0.6655
[TEST] Micro F1: 0.6656
[TEST] Per-Class F1 Scores:
 - Non-violent F1: 0.6583
 - Violent F1: 0.6727
Confusion Matrix:
 [[105  78]
 [ 31 112]]
