In [None]:
#basics
import os
from math import sqrt
import pandas
import numpy as np
import random


#torch stuff
import torch
from torch.nn import functional as F
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau
from torch.nn.utils import clip_grad_norm_
from torch_geometric.nn.models import AttentiveFP
from torch_geometric.loader import DataLoader

#stats
from scipy.stats import kendalltau

#HP tuning
from ray.tune.schedulers import ASHAScheduler
from ray import tune,train

#custom
from utils_data_prep import Prep_Graphs
from utils_plotting import plot_multiple_seeds
from utils_stats import perform_analysis

#other
from functools import partial

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#set random seeds
torch.manual_seed(0)
torch.cuda.manual_seed(0)
random.seed(0)
np.random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)



def get_path(sandbox):
    if sandbox:
        path = 'sandbox/'
    else:
        path = ''
    return path


#Direct Learning
def hp_finder(config,data_train,data_val_1,num_epochs=50):
    train_loader = DataLoader(data_train, batch_size=config["batch_size"], shuffle=True)
    val_loader = DataLoader(data_val_1, batch_size=config["batch_size"], shuffle=True)
    model = AttentiveFP(in_channels=23, hidden_channels=config["hidden_channels"], out_channels=1,
                        edge_dim=11, num_layers=config["num_layers"], num_timesteps=config["num_timesteps"],
                        dropout=0.0).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"],
                                weight_decay=10**-5)
    if config["Scheduler"] == "ReduceLROnPlateau":
        scheduler = ReduceLROnPlateau(optimizer, patience=2, factor=config["gamma"],verbose=False)
    elif config["Scheduler"] == "ExponentialLR":
        scheduler = ExponentialLR(optimizer, gamma=config["gamma"])
    else:
        raise ValueError('Scheduler not found')

    model = model.to(device)
    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        model.train()
        train_loss = train_func(train_loader, model, optimizer)
        train_losses.append(train_loss)

        model.eval()
        val_loss = val_func(val_loader, model)
        val_losses.append(val_loss)
        if config["Scheduler"] == "ReduceLROnPlateau":
            scheduler.step(val_loss)
        else:
            scheduler.step()

        ys,preds = get_preds(val_loader, model)
        k_tau = kendalltau(ys,preds).statistic
        train.report({"kendall_tau": k_tau})

def train_func(train_loader, model, optimizer):
    total_loss  = total_examples = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        data.y = data.y.view(-1, 1)
        loss = F.mse_loss(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs
        total_examples += data.num_graphs
        del data

    return sqrt(total_loss / total_examples)


def val_func(val_loader, model):
    total_loss = total_examples = 0
    for data in val_loader:
        data = data.to(device)
        data.y = data.y.view(-1, 1)
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        loss = F.mse_loss(out, data.y)
        total_loss += float(loss) * data.num_graphs
        total_examples += data.num_graphs
        del data

    return sqrt(total_loss / total_examples)

def get_preds(val_loader, model):
    preds,ys = [],[]
    for data in val_loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        preds.extend(out)
        ys.extend(data.y)
    preds = [float(p) for p in preds]
    ys = [float(y) for y in ys]         
    print(len(ys))
    return preds,ys

def hp_serach_direct(data_train,data_val):
    data_train = data_train
    data_val = data_val
    search_space = {
        #training params
        "lr": tune.loguniform(1e-4, 1e-2),
        "batch_size": tune.choice([8, 16, 32]),
        #model params
        "hidden_channels": tune.choice([100, 200, 300]),
        "num_layers": tune.choice([2, 3, 4]),
        "num_timesteps": tune.choice([1,2,3]),
        #scheduler params
        "gamma": tune.loguniform(0.9, 0.99),
        "Scheduler": tune.choice(["ReduceLROnPlateau", "ExponentialLR"])

    }

    tuner = tune.Tuner(
        partial(hp_finder, data_train=data_train, data_val_1=data_val),
        tune_config=tune.TuneConfig(
            num_samples=100,
            scheduler=ASHAScheduler(metric="kendall_tau", mode="max"),
        ),
        param_space=search_space,
    )
    results = tuner.fit()

    kendall_tau = 0
    dfs = {result.path: result.metrics_dataframe for result in results}

    for d in dfs.values():
        if max(d['kendall_tau']) > kendall_tau:
            kendall_tau = max(d['kendall_tau'])
            best_trial = d

    # #return best params
    # lr = max(d['config/lr'])
    # batch_size = max(d['config/batch_size'])
    # hidden_channels = max(d['config/hidden_channels'])
    # num_layers = max(d['config/num_layers'])
    # num_timesteps = max(d['config/num_timesteps'])
    # gamma = max(d['config/gamma'])
    # Scheduler = max(d['config/Scheduler'])

    return best_trial

#now production run
def train_direct(train_loader, model, optimizer):
    total_loss  = total_examples = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        data.y = data.y.view(-1, 1)
        loss = F.mse_loss(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs
        total_examples += data.num_graphs

    return sqrt(total_loss / total_examples)

def val_direct(train_loader, model):
    total_loss = total_examples = 0
    for data in train_loader:
        data = data.to(device)
        data.y = data.y.view(-1, 1)
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        loss = F.mse_loss(out, data.y)
        total_loss += float(loss) * data.num_graphs
        total_examples += data.num_graphs
    return sqrt(total_loss / total_examples)


def train_val_direct(model, train_loader, val_loader, optimizer, num_epochs,tid,assay_type,weight_seed,split_seed,verbose=True,sandbox=False,run_name='default',finetune_lr=None):
    path = get_path(sandbox)
    scheduler = ReduceLROnPlateau(optimizer, patience=2, factor=0.9,verbose=False)
    model = model.to(device)
    train_losses = []
    val_losses = []
    min_val_los = 1000
    for epoch in range(num_epochs):
        model.train()
        train_loss = train_direct(train_loader, model, optimizer)
        train_losses.append(train_loss)

        model.eval()
        val_loss = val_direct(val_loader, model)
        val_losses.append(val_loss)
        scheduler.step(val_loss)
        if val_loss < min_val_los:
            min_val_los = val_loss
            counter = 0
            if finetune_lr:
                torch.save(model.state_dict(), path+f'models/model_state_dict_{tid}_{assay_type}_single_{weight_seed}_{split_seed}_finetune_lr_{finetune_lr}_{run_name}.pt')
            else:
                torch.save(model.state_dict(), path+f'models/model_state_dict_{tid}_{assay_type}_single_{weight_seed}_{split_seed}_{run_name}.pt')
        else:
            counter += 1
        if counter > 10:
            if verbose:
                print('early stopping')
            break
        if verbose:
            print(f"Epoch {epoch + 1}/{num_epochs}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    return


