In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))

if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import json
from pathlib import Path

import numpy as np

import tensorflow as tf

from phagenet import PredictionGenerator
from phagenet.preprocessing import preprocess_assemblies, preprocess_assemblies_gcp
from phagenet.models import LSTMClassifier

In [3]:
input_dir = Path.cwd() / '..' / 'input'
input_dir.mkdir(exist_ok=True)

output_dir = Path.cwd() / '..' / 'output'
output_dir.mkdir(exist_ok=True)

weights_dir = Path.cwd() / '..' / 'weights'
weights_dir.mkdir(exist_ok=True)

data_dir = Path.cwd() / '..' / '..' / 'prophage-tool' / 'data'

In [4]:
k = 7
model_name = 'lstm-64-transfer'

classifier = LSTMClassifier(k, 64)
classifier.set_checkpoint_path(weights_dir / model_name)
classifier.load(40)
classifier.model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
phage-7-mer-freq (InputLayer)   [(None, 16384)]      0                                            
__________________________________________________________________________________________________
bacteria-7-mer-freq (InputLayer [(None, 16384)]      0                                            
__________________________________________________________________________________________________
reshape (Reshape)               (None, 1, 16384)     0           phage-7-mer-freq[0][0]           
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 1, 16384)     0           bacteria-7-mer-freq[0][0]        
______________________________________________________________________________________________

In [6]:
taxids = {}
bact_assemblies = {}

for entry in os.scandir(data_dir):
    fasta_dir = Path(entry.path, 'fasta')
    if entry.is_dir() and fasta_dir.exists():
        taxid = entry.name
        assemblies = []
        for assembly in os.scandir(fasta_dir):
            bact_assemblies[assembly.name] = assembly.path
            taxids[assembly.name] = taxid

In [5]:
phages = np.load(input_dir / 'phages.ids.npy')
bact = np.load(input_dir / 'bact.capstone.ids.npy')

In [7]:
bact, _ = preprocess_assemblies(k, bact_assemblies, out_dir=input_dir, out_prefix='bact.capstone')

KeyboardInterrupt: 

In [7]:
interactions = []
for b in bact:
    for p in phages:
        interactions.append(f"{p},{b}")

In [10]:
generator = PredictionGenerator(k, interactions, (input_dir / 'phages.ids.npy'), (input_dir / 'phages.kmer-freqs.npy'), (input_dir / 'bact.capstone.ids.npy'), (input_dir / 'bact.capstone.kmer-freqs.npy'), batch_size=512)

In [11]:
predictions = classifier.model.predict(generator, verbose=1)



In [12]:
import pandas as pd

p_col, b_col, ba_col, i_col = [], [], [], []
for inter, pred in zip(interactions, predictions):
    p, b = inter.split(',')
    p_col.append(p)
    b_col.append(taxids[b])
    ba_col.append(b)
    i_col.append(pred)

results = pd.DataFrame({
    'bacteria': b_col,
    'bacteria_assembly': ba_col,
    'phage': p_col,
    'model_output': i_col
})
results.model_output = results.model_output.astype(float)
results.to_csv(output_dir / 'results.post.csv', index=False)



In [21]:
results.model_output = results.model_output.astype(float)

In [23]:
results.to_csv(output_dir / 'results.all.csv', index=False)

In [26]:
results['bacteria'] = results.apply(lambda x: taxids[x.bacteria_assembly], axis=1)

In [28]:
results.to_csv(output_dir / 'results.all.csv', index=False)

In [30]:
results[results.model_output > 0.9].groupby('bacteria').apply(lambda x: len(x.phage.unique()))

bacteria
1273132     20
234        116
28105      115
309868      53
34020       30
34021       41
357        119
379        128
556287      30
744859      11
773        160
953        204
dtype: int64

In [10]:
results_pre = pd.read_csv(output_dir / 'results.pre.csv')
results_post = pd.read_csv(output_dir / 'results.post.csv')

In [11]:
results_all = results_pre.merge(results_post, on=['bacteria','bacteria_assembly', 'phage'], suffixes=['_pre', '_post'])

In [12]:
results_all.to_csv(output_dir / 'results.all.csv', index=False)

In [4]:
import pandas as pd
results_all = pd.read_csv(output_dir / 'results.all.csv')

In [16]:
results_all[(results_all.model_output_post > 0.9) & (results_all.model_output_pre > 0.9)]

Unnamed: 0,bacteria,bacteria_assembly,phage,model_output_pre,model_output_post
10516,309868,GCA_000350385.1,35238,0.993997,0.999632
19616,309868,GCA_000496595.1,2693670,0.901043,0.901884
21682,309868,GCA_000496595.1,35238,0.994026,0.999630
22496,34020,GCA_001021085.1,1072683,0.970320,0.999951
22607,34020,GCA_001021085.1,1105171,0.948981,0.999997
...,...,...,...,...,...
4835776,744859,GCA_003045065.1,1334243,0.924859,0.999837
4838703,744859,GCA_003045065.1,2041382,0.981830,0.934536
4843254,744859,GCA_003045065.1,2693370,0.966579,0.999384
4843412,744859,GCA_003045065.1,2694060,0.939844,0.962707


In [8]:
np.sqrt(np.power(results_all.model_output_post - results_all.model_output_pre, 2).mean())

0.10409932620581669

In [9]:
# def process_results_by_bact(res):
#     res[res.model_output_pre]
    
results_all.groupby('bacteria').apply(lambda x: pd.DataFrame(
    {
        # 'unique_phages_pre': x[x.model_output_pre > 0.9].phage.unique(),
        # 'unique_phages_post': x[x.model_output_post > 0.9].phage.unique(),
        'unique_phages_overlap': x[(x.model_output_pre > 0.9) & (x.model_output_post > 0.9)].phage.unique(),
        'model_ou': x[(x.model_output_pre > 0.9) & (x.model_output_post > 0.9)].phage.unique(),
    }
)).reset_index().drop(columns=['level_1']).to_csv(output_dir / 'results.unique-phages-overlap.csv', index=False)