### Import Statements

In [1]:
import seaborn as sns
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from matplotlib import pyplot as plt

from src.hair_color.dataset import HairColorDataset
from src.hair_color.dataloader import HairColorDataLoader
from src.hair_color.classifier import Classifier, Trainer, Tester

### Environment Variables

In [2]:
dataset_path = '../../data/hair_color/build37_autosomal/dataset'
classifier_out_path = '../../models/hair_color'

dataset_file_name = 'dataset'
classifier_file_name = 'classifier'

train_test_split_ratio = 0.8
batch_size = 16
shuffle_train = True
shuffle_test = False
weighted_sampling_train = True
weighted_sampling_test = False
one_hot_features = False
one_hot_labels = True
input_size = 1
hidden_size = 96
num_layers = 1
bidirectional = False
dropout = 0
learning_rate = 0.01
step_size = 1
gamma = 0.9
num_epochs = 32

### Load Data

In [3]:
dataset = HairColorDataset().load(dataset_path, dataset_file_name)

### Split Data

In [4]:
train_set, test_set = dataset.split_train_test(train_test_split_ratio)

### Load Data Loaders

In [11]:
train_loader = HairColorDataLoader(train_set, batch_size=batch_size, shuffle=shuffle_train, weighted_sampling=weighted_sampling_train, one_hot_features=one_hot_features, one_hot_labels=one_hot_labels)
test_loader = HairColorDataLoader(test_set, batch_size=batch_size, shuffle=shuffle_test, weighted_sampling=weighted_sampling_test,  one_hot_features=one_hot_features, one_hot_labels=one_hot_labels)

### Initialize Classifier

In [12]:
classifier = Classifier(input_size, hidden_size, num_layers, bidirectional, dropout)

### Initialize Trainer

In [13]:
loss = CrossEntropyLoss()
optimizer = Adam(classifier.parameters(), lr=learning_rate)
scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
trainer = Trainer(classifier, loss, optimizer, scheduler, train_loader, test_loader, classifier_out_path, classifier_file_name)

### Train Classifier

In [15]:
train_losses, train_accuracies, val_losses, val_accuracies = trainer.train(num_epochs)

KeyboardInterrupt: 

### Plot Training

In [None]:
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss over Epochs')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Train Accuracy')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy over Epochs')
plt.legend()

plt.show()

### Load Best Classifier

In [None]:
classifier.load(classifier_out_path, classifier_file_name)

### Initialize Tester

In [None]:
tester = Tester(classifier, loss, test_loader)

### Test Classifier

In [None]:
test_loss, test_accuracy, test_precision, test_recall, test_f1, test_auroc, test_cm = tester.test()
print(f'Test Loss: {test_loss}')
print(f'Test Accuracy: {test_accuracy}')
print(f'Test Precision: {test_precision}')
print(f'Test Recall: {test_recall}')
print(f'Test F1: {test_f1}')
print(f'Test AUROC: {test_auroc}')

### Confusion Matrix

In [None]:
plt.figure(figsize=(8, 6))
sns.heatmap(test_cm, annot=True, fmt='d', cmap='Blues', xticklabels=["Blonde", "Brown", "Black"], yticklabels=["Blonde", "Brown", "Black"])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()