In [2]:
import torch
import torch.nn as nn
import torchvision.models as models
import norse.torch as norse  # For SNN layers

# Ensure device is set
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Redefine the model structure
class CNN_SNN_Model(nn.Module):
    def __init__(self):
        super(CNN_SNN_Model, self).__init__()

        # CNN Feature Extractor
        resnet = models.resnet18(pretrained=True)
        self.cnn = nn.Sequential(*list(resnet.children())[:-1])  

        # SNN Layer (Spiking Neurons for Temporal Processing)
        self.snn = norse.LIFRecurrentCell(512, 256)  

        # Fully Connected Classifier
        self.fc = nn.Linear(256, 1)  

    def forward(self, x):
        batch_size, T, C, H, W = x.shape  
        x = x.view(-1, C, H, W)  # Flatten batch & time for CNN input
        x = self.cnn(x)  
        x = x.view(batch_size, T, -1)  

        # Pass each time step through SNN
        snn_out = []
        state = None
        for t in range(T):
            out, state = self.snn(x[:, t, :], state)
            snn_out.append(out)

        x = torch.stack(snn_out, dim=1)  
        x = x.mean(dim=1)  
        x = self.fc(x)  
        return torch.sigmoid(x)  # Binary classification

# Instantiate the model
model = CNN_SNN_Model().to(device)

# Now load the trained model weights
model.load_state_dict(torch.load("cnn_snn_model.pth", map_location=device))
model.to(device)
model.eval()

print("✅ Model Loaded Successfully! Ready for Inference.")


  optree.register_pytree_node(
  optree.register_pytree_node(


✅ Model Loaded Successfully! Ready for Inference.


  model.load_state_dict(torch.load("cnn_snn_model.pth", map_location=device))


In [4]:
import cv2
import torch
import numpy as np
import torchvision.transforms as transforms

# Define the same transformations used during training
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),  # Resize to match model input
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize
])

# Function to Predict Anomalies in New Videos
def predict_anomaly(video_path, model, device):
    cap = cv2.VideoCapture(video_path)
    frames = []
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
        frame = transform(frame)  # Apply transformations
        frames.append(frame)

    cap.release()

    if len(frames) == 0:
        print("⚠ No frames extracted from the video! Check the file format or path.")
        return

    # Convert frames list to Tensor format (T, C, H, W)
    video_tensor = torch.stack(frames).unsqueeze(0).to(device)  # Add batch dimension

    # Run inference
    model.eval()
    with torch.no_grad():
        output = model(video_tensor)
        prediction = "Anomaly" if output.item() > 0.5 else "Normal"

    print(f"🚨 Prediction: {prediction} (Confidence: {output.item():.4f})")

# Test with a new video
predict_anomaly(r"D:\DB\Recording 2025-01-21 172723.mp4", model, device)


🚨 Prediction: Anomaly (Confidence: 0.9993)
