# Training
This file contains the model training using the data that was formatted in the `data_format.ipynb`

In [2]:
import torch
import torchvision

from torch import nn
import torch.utils.data as data
import torchvision.transforms as transforms
from torchinfo import summary

from tqdm.auto import tqdm
from typing import Dict, List
import matplotlib.pyplot as plt
from timeit import default_timer as timer
import numpy as np
import os
from PIL import Image
import re
import json

## Defining Model
For this project, I used ResNet-like model which uses 3 dimension (height, width, time)

- a base class for each block

In [3]:
class Block(nn.Module):
    def __init__(self, dimension):
        super().__init__()
        self.conv1 = nn.Conv3d(dimension, dimension, 3, padding='same')
        self.conv2 = nn.Conv3d(dimension, dimension, 3, padding='same')
        self.norm1 = nn.BatchNorm3d(dimension)
        self.norm2 = nn.BatchNorm3d(dimension)
        self.act = nn.GELU()
    
    def forward(self, x):
        identity = x

        x = self.conv1(x)
        x = self.norm1(x)
        x = self.act(x)

        x = self.conv2(x)
        x = self.norm2(x)

        x = x + identity

        x = self.act(x)

        return x

- a whole model

In [4]:
class TempResNet(nn.Module):
    def __init__(self, channel, depth):
        super().__init__()
        self.depth = depth
        self.encode = nn.Conv3d(channel, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False)
        self.norm = nn.BatchNorm3d(64)
        self.expand1 = nn.Sequential(
            nn.Conv3d(64, 128, 1),
            nn.BatchNorm3d(128)
        )
        self.expand2 = nn.Sequential(
            nn.Conv3d(128, 256, 1),
            nn.BatchNorm3d(256)
        )
        self.expand3 = nn.Sequential(
            nn.Conv3d(256, 512, 1),
            nn.BatchNorm3d(512)
        )
        self.act = nn.GELU()

        self.max_pool1 = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
        self.max_pool2 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0)

        self.block1 = Block(64)
        self.block2 = Block(128)
        self.block3 = Block(256)
        self.block4 = Block(512)

        self.ave_pool = nn.AdaptiveAvgPool3d((1, 1, 1))

        self.flatten = nn.Flatten()
        self.fc = nn.Linear(512, depth*21*3)

    def forward(self, x):
        x = self.encode(x)
        x = self.norm(x)
        x = self.act(x)
        x = self.max_pool1(x)
        x = self.block1(x)
        x = self.max_pool2(x)

        x = self.expand1(x)
        x = self.block2(x)
        x = self.max_pool2(x)

        x = self.expand2(x)
        x = self.block3(x)
        x = self.max_pool2(x)

        x = self.expand3(x)
        x = self.block4(x)

        x = self.ave_pool(x)
        x = self.flatten(x)

        x = self.fc(x)

        x = x.view(-1, self.depth, 21, 3)

        return x

## Constants and Helper Function

In [5]:
DEVICE = "cuda:1" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda:1'

In [6]:
thresholds = [i for i in range(0, 80, 5)]

In [7]:
def draw_bone(start, end, ax, c):
    xs = [start[0], end[0]]
    ys = [start[1], end[1]]
    zs = [start[2], end[2]]
    ax.plot(xs, ys, zs, color=c)

## Dataset Class

In [8]:
class Dataset(data.Dataset):

    def __init__(self, root='./', transform=None, test=False):
        self.root = root
        self.transform = transform

        if test:
            self.images = np.load(os.path.join(root, 'data/image_path_test.npy'))
            self.movement = np.load(os.path.join(root, 'data/movement_test.npy'))
        else:
            self.images = np.load(os.path.join(root, 'data/image_path.npy'))
            self.movement = np.load(os.path.join(root, 'data/movement.npy'))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, points3D, path)
        """
        paths = []
        images = []
        for img_name in self.images[index]:
            path = os.path.join(self.root, img_name)
            paths.append(path)
            images.append(Image.open(path))
        movement = self.movement[index]

        if self.transform is not None:
            for i in range(len(images)):
                images[i] = self.transform(images[i])
        
        images = torch.stack(images).permute(1, 0, 2, 3)

        return images, movement, paths

    def __len__(self):
        return len(self.images)

In [9]:
transform = transforms.Compose([transforms.Resize((256, 256)),
                                transforms.ToTensor()])

train_dataset = Dataset(transform=transform)
test_dataset = Dataset(transform=transform, test=True)

In [10]:
trainloader = data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=16, pin_memory=True)
testloader = data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=16, pin_memory=True)

## Preparation of Training

In [11]:
imgs, movement, paths = train_dataset[0]
channel, depth, img_size, _ = imgs.shape

model = TempResNet(channel, depth).to(DEVICE)

### Model Summary

In [12]:
summary(model=model, input_size=(1, channel, depth, img_size, img_size))

Layer (type:depth-idx)                   Output Shape              Param #
TempResNet                               [1, 10, 21, 3]            --
├─Conv3d: 1-1                            [1, 64, 10, 128, 128]     28,224
├─BatchNorm3d: 1-2                       [1, 64, 10, 128, 128]     128
├─GELU: 1-3                              [1, 64, 10, 128, 128]     --
├─MaxPool3d: 1-4                         [1, 64, 10, 64, 64]       --
├─Block: 1-5                             [1, 64, 10, 64, 64]       --
│    └─Conv3d: 2-1                       [1, 64, 10, 64, 64]       110,656
│    └─BatchNorm3d: 2-2                  [1, 64, 10, 64, 64]       128
│    └─GELU: 2-3                         [1, 64, 10, 64, 64]       --
│    └─Conv3d: 2-4                       [1, 64, 10, 64, 64]       110,656
│    └─BatchNorm3d: 2-5                  [1, 64, 10, 64, 64]       128
│    └─GELU: 2-6                         [1, 64, 10, 64, 64]       --
├─MaxPool3d: 1-6                         [1, 64, 10, 32, 32]       -

### Define a Evaluation Index

In [13]:
def shape_accuracy(pred, target):
    """
    Parameters:
        pred (torch.Tensor): [batch_size, depth, 21, 3]
        target (torch.Tensor): [batch_size, depth, 21, 3]

    Returns:
        accuracy (float): average accuracy
    """
    pred_hand = pred[:, -1, :, :].squeeze(1)
    target_hand = target[:, -1, :, :].squeeze(1)

    distances = torch.sqrt(torch.sum((pred_hand - target_hand) ** 2, dim=2))

    mean_distances = distances.mean(dim=1)

    accuracies = [0 for _ in range(len(thresholds))]
    for mean_distance in mean_distances:
        for i in range(len(thresholds)):
            accuracies[i] += 1 if mean_distance <= thresholds[i] else 0
    
    for i in range(len(accuracies)):
        accuracies[i] /= len(mean_distances)

    return accuracies

In [14]:
def movement_accuracy(pred, target):
    """
    Parameters:
        pred (torch.Tensor): [batch_size, depth, 21, 3]
        target (torch.Tensor): [batch_size, depth, 21, 3]

    Returns:
        accuracy (float): average accuracy
    """
    pred_move = pred[:, -1, [0, 1, 5], :]
    target_move = target[:, -1, [0, 1, 5], :]

    distances = torch.sqrt(torch.sum((pred_move - target_move) ** 2, dim=2))

    mean_distances = distances.mean(dim=1)

    accuracies = [0 for _ in range(len(thresholds))]
    for mean_distance in mean_distances:
        for i in range(len(thresholds)):
            accuracies[i] += 1 if mean_distance <= thresholds[i] else 0
    
    for i in range(len(accuracies)):
        accuracies[i] /= len(mean_distances)
    
    return accuracies

### Defining Train Step and Test Step

In [15]:
def train_step(model: nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn: torch.nn.Module, optimizer: torch.optim.Optimizer):
  model.train()

  train_loss, train_shape_acc, train_movement_acc = 0, [0 for _ in range(len(thresholds))], [0 for _ in range(len(thresholds))]

  for batch, (X, movement, _) in enumerate(dataloader):
    X, movement = X.float().to(DEVICE), movement.float().to(DEVICE)

    pred_movement = model(X)

    loss = loss_fn(pred_movement, movement)
    train_loss += loss.item()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    shape_accuracies = shape_accuracy(pred_movement, movement)
    movement_accuracies = movement_accuracy(pred_movement, movement)

    for i in range(len(thresholds)):
      train_shape_acc[i] += shape_accuracies[i]
      train_movement_acc[i] += movement_accuracies[i]

  train_loss /= len(dataloader)
  for i in range(len(train_shape_acc)):
    train_shape_acc[i] /= len(dataloader)
  for i in range(len(train_movement_acc)):
    train_movement_acc[i] /= len(dataloader)

  return train_loss, train_shape_acc, train_movement_acc

In [16]:
def test_step(model: nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn: torch.nn.Module):
  model.eval()

  test_loss, test_shape_acc, test_movement_acc = 0, [0 for _ in range(len(thresholds))], [0 for _ in range(len(thresholds))]

  with torch.inference_mode():
    for batch, (X, movement, _) in enumerate(dataloader):
      X, movement = X.to(DEVICE), movement.to(DEVICE)

      pred_movement = model(X)

      loss = loss_fn(pred_movement, movement)
      test_loss += loss.item()

      shape_accuracies = shape_accuracy(pred_movement, movement)
      movement_accuracies = movement_accuracy(pred_movement, movement)

      for i in range(len(thresholds)):
        test_shape_acc[i] += shape_accuracies[i]
        test_movement_acc[i] += movement_accuracies[i]

  test_loss /= len(dataloader)
  for i in range(len(test_shape_acc)):
    test_shape_acc[i] /= len(dataloader)
  for i in range(len(test_movement_acc)):
    test_movement_acc[i] /= len(dataloader)

  return test_loss, test_shape_acc, test_movement_acc

### Define Train Process

In [17]:
def train(model: torch.nn.Module,
          train_dataloader: torch.utils.data.DataLoader,
          test_dataloader: torch.utils.data.DataLoader,
          optimizer: torch.optim.Optimizer,
          scheduler: torch.optim.lr_scheduler._LRScheduler=None,
          loss_fn: torch.nn.Module = nn.CrossEntropyLoss(),
          epochs: int=5):

  results = {
      "train_loss": [],
      "train_shape_acc": [],
      "train_movement_acc": [],
      "test_loss": [],
      "test_shape_acc": [],
      "test_movement_acc": [],
  }

  for epoch in tqdm(range(epochs)):
    train_loss, train_shape_acc, train_movement_acc = train_step(model=model, dataloader=train_dataloader, loss_fn=loss_fn, optimizer=optimizer)
    if scheduler != None:
      scheduler.step()
    test_loss, test_shape_acc, test_movement_acc = test_step(model, dataloader=test_dataloader, loss_fn=loss_fn)

    print(f"Epoch: {epoch} | Train Loss: {train_loss:.4f} | Train Shape Acc: {train_shape_acc[4]:.3f} | Train Movement Acc: {train_movement_acc[4]:.3f} | Test Loss: {test_loss:.4f} | Test Shape Acc: {test_shape_acc[4]:.3f} | Test Movement Acc: {test_movement_acc[4]:.3f}")

    results["train_loss"].append(train_loss)
    results["train_shape_acc"].append(train_shape_acc)
    results["train_movement_acc"].append(train_movement_acc)
    results["test_loss"].append(test_loss)
    results["test_shape_acc"].append(test_shape_acc)
    results["test_movement_acc"].append(test_movement_acc)

  return results

## Train

In [18]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

NUM_EPOCHS = 50
LR = 0.1

loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)

start_time = timer()

model_results = train(model,
                        train_dataloader=trainloader,
                        test_dataloader=testloader,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        loss_fn=loss_fn,
                        epochs=NUM_EPOCHS)

end_time = timer()
print(f"Total Training time: {end_time - start_time:.3f} seconds")

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch: 0 | Train Loss: 1251.1442 | Train Shape Acc: 0.110 | Train Movement Acc: 0.059 | Test Loss: 7459.7711 | Test Shape Acc: 0.000 | Test Movement Acc: 0.000
Epoch: 1 | Train Loss: 1157.9348 | Train Shape Acc: 0.171 | Train Movement Acc: 0.078 | Test Loss: 1917.7626 | Test Shape Acc: 0.135 | Test Movement Acc: 0.002
Epoch: 2 | Train Loss: 1120.6874 | Train Shape Acc: 0.174 | Train Movement Acc: 0.072 | Test Loss: 1906.8619 | Test Shape Acc: 0.074 | Test Movement Acc: 0.000
Epoch: 3 | Train Loss: 1083.5413 | Train Shape Acc: 0.204 | Train Movement Acc: 0.088 | Test Loss: 1250.5532 | Test Shape Acc: 0.131 | Test Movement Acc: 0.010
Epoch: 4 | Train Loss: 1061.4866 | Train Shape Acc: 0.187 | Train Movement Acc: 0.067 | Test Loss: 2847.3204 | Test Shape Acc: 0.036 | Test Movement Acc: 0.000
Epoch: 5 | Train Loss: 1003.7060 | Train Shape Acc: 0.195 | Train Movement Acc: 0.072 | Test Loss: 1480.2115 | Test Shape Acc: 0.104 | Test Movement Acc: 0.008
Epoch: 6 | Train Loss: 977.5283 | Train 

In [19]:
with open("./data/result.json", 'w') as f:
    json.dump(model_results, f)

In [21]:
torch.save(model, './data/model.pth')