In [None]:
import glob
import math
from typing import List
import warnings 
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

residues = sorted([
    "ALA", "ARG", "ASH", "ASN", "ASP", 
    "CYS", "GLH", "GLN", "GLU", "GLY", 
    "HIP", "HIS", "ILE", "LEU", "LYS", 
    "MET", "PHE", "PRO", "PTR", "S1P",
    "SEP", "SER", "T1P", "THR", "TPO", 
    "TRP", "TYR", "VAL", "Y1P"])

%matplotlib inline

In [None]:
class polar:

    @staticmethod
    def distance(r_1, theta_1, r_2, theta_2):
        """ Calculates the distance between two coordinates """
        return np.sqrt(r_1 ** 2 + r_2 ** 2 - 2 * r_1 * r_2 * np.cos(theta_1 - theta_2))

    @staticmethod
    def get_incircle(r1, th1, r2, th2):
        """ Returns the center and radius of the circle, to which both lines are tangents """
        delta = abs(th1 - th2)                      # angle between tangents
        chord = polar.distance(r1, th1, r2, th2)          # length of chord
        alpha = .5 * delta                          # angle between chord and radius
        circle_radius = .5 * chord / np.cos(alpha)  # radius of the incircle
        circle_radius_theta = th1 + .5 * np.pi      # theta perpendicular to tangent1
        origin_r, origin_theta = polar.vector_addition(r1, th1, circle_radius, circle_radius_theta)
        return polar.Circle(origin_r, origin_theta, circle_radius)

    @staticmethod
    def vector_addition(r_1, theta_1, r_2, theta_2):
        """ Returns the vector that represents the sum of two vectors """
        sum_r = polar.distance(r_1, theta_1, r_2, theta_2)
        sum_theta = theta_1 + np.arctan2(
            r_2 * np.sin(theta_2 - theta_1),
            r_1 + r_2 * np.cos(theta_2 - theta_1))
        return sum_r, sum_theta

    class Circle:
        """ An (off-center) circle in polar coordinates """

        def __init__(self, origin_r, origin_theta, radius):
            self.origin_r = origin_r
            self.origin_theta = origin_theta
            self.radius = radius

        def __call__(self, theta):
            p = self.origin_r * np.cos(theta - self.origin_theta)
            q = self.origin_r ** 2 - self.radius ** 2
            if isinstance(theta, np.ndarray):
                points = np.stack([p + np.sqrt(p**2 - q), 
                                   p - np.sqrt(p**2 - q)], axis=-1)
            else:
                points = np.array([p + np.sqrt(p**2 - q),
                                   p - np.sqrt(p**2 - q)])
            return points

In [None]:
class PolarGraphPlot:
    
    def __init__(self, connectivity_matrix, is_unidirectional=False, **kwargs):
        self.connectivity_matrix = connectivity_matrix
        self.is_unidirectional = is_unidirectional
        self.node_linewidth = 1
        self.node_linecolor = "k"
        self.node_fillcolor = "C0"
        self.node_labels = None
        self.node_spacing = .03
        self.node_widths = None
        self.node_height = .15
        self.edge_linewidth = 2
        self.edge_linewidth_equal = True
        self.edge_cmap = matplotlib.cm.BuPu
        self.edge_norm = matplotlib.colors.Normalize(np.nanmin(self.connectivity_matrix), np.nanmax(self.connectivity_matrix))       
        self.theta_gap = np.pi / 4.
        self.theta_offset = 0
        self.zorder = 5
        self.set(**kwargs)
        
    @property
    def connectivity_matrix(self):
        return self._mtx.copy()
    
    @connectivity_matrix.setter
    def connectivity_matrix(self, new_matrix):
        if len(new_matrix.shape) != 2 or new_matrix.shape[0] != new_matrix.shape[1]:
            raise ValueError("The connectivity matrix must be a square matix!")
        self._mtx = new_matrix.copy()
    
    @property
    def n_nodes(self):
        """ Returns the number of nodes in the plot """
        return self.connectivity_matrix.shape[0]
    
    def get_node_theta(self, node_idx: int):
        """ Returns the theta value at which this node[idx] is centered """    
        theta = .5 * self.theta_gap
        for idx in range(node_idx):
            theta += self.get_node_width(idx) + self.get_node_spacing(idx)
        theta += .5 * self.get_node_width(node_idx)
        return theta
        
    def get_node_width(self, node_idx: int):
        """ Returns the theta range over which the node[idx] is drawn """
        if self.node_widths is None:
            width = 2 * np.pi - self.theta_gap
            if isinstance(self.node_spacing, (float, int)):
                width -= (self.n_nodes - 1) * self.node_spacing
            elif isinstance(self.node_spacing, (list, tuple, np.ndarray)):
                width -= np.sum(self.node_spacing)
            self.node_widths = np.repeat(width / self.n_nodes, self.n_nodes)
        return self.node_widths[node_idx]
    
    def get_node_height(self, node_idx: int):
        """ Returns the height (r) of node[idx] """
        if isinstance(self.node_height, (float, int)):
            return self.node_height
        else:
            return self.node_height[node_idx]
    
    def get_node_linewidth(self, node_idx: int):
        """ Returns the linewidth for node[idx] """
        if isinstance(self.node_linewidth, (float, int)):
            return self.node_linewidth
        else:
            return self.node_linewidth[node_idx]
        
    def get_node_spacing(self, node_idx: int):
        """ Returns the (theta) space between node[idx] and node[idx+1] """
        if isinstance(self.node_spacing, (float, int)):
            return self.node_spacing
        else:
            return self.node_spacing[node_idx]
        
    def get_node_linecolor(self, node_idx: int):
        """ Returns the line color of node[idx] """
        if isinstance(self.node_linecolor, (list, tuple, np.ndarray)) and len(self.node_linecolor) == self.n_nodes:
            return self.node_linecolor[node_idx]
        else:
            return self.node_linecolor
            
    def get_node_fillcolor(self, node_idx: int):
        """ Returns the fill color of node[idx] """
        if isinstance(self.node_fillcolor, (list, tuple, np.ndarray)) and len(self.node_fillcolor) == self.n_nodes:
            return self.node_fillcolor[node_idx]
        else:
            return self.node_fillcolor
              
    def get_edge_linewidth(self, edge_value: float):
        """ Returns the linewidth for an edge of a given value """
        if self.edge_linewidth_equal:
            return self.edge_linewidth
        else:
            factor = self.edge_norm(edge_value)
            return self.edge_linewidth * factor
        
    def get_edge_linecolor(self, edge_value: float):
        """ Returns the color for an edge of a given value """
        return self.edge_cmap(self.edge_norm(edge_value))        
        
    def set(self, **kwargs):
        for kw, value in kwargs.items():
            if hasattr(self, kw):
                setattr(self, kw, value)
            else:
                raise AttributeError(f"Unable to set unknown attribute: '{kw}'")
        return self
    
    def plot(self, ax):
        
        for i in range(self.connectivity_matrix.shape[0]):
            theta = self.get_node_theta(i)
            self.plot_node(i)
            if self.node_labels is not None:
                self.plot_nodelabel(i)
                
        # Plot edges
        edges = np.array([[i, j, self.connectivity_matrix[i,j]] 
            for i in range(self.n_nodes) for j in range(i+1, self.n_nodes)
            if self.connectivity_matrix[i,j] > 0
        ])
        edges = edges[np.argsort(edges[:,2])]
        for node_i, node_j, value in edges:
            node_i, node_j = int(node_i), int(node_j)
            self.plot_edge(node_i, node_j, value)            

    def plot_node(self, node_idx):
        """ Draw the node[idx]"""
        
        node_theta  = self.get_node_theta(node_idx)
        node_width  = self.get_node_width(node_idx)
        node_height = self.get_node_height(node_idx)
        node_fillcolor = self.get_node_fillcolor(node_idx)
        
        th = np.linspace(node_theta - .5 * node_width, node_theta + .5 * node_width, 20)
        theta_values = np.concatenate([th, th[::-1], th[:1]])
        r_values = np.concatenate([[1] * th.size, [1 + node_height] * th.size, [1]])
        
        if node_fillcolor is not None:
            ax.fill_between(th, [1] * th.size, [1 + node_height] * th.size, 
                            color=node_fillcolor, 
                            zorder=self.zorder)
        ax.plot(theta_values, r_values, 
                linewidth=self.get_node_linewidth(node_idx),
                color=self.get_node_linecolor(node_idx), 
                zorder=self.zorder)
    
    def plot_nodelabel(self, node_idx):
        try:
            label = self.node_labels[node_idx]
        except:
            return None
        theta  = self.get_node_theta(node_idx)
        real_theta = theta + self.theta_offset
        rvalue = 1.05 + self.get_node_height(node_idx)       
        rotation = 180 * real_theta / np.pi
        ax.text(theta, rvalue, label, 
            va="center", ha="left", rotation=rotation, rotation_mode="anchor")
     
    def plot_edge(self, node0_idx, node1_idx, edge_value):
        
        theta0, theta1 = sorted([self.get_node_theta(node0_idx), 
                                 self.get_node_theta(node1_idx)])

        if theta1 - theta0 == np.pi:
            # Plot edge as straight line
            ax.plot([theta0, theta1], [1,1], 
                    linewidth=self.get_edge_linewidth(value), 
                    color=self.get_edge_linecolor(value))        
        else:
            # Make sure to draw edges along the shortest path
            while abs(theta1 - theta0) > np.pi:
                theta0 += 2 * np.pi
                
            # Get the circle with arc that represents the edge
            circle = polar.get_incircle(1, min(theta0, theta1), 1, max(theta0, theta1))
            #ax.scatter(circle.origin_theta, circle.origin_r) # plot center of circle
            
            # Plot the edge
            theta_values = np.linspace(
                min(theta0, theta1) + 1e-6, max(theta0, theta1) - 1e-6, 50) # make sure that all values are in bounds
            r_values = circle(theta_values)
            ax.plot(theta_values, np.nanmin(r_values, axis=1),
                    linewidth=self.get_edge_linewidth(edge_value), 
                    color=self.get_edge_linecolor(edge_value))

def generate_fillcolors(peptide_length: int):
    return ["#ddd"] + ["#ddd", "#888", "#ddd"] * peptide_length + ["#ddd"]

def generate_heights(peptide_length: int):
    return [.1] + [.1, .18, .1] * peptide_length + [.1]
    
def generate_labels(peptide_length: int):
    labels = []
    for i in range(1, peptide_length+1):
        if i == 3:
            labels += ["N", f"{residues[N]}:{i:d}", "O"]
        elif i == peptide_length - 2:
            labels += ["N", f"XXX:{i:d}", "O"]
        else:
            labels += ["N", f"GLY:{i:d}", "O"]
    return ["ACE:O"] + labels + ["NME:N"]

In [None]:
def hbond_matrix(input_files: List[str], n_residues: int, end_frame: int = None, include_mediated: bool = False):
    """ Generate a hbond matrix from the given files """
    hbonds = pd.DataFrame()
    for input_file in input_files:
        df = pd.read_csv(input_file)
        if end_frame is None:
            df["weight"] = 1 / 5000 / math.ceil(df["#frame"].max() / 5000)
        else:
            df = df.loc[df["#frame"] < end_frame]
            df["weight"] = 1 / end_frame
        if not include_mediated:
            df = df.loc[~df['type'].str.contains('aq')]
        hbonds = pd.concat([hbonds, df], axis=0, ignore_index=True)
    
    mtx = np.zeros((3*n_residues,)*2 )
    gb = hbonds.groupby(["resid0", "resid1", "type"])
    for resid0, resid1, hbtype in gb.groups:
        i, j = 3 * (resid0 - 1), 3 * (resid1 - 1)
        if hbtype.startswith("sc"):
            i += 1
        elif hbtype.startswith("bb<-"):
            i += 2
        if hbtype.endswith("sc"):
            j += 1
        elif hbtype.endswith("->bb"):
            j += 2
        mtx[max(i,j),min(i,j)] += gb.get_group((resid0, resid1, hbtype))["weight"].sum()
    mtx += mtx.T
    assert mtx[np.identity(3*n_residues, dtype=bool)].all() == 0
    return mtx / len(input_files)

def generate_hbond_matrix(peptide_length: int, include_residues: list[str] = None, end_frame: int= None, include_mediated: bool = False):
    """ Generate an array of hbond matrices for all peptides of similar length """
    if include_residues is None:
        include_residues = residues
    n_residues = peptide_length + 2
    hbond_mtx = np.zeros((len(residues), 3 * n_residues, 3 * n_residues))
    
    for k, res in enumerate(include_residues):
        input_files = sorted(glob.glob(f"../{peptide_length:d}peptides/gg{res}*gg/analy/hbonds.csv"))
        hbond_mtx[k] = hbond_matrix(input_files, n_residues, end_frame, include_mediated)
        print(". " if k%10 == 9 else ".", end="")
    print()
        
    return hbond_mtx

In [None]:
hbond_matrics = []
for i in range(5,10):
    hbond_matrics.append(generate_hbond_matrix(i))

In [None]:
for mtx in hbond_matrics:
    print(np.max(mtx), np.argmax(mtx))

In [None]:
for mtx in hbond_matrics:
    print(mtx.shape)

In [None]:
mtx = hbond_matrics[3]
    
for res in ["SER", "S1P", "SEP", "THR", "T1P", "TPO", "TYR", "Y1P", "PTR"]:
    N = residues.index(res)
    peptide_length = mtx.shape[1] // 3 - 2

    hbonds_options = dict(
        edge_cmap = matplotlib.cm.magma_r,
        edge_linewidth = 1.5,
        node_fillcolor = generate_fillcolors(peptide_length),
        node_height    = generate_heights(peptide_length),
        node_labels    = generate_labels(peptide_length),
    )

    HBonds = PolarGraphPlot(mtx[N, 2:-2, 2:-2], **hbonds_options)
    HBonds.edge_norm = matplotlib.colors.Normalize(0., .2)

    fig, ax = plt.subplots(1,1, figsize=(4,3), dpi=300, subplot_kw={'projection': 'polar'})
    fig.patch.set_alpha(0)
    #ax.set_title()
    ax.grid(False)
    ax.spines['polar'].set_visible(False)

    HBonds.plot(ax)
    ax.set_yticks([])
    ax.set_xticks([])
    ax.set_ylim(ymin=0,)

    # fig.colorbar(matplotlib.cm.ScalarMappable(norm=HBonds.edge_norm, cmap=HBonds.edge_cmap), 
    #              ax=ax, shrink=.8, pad=.12)

    fig.savefig(f"figures/hbonds/aaa{peptide_length}_{residues[N]}.png", format="png", dpi=300, bbox_inches="tight", transparent=True)

In [None]:
fig, ax = plt.subplots(1,1, figsize=(4,3), dpi=300)

ax.set_axis_off()
cbar = fig.colorbar(matplotlib.cm.ScalarMappable(norm=HBonds.edge_norm, cmap=HBonds.edge_cmap), 
                    ax=ax, shrink=.8, pad=.12, orientation="horizontal")
#cbar.set_ticks(np.arange(0, .35, .05), minor=True)
fig.savefig(f"figures/hbonds/colorbarH.pdf", format="pdf", dpi=96, bbox_inches="tight")