In [2]:
import torch
import torch.nn as nn
from torchvision import models
import cv2
import mediapipe as mp
import torchvision.transforms as transforms
import PoseModule as pm




In [3]:


class CustomInceptionV3(nn.Module):
    def __init__(self, num_classes=2, aux_logits=False):
        super(CustomInceptionV3, self).__init__()
        self.model = models.inception_v3(pretrained=False, aux_logits=aux_logits)
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs, num_classes)
        if aux_logits:
            aux_num_ftrs = self.model.AuxLogits.fc.in_features
            self.model.AuxLogits.fc = nn.Linear(aux_num_ftrs, num_classes)

    def forward(self, x):
        return self.model(x)


In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
state_dict = torch.load('C:/Users/20109/OneDrive/Desktop/New-Coach/model.pth', map_location=device)

for key in state_dict.keys():
    print(key)


inception.Conv2d_1a_3x3.conv.weight
inception.Conv2d_1a_3x3.bn.weight
inception.Conv2d_1a_3x3.bn.bias
inception.Conv2d_1a_3x3.bn.running_mean
inception.Conv2d_1a_3x3.bn.running_var
inception.Conv2d_1a_3x3.bn.num_batches_tracked
inception.Conv2d_2a_3x3.conv.weight
inception.Conv2d_2a_3x3.bn.weight
inception.Conv2d_2a_3x3.bn.bias
inception.Conv2d_2a_3x3.bn.running_mean
inception.Conv2d_2a_3x3.bn.running_var
inception.Conv2d_2a_3x3.bn.num_batches_tracked
inception.Conv2d_2b_3x3.conv.weight
inception.Conv2d_2b_3x3.bn.weight
inception.Conv2d_2b_3x3.bn.bias
inception.Conv2d_2b_3x3.bn.running_mean
inception.Conv2d_2b_3x3.bn.running_var
inception.Conv2d_2b_3x3.bn.num_batches_tracked
inception.Conv2d_3b_1x1.conv.weight
inception.Conv2d_3b_1x1.bn.weight
inception.Conv2d_3b_1x1.bn.bias
inception.Conv2d_3b_1x1.bn.running_mean
inception.Conv2d_3b_1x1.bn.running_var
inception.Conv2d_3b_1x1.bn.num_batches_tracked
inception.Conv2d_4a_3x3.conv.weight
inception.Conv2d_4a_3x3.bn.weight
inception.Conv2d_4

In [5]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = CustomInceptionV3(num_classes=2, aux_logits=True)  
model = model.to(device)
new_state_dict = {}
for key, value in state_dict.items():
    if key.startswith('model.'):
        new_key = key[6:] 
    else:
        new_key = 'model.' + key  
    new_state_dict[new_key] = value

model.load_state_dict(new_state_dict, strict=False) 
model.eval()




CustomInceptionV3(
  (model): Inception3(
    (Conv2d_1a_3x3): BasicConv2d(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (Conv2d_2a_3x3): BasicConv2d(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (Conv2d_2b_3x3): BasicConv2d(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (Conv2d_3b_1x1): BasicConv2d(
      (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (Conv2d_4a_3x3):

In [12]:
import cv2
import torch
from torchvision import transforms
import PoseModule as pm  
class_names = ['Correct','Incorrect']

def predict_video(model, video_path):
    cap = cv2.VideoCapture(video_path)
    detector = pm.poseDetector()  
    
    preprocess = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        frame = detector.findPose(frame)
        lmlist = detector.findPosition(frame, draw=False)
        
        if lmlist:
            img_tensor = preprocess(frame).unsqueeze(0).to(device)
            with torch.no_grad():
                outputs = model(img_tensor)
                _, preds = torch.max(outputs, 1)
                prediction = preds.item()
                label = class_names[prediction]  
            
           
            cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
        
        cv2.imshow('Video', frame)
        
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
    
    cap.release()
    cv2.destroyAllWindows()

predict_video(model, 'C:/Users/20109/OneDrive/Desktop/New-Coach2/Test/t2.mp4')
