In [2]:
# fundnn.preprocessor
import torch
import pickle
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from sklearn.preprocessing import StandardScaler


def load_data(input_path, file_name):
    """
    Preprocesses input node features and GECs dataSets.

    Args:
        input_path: Paths to input DataSets.
        file_name (str): node features or GECs dataSets file name.

    Returns:
        Tensor: 2D tensor of train, valid and test sets. Each row is a node, each column is a feature.
    """

    file_path = input_path + file_name
    with open(file_path, 'rb') as f:
        datasets_dict = pickle.load(f)

    if file_name == "feature_dict.pkl":
        scaler = StandardScaler()
        train_tensor = torch.FloatTensor(scaler.fit_transform(datasets_dict['train'].sort_index().values))
        valid_tensor = torch.FloatTensor(scaler.transform(datasets_dict['valid'].sort_index().values))
        test_tensor = torch.FloatTensor(scaler.transform(datasets_dict['test'].sort_index().values))

        print('Node feature dimension:\ntrain data:{}\nvalid data:{}\ntest data:{}\n'
              .format(train_tensor.shape, valid_tensor.shape, test_tensor.shape))
        return train_tensor, valid_tensor, test_tensor

    elif file_name == "GECs_dict.pkl":
        train_tensor = torch.FloatTensor(datasets_dict['train'].sort_index().values)
        valid_tensor = torch.FloatTensor(datasets_dict['valid'].sort_index().values)
        test_array = datasets_dict['test'].sort_index()

        print('GECS data dimension:\ntrain data:{}\nvalid data:{}\ntest data:{}\n'
              .format(train_tensor.shape, valid_tensor.shape, test_array.shape))
        return train_tensor, valid_tensor, test_array


def get_dataloader(batch_size, node_train, gecs_train, node_valid, gecs_valid):
    """
    Build batched data for training and validation sets.

    Args:
        batch_size (int): Batch size hyperparameter(see Class TranscriptionNet_Hyperparameters).
        node_train (tensor): Tensor data for training sets of node features.
        node_valid (tensor): Tensor data for testing sets of node features.
        gecs_train (tensor): Tensor data for training sets of GECs data.
        gecs_valid (tensor): Tensor data for testing sets of GECs data.

    Returns:
        Dataloader of train and valid sets.
    """
    train_dataset = TensorDataset(node_train, gecs_train)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valid_dataset = TensorDataset(node_valid, gecs_valid)
    valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
    return train_dataloader, valid_dataloader






In [1]:
# dataprocess.py
import os
import pickle
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split


def load_rawdata(input_path):
    """
    Load raw data(node feature and true GECs)
    """
    datasets = []
    file_paths = os.listdir(input_path)
    for file in file_paths:
        file_path = os.path.join(input_path, file)
        data = pd.read_csv(file_path, sep=",", index_col=0).sort_index()
        datasets.append(data)

    RNAi = datasets[3].sort_index()
    OE = datasets[2].sort_index()
    CRISPR = datasets[0].sort_index()
    node_feature = datasets[1].sort_index()
    return node_feature, RNAi, OE, CRISPR


def train_test_val_split(GECs, ratio_train, ratio_test, ratio_valid):
    """
    Split the data into train, test, and validation sets.
    """
    train, middle = train_test_split(GECs, train_size=ratio_train, test_size=ratio_test + ratio_valid, random_state=0)
    ratio = ratio_valid / (1 - ratio_train)
    test, validation = train_test_split(middle, test_size=ratio, random_state=0)

    train = train.sort_index()
    test = test.sort_index()
    validation = validation.sort_index()
    return train, test, validation


def save_datasets(df_dict, file_path):
    """
    Save the datasets as pickle files.
    """
    with open(file_path, 'wb') as f:
        pickle.dump(df_dict, f)


def datasets_split(GECs, feature, save_path):
    """
    Datasets split and data scaler.
    """
    GECs_filter = GECs[GECs.index.isin(feature.index)]
    GECs_train, GECs_test, GECs_valid = train_test_val_split(GECs_filter, 0.7, 0.2, 0.1)

    scaler = MinMaxScaler(feature_range=(-1, 1))
    GECs_train_scaled = scaler.fit_transform(GECs_train.values)
    GECs_test_scaled = scaler.fit_transform(GECs_test.values)
    GECs_valid_scaled = scaler.fit_transform(GECs_valid.values)

    GECs_train = pd.DataFrame(GECs_train_scaled, index=GECs_train.index, columns=GECs_filter.columns).sort_index()
    GECs_test = pd.DataFrame(GECs_test_scaled, index=GECs_test.index, columns=GECs_filter.columns).sort_index()
    GECs_valid = pd.DataFrame(GECs_valid_scaled, index=GECs_valid.index, columns=GECs_filter.columns).sort_index()

    feature_train = feature[feature.index.isin(GECs_train.index)].sort_index()
    feature_test = feature[feature.index.isin(GECs_test.index)].sort_index()
    feature_valid = feature[feature.index.isin(GECs_valid.index)].sort_index()

    GECs_dict = {'train': GECs_train, 'valid': GECs_valid, 'test': GECs_test}
    feature_dict = {'train': feature_train, 'valid': feature_valid, 'test': feature_test}

    save_datasets(GECs_dict, save_path + 'GECs_dict.pkl')
    save_datasets(feature_dict, save_path + 'feature_dict.pkl')

    return scaler


In [3]:
# config_parser
import torch
import torch.nn as nn


class TranscriptionNet_Hyperparameters(object):
    """Defines the default TranscriptionNet config parameters."""

    def __init__(self):
        # FunDNN Model
        self.FunDNN_layers = 5  # Number of layers for FunDNN
        self.FunDNN_epochs = 1000  # Number of epochs for FunDNN model training
        self.FunDNN_batch_size = 32  # Number of GECs in each batch
        self.FunDNN_hidden_nodes = 1024  # Number of nodes in each layers for FunDNN
        self.FunDNN_dropout_rate = 0.1  # Dropout layer ratio of FunDNN
        self.FunDNN_activation_func = nn.LeakyReLU  # Activation function between each layer of FunDNN
        self.FunDNN_learning_rate = 0.00035  # Adadelta optimizer learning rate
        self.FunDNN_RNAi_path = "example_data/datasets/RNAi/"  # RNAi GECs and node feature path
        self.FunDNN_OE_path = "example_data/datasets/RNAi/"  # OE GECs and node feature path
        self.FunDNN_CRISPR_path = "example_data/datasets/RNAi/"  # CRISPR GECs and node feature path
        self.FunDNN_save_path = "example_data/result/"

        # GenSAN Model
        self.GenSAN_heads = 2  # Number of attention heads of row-wise self-attention block
        self.GenSAN_blocks = 3  # Number of transformer encoder units for GenSAN
        self.GenSAN_recycles = 3  # Recycle times of GenSAN model
        self.GenSAN_epochs = 110  # Number of epochs for GenSAN model training
        self.GenSAN_warmup_epochs = 5  # Number of warm-up epochs for GenSAN model training
        self.GenSAN_batch_size = 32  # Number of GECs in each batch
        self.GenSAN_GECs_dimension = 978  # The last dimension of GEC size(978).
        self.GenSAN_hidden_nodes = 1024  # Number of nodes in each feed forward neural network layer
        self.GenSAN_dropout_rate = 0.05  # Dropout layer ratio of GenSAN
        self.GenSAN_learning_rate = 0.0000045  # Adam optimizer learning rate
        self.GenSAN_weight_decay = 1e-5  # Adam optimizer weight_decay
        self.GenSAN_save_path = "example_data/result/"

        self.PMSELoss_beta = 0.1  # Weight hyperparameter of the combination of mse and pcc


class Device:
    """Returns the currently used device by calling `Device()`.

    Returns:
        str: Either "cuda" or "cpu".
    """

    _device = "cuda" if torch.cuda.is_available() else "cpu"

    def __new__(cls) -> str:
        return cls._device


In [6]:
# pmseloss
import torch
import torch.nn as nn


def PMSELoss(true, predict, beta):
    """
    Loss function combining the mean squared error (MSE) and Pearson correlation losses.

    Args:
        true (tensor): Tensor data of true GECs data.
        predict (tensor): Tensor data of GECs predicted by the model.
        beta (float): Weight hyperparameter of the combination of mse and pcc(see Class TranscriptionNet_Hyperparameters).

    Returns:
        Float: Combination loss, mse loss and pearson_loss.
    """

    error = true - predict
    sqr_error = torch.square(error)
    mse_loss = torch.mean(sqr_error)

    m_pred = torch.mean(predict, dim=1, keepdim=True)
    m_true = torch.mean(true, dim=1, keepdim=True)
    pred_m, true_m = predict - m_pred, true - m_true

    cos = nn.CosineSimilarity(dim=1, eps=1e-6)
    pearson_loss = torch.mean(1 - cos(pred_m, true_m))

    loss = (1 - beta) * mse_loss + beta * pearson_loss
    return loss, mse_loss, pearson_loss


In [None]:
# fundnn.train_function
import time
import copy
import torch
import pandas as pd
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error

# from config_parser import Device
# from PMSELoss import PMSELoss


def train_function(model, train_dataloader, optimizer, beta):
    """
    FunDNN model training function.

    Args:
        model : FunDNN initialization model.
        train_dataloader : Training set batch data.
        optimizer : Adadelta optimizer.
        beta (float): Weight hyperparameter of the combination of mse and pcc(see Class TranscriptionNet_Hyperparameters).

    Returns:
        Float: Combination loss, mse loss and pearson_loss of train sets.
    """

    train_loss = 0
    mse_loss = 0
    pcc_loss = 0
    num_batches = len(train_dataloader)

    model.train()
    for node_feature, gecs_data in train_dataloader:
        node_feature = node_feature.to(Device())
        gecs_data = gecs_data.to(Device())

        optimizer.zero_grad()
        predict_gecs = model(node_feature)
        loss, mse, pcc = PMSELoss(gecs_data, predict_gecs, beta)
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            train_loss += loss.item()
            mse_loss += mse.item()
            pcc_loss += pcc.item()

    train_loss /= num_batches
    mse_loss /= num_batches
    pcc_loss /= num_batches
    return train_loss, mse_loss, pcc_loss


def valid_function(model, valid_dataloader, beta):
    """
    FunDNN model validation function.

    Args:
        model : The FunDNN model after training on the training set.
        valid_dataloader : Validation set batch data.
        beta (float): Weight hyperparameter of the combination of mse and pcc(see Class TranscriptionNet_Hyperparameters).

    Returns:
        Float: Combination loss, mse loss and pearson_loss of valid sets.
    """

    valid_loss = 0
    mse_loss = 0
    pcc_loss = 0
    num_batches = len(valid_dataloader)

    model.eval()
    with torch.no_grad():
        for node_feature, gecs_data in valid_dataloader:
            node_feature = node_feature.to(Device())
            gecs_data = gecs_data.to(Device())

            predict_gecs = model(node_feature)
            loss, mse, pcc = PMSELoss(gecs_data, predict_gecs, beta)
            valid_loss += loss.item()
            mse_loss += mse.item()
            pcc_loss += pcc.item()

    valid_loss /= num_batches
    mse_loss /= num_batches
    pcc_loss /= num_batches
    return valid_loss, mse_loss, pcc_loss


# def test_evaluate(best_model, feature_test, gecs_test):
#     """
#     FunDNN model test evaluation function.

#     Args:
#         best_model : The FunDNN model after all iterations of training.
#         feature_test (tensor): Test set of node features.
#         gecs_test (ndarray): Test set of GECs data.

#     Returns:
#         None
#     """

#     feature_test = feature_test.to(Device())
#     feature_test_predict = best_model(feature_test).cpu().detach().numpy()

#     # feature_test_predict_df = pd.DataFrame(feature_test_predict, index=net_test.index)
#     # feature_test_predict_df.to_csv(save_path + "feature_test_predict.csv", index=True)

#     d = []
#     pcc = []
#     for i in range(feature_test_predict.shape[0]):
#         pearson = np.corrcoef(feature_test_predict[i], gecs_test[i])[0, 1]
#         item_d, _ = stats.ks_2samp(feature_test_predict[i], gecs_test[i])
#         pcc.append(pearson)
#         d.append(item_d)
#     abs_pcc = abs(np.array(pcc))
#     abs_pcc_mean = abs_pcc.mean()
#     d_mean = np.array(d).mean()

#     mse = mean_squared_error(gecs_test, feature_test_predict)

#     print('=' * 30)
#     print('test evaluate result:\nAverage pcc: {}\nAverage mse: {}\nAverage D: {}'
#           .format(abs_pcc_mean, mse, d_mean))
#     print('=' * 30)

#     # return abs_pearson

# def test_evaluate(best_model, feature_test, gecs_test):
#     """
#     FunDNN model test evaluation function.

#     Args:
#         best_model : The FunDNN model after all iterations of training.
#         feature_test (tensor): Test set of node features.
#         gecs_test (ndarray): Test set of GECs data.

#     Returns:
#         None
#     """

#     try:
#         print("Starting test_evaluate function...")

#         # Move the feature test to the device
#         feature_test = feature_test.to(Device())
#         print(f"Feature test shape: {feature_test.shape}")

#         # Get the model predictions
#         feature_test_predict = best_model(feature_test).cpu().detach().numpy()
#         print(f"Feature test prediction shape: {feature_test_predict.shape}")

#         # Check if there are NaN or infinite values in predictions
#         if np.any(np.isnan(feature_test_predict)) or np.any(np.isinf(feature_test_predict)):
#             print("Warning: NaN or Infinite values detected in predictions!")

#         # Check if there are NaN or infinite values in gecs_test
#         if np.any(np.isnan(gecs_test)) or np.any(np.isinf(gecs_test)):
#             print("Warning: NaN or Infinite values detected in gecs_test!")

#         # Ensure that feature_test_predict and gecs_test have the same number of samples
#         if feature_test_predict.shape[0] != gecs_test.shape[0]:
#             print(f"Mismatch in number of samples: feature_test_predict has {feature_test_predict.shape[0]} samples, but gecs_test has {gecs_test.shape[0]} samples.")
#             return

#         d = []
#         pcc = []
#         for i in range(feature_test_predict.shape[0]):
#             # Print first 10 elements of feature_test_predict and gecs_test for debugging
#             if i < 10:
#                 print(f"Sample {i} - feature_test_predict: {feature_test_predict[i][:10]} ...")
#                 print(f"Sample {i} - gecs_test: {gecs_test[i][:10]} ...")

#             # Calculate Pearson correlation and KS test
#             pearson = np.corrcoef(feature_test_predict[i], gecs_test[i])[0, 1]
#             item_d, _ = stats.ks_2samp(feature_test_predict[i], gecs_test[i])

#             pcc.append(pearson)
#             d.append(item_d)

#         abs_pcc = abs(np.array(pcc))
#         abs_pcc_mean = abs_pcc.mean()
#         d_mean = np.array(d).mean()

#         # Calculate Mean Squared Error (MSE)
#         mse = mean_squared_error(gecs_test, feature_test_predict)

#         print('=' * 30)
#         print('Test evaluate result:')
#         print(f"Average PCC: {abs_pcc_mean}")
#         print(f"Average MSE: {mse}")
#         print(f"Average D: {d_mean}")
#         print('=' * 30)

#     except Exception as e:
#         print(f"An error occurred in test_evaluate: {e}")
#         raise

def test_evaluateFunDNN(best_model, feature_test, gecs_test):
    """
    FunDNN model test evaluation function.

    Args:
        best_model : The FunDNN model after all iterations of training.
        feature_test (tensor): Test set of node features.
        gecs_test (ndarray): Test set of GECs data.

    Returns:
        None
    """

    try:
        feature_test = feature_test.to(Device())

        feature_test_predict = best_model(feature_test).cpu().detach().numpy()

        # debug stuff bc original wasnt working
        # check if there are nan or infinite values in predictions
        if np.any(np.isnan(feature_test_predict)) or np.any(np.isinf(feature_test_predict)):
            print("warning: nan/infinite values detected in predictions")

        # check if there are nan or infinite values in gecs_test
        if np.any(np.isnan(gecs_test)) or np.any(np.isinf(gecs_test)):
            print("warning: nan or infinite values detected in gecs_test")

        # ensure feature_test_predict and gecs_test have the same number of samples
        if feature_test_predict.shape[0] != gecs_test.shape[0]:
            print(f"mismatch in number of samples: feature_test_predict has {feature_test_predict.shape[0]} samples, but gecs_test has {gecs_test.shape[0]} samples.")
            return

        # # check the type and structure of gecs_test
        # print(f"Type of gecs_test: {type(gecs_test)}")
        # print(f"Index of gecs_test (first 10 rows): {gecs_test.index[:10]}")

        # Use .iloc to access positional indexing
        d = []
        pcc = []
        for i in range(feature_test_predict.shape[0]):
            # if i < 10:
                # print(f"Sample {i} - feature_test_predict: {feature_test_predict[i][:10]} ...")
                # Access gecs_test using .iloc for positional indexing
                # print(f"Sample {i} - gecs_test (iloc): {gecs_test.iloc[i][:10]} ...")


            # calculate pearson correlation and ks test
            pearson = np.corrcoef(feature_test_predict[i], gecs_test.iloc[i])[0, 1]
            item_d, _ = stats.ks_2samp(feature_test_predict[i], gecs_test.iloc[i])

            pcc.append(pearson)
            d.append(item_d)

        abs_pcc = abs(np.array(pcc))
        abs_pcc_mean = abs_pcc.mean()
        d_mean = np.array(d).mean()

        # calculate MSE
        mse = mean_squared_error(gecs_test, feature_test_predict)

        print('=' * 30)
        print('Test evaluate result:')
        print(f"Average PCC: {abs_pcc_mean}")
        print(f"Average MSE: {mse}")
        print(f"Average D: {d_mean}")
        print('=' * 30)

    except Exception as e:
        print(f"An error occurred in test_evaluate: {e}")
        raise


def plot_loss_figure(epochs, train_loss, valid_loss):
    """
    Draw the training loss value image

    Args:
        epochs (int): Number of epochs for FunDNN model training
        train_loss (list): List of training set loss values.
        valid_loss (list): List of valid set loss values.
    """
    plt.rcParams['font.family'] = 'Arial'
    plt.rcParams['xtick.labelsize'] = 22
    plt.rcParams['ytick.labelsize'] = 22
    plt.rcParams['axes.titlesize'] = 28
    plt.rcParams['axes.labelsize'] = 28
    plt.rcParams['legend.fontsize'] = 22
    plt.rcParams['axes.unicode_minus'] = False
    plt.figure(figsize=(7, 7), dpi=144)

    plt.title('Loss figure')
    plt.plot(range(epochs), train_loss, color='red', linestyle='--', label='train loss', linewidth=2)
    plt.plot(range(epochs), valid_loss, color='dodgerblue', linestyle='-', label='valid loss', linewidth=2)
    plt.legend(loc='upper right', frameon=False)
    plt.xlabel('Epochs')
    plt.ylabel('Loss values')
    plt.show()


def trainFunDNN(epochs, model, train_dataloader, valid_dataloader, optimizer, beta):
    """
    Train the FunDNN model.

    Args:
        epochs (int): Number of epochs for FunDNN model training
        model (nn.Module): FunDNN model
        train_dataloader : Training set batch data.
        valid_dataloader : Validation set batch data.
        optimizer : Adadelta optimizer.
        beta (float): Weight hyperparameter of the combination of mse and pcc(see Class TranscriptionNet_Hyperparameters).
    Returns:
        best_model (nn.Module): The trained FunDNN model with the lowest validation loss.
    """

    min_loss = float("inf")
    best_model = None

    train_losses = []
    valid_losses = []
    for epoch in range(epochs):
        epoch_start_time = time.time()

        tra_loss, tra_mse_loss, tra_cor_loss = train_function(model=model,
                                                              train_dataloader=train_dataloader,
                                                              optimizer=optimizer,
                                                              beta=beta)
        val_loss, val_mse_loss, val_cor_loss = valid_function(model=model,
                                                              valid_dataloader=valid_dataloader,
                                                              beta=beta)
        train_losses.append(tra_loss)
        valid_losses.append(val_loss)

        if val_loss < min_loss:
            min_loss = val_loss
            best_model = copy.deepcopy(model)

        print('end of epoch:{:3d} | time:{:5.2f}s | train loss:{:5.5f} | valid loss:{:5.5f} | train MseLoss:{:5.5f} | '
              'train PccLoss:{:5.5f} | valid MseLoss:{:5.5f} | valid PccLoss:{:5.5f}'
              .format(epoch, (time.time() - epoch_start_time), tra_loss, val_loss, tra_mse_loss, tra_cor_loss,
                      val_mse_loss, val_cor_loss))

    # plot_loss_figure(epochs, train_losses, valid_losses)

    return best_model


def feature_predict(best_model, node_feature, save_path, name):
    """
    Predict the pre-GECs

    Args:
        best_model (nn.Module): The trained FunDNN model with the lowest validation loss.
        node_feature: All network nodes embedded features.
        save_path: Path to save the trained model.
        name: GECs type(RNAi, OE or CRISPR)

    Returns:
        pre_GECs (dataframe): pre-GECs.
    """
    node_feature_tensor = torch.FloatTensor(node_feature.values)
    node_feature_tensor = node_feature_tensor.to(Device())

    node_feature_predict = best_model(node_feature_tensor).cpu().detach().numpy()

    node_feature_predict = pd.DataFrame(node_feature_predict, index=node_feature.index).sort_index()
    node_feature_predict.to_csv(save_path + name + "_pre_GECs.csv", index=True, sep=",")

    print('\npredict finish:\npre_GECs size:{}\n'.format(node_feature_predict.shape))
    return node_feature_predict


In [57]:
import torch
import torch.nn as nn
# from config_parser import Device
# from FunDNN.train_function import train, test_evaluate, feature_predict


class FunDNN(nn.Module):
    def __init__(self,
                 num_layers: int,
                 hidden_nodes: int,
                 activate_func: nn.Module,
                 dropout_rate: float):
        """
        The FunDNN model.

        Args:
            num_layers (int): Number of layers for FunDNN.
            hidden_nodes (int): Number of nodes in each layers for FunDNN.
            activate_func (nn.Module): Activation function between each layer of FunDNN.
            dropout_rate (float, optional): Dropout layer ratio of FunDNN.
        """

        super(FunDNN, self).__init__()
        layers = [nn.Linear(512, hidden_nodes), activate_func(), nn.Dropout(dropout_rate)]
        for _ in range(num_layers - 3):
            layers.extend([nn.Linear(hidden_nodes, hidden_nodes), activate_func(), nn.Dropout(dropout_rate)])
        layers.extend([nn.Linear(hidden_nodes, hidden_nodes), nn.Tanh(), nn.Dropout(dropout_rate)])
        layers.append(nn.Linear(hidden_nodes, 978))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        """
        Forward pass logic.

        Args:
            x (Tensor): 2D tensor of node features.Each row is a node, each column is a feature.

        Returns:
            Tensor: 2D tensor of .
        """
        return self.model(x)


# def run_model(num_layers, hidden_nodes, activate_func, dropout_rate,
#               learning_rate, epochs, train_dataloader, valid_dataloader, beta,
#               feature_test, gecs_test, save_path, node_feature, name):
#     """
#     FunDNN model training process.

#     Args:
#         num_layers (int): Number of layers for FunDNN.
#         hidden_nodes (int): Number of nodes in each layers for FunDNN.
#         activate_func (nn.Module): Activation function between each layer of FunDNN.
#         dropout_rate (float): Dropout layer ratio of FunDNN
#         learning_rate (float): Adadelta optimizer learning rate
#         epochs (int): Number of epochs for FunDNN model training
#         train_dataloader: Training set batch data.
#         valid_dataloader: Validation set batch data.
#         beta (float): Weight hyperparameter of the combination of mse and pcc(see Class TranscriptionNet_Hyperparameters).
#         feature_test (tensor): Test set of node features.
#         gecs_test (ndarray): Test set of GECs data.
#         save_path: Path to save the trained model.
#         node_feature: All network nodes embedded features.
#         name: GECs type(RNAi, OE or CRISPR)

#     Returns:
#         best_model (nn.Module): The trained FunDNN model with the lowest validation loss.
#     """

#     model = FunDNN(num_layers=num_layers,
#                    hidden_nodes=hidden_nodes,
#                    activate_func=activate_func,
#                    dropout_rate=dropout_rate)
#     model = model.to(Device())

#     optimizer = torch.optim.Adadelta(model.parameters(), lr=learning_rate)

#     best_model = train(epochs=epochs,
#                        model=model,
#                        train_dataloader=train_dataloader,
#                        valid_dataloader=valid_dataloader,
#                        optimizer=optimizer,
#                        beta=beta)

#     test_evaluate(best_model, feature_test, gecs_test)

#     # torch.save(best_model, save_path + "FunDNN best model.pt")

#     pre_GECs = feature_predict(best_model=best_model,
#                                node_feature=node_feature,
#                                save_path=save_path,
#                                name=name)
#     return pre_GECs


In [66]:
# fundnn.model
import torch
import torch.nn as nn
# from config_parser import Device
# from FunDNN.train_function import train, test_evaluate, feature_predict


class FunDNN(nn.Module):
    def __init__(self,
                 num_layers: int,
                 hidden_nodes: int,
                 activate_func: nn.Module,
                 dropout_rate: float):
        """
        The FunDNN model.

        Args:
            num_layers (int): Number of layers for FunDNN.
            hidden_nodes (int): Number of nodes in each layers for FunDNN.
            activate_func (nn.Module): Activation function between each layer of FunDNN.
            dropout_rate (float, optional): Dropout layer ratio of FunDNN.
        """

        super(FunDNN, self).__init__()
        layers = [nn.Linear(512, hidden_nodes), activate_func(), nn.Dropout(dropout_rate)]
        for _ in range(num_layers - 3):
            layers.extend([nn.Linear(hidden_nodes, hidden_nodes), activate_func(), nn.Dropout(dropout_rate)])
        layers.extend([nn.Linear(hidden_nodes, hidden_nodes), nn.Tanh(), nn.Dropout(dropout_rate)])
        layers.append(nn.Linear(hidden_nodes, 978))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        """
        Forward pass logic.

        Args:
            x (Tensor): 2D tensor of node features.Each row is a node, each column is a feature.

        Returns:
            Tensor: 2D tensor of .
        """
        return self.model(x)


def run_model(num_layers, hidden_nodes, activate_func, dropout_rate,
              learning_rate, epochs, train_dataloader, valid_dataloader, beta,
              feature_test, gecs_test, save_path, node_feature, name, warmup_epoch):
    """
    FunDNN model training process.

    Args:
        num_layers (int): Number of layers for FunDNN.
        hidden_nodes (int): Number of nodes in each layers for FunDNN.
        activate_func (nn.Module): Activation function between each layer of FunDNN.
        dropout_rate (float): Dropout layer ratio of FunDNN
        learning_rate (float): Adadelta optimizer learning rate
        epochs (int): Number of epochs for FunDNN model training
        train_dataloader: Training set batch data.
        valid_dataloader: Validation set batch data.
        beta (float): Weight hyperparameter of the combination of mse and pcc(see Class TranscriptionNet_Hyperparameters).
        feature_test (tensor): Test set of node features.
        gecs_test (ndarray): Test set of GECs data.
        save_path: Path to save the trained model.
        node_feature: All network nodes embedded features.
        name: GECs type(RNAi, OE or CRISPR)

    Returns:
        best_model (nn.Module): The trained FunDNN model with the lowest validation loss.
    """

    model = FunDNN(num_layers=num_layers,
                   hidden_nodes=hidden_nodes,
                   activate_func=activate_func,
                   dropout_rate=dropout_rate)
    model = model.to(Device())

    optimizer = torch.optim.Adadelta(model.parameters(), lr=learning_rate)

    best_model = trainFunDNN(epochs=epochs,
                       model=model,
                       train_dataloader=train_dataloader,
                       valid_dataloader=valid_dataloader,
                       optimizer=optimizer,
                       beta=beta)
                      #  warmup_epoch=warmup_epoch,
                      #  learning_rate=learning_rate)
    # print("here")
    test_evaluateFunDNN(best_model, feature_test, gecs_test)
    # print("here2")
    # torch.save(best_model, save_path + "FunDNN best model.pt")
    # print(pre_GECs.index)

    pre_GECs = feature_predict(best_model=best_model,
                               node_feature=node_feature,
                               save_path=save_path,
                               name=name)
    # print(pre_GECs.index)
    return pre_GECs


In [11]:
# gensan.preprocessor
import pickle
import torch
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler


def load_GECs(input_path, file_name):
    """
    Preprocesses input GECs dataSets.

    Args:
        input_path: Paths to input DataSets.
        file_name (str): GECs dataSets folder name.

    Returns:
        Dataframe: train, valid and test sets. Each row is a node, each column is a feature.
    """
    file_path = input_path + file_name
    with open(file_path, 'rb') as f:
        datasets_dict = pickle.load(f)

    train = datasets_dict["train"].sort_index()
    valid = datasets_dict["valid"].sort_index()
    test = datasets_dict["test"].sort_index()
    return train, valid, test


def GECs_combine(true_GECs, predict_GECs):
    """
    Combines true and predicted GECs dataSets.

    Args:
        true_GECs (DataFrame): True GECs dataSets.
        predict_GECs (DataFrame): Predicted GECs dataSets.

    Returns:
        Dataframe: Combined dataSets.
    """
    predict = predict_GECs[~ predict_GECs.index.isin(true_GECs.index)]
    predict.columns = true_GECs.columns
    combine = pd.concat([true_GECs, predict], axis=0).sort_index()
    return combine


def get_datasets(GECs, pre_GECs, combine1, combine2):
    """
    Gets train, valid and test sets.

    Args:
        GECs (DataFrame): True GECs dataSets(train, valid or test).
        pre_GECs (DataFrame): Predicted GECs dataSets.
        combine1 (DataFrame): Combined GECs dataSets.
        combine2 (DataFrame): Combined GECs dataSets.

    Returns:
        ndarray: train, valid or test sets.
    """

    index = GECs.index.sort_values()

    # input
    pre_gecs_item = pre_GECs[pre_GECs.index.isin(index)].sort_index()
    combine1_item = combine1[combine1.index.isin(index)].sort_index()
    combine2_item = combine2[combine2.index.isin(index)].sort_index()

    datasets = []
    for i in range(len(index)):
        single_pre_gecs = pre_gecs_item[pre_gecs_item.index == index[i]].values
        single_combine1 = combine1_item[combine1_item.index == index[i]].values
        single_combine2 = combine2_item[combine2_item.index == index[i]].values
        single_data = np.concatenate((single_pre_gecs, single_combine1, single_combine2), axis=0)
        datasets.append(single_data)

    datasets_array = np.array(datasets)
    return datasets_array


def datasets_scaled(trainSets, validSets, testSets):
    """
    Standardization of training, validation, and testing data sets.
    """

    train_gecs, train_cmap, train_gene = trainSets.shape
    valid_gecs, valid_cmap, valid_gene = validSets.shape
    test_gecs, test_cmap, test_gene = testSets.shape

    trainSets_item = trainSets.reshape((train_gecs, train_cmap * train_gene))
    validSets_item = validSets.reshape((valid_gecs, valid_cmap * valid_gene))
    testSets_item = testSets.reshape((test_gecs, test_cmap * test_gene))

    scaler = StandardScaler()
    train_scaled_item = scaler.fit_transform(trainSets_item)
    valid_scaled_item = scaler.transform(validSets_item)
    test_scaled_item = scaler.transform(testSets_item)

    train_scaled = train_scaled_item.reshape((train_gecs, train_cmap, train_gene))
    valid_scaled = valid_scaled_item.reshape((valid_gecs, valid_cmap, valid_gene))
    test_scaled = test_scaled_item.reshape((test_gecs, test_cmap, test_gene))

    train = torch.FloatTensor(train_scaled)
    valid = torch.FloatTensor(valid_scaled)
    test = torch.FloatTensor(test_scaled)
    return train, valid, test


def get_pre_GECs(pre_GECs, combine1, combine2):
    """
    GenSAN model prediction data

    Args:
        pre_GECs (DataFrame): Predicted GECs dataSets.
        combine1 (DataFrame): Combined GECs dataSets.
        combine2 (DataFrame): Combined GECs dataSets.

    Returns:
        Tensor: 3D Tensor of prediction data.
    """

    index = pre_GECs.index.sort_values()

    pre_GECs = pre_GECs.sort_index()
    combine1 = combine1.sort_index()
    combine2 = combine2.sort_index()

    predict_data = []
    for i in range(len(index)):
        pre_GECs_item = pre_GECs[pre_GECs.index == index[i]].values
        combine1_item = combine1[combine1.index == index[i]].values
        combine2_item = combine2[combine2.index == index[i]].values
        predict_data_item = np.concatenate((pre_GECs_item, combine1_item, combine2_item), axis=0)
        predict_data.append(predict_data_item)

    predict_data = torch.FloatTensor(np.array(predict_data))

    return predict_data


def GenSAN_preprocessor(true_GECs1, true_GECs2, predict_GECs1, predict_GECs2, pre_GECS, input_path, file_name):
    """
    GenSAN data preprocessing
    """

    combine_GECs1 = GECs_combine(true_GECs1, predict_GECs1)
    combine_GECs2 = GECs_combine(true_GECs2, predict_GECs2)

    train_GECs, valid_GECs, test_GECs = load_GECs(input_path, file_name)

    train = get_datasets(train_GECs, pre_GECS, combine_GECs1, combine_GECs2)
    valid = get_datasets(valid_GECs, pre_GECS, combine_GECs1, combine_GECs2)
    test = get_datasets(test_GECs, pre_GECS, combine_GECs1, combine_GECs2)

    train_data, valid_data, test_data = datasets_scaled(train, valid, test)

    print('pre-GECS dimension:\ntrain data:{}\nvalid data:{}\ntest data:{}\n'
              .format(train_data.shape, valid_data.shape, test_data.shape))

    return train_data, valid_data, test_data, combine_GECs1, combine_GECs2


In [13]:
# gensan utils
import math
import copy
import torch
import torch.nn as nn
from math import cos, pi
import torch.nn.functional as f


class PositionwiseFeedForward(nn.Module):
    def __init__(self,
                 GECs_dimension: int,
                 hidden_nodes: int,
                 dropout_rate: float):
        """
        Feed forward neural network layer.

        Args:
            GECs_dimension (int): The last dimension of GEC size(978).
            hidden_nodes (int): Number of nodes in each layers for Feed forward neural network.
            dropout_rate (float, optional): Dropout layer ratio of Feed forward neural network.
        """

        super(PositionwiseFeedForward, self).__init__()
        self.w1 = nn.Linear(GECs_dimension, hidden_nodes)
        self.w2 = nn.Linear(hidden_nodes, GECs_dimension)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        return self.w2(self.dropout(f.leaky_relu(self.w1(x))))


class LayerNorm(nn.Module):
    def __init__(self, dimension=978, eps=1e-6):
        """
        LayerNorm module.

        Args:
            dimension (int, optional): The dimension of norm.Defaults to 978
            eps (float, optional): A value added to the denominator for numerical stability. Defaults to 1e-6.
        """
        super(LayerNorm, self).__init__()
        self.a2 = nn.Parameter(torch.ones(dimension))
        self.b2 = nn.Parameter(torch.zeros(dimension))
        self.eps = eps

    def forward(self, x):
        """输入参数x代表来自上一层的输出"""
        mean = x.mean(-1, keepdims=True)
        std = x.std(-1, keepdims=True)
        return self.a2 * (x - mean) / (std + self.eps) + self.b2


class SublayerConnection(nn.Module):
    def __init__(self, dropout_rate):
        """
        SublayerConnection module.

        Args:
            dropout_rate (float): Dropout rate of each sub-layer.
        """
        super(SublayerConnection, self).__init__()

        self.norm = LayerNorm()
        self.dropout = nn.Dropout(p=dropout_rate)

    def forward(self, x, sublayer):
        """
        Forward pass logic.

        Args:
            x (Tensor): The input of the previous layer or sub-layer.
            sublayer (nn.Module): The sublayer function in the sublayer connection.
        """

        return x + self.dropout(sublayer(self.norm(x)))


def attention(query, key, value, dropout=None):
    """
    Implementation of attention mechanism.

    Args:
        query: A tensor of queries.
        key: A tensor of key.
        value: A tensor of values.
        dropout (nn.Dropout): The dropout layer.Default is None
    """

    d_k = query.size(-1)
    scores = torch.as_tensor(torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k))

    p_attn = f.softmax(scores, dim=-1)

    if dropout is not None:
        p_attn = dropout(p_attn)

    return torch.matmul(p_attn, value)


class ColumnAttention(nn.Module):
    def __init__(self, dropout_rate):
        """
        column-wise attention.

        Args:
            dropout_rate (float): Dropout layer ratio of column-wise attention.
        """
        super(ColumnAttention, self).__init__()
        self.dropout = nn.Dropout(p=dropout_rate)

    def forward(self, query, key, value):
        """
        Forward pass logic.

        Args:
            query (Tensor): A tensor of queries.
            key (Tensor): A tensor of key.
            value (Tensor): A tensor of values.
        """
        query = torch.transpose(query, -1, -2)
        key = torch.transpose(key, -1, -2)
        value = torch.transpose(value, -1, -2)

        col_attn = attention(query, key, value, dropout=self.dropout)

        col_attn = torch.transpose(col_attn, -1, -2)

        return col_attn


def clones(module, num_clone):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(num_clone)])


class RowAttention(nn.Module):
    def __init__(self, heads, GECs_dimension, dropout_rate=0.1):
        """
        Multi-head attention.

        Args:
            heads (int): The number of heads.
            GECs_dimension (int): The last dimension of GEC size(978).
            dropout_rate (float): The dropout layer.Default is None
        """

        super(RowAttention, self).__init__()

        assert GECs_dimension % heads == 0

        self.d_k = GECs_dimension // heads
        self.heads = heads
        self.linears = clones(nn.Linear(GECs_dimension, GECs_dimension), 4)
        self.dropout = nn.Dropout(p=dropout_rate)

    def forward(self, query, key, value):
        batch_size = query.size(0)
        query, key, value = \
            [model(x).view(batch_size, -1, self.heads, self.d_k).transpose(1, 2)
             for model, x in zip(self.linears, (query, key, value))]

        x = attention(query, key, value, dropout=self.dropout)
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.heads * self.d_k)
        return self.linears[-1](x)


class Generator(nn.Module):
    def __init__(self,
                 GECs_dimension: int,
                 hidden_nodes: int,
                 dropout_rate: float):
        """
        Generator.

        Args:
            GECs_dimension (int): The last dimension of GEC size(978).
            hidden_nodes (int): Number of nodes in each layer.
            dropout_rate (float, optional): Dropout layer ratio.
        """

        super(Generator, self).__init__()
        self.w1 = nn.Linear(GECs_dimension, hidden_nodes)
        self.w2 = nn.Linear(hidden_nodes, GECs_dimension)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        output_item = self.w2(self.dropout(torch.tanh(self.w1(x))))
        output = output_item[:, 0, :]
        return output


def adjust_learning_rate(optimizer, warmup_epoch, current_epoch, max_epoch, lr_min=0, lr_max=0.1, warmup=True):
    """
    Cosine preheating mechanism regulates learning rate.

    Args:
        optimizer (nn.Module): Adam optimizer.
        warmup_epoch (int): Number of warm-up epochs.
        current_epoch (int): Current epoch.
        max_epoch (int): Number of epochs for model training.
        lr_min (float, optional): Minimum learning rate. The default is 0
        lr_max (float, optional): Max learning rate. The default is 0.1
        warmup (bool, optional)
    """
    warmup_epoch = warmup_epoch if warmup else 0
    if current_epoch < warmup_epoch:
        lr = lr_max * current_epoch / warmup_epoch
    else:
        lr = lr_min + (lr_max - lr_min) * (
                1 + cos(pi * (current_epoch - warmup_epoch) / (max_epoch - warmup_epoch))) / 2
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


In [80]:
# gensan_train_function
import time
import copy
import torch
import numpy as np
import pandas as pd
from scipy import stats
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error

# from config_parser import Device
# from PMSELoss import PMSELoss
# from GenSAN.utils import adjust_learning_rate
# from GenSAN.model import GenSAN_model


def train_function(model, train_dataloader, optimizer, beta):
    """
    GenSAN model training function.

    Args:
        model : GenSAN initialization model.
        train_dataloader : Training set batch data.
        optimizer : Adam optimizer.
        beta (float): Weight hyperparameter of the combination of mse and pcc(see Class TranscriptionNet_Hyperparameters).

    Returns:
        Float: Combination loss, mse loss and pearson_loss of train sets.
    """

    train_loss = 0
    mse_loss = 0
    pcc_loss = 0
    num_batches = len(train_dataloader)

    model.train()
    for node_feature, gecs_data in train_dataloader:
        node_feature = node_feature.to(Device())
        gecs_data = gecs_data.to(Device())

        optimizer.zero_grad()
        predict_gecs = model(node_feature)
        loss, mse, pcc = PMSELoss(gecs_data, predict_gecs, beta)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

        optimizer.step()
        with torch.no_grad():
            train_loss += loss.item()
            mse_loss += mse.item()
            pcc_loss += pcc.item()

    train_loss /= num_batches
    mse_loss /= num_batches
    pcc_loss /= num_batches
    return train_loss, mse_loss, pcc_loss


def valid_function(model, valid_dataloader, beta):
    """
    GenSAN model validation function.

    Args:
        model : The GenSAN model after training on the training set.
        valid_dataloader : Validation set batch data.
        beta (float): Weight hyperparameter of the combination of mse and pcc(see Class TranscriptionNet_Hyperparameters).

    Returns:
        Float: Combination loss, mse loss and pearson_loss of valid sets.
    """

    valid_loss = 0
    mse_loss = 0
    pcc_loss = 0
    num_batches = len(valid_dataloader)

    model.eval()
    with torch.no_grad():
        for node_feature, gecs_data in valid_dataloader:
            node_feature = node_feature.to(Device())
            gecs_data = gecs_data.to(Device())

            predict_gecs = model(node_feature)
            loss, mse, pcc = PMSELoss(gecs_data, predict_gecs, beta)
            valid_loss += loss.item()
            mse_loss += mse.item()
            pcc_loss += pcc.item()

    valid_loss /= num_batches
    mse_loss /= num_batches
    pcc_loss /= num_batches
    return valid_loss, mse_loss, pcc_loss


# def test_evaluateGenSAN(best_model, pre_gecs_test, gecs_test):
#     """
#     GenSAN model test evaluation function.

#     Args:
#         best_model : The GenSAN model after all iterations of training.
#         pre_gecs_test (tensor): Test set of pre-GECs.
#         gecs_test (ndarray): Test set of GECs data.
#     """

#     feature_test = pre_gecs_test.to(Device())
#     feature_test_predict = best_model(feature_test).cpu().detach().numpy()

#     # feature_test_predict_df = pd.DataFrame(feature_test_predict, index=net_test.index)
#     # feature_test_predict_df.to_csv(save_path + "feature_test_predict.csv", index=True)

#     d = []
#     pcc = []
#     for i in range(feature_test_predict.shape[0]):
#         pearson = np.corrcoef(feature_test_predict[i], gecs_test[i])[0, 1]
#         item_d, _ = stats.ks_2samp(feature_test_predict[i], gecs_test[i])
#         pcc.append(pearson)
#         d.append(item_d)
#     abs_pcc = abs(np.array(pcc))
#     abs_pcc_mean = abs_pcc.mean()
#     d_mean = np.array(d).mean()

#     mse = mean_squared_error(gecs_test, feature_test_predict)

#     print('=' * 30)
#     print('test evaluate result:\nAverage pcc: {}\nAverage mse: {}\nAverage D: {}'
#           .format(abs_pcc_mean, mse, d_mean))
#     print('=' * 30)

#     # return abs_pearson


def test_evaluateGenSAN(best_model, pre_gecs_test, gecs_test):
    """
    GenSAN model test evaluation function.

    Args:
        best_model : The FunDNN model after all iterations of training.
        feature_test (tensor): Test set of node features.
        gecs_test (ndarray): Test set of GECs data.

    Returns:
        None
    """

    try:
        feature_test = pre_gecs_test.to(Device())

        feature_test_predict = best_model(feature_test).cpu().detach().numpy()

        # debug stuff bc original wasnt working
        # check if there are nan or infinite values in predictions
        if np.any(np.isnan(feature_test_predict)) or np.any(np.isinf(feature_test_predict)):
            print("warning: nan/infinite values detected in predictions")

        # check if there are nan or infinite values in gecs_test
        if np.any(np.isnan(gecs_test)) or np.any(np.isinf(gecs_test)):
            print("warning: nan or infinite values detected in gecs_test")

        # ensure feature_test_predict and gecs_test have the same number of samples
        if feature_test_predict.shape[0] != gecs_test.shape[0]:
            print(f"mismatch in number of samples: feature_test_predict has {feature_test_predict.shape[0]} samples, but gecs_test has {gecs_test.shape[0]} samples.")
            return

        # # check the type and structure of gecs_test
        # print(f"Type of gecs_test: {type(gecs_test)}")
        # print(f"Index of gecs_test (first 10 rows): {gecs_test.index[:10]}")

        # Use .iloc to access positional indexing
        d = []
        pcc = []
        for i in range(feature_test_predict.shape[0]):
            # if i < 10:
                # print(f"Sample {i} - feature_test_predict: {feature_test_predict[i][:10]} ...")
                # Access gecs_test using .iloc for positional indexing
                # print(f"Sample {i} - gecs_test (iloc): {gecs_test.iloc[i][:10]} ...")


            # calculate pearson correlation and ks test
            pearson = np.corrcoef(feature_test_predict[i], gecs_test.iloc[i])[0, 1]
            item_d, _ = stats.ks_2samp(feature_test_predict[i], gecs_test.iloc[i])

            pcc.append(pearson)
            d.append(item_d)

        abs_pcc = abs(np.array(pcc))
        abs_pcc_mean = abs_pcc.mean()
        d_mean = np.array(d).mean()

        # calculate MSE
        mse = mean_squared_error(gecs_test, feature_test_predict)

        print('=' * 30)
        print('Test evaluate result:')
        print(f"Average PCC: {abs_pcc_mean}")
        print(f"Average MSE: {mse}")
        print(f"Average D: {d_mean}")
        print('=' * 30)

    except Exception as e:
        print(f"An error occurred in test_evaluate: {e}")
        raise


def plot_loss_figure(epochs, train_loss, valid_loss):
    """
    Draw the training loss value image

    Args:
        epochs (int): Number of epochs for FunDNN model training
        train_loss (list): List of training set loss values.
        valid_loss (list): List of valid set loss values.
    """
    plt.rcParams['font.family'] = 'Arial'
    plt.rcParams['xtick.labelsize'] = 22
    plt.rcParams['ytick.labelsize'] = 22
    plt.rcParams['axes.titlesize'] = 28
    plt.rcParams['axes.labelsize'] = 28
    plt.rcParams['legend.fontsize'] = 22
    plt.rcParams['axes.unicode_minus'] = False
    plt.figure(figsize=(7, 7), dpi=144)

    plt.title('Loss figure')
    plt.plot(range(epochs), train_loss, color='red', linestyle='--', label='train loss', linewidth=2)
    plt.plot(range(epochs), valid_loss, color='dodgerblue', linestyle='-', label='valid loss', linewidth=2)
    plt.legend(loc='upper right', frameon=False)
    plt.xlabel('Epochs')
    plt.ylabel('Loss values')
    plt.show()


def train(epochs, model, train_dataloader, valid_dataloader, optimizer, beta, warmup_epoch, learning_rate):
    """
    Train the GenSAN model.

    Args:
        epochs (int): Number of epochs for GenSAN model training
        model (nn.Module): GenSAN model
        train_dataloader : Training set batch data.
        valid_dataloader : Validation set batch data.
        optimizer : Adam optimizer.
        beta (float): Weight hyperparameter of the combination of mse and pcc(see Class TranscriptionNet_Hyperparameters).
        warmup_epoch (int): Number of warm-up epochs for GenSAN model training
        learning_rate (float): The initial learning rate.
    Returns:
        best_model (nn.Module): The trained GenSAN model with the lowest validation loss.
    """

    min_loss = float("inf")
    best_model = None

    train_losses = []
    valid_losses = []
    for epoch in range(epochs):
        epoch_start_time = time.time()

        tra_loss, tra_mse_loss, tra_cor_loss = train_function(model=model,
                                                              train_dataloader=train_dataloader,
                                                              optimizer=optimizer,
                                                              beta=beta)
        val_loss, val_mse_loss, val_cor_loss = valid_function(model=model,
                                                              valid_dataloader=valid_dataloader,
                                                              beta=beta)
        train_losses.append(tra_loss)
        valid_losses.append(val_loss)

        if val_loss < min_loss:
            min_loss = val_loss
            best_model = copy.deepcopy(model)

        adjust_learning_rate(optimizer=optimizer, warmup_epoch=warmup_epoch, current_epoch=epoch, max_epoch=epochs,
                             lr_min=0, lr_max=learning_rate, warmup=True)

        print('end of epoch:{:3d} | time:{:5.2f}s | train loss:{:5.5f} | valid loss:{:5.5f} | train MseLoss:{:5.5f} | '
              'train PccLoss:{:5.5f} | valid MseLoss:{:5.5f} | valid PccLoss:{:5.5f}'
              .format(epoch, (time.time() - epoch_start_time), tra_loss, val_loss, tra_mse_loss, tra_cor_loss,
                      val_mse_loss, val_cor_loss))

    # plot_loss_figure(epochs, train_losses, valid_losses)

    return best_model


def feature_predict(input_matrix, best_model, pre_GECs, scaler, save_path, name, length=64):
    """
    Predict the GECs and inverse MinMaxScaler

    Args:
        input_matrix (tensor): 3D tensor composed of pre-GECs of RNAi, OE, and CRISPR.
        best_model (nn.Module): The trained GenSAN model with the lowest validation loss.
        length (int): Divide the length of input_matrix.Defaults to 64.
        pre_GECs (dataframe): Single pre-GECs data(RNAi, OE or CRISPR).
        scaler: MinMaxScaler object.
        save_path : Path to save predict GECs.
        name (str): GECs type(RNAi, OE or CRISPR)
    Returns:
        predict_GECs (dataframe): predict GECs.
    """

    sub_matrix = [input_matrix[i:i + length, :, :] for i in range(0, len(input_matrix), length)]
    predict_GECs = []
    for sub in range(len(sub_matrix)):
        inputs = sub_matrix[sub].to(Device())
        sub_predict_GECs = pd.DataFrame(best_model(inputs).cpu().detach().numpy())
        predict_GECs.append(sub_predict_GECs)
    predict_GECs = pd.concat(predict_GECs, ignore_index=True)

    inverse_predict_GECs = scaler.inverse_transform(predict_GECs.values)
    inverse_predict_GECs = pd.DataFrame(inverse_predict_GECs, index=pre_GECs.index)
    inverse_predict_GECs.to_csv(save_path + name + "_predict_GECs.csv", index=True, sep=",")

    print('\npredict finish:\npredict GECs:{}\n'.format(inverse_predict_GECs.shape))
    return predict_GECs


def run_GenSAN_model(blocks, GECs_dimension, hidden_nodes, heads, dropout_rate, recycles,
                     learning_rate, weight_decay,
                     epochs, train_dataloader, valid_dataloader, beta, warmup_epoch,
                     pre_gecs_test, gecs_test, save_path,
                     input_matrix, length, pre_GECs, scaler, name):
    """
    FunDNN model training process.

    Args:
        blocks (int): Number of transformer encoder units for GenSAN.
        GECs_dimension (int): The last dimension of GEC size(978).
        hidden_nodes (int): Number of nodes in each feed forward neural network layer.
        heads (int): Number of attention heads of row-wise self-attention block.
        dropout_rate (float): Dropout layer ratio of GenSAN.
        recycles (int): Recycle times of GenSAN model.
        learning_rate (float): Adam optimizer initial learning rate.
        weight_decay (float): Adam optimizer weight decay parameter.
        epochs (int): Number of epochs for GenSAN model training
        train_dataloader : Training set batch data.
        valid_dataloader : Validation set batch data.
        beta (float): Weight hyperparameter of the combination of mse and pcc(see Class TranscriptionNet_Hyperparameters).
        warmup_epoch (int):Number of warm-up epochs for GenSAN model training
        pre_gecs_test (tensor): Test set of pre-GECs..
        gecs_test (ndarray): Test set of GECs data.
        save_path : Path to save predict GECs and best model.
        input_matrix (tensor): 3D tensor composed of pre-GECs of RNAi, OE, and CRISPR.
        length (int): Divide the length of input_matrix
        pre_GECs (dataframe): Single pre-GECs data(RNAi, OE or CRISPR).
        scaler: MinMaxScaler object.
        name (str): GECs type(RNAi, OE or CRISPR)

    Returns:
        best_model (nn.Module): The trained GenSAN model with the lowest validation loss.
    """

    model = GenSAN_model(blocks=blocks,
                         GECs_dimension=GECs_dimension,
                         hidden_nodes=hidden_nodes,
                         heads=heads,
                         dropout_rate=dropout_rate,
                         recycles=recycles)
    model = model.to(Device())

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=learning_rate,
                                 weight_decay=weight_decay)

    best_model = train(epochs=epochs,
                       model=model,
                       train_dataloader=train_dataloader,
                       valid_dataloader=valid_dataloader,
                       optimizer=optimizer,
                       beta=beta,
                       warmup_epoch=warmup_epoch,
                       learning_rate=learning_rate)
    print("here")
    test_evaluateGenSAN(best_model, pre_gecs_test, gecs_test)
    print("here2")
    # torch.save(best_model, save_path + "FunDNN best model.pt")

    predict_GECs = feature_predict(input_matrix=input_matrix,
                                   length=length,
                                   best_model=best_model,
                                   pre_GECs=pre_GECs,
                                   scaler=scaler,
                                   save_path=save_path,
                                   name=name)
    return predict_GECs


In [75]:
# gensan_module
import torch.nn as nn
# from GenSAN.utils import clones, SublayerConnection, LayerNorm


class BlockLayer(nn.Module):
    def __init__(self, GECs_dimension, col_attn, row_attn, feed_forward, dropout_rate):
        """
        Single transformer encoder block.

        Args:
            GECs_dimension (int): The last dimension of GEC size(978).
            col_attn (nn.Module): attention object for column attention.
            row_attn (nn.Module): Multi-head attention object for row attention.
            feed_forward (nn.Module): Feed forward neural network object.
            dropout_rate (float, optional): Dropout layer ratio.
        """
        super(BlockLayer, self).__init__()
        self.GECs_dimension = GECs_dimension
        self.col_attn = col_attn
        self.row_attn = row_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(dropout_rate), 3)

    def forward(self, x):
        x = self.sublayer[0](x, lambda x: self.row_attn(x, x, x))

        x = self.sublayer[1](x, lambda x: self.col_attn(x, x, x))

        output = self.sublayer[2](x, self.feed_forward)
        return output


class Blocks(nn.Module):
    def __init__(self, layer, blocks):
        """
        Transformer encoder blocks.

        Args:
            layer (nn.Module): Single transformer encoder block.
            blocks (int): Number of transformer encoder units for GenSAN.
        """

        super(Blocks, self).__init__()
        self.layers = clones(layer, blocks)
        self.norm = LayerNorm()

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)


class BlocksRecycling(nn.Module):
    def __init__(self, block_layers, recycles):
        """
        Transformer encoder blocks with recycling

        Args:
            block_layers (nn.Module): Single transformer encoder block.
            recycles (int): Recycle times of Transformer encoder blocks
        """

        super(BlocksRecycling, self).__init__()
        self.models = clones(block_layers, recycles)
        self.norm = LayerNorm()

    def forward(self, x):
        for model in self.models:
            x = model(x)
            x = self.norm(x)
            x += x
        return x




In [76]:
import copy
import torch.nn as nn
# from GenSAN.utils import ColumnAttention, RowAttention, PositionwiseFeedForward, Generator
# from GenSAN.module import BlockLayer, Blocks, BlocksRecycling


class GenSAN(nn.Module):
    def __init__(self, util, generator):
        """
        GenSAN model

        Args:
            util (nn.Module): Transformer encoder blocks with recycling.
            generator (nn.Module): Generator block.
        """
        super(GenSAN, self).__init__()
        self.util = util
        self.generator = generator

    def forward(self, x):
        output = self.util(x)
        return self.generator(output)


def GenSAN_model(blocks=3, GECs_dimension=978, hidden_nodes=1024, heads=6, dropout_rate=0.2, recycles=1):
    """
    GenSAN model instantiate.

    Args:
        blocks (int): Number of transformer encoder units for GenSAN
        GECs_dimension (int): The last dimension of GEC size(978).
        hidden_nodes (int): Number of nodes in each feed forward neural network layer
        heads (int): Number of attention heads of row-wise self-attention block
        dropout_rate (float): Dropout layer ratio of GenSAN
        recycles (int): Recycle times of GenSAN model
    """

    c = copy.deepcopy

    col_attn = ColumnAttention(dropout_rate)
    row_attn = RowAttention(heads, GECs_dimension, dropout_rate)

    ff = PositionwiseFeedForward(GECs_dimension, hidden_nodes, dropout_rate)

    model = GenSAN(
        BlocksRecycling(Blocks(BlockLayer(GECs_dimension, c(col_attn), c(row_attn), c(ff), dropout_rate), blocks),
                        recycles),
        Generator(GECs_dimension, hidden_nodes, dropout_rate))

    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model


In [15]:
# from Data_process import load_rawdata, datasets_split
# from FunDNN.Model import run_model
# from FunDNN.preprocessor import load_data, get_dataloader
# from GenSAN.preprocessor import GenSAN_preprocessor, get_pre_GECs
# from GenSAN.train_function import run_GenSAN_model
# from config_parser import TranscriptionNet_Hyperparameters

Load TranscriptionNet model hyperparameters

In [16]:
config = TranscriptionNet_Hyperparameters()

Load raw data and get training, validation and test sets

In [18]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [24]:
node_feature, _, _, CRISPR_GECs = load_rawdata("drive/MyDrive/raw_data/")

# RNAi_MMScaler = datasets_split(RNAi_GECs, node_feature, "example_data/datasets/RNAi/")
# OE_MMScaler = datasets_split(OE_GECs, node_feature, "example_data/datasets/OE/")
CRISPR_MMScaler = datasets_split(CRISPR_GECs, node_feature, "drive/MyDrive/CRISPR/")

CRISPR FunDNN model to predict CRISPR_pre-GECs

In [67]:
CRISPR_feature_train, CRISPR_feature_valid, CRISPR_feature_test = load_data("drive/MyDrive/CRISPR/",
                                                                            "feature_dict.pkl")

CRISPR_GECs_train, CRISPR_GECs_valid, CRISPR_GECs_test = load_data("drive/MyDrive/CRISPR/", "GECs_dict.pkl")

CRISPR_train_dataloader, CRISPR_valid_dataloader = get_dataloader(32, CRISPR_feature_train,
                                                                  CRISPR_GECs_train,
                                                                  CRISPR_feature_valid, CRISPR_GECs_valid)

CRISPR_pre_GECs = run_model(num_layers=5,
                            hidden_nodes=1024,
                            activate_func=nn.LeakyReLU,
                            dropout_rate=0.1,
                            epochs=1000,
                            train_dataloader=CRISPR_train_dataloader,
                            valid_dataloader=CRISPR_valid_dataloader,
                            beta=0.1,
                            feature_test=CRISPR_feature_test,
                            gecs_test=CRISPR_GECs_test,
                            save_path="drive/MyDrive/CRISPR/results/",
                            node_feature=node_feature,
                            name="CRISPR",
                            warmup_epoch=10,
                            learning_rate=0.00035)


Node feature dimension:
train data:torch.Size([264, 512])
valid data:torch.Size([38, 512])
test data:torch.Size([76, 512])

GECS data dimension:
train data:torch.Size([264, 978])
valid data:torch.Size([38, 978])
test data:(76, 978)

end of epoch:  0 | time: 0.05s | train loss:0.22452 | valid loss:0.31303 | train MseLoss:0.13947 | train PccLoss:0.99000 | valid MseLoss:0.23694 | valid PccLoss:0.99785
end of epoch:  1 | time: 0.05s | train loss:0.22597 | valid loss:0.30239 | train MseLoss:0.14093 | train PccLoss:0.99135 | valid MseLoss:0.22586 | valid PccLoss:0.99119
end of epoch:  2 | time: 0.05s | train loss:0.22534 | valid loss:0.32083 | train MseLoss:0.14058 | train PccLoss:0.98822 | valid MseLoss:0.24626 | valid PccLoss:0.99199
end of epoch:  3 | time: 0.05s | train loss:0.22563 | valid loss:0.31455 | train MseLoss:0.14083 | train PccLoss:0.98881 | valid MseLoss:0.23926 | valid PccLoss:0.99220
end of epoch:  4 | time: 0.05s | train loss:0.22614 | valid loss:0.29943 | train MseLoss:0.

In [81]:
# CRISPR GenSAN model
print("CRISPR GenSAN model")
GenSAN_train, GenSAN_valid, GenSAN_test, _, _ = GenSAN_preprocessor(true_GECs1=CRISPR_GECs,
                                                                    true_GECs2=CRISPR_GECs,
                                                                    predict_GECs1=CRISPR_pre_GECs,
                                                                    predict_GECs2=CRISPR_pre_GECs,
                                                                    pre_GECS=CRISPR_pre_GECs,
                                                                    input_path="drive/MyDrive/CRISPR/",
                                                                    file_name="feature_dict.pkl")

GenSAN_train_dataloader, GenSAN_valid_dataloader = get_dataloader(batch_size=32,
                                                                  node_train=GenSAN_train,
                                                                  gecs_train=CRISPR_GECs_train,
                                                                  node_valid=GenSAN_valid,
                                                                  gecs_valid=CRISPR_GECs_valid)

input_matrix = get_pre_GECs(CRISPR_pre_GECs, CRISPR_pre_GECs, CRISPR_pre_GECs)

CRISPR_predict_GECs = run_GenSAN_model(blocks=3,
                                       GECs_dimension=978,
                                       hidden_nodes=1024,
                                       heads=2,
                                       dropout_rate=0.05,
                                       recycles=3,
                                       learning_rate=0.0000045,
                                       weight_decay=1e-5,
                                       epochs=110,
                                       train_dataloader=GenSAN_train_dataloader,
                                       valid_dataloader=GenSAN_valid_dataloader,
                                       beta=0.1,
                                       warmup_epoch=5,
                                       pre_gecs_test=GenSAN_test,
                                       gecs_test=CRISPR_GECs_test,
                                       save_path="drive/MyDrive/CRISPR/results/",
                                       input_matrix=input_matrix,
                                       length=64,
                                       pre_GECs=CRISPR_pre_GECs,
                                       scaler=CRISPR_MMScaler,
                                       name="CRISPR")


CRISPR GenSAN model
pre-GECS dimension:
train data:torch.Size([264, 3, 978])
valid data:torch.Size([38, 3, 978])
test data:torch.Size([76, 3, 978])

end of epoch:  0 | time: 1.58s | train loss:0.82121 | valid loss:0.85952 | train MseLoss:0.80324 | train PccLoss:0.98295 | valid MseLoss:0.84591 | valid PccLoss:0.98206
end of epoch:  1 | time: 1.54s | train loss:0.79145 | valid loss:0.86226 | train MseLoss:0.77172 | train PccLoss:0.96902 | valid MseLoss:0.84865 | valid PccLoss:0.98481
end of epoch:  2 | time: 1.60s | train loss:0.78747 | valid loss:0.84952 | train MseLoss:0.76754 | train PccLoss:0.96677 | valid MseLoss:0.83417 | valid PccLoss:0.98774
end of epoch:  3 | time: 1.56s | train loss:0.77817 | valid loss:0.81502 | train MseLoss:0.75795 | train PccLoss:0.96010 | valid MseLoss:0.79806 | valid PccLoss:0.96763
end of epoch:  4 | time: 1.57s | train loss:0.75887 | valid loss:0.78820 | train MseLoss:0.73795 | train PccLoss:0.94724 | valid MseLoss:0.76940 | valid PccLoss:0.95743
end of