In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from pycocotools.coco import COCO
import cv2
import numpy as np
import os
import matplotlib.pyplot as plt
from models import PoseNet
from utils import CocoKeypoints, VideoCreator, convert_to_modern_mp4, detect_keypoints

import warnings
warnings.filterwarnings("ignore")

In [None]:
coco_root = '/kaggle/input/coco-2017-dataset/coco2017'
train_img_dir = os.path.join(coco_root, 'train2017')
train_ann_file = os.path.join(coco_root, 'annotations/person_keypoints_train2017.json')

dataset = CocoKeypoints(train_img_dir, train_ann_file)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
from IPython.display import FileLink

def train(model, train_loader, epochs=10, lr=0.001, device='cuda'):
    model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        for i, (images, heatmaps) in enumerate(train_loader):
            images = images.to(device)
            heatmaps = heatmaps.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, heatmaps)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 50 == 49:
                print(f'Epoch {epoch+1}, Batch {i+1}: Loss {running_loss/50:.4f}')
                running_loss = 0.0

        # Save the model after each epoch
        model_path = f'model_epoch_{epoch+1}.pth'
        torch.save(model.state_dict(), model_path)
        print(f'Model saved: {model_path}')
        
        # Generate a download link
        display(FileLink(model_path))

        print(f'Epoch {epoch+1} completed')
    
    return model

In [16]:
model = PoseNet()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = train(model, train_loader, epochs=10, device=device)

Epoch 1, Batch 50: Loss 2.1402
Epoch 1, Batch 100: Loss 0.0225
Epoch 1, Batch 150: Loss 0.0114
Epoch 1, Batch 200: Loss 0.0071
Epoch 1, Batch 250: Loss 0.0053
Epoch 1, Batch 300: Loss 0.0047
Epoch 1, Batch 350: Loss 0.0048
Epoch 1, Batch 400: Loss 0.0045
Epoch 1, Batch 450: Loss 0.0038
Epoch 1, Batch 500: Loss 0.0036
Epoch 1, Batch 550: Loss 0.0034
Epoch 1, Batch 600: Loss 0.0032
Epoch 1, Batch 650: Loss 0.0036
Epoch 1, Batch 700: Loss 0.0034
Epoch 1, Batch 750: Loss 0.0032
Epoch 1, Batch 800: Loss 0.0032
Epoch 1, Batch 850: Loss 0.0030
Epoch 1, Batch 900: Loss 0.0031
Epoch 1, Batch 950: Loss 0.0030
Epoch 1, Batch 1000: Loss 0.0030
Epoch 1, Batch 1050: Loss 0.0032
Epoch 1, Batch 1100: Loss 0.0032
Epoch 1, Batch 1150: Loss 0.0029
Epoch 1, Batch 1200: Loss 0.0029
Epoch 1, Batch 1250: Loss 0.0027
Epoch 1, Batch 1300: Loss 0.0028
Epoch 1, Batch 1350: Loss 0.0028
Epoch 1, Batch 1400: Loss 0.0027
Epoch 1, Batch 1450: Loss 0.0027
Epoch 1, Batch 1500: Loss 0.0027
Epoch 1, Batch 1550: Loss 0.00

Epoch 1 completed
Epoch 2, Batch 50: Loss 0.0026
Epoch 2, Batch 100: Loss 0.0028
Epoch 2, Batch 150: Loss 0.0027
Epoch 2, Batch 200: Loss 0.0026
Epoch 2, Batch 250: Loss 0.0026
Epoch 2, Batch 300: Loss 0.0027
Epoch 2, Batch 350: Loss 0.0026
Epoch 2, Batch 400: Loss 0.0027
Epoch 2, Batch 450: Loss 0.0026
Epoch 2, Batch 500: Loss 0.0027
Epoch 2, Batch 550: Loss 0.0028
Epoch 2, Batch 600: Loss 0.0029
Epoch 2, Batch 650: Loss 0.0027
Epoch 2, Batch 700: Loss 0.0028
Epoch 2, Batch 750: Loss 0.0027
Epoch 2, Batch 800: Loss 0.0027
Epoch 2, Batch 850: Loss 0.0027
Epoch 2, Batch 900: Loss 0.0026
Epoch 2, Batch 950: Loss 0.0027
Epoch 2, Batch 1000: Loss 0.0026
Epoch 2, Batch 1050: Loss 0.0026
Epoch 2, Batch 1100: Loss 0.0026
Epoch 2, Batch 1150: Loss 0.0026
Epoch 2, Batch 1200: Loss 0.0027
Epoch 2, Batch 1250: Loss 0.0026
Epoch 2, Batch 1300: Loss 0.0026
Epoch 2, Batch 1350: Loss 0.0026
Epoch 2, Batch 1400: Loss 0.0031
Epoch 2, Batch 1450: Loss 0.0028
Epoch 2, Batch 1500: Loss 0.0028
Epoch 2, Bat

Epoch 2 completed
Epoch 3, Batch 50: Loss 0.0031
Epoch 3, Batch 100: Loss 0.0031
Epoch 3, Batch 150: Loss 0.0029
Epoch 3, Batch 200: Loss 0.0034
Epoch 3, Batch 250: Loss 0.0030
Epoch 3, Batch 300: Loss 0.0031
Epoch 3, Batch 350: Loss 0.0030
Epoch 3, Batch 400: Loss 0.0031
Epoch 3, Batch 450: Loss 0.0032
Epoch 3, Batch 500: Loss 0.0028
Epoch 3, Batch 550: Loss 0.0032
Epoch 3, Batch 600: Loss 0.0031
Epoch 3, Batch 650: Loss 0.0031
Epoch 3, Batch 700: Loss 0.0029
Epoch 3, Batch 750: Loss 0.0034
Epoch 3, Batch 800: Loss 0.0032
Epoch 3, Batch 850: Loss 0.0031
Epoch 3, Batch 900: Loss 0.0031
Epoch 3, Batch 950: Loss 0.0030
Epoch 3, Batch 1000: Loss 0.0031
Epoch 3, Batch 1050: Loss 0.0030
Epoch 3, Batch 1100: Loss 0.0030
Epoch 3, Batch 1150: Loss 0.0034
Epoch 3, Batch 1200: Loss 0.0030
Epoch 3, Batch 1250: Loss 0.0030
Epoch 3, Batch 1300: Loss 0.0028
Epoch 3, Batch 1350: Loss 0.0033
Epoch 3, Batch 1400: Loss 0.0029
Epoch 3, Batch 1450: Loss 0.0029
Epoch 3, Batch 1500: Loss 0.0032
Epoch 3, Bat

Epoch 3 completed
Epoch 4, Batch 50: Loss 0.0030
Epoch 4, Batch 100: Loss 0.0030
Epoch 4, Batch 150: Loss 0.0029
Epoch 4, Batch 200: Loss 0.0029
Epoch 4, Batch 250: Loss 0.0031
Epoch 4, Batch 300: Loss 0.0033
Epoch 4, Batch 350: Loss 0.0030
Epoch 4, Batch 400: Loss 0.0028
Epoch 4, Batch 450: Loss 0.0027
Epoch 4, Batch 500: Loss 0.0033
Epoch 4, Batch 550: Loss 0.0027
Epoch 4, Batch 600: Loss 0.0029
Epoch 4, Batch 650: Loss 0.0028
Epoch 4, Batch 700: Loss 0.0030
Epoch 4, Batch 750: Loss 0.0029
Epoch 4, Batch 800: Loss 0.0030
Epoch 4, Batch 850: Loss 0.0029
Epoch 4, Batch 900: Loss 0.0027
Epoch 4, Batch 950: Loss 0.0031
Epoch 4, Batch 1000: Loss 0.0028
Epoch 4, Batch 1050: Loss 0.0029
Epoch 4, Batch 1100: Loss 0.0029
Epoch 4, Batch 1150: Loss 0.0028
Epoch 4, Batch 1200: Loss 0.0029
Epoch 4, Batch 1250: Loss 0.0029
Epoch 4, Batch 1300: Loss 0.0029
Epoch 4, Batch 1350: Loss 0.0028
Epoch 4, Batch 1400: Loss 0.0028
Epoch 4, Batch 1450: Loss 0.0029
Epoch 4, Batch 1500: Loss 0.0029
Epoch 4, Bat

Epoch 4 completed
Epoch 5, Batch 50: Loss 0.0028
Epoch 5, Batch 100: Loss 0.0027
Epoch 5, Batch 150: Loss 0.0029
Epoch 5, Batch 200: Loss 0.0029
Epoch 5, Batch 250: Loss 0.0027
Epoch 5, Batch 300: Loss 0.0028
Epoch 5, Batch 350: Loss 0.0027
Epoch 5, Batch 400: Loss 0.0028
Epoch 5, Batch 450: Loss 0.0027
Epoch 5, Batch 500: Loss 0.0026
Epoch 5, Batch 550: Loss 0.0028
Epoch 5, Batch 600: Loss 0.0029
Epoch 5, Batch 650: Loss 0.0028
Epoch 5, Batch 700: Loss 0.0028
Epoch 5, Batch 750: Loss 0.0028
Epoch 5, Batch 800: Loss 0.0027
Epoch 5, Batch 850: Loss 0.0029
Epoch 5, Batch 900: Loss 0.0027
Epoch 5, Batch 950: Loss 0.0028
Epoch 5, Batch 1000: Loss 0.0026
Epoch 5, Batch 1050: Loss 0.0026
Epoch 5, Batch 1100: Loss 0.0026
Epoch 5, Batch 1150: Loss 0.0028
Epoch 5, Batch 1200: Loss 0.0027
Epoch 5, Batch 1250: Loss 0.0026
Epoch 5, Batch 1300: Loss 0.0027
Epoch 5, Batch 1350: Loss 0.0026
Epoch 5, Batch 1400: Loss 0.0027
Epoch 5, Batch 1450: Loss 0.0027
Epoch 5, Batch 1500: Loss 0.0027
Epoch 5, Bat

Epoch 5 completed
Epoch 6, Batch 50: Loss 0.0027
Epoch 6, Batch 100: Loss 0.0026
Epoch 6, Batch 150: Loss 0.0026
Epoch 6, Batch 200: Loss 0.0026
Epoch 6, Batch 250: Loss 0.0027
Epoch 6, Batch 300: Loss 0.0026
Epoch 6, Batch 350: Loss 0.0027
Epoch 6, Batch 400: Loss 0.0027
Epoch 6, Batch 450: Loss 0.0026
Epoch 6, Batch 500: Loss 0.0027
Epoch 6, Batch 550: Loss 0.0027
Epoch 6, Batch 600: Loss 0.0028
Epoch 6, Batch 650: Loss 0.0027
Epoch 6, Batch 700: Loss 0.0027
Epoch 6, Batch 750: Loss 0.0027
Epoch 6, Batch 800: Loss 0.0027
Epoch 6, Batch 850: Loss 0.0027
Epoch 6, Batch 900: Loss 0.0027
Epoch 6, Batch 950: Loss 0.0026
Epoch 6, Batch 1000: Loss 0.0027
Epoch 6, Batch 1050: Loss 0.0027
Epoch 6, Batch 1100: Loss 0.0026
Epoch 6, Batch 1150: Loss 0.0026
Epoch 6, Batch 1200: Loss 0.0026
Epoch 6, Batch 1250: Loss 0.0026
Epoch 6, Batch 1300: Loss 0.0026
Epoch 6, Batch 1350: Loss 0.0026
Epoch 6, Batch 1400: Loss 0.0026
Epoch 6, Batch 1450: Loss 0.0026
Epoch 6, Batch 1500: Loss 0.0026
Epoch 6, Bat

Epoch 6 completed
Epoch 7, Batch 50: Loss 0.0027
Epoch 7, Batch 100: Loss 0.0025
Epoch 7, Batch 150: Loss 0.0026
Epoch 7, Batch 200: Loss 0.0027
Epoch 7, Batch 250: Loss 0.0026
Epoch 7, Batch 300: Loss 0.0026
Epoch 7, Batch 350: Loss 0.0026
Epoch 7, Batch 400: Loss 0.0026
Epoch 7, Batch 450: Loss 0.0026
Epoch 7, Batch 500: Loss 0.0025
Epoch 7, Batch 550: Loss 0.0026
Epoch 7, Batch 600: Loss 0.0026
Epoch 7, Batch 650: Loss 0.0026
Epoch 7, Batch 700: Loss 0.0026
Epoch 7, Batch 750: Loss 0.0026
Epoch 7, Batch 800: Loss 0.0026
Epoch 7, Batch 850: Loss 0.0026
Epoch 7, Batch 900: Loss 0.0026
Epoch 7, Batch 950: Loss 0.0025
Epoch 7, Batch 1000: Loss 0.0026
Epoch 7, Batch 1050: Loss 0.0026
Epoch 7, Batch 1100: Loss 0.0026
Epoch 7, Batch 1150: Loss 0.0025
Epoch 7, Batch 1200: Loss 0.0026
Epoch 7, Batch 1250: Loss 0.0026
Epoch 7, Batch 1300: Loss 0.0025
Epoch 7, Batch 1350: Loss 0.0026
Epoch 7, Batch 1400: Loss 0.0025
Epoch 7, Batch 1450: Loss 0.0026
Epoch 7, Batch 1500: Loss 0.0026
Epoch 7, Bat

Epoch 7 completed
Epoch 8, Batch 50: Loss 0.0025
Epoch 8, Batch 100: Loss 0.0026
Epoch 8, Batch 150: Loss 0.0025
Epoch 8, Batch 200: Loss 0.0025
Epoch 8, Batch 250: Loss 0.0025
Epoch 8, Batch 300: Loss 0.0025
Epoch 8, Batch 350: Loss 0.0025
Epoch 8, Batch 400: Loss 0.0025
Epoch 8, Batch 450: Loss 0.0026
Epoch 8, Batch 500: Loss 0.0026
Epoch 8, Batch 550: Loss 0.0025
Epoch 8, Batch 600: Loss 0.0025
Epoch 8, Batch 650: Loss 0.0025
Epoch 8, Batch 700: Loss 0.0025
Epoch 8, Batch 750: Loss 0.0025
Epoch 8, Batch 800: Loss 0.0026
Epoch 8, Batch 850: Loss 0.0026
Epoch 8, Batch 900: Loss 0.0026
Epoch 8, Batch 950: Loss 0.0025
Epoch 8, Batch 1000: Loss 0.0026
Epoch 8, Batch 1050: Loss 0.0025
Epoch 8, Batch 1100: Loss 0.0025
Epoch 8, Batch 1150: Loss 0.0025
Epoch 8, Batch 1200: Loss 0.0026
Epoch 8, Batch 1250: Loss 0.0025
Epoch 8, Batch 1300: Loss 0.0025
Epoch 8, Batch 1350: Loss 0.0025
Epoch 8, Batch 1400: Loss 0.0025
Epoch 8, Batch 1450: Loss 0.0026
Epoch 8, Batch 1500: Loss 0.0025
Epoch 8, Bat

Epoch 8 completed
Epoch 9, Batch 50: Loss 0.0025
Epoch 9, Batch 100: Loss 0.0026
Epoch 9, Batch 150: Loss 0.0026
Epoch 9, Batch 200: Loss 0.0025
Epoch 9, Batch 250: Loss 0.0025
Epoch 9, Batch 300: Loss 0.0026
Epoch 9, Batch 350: Loss 0.0025
Epoch 9, Batch 400: Loss 0.0025
Epoch 9, Batch 450: Loss 0.0026
Epoch 9, Batch 500: Loss 0.0025
Epoch 9, Batch 550: Loss 0.0025
Epoch 9, Batch 600: Loss 0.0025
Epoch 9, Batch 650: Loss 0.0025
Epoch 9, Batch 700: Loss 0.0025
Epoch 9, Batch 750: Loss 0.0025
Epoch 9, Batch 800: Loss 0.0025
Epoch 9, Batch 850: Loss 0.0026
Epoch 9, Batch 900: Loss 0.0025
Epoch 9, Batch 950: Loss 0.0025
Epoch 9, Batch 1000: Loss 0.0025
Epoch 9, Batch 1050: Loss 0.0025
Epoch 9, Batch 1100: Loss 0.0025
Epoch 9, Batch 1150: Loss 0.0025
Epoch 9, Batch 1200: Loss 0.0025
Epoch 9, Batch 1250: Loss 0.0025
Epoch 9, Batch 1300: Loss 0.0026
Epoch 9, Batch 1350: Loss 0.0026
Epoch 9, Batch 1400: Loss 0.0026
Epoch 9, Batch 1450: Loss 0.0026
Epoch 9, Batch 1500: Loss 0.0025
Epoch 9, Bat

Epoch 9 completed
Epoch 10, Batch 50: Loss 0.0025
Epoch 10, Batch 100: Loss 0.0025
Epoch 10, Batch 150: Loss 0.0025
Epoch 10, Batch 200: Loss 0.0026
Epoch 10, Batch 250: Loss 0.0026
Epoch 10, Batch 300: Loss 0.0025
Epoch 10, Batch 350: Loss 0.0025
Epoch 10, Batch 400: Loss 0.0025
Epoch 10, Batch 450: Loss 0.0026
Epoch 10, Batch 500: Loss 0.0025
Epoch 10, Batch 550: Loss 0.0026
Epoch 10, Batch 600: Loss 0.0025
Epoch 10, Batch 650: Loss 0.0026
Epoch 10, Batch 700: Loss 0.0025
Epoch 10, Batch 900: Loss 0.0026
Epoch 10, Batch 950: Loss 0.0025
Epoch 10, Batch 1000: Loss 0.0025
Epoch 10, Batch 1050: Loss 0.0025
Epoch 10, Batch 1100: Loss 0.0025
Epoch 10, Batch 1150: Loss 0.0025
Epoch 10, Batch 1200: Loss 0.0026
Epoch 10, Batch 1250: Loss 0.0026
Epoch 10, Batch 1300: Loss 0.0026
Epoch 10, Batch 1350: Loss 0.0025
Epoch 10, Batch 1400: Loss 0.0025
Epoch 10, Batch 1450: Loss 0.0025
Epoch 10, Batch 1500: Loss 0.0025
Epoch 10, Batch 1550: Loss 0.0025
Epoch 10, Batch 1600: Loss 0.0025
Epoch 10, Bat

Epoch 10 completed
