In [None]:
import base64
import bz2
from collections import defaultdict
import json
import time

import pandas as pd
import requests


COUNT_OFFTARGETS_URL = 'https://b7bc9zf9oj.execute-api.ap-southeast-2.amazonaws.com/prod/countOfftargets'
FIND_TARGETS_URL = 'https://b7bc9zf9oj.execute-api.ap-southeast-2.amazonaws.com/prod/findTargets'
MAX_TARGETS_PER_SCAN = 10000


SPECIES_S3_LOCATIONS = {
    'B. pertussis': ('cdiag-genomes', 'BordetellaPertussis.fasta'),
    'Ebola': ('cdiag-genomes', 'Ebola.fasta'),
    'F. Solani': ('cdiag-genomes', 'FusariumSolani.fasta'),
    'H1N1 1': ('cdiag-genomes', 'H1N1_1.fa'),
    'H1N1 2': ('cdiag-genomes', 'H1N1_2.fa'),
    'H1N1 3': ('cdiag-genomes', 'H1N1_3.fa'),
    'H1N1 4': ('cdiag-genomes', 'H1N1_4.fa'),
    'H1N1 5': ('cdiag-genomes', 'H1N1_5.fa'),
    'H1N1 6': ('cdiag-genomes', 'H1N1_6.fa'),
    'H1N1 7': ('cdiag-genomes', 'H1N1_7.fa'),
    'H1N1 8': ('cdiag-genomes', 'H1N1_8.fa'),
    'hMPV': ('cdiag-genomes', 'humanMetapneumovirus.fasta'),
    'L. prolificans': ('cdiag-genomes', 'LomentosporaProlificans.fasta'),
    'RSV': ('cdiag-genomes', 'respiratorySyncytialVirus.fasta'),
    'S. aureus': ('cdiag-genomes', 'StaphylococcusAureus.fasta'),
    'Coronavirus': ('cdiag-genomes', 'Coronavirus.fasta'),
}


def check_request(method, *args, **kwargs):
    print(f'{method}: {args[0]}')
    now = time.time()
    response = requests.request(method, *args, **kwargs)
    print(f'Completed request in {round((time.time() - now) * 1000)}ms')
    try:
        response.raise_for_status()
    except requests.HTTPError as e:
        print(f'HTTP Error {e.response.status_code}')
        print(e.response.text)
        raise
    return response


def format_data(data):
    num_organisms = len(data[0][-1])
    num_mismatch_columns = len(list(data[0][-1].values())[0])
    column_names = ["Organism", "Contig", "Target Start", "Strand", "Target"]
    off_target_columns = ['MM_{}'.format(i) for i in range(num_mismatch_columns)]
    for i in range(num_organisms):
        column_names += ["Off-Target Organism"]
        column_names += off_target_columns
    formatted_data = []
    for element in data:
        toAppend = element[:5]
        for organism in element[5]:
            toAppend.append(organism)
            for target in element[5][organism]:
                toAppend.append(target)
        formatted_data.append(toAppend)
    return pd.DataFrame.from_records(formatted_data, columns = column_names)


def get_all_offtargets(target_positions, organisms, rule, max_mismatches):
    target_mismatches = defaultdict(dict)
    targets = list(target_positions)
    for offtarget_organism in organisms:
        bucket, key = SPECIES_S3_LOCATIONS[offtarget_organism]
        start = 0
        end = MAX_TARGETS_PER_SCAN
        while True:
            target_chunk = targets[start:end]
            if not target_chunk:
                break
            target_hits_response = check_request('POST', COUNT_OFFTARGETS_URL, json={
                'fastaBucket': bucket,
                'fastaKey': key,
                'mismatches': max_mismatches,
                'rule': rule,
                'targets': target_chunk, 
            })
            target_hits = target_hits_response.json()
            for i, mismatches in enumerate(target_hits):
                seq = target_chunk[i]
                target_mismatches[seq][offtarget_organism] = mismatches
            start = end
            end += MAX_TARGETS_PER_SCAN
    output = []
    for seq, details in target_positions.items():
        mismatches = target_mismatches[seq]
        for pos in details:
            output.append(pos + [seq, mismatches])
    return output


def get_common_targets(in_organisms, out_organisms, rule):
    find_targets_response = check_request('POST', FIND_TARGETS_URL, json={
        'in_organisms': {
            organism: {
                's3_location': SPECIES_S3_LOCATIONS[organism],
                'regions': regions,
            }
            for organism, regions in in_organisms.items()
        },
        'out_organisms': {
            organism: {
                's3_location': SPECIES_S3_LOCATIONS[organism],
                'regions': regions,
            }
            for organism, regions in out_organisms.items()
        },
        'rule': rule,
    })
    json_data = bz2.decompress(base64.b64decode(find_targets_response.text))
    common_targets = json.loads(json_data)
    return common_targets

In [None]:
IN_ORGANISMS = {
    'B. pertussis': [
        'CP017403.1:200000-',
        'CP017403.1:400000-500000',
        'CP017403.1:-100000',
    ],
    'S. aureus': [
        'NC_007795.1',
    ]
}
OUT_ORGANISMS = {
    'Coronavirus': None
}

RULE = 'xxxxxxxxxxxxxxxxxxxNGG'
MAX_MISMATCHES = 6

In [None]:
filtered_targets = get_common_targets(IN_ORGANISMS, OUT_ORGANISMS, RULE)

In [None]:
filtered_targets

In [None]:
all_organisms = list(IN_ORGANISMS) + list(OUT_ORGANISMS)
start = time.time()
offtarget_counts = get_all_offtargets(filtered_targets, all_organisms, RULE, MAX_MISMATCHES)
print(time.time()-start)

In [None]:
df = format_data(offtarget_counts)
df.sort_values(by=['Target'])
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    display(df)