<font size=5>**Define functions used on several other notebooks**</font>

In [None]:
import torch
import random
import datetime
import time
import matplotlib.pyplot as plt
import pickle

import pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

import csv

import torch.multiprocessing as mp

from torch.optim import Adam, AdamW
from torch.utils.data import Dataset, DataLoader, BatchSampler, RandomSampler,random_split, SequentialSampler
from torch.multiprocessing import Pool, Process, set_start_method

from tqdm import tqdm

from scipy import stats

import os.path
from sklearn.metrics import r2_score, mean_squared_error

import seaborn as sns

In [4]:
def MinMaxNormalization(x):
    x_min=np.min(x)
    x_max=np.max(x)
    return (x-x_min)/(x_max-x_min)

In [5]:
def estimate_density(data,precision=2):
    density = scipy.stats.gaussian_kde(data)
    precision=2

    viability_around_list=np.around(np.arange(0,1.501,10**(-1*precision)),precision)
    density_list=[density(x)[0] for x in viability_around_list]
    density_list_minmax=MinMaxNormalization(density_list)
    density_dic={viability_around_list[i]: density_list_minmax[i] for i in range(len(viability_around_list))}
    
    return density_dic

In [6]:
def get_weight(torch_tensor,gamma,density_dic,precision=2):
    label_range=[x[0] for x in np.around(np.float64(torch_tensor.cpu().numpy()),precision)]
    density=np.array([density_dic[x] for x in label_range])
    weight_for_loss=1-gamma*density

    return torch.Tensor(weight_for_loss).to(torch.device(device))

In [7]:
def PCC(y_pred, y_true):
    x=y_pred.clone()
    y=y_true.clone()
    vx=x-torch.mean(x)
    vy=y-torch.mean(y)
    cov=torch.sum(vx*vy)
    corr=cov/(torch.sqrt(torch.sum(vx**2))*torch.sqrt(torch.sum(vy**2)))
    return corr

def MSE(y_pred, y_true):
    return torch.mean((y_pred-y_true)**2)

def RMSE(y_pred, y_true):
    return torch.sqrt(torch.mean((y_pred-y_true)**2))
    
def WeightedMSE(y_pred, y_true,density_dic,gamma=0.75):
    weight=get_weight(torch_tensor=y_true,gamma=gamma,density_dic=density_dic)
    loss=torch.mean(weight*((y_pred-y_true)**2))
    return loss

class CustomLoss(nn.Module):
    def __init__(self,density_dic,alpha=1,beta=0.5,gamma=0.75):
        self.alpha=alpha
        self.beta=beta
        self.gamma=gamma
        self.density_dic=density_dic
        super(CustomLoss, self).__init__()
        
    def forward(self, y_pred, y_true):
        weighted_mse=WeightedMSE(y_pred=y_pred,y_true=y_true,gamma=self.gamma,density_dic=self.density_dic)
        mse=MSE(y_pred,y_true)
        corr=PCC(y_pred,y_true)
        loss=self.alpha*weighted_mse+self.beta*(1-corr)
        return loss,mse,corr

class CustomLoss_withoutDensity(nn.Module):
    def __init__(self,alpha=1,beta=1):
        self.alpha=alpha
        self.beta=beta
        super(CustomLoss_withoutDensity, self).__init__()
        
    def forward(self, y_pred, y_true):
        mse=MSE(y_pred,y_true)
        corr=PCC(y_pred,y_true)
        loss=self.alpha*mse+self.beta*(1-corr)
        return loss,mse,corr

In [8]:
def batch_dot(tensor1,tensor2,batch_size=1024):
    return (tensor1[None]*tensor2).sum(dim=-1).view(-1,1)

In [9]:
def get_tensor_value(tensor):
    return tensor.cpu().detach().numpy()

In [10]:
import datetime
def timestamp2datetime(timestamp):
    return datetime.datetime.fromtimestamp(timestamp).strftime('%c')

In [None]:
class Hook():
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, i, o):
        self.i=i
        self.o=o
    def close(self):
        self.hook.remove()

In [11]:
def get_intermediate_output(dataloader, model, target_layer):
    model.eval()
    hook=Hook(model._modules[target_layer])

    num_batches = len(dataloader)
    intermediate_output_list=[]
    final_output_list=[]
    with torch.no_grad():
        for sample in dataloader:
            X,y=MonotherapyDataset2device(sample,device)
            pred = model(X)
            final_output_list.append(pred)
            intermediate_output_list.append(hook.o)
    intermediate_output=torch.cat(intermediate_output_list)
    final_output=torch.cat(final_output_list)

    return intermediate_output, final_output