In [None]:
# DEPENDENCIES
import glob
import numpy as np
import pandas as pd
import logomaker
from scipy import stats
import matplotlib.pyplot as plt
from matplotlib.ticker import (MultipleLocator, FixedLocator)
# LOCAL IMPORTS
import sse_func
import template_finder as tf
from indexing_classes import StAlIndexing

try: 
    GESAMT_BIN = os.environ.get('GESAMT_BIN')
except:
    GESAMT_BIN = "/home/hildilab/lib/xtal/ccp4-8.0/ccp4-8.0/bin/gesamt"

def find_pdb(name, pdb_folder):
    identifier = name.split("-")[0]
    target_pdb = glob.glob(f"{pdb_folder}/*{identifier}*.pdb")[0]
    return target_pdb

def find_offsets(fasta_file, accessions, sequences):
    # FIND THE ACTUAL OFFSET OF RESIDUES BETWEEN AN ALPHAFOLD MODEL AND THE UNIPROT ENTRY PROTEIN
    #   - searches through the accessions in the big sequence file,
    #   - finds the start for the provided sequence
    with open(fasta_file,"r") as fa:
        fa_data = fa.read()
        fasta_entries = fa_data.split(">")
    seqs = []
    headers = []
    offsets = []
    for seq in fasta_entries:
        # Fallback for too short sequences
        if len(seq) < 10: 
            continue
        data = seq.strip().split("\n")
        headers.append(data[0].split("|")[1]) # This is only the UniProtKB Accession Number and will be matched EXACTLY
        seqs.append("".join(data[1:]))
    
    heads = np.array(headers)
    for idx, accession in enumerate(accessions):
        seq_idx = np.where(heads == accession)[0][0]
        offset = "".join(seqs[seq_idx]).find("".join(sequences[idx]))
        print(seqs[seq_idx], "".join(sequences[idx]), offset, sep="\n")
        offsets.append(offset)
    
    return offsets

In [None]:
stal_indexing = np.load("stal_indexing.r4.pkl", allow_pickle=True)
human_collection = np.load("../human_collection.q.pkl", allow_pickle=True)
valid_collection = np.load("../valid_collection.q.pkl", allow_pickle=True)
s4_indexing = np.load("../s4_test_indexing.pkl", allow_pickle=True)

In [None]:
# Initialize Data
receptors = ["ADGRA1","ADGRA2","ADGRA3","ADGRB1","ADGRB2","ADGRB3","CELSR1","CELSR2","CELSR3","ADGRD1","ADGRD2","ADGRE1","ADGRE2",
             "ADGRE3","ADGRE4","ADGRE5","ADGRF1","ADGRF2","ADGRF3","ADGRF4","ADGRF5","ADGRG1","ADGRG2","ADGRG3","ADGRG4","ADGRG5",
             "ADGRG6","ADGRG7","ADGRL1","ADGRL2","ADGRL3","ADGRL4","ADGRV1","unknown","PKD"]
rr_list = ["A1","A2","A3","B1","B2","B3","C1","C2","C3","D1","D2","E1","E2","E3","E4","E5","F1","F2","F3","F4","F5","G1","G2","G3","G4","G5","G6","G7","L1","L2","L3","L4","V1","X"]
elements = ["H1","H2","H3","H4","H5","H6","S1","S2","S3","S4","S5","S6","S7","S8","S9","S10","S11","S12","S13","S14"]

absolute_sf_occupancy = np.zeros(shape=(10,20), dtype=int)
absolute_rr_occupancy = np.zeros(shape=(34,20), dtype=int)
s4_total_occ = np.zeros(shape=(34,20), dtype=int)

n_fams = np.zeros(shape=(10), dtype=int)
n_rr = np.zeros(shape=(34), dtype=int)
s4_rr = np.zeros(shape=(34), dtype=int)

el_index = dict(zip(elements, range(20)))
fam_index = dict(zip(list("ABCDEFGLVX"),range(10)))
rec_index = dict(zip(rr_list, range(34)))
print(fam_index,rec_index)

In [None]:
# extract binary info : which receptor has what element?
human_ac = [gain.name.split("_",)[0].split("-")[0] for gain in human_collection.collection]
human_idx = {}
for ac in human_ac:
    for i,sac in enumerate(stal_indexing.accessions):
        if ac == sac:
            human_idx[i] = stal_indexing.receptor_types[i]
print(human_idx)
human_el_matrix = np.zeros(shape=(33,20), dtype=bool)

for k in human_idx.keys():
    dd = stal_indexing.indexing_dirs[k]
    el_list = np.unique([k.split(".")[0] for k in dd.keys()])
    rr = human_idx[k]
    for el in el_list:
        if "GPS" in el:
            continue
        human_el_matrix[rec_index[rr], el_index[el]] = 1

In [None]:
# Get a matrix : a) Receptors + OCC; b) Families+ + OCC
for i in range(len(stal_indexing.receptor_types)):
    dd = stal_indexing.indexing_dirs[i]
    el_list = np.unique([k.split(".")[0] for k in dd.keys()])
    rr = stal_indexing.receptor_types[i]
    if rr in rr_list:
        n_rr[rec_index[rr]] += 1
        for el in el_list:
            if el not in elements:
                 continue
            absolute_rr_occupancy[rec_index[rr], el_index[el]] += 1
    if rr[0] in "ABCDEFGLVX":
        n_fams[fam_index[rr[0]]] += 1
        for el in el_list:
            if el not in elements:
                continue
            absolute_sf_occupancy[fam_index[rr[0]], el_index[el]] += 1

 


In [None]:
# Normalize by frequency
rr_occ = absolute_rr_occupancy / n_rr[:,None]
sf_occ = absolute_sf_occupancy / n_fams[:,None]

# store the values somewhere.
import pickle as pkl

data = {"absolute_rr_occupancy": absolute_rr_occupancy,
         "rr_occ": rr_occ,
         "absolute_sf_occupancy": absolute_sf_occupancy,
         "sf_occ": sf_occ,
         "fam_index": fam_index,
         "rec_index": rec_index,
         "n_rr": n_rr,
         "n_fams": n_fams}

#pkl.dump(data, "elemen_occ.pkl", 'w')

In [None]:
# GET PKD OCCUPANCY. THIS IS NOT AS HIGH-QUALITY AND LIKELY NOT FILTERED FOR GOOD AND BAD GAIN DOMAINS.
pkd_indexing = np.load("../pkd_indexing.pkl", allow_pickle=True)
pkd_collection = np.load("../pkd_collection.pkl", allow_pickle=True)
absolute_pkd_occupancy = np.zeros(shape=(20), dtype=int)

invalid = 0
for pkd in pkd_collection.collection:
    if not pkd.hasSubdomain: invalid +=1
print("Found INVALID PKD structures:",invalid)

n_pkd = 0
pkd_elements = []
for i in range(len(pkd_indexing.receptor_types)):
    dd = pkd_indexing.indexing_dirs[i]
    el_list = np.unique([k.split(".")[0] for k in dd.keys()])
    rr = pkd_indexing.receptor_types[i]
    pkd_elements.append(len(el_list))
    if rr in rr_list:
        n_pkd += 1
        for el in el_list:
            if el not in elements:
                 continue
            absolute_pkd_occupancy[el_index[el]] += 1
print("MATCH STATISTICS (NUMBER OF INDEXED ELEMENTS):",np.unique(pkd_elements, return_counts=True))

pkd_occ = absolute_pkd_occupancy/n_pkd

In [None]:
def draw_triangle(ax, center_coord, len=0.3):
    # The center of the value is at N,M (where N,M) are integers; the width and length are 1, respectively --> top left corner is x+0.5, y-0.5

    # Define three points making up the triangle
    corner = [center_coord[0]+0.45,center_coord[1]-0.45]
    left = [corner[0]-len, corner[1]]
    bottom = [corner[0], corner[1]+len]

    triangle = plt.Polygon([corner,left,bottom], color='deeppink')
    ax.add_patch(triangle)

def draw_circle(ax, center_coord, r=0.3, color='black'):
    # The center of the value is the exact value of the square, therefore no offset should be needed
    circle = plt.Circle(center_coord, r, color=color)
    c_i = plt.Circle(center_coord, r-0.05, color='white')
    ax.add_patch(circle)
    ax.add_patch(c_i)


In [None]:
r_comb_occ = np.zeros(shape=(35,20),dtype=float)
r_comb_occ[:-1,:] = rr_occ
r_comb_occ[-1:,] = pkd_occ.T
sf_comb_occ = np.zeros(shape=(11,20),dtype=float)
sf_comb_occ[:-1,:] = sf_occ
sf_comb_occ[-1:,] = pkd_occ.T

In [None]:
plt.rcParams['font.family'] = 'FreeSans'
from numpy.ma import masked_array

fig,ax = plt.subplots(facecolor='w', figsize=[8,8])

splitter = np.ones(shape=r_comb_occ.shape, dtype=int)
splitter[:,6:] = 2

strands = masked_array(r_comb_occ, splitter == 1)
hels = masked_array(r_comb_occ, splitter == 2)
ax.set_xticks(ticks = np.arange(-0.5,19.5), labels = elements, rotation=90, size=14, horizontalalignment='left')
ax.set_yticks(ticks = np.arange(-0.5,34.5), labels = receptors, size=13, verticalalignment='top', horizontalalignment='right',style='italic')
him = ax.imshow(hels, cmap='Blues', aspect='equal',extent = (-0.5, 19.5, 34.5, -0.5))
sim = ax.imshow(strands, cmap='Oranges')
#cbar = plt.colorbar(shrink=0.5)
ax.xaxis.tick_top()
#cbar.set_label("Element Occurrence")

for y in range(20):
    for x in range(33):
        if not human_el_matrix[x,y]:
            #draw_triangle(ax, [y,x], len=0.4)
            draw_circle(ax, [y,x], r=0.18, color='black')
ax.grid(True,'both', color = 'black')

cb1 = plt.colorbar(him,shrink=0.7) 
cb2 = plt.colorbar(sim,shrink=0.7) 
cb2.set_ticks([])
cb1.set_ticks(ticks=[0,1],labels=["0%","100%"], size=14)

cb1.set_label("Element Occupancy", size=14)

plt.savefig("../rr_occ_withpkd_circ2.svg",dpi=600, bbox_inches='tight')

In [None]:
from numpy.ma import masked_array

fig,ax = plt.subplots(facecolor='w', figsize=[8,4])

splitter = np.ones(shape=sf_comb_occ.shape, dtype=int)
splitter[:,6:] = 2

strands = masked_array(sf_comb_occ, splitter == 1)
hels = masked_array(sf_comb_occ, splitter == 2)

ax.grid(True,'both', color = 'black')
ax.set_xticks(ticks = np.arange(-0.5,19.5), labels = elements, rotation=90, size=14, horizontalalignment='left')
ax.set_yticks(ticks = np.arange(-0.5,10.5), labels = ["ADGRA","ADGRB","CELSR","ADGRD","ADGRE","ADGRF","ADGRG","ADGRL","ADGRV","unknown","PKD"], size=14, verticalalignment='top',  style='italic')
him = ax.imshow(hels, cmap='Blues', aspect='equal',extent = (-0.5, 19.5, 10.5, -0.5))
sim = ax.imshow(strands, cmap='Oranges')
#cbar = plt.colorbar(shrink=0.5)
#cbar.set_label("Element Occurrence")

cb1 = plt.colorbar(him,shrink=0.7) 
cb2 = plt.colorbar(sim,shrink=0.7) 
cb2.set_ticks([])
cb1.set_ticks(ticks=[0,1],labels=["0%","100%"], size=14)
cb1.set_label("Element Occupancy", size=14)
plt.savefig("../sf_occ_withpkd.svg",dpi=600, bbox_inches='tight')

In [None]:
pkd_collection = np.load("../pkd_collection.pkl", allow_pickle=True)
valid = 0
for gain in pkd_collection.collection:
    if gain.isValid: valid +=1
print(valid)

In [None]:
for i in range(len(s4_indexing.receptor_types)):
    dd = s4_indexing.indexing_dirs[i]
    el_list = np.unique([k.split(".")[0] for k in dd.keys()])
    rr = s4_indexing.receptor_types[i]
    if rr in rr_list:
        s4_rr[rec_index[rr]] += 1
        for el in el_list:
            if el not in elements:
                 continue
            s4_total_occ[rec_index[rr], el_index[el]] += 1

plt.imshow(s4_total_occ / s4_rr[:,None], cmap='Oranges')

In [None]:
# create a b_factor map, mapping segment occupancy to the segmentis in human ADGRA2, Q96PE1
segment_occ = np.mean(rr_occ, axis=0)
#print(segment_occ)
occ_dict = dict(zip(elements,segment_occ))
print(occ_dict)
for idx, ac in enumerate(stal_indexing.accessions):
    if "Q96PE1" in ac:
        a2_index = idx
        break
print(a2_index)

a2_elements = stal_indexing.indexing_dirs[a2_index]
print(a2_elements)
a2_pdb = "../all_pdbs/Q96PE1_A6H8W3_D3DSW4_Q8N3R1_Q8TEM3_Q96KB2_Q9P1Z7_Q9UFY4.pdb"

In [None]:
def label2b(pdbfile, outfile, res2value, fill_b=None):
    data = open(pdbfile).readlines()
    newdata = []
    for l in data:
        if not l.startswith("ATOM"):
            newdata.append(l)
            continue
        if int(l[22:26]) not in res2value.keys():
            #print(l[13:14], int(l[22:26]))
            if fill_b is not None:
                k = l[:60]+fill_b.rjust(6)+l[67:]
            newdata.append(k)
            continue
        k = l[:60]+"{:.4f}".format(res2value[int(l[22:26])])+l[67:]
        newdata.append(k)
    open(outfile, 'w').write("".join(newdata))
    print(f"Written residue labels to PDB file RES entries : {outfile}")

In [None]:
res2value = {}

for label, resid in a2_elements.items():
    segment = label.split(".")[0]
    if "GPS" in segment: continue
    occ = occ_dict[segment]
    res2value[resid] = occ

#label2b(a2_pdb, "../a2_segcons.pdb", res2value=res2value, clear_b=True)

In [None]:
def mark_seg_cons(rr_occ, elements, uniprot_id, pdbfile, outfile, fill_b=None):

    # create a b_factor map, mapping segment occupancy to the segmentis in human ADGRA2, Q96PE1
    segment_occ = np.mean(rr_occ, axis=0)
    #print(segment_occ)
    occ_dict = dict(zip(elements,segment_occ))
    for idx, ac in enumerate(stal_indexing.accessions):
        if uniprot_id in ac:
            gain_index = idx
            break
    print(gain_index)

    gain_elements = stal_indexing.indexing_dirs[gain_index]

    res2value = {}

    for label, resid in gain_elements.items():
        segment = label.split(".")[0]
        if "GPS" in segment: continue
        occ = occ_dict[segment]
        res2value[resid] = occ

    label2b(pdbfile, outfile, res2value=res2value, fill_b=fill_b)
    print("mark_seg_cons : Done.")

def mark_pos_cons(pos_occ_dict, uniprot_id, pdbfile, outfile, fill_b=None):
    # create a b_factor map, mapping POSITION occupancy in a given GAIN
    for idx, ac in enumerate(stal_indexing.accessions):
        if uniprot_id in ac:
            gain_index = idx
            break
    print(gain_index)

    gain_elements = stal_indexing.indexing_dirs[gain_index]

    res2value = {}

    for label, resid in gain_elements.items():
        if "GPS" in label: continue
        occ = pos_occ_dict[label]/14435 # divide absolute counts by number of total GAINs to normalize.
        res2value[resid] = occ

    label2b(pdbfile, outfile, res2value=res2value, fill_b=fill_b)
    print("mark_pos_cons : Done.")

#mark_seg_cons(rr_occ, elements, "O94910", "../all_pdbs/O94910_Q96IE7_Q9BU07_Q9HAR3.pdb", "../l1_segcons.pdb")


In [None]:
# gather occupancy for each element
label_occ_dict = {}
for indexing_dir in stal_indexing.indexing_dirs:
    for k, v in indexing_dir.items():
        if k not in label_occ_dict:
            label_occ_dict[k] = 0
        label_occ_dict[k] += 1
print(label_occ_dict)

In [None]:
mark_seg_cons(rr_occ, elements, "O94910", "../all_pdbs/O94910_Q96IE7_Q9BU07_Q9HAR3.pdb", "../l1_segcons2.pdb", fill_b="-1.000")
mark_seg_cons(rr_occ, elements, "Q96PE1", "../all_pdbs/Q96PE1_A6H8W3_D3DSW4_Q8N3R1_Q8TEM3_Q96KB2_Q9P1Z7_Q9UFY4.pdb", "../a2_segcons2.pdb", fill_b="-1.000")

mark_pos_cons(label_occ_dict, "O94910", "../all_pdbs/O94910_Q96IE7_Q9BU07_Q9HAR3.pdb", "../l1_poscons2.pdb", fill_b="-1.000")
mark_pos_cons(label_occ_dict, "Q96PE1", "../all_pdbs/Q96PE1_A6H8W3_D3DSW4_Q8N3R1_Q8TEM3_Q96KB2_Q9P1Z7_Q9UFY4.pdb", "../a2_poscons2.pdb", fill_b="-1.000")

In [None]:
def get_elem_seq(uniprot, stal_indexing, valid_collection, segment='H6'):
    for gain in valid_collection.collection:
        #print(gain.name)
        if uniprot in gain.name:
            print(gain.name, "found.")
            break
    for i,ac in enumerate(stal_indexing.accessions):
        if uniprot == ac:
            idx = i
            break
    myseg = [(k,v) for k,v in stal_indexing.indexing_dirs[idx].items() if k.split(".")[0] == segment]
    # get min and max of the element
    print(myseg[0], myseg[-1])

get_elem_seq(uniprot="O14514", stal_indexing=stal_indexing, valid_collection=valid_collection, segment='H6')
