## Train and stores embeddings

this colab loads a the code from git, trains GMM/GVAE model and stores the embeddings for further analysis



In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
#@title Set up and imports
from google.colab import drive 
import os 

import tensorflow as tf
from tensorflow.python.keras import models
from tensorflow.python.keras.engine import node
from tensorflow.python.keras.engine.node import Node 
import tensorflow_probability as tfp
tfd = tfp.distributions

import numpy as np
import scipy.sparse as sp

# notebook's location in the repo
WORKING_PATH = './drive/MyDrive/GGMbetaFactorVAE/GVAE/' 

if 'first_run' not in locals():
  drive.mount('/content/drive/')
  os.chdir(WORKING_PATH)

  first_run = False

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [3]:
from src.models import GM_VGAE, VGAE
from src import utils 

## Loading data 

In [28]:
## Loads data

network_path = 'data/diseasome/disease_network_adj.npy'
labels_path = 'data/diseasome/disease_network_types.npy'
output_path = 'data/saved/diseasome/model/'

data_params = dict(network_path = network_path,
                   labels_path = labels_path,
                   use_features=True,
                   auxiliary_prediction_task=False,
                   epochs=1000)

res = utils.load_and_build_dataset(data_params)
adj = res['adj']
aux_targets = res['target']
dataset = res['dataset']
val_edges = res['val_edges']
val_edges_false = res['val_edges_false']
test_edges = res['test_edges']
test_edges_false = ['test_edges_false']

adj_orig = adj - sp.dia_matrix((adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape)
adj_orig.eliminate_zeros()
adj_orig = adj_orig.toarray()

DONE: train_edges
DONE: test_edges_false
DONE: val_edges_false
True
True
True
True
True


## Defining a train step

In [29]:
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score


def get_roc_score(edges_pos, edges_neg, emb):

    def sigmoid(x):
        return 1 / (1 + np.exp(-x))

    # Predict on test set of edges
    adj_rec = np.dot(emb, emb.T)
    preds = []
    pos = []
    for e in edges_pos:
        preds.append(sigmoid(adj_rec[e[0], e[1]]))
        pos.append(adj_orig[e[0], e[1]])

    preds_neg = []
    neg = []
    for e in edges_neg:
        preds_neg.append(sigmoid(adj_rec[e[0], e[1]]))
        neg.append(adj_orig[e[0], e[1]])

    preds_all = np.hstack([preds, preds_neg])
    labels_all = np.hstack([np.ones(len(preds)), np.zeros(len(preds_neg))])
    roc_score = roc_auc_score(labels_all, preds_all)
    ap_score = average_precision_score(labels_all, preds_all)

    return roc_score, ap_score

In [30]:

def train_step(adj_normalized, features, adj_label, norm, pos_weight, experiment_params, aux_targets=None):   
    """ Defines basic training step """

    model_type = experiment_params['model']
    assert model_type in ['VGAE', 'GM_VGAE']
    assert ((model_type=='VGAE' and aux_targets is None) or (model_type=='GM_VGAE'))

    with tf.GradientTape() as tape:
        beta = 1
        adj_label = tf.reshape(adj_label, [-1])

        Q, Q_log_std, reconstructed = model(adj_normalized, features)
        reconstruction_loss = norm * tf.math.reduce_mean(
            tf.nn.weighted_cross_entropy_with_logits(labels=adj_label, logits=reconstructed, pos_weight=pos_weight)
        ) 
        node_num = adj_normalized.shape[0] 
        if model_type == 'VGAE':
            # kl = - (0.5 / node_num) * tf.math.reduce_mean(
            #     tf.math.reduce_sum(1 + 2 * Q_log_std - tf.math.square(Q.mean()) - tf.math.square(Q.stddev()), axis=1)
            # ) 

            kl = tf.reduce_mean(tfd.kl_divergence(Q, model.prior)) / node_num
            classification_loss = 0
        else:
            kl = tf.reduce_mean(utils.mc_kl_divergence(Q, model.prior)) / node_num
            # kl = tf.reduce_mean(kl_divergence_upper_bound(Q, model.prior)) / node_num

            if experiment_params['auxiliary_prediction_task']:
                classification_loss = tf.reduce_mean(
                    tf.nn.softmax_cross_entropy_with_logits(logits=model.cy_logits, labels=aux_targets)
                )
                classification_accuracy = tf.reduce_mean(
                    tf.cast(tf.argmax(model.cy_logits, axis=1)==tf.argmax(aux_targets, axis=1), tf.float32), axis=0
                )
            else: 
                classification_loss = 0
                classification_accuracy = None
        
        vae_loss = reconstruction_loss + beta*kl + classification_loss

    gradients = tape.gradient(vae_loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    # metrics
    RECONSTRUCTION.append(reconstruction_loss.numpy())
    KL_LOSSES.append(kl.numpy())
    LOSSES.append(vae_loss.numpy())
    if experiment_params['auxiliary_prediction_task']:
        CLASSIFICATION_LOSSES.append(classification_loss.numpy())
        CLASSIFICATION_ACCURACIES.append(classification_accuracy.numpy())

    if experiment_params['model'] == 'GM_VGAE':
        # taking the mean of the mixture doesn't work in this case.
        # NOTE we could approximate the mean 
        emb = tf.squeeze(Q.sample(1), axis=0).numpy() 
    else: 
        emb = Q.mean().numpy()

    roc_curr, ap_curr = get_roc_score(val_edges, val_edges_false, emb)
    VAL_ROC_SCORE.append(roc_curr)
    VAL_AP_SCORE.append(ap_curr)
    

## Experiment 

In [31]:
## Initializes experiment 

experiment_params = dict(
    learning_rate=1e-3,
    epochs=data_params['epochs'],
    hidden=32,
    latent_size=16,
    dropout=0.2 ,
    model='GM_VGAE',
    # model='VGAE',
    use_features=data_params['use_features'],      
    auxiliary_prediction_task=data_params['auxiliary_prediction_task'],
    save_path=output_path 
)

# auxiliary prediction can only be done with GM_VAE
assert not (experiment_params['model']=='VGAE' and experiment_params['auxiliary_prediction_task'])

    
optimizer = tf.keras.optimizers.Adam(learning_rate=experiment_params['learning_rate'])
RECONSTRUCTION = []
KL_LOSSES = []
LOSSES = []
CLASSIFICATION_LOSSES = []
CLASSIFICATION_ACCURACIES = []
VAL_ROC_SCORE = []
VAL_AP_SCORE = []

pos_weight = float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()
norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)


class_num = aux_targets.shape[1]
node_num = adj.shape[0]

if experiment_params['model'] == 'VGAE':
    model = VGAE(node_num=node_num, 
                  hidden=experiment_params['hidden'], 
                  latent_size=experiment_params['latent_size'],
                  dropout=experiment_params['dropout'])
elif experiment_params['model'] == 'GM_VGAE':
    model = GM_VGAE(node_num=node_num, 
                    class_num=class_num, 
                    latent_size=experiment_params['latent_size'],
                    hidden=experiment_params['hidden'],
                    dropout=experiment_params['dropout'])


In [None]:
#@title Training loop

e = 0
for adj_norm, features, label in dataset:
    if experiment_params['model'] == 'GM_VGAE':
        train_step(adj_norm, features, 
                label, norm, pos_weight, 
                experiment_params, aux_targets=aux_targets)
    else:
         train_step(adj_norm, features, 
                label, norm, pos_weight, 
                experiment_params)
    
    if e % 100 == 0:
        if experiment_params['auxiliary_prediction_task']:
            print(
                'total: {:.2f}, rec: {:.2f}, classification: {:.2f}, kl_loss: {:.2f}'.format(
                LOSSES[-1], RECONSTRUCTION[-1], CLASSIFICATION_ACCURACIES[-1], KL_LOSSES[-1]))
        else:
            print('total: {:.2f}, rec: {:.2f}, kl_loss: {:.2f}'.format(
                LOSSES[-1], RECONSTRUCTION[-1], KL_LOSSES[-1]))

    e +=1

# test_roc, test_ap = get_roc_score(test_edges, test_edges_false, )

total: 1.90, rec: 1.90, kl_loss: 0.00
total: 1.11, rec: 1.10, kl_loss: 0.00
total: 0.76, rec: 0.74, kl_loss: 0.02
total: 0.71, rec: 0.68, kl_loss: 0.03
total: 0.62, rec: 0.59, kl_loss: 0.03
total: 0.56, rec: 0.52, kl_loss: 0.03
total: 0.53, rec: 0.49, kl_loss: 0.04
total: 0.51, rec: 0.48, kl_loss: 0.04
total: 0.50, rec: 0.47, kl_loss: 0.04


In [None]:
#@title Diagnosis plots - 
#@markdown 

import matplotlib.pyplot as plt 
import seaborn as sns 
from sklearn.manifold import TSNE
import numpy as np


Q = model.Q
z = tf.squeeze(Q.sample(1))
    
z_proj = TSNE(n_components=2).fit_transform(z)
    

fig, axs = plt.subplots(3, 2, figsize=(20, 15))
axs = axs.flatten()

axs[0].plot(KL_LOSSES)
axs[0].set_title('KL loss')

axs[1].plot(RECONSTRUCTION)
axs[1].set_title('Reconstruction loss')

if experiment_params['auxiliary_prediction_task']:
    # axs[2].plot(CLASSIFICATION_LOSSES)
    # axs[2].set_title('Classification loss')

    axs[2].plot(CLASSIFICATION_ACCURACIES)
    axs[2].set_title('Classification accuracy')

if 'z_proj' in locals():
    # if experiment_params['auxiliary_prediction_task']:
    sns.scatterplot(
        x=z_proj[:, 0], y=z_proj[:, 1],
        palette=sns.color_palette("hls", aux_targets.shape[1]),
        hue=np.where(aux_targets==1)[1],
        legend='full', 
        alpha=0.8,
        ax=axs[3]
    )
    # else:
    #     sns.scatterplot(
    #         x=z_proj[:, 0], y=z_proj[:, 1],
    #         legend='full', 
    #         alpha=0.8,
    #         ax=axs[3]
    #     )

    axs[3].set_title('TSNE projection')

def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w

axs[4].set_title('validation ROC score')
axs[4].plot(VAL_ROC_SCORE)
axs[4].plot(moving_average(VAL_ROC_SCORE, 20))

axs[5].set_title('validation AP score')
axs[5].plot(VAL_AP_SCORE)
axs[5].plot(moving_average(VAL_AP_SCORE, 20))

In [9]:
# saves model

model.save_weights(experiment_params['save_path'])

## TODO 

In [10]:
## TODO what is going on with roc and precision? train a simple gvae, see what happens 
## TODO generate save_path string from experiment params 
## TODO script for multiple training runs with different experiment params and checkpoint generation  
## TODO when we're sampling in the GMVGAE we might as well sample the elbo directly 