In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
plt.ion()

import logomaker as lm



In [None]:
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio.Align import MultipleSeqAlignment

msa = []
with open(msa_file,"r") as f:
    raw_seq = f.readlines()
    for i in range(0,len(raw_seq),2):
        seq = raw_seq[i+1]
        ident = raw_seq[i].strip().split("/")[-1]
        msa.append(SeqRecord(Seq(seq), id=ident))

In [None]:
#msa_file="pmpnn_design/alpha_chain_top100.fasta"
#change to your msa file - only one chain
msa_file="esm_design/alpha_chain_top100.fasta"
#msa_file="esm_design/beta_chain_top100.fasta"
#change to your fasta with native sequence - only one chain
native_fasta="chainD.fasta"

#set range of residues, that should be displayed in the logo plot
start = 10
end = 20


res_range = [i for i in range(start-1,end)]

#or create list with specific residues:
#res_range = [10,11,66,68,69,70,71,73,74,75,110,112]
#res_range = [8,10,11,66, 69,77,88,110,112]

#reads msa file and creates matrix containing counts for each mutation
with open(msa_file,"r") as f:
    raw_seq = f.readlines()
f.close()

seqs = [seq.strip() for seq in raw_seq if ('#' not in seq) and ('>') not in seq]
seqs_clean = list(filter(None, seqs))
counts_mat = lm.alignment_to_matrix(seqs_clean)
counts_mat.head()

#reads native sequence file and creates list of native residues to use later for axis labels 
#- native fasta should have same starting point and length as fastas from the design
o_seq =[]
with open(native_fasta, 'r') as fasta:
    line = [l.strip() for l in fasta]
    o_seq[:]= line[1]
    #res_range=[i for i in range(1,len(o_seq)+1)]
    o_seq_res = [o_seq[i-1] for i in res_range]
fasta.close()

#filters count matrix based on the selected residue positions
counts_mat_range = counts_mat.loc[[x-1 for x in res_range]]
counts_mat_range.reset_index(drop=True, inplace=True)

print(f'the native residues displayed in this logo plot are {[str(i+1)+j for i,j in zip(res_range,o_seq_res)]} 

In [None]:
#output = 'pmpnn_design_top100_alpha_logoplot.png'
output = "esm_design_top100_alpha_completelogoplot" #change to output name

cm = 1/2.54
fig, ax = plt.subplots(figsize=[10*cm,5*cm])
seq_logo = lm.Logo(counts_mat_range,
                           ax=ax,
                           stack_order='small_on_top',
                          #fade_probabilities=True,
                           #font_name='Times New Roman',
                           color_scheme='black',
                           vpad=.02,
                           #color_scheme='grays'
                          )
seq_logo.ax.set_yticks([0, 1])
seq_logo.ax.set_yticklabels(['0', '1'], size=16)
seq_logo.ax.set_xticks(range(len(res_range)))
seq_logo.ax.set_xticklabels(o_seq_res, size=16)
secax = ax.secondary_xaxis('top')
secax.set_xticks(range(len(res_range)))
secax.set_xticklabels(res_range, size=16)
secax.set_xlabel("Residue numbers", size=18, labelpad=10)
ax.set_xlabel("Wildtype amino acid", size=18, labelpad=10)
ax.set_ylabel("count", size=18, labelpad=10)
ax.tick_params(axis=u'both', which=u'both',length=0)
secax.tick_params(axis=u'both', which=u'both',length=0)


#select which residues to highlight
highlights = []
nothighlights = [i for i in range(len(res_tange)) if i not in highlights]
for i in highlights:
    seq_logo.highlight_position(p=i-1, color='#9932CC', alpha=.5)
for i in nothighlights:
    seq_logo.highlight_position(p=i-1, color='#FFFFFF', alpha=.5)

    
fig.tight_layout()
fig.savefig(f'{output}.png', dpi=300) #save as png
fig.savefig(f'{output}.svg', dpi=300) #save as vector graphic