In [None]:
import os
import torch
import torchvision.transforms as transforms
from timm import create_model
from PIL import Image, ImageDraw, ImageFont
import glob

In [None]:

test_dir = r"F:\keyframes"  # Root test dataset folder
output_base_dir = r"D:\violence_classification\predictions"
violence_base_dir = os.path.join(output_base_dir, "Violence")
non_violence_base_dir = os.path.join(output_base_dir, "Non-Violence")
model_path = r"D:\rabbit\violence_swin_transformer_multi.pth"

# base output directories if they don't exist
os.makedirs(violence_base_dir, exist_ok=True)
os.makedirs(non_violence_base_dir, exist_ok=True)

#  Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model
model = create_model("swin_tiny_patch4_window7_224", pretrained=False, num_classes=2)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


image_paths = glob.glob(os.path.join(test_dir, "**", "*.jpg"), recursive=True)

if len(image_paths) == 0:
    print("⚠ No images found in test directory!")
else:
    print(f"✅ Found {len(image_paths)} images in test dataset.")


try:
    font = ImageFont.truetype("arial.ttf", 20)  
except:
    font = ImageFont.load_default()  


for img_path in image_paths:
    try:
        img = Image.open(img_path).convert("RGB")
        img_tensor = transform(img).unsqueeze(0).to(device)

        # Get Model Prediction
        with torch.no_grad():
            output = model(img_tensor)
            _, predicted = torch.max(output, 1)
            label = "Violence" if predicted.item() == 1 else "Non-Violence"

        
        video_name = os.path.basename(os.path.dirname(img_path))

        
        save_dir = os.path.join(violence_base_dir if label == "Violence" else non_violence_base_dir, video_name)
        os.makedirs(save_dir, exist_ok=True)

        
        img_name = os.path.basename(img_path)
        save_path = os.path.join(save_dir, img_name)

        
        draw = ImageDraw.Draw(img)
        box_color = "red" if label == "Violence" else "green"

        
        draw.rectangle([(5, 5), (img.width - 5, img.height - 5)], outline=box_color, width=5)

        
        text_position = (10, 10)
        text_size = draw.textbbox(text_position, label, font=font)
        text_bg_position = [text_size[0] - 5, text_size[1] - 5, text_size[2] + 5, text_size[3] + 5]
        draw.rectangle(text_bg_position, fill="black")

        
        draw.text(text_position, label, font=font, fill="white")

        
        img.save(save_path)

    except Exception as e:
        print(f"⚠ Error processing {img_path}: {e}")

print(f"✅ All test frames processed and stored in {output_base_dir}")


saved_violence = sum([len(files) for _, _, files in os.walk(violence_base_dir)])
saved_non_violence = sum([len(files) for _, _, files in os.walk(non_violence_base_dir)])
print(f"📊 Final Saved Summary: Violence = {saved_violence}, Non-Violence = {saved_non_violence}")

