# FCN 

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay, precision_score, roc_auc_score, recall_score
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from FCN import QPPClassifier, QPPDataset

import warnings
warnings.filterwarnings('ignore')

## Training

In [None]:
labels_train = 'path_to_labels_train.csv' 
labels_val = 'path_to_labels_val.csv'
labels_test = 'path_to_labels_test.csv'
timeseries_dir =  'path_to_dir_with_flare_fits'

dataset_train = QPPDataset(labels_train, timeseries_dir)
dataset_val = QPPDataset(labels_val, timeseries_dir)
dataset_test = QPPDataset(labels_test, timeseries_dir)

train_dataloader = DataLoader(dataset_train, batch_size=64, num_workers=7, persistent_workers=True)
val_dataloader = DataLoader(dataset_val, batch_size=128, num_workers=7, persistent_workers=True)
test_dataloader = DataLoader(dataset_test, batch_size=128, num_workers=7, persistent_workers=True)

model = QPPClassifier(in_channels=2, layers_sizes=(64, 128, 64), kernel_sizes=(8, 5, 3))
trainer = pl.Trainer(max_epochs=100, callbacks=[EarlyStopping(monitor="val_loss", mode="min")])
trainer.fit(model, train_dataloader, val_dataloader)

## Validation

In [None]:
m_path = "path_to_model_weights.ckpt"

model = QPPClassifier.load_from_checkpoint(m_path)
model.eval()

np.random.seed(42)

y_hats = []
ys = []

for i in range(len(dataset_test)): 
    x, y = dataset_test[i]
    x = torch.from_numpy(x.reshape((1, 2, 300))).to(device=model.device)

    with torch.no_grad():
        y_hat = model(x)
        y_hats.append(y_hat.cpu().numpy().flatten()[0])
    ys.append(y)

y_hats = np.array(y_hats)

thres = 0.5
y_hats1 = y_hats.copy()
y_hats1[y_hats1 < thres] = 0
y_hats1[y_hats1 >= thres] = 1

print('accuracy:', accuracy_score(ys, y_hats1))
print('precision:', precision_score(ys, y_hats1))
print('recall:', recall_score(ys, y_hats1))
print('roc-auc:', roc_auc_score(ys, y_hats))
cm = confusion_matrix(ys, y_hats1)
cmplt = ConfusionMatrixDisplay(cm, display_labels=['No', 'Yes']).plot()