In [None]:
#%%
from base64 import encode
from itertools import count
from pathlib import Path
import pickle
from pydoc import doc
from typing import List
from black import out
from matplotlib.pyplot import text

import numpy as np
from sklearn import preprocessing, svm
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.metrics import classification_report
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils import compute_class_weight
import spacy
import torch
import typer
from experiments.metadata.constraints.datasets.constraints_link.spacy_annoto_connectors import annoto2spacyDocs
from origami_indexers.metadata.constraint_extraction.relation_component import get_possible_rels, rels_between_spans
from origami_indexers.utils.pipes import NLP_CONTINGENCIES_PROPERTIES_DATASET_IMPORT_PIPES
from origami_indexers.utils.spacy import NLP_DEFAULT
from origami_indexers.utils.s3.s3 import LOCAL_CACHE_FOLDER, LOCAL_DATA_FOLDER, download_file_if_not_exists
from origami_indexers.utils.s3.file_readers import s3_contingencies_annotations
import pandas as pd
import torch.nn as nn
import pytorch_lightning as pl

from pytorch_lightning.loggers import CSVLogger
from sklearn.model_selection import StratifiedShuffleSplit, train_test_split
from torch.utils.data import TensorDataset, DataLoader

from ray.tune.integration.pytorch_lightning import TuneReportCallback
from ray import tune
import torchmetrics

docs_dump = LOCAL_DATA_FOLDER / 'annoto_docs_dump.pkl'
lookup_table = dict(
    c_severity_none=0,
    c_severity_low=1,
    c_severity_medium=2,
    c_severity_high=3
) 
classes = list(lookup_table.keys())


SEVERITY_DATASET_FOLDER = LOCAL_DATA_FOLDER / 'severity'
SEVERITY_DATASET_FOLDER.mkdir(exist_ok=True)
SEVERITY_X_PATH = SEVERITY_DATASET_FOLDER / "severity_x.bin"
SEVERITY_Y_PATH = SEVERITY_DATASET_FOLDER / "severity_y.bin"
SEVERITY_CLASS_WEIGHTS_PATH = SEVERITY_DATASET_FOLDER / "severity_class_weights.bin"
MAX_SEQ_LEN = 200


#create_text_dataset_from_docs()
#create_dataset_from_bert_embedding()
# %%
class SeverityDataModule(pl.LightningDataModule):
    
    def __init__(self, config):

        super().__init__()
        
        self.batch_size = config['batch_size']
        self.n_workers= 4    

    def prepare_data(self):
        # download, split, etc...
        # only called on 1 GPU/TPU in distributed
        
        x = torch.load(SEVERITY_X_PATH)
        y = torch.load(SEVERITY_Y_PATH)

        ss_split =  StratifiedShuffleSplit( test_size=0.33, random_state=42)
        
        train_idx, val_idx = next(ss_split.split(x,y))

    
        self.x_train = x[train_idx]
        self.x_val = x[val_idx]
        self.y_train = y[train_idx]
        self.y_val = y[val_idx]
    

    def train_dataloader(self):
        train_split = TensorDataset(self.x_train, self.y_train)
        return DataLoader(train_split, shuffle=True, batch_size=self.batch_size,num_workers=self.n_workers)
    def val_dataloader(self):
        val_split = TensorDataset(self.x_val, self.y_val)
        return DataLoader(val_split,batch_size=self.batch_size, num_workers=self.n_workers)



# %%

class SeverityClf(pl.LightningModule):
    def __init__(self, config, input_dim):
        super(SeverityClf, self).__init__()
        self.save_hyperparameters()
        self.lr = config["lr"]
        self.dropout_rate = config["dropout_rate"]
        
        self.batch_size = config["batch_size"]

        # Input shape is (batch_size, seq_len,  n_dim)
        self.lstm = nn.LSTM(input_size=input_dim, 
                            hidden_size=config["lstm_hidden_size"],
                            num_layers=config["n_lstm_layers"], 
                            dropout=self.dropout_rate, 
                            batch_first=True)
        self.linear = nn.Linear(config["lstm_hidden_size"], out_features=4)
        self.loss = nn.CrossEntropyLoss(weight=config['weights'])
        self.val_f1_weighted = torchmetrics.F1Score(num_classes=4, average='weighted')
        self.val_precision_weighted = torchmetrics.Precision(num_classes=4, average='weighted')
        self.val_recall_weighted = torchmetrics.Recall(num_classes=4, average='weighted')

    def forward(self, x):

        x, (hn,cn) = self.lstm(x)
        # x :  batch_size, seq_len, lstm_hidden_size
        # x is already the output of the last lstm layer
        
        x = self.linear(x[:,-1,:])  # take hidden state of last elt of seq


        # x is raw output, not log
        return x

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        output = self.forward(x)
        
        loss = self.loss(output, y)
        self.log("train/loss", loss, on_step=False, on_epoch=True)

        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        output = self.forward(x)
        loss = self.loss(output, y)
        # Compute metrics
        hard_labels = torch.argmax(output, dim=-1)
        self.val_f1_weighted(hard_labels, y)
        self.val_recall_weighted(hard_labels, y)
        self.val_precision_weighted(hard_labels, y)
        self.log('val/tr_loss', loss,on_step=False, on_epoch=True)
        self.log('val/precision_weighted', self.val_precision_weighted, on_step=False, on_epoch=True)
        self.log('val/recall_weighted', self.val_recall_weighted, on_step=False, on_epoch=True)
        self.log('val/f1_weighted', self.val_f1_weighted, on_step=False, on_epoch=True)
        # https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html

    

# %%


def train_with_config(config, input_dim, num_gpus, enable_tune=True):
    model = SeverityClf(config, input_dim)
    tune_callback = TuneReportCallback(["val/tr_loss", "val/f1_weighted"], on="validation_end")
    callbacks = [tune_callback] if enable_tune else [pl.callbacks.progress.TQDMProgressBar()]
    trainer = pl.Trainer(
        max_epochs=config['epochs'],
        gpus=num_gpus,
        progress_bar_refresh_rate=0,
        callbacks=callbacks)


    trainer.fit(model, datamodule=SeverityDataModule(config))
    return trainer


@app.command()
def hp_search(input_dim:int=768, num_samples:int=10, cpus_per_trial:int=1, gpus_per_trial:int=0,name='foo'):
    """
    Run HP search with ray tune 
    """
    class_weights =  torch.load(SEVERITY_CLASS_WEIGHTS_PATH)
    config = {
        "lstm_hidden_size": tune.choice([64, 128]),
        "n_lstm_layers": tune.choice([1,2]),
        "lr": tune.loguniform(1e-5, 1e-3),
        "dropout_rate": tune.uniform(0.1,0.4),
        "batch_size": tune.choice([8, 16, 32]),
        "epochs": tune.choice([5,10,20,40]),
        "weights":class_weights
    }

    trainable = tune.with_parameters(
        train_with_config,input_dim=input_dim, num_gpus=gpus_per_trial, enable_tune=True)
    return tune.run(
        trainable,
        resources_per_trial={
            "cpu": cpus_per_trial,
            "gpu": gpus_per_trial
        },
        metric="val/tr_loss",
        search_alg='hyperopt',
        mode="min",
        config=config,
        num_samples=num_samples,
        max_concurrent_trials=8,
        name=name)

#%%
#analysis = hp_search(num_samples=n_samples, cpus_per_trial=1,gpus_per_trial=0,name="severity")
#%%

# #%%
# best_trial_config = analysis.get_best_trial("loss", "min", "last").config
# best_trial_config
# # %%
# best_trainer = train_with_config( best_trial_config, 29, national_df, disable_logging=True, num_gpus=0)



