In [None]:
import os
import json
import datetime

import numpy as np
import tensorflow as tf

In [None]:
os.getcwd()

## Load Dataset

In [None]:
cell2idx = {
    'A549': 0,
    'GM12878': 1,
    'HCT116': 2,
    'HEPG2': 3,
    'K562': 4,
    'MCF7': 5,
    'Negative': 6,
}

In [None]:
base2int = {'A': 0, 'C': 1, 'G': 2, 'T': 3}

def sequence2int(sequence, mapping=base2int):
    return [mapping.get(base, 999) for base in sequence]

def sequence2onehot(sequence, mapping=base2int):
    return tf.one_hot(sequence2int(sequence, mapping), depth=4)

def load_fasta(fasta, cell2idx):
    with open(fasta) as f:
        for line in f:
            assert line[0] == '>'
            header, sequence = line.strip(), f.readline().strip()
            name, *_ = header[1:].split(':')
            
            if name not in cell2idx:
                continue
            
            yield tf.cast(sequence2onehot(sequence), tf.float32), cell2idx[name]

def load_dataset(fasta, cell2idx, cache=False, shuffle=False):
    #positive_label = tf.cast(positive_label, tf.string)
    
    dataset = tf.data.Dataset.from_generator(lambda: load_fasta(fasta, cell2idx), output_types=(tf.float32, tf.int8))
    
    #dataset = dataset.filter(lambda x, y: x in cell2idx)
    #dataset = dataset.map(lambda x, y: (x, cell2idx[y.numpy()]))
        
    #dataset = dataset.map(lambda x, y: (x, tf.cast(y == positive_label, tf.int8)))
    
    if cache:
        dataset = dataset.cache()
    if shuffle:
        dataset = dataset.shuffle(1_000_000)
    dataset = dataset.batch(256)
    return dataset

In [None]:
dataset = load_dataset('../data.csv/processed/train/train.fasta', cell2idx, cache=True, shuffle=True)
print(dataset.element_spec)
for x in dataset.take(1):
    print(x)
    
dataset_test = load_dataset('../data.csv/processed/test/test.fasta', cell2idx, cache=True, shuffle=True)

## Build Model

In [None]:
with open('model.config.multi-class.json') as f:
    model_config = json.loads(f.read())

In [None]:
model = tf.keras.models.Model().from_config(model_config)
model.summary()

## Train Model

In [None]:
# create callbacks
log_dir = 'logs.multi-class/' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
print(log_dir)

callbacks = [
    #tf.keras.callbacks.EarlyStopping(patience=15),
    tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1),
]

In [None]:
# create metrics
# metrics = {
#     'auROC':  tf.keras.metrics.AUC(curve='ROC', multi_label=True),
#     'accuracy': tf.keras.metrics.Accuracy(),
#     'precision': tf.keras.metrics.Precision(),
#     'recall': tf.keras.metrics.Recall(),
# }

metrics = [
    tf.keras.metrics.SparseCategoricalAccuracy(),
]
metrics

In [None]:
# compile
model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(), metrics=metrics)

In [None]:
EPOCHS = 50
model.fit(dataset, validation_data=dataset_test, epochs=EPOCHS, callbacks=callbacks)
model.save(f'{log_dir}/model.multitask.h5')