In [None]:
'''
Through a comprehensive online resource generated by the PsychENCODE Consortium (PEC) 
for the adult brain, many entities such as functional elements, quantitative-trait
loci (QTLs), and regulatory-network linkages have been identified and embedded into a
comprehensive deep learning model in order to predict psychiatric phenotypes from genotypic
and transcriptomic data. The end-result is a biologically relevant Deep Boltzmann Machine
architecture connecting genotype, functional genomics, and phenotype data, with conditional
and lateral connections that improve trait prediction over traditional additive models.
Our main goal is to follow a more segregating approach by implementing different
machine learning algorithms that exploit different priors accordingly. We show that 
by knowing just the links between genotypic (TG) and transcriptomic (TF) data, we are able to 
reach a very good accuracy in classifying schizophrenic (SCZ) patients from controls.

The first Pytorch Model we used for patient classification consists of a network 
of stacked Restricted Boltzmann machines (RBMs). In the first part of the code 
we add layers of RBMs sequentially and assess the accuracy of each network architecture. 
Our aim is to define the number of RBM layers we are going to use to build our final 
classifier. Firstly, we start by training RBMs and stack them sequentially to build 
denoising autoencoders (DAEs). Then we refine the RBM weights through a few consecutive 
rounds of DAE training. We then transfer the learnt weights from the DAE with the 
largest accuracy to other network architectures (mostly feedforward networks) for 
further refinement. In this initial code, we apply a MASK (representing the TF-TG 
links) to the first RBM layer in order to keep gradients steady during backpropagation. 
Except this MASK we implemented to only the first RBM, we don't enforce any further sparsity. 

In the second Pytorch Model we build for patient classification, we enforce sparsity 
through MASK (first layer) and through Kullback–Leibler divergence/L2-regularization 
(to all other layers), and transfer the weights to another model for further refinement 
and final classification.

Below we present the second Pytorch Model applied to only one of the 10 PEC datasets.

'''


In [None]:
#__________STEP I___________ Download the input data from the following link:
#                            https://drive.google.com/file/d/10uRj-4gd9wFaDEnEKNDRTKIn1DSTf0Hq/view?usp=sharing

#__________STEP II__________ Unzip the data to google drive: 
#                            "/content/drive/My Drive/datasets/"

#__________STEP III_________ Set Google Colab in GPU mode

#__________STEP IV__________ Please run the codes below sequentially.
#                            Even on GPU, CODE_I needs around 3.5h to complete, 
#                            but feel free to skip directly to CODE_II 


In [None]:
%load_ext rpy2.ipython
from google.colab import drive
drive.mount('/content/drive')


In [None]:
########################################################################################################################
###########################_________________________ PRE-PROCESSING _________________________###########################
########################################################################################################################

In [None]:
#_______________________________________________________________________________ Info on GPUs we are going to load our data.
!mkdir chpnt                                                                   

import tensorflow as tf
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
gpus = tf.config.list_physical_devices('GPU')
if gpus:
  # Restrict TensorFlow to only use the first GPU
  try:
    tf.config.set_visible_devices(gpus[0], 'GPU')
    logical_gpus = tf.config.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
  except RuntimeError as e:
    # Visible devices must be set before GPUs have been initialized
    print(e)

## memory footprint
!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi
!pip install gputil
!pip install psutil
!pip install humanize

import psutil
import humanize
import os
import GPUtil as GPU


In [None]:
#______INPUT FILES    :     topC.Rdata           (Transcription Factors-Genes connectivity matrix)
#                           GRN_1_genes.csv      (Genes(ensembl_gene_id) connected to transcription factors)
#                           Tot_TG.csv           (Total number of genes(ensembl_gene_id))
#                           scz_data5.mat        (Gene dataset downloaded from PEC(ensembl_gene_id))

#_______________________________________________________________________________ I. Converting input genes from ensembl_gene_id to hgnc_symbol
%%R                                                                            

load(file="/content/drive/My Drive/datasets/topC.Rdata")
GRN_1_TG <- read.csv("/content/drive/My Drive/datasets/GRN_1_genes.csv")  
C_TGe2 <- read.csv("/content/drive/My Drive/datasets/Tot_TG.csv")
#_______________________________________________________________________________ Ia. Convert input genes

install.packages("R.matlab")
library(R.matlab)
scz_data1 <- readMat("/content/drive/My Drive/datasets/scz_data5.mat")
scz_data1f <- data.frame("X.geneIds"=matrix(unlist(scz_data1$X.geneIds), 
                                            nrow=length(scz_data1$X.geneIds), 
                                            byrow=TRUE),stringsAsFactors=FALSE)

scz_data1f$X.geneIds <- as.character(scz_data1f$X.geneIds)
scz_data1f$X.geneIds <- sub("[.][0-9]*","",scz_data1f$X.geneIds)
genes1 <-  scz_data1f$X.geneIds

if (!requireNamespace("BiocManager", quietly = TRUE))  # Installing biomart from Bioconductor
    install.packages("BiocManager")
BiocManager::install("biomaRt", force = TRUE)
library(biomaRt)

require("biomaRt")
mart <- useMart("ENSEMBL_MART_ENSEMBL")
mart <- useDataset("hsapiens_gene_ensembl", mart)

ensLookup <- genes1

mrg_lst2 <- getBM(
  mart=mart,
  attributes=c("ensembl_gene_id","gene_biotype","hgnc_symbol"),
  filter="ensembl_gene_id",
  values=ensLookup,
  uniqueRows=TRUE)

mrg_lst2 <- data.frame(ensLookup[match(mrg_lst2$ensembl_gene_id, ensLookup)],
  mrg_lst2)

mrg_lst2 <- subset(mrg_lst2, (!is.na(mrg_lst2['hgnc_symbol'])))

colnames(mrg_lst2) <- c(
  "original_id",
  c("ensembl_gene_id","gene_biotype","hgnc_symbol"))
#_______________________________________________________________________________ Ib. Convert total genes

scz_data1h <- data.frame("hgnc_symbol"=matrix(unlist(C_TGe2$Var1), nrow=length(C_TGe2$Var1), byrow=TRUE),stringsAsFactors=FALSE)
genes2 <-  scz_data1h$hgnc_symbol

ensLookup <- genes2

mrg_lst <- getBM(
  mart=mart,
  attributes=c("hgnc_symbol","gene_biotype","ensembl_gene_id"),
  filter="hgnc_symbol",
  values=ensLookup,
  uniqueRows=TRUE)

mrg_lst <- data.frame(ensLookup[match(mrg_lst$hgnc_symbol, ensLookup)],
  mrg_lst)

mrg_lst <- subset(mrg_lst, (!is.na(mrg_lst['ensembl_gene_id'])))

colnames(mrg_lst) <- c(
  "original_id",
  c("ensembl_gene_id","gene_biotype","hgnc_symbol"))


In [None]:
#______INPUT FILES    :     ifx4.Rdata          (indexing of Genes(hgnc_symbol) connected to Transcription Factors)
#                           ifx6_tmp.Rdata      (indexing of Genes(hgnc_symbol) connected to Transcription Factors aligned to TopC connectivity matrix)
#                           GNT_NAME.Rdata      (unique gene names aligned to input genes)

#_______________________________________________________________________________ I. Loading input files and libraries
import torch.optim as optim
from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score, classification_report
import numba as nb
from numba.typed import List
import numba
!pip install -U "ray[tune]"
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from ray.tune.trial import ExportFormat
from functools import partial
import os
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms
import pandas as pd
import tensorflow as tf
import numpy
import numpy as np
import torch
from scipy import sparse, io 
import torchvision.datasets
import torchvision.models
import torchvision.transforms
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torch.utils.data as data_utils
from torchvision.datasets import MNIST
from sklearn.datasets import make_blobs
from tensorflow.python.keras.layers import Dense
from tensorflow.python.keras.models import Sequential
from sklearn.decomposition import PCA
import matplotlib
from matplotlib import pyplot
import matplotlib.pyplot as plt
import torch.nn as nn
from sklearn.model_selection import KFold
from sklearn.linear_model import LogisticRegression
from tensorflow.python.keras.utils.np_utils import to_categorical
from torchvision.utils import make_grid
import torch.nn.functional as F
import rpy2.robjects as robjects
from tensorflow.python.keras.optimizer_v2.adam import Adam
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import KFold
from sklearn.metrics import mean_absolute_error
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from keras.layers import BatchNormalization
from keras.models import load_model
from sklearn.feature_selection import SelectKBest
from sklearn.feature_selection import f_classif
from sklearn.feature_selection import f_regression
from numpy import set_printoptions
from sklearn.feature_selection import VarianceThreshold
from pickle import dump
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
import random as python_random
from keras.callbacks import Callback
from keras import backend
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
from sklearn.decomposition import PCA
np.random.seed(0)
from sklearn import datasets
from google.colab import output
%matplotlib inline

%R load("/content/drive/My Drive/datasets/ifx4.Rdata")
%R load("/content/drive/My Drive/datasets/ifx6_tmp.Rdata")

#_______________________________________________________________________________ II. Keep only Genes(hgnc_symbol) connected to Transcription Factors, 
#                                                                                    remove duplicates or NA

%R n13_train <- scz_data1$X.Gene2.tr[,ifx4][,-c(4070, 5038, 5353, 6869)][,ifx6_tmp][,-10041]
%R save(n13_train, file="/content/drive/My Drive/datasets/n13_train_short.Rdata")

%R n13_train_trait <- scz_data1$X.Trait.tr
%R save(n13_train_trait, file="/content/drive/My Drive/datasets/n13_train_trait_short.Rdata")

%R n13_test <- scz_data1$X.Gene2.te[,ifx4][,-c(4070, 5038, 5353, 6869)][,ifx6_tmp][,-10041]
%R save(n13_test, file="/content/drive/My Drive/datasets/n13_test_short.Rdata")

%R n13_test_trait <- scz_data1$X.Trait.te
%R save(n13_test_trait, file="/content/drive/My Drive/datasets/n13_test_trait_short.Rdata")

#_______________________________________________________________________________  III. Load files
%R load("/content/drive/My Drive/datasets/n13_train_short.Rdata")
%R load("/content/drive/My Drive/datasets/n13_test_short.Rdata")
%R load("/content/drive/My Drive/datasets/n13_train_trait_short.Rdata")
%R load("/content/drive/My Drive/datasets/n13_test_trait_short.Rdata")
%R load("/content/drive/My Drive/datasets/topC.Rdata")
%R load("/content/drive/My Drive/datasets/GNT_NAME.Rdata")

scz_data2_train_tmp = robjects.r['n13_train']
scz_data2_test_tmp = robjects.r['n13_test']
X_train_fl_tmp = pd.DataFrame(np.array(scz_data2_train_tmp)).astype(float, 64)

%R n13_train_mat <- as.matrix(n13_train)
%R n13_test_mat <- as.matrix(n13_test)

%R colnames(n13_train_mat) <- GNT_NAME$name
%R colnames(n13_test_mat) <- GNT_NAME$name

scz_data2_train = robjects.r['n13_train_mat']
scz_data2_test = robjects.r['n13_test_mat']
scz_data2_train_trait = robjects.r['n13_train_trait']
scz_data2_test_trait = robjects.r['n13_test_trait']

X_train_fl = pd.DataFrame(np.array(scz_data2_train)).astype(float, 64)
y_train = pd.DataFrame(np.array(scz_data2_train_trait)[:,1]).astype(int)

X_test_fl = pd.DataFrame(np.array(scz_data2_test)).astype(float, 64)
y_test = pd.DataFrame(np.array(scz_data2_test_trait)[:,1]).astype(int)

y_train.insert(0, "Patient", [i for i in range(len(y_train))], True)
y_train.rename(columns = {'Patient':'Patient', 
                       0:'Disease'}, inplace = True)

y_test.insert(0, "Patient", [i for i in range(len(y_test))], True)
y_test.rename(columns = {'Patient':'Patient', 
                       0:'Disease'}, inplace = True)
#_______________________________________________________________________________ IV. we select 275 random controls and 275 patients 
#                                                                                    (due to memory restrictions in Ray-tunes)

n1 = 275 
n2 = 275

idx1 = X_train_fl.index.values[y_train['Disease'] == 0]
idx2 = X_train_fl.index.values[y_train['Disease'] == 1]
len1 = len(idx1) 
len2 = len(idx2)

draw1 = np.random.permutation(len1)[:n1]
idx1_test = idx1[draw1]
draw2 = np.random.permutation(len2)[:n2]
idx2_test = idx2[draw2]
idx_test = np.hstack([idx1_test, idx2_test])

idx_train = X_train_fl.index.values[idx_test]

X_train_fl = X_train_fl.loc[idx_train, :]  # optional: .reset_index(drop=True)
y_train = y_train.loc[idx_train, :]

y_train.loc[y_train['Disease']==1].agg(['nunique','count','size'])  # check number of classes(equal "0" and "1")

# Apply the same scaling to both datasets
scaler = StandardScaler()
#scaler = MinMaxScaler()

X_train_scl = scaler.fit_transform(X_train_fl)
X_test_scl = scaler.transform(X_test_fl) # note that we transform rather than fit_transform


In [None]:
#______INPUT FILES    :     genematrix_tmp.txt              (sparse connectivity Matrix)
#                           GRN_3_mat_tmp_rownames.txt      (names of Genes connected to Transcription Factors)
#                           GRN_3_mat_tmp_colnames.txt      (names of Transcription Factors connected to Genes)

#_______________________________________________________________________________ I. PCA applied on our data
pca = PCA()
pca.fit_transform(X_train_scl)
total = sum(pca.explained_variance_)
k = 0
current_variance = 0
while current_variance/total < 0.90:
    current_variance += pca.explained_variance_[k]
    k = k + 1
#_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ Ia. Use PCA to remove less important data features
x_new = pca.fit_transform(X_train_scl)
pca.explained_variance_ratio_
n_pcs= pca.components_.shape[0]

x = 0
y = []
i1 = 0
lim = 0.90                 #____________ let's set a limit to 90% total variance
for i in range(pca.components_.shape[0]):
  i1 += 1
  x += pca.explained_variance_ratio_[i]
  #print(x)
  y.append(x % lim)
  #print(y)
  if len(y)>1:
    if y[i-1]<y[i] and x>lim:
      break
x1 = 0
for i in range(i1-1):
  x1 += pca.explained_variance_ratio_[i]
print("remove 5 bottom elements from total features that explain " , x1, " of variance.", sep='')
n_pcs_new = i1 - 1

a1 = [[b[0] for b in sorted(enumerate(np.abs(pca.components_[i])),key=lambda i:i[1])] for i in range(n_pcs_new)]
a2 = [[x for x in np.abs(pca.components_[i]) if x<np.abs(pca.components_[i])[a1[i][5]]] for i in range(n_pcs_new)]
x = [[np.abs(pca.components_[j]).tolist().index(i) for i in a2[j]] for j in range(n_pcs_new)]
unique = list(dict.fromkeys(np.concatenate(x)))
torch.save(unique,'/content/drive/My Drive/datasets/unique_tmp5.pt')

n_sparse = io.mmread('/content/drive/My Drive/datasets/genematrix_tmp.txt')

var_names = np.genfromtxt('/content/drive/My Drive/datasets/GRN_3_mat_tmp_rownames.txt', dtype=str)
col_names = np.genfromtxt('/content/drive/My Drive/datasets/GRN_3_mat_tmp_colnames.txt', dtype=str)

gene_mat = [pd.DataFrame(n_sparse.toarray(), columns=col_names, index=var_names[:-1])]

list_of_arrays1 = [np.array(df) for df in gene_mat]
df3 = pd.DataFrame(list_of_arrays1[0])
df3_tr = df3.drop(df3.index[unique])
list_of_arrays1 = df3_tr.to_numpy()

init_layer_lat = torch.tensor(np.stack(list_of_arrays1))
init_layer_lat = torch.squeeze(init_layer_lat, 0)
print(np.shape(init_layer_lat))

#_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ Ib. Plotting PCs
pca = PCA(n_components=k)
X_train_pca=pca.fit_transform(X_train_scl)
X_test_pca=pca.transform(X_test_scl)
var_exp = pca.explained_variance_ratio_.cumsum()
var_exp = var_exp*100
fig = plt.figure(1, figsize = (10, 6))
plt.bar(range(k), var_exp);

pca3 = PCA(n_components=3).fit(X_train_scl)
X_train_reduced = pca3.transform(X_train_scl)

import matplotlib.colors as mcolors
fig = plt.figure(2, figsize=(10,6 ))
ax = Axes3D(fig, elev=-150, azim=110,)
ax.scatter(X_train_reduced[:, 0], X_train_reduced[:, 1], X_train_reduced[:, 2], c = y_train.iloc[:,1], cmap = mcolors.ListedColormap(["blue", "red"]), linewidths=10)
ax.set_title("First three PCA directions")
ax.set_xlabel("1st eigenvector")
ax.w_xaxis.set_ticklabels([])
ax.set_ylabel("2nd eigenvector")
ax.w_yaxis.set_ticklabels([])
ax.set_zlabel("3rd eigenvector")
ax.w_zaxis.set_ticklabels([])

fig = plt.figure(3, figsize = (10, 6))
plt.scatter(X_train_reduced[:, 0],  X_train_reduced[:, 1], c = y_train.iloc[:,1].apply(pd.to_numeric), cmap = mcolors.ListedColormap(["blue", "red"]), linewidths=10)
plt.title("2D Transformation of the Above Graph to x-y plane")

fig = plt.figure(4, figsize = (10, 6))
plt.scatter(X_train_reduced[:, 0],  X_train_reduced[:, 2], c = y_train.iloc[:,1].apply(pd.to_numeric), cmap = mcolors.ListedColormap(["blue", "red"]), linewidths=10)
plt.title("2D Transformation of the Above Graph to x-z plane")


In [None]:
#_______________________________________________________________________________ I. Load the above cleaned data to GPU, 
#                                                                                   and pytorch DataLoaders for further use

df1 = pd.DataFrame(X_train_scl)
df1_tr = df1.drop(df1.columns[unique], axis=1)
print("input of train data:")
print(df1_tr.shape)
X_train_scl = df1_tr.to_numpy()

df2 = pd.DataFrame(X_test_scl)
df2_te = df2.drop(df2.columns[unique], axis=1)
print("input of test data:")
print(df2_te.shape)
X_test_scl = df2_te.to_numpy()

inp_pad = X_train_scl.shape[1]
torch.save(inp_pad,'/content/drive/My Drive/datasets/inp_pad_tmp5.pt')

#_______________________________________________________________________________ II. load to GPU
n14_train = torch.from_numpy(np.array(X_train_scl)).float()
n14_test = torch.from_numpy(np.array(X_test_scl)).float()
n14_train_trait = torch.from_numpy(np.array(y_train))
n14_train_trait=n14_train_trait[:,1]
n14_test_trait = torch.from_numpy(np.array(y_test))
n14_test_trait=n14_test_trait[:,1]

device = "cpu"
if torch.cuda.is_available():
  device = "cuda:0"
if device =='cuda:0':
  n14_train = n14_train.to(device)
  n14_train_trait = n14_train_trait.to(device)
  n14_test = n14_test.to(device)
  n14_test_trait = n14_test_trait.to(device)
  torch.cuda.synchronize()

#_______________________________________________________________________________ III. Load to pytorch DataLoaders()
y_val = pd.DataFrame(np.array(scz_data2_train_trait)[:,1]).astype(int)
y_val.insert(0, "Patient", [i for i in range(len(y_val))], True)
y_val.rename(columns = {'Patient':'Patient', 
                       0:'Disease'}, inplace = True)

mask1 = np.ones(y_val['Disease'].shape, bool)
mask1[~idx_train] = False
y_val = y_val[mask1]
y_val = y_val.dropna(axis = 0, how = 'all')
y_val.loc[y_val['Disease']==1].agg(['nunique','count','size'])


X_val_fl = pd.DataFrame(np.array(scz_data2_train)).astype(float, 64)
mask = np.ones(X_val_fl.shape, bool)
mask[~idx_train] = False
df_mask = pd.DataFrame(data=mask)
X_val_fl = X_val_fl[df_mask]
X_val_fl = X_val_fl.dropna(axis = 0, how = 'all')
X_val_scl = scaler.transform(X_val_fl)

df4 = pd.DataFrame(X_val_scl)
df4_tr = df4.drop(df4.columns[unique], axis=1)
print("input of train data:")
print(df4_tr.shape)
X_val_scl = df4_tr.to_numpy()


n14_val = torch.from_numpy(np.array(X_val_scl)).float()
n14_val_trait = torch.from_numpy(np.array(y_val))
n14_val_trait=n14_val_trait[:,1]

valid = data_utils.TensorDataset(n14_val, n14_val_trait)
valid_loader = data_utils.DataLoader(valid, batch_size=90, shuffle=True, num_workers=2)
train_dataset = data_utils.TensorDataset(n14_train, n14_train_trait)
train_loader = data_utils.DataLoader(train_dataset, batch_size=550, shuffle=True, num_workers=2)
test_dataset = data_utils.TensorDataset(n14_test, n14_test_trait)
test_loader = data_utils.DataLoader(test_dataset, batch_size=70, shuffle=False, num_workers=2)


In [None]:
##################################################################################################################
###########################_________________________ PIPELINE _________________________###########################
##################################################################################################################

In [None]:
#__________CODE I______________ 

#______OUTPUT FILES   :     iteration_list.txt

#_______________________________________________________________________________ I. DECLARING CLASSES AND FUNCTIONS

def weight_pruning(w: tf.Variable, k: float) -> tf.Variable:
    k = tf.cast(tf.round(tf.size(w, out_type=tf.float32) * tf.constant(k)), dtype=tf.int32)
    w_reshaped = tf.reshape(w, [-1])
    _, indices = tf.nn.top_k(tf.negative(tf.abs(w_reshaped)), k, sorted=True, name=None)
    mask = tf.compat.v1.scatter_nd_update(tf.Variable(tf.ones_like(w_reshaped, dtype=tf.float32), name="mask", trainable=False), tf.reshape(indices, [-1, 1]), tf.zeros([k], tf.float32))
    return tf.reshape(w_reshaped * mask, tf.shape(w))

                                      #_______________________ Ia. 3-layered netork (w/o batch normalization)
class model_tmp2(nn.Module):
    def __init__(self):
      super(model_tmp2, self).__init__()

      self.enc1 = nn.Linear(in_features=inp_pad, out_features=669)
      self.enc2 = nn.Linear(in_features=669, out_features=300)
      self.enc3 = nn.Linear(in_features=300, out_features=100)

      self.dec3 = nn.Linear(in_features=100, out_features=300)
      self.dec2 = nn.Linear(in_features=300, out_features=669)
      self.dec1 = nn.Linear(in_features=669, out_features=inp_pad)

    def forward(self, x):
      x = F.relu(self.enc1(x))
      x = F.relu(self.enc2(x))
      x = F.relu(self.enc3(x))

      x = F.relu(self.dec3(x))
      x = F.relu(self.dec2(x))
      x = torch.sigmoid(self.dec1(x))
      return x
                                      #______________________ Ib. 3-layered network (with batch normalization)
class model_tmp3(nn.Module):
    def __init__(self):
      super(model_tmp3, self).__init__()

      self.fc1 = nn.Linear(in_features=inp_pad, out_features=669)
      self.fc1_bn=nn.BatchNorm1d(669)
      self.fc2 = nn.Linear(in_features=669, out_features=300)
      self.fc2_bn=nn.BatchNorm1d(300)
      self.fc3 = nn.Linear(in_features=300, out_features=100)
      self.fc3_bn=nn.BatchNorm1d(100)
      self.fc4 = nn.Linear(in_features=100, out_features=1)

    def forward(self, x):
      x = F.relu(self.fc1_bn(self.fc1(x)))
      x = F.relu(self.fc2_bn(self.fc2(x)))
      x = F.relu(self.fc3_bn(self.fc3(x)))
      x = torch.sigmoid(self.fc4(x))
      return x


#______________________________ II/III. train_gene() performs model_2 training (autoencoder) in Pytorch from sctratch. 
                                # After enforcing sparsity through mask (first layer*) or KL_divergence/L2 (to all other layers),
                                # we transfer the weights to model_3 for further refinement and final classification.
                                # (*We can also choose to implement sparsity to the first layer instead of applying a mask.)
#_______________________________________________________________________________ The function performs 80 training sessions, 
#                                                                                with 7 cross-validations each. In every cross-validation
#                                                                                the random.seeds must be different

def train_gene(config, checkpoint_dir='/content/drive/My Drive/datasets/chpnt/', data_dir=None):

                                          #_____________________________________IIa. In each cross-validation the random seeds must change!
                                                                               # The initialization of all seeds occur in main() below.
  checkpoint_dir1 = None                  
  checkpoint_dir = '/content/drive/My Drive/datasets/chpnt/'
  step = 0
  print("step______________________:")
  print(step)
  count_dl4 = []
  count_dl4.append(1)
  with open('/content/drive/My Drive/datasets/count_dl5.txt', 'a') as f4:
    print(count_dl4, file=f4)
  if os.path.exists('/content/drive/My Drive/datasets/count_dl5.txt'):
    count_dl4_num = open('/content/drive/My Drive/datasets/count_dl5.txt').read().count('\n')
    print("count_dl4_num________________A")
    print(count_dl4_num)
  if (count_dl4_num > 7):          #__(1/2)___number of cross-validations : 7
    print("______________FILE DELETED in train_gene() _____________")
    os.remove('/content/drive/My Drive/datasets/count_dl5.txt')
    count_dl4 = []

  count_dl4_chp = []
  count_dl4_chp.append(1)
  with open('/content/drive/My Drive/datasets/count_dl5_chp.txt', 'a') as f4:
    print(count_dl4_chp, file=f4)
  if os.path.exists('/content/drive/My Drive/datasets/count_dl5_chp.txt'):
    count_dl4_num_chp = open('/content/drive/My Drive/datasets/count_dl5_chp.txt').read().count('\n')
    print("count_dl4_num_chp________________A")
    print(count_dl4_num_chp)

  np.random.seed(1234+count_dl4_num)
  python_random.seed(1234+count_dl4_num)
  tf.random.set_seed(1234+count_dl4_num)
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
  if DEVICE =='cuda':
    torch.cuda.manual_seed_all(356+count_dl4_num)
  else:
    torch.manual_seed(356+count_dl4_num)
    
  device = "cpu"
  if torch.cuda.is_available():
    device = "cuda:0"
  if device =='cuda:0':
    torch.cuda.manual_seed_all(356+count_dl4_num)
  else:
    torch.manual_seed(356+count_dl4_num) 

################################################################################
##########_______________________ MODEL_2 _______________________###############
################################################################################

#_______________________________________________________________________________ IIb. Load model_2 to GPU

  model_2 = model_tmp2().to(device)

#_______________________________________________________________________________ IIc. Optimizers  
  if checkpoint_dir1:
        print("Loading from checkpoint.")
        model_state, optimizer_state = torch.load(
            os.path.join(checkpoint_dir, "checkpoint"))
        model_2.load_state_dict(model_state)
        optimizer1.load_state_dict(optimizer_state)
        step = checkpoint["step"]

  if config['optimizer'] == 'optimizer_ADAM':

    optimizer1 = optim.Adam(filter(lambda p: p.requires_grad, model_2.parameters()), 
                                lr=config.get("lr", 0.01))
    if "netD_lr" in config:
            for param_group in optimizer1.param_groups:
                param_group["lr"] = config["netD_lr"]

  elif config['optimizer'] == 'optimizer_SGD':
    optimizer1 = torch.optim.SGD([{'params':model_2.enc1.parameters()}, 
                                 {'params':model_2.enc2.parameters()},
                                 {'params':model_2.enc3.parameters()}],
                                  lr = 0.01, momentum = 5e-4)
    if "netE1_lr" in config:
                optimizer1.param_groups[0]["lr"] = config["netE1_lr"]
    if "netE1_mom" in config:
                optimizer1.param_groups[0]["momentum"] = config["netE1_mom"]
    if "netE2_lr" in config:
                optimizer1.param_groups[1]["lr"] = config["netE2_lr"]
    if "netE2_mom" in config:
                optimizer1.param_groups[1]["momentum"] = config["netE2_mom"]
    if "netE3_lr" in config:
                optimizer1.param_groups[2]["lr"] = config["netE3_lr"]
    if "netE3_mom" in config:
                optimizer1.param_groups[2]["momentum"] = config["netE3_mom"]

  elif config['optimizer'] == 'optimizer_ADAgrad':                                            
    optimizer1 = optim.Adagrad(filter(lambda p: p.requires_grad, model_2.parameters()), 
                                lr=config.get("lr", 0.01))

    if "netF_lr" in config:
            for param_group in optimizer1.param_groups:
                param_group["lr"] = config["netF_lr"]

  elif config['optimizer'] == 'optimizer_RMS':
    optimizer1 = optim.RMSprop(filter(lambda p: p.requires_grad, model_2.parameters()), 
                                lr=config.get("lr", 0.01))
        
    if "netG_lr" in config:
            for param_group in optimizer1.param_groups:
                param_group["lr"] = config["netG_lr"]
                                    
                                              #_________________________________ IId. model_2 training

  scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=config["step_size1"], gamma=config["gamma1"])

  criterion = nn.BCELoss()
  #criterion = nn.MSELoss()
  test_loss = [] 
  train_loss = []
  tmp_acc = []
  rh0 = 0.05
  BATCH_SIZE_tr = 550
  BATCH_SIZE_te = 70
  BETA = 0.001
  nb1 = config["nb1"]
  MNIST_NUM_PIXELS = inp_pad

  summ_spars = []
  count = 3
  prm4 = []
  model_children = list(model_2.children())

  MASK.requires_grad = False   # MASKed neurons cannot be trained

  if torch.cuda.is_available():
    model_2.enc1.weight.register_hook(lambda grad: grad.mul_(MASK.to(device)))
    model_2.dec1.weight.register_hook(lambda grad: grad.mul_(MASK.t().to(device)))
  else:
    model_2.enc1.weight.register_hook(lambda grad: grad.mul_(MASK))
    model_2.dec1.weight.register_hook(lambda grad: grad.mul_(MASK.t()))
  
  train_losses, dev_losses, train_acc, train_pcc, dev_acc, dev_pcc, dev_fcc, dev_rcc = [], [], [], [], [], [], [], []
  for epoch in range(nb1):
      running_loss = 0.0
      features = np.zeros((BATCH_SIZE_tr, MNIST_NUM_PIXELS))
      labels = np.zeros(BATCH_SIZE_tr)

      model_2.train()
      l2_reg = 0
      features_train = n14_train
      optimizer1.zero_grad()
      sparsity = 0
      sparsity1 = 0
      sparsity2 = 0
      sparsity3 = 0
      sparsity4 = 0
      values = features_train
      features = model_2(features_train)
      labels = n14_train_trait
      mse_loss = criterion(features,features_train)

                                                  #_____________________________ IIe. We calculate sparsity only to the encoder part(3/6 layers)
      for i in range(len(model_children)):
        values = model_children[i](values)
        if (i == 0):
          values1 = torch.mean(torch.sigmoid(values), 1)
          rho1 = torch.tensor([config["RHO1"]] * len(values1)).to(device)
          sparsity1 += torch.sum(rho1 * torch.log(rho1/values1) + (1 - rho1) * torch.log((1 - rho1)/(1.0001 - values1)))
        if (i == 1):           
          values1 = torch.mean(torch.sigmoid(values), 1)
          rho2 = torch.tensor([config["RHO2"]] * len(values1)).to(device)
          sparsity2 += torch.sum(rho2 * torch.log(rho2/values1) + (1 - rho2) * torch.log((1 - rho2)/(1.0001 - values1)))
        if (i == 2):
          values1 = torch.mean(torch.sigmoid(values), 1)
          rho3 = torch.tensor([config["RHO3"]] * len(values1)).to(device)
          sparsity3 += torch.sum(rho3 * torch.log(rho3/values1) + (1 - rho3) * torch.log((1 - rho3)/(1.0001 - values1)))

      for name, param in model_2.named_parameters():
        if 'weight' in name:
          l2_reg += param.pow(2).sum() / 2

      sparsity1 = sparsity1.clone().detach().requires_grad_(True)
      sparsity2 = sparsity2.clone().detach().requires_grad_(True)
      sparsity3 = sparsity3.clone().detach().requires_grad_(True)
      sparsity = sparsity1 + sparsity2 + sparsity3

      l2_reg = l2_reg.clone().detach().requires_grad_(True)

      if config['reg'] == 'sparse_norm':
        #loss = mse_loss + config['BETA'] * sparsity   # KL Divergence (for total sparsity in the whole network)
        loss = (mse_loss + config['BETA1'] * sparsity1
                                 + config['BETA2'] * sparsity2 
                                       + config['BETA3'] * sparsity3) # KL Div (for sparsities in individual layers)
      if config['reg'] == 'L2_norm':
        loss = mse_loss + config['weight_decay'] * l2_reg   # L2 regularization

      loss.backward()
      optimizer1.step()
      scheduler1.step()
      running_loss += loss.item()
      epoch_loss = running_loss
      train_losses.append(epoch_loss)

      mse_loss = 0
      pp = 0
      running_loss = 0.0
      epoch_loss = 0.0
      loss = 0.0
      acc = 0.0
                                            # __________________________________ IIf. Model_2 testing 
      for param in model_2.parameters():
            param.requires_grad = False
      with torch.set_grad_enabled(False):
            model_2.eval()
            features_test = np.zeros((BATCH_SIZE_te, MNIST_NUM_PIXELS))
            labels_test = np.zeros(BATCH_SIZE_te)
            features_test_tmp = n14_test
            features_test = model_2(features_test_tmp)
            labels_test = n14_test_trait
            mse_loss = criterion(features_test, features_test_tmp)
            loss = mse_loss
            running_loss = loss.item()
            epoch_loss = running_loss
            dev_losses.append(epoch_loss)
                                                    #___________________________ IIg. All layers can be freely trained again
      for name, param in model_2.named_parameters():
            if 'enc1.weight' in name:
              param.requires_grad = True
            if 'enc1.bias' in name:
              param.requires_grad = True
            if 'enc2.weight' in name:
              param.requires_grad = True
            if 'enc2.bias' in name:
              param.requires_grad = True
            if 'enc3.weight' in name:
              param.requires_grad = True            
            if 'enc3.bias' in name:
              param.requires_grad = True

            if 'dec1.weight' in name:
              param.requires_grad = True
            if 'dec1.bias' in name:
              param.requires_grad = True
            if 'dec2.weight' in name:
              param.requires_grad = True
            if 'dec2.bias' in name:
              param.requires_grad = True
            if 'dec3.weight' in name:
              param.requires_grad = True            
            if 'dec3.bias' in name:
              param.requires_grad = True                                                         

  if checkpoint_dir1:
        model_state, optimizer_state = torch.load(
            os.path.join(checkpoint_dir, "checkpoint"))
        model_2.load_state_dict(model_state)
        optimizer1.load_state_dict(optimizer_state)

  #_____________________________________________________________________________ IIh. Pruning techniques

                                                #________ IIh1. pruning layer-1 through MASK
  MASK.requires_grad = False
  
  for name, param in model_2.named_parameters():
    param.requires_grad = False
    if 'enc1.weight' in name:
      if torch.cuda.is_available():
        param1 = tf.convert_to_tensor(param.detach().cpu().numpy()) * MASK
      else:
        param1 = tf.convert_to_tensor(param) * MASK     
      parm4 = torch.from_numpy(param1.numpy())
      param.copy_(parm4)

  for name, param in model_2.named_parameters():
    param.requires_grad = False
    if 'dec1.weight' in name:
      if torch.cuda.is_available():
        param1 = tf.convert_to_tensor(param.detach().cpu().numpy()) * MASK.t()
      else:
        param1 = tf.convert_to_tensor(param) * MASK.t()      
      parm4 = torch.from_numpy(param1.numpy())
      param.copy_(parm4)

                                                #________ IIh2. pruning the rest of layers through KL_div
  for name, param in model_2.named_parameters():
    param.requires_grad = False
    if 'enc2.weight' in name:
      if torch.cuda.is_available():
        param1 = weight_pruning(tf.convert_to_tensor(param.detach().cpu().numpy()), rho2[0].cpu())
      else:
        param1 = weight_pruning(tf.convert_to_tensor(param), rho2[0])
      parm4 = torch.from_numpy(param1.numpy())
      param.copy_(parm4)
    if 'enc3.weight' in name:
      if torch.cuda.is_available():
        param1 = weight_pruning(tf.convert_to_tensor(param.detach().cpu().numpy()), rho3[0].cpu())
      else:
        param1 = weight_pruning(tf.convert_to_tensor(param),  rho3[0])          
      parm4 = torch.from_numpy(param1.numpy())
      param.copy_(parm4)     
    if 'dec2.weight' in name:
      if torch.cuda.is_available():
        param1 = weight_pruning(tf.convert_to_tensor(param.detach().cpu().numpy()), rho2[0].cpu())
      else:
        param1 = weight_pruning(tf.convert_to_tensor(param), rho2[0])           
      parm4 = torch.from_numpy(param1.numpy())
      param.copy_(parm4)      
    if 'dec3.weight' in name:
      if torch.cuda.is_available():
        param1 = weight_pruning(tf.convert_to_tensor(param.detach().cpu().numpy()), rho3[0].cpu())
      else:
        param1 = weight_pruning(tf.convert_to_tensor(param), rho3[0])           
      parm4 = torch.from_numpy(param1.numpy())
      param.copy_(parm4)      
  model_children = list(model_2.children())

  #_____________________________________________________________________________ IIi. Estimate sparsity per layer 
  with open('/content/Sparsity.txt', 'a') as f3:
    print("Sparsity in enc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model_2.enc1.weight == 0))
        / float(model_2.enc1.weight.nelement())), file=f3)
    print("Sparsity in enc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model_2.enc2.weight == 0))
        / float(model_2.enc2.weight.nelement())), file=f3)
    print("Sparsity in enc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model_2.enc3.weight == 0))
        / float(model_2.enc3.weight.nelement())), file=f3)
    print("Sparsity in dec1.weight: {:.2f}%".format(
        100. * float(torch.sum(model_2.dec1.weight == 0))
        / float(model_2.dec1.weight.nelement())), file=f3)
    print("Sparsity in dec2.weight: {:.2f}%".format(
        100. * float(torch.sum(model_2.dec2.weight == 0))
        / float(model_2.dec2.weight.nelement())), file=f3)
    print("Sparsity in dec3.weight: {:.2f}%".format(
        100. * float(torch.sum(model_2.dec3.weight == 0))
        / float(model_2.dec3.weight.nelement())), file=f3)
    print("Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model_2.enc1.weight == 0)
            + torch.sum(model_2.enc2.weight == 0)
            + torch.sum(model_2.enc3.weight == 0)
            + torch.sum(model_2.dec1.weight == 0)
            + torch.sum(model_2.dec2.weight == 0)
            + torch.sum(model_2.dec3.weight == 0))
        / float(
            model_2.enc1.weight.nelement()
            + model_2.enc2.weight.nelement()
            + model_2.enc3.weight.nelement()
            + model_2.dec1.weight.nelement()
            + model_2.dec2.weight.nelement()
            + model_2.dec3.weight.nelement())), file=f3)
    print("___________________________________________\n", file=f3) 


  ##############################################################################
  ##########_______________________ MODEL_3 _______________________#############
  ##############################################################################

  #_____________________________________________________________________________ IIIa. Load model_3 to GPU
  model_3 = model_tmp3().to(device)

#  GPUs = GPU.getGPUs()   #_____________________________________________________ IIIb. Uncomment to watch the available RAM during training
#  gpu = GPUs[0]
#  def printm():
#    process = psutil.Process(os.getpid())
#    print("Gen RAM Free: " + humanize.naturalsize(psutil.virtual_memory().available), " |     Proc size: " + humanize.naturalsize(process.memory_info().rss))
#    print("GPU RAM Free: {0:.0f}MB | Used: {1:.0f}MB | Util {2:3.0f}% | Total     {3:.0f}MB".format(gpu.memoryFree, gpu.memoryUsed, gpu.memoryUtil*100, gpu.memoryTotal))
#  printm()

  #_____________________________________________________________________________ IIIc. weight tranfer from model_2 to model_3, and weight freezing
  mdl_2_name = []
  mdl_2_prms = []

  for name,param in model_2.named_parameters():
    mdl_2_name.append(name)
    mdl_2_prms.append(param)

  with torch.no_grad():
    for name, param in model_3.named_parameters():
      if 'fc1.weight' in name:
        param.copy_(mdl_2_prms[0].clone().detach().requires_grad_(False))
      if 'fc1.bias' in name:
        param.copy_(mdl_2_prms[1].clone().detach().requires_grad_(False))   
      if 'fc2.weight' in name:
        param.copy_(mdl_2_prms[2].clone().detach().requires_grad_(False))
      if 'fc2.bias' in name:
        param.copy_(mdl_2_prms[3].clone().detach().requires_grad_(False))
      if 'fc3.weight' in name:
        param.copy_(mdl_2_prms[4].clone().detach().requires_grad_(False))
      if 'fc3.bias' in name:
        param.copy_(mdl_2_prms[5].clone().detach().requires_grad_(False))
      if 'fc4.weight' in name:
        param.requires_grad = True
      if 'fc4.bias' in name:
        param.requires_grad = True   

  model_children = list(model_3.children())

  model_3_fc1_weight_mask = MASK
  model_3_fc1_bias_mask = [1 if val != 0 else val for val in model_3.fc1.bias]
  model_3_fc2_weight_mask = [[1 if val != 0 else val for val in subl] for subl in model_3.fc2.weight]
  model_3_fc2_bias_mask = [1 if val != 0 else val for val in model_3.fc2.bias]
  model_3_fc3_weight_mask = [[1 if val != 0 else val for val in subl] for subl in model_3.fc3.weight]
  model_3_fc3_bias_mask = [1 if val != 0 else val for val in model_3.fc3.bias]

                              #_________________________________________________ IIId. Update only sub-elements of weights based on masks

  if torch.cuda.is_available():
    model_3.fc1.weight.register_hook(lambda grad: grad.mul_(torch.Tensor(model_3_fc1_weight_mask).to(device)))
    model_3.fc1.bias.register_hook(lambda grad: grad.mul_(torch.Tensor(model_3_fc1_bias_mask).to(device)))
    model_3.fc2.weight.register_hook(lambda grad: grad.mul_(torch.Tensor(model_3_fc2_weight_mask).to(device)))
    model_3.fc2.bias.register_hook(lambda grad: grad.mul_(torch.Tensor(model_3_fc2_bias_mask).to(device)))
    model_3.fc3.weight.register_hook(lambda grad: grad.mul_(torch.Tensor(model_3_fc3_weight_mask).to(device)))
    model_3.fc3.bias.register_hook(lambda grad: grad.mul_(torch.Tensor(model_3_fc3_bias_mask).to(device)))
  else:
    model_3.fc1.weight.register_hook(lambda grad: grad.mul_(torch.Tensor(model_3_fc1_weight_mask)))
    model_3.fc1.bias.register_hook(lambda grad: grad.mul_(torch.Tensor(model_3_fc1_bias_mask)))
    model_3.fc2.weight.register_hook(lambda grad: grad.mul_(torch.Tensor(model_3_fc2_weight_mask)))
    model_3.fc2.bias.register_hook(lambda grad: grad.mul_(torch.Tensor(model_3_fc2_bias_mask))) 
    model_3.fc3.weight.register_hook(lambda grad: grad.mul_(torch.Tensor(model_3_fc3_weight_mask)))
    model_3.fc3.bias.register_hook(lambda grad: grad.mul_(torch.Tensor(model_3_fc3_bias_mask)))

                               #________________________________________________ IIIe. Model_3 training

  #criterion = nn.BCELoss()
  criterion = nn.MSELoss()
  nb2 = config["nb2"]                                                            
  optimizer2 = optim.Adam(filter(lambda p: p.requires_grad, model_3.parameters()), 
                       lr=config["learning_rate2"])
  scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=config["step_size2"], gamma=config["gamma2"])
    
  train_losses, dev_losses, train_acc, train_pcc, train_fcc, train_rcc, dev_acc, dev_pcc, dev_fcc, dev_rcc = [], [], [], [], [], [], [], [], [], []
  for epoch in range(nb2):                  
      running_loss_tr = 0.0
      running_acc = 0.0
      running_pcc = 0.0 
      running_fcc = 0.0 
      running_rcc = 0.0                
      pp = 0

      for param in model_3.parameters():
            param.requires_grad = True
      with torch.set_grad_enabled(True):
            model_3.train()
            features = np.zeros((BATCH_SIZE_tr, MNIST_NUM_PIXELS))
            labels = np.zeros(BATCH_SIZE_tr)
            features_train = n14_train
            optimizer2.zero_grad()
            sparsity = 0
            values = features_train
            features = model_3(features_train)
            labels = n14_train_trait
            if torch.cuda.is_available():
              mse_loss = criterion(features.squeeze(), labels.float())
            loss = mse_loss #+ BETA * sparsity          # KL Div
            #loss = mse_loss + weight_decay * l2_reg    # L2 regularization
            loss.backward()
            optimizer2.step()
            scheduler2.step()
            running_loss_tr = loss.item()
            top_class = (features>0.5).float()
            if torch.cuda.is_available():        
              running_acc = accuracy_score(labels.cpu(), top_class.detach().cpu().numpy())
              running_pcc = precision_score(labels.cpu(), top_class.detach().cpu().numpy())

              #______________________PRECISION train_________________________
              with open('/content/PRECISION_train.txt', 'a') as f5:
                print(count_dl4_num_chp, file = f5)
                print(top_class.detach().cpu().numpy(), file=f5)              
              #______________________________________________________________
              running_fcc = f1_score(labels.cpu(), top_class.detach().cpu().numpy())
              running_rcc = recall_score(labels.cpu(), top_class.detach().cpu().numpy())

            epoch_loss = running_loss_tr
            train_losses.append(epoch_loss)
            train_acc.append(running_acc)
            train_pcc.append(running_pcc)
            train_fcc.append(running_fcc)
            train_rcc.append(running_rcc)

                               #________________________________________________ IIIf. Model_3 testing
      mse_loss = 0
      pp = 0
      running_loss_te = 0.0
      epoch_loss = 0.0
      loss = 0.0
      acc = 0.0
      pcc = 0.0
      fcc = 0.0
      rcc = 0.0            
      for param in model_3.parameters():
            param.requires_grad = False
      with torch.set_grad_enabled(False):
            model_3.eval()
            features_test = np.zeros((BATCH_SIZE_te, MNIST_NUM_PIXELS))
            labels_test = np.zeros(BATCH_SIZE_te)
            features_test_tmp = n14_test
            features_test = model_3(features_test_tmp)
            labels_test = n14_test_trait
            if torch.cuda.is_available():
              mse_loss = criterion(features_test.squeeze(), labels_test.float())
            loss = mse_loss
            running_loss_te = loss.item()
            top_class_dev = (features_test>0.5).float()
            if torch.cuda.is_available():        
              acc = accuracy_score(labels_test.cpu(), top_class_dev.detach().cpu().numpy())
              pcc = precision_score(labels_test.cpu(), top_class_dev.detach().cpu().numpy())

              #______________________PRECISION test_________________________
              with open('/content/PRECISION test.txt', 'a') as f6:
                print(count_dl4_num_chp, file = f6)
                print(top_class_dev.detach().cpu().numpy(), file=f6)              
              #_____________________________________________________________
              fcc = f1_score(labels_test.cpu(), top_class_dev.detach().cpu().numpy())
              rcc = recall_score(labels_test.cpu(), top_class_dev.detach().cpu().numpy())

            epoch_loss = running_loss_te
            dev_losses.append(epoch_loss)
            dev_acc.append(acc)
            dev_pcc.append(pcc)
            dev_fcc.append(fcc)
            dev_rcc.append(rcc)
                                       #________________________________________ IIIg. Unfreezing weights
      for name, param in model_2.named_parameters():
            if 'enc1.weight' in name:
              param.requires_grad = True
            if 'enc1.bias' in name:
              param.requires_grad = True
            if 'enc2.weight' in name:
              param.requires_grad = True
            if 'enc2.bias' in name:
              param.requires_grad = True
            if 'enc3.weight' in name:
              param.requires_grad = True            
            if 'enc3.bias' in name:
              param.requires_grad = True
            if 'enc4.weight' in name:
              param.requires_grad = True            
            if 'enc4.bias' in name:
              param.requires_grad = True

      if count_dl4_num_chp % 1 == 0:     # frequency of saving pth files!!!!!!!
            chp_dir = "chp_"+str(count_dl4_num_chp) + ".pth"
            with tune.checkpoint_dir(count_dl4_num_chp) as checkpoint_dir:
                path = os.path.join(checkpoint_dir, "checkpoint")
                torch.save({"step": count_dl4_num_chp, "model_state_dict": model_2.state_dict(),
                            "optim": optimizer1.state_dict(), "accuracy": acc, "precision": pcc, "f1": fcc, "recall": rcc}, "/content/chpnt"+"/"+chp_dir)
      step += 1
      tune.report(loss = (running_loss_te), accuracy = acc, precision = pcc, f1 = fcc, recall = rcc)



#_______________________________________________________________________________ IV. Main() 

def main(num_samples=10, max_num_epochs=10, gpus_per_trial=1):

                                                #_______________________________ IVa. Search space
    data_dir = os.path.abspath("checkpoint")
    #load_data(data_dir)
    config = {
        "nb1": tune.choice([200, 400]), 
        "nb2": tune.choice([200, 400]),        
        "optimizer" : tune.choice(["optimizer_ADAM","optimizer_SGD","optimizer_ADAgrad","optimizer_RMS"]),
        "reg" : tune.choice(["sparse_norm","L2_norm"]),
        "BETA" : tune.loguniform(0.0001, 0.01),
        "BETA1" : tune.loguniform(0.0001, 0.01),
        "BETA2" : tune.loguniform(0.0001, 0.01),
        "BETA3" : tune.loguniform(0.0001, 0.01),
        "weight_decay" : tune.loguniform(0.0001, 0.01),
        "RHO1":tune.loguniform(0.05, 0.4),
        "RHO2":tune.loguniform(0.05, 0.4),
        "RHO3":tune.loguniform(0.05, 0.4),
        "RHO4":tune.loguniform(0.05, 0.4),
        "netD_lr": tune.loguniform(0.000001,0.01),
        "netE1_lr": tune.loguniform(0.000001,0.01),
        "netE2_lr": tune.loguniform(0.000001,0.01),
        "netE3_lr": tune.loguniform(0.000001,0.01),
        "netE1_mom": tune.loguniform(0.1,0.99),
        "netE2_mom": tune.loguniform(0.1,0.99),
        "netE3_mom": tune.loguniform(0.1,0.99),
        "netF_lr": tune.loguniform(0.000001,0.01),
        "netG_lr": tune.loguniform(0.000001,0.01),
        "learning_rate2" : tune.loguniform(0.00001, 0.01),      
        "gamma2" :  tune.choice([0.01, 0.1, 1]),
        "step_size2" : tune.choice([50, 100, 200]),
        "gamma1" :  tune.choice([0.01, 0.1, 1]),
        "step_size1" : tune.choice([50, 100, 200]),        
        "repeat": tune.grid_search(list(range(7)))  # (2/2) number of cross-validations : 7
    }

    scheduler = ASHAScheduler(
        metric="loss",
        mode="min",
        max_t=max_num_epochs,
        grace_period=1,
        reduction_factor=2)
    reporter = CLIReporter(
        # parameter_columns=["l1", "l2", "lr", "batch_size"],
        metric_columns=["loss", "accuracy", "training_iteration", "precision", "f1", "recall"],
        max_progress_rows=80, max_error_rows=20, max_report_frequency=240)
    searcher = tune.search.basic_variant.BasicVariantGenerator(
    constant_grid_search=True)

    result = tune.run(
        partial(train_gene, data_dir=data_dir),
        name = "train_gene",
        resources_per_trial={"cpu": 2, "gpu": gpus_per_trial},
        search_alg=searcher,
        config=config,
        num_samples=num_samples,
        scheduler=scheduler,
        verbose = 1,
        progress_reporter=reporter)

    best_trial = result.get_best_trial("loss", "min", "last")

                                                            #___________________ IVb. Print results to file
    with open('/content/report_best_config_file.txt', 'a') as f2:
      print("Best trial config: {}".format(best_trial.config), file=f2)
      print("Best trial final validation loss: {}".format(
        best_trial.last_result["loss"]), file=f2)
      print("Best trial final validation accuracy: {}".format(
        best_trial.last_result["accuracy"]), file=f2)
      print("Best trial final validation precision: {}".format(
        best_trial.last_result["precision"]), file=f2)
      print("Best trial final validation f1: {}".format(
        best_trial.last_result["f1"]), file=f2)
      print("Best trial final validation recall: {}".format(
        best_trial.last_result["recall"]), file=f2)
      print("______________________________________________\n", file=f2)
          
if __name__ == "__main__":

    if os.path.exists('/content/drive/My Drive/datasets/count_dl5_chp.txt'):
        os.remove('/content/drive/My Drive/datasets/count_dl5_chp.txt')

    if os.path.exists('/content/drive/My Drive/datasets/count_dl5.txt'):
      count_dl9 = open('/content/drive/My Drive/datasets/count_dl5.txt').read().count('\n')
      count_dl9=count_dl9+5
      print("______________CAUTION: FILE ALREADY EXISTS_____________")
      print(count_dl9)
      if (count_dl9 > 3):
        print("______________FILE DELETED in main()_____________")
        os.remove('/content/drive/My Drive/datasets/count_dl5.txt')
        count_dl4 = []
                                   #____________________________________________ IVc. Initialize random seeds for numpy, tensorflow and pytorch
    np.random.seed(1234)
    python_random.seed(1234)
    tf.random.set_seed(1234)
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    if DEVICE =='cuda':
      torch.cuda.manual_seed_all(356)
    else:
      torch.manual_seed(356)
    
    device = "cpu"
    if torch.cuda.is_available():
      device = "cuda:0"
    if device =='cuda:0':
      torch.cuda.manual_seed_all(356)
    else:
      torch.manual_seed(356)    
                                   #____________________________________________ IVd. Initialize MASK

    #a = torch.empty(669, inp_pad).uniform_(0, 1)   # for a random mask uncomment these 2 lines
    #MASK = torch.bernoulli(a)
    MASK = init_layer_lat.t().float()   # MASK based on specific Transcription factor-Gene connections
    MASK.requires_grad = False

                                   #____________________________________________ IVe. You can change the number of training iterations,
#                                                                                     epochs and available GPUs here:  
    main(num_samples=80, max_num_epochs=300, gpus_per_trial=1)


In [None]:
########################################################################################################################
##########################_________________________ POST-PROCESSING _________________________###########################
########################################################################################################################

In [None]:
#__________CODE II.1______________ Run this module to analyse the output from train_gene().

#______INPUT FILES   :     iteration_list.txt


import pandas as pd
import tensorflow as tf
import numpy
import numpy as np
import torch

df = pd.read_csv("/content/drive/My Drive/datasets/iteration_list.txt", sep="|", usecols = [32,33,35,36,37], low_memory = True)  
#print(df[0:])
df1 = df.groupby(np.arange(len(df))//7).mean()
df2 = df.groupby(np.arange(len(df))//7).std()

#_____________________________________________
print("All the metrics of the iteration with the maximum mean accuracy: ")
#print(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()])                     # Max mean accuracy
print(df1.iloc[df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0],:]) 
#print(df1.iloc[df2.loc[df2.iloc[:,1] == df2.iloc[:,1].min()].index[0],:])
print("____________________________________________")
print("All the metrics of the iteration with the minimum mean loss: ")
#print(df1.loc[df1.iloc[:,0] == df1.iloc[:,0].min()])                     # Min mean loss
print(df1.iloc[df1.loc[df1.iloc[:,0] == df1.iloc[:,0].min()].index[0],:])
#print(df1.iloc[df2.loc[df2.iloc[:,0] == df2.iloc[:,0].min()].index[0],:])

print("#_________________________#_____________________________#")
print("#_________________________#_____________________________#")

df3 = pd.read_csv("/content/drive/My Drive/datasets/iteration_list.txt", sep="|")    #__ Best Parameters!!!
df3 = df3[df3.index % 7 == 0]

print("The parameters of the iteration with the maximum accuracies: ")
print(df3.iloc[df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0],:])
print("- - - - - - - - - - - - - - - - - - - - - - - - - - -")
print("The parameters of the iteration with the minimum loss: ")
print(df3.iloc[df1.loc[df1.iloc[:,0] == df1.iloc[:,0].min()].index[0],:])

print("_________________________________________________________________________")

print("The iteration with the maximum accuracies: ")
dg1 = []
for i in range(0,7):
  print(df.iloc[df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0]*7+i,1])
  dg1.append(df.iloc[df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0]*7+i,1])

print("_______________________")

print("The iteration with the minimum loss: ")
dg2 = []
for i in range(0,7):
  print(df.iloc[df1.loc[df1.iloc[:,0] == df1.iloc[:,0].min()].index[0]*7+i,0])
  dg2.append(df.iloc[df1.loc[df1.iloc[:,0] == df1.iloc[:,0].min()].index[0]*7+i,0])

minpos = dg2.index(min(dg2)) + df1.loc[df1.iloc[:,0] == df1.iloc[:,0].min()].index[0]*7 + 1
maxpos = dg1.index(max(dg1)) + df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0]*7 + 1

print("____________________________________")
print('Download chp_',minpos,'.pth')           #this is the weight matrix of the model with the minimum loss
print('Download chp_',maxpos,'.pth')           #this is the weight matrix of the model with the maximum accuracy
print("____________________________________")


In [None]:
#__________CODE II.2______________  Run this module to plot the metrics from the best of the 80 training iterations, 
#                                   over the course of the 7 cross-validations.

from matplotlib.ticker import MaxNLocator
from matplotlib import pyplot
import matplotlib.pyplot as plt
from scipy.interpolate import splrep, splev

rbm4 = []
rbm8 = []
rbm12 = []
rbm16 = []
rbm20 = []
rbm24 = []
rbm28 = []

dae4 = []
dae8 = []
dae12 = []
dae16 = []
dae20 = []
dae24 = []
dae28 = []

gbm4 = []
gbm8 = []
gbm12 = []
gbm16 = []
gbm20 = []
gbm24 = []
gbm28 = []

kae4 = []
kae8 = []
kae12 = []
kae16 = []
kae20 = []
kae24 = []
kae28 = []

pbm4 = []
pbm8 = []
pbm12 = []
pbm16 = []
pbm20 = []
pbm24 = []
pbm28 = []

rbm4.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+0,0])
rbm8.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+1,0])
rbm12.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+2,0])
rbm16.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+3,0])
rbm20.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+4,0])
rbm24.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+5,0])
rbm28.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+6,0])

dae4.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+0,1])
dae8.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+1,1])
dae12.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+2,1])
dae16.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+3,1])
dae20.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+4,1])
dae24.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+5,1])
dae28.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+6,1])

gbm4.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+0,2])
gbm8.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+1,2])
gbm12.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+2,2])
gbm16.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+3,2])
gbm20.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+4,2])
gbm24.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+5,2])
gbm28.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+6,2])

kae4.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+0,3])
kae8.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+1,3])
kae12.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+2,3])
kae16.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+3,3])
kae20.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+4,3])
kae24.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+5,3])
kae28.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+6,3])

pbm4.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+0,4])
pbm8.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+1,4])
pbm12.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+2,4])
pbm16.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+3,4])
pbm20.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+4,4])
pbm24.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+5,4])
pbm28.append(df.iloc[(df1.loc[df1.iloc[:,1] == df1.iloc[:,1].max()].index[0])*7+6,4])

y1 = []
y2 = []
y3 = []
y4 = []
y5 = []

x = np.arange(1, 8)

for i in range(len(rbm4)):
  y1.append(np.transpose(np.array([[rbm4[i], rbm8[i], rbm12[i], rbm16[i], rbm20[i], rbm24[i], rbm28[i]]])))
  y2.append(np.transpose(np.array([[dae4[i], dae8[i], dae12[i], dae16[i], dae20[i], dae24[i], dae28[i]]])))
  y3.append(np.transpose(np.array([[gbm4[i], gbm8[i], gbm12[i], gbm16[i], gbm20[i], gbm24[i], gbm28[i]]])))
  y4.append(np.transpose(np.array([[kae4[i], kae8[i], kae12[i], kae16[i], kae20[i], kae24[i], kae28[i]]])))
  y5.append(np.transpose(np.array([[pbm4[i], pbm8[i], pbm12[i], pbm16[i], pbm20[i], pbm24[i], pbm28[i]]])))


mp1 = []
for i in range(len(y1)):
  mp1.append(np.mean(y1[i][:].astype(float)))

mp2 = []
for i in range(len(y2)):
  mp2.append(np.mean(y2[i][:].astype(float)))

mp3 = []
for i in range(len(y3)):
  mp3.append(np.mean(y3[i][:].astype(float)))

mp4 = []
for i in range(len(y4)):
  mp4.append(np.mean(y4[i][:].astype(float)))

mp5 = []
for i in range(len(y5)):
  mp5.append(np.mean(y5[i][:].astype(float)))

max_value1 = min(mp1)
index_max1 = mp1.index(max_value1)
max_value2 = max(mp2)
index_max2 = mp2.index(max_value2)
max_value3 = max(mp3)
index_max3 = mp3.index(max_value3)
max_value4 = max(mp4)
index_max4 = mp4.index(max_value4)
max_value5 = max(mp5)
index_max5 = mp5.index(max_value5)
#__________________________

bspl1 = []
bspl2 = []
bspl3 = []
bspl4 = []
bspl5 = []

y6 = []
y7 = []
y8 = []
y9 = []
y10 = []

for i in range(len(rbm4)):
  y6.append(np.transpose(np.array([rbm4[i], rbm8[i], rbm12[i], rbm16[i], rbm20[i], rbm24[i], rbm28[i]])))
  y7.append(np.transpose(np.array([dae4[i], dae8[i], dae12[i], dae16[i], dae20[i], dae24[i], dae28[i]])))
  y8.append(np.transpose(np.array([gbm4[i], gbm8[i], gbm12[i], gbm16[i], gbm20[i], gbm24[i], gbm28[i]])))
  y9.append(np.transpose(np.array([kae4[i], kae8[i], kae12[i], kae16[i], kae20[i], kae24[i], kae28[i]])))
  y10.append(np.transpose(np.array([pbm4[i], pbm8[i], pbm12[i], pbm16[i], pbm20[i], pbm24[i], pbm28[i]])))

for i in range(len(y6)):
  bspl1.append(splrep(x, y6[i], k=3))
  bspl2.append(splrep(x, y7[i], k=3))
  bspl3.append(splrep(x, y8[i], k=3))
  bspl4.append(splrep(x, y9[i], k=3))
  bspl5.append(splrep(x, y10[i], k=3))

x_new = np.linspace(1, 7, 100)

#__________________________
ax = plt.figure(figsize=(20, 7)).gca()
ax.xaxis.set_major_locator(MaxNLocator(integer=True))

plt.title("Training Metrics")
plt.xlabel("cross-validation iterations")
plt.ylabel("metrics")
for i in range(len(y6)):
  bspl_y1 = splev(x_new,bspl1[i])
  bspl_y2 = splev(x_new,bspl2[i])
  bspl_y3 = splev(x_new,bspl3[i])
  bspl_y4 = splev(x_new,bspl4[i])
  bspl_y5 = splev(x_new,bspl5[i])

  if(i==index_max1):
    plt.plot(x_new, bspl_y1.astype(float), color ="grey",linewidth=6.0)
  if(i==index_max2):
    plt.plot(x_new, bspl_y2.astype(float), color ="green",linewidth=6.0)
  if(i==index_max3):
    plt.plot(x_new, bspl_y3.astype(float), color ="brown",linewidth=6.0)
  if(i==index_max4):
    plt.plot(x_new, bspl_y4.astype(float), color ="magenta",linewidth=6.0)
  if(i==index_max5):
    plt.plot(x_new, bspl_y5.astype(float), color ="pink",linewidth=6.0)

plt.legend(["Loss", "Accuracy", "Precision", "F1", "Recall"], loc ="upper right")
plt.show()
