In [1]:
import torch
import cv2
import numpy as np

In [2]:
model = torch.load('SegmentAnalysis.pth')
video = cv2.VideoCapture('../Videos/game_1.mp4')


In [3]:
import torch
import torch.nn as nn

class EventSegmentationLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=True)
        # self.lstm.lstm.set_backward_compatible(True) # Disable cuDNN
        self.classifier = nn.Linear(hidden_dim * 2, num_classes)

    def forward(self, x):  # x: [batch_size, seq_len, input_dim]
        out, _ = self.lstm(x)
        logits = self.classifier(out)  # [batch_size, seq_len, num_classes]
        return logits

# --- Step 1: Set your model hyperparameters ---
input_dim = 1000     # replace with your actual input dimension
hidden_dim = 128     # replace with your actual hidden dimension
num_classes = 2     # replace with your actual number of classes

# --- Step 2: Initialize the model ---
model = EventSegmentationLSTM(input_dim, hidden_dim, num_classes)

# --- Step 3: Load the state_dict ---
state_dict_path = "SegmentAnalysis.pth"  # replace with your saved file path
state_dict = torch.load(state_dict_path, map_location='cuda' if torch.cuda.is_available() else 'cpu')
model = model.cuda() if torch.cuda.is_available() else model
model.load_state_dict(state_dict)

# --- Step 4: Set model to eval mode if needed ---
model.eval()

print("Model loaded successfully!")


Model loaded successfully!


In [4]:
import cv2
import torch
import torchvision.transforms as T
from torchvision.models import resnet18

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load model
pre_processing_model = resnet18(pretrained=True).to(device)
pre_processing_model.eval()

# Transform (same as training)
transform = T.Compose([
    T.ToPILImage(),
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

features = []

# Open video
cap = cv2.VideoCapture("../Videos/rallies_01.mp4")
count = 0
while True:
    ret,frame = cap.read()
    if not ret:
        break
    img_tensor = transform(frame).unsqueeze(0).to(device)  # shape [1, 3, 224, 224]
    
    with torch.no_grad():
        output = pre_processing_model(img_tensor).squeeze().cpu().numpy()

    features.append(output)
    print(f"Processed frame {count}")
    count += 1

cap.release()
cv2.destroyAllWindows()
print("Preprocessing of all frames done")




Processed frame 0
Processed frame 1
Processed frame 2
Processed frame 3
Processed frame 4
Processed frame 5
Processed frame 6
Processed frame 7
Processed frame 8
Processed frame 9
Processed frame 10
Processed frame 11
Processed frame 12
Processed frame 13
Processed frame 14
Processed frame 15
Processed frame 16
Processed frame 17
Processed frame 18
Processed frame 19
Processed frame 20
Processed frame 21
Processed frame 22
Processed frame 23
Processed frame 24
Processed frame 25
Processed frame 26
Processed frame 27
Processed frame 28
Processed frame 29
Processed frame 30
Processed frame 31
Processed frame 32
Processed frame 33
Processed frame 34
Processed frame 35
Processed frame 36
Processed frame 37
Processed frame 38
Processed frame 39
Processed frame 40
Processed frame 41
Processed frame 42
Processed frame 43
Processed frame 44
Processed frame 45
Processed frame 46
Processed frame 47
Processed frame 48
Processed frame 49
Processed frame 50
Processed frame 51
Processed frame 52
Pro

In [5]:
from sklearn.preprocessing import OneHotEncoder
ohe = OneHotEncoder()
labels = ["none","play"]
ohe.fit(np.array(labels).reshape(-1, 1))

In [6]:
features_np = np.array(features)
with torch.no_grad():
    full_feat = torch.tensor(features_np, dtype=torch.float32).unsqueeze(0)  # shape: [1, num_frames, feat_dim]
    full_feat = full_feat.cuda() if torch.cuda.is_available() else full_feat

    logits = model(full_feat)  # shape: [1, num_frames, num_classes]

    probs = torch.sigmoid(logits).squeeze(0)  # shape: [num_frames, num_classes]
    pred_labels = torch.argmax(probs, dim=1).cpu().numpy()  # shape: [num_frames]

label_strings = ohe.inverse_transform(np.eye(2)[pred_labels])


In [7]:
label_strings

array([['none'],
       ['none'],
       ['none'],
       ...,
       ['none'],
       ['none'],
       ['none']], dtype='<U4')

In [8]:
unique = set()
for val in label_strings:
  unique.add(val[0])
print(unique)
print(len(label_strings))

{'play', 'none'}
15435


In [9]:
def show_results(video_path,results):
    video = cv2.VideoCapture(video_path)
    count=0
    while True:
        ret,frame = video.read()
        if not ret:
            break
        
        cv2.putText(frame,results[count][0],(50,50),cv2.FONT_HERSHEY_SIMPLEX,1,(0,255,0),2)
        cv2.imshow("Frame",frame)
        if cv2.waitKey(25) & 0xFF == ord('q'):
            break

        count += 1
        if count >= len(results):
            break
    
    video.release()
    cv2.destroyAllWindows()

In [10]:
def smooth_output(labels, window_size=30):
    smoothed = []
    half_window = window_size // 2
    padded_labels = ['play'] * half_window + label_strings.tolist() + ['play'] * half_window

    for i in range(half_window, len(padded_labels) - half_window):
        window = padded_labels[i - half_window:i + half_window + 1]
        most_common = max(window, key=window.count)
        smoothed.append(most_common)

    return smoothed

In [13]:
smoothed_output = smooth_output(label_strings, window_size=30)
show_results("../Videos/rallies_01.mp4",smoothed_output)