In [39]:
import pandas as pd
import numpy as np
import torch

import pickle

from Bio import SeqIO
from utils.utils import read_json

import warnings
warnings.filterwarnings('ignore')

In [2]:
clades = pd.read_csv("data/output/nextclade.tsv","\t")

In [82]:
# reading all records
genomic_records = list(SeqIO.parse("data/ncbidata/genomic.fna", "fasta"))

#translate records 
protein_records = [seq.translate() for seq in genomic_records]

In [83]:
vocab = set()
for record in protein_records:
    vocab.update(str(record.seq))
    
vocab.add("<pad>"), vocab.add("<sos>"), vocab.add("<eos>")
to_ix = {char: i for i, char in enumerate(vocab)}
inv_to_ix = {v: k for k, v in to_ix.items()}

In [84]:
class BiologicalSequenceDataset:
    def __init__(self, sequences):
        self.records = sequences

    def __len__(self):
        return len(self.records)

    def __getitem__(self, i):
        seq = self.records[i]
        return torch.tensor([to_ix[residue] for residue in seq])

def collate_fn(batch):
    return torch.nn.utils.rnn.pad_sequence(
        batch,
        batch_first=True,
        padding_value=to_ix["<pad>"]
    )

In [85]:
#drop unassigned sequences
indexes_to_drop = clades[clades["clade"]=="recombinant"].index

clades = clades.reset_index()
clades = clades.drop(index=indexes_to_drop )

for index in sorted(indexes_to_drop, reverse=True):
    del protein_records[index]

In [23]:
clades["clade"].unique()

array(['22B (Omicron)', '22C (Omicron)', '21L (Omicron)', '22A (Omicron)',
       '21K (Omicron)', '21J (Delta)', '20A', '21M (Omicron)',
       '21A (Delta)', '20B'], dtype=object)

In [86]:
clades_in_out_path = "data/clade_in_clade_out.json"
clades_in_out = read_json(clades_in_out_path)

create pairs

In [87]:
#work with indexes
parents = []
children = []
for parent_clade in clades_in_out.keys():
    for child_clade in clades_in_out[parent_clade]:
        parent_index = clades[clades["clade"]==parent_clade].index
        child_index = clades[clades["clade"]==child_clade].index
        for p in parent_index:
            for c in child_index:
                parents.append(str(protein_records[p].seq))
                children.append(str(protein_records[c].seq))
                

In [88]:
training_parents = torch.utils.data.DataLoader(
    BiologicalSequenceDataset(parents),
    batch_size=2,
    collate_fn=collate_fn
)

training_children = torch.utils.data.DataLoader(
    BiologicalSequenceDataset(children),
    batch_size=2,
    collate_fn=collate_fn
)

In [89]:
pickle.dump(training_parents, open("data/training_parents.p", "wb"))
pickle.dump(training_children, open("data/training_children.p", "wb"))
pickle.dump(to_ix, open("data/to_ix.p", "wb"))