In [10]:
import sys
sys.path.append('../')
import torch  
from Bio import SeqIO, AlignIO
import mutations as bm

In [11]:
print(torch.cuda.is_available())
print(torch.cuda.current_device())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))

0
1
Quadro P5000
True


In [2]:
msa_fp = '/valhalla/gisaid/sequences_2021-02-21_aligned.fasta'
patient_zero = 'NC_045512.2'
seqs = SeqIO.parse(msa_fp, 'fasta')

In [3]:
def encode_rna(sequence: str):
    "encode nucleotide sequence as 4-bit vector"
    code = {
            '-': -1,
            'N': 0,
            'A': 1,
            'C': 2,
            'T': 3,
            'G': 4
    }
    return [code.get(nt, 0) for nt in sequence]

In [4]:
# Load sample data from GISAID and encode
data = []
for i, rec in enumerate(seqs):
    data.append(encode_rna(str(rec.seq)))
    if i >= 10000:
        break
# add reference
ref = bm.get_seq(seqs, patient_zero)
data.append(encode_rna(ref))
len(data)

10002

In [5]:
# TEST
max_len = 0
for seq in data:
    if len(seq)>max_len:
        max_len = len(seq)
max_len

29903

In [6]:
# TEST
min_len = 100000
for seq in data:
    if len(seq)<min_len:
        min_len = len(seq)
min_len

29903

In [12]:
# load sequences into GPU
data_tensor = torch.as_tensor(data, 
                              dtype=torch.int8, 
                              device=torch.cuda.current_device())

In [13]:
data_tensor

tensor([[-1, -1, -1,  ..., -1, -1, -1],
        [-1, -1, -1,  ...,  1,  1,  1],
        [-1, -1, -1,  ...,  1,  1,  1],
        ...,
        [-1, -1, -1,  ..., -1, -1, -1],
        [-1, -1, -1,  ..., -1, -1, -1],
        [ 1,  3,  3,  ...,  1,  1,  1]], device='cuda:0', dtype=torch.int8)

In [None]:
class SequenceData(Dataset):
    """SARS-CoV-2 consensus sequence data."""
    def __init__(self, fasta_filepath, meta_filepath, transforms=None):
        """
        Args:
            fasta_filepath (string): Path to the file containing sequences.
            meta_filepath (string): Path to the file containing metadata.
            transform (callable, optional): Optional transforms to be applied
                on a sample e.g. encode_rna().
        """
        self.meta = pd.read_csv(meta_filepath, sep='\t', compression='gzip')
        # load as generator to save mem
        self.sequences = SeqIO.parse(fasta_filepath, 'fasta')
        self.transforms = transforms

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        if self.transform:
            sample = self.transform(sample)
        return sample

In [14]:
# TODO: define transform to turn sequences into memory-efficient tensors

In [15]:
# TODO: re-implement bjorn's variant counting as tensor operations

In [16]:
class BasicNN(torch.nn.Module):
    def __init__(self, num_inputs, num_hidden, num_outputs):
        super(BasicNN, self).__init__()
        self.linear1 = torch.nn.Linear(num_inputs, num_hidden)
        self.activation1 = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(num_hidden, num_outputs)
        self.final_activation = torch.nn.Softmax()
        
        
    def forward(self, x):
        x = self.linear1(x)
        x = self.activation1(x)
        x = self.linear2(x)
        prediction = self.final_activation(x)
        return prediction