# Imports

In [None]:
# Check free GPUs:
!nvidia-smi

In [None]:
import tensorflow as tf
# Make sure the GPU is enabled 
assert tf.config.list_physical_devices('GPU')

gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[0], 'GPU') # Set GPUs to use

tf.random.set_seed(1) # Set seed

In [None]:
import sonnet as snt
from tqdm import tqdm, trange
from IPython.display import clear_output
import numpy as np
import pandas as pd
import time
import os
import selene_sdk
from torch.utils.tensorboard import SummaryWriter
import csv
import torch
import math

import sklearn.metrics
import scipy.stats
import scipy.stats

assert snt.__version__.startswith('2.0')

from filesampler import RPKMFileSampler
from utils import *

# Samplers

In [None]:
sampler = RPKMFileSampler(
        '/home/evmalkin/DeepCT/src/data/cross_validation/fold1/train_data_rpkm_log.bed',
        reference_sequence = selene_sdk.sequences.Genome('/mnt/datasets/DeepCT/male.hg19.fasta'),
        n_cell_types=39,
        sequence_length=196_608,
        # sequence_length = 115200, # works with attention pooling
        balance=False,
        zero_expression=None,
        keep_zero_percent=None)

validation_sampler = RPKMFileSampler(
        '/home/evmalkin/DeepCT/src/data/cross_validation/fold1_val/validate_data_rpkm_log.bed',
        reference_sequence = selene_sdk.sequences.Genome('/mnt/datasets/DeepCT/male.hg19.fasta'),
        n_cell_types=10,
        sequence_length=196_608,
        # sequence_length = 115200, # works with attention pooling
        balance=False,
        zero_expression=None,
        keep_zero_percent=None)

n_validation_samples = 5781
train_embeddings = get_embeddings('/home/evmalkin/DeepCT/src/data/cross_validation/fold1/selected_feature_embeddings_len5.csv')
validation_embeddings = get_embeddings('/home/evmalkin/DeepCT/src/data/cross_validation/fold1_val/selected_feature_embeddings_len5.csv')

# Define Model

In [None]:
loss_type = 'MSE' # 'MSE' or 'correlation'

In [None]:
def criterion(targets, predictions, evaluate=False, num_average=300):
    if loss_type == 'MSE':
        return tf.reduce_mean(tf.keras.losses.MSE(tf.squeeze(targets), tf.squeeze(predictions)))
    elif loss_type == 'correlation':
        if not evaluate:
            avg_correlation = np.mean(prev_correlations_train[-num_average:], dtype='float32')
            loss, correlation = correlation_loss(targets, predictions, avg_correlation)
            prev_correlations_train.append(correlation)
        else:
            avg_correlation = np.mean(prev_correlations_validate[-num_average:], dtype='float32')
            loss, correlation = correlation_loss(targets, predictions, avg_correlation)
            prev_correlations_validate.append(correlation)
        return loss

prev_correlations_train = [0]
prev_correlations_validate = [0]     
def correlation_loss(targets, predictions, prev_correlation):
    x = tf.squeeze(predictions)
    y = tf.squeeze(targets)

    vx = x - tf.math.reduce_mean(x)
    vy = y - tf.math.reduce_mean(y)
    
    vary = tf.math.reduce_sum(vy ** 2)
    
    if vary > 0:
        correlation = tf.math.reduce_sum(vx * vy) / (tf.math.sqrt(tf.math.reduce_sum(vx ** 2)) * tf.math.sqrt(vary))
    else:
        correlation = prev_correlation
    
    mse = tf.math.reduce_mean((predictions - targets)**2)
    
    corr_weight = 10
    mse_weight = 0.1
    cost = corr_weight*(1 - correlation) + mse_weight*mse
    
    return cost, correlation

In [None]:
import enformer
import importlib
importlib.reload(enformer)

learning_rate = tf.Variable(0., trainable=False, name='learning_rate')
initial_learning_rate = 0.00005
decay_steps = 5000
alpha = 1/20

learning_rate.assign(initial_learning_rate)

optimizer = snt.optimizers.Adam(learning_rate=learning_rate)

model = enformer.Enformer(channels=1536 // 4,
                          num_heads=4,
                          num_transformer_layers=5,
                          pooling_type='max')

train_step = create_step_function(model, optimizer)

In [None]:
def decayed_learning_rate(step):
    step = min(step, decay_steps)
    cosine_decay = 0.5 * (1 + math.cos(math.pi * step / decay_steps))
    decayed = (1 - alpha) * cosine_decay + alpha
    learning_rate.assign(initial_learning_rate * decayed)

# Training/Evaluation setup

In [None]:
steps_per_epoch = 10000
num_epochs = 50
report_stats_every_n_steps = 5000
n_batches = 64

In [None]:
def create_step_function(model, optimizer):

    @tf.function
    def train_step(sequence, targets, embeddings):
        with tf.GradientTape() as tape:
            outputs = model(sequence, embeddings, is_training=True)           
            loss = criterion(targets, outputs)

        gradients = tape.gradient(loss, model.trainable_variables)
        return loss, gradients, outputs
    
    return train_step


def evaluate_model(model, dataset, embeddings):
    @tf.function
    def predict(sequence, embeddings):
        return model(sequence, embeddings, is_training=False)
    
    all_predictions = []
    all_targets = []
    batch_losses = []
    for samples_batch in dataset:
        sequence = samples_batch._input_batches['sequence_batch']
        
        targets = samples_batch._target_batch
        #tars = tf.expand_dims(targets, axis=1)
        
        output = predict(sequence, embeddings)
        all_predictions.append(output.numpy().flatten())
        all_targets.append(targets.numpy().flatten())

#         loss = criterion(targets, output)
        loss = correlation_criterion(targets, output, evaluate=True)
        batch_losses.append(loss)
    
    prev_correlations_validate[:-300] = []
    return np.average(batch_losses), np.array(all_predictions), np.array(all_targets)

# Output Directory

In [None]:
direc = 'debug'

# Training

In [None]:
output_dir = 'results/' + direc

checkpoint_root = os.path.join(output_dir, "checkpoints/")
checkpoint = tf.train.Checkpoint(module=model)
checkpoint_name = "best_model"
save_prefix = checkpoint_root
manager = tf.train.CheckpointManager(checkpoint, save_prefix, checkpoint_name=checkpoint_name, max_to_keep=2)

tensorboard_writer = SummaryWriter(os.path.join(output_dir))

validation_data = create_validation_set(validation_sampler, n_samples = n_validation_samples)

evaluate_model(model, [validation_data[0]], validation_embeddings) # initialize model variables

min_loss = float("inf")
opt_step = 0

train_losses = []
train_preds = []
train_targets = []


# constant to scale sum of gradient
const = tf.constant(1/n_batches)
# get all trainable variables
t_vars = model.trainable_variables
# create a copy of all trainable variables with `0` as initial values
accum_tvars = [tf.Variable(tf.zeros_like(t_var),trainable=False) for t_var in t_vars]   
zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in accum_tvars]

num_steps = num_epochs * steps_per_epoch
    
for index in tqdm(range(num_steps)):
    samples_batch, target_cells = sampler.sample()
    
    sequence = samples_batch._input_batches['sequence_batch']
    
    targets = samples_batch._target_batch
    
    loss, gradients, train_outputs = train_step(sequence, targets, train_embeddings)
    train_losses.append(loss)
    train_preds.append(train_outputs.numpy().flatten())
    train_targets.append(targets.numpy().flatten())
    
    accum_ops = [accum_tvars[i].assign_add(tf.scalar_mul(const, grad)) for i, grad in enumerate(gradients)]

    if (index + 1) % n_batches == 0:
        optimizer.apply(accum_tvars, model.trainable_variables)
        zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in accum_tvars]

        opt_step += 1        
        # LEARNING RATE DECAY
        decayed_learning_rate(opt_step)
    
    if index and index % report_stats_every_n_steps == 0:
        validation_loss, validation_predictions, validation_targets = evaluate_model(model, validation_data, validation_embeddings)
        if validation_loss < min_loss:
            min_loss = validation_loss
            save_checkpoint(manager, is_best=True)
            print('Saving best loss', min_loss)
            
        train_loss = np.average(train_losses)
        train_losses[:] = []
        
        train_metrics = calc_metrics(np.array(train_targets), np.array(train_preds))
        train_targets[:] = []
        train_preds[:] = []
        
        validation_metrics = calc_metrics(validation_targets, validation_predictions)
            
        step = index
        log_metrics(tensorboard_writer, step, train_loss, validation_loss, train_metrics, validation_metrics)
        prev_correlations_train[:-300] = []

# CHECKPOINT FINAL VALUES?
tensorboard_writer.flush()