##2.1 The Baseline: A site-indep model 

### Get sequence from BLAT_ECOLX_1_b0.5_labeled.fasta

In [39]:
# parsing the FASTA file, codes from https://colab.research.google.com/github/wouterboomsma/pml_vae_project/blob/main/protein_vae_data_processing.ipynb
import os
import re
import numpy as np
import torch
import torch.nn.functional as F
import pandas as pd

try:
    from Bio import SeqIO
except:
    !pip install biopython
    from Bio import SeqIO
    
if not os.path.exists('BLAT_ECOLX_1_b0.5_labeled.fasta'):
    !wget https://sid.erda.dk/share_redirect/a5PTfl88w0/BLAT_ECOLX_1_b0.5_labeled.fasta
        
if not os.path.exists('BLAT_ECOLX_Ranganathan2015.csv'):
    !wget https://sid.erda.dk/share_redirect/a5PTfl88w0/BLAT_ECOLX_Ranganathan2015.csv
        
aa1_to_index = {'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 'G': 5, 'H': 6,
                'I': 7, 'K': 8, 'L': 9, 'M': 10, 'N': 11, 'P': 12,
                'Q': 13, 'R': 14, 'S': 15, 'T': 16, 'V': 17, 'W': 18,
                'Y': 19, 'X':20, 'Z': 21, '-': 22}
aa1 = "ACDEFGHIKLMNPQRSTVWYXZ-"

phyla = ['Acidobacteria', 'Actinobacteria', 'Bacteroidetes',
         'Chloroflexi', 'Cyanobacteria', 'Deinococcus-Thermus',
         'Firmicutes', 'Fusobacteria', 'Proteobacteria', 'Other']

device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
def get_baseline_data(data_filename, calc_weights=False, weights_similarity_threshold=0.8):
    ids = []
    labels = []
    seqs = []
    label_re = re.compile(r'\[([^\]]*)\]')
    for record in SeqIO.parse(data_filename, "fasta"):
        ids.append(record.id)       
        seqs.append(np.array([aa1_to_index[aa] for aa in str(record.seq).upper().replace('.', '-')]))
        
        label = label_re.search(record.description).group(1)

        if label not in phyla:
            label = 'Other'
        labels.append(label)
                
    seqs =np.vstack(seqs)
    labels = np.array(labels)
    seqs_tensor = torch.from_numpy(seqs)

    phyla_lookup_table, phyla_idx = np.unique(labels, return_inverse=True)
    dataset = torch.utils.data.TensorDataset(*[seqs_tensor, torch.from_numpy(phyla_idx)])
    
    weights = None
    if calc_weights is not False:

        # Experiencing memory issues on colab for this code because pytorch doesn't
        # allow one_hot directly to bool. Splitting in two and then merging.
        # one_hot = F.one_hot(seqs_tensor.long()).to('cuda' if torch.cuda.is_available() else 'cpu')
        one_hot1 = F.one_hot(seqs_tensor[:len(seqs_tensor)//2].long()).bool()
        one_hot2 = F.one_hot(seqs_tensor[len(seqs_tensor)//2:].long()).bool()
        one_hot = torch.cat([one_hot1, one_hot2]).to(device)
        assert(len(seqs_tensor) == len(one_hot))
        del one_hot1
        del one_hot2
        one_hot[seqs_tensor>19] = 0
        flat_one_hot = one_hot.flatten(1)

        weights = []
        weight_batch_size = 1000
        flat_one_hot = flat_one_hot.float()
        for i in range(seqs_tensor.size(0) // weight_batch_size + 1):
            x = flat_one_hot[i * weight_batch_size : (i + 1) * weight_batch_size]
            similarities = torch.mm(x, flat_one_hot.T)
            lengths = (seqs_tensor[i * weight_batch_size : (i + 1) * weight_batch_size] <=19).sum(1).unsqueeze(-1).to('cuda' if torch.cuda.is_available() else 'cpu')
            w = 1.0 / (similarities / lengths).gt(weights_similarity_threshold).sum(1).float()
            weights.append(w)
            
        weights = torch.cat(weights)
        neff = weights.sum()
    return seqs, labels, weights, phyla_lookup_table, phyla_idx, dataset

In [40]:
seqs, labels, weights, phyla_lookup_table, phyla_idx, dataset=get_baseline_data('BLAT_ECOLX_1_b0.5_labeled.fasta',calc_weights=True)

### Baseline

In [8]:
from collections import Counter
def my_log(data):
  if data > 0:
    result = np.log(data)
  else:
    result = 0
  return result

class base1ine(object):
    def __init__(self, pseudo_count = 1):
        self.pseudo_count = pseudo_count
        self.freqs = []

    def get_freqs(self, seq_data):
        for position in range(seq_data.shape[1]):
            freq_aa_in_position = {}
            aa_in_position = seq_data[:, position]
            count_aa = Counter(aa_in_position)
            num_aa = len(np.unique(aa_in_position))
            for i in range(23):
                freq_aa_in_position[i] = (count_aa[i] + self.pseudo_count) / (seq_data.shape[0] + num_aa * self.pseudo_count)
            self.freqs.append(freq_aa_in_position)

    def get_P_of_seqs(self, seq_data):   
        P_of_seqs = []
        for seq in seq_data:
            P_of_seq= 0
            for i, aa in enumerate(seq):
                P_of_seq += my_log(self.freqs[i][aa])
            P_of_seqs.append(P_of_seq)
            P_of_seqs = np.array(P_of_seqs)
        return P_of_seqs

In [9]:
baseline = base1ine(pseudo_count=1)

### Get Experimental data

In [45]:
def read_experimental_data(filename, alignment_data, measurement_col_name = '2500', sequence_offset=0):
    
    measurement_df = pd.read_csv(filename, delimiter=',', usecols=['mutant', measurement_col_name])
    
    wt_sequence, wt_label = alignment_data[0]
    
    zero_index = None
    
    experimental_data = {}
    for idx, entry in measurement_df.iterrows():
        mutant_from, position, mutant_to = entry['mutant'][:1],int(entry['mutant'][1:-1]),entry['mutant'][-1:]  
       
        if zero_index is None:
            zero_index = position

        seq_position = position-zero_index+sequence_offset
            
        assert mutant_from == aa1[wt_sequence[seq_position]]  
        
        if seq_position not in experimental_data:
            experimental_data[seq_position] = {}
        
        assert mutant_to not in experimental_data[seq_position]
        
        experimental_data[seq_position]['pos'] = seq_position
        experimental_data[seq_position]['WT'] = mutant_from
        experimental_data[seq_position][mutant_to] = entry[measurement_col_name]
    
    experimental_data = pd.DataFrame(experimental_data).transpose().set_index(['pos', 'WT'])
    return experimental_data
        
experimental_data = read_experimental_data("BLAT_ECOLX_Ranganathan2015.csv", dataset)

### Result

In [70]:
baseline.get_freqs(seqs)

In [63]:
import copy
raw_sequence = [seqs[0]]
log_P_of_wt = baseline.get_P_of_seqs(raw_sequence)
experiment_value = []
predicted_value = []
for (position, mutant_from), row in experimental_data.iterrows():
    assert aa1_to_index[mutant_from] == raw_sequence[0][position]
    for mutant_to, exp_value in row.iteritems():
        if mutant_to != mutant_from:
            new_sequence = copy.deepcopy(raw_sequence)
            new_sequence[0][position] = aa1_to_index[mutant_to]
            experiment_value.append(exp_value)
            log_P_of_mt = baseline.get_P_of_seqs(new_sequence)
            predicted_value.append(-(log_P_of_wt - log_P_of_mt))

In [64]:
from scipy.stats import spearmanr
spearmanr(experiment_value, predicted_value)

SpearmanrResult(correlation=0.6061777186788756, pvalue=0.0)

### Weighted result

In [73]:
dataloader_weighted = torch.utils.data.DataLoader(dataset, batch_size=16, sampler=torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples=len(dataset)))

In [74]:
weighed_seqs=[]
for i in dataloader_weighted:
  weighed_seqs.append(i[0][0].cpu().detach().numpy())
weighed_seqs = np.array(weighed_seqs)

In [81]:
baseline_weighted = base1ine(pseudo_count=1)
baseline_weighted.get_freqs(weighed_seqs)

In [84]:
import copy
raw_sequence = [seqs[0]]
log_P_of_wt = baseline_weighted.get_P_of_seqs(raw_sequence)
experiment_value = []
predicted_value = []
for (position, mutant_from), row in experimental_data.iterrows():
    assert aa1_to_index[mutant_from] == raw_sequence[0][position]
    for mutant_to, exp_value in row.iteritems():
        if mutant_to != mutant_from:
            new_sequence = copy.deepcopy(raw_sequence)
            new_sequence[0][position] = aa1_to_index[mutant_to]
            experiment_value.append(exp_value)
            log_P_of_mt = baseline_weighted.get_P_of_seqs(new_sequence)
            predicted_value.append(-(log_P_of_wt - log_P_of_mt))

In [85]:
from scipy.stats import spearmanr
spearmanr(experiment_value, predicted_value)

SpearmanrResult(correlation=0.5960609586606334, pvalue=0.0)