In [40]:
from net_work_def import  MtlNetwork_body2
import pickle
import matplotlib as mpl
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from datetime import datetime
import os

In [41]:
ddir = 'replicate_pickle_data'
num_classes = 85
num_features = 100*100*40
batch_size = 16
total_training_batches = 2500

subject_data = {}

subs = ['sub1', 'sub2', 'sub3', 'sub4', 'sub5', 'sub7']
for subject in subs:

    # load subject data
    with open(f'{ddir}/{subject}/data_align.pkl', 'rb') as f:
        data_align = pickle.load(f)
    f.close()
    with open(f'{ddir}/{subject}/ppgnp_align.pkl', 'rb') as f:
        ppgnp_align = pickle.load(f)
    f.close()
    with open(f'{ddir}/{subject}/trainX.pkl', 'rb') as f:
        trainX = pickle.load(f)
    f.close()
    with open(f'{ddir}/{subject}/trainY.pkl', 'rb') as f:
        trainY = pickle.load(f)
    f.close()
    with open(f'{ddir}/{subject}/pulR.pkl', 'rb') as f:
        pulR = pickle.load(f)
    f.close()

    subject_data[subject] = {}
    subject_data[subject]['data_align'] = data_align
    subject_data[subject]['ppgnp_align'] = ppgnp_align
    subject_data[subject]['trainX'] = trainX[0: total_training_batches]
    subject_data[subject]['trainY'] = trainY[0: total_training_batches]
    subject_data[subject]['pulR'] = pulR

In [42]:
def final_data_prep(trainX, trainY):

    trainX = np.array(trainX, dtype = np.float32)
    trainY = np.array(trainY, dtype = np.float32)

    trainY = trainY - trainY.min(axis = 1)[:, np.newaxis]
    trainY = (trainY/(trainY.max(axis = 1)[:, np.newaxis]+ 10**-5))*2-1
    trainX = (trainX-trainX.min())
    trainX = trainX/ trainX.max()

    trX, teX, trY, teY = train_test_split(trainX , trainY, 
                                        test_size = .1, random_state = 42)

    train_data = tf.data.Dataset.from_tensor_slices((trX, trY))
    train_data = train_data.repeat().shuffle(buffer_size=100,
                                            seed= 8).batch(batch_size).prefetch(1)
    
    return {
        'train_data': train_data,
        'testX': teX,
        'testY': teY
    }

training_data = {}
for subject in tqdm(subject_data):
    training_data[subject] = final_data_prep(subject_data[subject]['trainX'], subject_data[subject]['trainY'])

100%|██████████| 6/6 [01:44<00:00, 17.40s/it]


In [43]:
def RootMeanSquareLoss(x,y):
    
    # pdb.set_trace()  
    loss = tf.keras.losses.MSE(y_true = y, y_pred =x)  # initial one
    #return tf.reduce_mean(loss)  # some other shape similarity
     
    loss2 = tf.reduce_mean((tf.math.abs(tf.math.sign(y))-tf.math.sign(tf.math.multiply(x,y))),axis = -1)
    # print(loss2.shape)
    
    # print(tf.reduce_mean(loss), tf.reduce_mean(loss2))
    return loss + 0.5*loss2

In [44]:
def run_optimization(opt, mod, x, y):
    
    with tf.GradientTape() as g:
        
        pred = mod(x, training=True)
        loss = RootMeanSquareLoss(y, pred)

    # Gradients for the shared body
    tvars = mod.trainable_variables
    gradients = g.gradient(loss, tvars)
    opt.apply_gradients(zip(gradients, tvars))

    del g  # Delete the GradientTape object to free up resources


In [45]:
train_loss = []
val_loss = []

def Val_loss (mod, testX, testY):
    pred = mod(testX, training = False)
    loss = RootMeanSquareLoss(testY, pred)
    val_loss.append(tf.reduce_mean(loss))

In [46]:
def train_mtl_nn(mod, training_data, training_steps_per_subject = 1000):
    
    opt = tf.optimizers.SGD(0.0005)
    min_val_loss = float('inf')

    subject_iterators = {subject: iter(training_data[subject]['train_data']) for subject in training_data}

    t1 = datetime.today()

    for step in range(1, training_steps_per_subject + 1):
        for subject in training_data:
            
            teX = training_data[subject]['testX']
            teY = training_data[subject]['testY']

            batch_x, batch_y = next(subject_iterators[subject])

            # Run optimization for the current subject's data and head
            run_optimization(opt, mod, batch_x, batch_y)
        
            if step % 10 == 0:
                
                pred = mod(batch_x, training=True)
                loss = RootMeanSquareLoss(batch_y, pred)
                train_loss.append(tf.reduce_mean(loss))

                tp = np.random.randint(len(teX)-16)
                Val_loss(mod, teX[tp+0:tp+16], teY[tp+0:tp+16])
                current_val_loss = val_loss[-1]
                print(f'Step: {step}, Subject: {subject}; Training Loss: {tf.reduce_mean(train_loss[-1])}, Val Loss: {current_val_loss}')
                
                # Save the model weights if the current validation loss is lower than the previous minimum validation loss
                if current_val_loss < min_val_loss:
                    min_val_loss = current_val_loss
                    if not os.path.exists("inprocess_weights"):
                        os.mkdir("inprocess_weights")

                    print(f'Saving model with validation loss: {min_val_loss}\n')
                    mod.save_weights("inprocess_weights/shared_body")


In [47]:
mtl_body = MtlNetwork_body2(num_classes = 85)
train_mtl_nn(mtl_body, training_data, training_steps_per_subject = 5000)

Step: 10, Subject: sub1; Training Loss: 0.8419879674911499, Val Loss: 0.8882651925086975
Saving model with validation loss: 0.8882651925086975

Step: 10, Subject: sub2; Training Loss: 0.8147581815719604, Val Loss: 0.9089601039886475
Step: 10, Subject: sub3; Training Loss: 0.7709710597991943, Val Loss: 0.9080119132995605
Step: 10, Subject: sub4; Training Loss: 0.8473868370056152, Val Loss: 0.9087417125701904
Step: 10, Subject: sub5; Training Loss: 0.7645107507705688, Val Loss: 0.7904037833213806
Saving model with validation loss: 0.7904037833213806

Step: 10, Subject: sub7; Training Loss: 0.648466944694519, Val Loss: 0.8710561990737915
Step: 20, Subject: sub1; Training Loss: 0.8585892915725708, Val Loss: 0.9707973003387451
Step: 20, Subject: sub2; Training Loss: 0.7584059238433838, Val Loss: 0.9322973489761353
Step: 20, Subject: sub3; Training Loss: 0.8060441017150879, Val Loss: 0.9230567216873169
Step: 20, Subject: sub4; Training Loss: 0.7638152837753296, Val Loss: 0.9496170282363892
S

In [48]:
mtl_body.save_weights("nn_weights/generalized-sub6/shared_body")