In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))

if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import json
from pathlib import Path

import numpy as np

import tensorflow as tf

from phagenet import PredictionGenerator, DataGenerator
from phagenet.preprocessing import preprocess_assemblies, preprocess_assemblies_gcp
from phagenet.models import LSTMClassifier

In [3]:
input_dir = Path.cwd() / '..' / 'input'
input_dir.mkdir(exist_ok=True)

output_dir = Path.cwd() / '..' / 'output'
output_dir.mkdir(exist_ok=True)

weights_dir = Path.cwd() / '..' / 'weights'
weights_dir.mkdir(exist_ok=True)

data_dir = Path.cwd() / '..' / '..' / 'prophage-tool' / 'out' / 'phageboost_output'

In [4]:
k = 7
model_name = 'lstm-64-transfer'

classifier = LSTMClassifier(k, 64)
classifier.set_checkpoint_path(weights_dir / 'lstm-64')
classifier.load(40)
classifier.set_checkpoint_path(weights_dir / model_name)
classifier.model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
phage-7-mer-freq (InputLayer)   [(None, 16384)]      0                                            
__________________________________________________________________________________________________
bacteria-7-mer-freq (InputLayer [(None, 16384)]      0                                            
__________________________________________________________________________________________________
reshape (Reshape)               (None, 1, 16384)     0           phage-7-mer-freq[0][0]           
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 1, 16384)     0           bacteria-7-mer-freq[0][0]        
______________________________________________________________________________________________

In [5]:
prophage_labels = {}
prophage_interactions = []
prophage_assemblies = {}

for t in os.scandir(data_dir):
    for a in os.scandir(t.path):
        bact_assembly = a.name
        for f in os.scandir(a.path):
            if f.name.endswith('.fasta'):
                phage_assembly = f.name.replace('.fasta', '')
                prophage_assemblies[phage_assembly] = f.path
                inter = f"{phage_assembly},{bact_assembly}"
                prophage_interactions.append(inter)
                prophage_labels[inter] = 1

In [6]:
prophages, _ = preprocess_assemblies(k, prophage_assemblies, out_dir=input_dir, out_prefix='prophages.capstone')

KeyboardInterrupt: 

In [6]:
np.random.shuffle(prophage_interactions)
ptest_set = prophage_interactions[:100]
ptrain_set = prophage_interactions[100:]

In [7]:
with open(input_dir / 'interactions.final.json') as f:
    interactions = np.array(json.load(f))

class_sizes = (
    len([i for i in interactions if i[2] == '1']),
    len([i for i in interactions if i[2] == '0'])
)
test_sample_split = (1000, 9000)

test_idxs = np.concatenate([
    np.random.choice(class_sizes[0], size=test_sample_split[0], replace=False),
    np.random.choice(class_sizes[1], size=test_sample_split[1], replace=False) + class_sizes[0]
])

labels, train_interactions, test_interactions = {}, [], []
for i, interaction in enumerate(interactions):
    interaction_str = "{0},{1}".format(interaction[0], interaction[1])
    labels[interaction_str] = int(interaction[2])
    if i in test_idxs:
        test_interactions.append(interaction_str)
    else:
        train_interactions.append(interaction_str)

labels.update(prophage_labels)
train_interactions = np.concatenate([train_interactions, ptrain_set])
test_interactions = np.concatenate([test_interactions, ptest_set])

In [8]:
train_generator = DataGenerator(k, train_interactions, labels, 
    [(input_dir / 'phages.ids.npy'), (input_dir / 'prophages.capstone.ids.npy')], 
    [(input_dir / 'phages.kmer-freqs.npy'), (input_dir / 'prophages.capstone.kmer-freqs.npy')], 
    [(input_dir / 'bact.ids.npy'), (input_dir / 'bact.capstone.ids.npy')], 
    [(input_dir / 'bact.kmer-freqs.npy'), (input_dir / 'bact.capstone.kmer-freqs.npy')],
    batch_size=128)

test_generator = DataGenerator(k, test_interactions, labels, 
    [(input_dir / 'phages.ids.npy'), (input_dir / 'prophages.capstone.ids.npy')], 
    [(input_dir / 'phages.kmer-freqs.npy'), (input_dir / 'prophages.capstone.kmer-freqs.npy')], 
    [(input_dir / 'bact.ids.npy'), (input_dir / 'bact.capstone.ids.npy')], 
    [(input_dir / 'bact.kmer-freqs.npy'), (input_dir / 'bact.capstone.kmer-freqs.npy')],
    batch_size=len(test_interactions))

ptest_generator = DataGenerator(k, ptest_set, labels, (input_dir / 'prophages.capstone.ids.npy'), (input_dir / 'prophages.capstone.kmer-freqs.npy'), (input_dir / 'bact.capstone.ids.npy'), (input_dir / 'bact.capstone.kmer-freqs.npy'), batch_size=len(ptest_set))


In [9]:
classifier.compile()

In [13]:
classifier.fit(20, train_generator, log_dir=(output_dir / "logs" / model_name), validation_data=test_generator[0])

Epoch 41/60
Epoch 42/60
Epoch 43/60
Epoch 44/60
Epoch 45/60
Epoch 46/60
Epoch 47/60
Epoch 48/60
Epoch 49/60
Epoch 50/60
Epoch 51/60
Epoch 52/60
Epoch 53/60
Epoch 54/60
Epoch 55/60
Epoch 56/60
Epoch 57/60
Epoch 58/60
Epoch 59/60
Epoch 60/60
