In [19]:
import sys
sys.path.append('../')
import torch  
from Bio import SeqIO, AlignIO, Align, Seq
import mutations as bm

In [20]:
print(torch.cuda.is_available())
print(torch.cuda.current_device())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))

True
0
1
Quadro P5000


In [21]:
msa_fp = '/valhalla/gisaid/sequences_2021-02-21_aligned.fasta'
patient_zero = 'NC_045512.2'
seqs = SeqIO.parse(msa_fp, 'fasta')

In [22]:
def encode_rna(sequence: str):
    "encode nucleotide sequence as 4-bit vector"
    code = {
            '-': -1,
            'N': 0,
            'A': 1,
            'C': 2,
            'T': 3,
            'G': 4
    }
    return [code.get(nt, 0) for nt in sequence]

In [5]:
# Load sample data from GISAID and encode
data = []
for i, rec in enumerate(seqs):
    data.append(encode_rna(str(rec.seq)))
    if i >= 10000:
        break
# add reference
ref = bm.get_seq(seqs, patient_zero)
data.append(encode_rna(ref))
len(data)

10002

In [6]:
# TEST
max_len = 0
for seq in data:
    if len(seq)>max_len:
        max_len = len(seq)
max_len

29903

In [7]:
# TEST
min_len = 100000
for seq in data:
    if len(seq)<min_len:
        min_len = len(seq)
min_len

29903

In [8]:
# load sequences into GPU
data_tensor = torch.as_tensor(data, 
                              dtype=torch.int8, 
                              device=torch.cuda.current_device())

In [9]:
data_tensor

tensor([[-1, -1, -1,  ..., -1, -1, -1],
        [-1, -1, -1,  ...,  1,  1,  1],
        [-1, -1, -1,  ...,  1,  1,  1],
        ...,
        [-1, -1, -1,  ..., -1, -1, -1],
        [-1, -1, -1,  ..., -1, -1, -1],
        [ 1,  3,  3,  ...,  1,  1,  1]], device='cuda:0', dtype=torch.int8)

In [56]:
seqs = AlignIO.read(msa_fp, 'fasta')

In [57]:
seqs = bm.get_seqs(seqs)

In [58]:
len(seqs)

582244

In [67]:
len(seqs[list(seqs.keys())[0]])

29409

In [68]:
class SequenceData(torch.utils.data.Dataset):
    """SARS-CoV-2 consensus sequence data."""
    def __init__(self, fasta_filepath, meta_filepath=None, transforms=None):
        """
        Args:
            fasta_filepath (string): Path to the file containing sequences.
            meta_filepath (string): Path to the file containing metadata.
            transform (callable, optional): Optional transforms to be applied
                on a sample e.g. encode_rna().
        """
        # load as generator to save mem
        self.samples, self.sequences = self.get_seqs(AlignIO.read(fasta_filepath, 'fasta'))
        if meta_filepath:
            self.meta = pd.read_csv(meta_filepath, sep='\t', compression='gzip')
        self.transforms = transforms
        
    # support functions
    def get_seqs(self, 
                 bio_seqs: Align.MultipleSeqAlignment, 
                 min_pos: int=265, 
                 max_pos: int=29674) -> dict:
        """Parse aligned sequences from Bio.Align.MultipleSeqAlignment to list objects.
        One list are sample names and the other are their consensus sequences. 
        Each sequence is trimmed from both ends using `min_pos` and `max_pos`"""
        samples, sequences = [], []
        for row in bio_seqs:
            samples.append(str(row.id))
            s = str(row.seq)
            sequences.append(s[min_pos:max_pos])
        return samples, sequences
        
    def encode_rna(self, sequence):
        "encode nucleotide sequence as 4-bit vector"
        code = {
                '-': -1,
                'N': 0,
                'A': 1,
                'C': 2,
                'T': 3,
                'G': 4
        }
        return [code.get(nt, 0) for nt in sequence]

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        sample = self.samples[idx]
        sequence = torch.as_tensor(self.encode_rna(self.sequences[idx]),
                                   dtype=torch.int8)
        # process sample ID (optional)
        if self.transforms:
            sample = self.transforms(sample)
        return sample, sequence

In [50]:
# create dataset object
dataset = SequenceData(msa_fp)

In [59]:
dataset[0]

('Australia/NT12/2020', tensor([1, 3, 4,  ..., 3, 1, 4], dtype=torch.int8))

In [73]:
# create dataloader object
dataloader = torch.utils.data.DataLoader(data, batch_size=10000, shuffle=True)

In [74]:
# load batch of data
x, y = next(iter(dataloader))

In [75]:
y.cuda()

tensor([[1, 3, 4,  ..., 3, 1, 4],
        [1, 3, 4,  ..., 3, 1, 4],
        [1, 3, 4,  ..., 3, 1, 4],
        ...,
        [1, 3, 4,  ..., 3, 1, 4],
        [1, 3, 4,  ..., 3, 1, 4],
        [1, 3, 4,  ..., 3, 1, 4]], device='cuda:0', dtype=torch.int8)

In [76]:
# x

In [14]:
# TODO: define transform to turn sequences into memory-efficient tensors

In [15]:
# TODO: re-implement bjorn's variant counting as tensor operations

In [16]:
class BasicNN(torch.nn.Module):
    def __init__(self, num_inputs, num_hidden, num_outputs):
        super(BasicNN, self).__init__()
        self.linear1 = torch.nn.Linear(num_inputs, num_hidden)
        self.activation1 = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(num_hidden, num_outputs)
        self.final_activation = torch.nn.Softmax()
        
        
    def forward(self, x):
        x = self.linear1(x)
        x = self.activation1(x)
        x = self.linear2(x)
        prediction = self.final_activation(x)
        return prediction