In [22]:
from utils.utils import *
import matplotlib.pyplot as plt
import utils.language_helpers as lh
from utils.bio_utils import *
import numpy as np

In [23]:
codon_table = {
    'ATA':'I', 'ATC':'I', 'ATT':'I', 'ATG':'M',
    'ACA':'T', 'ACC':'T', 'ACG':'T', 'ACT':'T',
    'AAC':'N', 'AAT':'N', 'AAA':'K', 'AAG':'K',
    'AGC':'S', 'AGT':'S', 'AGA':'R', 'AGG':'R',
    'CTA':'L', 'CTC':'L', 'CTG':'L', 'CTT':'L',
    'CCA':'P', 'CCC':'P', 'CCG':'P', 'CCT':'P',
    'CAC':'H', 'CAT':'H', 'CAA':'Q', 'CAG':'Q',
    'CGA':'R', 'CGC':'R', 'CGG':'R', 'CGT':'R',
    'GTA':'V', 'GTC':'V', 'GTG':'V', 'GTT':'V',
    'GCA':'A', 'GCC':'A', 'GCG':'A', 'GCT':'A',
    'GAC':'D', 'GAT':'D', 'GAA':'E', 'GAG':'E',
    'GGA':'G', 'GGC':'G', 'GGG':'G', 'GGT':'G',
    'TCA':'S', 'TCC':'S', 'TCG':'S', 'TCT':'S',
    'TTC':'F', 'TTT':'F', 'TTA':'L', 'TTG':'L',
    'TAC':'Y', 'TAT':'Y', 'TAA':'_', 'TAG':'_',
    'TGC':'C', 'TGT':'C', 'TGA':'_', 'TGG':'W',
}

def countCorrectP(dna_seqs, verbose=False):
    global codon_table
    total = 0.0
    correct = 0.0
    for dna_seq in dna_seqs:
        p_seq = ""
        total += 1
        if dna_seq[0:3] != 'ATG':
            if verbose: print("Not valid gene (no ATG)")
            continue
        for i in range(3, len(dna_seq), 3):
            codon = dna_seq[i:i+3]
            try:
                aa = codon_table[codon]
                p_seq += aa
                if aa == '_': 
                    break
            except:
                if verbose: print("Error! Invalid Codon {} in {}".format(codon, dna_seq))
                break
        if len(p_seq) <= 2: #needs to have stop codon and be of length greater than 2
            if verbose: print("Error! Protein too short.")
        elif p_seq[-1] != '_':
            if verbose: print("Error! No valid stop codon.")
        else:
            correct+=1
    return correct/total

In [24]:
ite = [protein for (Nc,protein) in codon_table.items()] 
ite = list(set(ite))
print(ite)

['I', 'V', 'P', 'K', 'F', 'S', 'W', 'M', 'N', 'L', 'Y', 'H', 'T', 'A', 'E', 'R', '_', 'D', 'G', 'C', 'Q']


In [25]:
def calFreq(ite,lines):
    t_freq_dict = {}
    for char in ite:
        t_freq_dict[char] = 0.0
    if len(lines)== 0:
        return t_freq_dict
    total = 0
    for line in lines:
        if len(line)!= 0:
            for char in line:
                t_freq_dict[char] += 1
            total+=len(line)
    
    for char in ite:
        t_freq_dict[char]/=total
    return t_freq_dict  

In [26]:
t_freq_pre = []
for i in range(12400):
    if i % 100 == 99:
        path = './samples/realProt_50aa/sampled_val_{}.txt'.format(i)
        lines = []
        with open(path,'r') as f:
            content = f.read()
            lines = content.split('\n')
        lines = geneToProtein(lines,verbose=False)
        t_freq_pre.append(calFreq(ite,lines))


In [27]:
t_freq_fb = []
for i in range(150):
    path = './samples/fbgan_amp_demo/sample_val_{}.txt'.format(i+1)
    lines = []
    with open(path,'r') as f:
        content = f.read()
        lines = content.split('\n')
    lines = geneToProtein(lines,verbose=False)
    t_freq_fb.append(calFreq(ite,lines))

In [28]:
real_datas,_,__ = lh.load_dataset(max_length=156,max_n_examples=2048,data_dir='./data/AMP_dataset.fa')
real_lines = []
for data in real_datas:
        line = ''
        for char in data:
                if char.isalpha():
                        line += char
        real_lines.append(line)
real_lines = geneToProtein(real_lines,verbose=False)

loading dataset...
('A', 'T', 'G', 'C', 'G', 'T', 'A', 'T', 'G', 'T', 'G', 'C', 'A', 'A', 'G', 'A', 'C', 'T', 'C', 'C', 'T', 'A', 'G', 'T', 'G', 'G', 'T', 'A', 'A', 'A', 'T', 'T', 'C', 'A', 'A', 'G', 'G', 'G', 'T', 'T', 'A', 'C', 'T', 'G', 'T', 'G', 'T', 'T', 'A', 'A', 'T', 'A', 'A', 'C', 'A', 'C', 'G', 'A', 'A', 'C', 'T', 'G', 'C', 'A', 'A', 'A', 'A', 'A', 'C', 'G', 'T', 'A', 'T', 'G', 'C', 'C', 'G', 'G', 'A', 'C', 'A', 'G', 'A', 'G', 'G', 'G', 'C', 'T', 'T', 'T', 'C', 'C', 'C', 'A', 'C', 'C', 'G', 'G', 'A', 'T', 'C', 'T', 'T', 'G', 'T', 'G', 'A', 'T', 'T', 'T', 'T', 'C', 'A', 'C', 'G', 'T', 'C', 'G', 'C', 'C', 'G', 'G', 'C', 'C', 'G', 'T', 'A', 'A', 'A', 'T', 'G', 'T', 'T', 'A', 'C', 'T', 'G', 'T', 'T', 'A', 'C', 'A', 'A', 'A', 'C', 'C', 'T', 'T', 'G', 'C', 'C', 'C', 'C', 'T', 'A', 'A')
loaded 2048 lines in dataset


In [29]:
t_freq_real = [calFreq(ite,real_lines)] * len(t_freq_fb)

In [30]:
freq_fb = {}
freq_pre = {}
freq_real = {}
for char in ite:
    freq_fb[char] = []
    freq_pre[char] = []
    freq_real[char] = []
    for sample in t_freq_fb:
        freq_fb[char].append(sample[char])
    for sample in t_freq_pre:
        freq_pre[char].append(sample[char])
    for sample in t_freq_real:
        freq_real[char].append(sample[char])

In [31]:
plt.figure(figsize=(20,13))

for char in ite:
    plt.plot(freq_pre[char],label=char+' in pre',linewidth=3)
    plt.plot(freq_real[char])

plt.xlabel('Epoch')
plt.ylim(bottom=0)
plt.grid()
plt.title('Percentage of aa per chain in pre train')
plt.legend(loc='lower left')

plt.savefig('./AMP/aaDis.png')