In [None]:
import pathlib

from Bio.PDB import PDBParser
from Bio.Data.IUPACData import protein_letters_3to1

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib import ticker
from matplotlib.patches import Circle, Rectangle

import numpy as np
import pandas as pd

In [None]:
# Truetype fonts for better handling in Designer/Illustrator
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = 'Arial'
# mpl.rcParams['pdf.fonttype'] = 42
plt.rcParams['svg.fonttype'] = 'none'

In [None]:
# define color pallete (cvd-friendly)
blue = '#005AB5'
red = '#DC3220'
gray = '#D0D0D0'

In [None]:
print(mpl.__version__)
print(pd.__version__)

In [None]:
def cmap_discretize(cmap, N):
    """Return a discrete colormap from the continuous colormap cmap.
    
        cmap: colormap instance, eg. cm.jet. 
        N: number of colors.
    """
    if type(cmap) == str:
        cmap = cm.get_cmap(cmap)
    colors_i = np.concatenate((np.linspace(0, 1., N), (0.,0.,0.,0.)))
    colors_rgba = cmap(colors_i)
    indices = np.linspace(0, 1., N+1)
    cdict = {}
    for ki, key in enumerate(('red','green','blue')):
        cdict[key] = [(indices[i], colors_rgba[i-1,ki], colors_rgba[i,ki]) for i in range(N+1)]
    # Return colormap object.
    return mpl.colors.LinearSegmentedColormap(cmap.name + "_%d"%N, cdict, 1024)

In [None]:
def cm2inch(*tupl):
    inch = 2.54
    if isinstance(tupl[0], tuple):
        return tuple(i/inch for i in tupl[0])
    else:
        return tuple(i/inch for i in tupl)

In [None]:
# Scientific to Common
namedict1 = {
    'Homo sapiens': 'Human',
    'Anas platyrhynchos': 'Duck',
    'Bos taurus': 'Cow',
    'Camelus dromedarius': 'Dromedary',
    'Canis lupus familiaris': 'Dog',
    'Capra hircus': 'Goat',
    'Carassius auratus': 'Goldfish',
    'Cavia porcellus': 'Guinea pig',
    'Columba livia': 'Pidgeon',
    'Crocodylus porosus': 'Crocodile',
    'Equus asinus': 'Donkey',
    'Equus caballus': 'Horse',
    'Erinaceus europaeus': 'Hedgehog',
    'Felis catus': 'Cat',
    'Gallus gallus': 'Chicken',
    'Macaca mulatta': 'Macaque',
    'Manis javanica': 'Pangolin',
    'Mesocricetus auratus': 'Hamster',
    'Mus musculus': 'Mouse',
    'Mustela putorius furo': 'Ferret',
    'Oryctolagus cuniculus': 'Rabbit',
    'Ovis aries': 'Sheep',
    'Paguma larvata': 'Civet',
    'Pan troglodytes': 'Chimpanzee',
    'Panthera tigris altaica': 'Siberian Tiger',
    'Pongo abelii': 'Orangutan',
    'Rattus norvegicus': 'Rat',
    'Rhinolophus sinicus': 'Horseshoe Bat',
    'Serinus canaria': 'Canary',
    'Sus scrofa': 'Pig'
}

# Lower case and underscore
namedict = {
    '_'.join(k.lower().split()[:2]): v
    for k, v in namedict1.items()
}

In [None]:
# Sets of sars-cov-2 positive/negative species
positive = [
    'Homo sapiens',
    'Felis catus',
    'Manis javanica',
    'Mesocricetus auratus',
    'Mustela putorius furo',
    'Paguma larvata',
    'Panthera tigris altaica',
    'Rhinolophus sinicus',
    'Bos taurus',
    'Ovis aries',
    'Camelus dromedarius',
    'Oryctolagus cuniculus',
    'Equus caballus', 
]

positive = list(map(namedict1.get, positive))

negative = [
    'Anas platyrhynchos',
    'Gallus gallus',
    'Mus musculus',
    'Cavia porcellus',
    'Rattus norvegicus',
]

negative = list(map(namedict1.get, negative))

In [None]:
rootdir = pathlib.Path('.').resolve(strict=True)
datadir = rootdir.parent / 'refinement' / 'energy_analysis_5A'

In [None]:
# Read in PDB files for each species
parser = PDBParser(QUIET=1)
pdbs = list(datadir.rglob('*1.pdb'))
pdbdict = {}
for pdb in pdbs:
#     print(pdb)
    chaindict = {}
    s = parser.get_structure('x', str(pdb))
    for chain in s.get_chains():
        chaindict[chain.id] = {}
        for res in chain:
            resi = res.id[1]
            resn = protein_letters_3to1.get(res.resname.capitalize())
            if resn is None:
                continue
            chaindict[chain.id][resi] = resn
    species = namedict.get(pdb.parent.name)
    if species is None:
        raise Exception(species)
    pdbdict[species]= chaindict

In [None]:
df = pd.read_csv(datadir / 'agg_hs.csv', index_col='species')
df.index = map(namedict.get, df.index)
df.head(2)

In [None]:
# Remove all 0.0 columns
df = df[df != 0.0].dropna(axis=1, how='all')

In [None]:
# Load seq_id to sort by HS
df2 = pd.read_csv(datadir.parent / 'scores.dat', index_col='Species')
df2.index = map(namedict1.get, df2.index)
df2.head(2)

In [None]:
df['HS'] = df2['HS']
df.sort_values(by='HS', inplace=True)
df.drop(columns=['HS'], inplace=True)
df.head(2)

In [None]:
# Pick columns for each protein based on chain ID
ace2_cols = [c for c in df.columns if 'B' in c]
vrbd_cols = [c for c in df.columns if 'E' in c]

ace2 = df[ace2_cols]
vrbd = df[vrbd_cols]

In [None]:
# Make dataframes with residue names for annotation
ace2_annot = ace2.copy(deep=True)
for row in ace2_annot.index:
    for col in ace2_annot.columns:
        chain = col[0]
        resid = int(col[1:])
        try:
            aakey = pdbdict[row][chain][resid]
        except KeyError:
            aakey = 'X'
        ace2_annot.loc[row, col] = aakey
# ace2_annot.head(2)

vrbd_annot = vrbd.copy(deep=True)
for row in vrbd_annot.index:
    for col in vrbd_annot.columns:
        chain = col[0]
        resid = int(col[1:])
        try:
            aakey = pdbdict[row][chain][resid]
        except KeyError:
            aakey = 'X'
        vrbd_annot.loc[row, col] = aakey
vrbd_annot.head(2)

In [None]:
df.columns

## Figure

In [None]:
fig, ax1 = plt.subplots(
    nrows=1, ncols=1,
    figsize=cm2inch(18, 12),  # w,h
#     dpi=600  # for viewing here on notebook
)

# Calculate average row and add it to df
mask = ace2[ace2 == 0.0].astype(bool)
ave_data = ace2[mask].median().to_frame().T
ave_data.rename({0: 'Median'}, inplace=True)

# make copy with blank values
ave_data_blank = ave_data.copy(deep=True)
ave_data_blank.loc['Median'] = np.nan

# Merge
data = pd.concat([ave_data_blank, ace2[mask]])

# Create discrete colormap
discretized_cmap = cmap_discretize('Blues_r', 5)

# Per species heatmap
hm = ax1.imshow(
    data.values,
    cmap=discretized_cmap,
    interpolation=None
)

# Remove spines and make white grid
for edge, spine in ax1.spines.items():
    spine.set_visible(False)

ax1.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
ax1.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
ax1.grid(which="minor", color="w", linestyle='-', linewidth=0.5)
ax1.tick_params(which="minor", bottom=False, left=False)

# Set axes labels
ax1.set_xticks(np.arange(data.shape[1]))
ax1.set_yticks(np.arange(data.shape[0]))

# Turn first row to circles for distinction
nr, nc = ave_data.shape
for col in range(nc):
    v = ave_data.iloc[0, col]
    c = hm.cmap(hm.norm(v))

    circ = Circle(
        (col, 0),
        radius=0.25,
        color=c
    )
    ax1.add_patch(circ)

# Replace imshow with individual squares
# for editability in illustrator/etc
nr, nc = data.shape
for x in range(nc):
    for y in range(1, nr):
        v = data.iloc[y, x]

        if v == 0.0:
            continue

        c = hm.cmap(hm.norm(v))

        sq = Rectangle(
            (x - .5, y - .5),
            width=1,
            height=1,
            color=c,
            linewidth=0
        )

        ax1.add_patch(sq)
    
# Add amino acid labels
# Pick font color based on luminance of background
# from seaborn source
def get_font_color(data, px, py):

    v = data.iloc[px, py]
    color = hm.cmap(hm.norm(v))
    
    rgb = mpl.colors.colorConverter.to_rgba_array(color)[:, :3]
    rgb = np.where(rgb <= .03928, rgb / 12.92, ((rgb + .055) / 1.055) ** 2.4)
    lum = rgb.dot([.2126, .7152, .0722])
    return "k" if lum > .408 else "w"

    
nx, ny = ace2.shape
for i in range(nx):
    for j in range(ny):
        if ace2.iloc[i, j] == 0.0:
            continue
        text = ax1.text(
            j, i + 1, ace2_annot[mask].iloc[i, j],
            ha="center", va="center",
            color=get_font_color(data, i + 1, j),
            size=7.5
        )
    
# Handle labels
ax1.xaxis.set_ticks_position('none')
ax1.yaxis.set_ticks_position('none')

xlabels = []
for c in data.columns:
    resi = c[1:]  # clip chain name
    if int(resi) > 1000:  # replace first digit with X
        resi = 'x' + resi[1:]
    xlabels.append(resi)

ax1.set_xticklabels(
    xlabels,
    fontsize=8,
    rotation=90
)

ax1.set_yticklabels(
    data.index,
    fontsize=8
)

ax1.set_xlabel('ACE2 Residues', size=8)

for l in ax1.yaxis.get_ticklabels():  # bold average
    l.set_fontweight('bold')
    break

# Color ylabels
# for l in ax1.yaxis.get_ticklabels():
#     l_text = l.get_text()
#     if l_text in positive_set:
#         l.set_c(blue)
#     elif l_text in negative_set:
#         l.set_c(red)
#     else:
#         l.set_c('k')


# Create colorbar
fig.canvas.draw()  # draw first to get positions
ax_pos = ax1.get_position()
cbax = fig.add_axes([
    ax_pos.xmax + 0.01,  # xmax + pad
    ax_pos.ymin,
    0.03,
    0.73
])

ydata = data.values[~np.isnan(data.values)].ravel()
datamin, datamax = ydata.min(), ydata.max()
datarange = datamax - datamin
datastep = datarange / 5
cbticks = [datamax - (i*datastep) for i in range(6)]

cbar = fig.colorbar(
    hm,
    cax=cbax,
    orientation='vertical',
    pad=0,
    ticks=cbticks
)
cbar.ax.set_ylabel('Residue HADDOCK score (a.u.)', size=9)
for l in cbar.ax.yaxis.get_ticklabels():
    l.set_fontsize(8)

cbar.outline.set_linewidth(0.5)

# Hide imshow data
data_copy = data.copy(deep=True)
for col in data_copy.columns:
    data_copy[col].values[:] = np.nan
hm.set_data(data_copy)

In [None]:
#fig.savefig('Figure_S3-new.svg', transparent=True);