In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torch.cuda.amp import autocast, GradScaler
from collections import deque
import pandas as pd
import numpy as np
from PIL import Image
import random
import glob
import cv2
import os
import torchvision
import ast  

In [2]:
torch.cuda.is_available()
torch.device("cuda" if torch.cuda.is_available() else "cpu")

device(type='cuda')

In [3]:
class CNN_LSTM(nn.Module):
    def __init__(self, num_classes, hidden_size=256, num_layers=1):
        super(CNN_LSTM, self).__init__()

        resnet = models.resnet18(pretrained=True)
        resnet.fc = nn.Identity()  
        self.cnn = resnet

        self.lstm = nn.LSTM(input_size=512, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)

        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        batch_size, seq_len, C, H, W = x.shape

        x = x.view(batch_size * seq_len, C, H, W)

        x = self.cnn(x)  

        x = x.view(batch_size, seq_len, -1)

        x, _ = self.lstm(x)

        x = x[:, -1, :]

        x = self.fc(x)
        
        return x

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load("dog_pose_model_final.pth", map_location=device)
 
model = CNN_LSTM(num_classes=3).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

NUM_FRAMES = 8

class_labels = {2: "walking", 1: "standing+sitting", 0: "resting"}

def predict_from_frames(frames):
    if len(frames) < NUM_FRAMES:
        return None, None  

    frames_tensor = torch.stack(list(frames)).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(frames_tensor)
        probabilities = F.softmax(output, dim=1).squeeze(0)

    predicted_class = torch.argmax(probabilities).item()
    probs_dict = {class_labels[i]: probabilities[i].item() * 100 for i in range(len(class_labels))}

    return class_labels[predicted_class], probs_dict

def show_video_with_live_predictions(video_path):
    cap = cv2.VideoCapture(video_path)
    frames_queue = deque(maxlen=NUM_FRAMES)  

    predicted_label = "Loading..." 
    probs = {class_labels[i]: 0.0 for i in range(len(class_labels))}

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame_tensor = transform(frame_rgb)
        frames_queue.append(frame_tensor)

        if len(frames_queue) == NUM_FRAMES:
            predicted_label, probs = predict_from_frames(frames_queue)

        cv2.putText(frame, f"Prediction: {predicted_label}", (50, 50), 
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)

        y_offset = 90
        for label, prob in probs.items():
            cv2.putText(frame, f"{label}: {prob:.2f}%", (50, y_offset), 
                        cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2, cv2.LINE_AA)
            y_offset += 40

        cv2.imshow("Video Prediction", frame)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()

show_video_with_live_predictions('Test_videos\Another Five Test Videos\PR0038747VXYF_20230809000001.mp4')