In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from matplotlib import pyplot as plt
import math
import os
from utils import models

In [2]:
# Model / data parameters
from utils.preprocessing import prepare_dataset_U_N

ROOT_DIR = '/mnt/Archive/Data_Sets/MesoThermo'

# Directories
dir_fasta = os.path.join(ROOT_DIR,'FASTA')
dir_blast = os.path.join(ROOT_DIR,'BLAST_record')
dir_pdb = '/mnt/Archive/PDB'
outfolder = os.path.join(ROOT_DIR,'Dist_Mat/Thermophile')

name_thermo = 'dist_termo.fasta'
name_meso   = 'dist_meso.fasta'

thermo_train, thermo_val = prepare_dataset_U_N(dir_fasta, name_thermo, file_format = 'fasta', seq_length = 1024, t_v_split = 0.1)
meso_train,   meso_val   = prepare_dataset_U_N(dir_fasta, name_meso  , file_format = 'fasta', seq_length = 1024, t_v_split = 0.1)


In [3]:
model_type = 'res'

p = {'max_seq_len': 1024,
     'num_classes': 10,
     'emb_size': 10,
     'num_filter': [32, 64, 128],
     'kernel_size':[3, 3, 3, 3, 3] ,
     'sampling_stride': [2, 2],
     'pool_size': [2, 2, 2, 2, 2, 2],
     'rate': [0.0, 0.0, 0.0, 0.0, 0.0],
     'l1': [0.001, 0.001, 0.001, 0.01, 0.001],
     'l2': [0.001, 0.001, 0.001, 0.01, 0.001],
     'use_max_pool': False,
     'output_activation': 'softmax'}

model_encode   = models.U_net(p, model_type, emb = True)

In [4]:

loss_fn = tf.keras.losses.CategoricalCrossentropy(
    from_logits=False, label_smoothing=0,
    reduction = tf.keras.losses.Reduction.NONE,
    name='categorical_crossentropy'
)
opt = tf.keras.optimizers.Adam(
    learning_rate=0.001, beta_1=0.5, beta_2=0.8, epsilon=1e-07, amsgrad=False,
    name='Adam')



In [5]:

@tf.function
def train_on_batch(x, y, w, model, loss_fn, optimizer):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss = loss_fn(y, logits, w)
        loss = tf.reduce_mean(loss,axis=0)
        loss += tf.reduce_mean(model.losses)
    gradients = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))
    loss = tf.reduce_mean(loss,axis=None)
    return logits, loss
@tf.function
def validate_on_batch(x, y, w, model, loss_fn):
    logits = model(x,training=False)
    loss = loss_fn(y, logits, w)
    loss = tf.reduce_mean(loss,axis=0)
    loss += tf.reduce_mean(model.losses)
    loss = tf.reduce_mean(loss,axis=None)
    return logits, loss



In [6]:
train_loss = tf.keras.metrics.Mean('loss', dtype=tf.float32)
val_loss = tf.keras.metrics.Mean('val_loss', dtype=tf.float32)
    
train_acc = tf.keras.metrics.CategoricalAccuracy()
val_acc   = tf.keras.metrics.CategoricalAccuracy()

for epoch in range(30):
    batches_train = meso_train.shuffle(buffer_size = 40000).batch(64, drop_remainder=True)
    batches_val = meso_val.shuffle(buffer_size = 40000).batch(64, drop_remainder=True)
    for step, (x, y, w) in enumerate(batches_train):
        logits, loss = train_on_batch(x, y, w, model_encode, loss_fn, opt)
        
        train_loss(loss)
        train_acc(y, logits, w)
        
    for step, (x, y, w) in enumerate(batches_val):
        logits, loss = validate_on_batch(x, y, w, model_encode, loss_fn)
        
        val_loss(loss)
        val_acc(y, logits, w)
        
    print("Epoch: %d\tLoss%6.4f\tAcc%6.2f\tVal_loss%6.4f\tVal_acc%6.2f" % 
                      (epoch, train_loss.result().numpy(),                                                                       train_acc.result().numpy(),
                        val_loss.result().numpy(), 
                        val_acc.result().numpy()))
                  
    train_loss.reset_states()
    train_acc.reset_states()
    val_loss.reset_states()
    val_acc.reset_states()

Epoch: 0	Loss0.5171	Acc  0.86	Val_loss1.0762	Val_acc  0.05
Epoch: 1	Loss0.0688	Acc  1.00	Val_loss1.7644	Val_acc  0.05
Epoch: 2	Loss0.0301	Acc  1.00	Val_loss1.9180	Val_acc  0.05
Epoch: 3	Loss0.0188	Acc  1.00	Val_loss1.8920	Val_acc  0.05
Epoch: 4	Loss0.0158	Acc  1.00	Val_loss1.3502	Val_acc  0.05
Epoch: 5	Loss0.0128	Acc  1.00	Val_loss0.6817	Val_acc  0.10
Epoch: 6	Loss0.0117	Acc  1.00	Val_loss1.9922	Val_acc  0.08
Epoch: 7	Loss0.0095	Acc  1.00	Val_loss0.0139	Val_acc  1.00
Epoch: 8	Loss0.0113	Acc  1.00	Val_loss0.0164	Val_acc  1.00
Epoch: 9	Loss0.0090	Acc  1.00	Val_loss0.0124	Val_acc  1.00
Epoch: 10	Loss0.0065	Acc  1.00	Val_loss0.0326	Val_acc  0.99
Epoch: 11	Loss0.0077	Acc  1.00	Val_loss0.1397	Val_acc  0.88
Epoch: 12	Loss0.0057	Acc  1.00	Val_loss0.0055	Val_acc  1.00
Epoch: 13	Loss0.0053	Acc  1.00	Val_loss0.0163	Val_acc  1.00
Epoch: 14	Loss0.0049	Acc  1.00	Val_loss0.0100	Val_acc  1.00
Epoch: 15	Loss0.0047	Acc  1.00	Val_loss0.0231	Val_acc  1.00
Epoch: 16	Loss0.0047	Acc  1.00	Val_loss0.0050	Val_

KeyboardInterrupt: 