In [1]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
%matplotlib inline

In [2]:
#Model updated for TF2.0
#python -m ae_model_train --batchsize 100 --cvfold 0 --alpha_T 1.0 --alpha_E 1.0 --alpha_M 1.0 --lambda_TE 0.0 --latent_dim 3 --n_epochs 2000 --n_steps_per_epoch 500 --ckpt_save_freq 100 --run_iter 0 --model_id 'v1' --exp_name 'TE_Patchseq_Bioarxiv'
import argparse
import os
import pdb
import re
import socket
import sys
import timeit

import numpy as np
import scipy.io as sio
import tensorflow as tf
from tensorflow.keras import layers
from data_funcs import TEM_get_splits
from ae_model_def import Model_TE
import csv
from timebudget import timebudget

from ae_model_train import Datagen
from ae_model_train import set_paths

batchsize=200
cvfold=0
alpha_T=1.0
alpha_E=1.0
alpha_M=1.0
lambda_TE=1.0
latent_dim=3
n_epochs=100
n_steps_per_epoch=500
ckpt_save_freq=100
run_iter=0
model_id='Frozen_net'
exp_name='TE_Patchseq_Bioarxiv'
    
dir_pth = set_paths(exp_name=exp_name)
fileid = model_id + \
    '_aT_' + str(alpha_T) + \
    '_aE_' + str(alpha_E) + \
    '_aM_' + str(alpha_M) + \
    '_cs_' + str(lambda_TE) + \
    '_ld_' + str(latent_dim) + \
    '_bs_' + str(batchsize) + \
    '_se_' + str(n_steps_per_epoch) +\
    '_ne_' + str(n_epochs) + \
    '_cv_' + str(cvfold) + \
    '_ri_' + str(run_iter)
fileid = fileid.replace('.', '-')

In [3]:
#Load data:
D = sio.loadmat(dir_pth['data']+'PS_v4_beta_0-4_matched_well-sampled.mat',squeeze_me=True)
cvset,testset = TEM_get_splits(D)

train_ind = cvset[cvfold]['train']
train_T_dat = tf.constant(D['T_dat'][train_ind,:])
train_E_dat = D['E_dat'][train_ind,:]
train_M_dat = D['M_dat'][train_ind]
train_E_dat = tf.constant(np.concatenate([train_E_dat,train_M_dat.reshape(train_M_dat.size,1)],axis=1))

val_ind = cvset[cvfold]['val']
val_T_dat = D['T_dat'][val_ind,:]
val_E_dat = D['E_dat'][val_ind,:]
val_M_dat = D['M_dat'][val_ind]
val_E_dat = tf.constant(np.concatenate([val_E_dat,val_M_dat.reshape(val_M_dat.size,1)],axis=1))
Xval = (tf.constant(val_T_dat),tf.constant(val_E_dat))

maxsteps = tf.constant(n_epochs*n_steps_per_epoch)
batchsize = tf.constant(batchsize)
alpha_T   = tf.constant(alpha_T,dtype=tf.float32)
alpha_E   = tf.constant(alpha_E,dtype=tf.float32)
alpha_M   = tf.constant(alpha_M,dtype=tf.float32)
lambda_TE = tf.constant(lambda_TE,dtype=tf.float32)

def min_var_loss(zi, zj, Wij=None):
    """SVD is calculated over entire batch. MSE is calculated over only paired entries within batch
    """
    batch_size = tf.shape(zi)[0]
    if Wij is None:
        Wij_ = tf.ones([batch_size, ])
    else:
        Wij_ = tf.reshape(Wij, [batch_size, ])

    zi_paired = tf.boolean_mask(zi, tf.math.greater(Wij_, 1e-2))
    zj_paired = tf.boolean_mask(zj, tf.math.greater(Wij_, 1e-2))
    Wij_paired = tf.boolean_mask(Wij_, tf.math.greater(Wij_, 1e-2))

    vars_j_ = tf.square(tf.linalg.svd(zj - tf.reduce_mean(zj, axis=0), compute_uv=False))/tf.cast(batch_size - 1, tf.float32)
    vars_j  = tf.where(tf.math.is_nan(vars_j_), tf.zeros_like(vars_j_) + tf.cast(1e-2,dtype=tf.float32), vars_j_)
    weighted_distance = tf.multiply(tf.sqrt(tf.reduce_sum(tf.math.squared_difference(zi_paired, zj_paired),axis=1)),Wij_paired)
    loss_ij = tf.reduce_mean(weighted_distance,axis=None)/tf.maximum(tf.reduce_min(vars_j, axis=None),tf.cast(1e-2,dtype=tf.float32))
    return loss_ij

def report_losses(XT, XE, zT, zE, XrT, XrE,epoch, datatype='train', verbose=False):
    mse_loss_T = tf.reduce_mean(tf.math.squared_difference(XT, XrT))
    mse_loss_E = tf.reduce_mean(tf.math.squared_difference(XE, XrE))
    mse_loss_M = tf.reduce_mean(tf.math.squared_difference(XE[:, -1], XrE[:, -1]))
    mse_loss_TE = tf.reduce_mean(tf.math.squared_difference(zT, zE))

    if verbose:
        print('Epoch:{:5d}, '
              'mse_T: {:0.5f}, '
              'mse_E: {:0.5f}, '
              'mse_M: {:0.5f}, '
              'mse_TE: {:0.5f}'.format(epoch,
                                       mse_loss_T.numpy(),
                                       mse_loss_E.numpy(),
                                       mse_loss_M.numpy(),
                                       mse_loss_TE.numpy()))

    log_name = [datatype+i for i in ['epoch','mse_T', 'mse_E', 'mse_M', 'mse_TE']]
    log_values = [epoch, mse_loss_T.numpy(), mse_loss_E.numpy(),
                  mse_loss_M.numpy(), mse_loss_TE.numpy()]
    return log_name, log_values


optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
train_generator = tf.data.Dataset.from_generator(Datagen,output_types=(tf.float32, tf.float32),
                                                 args=(maxsteps,batchsize,train_T_dat,train_E_dat))

model_TE = Model_TE(T_output_dim=train_T_dat.shape[1],
                    E_output_dim=train_E_dat.shape[1],
                    T_intermediate_dim=50,
                    E_intermediate_dim=40,
                    T_dropout=0.5,
                    E_gnoise_sd=0.05,
                    E_dropout=0.1,
                    latent_dim=latent_dim,
                    name='TE')

In [4]:
#Perform dummy inference to build the model:
x = tf.constant(np.random.rand(1,train_T_dat.shape[1]),dtype=tf.float32)
y = tf.constant(np.random.rand(1,train_E_dat.shape[1]),dtype=tf.float32)
_,_,_,_ = model_TE((x,y),train_T=False,train_E=False)

#Loading weights:
model_TE.load_weights('/Users/fruity/Dropbox/AllenInstitute/CellTypes/dat/result/TE_Patchseq_Bioarxiv/v1_aT_1-0_aE_1-0_aM_1-0_cs_1-0_ld_3_bs_200_se_500_ne_1500_cv_0_ri_0-weights.h5')

In [5]:
@tf.function
def train_fn(XT, XE, train_T=False, train_E=False, subnetwork='all'):
    with tf.GradientTape() as tape:
        zT, zE, XrT, XrE = model_TE((XT, XE), train_T=train_T, train_E=train_E)
        
        #Find the weights to update
        mse_loss_T = tf.reduce_mean(tf.math.squared_difference(XT, XrT))
        mse_loss_E = tf.reduce_mean(tf.math.squared_difference(XE[:, :-1], XrE[:, :-1]))
        mse_loss_M = tf.reduce_mean(tf.math.squared_difference(XE[:, -1], XrE[:, -1]))
        cpl_loss_TE = min_var_loss(zT, zE)
        loss = alpha_T*mse_loss_T + \
            alpha_E*mse_loss_E + \
            alpha_M*mse_loss_M + \
            lambda_TE*cpl_loss_TE

        #Apply updates if training any of the subnetworks
        if subnetwork is 'all':
            trainable_weights = [weight for weight in model_TE.trainable_weights]
            grads = tape.gradient(loss, trainable_weights)
            optimizer.apply_gradients(zip(grads, trainable_weights))
            
        if subnetwork is 'E':
            trainable_weights = [weight for weight in model_TE.trainable_weights if '_E' in weight.name]
            grads = tape.gradient(loss, trainable_weights)
            optimizer.apply_gradients(zip(grads, trainable_weights))
            
    return zT, zE, XrT, XrE

In [6]:
@tf.function
def finetune_fn(XT, XE, train_T=False, train_E=False, subnetwork=None):
    with tf.GradientTape() as tape:
        zT, zE, XrT, XrE = model_TE((XT, XE), train_T=train_T, train_E=train_E)

        #Find the weights to update
        if subnetwork=='Encoder_E':
            mse_loss_z = tf.reduce_mean(tf.math.squared_difference(zT,zE))        
            loss = mse_loss_z
            trainable_weights = [weight for weight in model_TE.trainable_weights if subnetwork in weight.name]
            grads = tape.gradient(loss, trainable_weights)
            optimizer.apply_gradients(zip(grads, trainable_weights))

        elif subnetwork=='Decoder_E':
            mse_loss_E = tf.reduce_mean(tf.math.squared_difference(XE[:, :-1], XrE[:, :-1]))
            mse_loss_M = tf.reduce_mean(tf.math.squared_difference(XE[:, -1], XrE[:, -1]))
            loss = mse_loss_M+mse_loss_E
            trainable_weights = [weight for weight in model_TE.trainable_weights if subnetwork in weight.name]
            grads = tape.gradient(loss, trainable_weights)
            optimizer.apply_gradients(zip(grads, trainable_weights))

    return zT, zE, XrT, XrE

In [7]:
#Checks to make sure training works as expected:
#train_T and train_E control the dropout and noise addition
fine_tune_epochs = 20
best_loss = np.inf
test_iters = 5
epoch = 1

#Full network train step:
print('Expect everything to change:--------------')
for epoch in range(test_iters):
    zT, zE, XrT, XrE = train_fn(XT=train_T_dat, XE=train_E_dat, train_T=True, train_E=True, subnetwork='all')
    train_log_name, train_log_values = report_losses(train_T_dat, train_E_dat, zT, zE, XrT, XrE, epoch=epoch, datatype='train_', verbose=True)

#Switch off data augmentation for the T arm and update only E arm:
print('\nExpect mse_T to be unchanged:--------------')
for epoch in range(test_iters):
    zT, zE, XrT, XrE = train_fn(XT=train_T_dat, XE=train_E_dat, train_T=False, train_E=True, subnetwork='E')
    train_log_name, train_log_values = report_losses(train_T_dat, train_E_dat, zT, zE, XrT, XrE, epoch=epoch, datatype='train_', verbose=True)

    
#Switch off data augmentation for the T arm and update only E arm:
print('\nExpect nothing to change:--------------')
for epoch in range(test_iters):
    zT, zE, XrT, XrE = train_fn(XT=train_T_dat, XE=train_E_dat, train_T=False, train_E=False, subnetwork=None)
    train_log_name, train_log_values = report_losses(train_T_dat, train_E_dat, zT, zE, XrT, XrE, epoch=epoch, datatype='train_', verbose=True)
            
        
#Testing only Encoder or decoder to change:
print('\nExpect mse_TE to be the same, but mse_E and mse_M to change:--------------')
for epoch in range(test_iters):
    zT, zE, XrT, XrE = finetune_fn(XT=train_T_dat, XE=train_E_dat, train_T=False, train_E=False, subnetwork='Decoder_E')
    train_log_name, train_log_values = report_losses(train_T_dat, train_E_dat, zT, zE, XrT, XrE, epoch=epoch, datatype='train_', verbose=True)

print('\nExpect mse_T to be the same.\n'
      'Condition checks equality of zT through Decoder_E on successive epochs:--------------')
for epoch in range(test_iters):
    #Running T representation through the E decoder:
    zT, zE, XrT, XrE = finetune_fn(XT=train_T_dat, XE=train_E_dat, train_T=False, train_E=False, subnetwork='Encoder_E')
    train_log_name, train_log_values = report_losses(train_T_dat, train_E_dat, zT, zE, XrT, XrE, epoch=epoch, datatype='train_', verbose=True)
    
    if epoch==0:
        zE = model_TE.decoder_E(zT, training=False)
        zEref = zE.numpy().copy()
    else:
        zEcurrent = model_TE.decoder_E(zT, training=False).numpy()
        print('decoder_E(zT) same as previous step is {}'.format(np.array_equal(zEref,zEcurrent)))
        zEref = zEcurrent.copy()

Expect everything to change:--------------
Epoch:    0, mse_T: 1.74592, mse_E: 0.40934, mse_M: 0.03344, mse_TE: 0.01526
Epoch:    1, mse_T: 1.74258, mse_E: 0.40608, mse_M: 0.03984, mse_TE: 0.01496
Epoch:    2, mse_T: 1.74459, mse_E: 0.40853, mse_M: 0.03456, mse_TE: 0.01774
Epoch:    3, mse_T: 1.74090, mse_E: 0.40217, mse_M: 0.03627, mse_TE: 0.01507
Epoch:    4, mse_T: 1.74253, mse_E: 0.40386, mse_M: 0.03349, mse_TE: 0.01506

Expect mse_T to be unchanged:--------------
Epoch:    0, mse_T: 1.72686, mse_E: 0.40256, mse_M: 0.03441, mse_TE: 0.01421
Epoch:    1, mse_T: 1.72686, mse_E: 0.40777, mse_M: 0.03751, mse_TE: 0.01417
Epoch:    2, mse_T: 1.72686, mse_E: 0.40181, mse_M: 0.03682, mse_TE: 0.01156
Epoch:    3, mse_T: 1.72686, mse_E: 0.40334, mse_M: 0.03280, mse_TE: 0.01293
Epoch:    4, mse_T: 1.72686, mse_E: 0.40487, mse_M: 0.04407, mse_TE: 0.01334

Expect nothing to change:--------------
Epoch:    0, mse_T: 1.72686, mse_E: 0.38620, mse_M: 0.01893, mse_TE: 0.00850
Epoch:    1, mse_T: 1.72

In [None]:
#Sample loop to store best weights in the fine tuning step.

fine_tune_epochs = 20
best_loss = np.inf

#Fine tune E decoder:
for epoch in range(fine_tune_epochs):
    #Training with train_E=True was found to prevent overfitting (as per validation losses)
    zT, zE, XrT, XrE = finetune_fn(XT=train_T_dat, XE=train_E_dat, train_T=False, train_E=True, subnetwork='Decoder_E')
    train_log_name, train_log_values = report_losses(train_T_dat, train_E_dat, zT, zE, XrT, XrE, epoch=epoch, datatype='train_', verbose=False)
            
    #Collect validation metrics
    zT, zE, XrT, XrE = finetune_fn(XT=val_T_dat, XE=val_E_dat, train_T=False, train_E=False, subnetwork=None)
    val_log_name, val_log_values = report_losses(val_T_dat, val_E_dat, zT, zE, XrT, XrE, epoch=epoch, datatype='val_', verbose=False)
    
    loss_dict = dict(zip(val_log_name, val_log_values))
    if best_loss>loss_dict['val_mse_E']+loss_dict['val_mse_M']:
        best_loss = loss_dict['val_mse_E']+loss_dict['val_mse_M']
        val_log_name, val_log_values = report_losses(val_T_dat, val_E_dat, zT, zE, XrT, XrE, epoch=epoch, datatype='val_', verbose=True)
        print('{:0.5f} is best val_mse_TE'.format(best_loss))
        model_TE.save_weights(dir_pth['result']+fileid+'-best_E_weights.h5')