In [1]:
import optuna as opt
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
#from twilio.rest import Client
import pickle

import time

import sys
import os 






import os
os.environ["OPENBLAS_NUM_THREADS"] = "1"
sys.path.append('../../pscapes')
sys.path.append('../../nk-ml-2024/')


from torch.utils.data import DataLoader
from pscapes.landscape_class import ProteinLandscape
from pscapes.utils import dict_to_np_array, np_array_to_dict

from src.architectures import SequenceRegressionCNN, SequenceRegressionLinear, SequenceRegressionMLP, SequenceRegressionLSTM, SequenceRegressionTransformer

from src.ml_utils import train_val_test_split_ohe, landscapes_ohe_to_numpy
from src.hyperopt import objective_NK, sklearn_objective_NK

from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor 

from src.train_utils import train_models_from_hparams_NK, read_MLP_hparams, read_CNN_hparams, read_LSTM_hparams, read_transformer_hparams, instantiate_model_from_study

from src.analysis import get_latent_representation
import matplotlib.pyplot as plt
from sklearn.neighbors import kneighbors_graph
import math
import networkx as nx
from scipy.sparse import diags

import torchmetrics
from torchmetrics.regression import SpearmanCorrCoef, PearsonCorrCoef
from src.analysis import adjacency_to_diag_laplacian, sparse_dirichlet


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#set global parameters
HPARAM_PATH = '../hyperopt/results/NK_hyperopt_results.pkl'
DATA_PATH = '../data/nk_landscapes/'
MODEL_SAVEPATH = '../models/models_K3/'
RESULT_PATH = '../results/results_K3/NK_train_test_results.pkl'
SEQ_LEN = 6
AA_ALPHABET  = 'ACDEFG'
N_REPLICATES = 4

In [3]:
#load landscapes
#load landscape data 
landscapes = []
print('Loading landscapes.')
for k in range(SEQ_LEN):
    replicate_list = []
    for r in range(N_REPLICATES):
        landscape = ProteinLandscape(csv_path=DATA_PATH+'/k{0}_r{1}.csv'.format(k,r), amino_acids=AA_ALPHABET)
        replicate_list.append(landscape)
    landscapes.append(replicate_list)
landscapes = [[i.fit_OHE() for i in j] for j in landscapes]

print('Calculating train-test-val splits')
splits = [train_val_test_split_ohe(i, random_state=1) for i in landscapes]
#landscapes_ohe, xy_train, xy_val, xy_test, x_tests, y_tests = splits[k_index] 

Loading landscapes.
Calculating train-test-val splits


In [4]:
#load hparam studies 
with open(HPARAM_PATH, 'rb') as handle: 
    NK_hparams = pickle.load(handle)

model_names = ['linear', 'mlp', 'cnn', 'ulstm', 'blstm', 'transformer']
NK_hparams_k3 = {x:[NK_hparams[x][3] for _ in range(SEQ_LEN)] for x in NK_hparams.keys()}

## Implementing Latent Space Regularisation with kNN Dirichlet Energy

In [5]:
from src.ml_utils import EarlyStopping

In [7]:
49000/30

1633.3333333333333

In [None]:

def get_latent_representationw(model, model_name,  x_data):
    # Variable to store the final layer activation
    final_activation = None
    # Define a forward hook callback function to capture the output
    def forward_hook(module, input, output):
        nonlocal final_activation  # Use nonlocal to modify the variable outside the inner function
        final_activation = output

    # Attach the hook to the final layer of the model
    if model_name == 'mlp': 
        final_layer = model.fc_layers[-1]
        hook_handle = final_layer.register_forward_hook(forward_hook)
    elif model_name == 'cnn': 
        final_layer = list(model.children())[-2] #gets the final MaxPool1d layer 
        hook_handle = final_layer.register_forward_hook(forward_hook)
    elif model_name == 'ulstm' or model_name=='blstm': 
        hook_handle = model.lstm.register_forward_hook(forward_hook)
    elif model_name == 'transformer': 
        final_layer = list(model.children())[-2] #gets output of Transformer module
        hook_handle = final_layer.register_forward_hook(forward_hook)

    else: 
        raise Exception('Model name not recognised.')

    
    
    # Run a forward pass
    _ = model(x_data)

    # Remove the hook to prevent side effects
    hook_handle.remove()

    if model_name =='ulstm' or model_name=='blstm': 
        final_activation = final_activation[0][:, -1, :]
    # Return the captured activation
    return final_activation.detach()

def dirichlet_from_representation(rep_tensor, y_data, degree=30, n_jobs=-1): 
    """
    Constructs a kNN graph from the representation of a model, and computes the dirichlet energy of the 
    signal y over that kNN. 
    Args: 
        rep_tensor (torch.tensor):        torch tensor of shape ()
        y_data (np.array):                np array containing y_data of shape (n_samples, 1) or (n_samples,)
        degree (int):                     degree of the kNN graph
    Returns: 
        dirichlet_energy (float)    
    """
    A = kneighbors_graph(rep_tensor, n_neighbors=degree, n_jobs=n_jobs) #compute adjacency
    A = A.maximum(A.T) #ensure A is symmetric 
    L = adjacency_to_diag_laplacian(A)[1] #compute laplacian 
    dirichlet_energy = sparse_dirichlet(L, y_data)
    
    return dirichlet_energy



def train_model_de_reg(model, model_name, optimizer, loss_fn, train_loader, val_loader, n_epochs=30, device='cpu', 
                       patience=5, min_delta=1e-5, lambda_reg=0.1, knn_degree=30, x_data=None):
    early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
    model = model.to(device)
    val_epoch_losses = []
    train_epoch_losses = []
    de_epoch_losses = []
    epoch_latent_reps  = []
    
    for epoch in range(n_epochs):
        model.train()  # Training mode
        train_loss = 0.0
        de_loss    = 0.0
        for x_batch, y_batch in train_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            
            optimizer.zero_grad()
            y_pred = model(x_batch)
            loss = loss_fn(y_pred, y_batch) #compute primary loss criterion
            latent_rep  = get_latent_representationw(model, model_name, x_batch) #compute dirichlet energy on train_data
            de = dirichlet_from_representation(latent_rep, y_batch, degree=knn_degree, n_jobs=-1) #compute dirichlet energy of this batch's knn
            
            total_loss = loss + lambda_reg * de
            
            total_loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            de_loss += de.item()
            
        epoch_loss = train_loss/len(train_loader)
        de_epoch_loss = de_loss/len(train_loader)
        train_epoch_losses.append(epoch_loss)
        de_epoch_losses.append(de_epoch_loss)

        # Validation phase
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for inputs, targets in val_loader:
                #print(inputs.shape)
                #print(targets.shape)
                inputs, targets = inputs.to(device), targets.to(device)
                
                outputs = model(inputs)
                loss = loss_fn(outputs, targets)
                val_loss += loss.item()
        val_loss /= len(val_loader)
        val_epoch_losses.append(val_loss)

        assert x_data!=None, 'No landscape x_data provided for latent representation calculation.'
        x_data = x_data.to(device)
        latent_rep  = get_latent_representation(model, model_name, x_data)
        epoch_latent_reps.append(latent_rep)

        
            
            
        

        print(f"Epoch [{epoch+1}/{n_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

        # Check early stopping
        early_stopping(val_loss, model)
        if early_stopping.early_stop:
            print("Early stopping triggered")
            break

    # Load the best model after early stopping
    model.load_state_dict(torch.load(early_stopping.path))
    return model, train_epoch_losses, val_epoch_losses, epoch_latent_reps
