# 1. Load dataloaders 
> With `CLAHE`

In [None]:
from data.dataloader import create_dataloaders

train_dataloader, val_dataloader = create_dataloaders(
    healthy_dir= "../project_datasets/drawing/Healthy/",
    pd_dir= "../project_datasets/drawing/Parkinson/",
    
    img_size=(224, 224),
    batch_size= 32,
)

# 2. Load model

In [None]:
from Models.ModelV0 import create_improved_densenet

model = create_improved_densenet()
model_name = "ModelV0"

# 3. Train models

In [None]:
from training.trainer import train


train(
    model= model,
    train_dataloader=  train_dataloader,
    val_dataloader=  val_dataloader,
    
    model_name= model_name,
    run_name= model_name,
    
    # load_pretrained="checkpoints/Phase_3/DenseNet201_improved.pth",
    
    epochs= 50
)

In [None]:
!tensorboard --logdir=runs

# 4. Plot confusion matrix (of the best model)

In [None]:
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# load model
checkpoint = torch.load("checkpoints/Phase_3/DenseNet201_improved_finetuned.pth", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

print(f"Loaded pretrained model:")
print(f"- val_loss={checkpoint['val_loss']:.4f}")
print(f"- val_acc={checkpoint['val_acc']:.4f}")
print(f"- val_recall={checkpoint['val_recall']:.4f}")
print(f"- val_precision={checkpoint['val_precision']:.4f}")
print(f"- val_f1={checkpoint['val_f1']:.4f}")

In [None]:
from training.confusion_mat import plot_confusion_matrix

plot_confusion_matrix(
    model=model,
    dataloader=val_dataloader,
    device=device,
    class_names=["Healthy", "PD"],
    # threshold=0.49,
)