In [6]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torch.cuda.amp import autocast, GradScaler
from collections import deque
import pandas as pd
import numpy as np
from PIL import Image
import random
import glob
import cv2
import os
import torchvision
import ast  

In [7]:
torch.cuda.is_available()
torch.device("cuda" if torch.cuda.is_available() else "cpu")

device(type='cuda')

In [8]:
class CNN_LSTM(nn.Module):
    def __init__(self, num_classes, hidden_size=256, num_layers=1):
        super(CNN_LSTM, self).__init__()

        resnet = models.resnet18(pretrained=True)
        resnet.fc = nn.Identity()  
        self.cnn = resnet

        self.lstm = nn.LSTM(input_size=512, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)

        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        batch_size, seq_len, C, H, W = x.shape

        x = x.view(batch_size * seq_len, C, H, W)

        x = self.cnn(x)  

        x = x.view(batch_size, seq_len, -1)

        x, _ = self.lstm(x)

        x = x[:, -1, :]

        x = self.fc(x)
        
        return x

In [34]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load("sniff_detection_3.pth", map_location=device)

model = CNN_LSTM(num_classes=4)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

num_frames = 8
general_confidence_threshold = 0.85  
alert_confidence_threshold = 0.992   

class_labels = {0: "Alert", 1: "Sniff bin 1", 2: "Sniff bin 2", 3: "Sniff bin 3"}
video_path_test = 'Test set sniffing/IMG_3392.MOV'

def predict_from_frames(frames_deque):
    if len(frames_deque) < num_frames:
        return None, 0.0
    
    frames_tensor = torch.stack(list(frames_deque)).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(frames_tensor)
        probabilities = F.softmax(output, dim=1).squeeze(0)
    
    confidence, predicted_index_tensor = torch.max(probabilities, 0)
    predicted_index = predicted_index_tensor.item()
    predicted_label = class_labels[predicted_index]
    
    return predicted_label, confidence.item()

def show_video_with_live_predictions(video_path):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video at {video_path}")
        return
        
    frames_queue = deque(maxlen=num_frames)
    display_label = "Initializing..."
    display_confidence = 0.0

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame_tensor = transform(frame_rgb)
        frames_queue.append(frame_tensor)

        if len(frames_queue) == num_frames:
            predicted_label, confidence = predict_from_frames(frames_queue)
            
            threshold = 0.0
            if predicted_label == "Alert":
                threshold = alert_confidence_threshold
            else:
                threshold = general_confidence_threshold

            if confidence >= threshold:
                display_label = predicted_label
                display_confidence = confidence
            else:
                display_label = "No Action"
                display_confidence = 0.0

        if display_label in ["Initializing...", "No Action"]:
            prediction_text = f"Prediction: {display_label}        {video_path_test.split('/')[-1]}"
        else:
            prediction_text = f"Prediction: {display_label} ({display_confidence:.2%})      {video_path_test.split('/')[-1]}"

        video_filename = os.path.basename(video_path)
        cv2.putText(frame, prediction_text, (20, 40), 
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 3, cv2.LINE_AA)
        cv2.putText(frame, prediction_text, (20, 40), 
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
        
        cv2.imshow("Sniffing Detection", frame)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()

show_video_with_live_predictions(video_path_test)