# find interface residues

In [None]:
import biotite.structure.io as strucio
import biotite.structure as struc
import numpy as np

def interface_residue(atom_array, chain_1, chain_2, cutoff=5.0):
    a = atom_array[atom_array.chain_id == chain_1]
    b = atom_array[atom_array.chain_id == chain_2]
    
    cell = struc.CellList(b.coord, cutoff)
    neighbors = cell.get_atoms(a.coord, radius=cutoff)
    
    hits = np.any(neighbors != -1, axis=1)
    
    res1 = set(map(int, a.residue_number[hits]))
    
    idx2 = np.unique(neighbors[hits][neighbors[hits] != -1])
    res2 = set(map(int, b.residue_number[idx2]))
    
    return res1, res2

example_aa = strucio.load_structure("example.pdb")
residues_chain_A, residues_chain_B = interface_residue(example_aa, "A", "B", cutoff=5.0)


# Check clashes

In [None]:
from collections import defaultdict

def check_clashes(dataset, cutoff=2.0, ignore_seq=1):
    clash_summary = defaultdict(list)
    for i in range(len(dataset)):
        atom_array = dataset[i]['atom_array'][0]
        atom_array = atom_array[~np.isnan(atom_array.coord).any(axis=1)]
        mask = np.isin(atom_array.atom_name, ["N", "CA", "C", "O"])  
        aa = atom_array[mask]
        cell = struc.CellList(aa, cell_size=7)
        neighbors = cell.get_atoms(aa.coord, radius=cutoff)
        n, p = neighbors.shape
        
        i = np.repeat(np.arange(n), p)
        j = neighbors.reshape(-1)
        
        mask = (i<j) & (j!=-1)
        i = i[mask]
        j= j[mask]
        seq_diff = np.abs(aa.res_id[i] - aa.res_id[j])
        mask = (seq_diff > ignore_seq)
        i = i[mask]
        j= j[mask]
        d = np.linalog.norm(aa.coord[i] - aa.coord[j], axis=1)
        sorted_idx = np.argsort(d)
        i = i[sorted_idx]
        j = j[sorted_idx]
        d = d[sorted_idx]
        for idx1, idx2, dist in zip(i, j, d):
            res1 = (aa.chain_id[idx1], int(aa.res_id[idx1]), aa.insertion_code[idx1])
            res2 = (aa.chain_id[idx2], int(aa.res_id[idx2]), aa.insertion_code[idx2])
            clash_summary[i].append((res1, res2, dist))
    return clash_summary
        
        