In [None]:
import os
import json
from pathlib import Path

import numpy as np
import tensorflow as tf

print(os.getcwd())

In [None]:
base2int = {'A': 0, 'C': 1, 'G': 2, 'T': 3}

def sequence2int(sequence, mapping=base2int):
    return [mapping.get(base, 999) for base in sequence]

def sequence2onehot(sequence, mapping=base2int):
    return tf.one_hot(sequence2int(sequence, mapping), depth=4)

def load_fasta(fasta):
    with open(fasta) as f:
        for line in f:
            assert line[0] == '>'
            header, sequence = line.strip(), f.readline().strip()
            name, *_ = header[1:].split(':')
            yield tf.cast(sequence2onehot(sequence), tf.float32), tf.cast(name, tf.string)

def load_dataset(fasta, cache=False, shuffle=False):
    #positive_label = tf.cast(positive_label, tf.string)
    
    dataset = tf.data.Dataset.from_generator(lambda: load_fasta(fasta), output_types=(tf.float32, tf.string))
    #dataset = dataset.map(lambda x, y: (x, tf.cast(y == positive_label, tf.int8)))
    if cache:
        dataset = dataset.cache()
    if shuffle:
        dataset = dataset.shuffle(1_000_000)
    #dataset = dataset.batch(128)
    return dataset

In [None]:
# X = [x for x in load_fasta('../data.csv/processed/train/HEPG2.fasta')]
# print(len(X))

In [None]:
# D = load_dataset('../data.csv/processed/train/HEPG2.fasta')
# D.element_spec

# for x in D.take(1):
#     print(x)

In [None]:
def load_finetune_models(dir):
    models_dict = dict()
    for fpath in Path(dir).glob('**/*.h5'):
        cell = str(fpath).split('/')[-2]
        models_dict[cell] = tf.keras.models.load_model(str(fpath))
    return models_dict

finetune_models = load_finetune_models('../models/rDHS.models.finetuned')
finetune_models

In [None]:
finetune_models['HEPG2'](tf.random.uniform(shape=(1, 600, 4)))

In [None]:
multiclass_model = tf.keras.models.load_model('../train.csv/model.multitask.Jan-5-2023.h5', compile=False)

In [None]:
multiclass_model(tf.random.uniform(shape=(1, 600, 4)))

In [None]:
cell2idx = {
    'A549': 0,
    'GM12878': 1,
    'HCT116': 2,
    'HEPG2': 3,
    'K562': 4,
    'MCF7': 5,
    'Negative': 6,
}

In [None]:
multiclass_model

In [None]:
import tqdm

def evaluate_model_on_fasta(model, dataset, output_filepath, positive_label, eval_type='binary', idx=0):
    
    with tqdm.tqdm(total=n) as pbar:
        with open(output_filepath, 'w') as f_out:
            for x, y in dataset:
                pred = model(x)
                for i in range(len(pred)):
                    if idx is not None:
                        pred_i = pred[i, idx]
                        
                    # if eval_type == 'multiclass':
                    #     # whether idx is max score --> needed for F1 score calculation of multitask
                    #     pred_i_max_idx = tf.argmax(pred[i, ]).numpy()
                    #     pred_i_label = int(idx == pred_i_max_idx)
                    #     print('argmax:', pred_i_max_idx, ' - ', 'idx:', idx, 'label:', pred_i_label, 'pred:', pred[i, ].numpy())
                    # elif eval_type == 'binary':
                    #     pred_i_label = int(pred_i > 0.5)
                    
                    y_i = y[i].numpy().decode('UTF-8')
                    label_i = int(positive_label == y_i)
                    print(f"{pred_i},{cell},{y_i},{label_i},{pred_i_label}", file=f_out, flush=True)
                    pbar.update(1)

In [None]:
for cell in ['HEPG2']:# cell2idx.keys():
    if cell == 'Negative':
        continue
    
    print('-->', cell)
    
    fasta = f'../data.csv/processed/test/{cell}.fasta'
    
    print('Loading dataset ...')
    dataset = load_dataset(fasta, cache=True)
    
    n = 0
    for _ in dataset:
        n += 1
    print('total:', n)
    
    dataset = dataset.batch(256)
    dataset = dataset.cache()
    
    # multi-task
    evaluate_model_on_fasta(multiclass_model, dataset, f'multiclass/eval.{cell}.multiclass.csv', positive_label=cell, idx=cell2idx[cell], eval_type='multiclass')
    
    # finetuned
    evaluate_model_on_fasta(finetune_models[cell], dataset, f'finetuned/eval.{cell}.finetuned.csv', positive_label=cell, eval_type='binary')
    