In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
cd ../src/

## Imports

In [None]:
import os
import re
import cv2
import time
import torch
import imageio
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from collections import Counter
from tqdm.notebook import tqdm
from skimage.transform import resize

In [None]:
from params import *

## Data

In [None]:
PATH = CROP_PATH_3D
DF_PATH = 'df_train_3d.csv'

### Load

In [None]:
df_train = pd.read_csv(DATA_PATH + DF_PATH)

In [None]:
df_train.dropna(inplace=True)

In [None]:
frame_impacts = df_train[["image_name", "extended_impact"]].groupby('image_name').max().reset_index()
frame_impacts = frame_impacts.rename(columns={"extended_impact": "frame_has_impact"})
df_train = df_train.merge(frame_impacts, on="image_name")

###  Aux target

In [None]:
# Single-label approach
aux_label = np.zeros((len(df_train)))

aux_label[df_train['extended_Helmet'] == 1] = 1
aux_label[df_train['extended_Shoulder'] == 1] = 2
aux_label[df_train['extended_shoulder'] == 1] = 2
aux_label[df_train['extended_Body'] == 1] = 2
aux_label[df_train['extended_Hand'] == 1] = 2
aux_label[df_train['extended_Ground'] == 1] = 3

df_train['aux_target'] = list(aux_label)

### Folds

In [None]:
folds = pd.read_csv(OUT_DIR + "folds.csv")
df_train = df_train.merge(folds, on="video")

## Dataset

In [None]:
from data.dataset import NFLDatasetCls3D
from data.transforms import get_transfos_cls

In [None]:
dataset = NFLDatasetCls3D(
    df_train.copy(),
    root=PATH,
    target_name='extended_impact',
    visualize=True
)

In [None]:
# for i in tqdm(range(len(dataset))):
#     image, y, y_aux = dataset[i]
#     assert image.shape == (9, 64, 64, 3)A

In [None]:
df_train['image_name'].unique()

In [None]:
image, y, y_aux = dataset[0]

In [None]:
i = np.random.choice(len(dataset))

for i in np.random.choice(len(dataset), 10):
    image, y, y_aux = dataset[i]
    
    if y:
        plt.figure(figsize=(15, 15))
        for i, img in enumerate(image):
            if image.shape[0] == 9:
                plt.subplot(3, 3, i+1)
            else:
                plt.subplot(4, 5, i+1)
            plt.imshow(img)
            plt.axis(False)
        plt.show()

## Model

In [None]:
from model_zoo.models_cls_3d import get_model_cls_3d

In [None]:
# model = get_model_cls_3d('i3d', num_classes=1, num_classes_aux=5)
model = get_model_cls_3d('slowonly', num_classes=1, num_classes_aux=5)
# model = get_model_cls_3d('resnet50', num_classes=1, num_classes_aux=0)

In [None]:
x = torch.randn(1, 3, 9, 64, 64)

In [None]:
model(x)

## Training

In [None]:
from training.main_cls_3d import k_fold_cls_3d

In [None]:
from utils.logger import prepare_log_folder, save_config, create_logger

In [None]:
BATCH_SIZES = {
    "i3d": 32,
    "slowfast": 64,
    "slowonly": 32,
    "resnet18": 128,
    "resnet34": 64,
    "resnet50": 32,
}

In [None]:
class Config:
    """
    Parameters used for training
    """
    # General
    seed = 42
    verbose = 1
    img_path = PATH
    device = "cuda" if torch.cuda.is_available() else "cpu"
    save_weights = True
    
    # Target
#     target_name = f'impact_{STRIDE}_{N_FRAMES}'
    target_name = "extended_impact"

    # k-fold
    k = 5
    random_state = 0
    selected_folds = [0, 1, 2, 3, 4]

    # Model
    name = "i3d"  #'slowonly', "slowfast", "resnet18", "resnet34"
    num_classes = 1
    
    aux_mode = "softmax"
    num_classes_aux = 0  # 4

    # Training       
    batch_size = BATCH_SIZES[name]
    samples_per_player = 4
    optimizer = "Adam"
    
    acc_steps = 1
    epochs = 20 if samples_per_player else 4
    swa_first_epoch = 15

    lr = 5e-4  # 5e-4 / 1e-3
    warmup_prop = 0.05
    val_bs = batch_size * 2
    
    first_epoch_eval = 20

In [None]:
DEBUG = True
log_folder = None

In [None]:
if not DEBUG:
    log_folder = prepare_log_folder(LOG_PATH_CLS_3D)
    print(f'Logging results to {log_folder}')
    config_df = save_config(Config, log_folder + 'config.json')
    create_logger(directory=log_folder, name="logs.txt")
    
    warnings.filterwarnings("ignore")

pred_oof = k_fold_cls_3d(
    Config,
    df_train,
    log_folder=log_folder
)