In [None]:
import numpy as np
import pandas as pd
from dataclasses import dataclass
from dgl.dataloading import GraphDataLoader
import pytorch_lightning as pl

from Dataset.GraphDataset import GymPoseDataset
from Dataset.Augmentation import augment_points

from Model.GraphModel import ALIGNN
from Model.Config import BaseConfig

from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
# tqdm 
from tqdm import tqdm   
import os
import pickle
import torch
from pytorch_lightning.callbacks import ModelCheckpoint

from sklearn.utils import resample

In [None]:
df = pd.read_csv('data/ohp/annotations_ohp.csv')
points = np.load('data/ohp/points_ohp_3d.npz')
df_t =  df[df['split'] == 'train']
train_ws = (df_t['elbow_error'].value_counts() / len(df_t)) * 100
print(train_ws)
train_ws = (df_t['knee_error'].value_counts() / len(df_t)) * 100
print(train_ws)
#for key in points:
#    print(key, points[key].shape)

In [None]:
train_df = df[df['split'] == 'train']
val_df = df[df['split'] == 'val']
test_df = df[df['split'] == 'test']
print(len(train_df))

In [None]:
train_ws = (train_df['knee_error'].value_counts() / len(train_df)).values
train_ws = torch.tensor(train_ws, dtype=torch.float32)
train_ws = 1 - train_ws
print(train_ws)
type_error = 1

In [None]:
train_dataset = GymPoseDataset(train_df, points, "Train", "preprocessed_dataset")
val_dataset = GymPoseDataset(val_df, points, "Val", "preprocessed_dataset")
test_dataset = GymPoseDataset(test_df, points, "Test", "preprocessed_dataset")

In [None]:
model = ALIGNN(BaseConfig, train_ws, type_error)
#model = ALIGNN.load_from_checkpoint(path, config=BaseConfig, weights=train_ws, type_error=type_error)

In [None]:
train_loader = GraphDataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=31, persistent_workers=True)
val_loader = GraphDataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=31, persistent_workers=True)
test_loader = GraphDataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=31)

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_f1',  
    mode='max',           
    save_top_k=3,  
    filename='model-{epoch:02d}-{val_f1:.4f}',      
)

In [None]:
trainer = pl.Trainer(callbacks=[checkpoint_callback], accelerator='gpu', max_epochs=200)

In [None]:
trainer.fit(model, train_loader, val_loader)

In [None]:
best_model_paths = checkpoint_callback.best_k_models

In [None]:
for model_path in best_model_paths:
    print(model_path)
    model = ALIGNN.load_from_checkpoint(model_path, config=BaseConfig, weights=train_ws, type_error=type_error)
    trainer.test(model, test_loader)