In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from timm import create_model
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import InterpolationMode
from tqdm import tqdm  # Progress bar


In [None]:
#  Paths
data_dir = r"F:\training frames"  # Root dataset folder
model_save_path = "violence_swin_transformer_multi.pth"

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data Augmentation
train_transform = transforms.Compose([
    transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
    transforms.RandomHorizontalFlip(p=0.5),  # Flip 50% of images
    transforms.RandomRotation(degrees=10),  # Rotate ±10 degrees
    transforms.ColorJitter(brightness=0.2, contrast=0.2),  # Adjust brightness & contrast
    transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0)),  # Apply blur
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


valid_transform = transforms.Compose([
    transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load  Dataset
full_dataset = datasets.ImageFolder(root=data_dir, transform=train_transform)

# Split into Train (80%) and Validation (20%)
train_size = int(0.8 * len(full_dataset))
valid_size = len(full_dataset) - train_size
train_dataset, valid_dataset = random_split(full_dataset, [train_size, valid_size])


valid_dataset.dataset.transform = valid_transform


train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=False, num_workers=4)


class_names = full_dataset.classes  # Auto-detect classes from folders
print("Classes:", class_names)

# Load Swin Transformer Model
model = create_model("swin_tiny_patch4_window7_224", pretrained=True, num_classes=len(class_names))
model.to(device)

# Loss Function & Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=1e-4)

# Early Stopping Class
class EarlyStopping:
    def __init__(self, patience=2, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = None
        self.counter = 0

    def step(self, val_loss):
        if self.best_loss is None or val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                print("⏹ Early stopping triggered!")
                return True  # Stop training
        return False

# Initialize Early Stopping
early_stopping = EarlyStopping(patience=2)

# Train Model
epochs = 10
for epoch in range(epochs):
    model.train()
    running_loss = 0.0

    loop = tqdm(train_loader, leave=True)
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        loop.set_description(f"Epoch [{epoch+1}/{epochs}]")
        loop.set_postfix(loss=running_loss / len(train_loader))

    # Validate Model
    model.eval()
    valid_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
    for images, labels in valid_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                valid_loss += loss.item()

                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

    val_loss = valid_loss / len(valid_loader)
    accuracy = 100 * correct / total
    print(f"✅ Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.2f}%")


    if early_stopping.step(val_loss):
        break

# Save Trained Model
torch.save(model.state_dict(), model_save_path)
print(f"✅ Model training complete and saved to {model_save_path}")


    
    

