In [None]:
import pandas as pd
from os import path
import torch
from torch.utils.data import DataLoader
from fastai.losses import FocalLoss

from lsfb_dataset.datasets.lsfb_cont.skeleton_landmarks import SkeletonLandmarksWindowedDataset
from lsfb_dataset.utils.logger import init_root_logger
from lsfb_dataset.utils.training import train_rnn_model
from lsfb_dataset.models import LSTMClassifier

# Load information about the videos

In [None]:
root = 'T:/datasets/lsfb_cont'
df_videos = pd.read_csv(path.join(root, 'videos.csv'))
df_videos.dropna(inplace=True)
df_videos.head()

In [None]:
seed = 1548621
torch.manual_seed(seed)

df_train = df_videos.sample(frac=0.7, random_state=seed)
df_val = df_videos.drop(index=df_train.index)

df_train.shape, df_val.shape

# Load the upper skeletons dataset

In [None]:
datasets = {
    x: SkeletonLandmarksWindowedDataset(root, df_train if x == 'train' else df_val, isolate_transitions=True)
    for x in ['train', 'val']
}

data_loaders = {x: DataLoader(datasets[x], shuffle=True, batch_size=10) for x in ['train', 'val']}

# Train a LSTM with this dataset

In [None]:
def launch_training(model, model_name, num_classes, epoch_nb):
    class_weights = torch.Tensor(datasets['train'].class_weights).cuda()
    criterion = FocalLoss(weight=class_weights, gamma=2.0)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=0.0001, momentum=0.9, nesterov=True)

    init_root_logger(f'{model_name}.log', stdout=True)

    out_model, (last_model_state, best_model_state), (train_metrics, val_metrics) = train_rnn_model(model, criterion, optimizer, data_loaders, num_epochs=epoch_nb, num_classes=num_classes)

    torch.save({
        'last_model': last_model_state,
        'best_model': best_model_state,
        'criterion': criterion.state_dict(),
        'optimizer': optimizer.state_dict(),
        'train_metrics': train_metrics.state_dict(),
        'val_metrics': val_metrics.state_dict(),
    }, f'{model_name}.model')

In [None]:
input_size = 46
hidden_size = 256
num_classes = 3

model = LSTMClassifier(input_size, hidden_size, num_classes)

launch_training(model, 'LSTM_EXAMPLE', num_classes, epoch_nb=10)