In [113]:
import os
import warnings
from dotenv import load_dotenv
load_dotenv()
warnings.simplefilter("ignore", UserWarning)

import yaml
import wandb
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.models import resnet
import torchvision.transforms as transforms
from sklearn.preprocessing import OrdinalEncoder

from models import ResNet18_RNN
from dataset import TricksDataset
from utils import train_fn, get_loaders, load_checkpoint, save_checkpoint, check_performance, plot_frames

with open("config_hard.yaml", "r") as f:
    config = yaml.safe_load(f)

EPOCHS = config["training_parameters"]["epochs"]
LEARNING_RATE = config["training_parameters"]["learning_rate"]
LABEL_COLUMNS = config["training_parameters"]["label_columns"]

TRAIN_CSV = config["dataloader_parameters"]["train_csv"]
VAL_CSV = config["dataloader_parameters"]["val_csv"]
ROOT_DIR = config["dataloader_parameters"]["root_dir"]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 2
MAX_FRAMES = config["dataloader_parameters"]["max_frames"]
NUM_WORKERS = config["dataloader_parameters"]["num_workers"]
PIN_MEMORY = config["dataloader_parameters"]["pin_memory"]

RNN_TYPE = config["model_parameters"]["rnn_type"]
RNN_LAYERS = config["model_parameters"]["rnn_layers"]
RNN_HIDDEN = config["model_parameters"]["rnn_hidden"]
TRAINABLE_BACKBONE = config["model_parameters"]["trainable_backbone"]
HEADS_PARAMS = config["model_parameters"]["heads_params"]
HEADS_PARAMS["in_features"] = RNN_HIDDEN * MAX_FRAMES
train_labels = pd.read_csv(TRAIN_CSV)
df_train = train_labels.loc[train_labels["trick_name"].isin(["heelflip", "kickflip"])].reset_index(drop=True).copy()
if "trick_name" in LABEL_COLUMNS:
    HEADS_PARAMS["n_tricks"] = df_train["trick_name"].nunique()

In [111]:
model = ResNet18_RNN(
    RNN_TYPE, 
    RNN_LAYERS, 
    RNN_HIDDEN, 
    HEADS_PARAMS, 
    TRAINABLE_BACKBONE
)

resnet_transforms = resnet.ResNet18_Weights.DEFAULT.transforms()

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

loss_fns = {
    'trick_name': nn.CrossEntropyLoss(),
    'landed': nn.CrossEntropyLoss(),
    'stance': nn.CrossEntropyLoss()
}

train_transforms = transforms.Compose([
    transforms.ColorJitter(),
    transforms.RandomHorizontalFlip(),
    resnet_transforms,
])

val_transforms = transforms.Compose([
    transforms.ColorJitter(),
    transforms.RandomHorizontalFlip(),
    resnet_transforms,
])

encoder = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1, dtype=int).set_output(transform="pandas")
encoder.fit(df_train[LABEL_COLUMNS])

train_ds = TricksDataset(
    csv_file=df_train,
    root_dir=ROOT_DIR,
    max_frames=MAX_FRAMES,
    transform=train_transforms,
    label_enconder=encoder
)

train_loader, val_loader = get_loaders(
    df_train,
    VAL_CSV,
    ROOT_DIR,
    MAX_FRAMES,
    1,
    train_transforms,
    val_transforms,
    NUM_WORKERS,
    PIN_MEMORY,
    encoder
)

model.to(DEVICE);

In [50]:
for data, target in train_loader:
    break

In [112]:
for _ in range(1):
    x = train_fn(train_loader, model, optimizer, loss_fns, DEVICE)

100%|██████████| 63/63 [05:01<00:00,  4.79s/it, loss_total=2.23]


In [114]:
model.eval()
for idx, (data, target) in enumerate(train_loader):
    print("-"*50)
    print(f"PREDICTION #{idx+1}: ")
    with torch.no_grad():
        preds = model(data)
    pred_decoded = encoder.inverse_transform(np.array([[F.softmax(val).argmax().item() for key, val in preds.items()]]))
    target_decoded = encoder.inverse_transform(np.array([[val.item() for key, val in target.items()]]))
    print(f" Prediction: {pred_decoded} \t\t Groundtruth: {target_decoded}")

--------------------------------------------------
PREDICTION #1: 
 Prediction: [['kickflip' True 'fakie']] 		 Groundtruth: [['kickflip' True 'fakie']]
--------------------------------------------------
PREDICTION #2: 
 Prediction: [['kickflip' True 'fakie']] 		 Groundtruth: [['heelflip' False 'fakie']]
--------------------------------------------------
PREDICTION #3: 
 Prediction: [['kickflip' True 'fakie']] 		 Groundtruth: [['heelflip' True 'nollie']]
--------------------------------------------------
PREDICTION #4: 
 Prediction: [['kickflip' True 'fakie']] 		 Groundtruth: [['kickflip' True 'fakie']]
--------------------------------------------------
PREDICTION #5: 
 Prediction: [['kickflip' True 'fakie']] 		 Groundtruth: [['kickflip' True 'regular']]
--------------------------------------------------
PREDICTION #6: 
 Prediction: [['kickflip' True 'fakie']] 		 Groundtruth: [['kickflip' True 'nollie']]
--------------------------------------------------
PREDICTION #7: 
 Prediction: [['

KeyboardInterrupt: 