In [1]:
import torch
from torch import nn
import numpy as np

from resource import ResourceType
from genetics import inference_gene, dna_transcriptor

In [2]:
from genetics import AMINOACIDS_MAPPING

In [8]:
w = "TATAAA"
w2 = "AAATAT"
w3 = "TATAAAAAATAT"
w4 = "TATAAATAGAAATAT"
w5 = "AAATATTAGTATAAA"
print(dna_transcriptor.transcript(w), inference_gene(w))
print(dna_transcriptor.transcript(w2), inference_gene(w2))
print(dna_transcriptor.transcript(w3), inference_gene(w3))
print(dna_transcriptor.transcript(w4), inference_gene(w4))
print(dna_transcriptor.transcript(w5), inference_gene(w5))

YK (0.3210432827472687, 1.0878145694732666)
KY (0.35069864988327026, 1.3849995136260986)
YKKY (0.3651913106441498, 0.4586832821369171)
YK$KY (0.29189637303352356, 0.7147053480148315)
KY$YK (0.3570723533630371, 0.4408493936061859)


In [3]:
for k, item in AMINOACIDS_MAPPING.items():
    print(k, inference_gene(k))

AAA (0.5327721238136292, 0.8989083170890808)
AAT (0.3168683648109436, 0.6960346102714539)
AAG (0.5327721238136292, 0.8989083170890808)
AAC (0.3168683648109436, 0.6960346102714539)
ATA (0.34641411900520325, 0.7110138535499573)
ATT (0.34641411900520325, 0.7110138535499573)
ATG (0.44553282856941223, 0.880810022354126)
ATC (0.34641411900520325, 0.7110138535499573)
AGA (0.35719195008277893, 1.016831636428833)
AGT (0.3200312554836273, 0.5010390877723694)
AGG (0.35719195008277893, 1.016831636428833)
AGC (0.3200312554836273, 0.5010390877723694)
ACA (0.38398969173431396, 0.593378484249115)
ACT (0.38398969173431396, 0.593378484249115)
ACG (0.38398969173431396, 0.593378484249115)
ACC (0.38398969173431396, 0.593378484249115)
TAA (0.597590982913971, 0.9058592915534973)
TAT (0.47040489315986633, 0.5021955966949463)
TAG (0.597590982913971, 0.9058592915534973)
TAC (0.47040489315986633, 0.5021955966949463)
TTA (0.4684317409992218, 0.6818724274635315)
TTT (0.3636183738708496, 1.0543146133422852)
TTG (0.

In [254]:
AMINOACIDS = sorted(list('FLIMVPTAYHQNKDECWRSG'))
NUCLEOTIDES = list('ATGC')
AMINOACIDS_MAPPING = {
    'AAA': 'K',    'AAT': 'N',    'AAG': 'K',    'AAC': 'N',
    'ATA': 'I',    'ATT': 'I',    'ATG': 'M',    'ATC': 'I',
    'AGA': 'R',    'AGT': 'S',    'AGG': 'R',    'AGC': 'S',
    'ACA': 'T',    'ACT': 'T',    'ACG': 'T',    'ACC': 'T',
    'TAA': '$',    'TAT': 'Y',    'TAG': '$',    'TAC': 'Y',
    'TTA': 'L',    'TTT': 'F',    'TTG': 'L',    'TTC': 'F',
    'TGA': '$',    'TGT': 'C',    'TGG': 'W',    'TGC': 'C',
    'TCA': 'S',    'TCT': 'S',    'TCG': 'S',    'TCC': 'S',
    'GAA': 'E',    'GAT': 'D',    'GAG': 'E',    'GAC': 'D',
    'GTA': 'V',    'GTT': 'V',    'GTG': 'V',    'GTC': 'V',
    'GGA': 'G',    'GGT': 'G',    'GGG': 'G',    'GGC': 'G',
    'GCA': 'A',    'GCT': 'A',    'GCG': 'A',    'GCC': 'A',
    'CAA': 'H',    'CAT': 'H',    'CAG': 'H',    'CAC': 'H',
    'CTA': 'L',    'CTT': 'L',    'CTG': 'L',    'CTC': 'L',
    'CGA': 'R',    'CGT': 'R',    'CGG': 'R',    'CGC': 'R',
    'CCA': 'P',    'CCT': 'P',    'CCG': 'P',    'CCC': 'P'
}

In [255]:
items = np.unique(list(zip(*AMINOACIDS_MAPPING.items()))[1]).tolist()
dict(zip(items, range(len(items))))

{'$': 0,
 'A': 1,
 'C': 2,
 'D': 3,
 'E': 4,
 'F': 5,
 'G': 6,
 'H': 7,
 'I': 8,
 'K': 9,
 'L': 10,
 'M': 11,
 'N': 12,
 'P': 13,
 'R': 14,
 'S': 15,
 'T': 16,
 'V': 17,
 'W': 18,
 'Y': 19}

In [158]:
class DNATranscriptor:
    def __init__(self, mapping=AMINOACIDS_MAPPING):
        self.mapping = mapping

    def transcript(self, dna):
        triplet_list = self.cut_triplets(dna)
        return ''.join([self.mapping[triplet] for triplet in triplet_list])
            

    def cut_triplets(self, dna):
        triplet_list = []
        step_count = len(dna) // 3
        for i in range(step_count):
            triplet_list.append(dna[3*i:3*i+3])
        return triplet_list
    
    def __call__(self, dna):
        return self.transcript(dna)

In [217]:
class AminoacidTokenizer:
    def __init__(self):
        items = np.unique(list(zip(*AMINOACIDS_MAPPING.items()))[1]).tolist()
        self.mapping = dict(zip(items, range(len(items))))
        
    def tokenize(self, x):
        gene = torch.tensor(list(map(lambda x: self.mapping[x], x)))
        position = torch.arange(len(gene))
        s_codon = torch.roll(torch.cumsum((gene == 0), dim=0), 1)
        s_codon[0] = 0
        return {'protein': gene.view(1, -1), 'position': position.view(1, -1), 's-codon': s_codon.view(1, -1)}
    
    def __call__(self, x):
        return self.tokenize(x)

In [218]:
dna_transcriptor = DNATranscriptor()
aminoacid_tokenizer = AminoacidTokenizer()

In [219]:
aminoacid_tokenizer(dna_transcriptor('AAATGAAAAATA'))

{'protein': tensor([[9, 0, 9, 8]]),
 'position': tensor([[0, 1, 2, 3]]),
 's-codon': tensor([[0, 0, 1, 1]])}

In [208]:
class GeneEncoder(nn.Module):
    def __init__(self):
        super(GeneEncoder, self).__init__()
        self.embedding = nn.Embedding(num_embeddings=20, embedding_dim=64)
        self.position_embedding = nn.Embedding(num_embeddings=128, embedding_dim=64, padding_idx=0)
        self.stop_codon_embedding = nn.Embedding(num_embeddings=128, embedding_dim=64)
        block = nn.TransformerEncoderLayer(d_model=64,
                                        nhead=8,
                                        dim_feedforward=64,
                                        dropout=0.1,
                                        batch_first=True)
        self.encoder = nn.TransformerEncoder(block, 6)
        
    def forward(self, gene, position, stop_codon_alignment):
        gene_encoding = self.embedding(gene)
        position_encoding = self.position_embedding(position)
        stop_codon_encoding = self.stop_codon_embedding(stop_codon_alignment)
        
        embeddings = gene_encoding + position_encoding + stop_codon_encoding
        outputs = self.encoder(embeddings)
        pooling = outputs.mean(dim=1)
        return outputs, pooling
    
class Head(nn.Module):
    def __init__(self):
        super(Head, self).__init__()
        self.linear = nn.Linear(64,  1)
        
    def forward(self, x):
        return torch.sigmoid(self.linear(x))
    
    
class FloatHead(nn.Module):
    def __init__(self):
        super(FloatHead, self).__init__()
        self.linear = nn.Linear(64,  1)
        
    def forward(self, x):
        return self.linear(x)

In [209]:
model = GeneEncoder()

In [210]:
is_cool_gene = Head()
is_float_gene = FloatHead()

In [220]:
def encode_gene(gene):
    return aminoacid_tokenizer(dna_transcriptor(gene))

In [221]:
def inference_gene(gene):
    model.eval()
    gene_tokenzied = encode_gene(gene)
    out, pooling = model(gene_tokenzied['protein'], gene_tokenzied['position'], gene_tokenzied['s-codon'])
    return float(is_cool_gene(pooling)), float(is_float_gene(pooling))

In [216]:
with torch.no_grad():
    for gene in ['AAATGA', "AAA", "AAATGAAAA", "AAAAAAAAA", "TGATGATGA", "GCGCGC"]:
        print(gene, inference_gene(gene))

AAATGA (0.44808048009872437, -0.6520209908485413)
AAA (0.4091910421848297, -0.9662929773330688)
AAATGAAAA (0.5065053701400757, -0.6098647713661194)
AAAAAAAAA (0.4136034846305847, -0.7743958830833435)
TGATGATGA (0.439308762550354, -0.6520992517471313)
GCGCGC (0.24293318390846252, -0.9206771850585938)


In [239]:
with torch.no_grad():
    for gene in ['AAATGA', "AAA", "AAATGAAAA", "AAAAAAAAA", "TGATGATGA", "GCGCGC", "TTTTTTTTT"]:
        print(gene, inference_gene(gene))

AAATGA (0.315243124961853, -0.9511200785636902)
AAA (0.4091910421848297, -0.9662929773330688)
AAATGAAAA (0.40868523716926575, -0.8094353079795837)
AAAAAAAAA (0.4136034846305847, -0.7743958830833435)
TGATGATGA (0.33275094628334045, -0.8055247068405151)
GCGCGC (0.24293318390846252, -0.9206771850585938)
TTTTTTTTT (0.25903835892677307, -0.7840675711631775)


In [214]:
o = aminoacid_tokenizer(dna_transcriptor('ATATAAAATTAAAAA'))
pr, po = o['protein'], o['position']

In [183]:
torch.cumsum((pr == 0), dim=1)

tensor([[0, 1, 1, 2, 2]])

In [184]:
pr.size()

torch.Size([1, 5])

In [186]:
(pr==0).nonzero()

tensor([[0, 1],
        [0, 3]])

In [241]:
my_resources = ResourceType({'h2o': 5, 'al': 1})

In [252]:
print(my_resources * 2.5)

None


In [250]:
int(1 / 0.12)

8

In [249]:
 8.333333333333334

0.12