In [None]:
def native_aa(enzyme_file_path):
    """Extract native amino acid sequence from a PDB file."""

    exclude = {"WAT", "Na+"}
    residues = []
    seen = set()
    
    with open(enzyme_file_path) as f:
        for line in f:
            if line.startswith(("ATOM", "HETATM")):
                chain   = line[21]            # chain ID
                resseq  = line[22:26].strip() # residue sequence number
                resname = line[17:20].strip() # residue name
                key     = (chain, resseq)

                # mapping: MER → SER
                if resname == "MER":
                    resname = "SER"   
                # map protonation‐state variants back to the standard three‐letter code
                if resname in ("HIE", "HID"):
                    resname = "HIS"
                if resname == "ASH":
                    resname = "ASP"
                if resname == "LYN":
                    resname = "LYS"
                elif resname == "CYX":
                    resname = "CYS"
                elif resname == "GLH":
                    resname = "GLU"
             
                if key not in seen:
                    seen.add(key)
                    if resname not in exclude:
                        residues.append(resname)

    # three to one letter map
    three_to_one = {
        "ALA":"A","CYS":"C","ASP":"D","GLU":"E","PHE":"F","GLY":"G","HIS":"H",
        "ILE":"I","LYS":"K","LEU":"L","MET":"M","ASN":"N","PRO":"P","GLN":"Q",
        "ARG":"R","SER":"S","THR":"T","VAL":"V","TRP":"W","TYR":"Y"
    }

    # Convert to one-letter codes
    one_letter_residues = [three_to_one[res] for res in residues if res in three_to_one]
    
    return one_letter_residues



BlaC = native_aa(r"Enzymes\BlaC.pdb")
CTXM = native_aa(r"Enzymes\CTXM.pdb")
KPC  = native_aa(r"Enzymes\KPC_MM.pdb")
NMCA = native_aa(r"Enzymes\NMCA.pdb")
SFC  = native_aa(r"Enzymes\SFC.pdb")
SHV  = native_aa(r"Enzymes\SHV.pdb")
SME  = native_aa(r"Enzymes\SME.pdb")
TEM  = native_aa(r"Enzymes\TEM.pdb")


In [None]:
# Define enzyme metadata: class labels and native sequences
enzyme_info = {
    "BlaC": {
        "carbapenemase": 0,
        "sequence": BlaC
    },
    "CTXM16": {
        "carbapenemase": 0,
        "sequence": CTXM
    },
    "KPC2": {
        "carbapenemase": 1,
        "sequence": KPC
    },
    "NMCA": {
        "carbapenemase": 1,
        "sequence": NMCA
    },
    "SFC1": {
        "carbapenemase": 1,
        "sequence": SFC
    },
    "SHV1": {
        "carbapenemase": 0,
        "sequence": SHV
    },
    "SME1": {
        "carbapenemase": 1,
        "sequence": SME
    },
    "TEM1": {
        "carbapenemase": 0,
        "sequence": TEM
    }
}

In [None]:
import numpy as np, h5py
from Bio import pairwise2
from Bio.Align import substitution_matrices
from pathlib import Path

PATH_TEMPLATE = r"Data/{enzyme}/{enzyme}_{run_id}_pdb_frames.h5"
TARGET_FILE = "organised_data_new.h5"

TEM1_REF = "".join(enzyme_info["TEM1"]["sequence"])
START_AMBLER_FOR_TEM = 26  # TEM starts at Ambler 26

def ambler_map_labels(query_seq_str, ref_seq_str, start_ambler=1):
    mat = substitution_matrices.load("BLOSUM62")
    aln = pairwise2.align.globalds(ref_seq_str, query_seq_str, mat, -10, -1, one_alignment_only=True)[0]
    ref_g, qry_g = aln.seqA, aln.seqB

    labels = []
    ref_num = start_ambler - 1
    last_ref = None
    ins_counts = {}
    for r, q in zip(ref_g, qry_g):
        r_is = (r != '-'); q_is = (q != '-')
        if r_is: ref_num += 1
        if r_is and q_is:
            labels.append(str(ref_num)); last_ref = ref_num
        elif (not r_is) and q_is:
            base = last_ref if last_ref is not None else (start_ambler - 1)
            k = ins_counts.get(base, 0)
            labels.append(f"{base}{chr(ord('a')+k)}")
            ins_counts[base] = k + 1
    assert len(labels) == len(query_seq_str)
    return labels

def split_label(label):
    i = 0
    while i < len(label) and label[i].isdigit(): i += 1
    base = int(label[:i]) if i else 0
    rank = 0 if i == len(label) else (ord(label[i]) - ord('a') + 1)
    return base, rank

# Fix for 237a/238 swap
# small normaliser for the 237a/238 swap in class A carbapenemases
def reset_ranks_for_base(base_nums, ins_rank, base, canonical_idx):
    """
    Make 'canonical_idx' the canonical residue for a given base,
    and renumber all other residues of that base in sequence order.
    """
    # set canonical
    ins_rank[canonical_idx] = 0
    base_mask = (base_nums == base)
    idxs = np.where(base_mask)[0]
    # keep sequence order but put canonical first
    others = [i for i in idxs if i != canonical_idx]
    # assign a,b,c
    for k, j in enumerate(others, start=1):
        ins_rank[j] = k

def normalise_237a_238(base_nums, ins_rank, seq_str):
    """
    If we see Cys at '237a' and Gly as canonical 238, put Cys to 238 and
    the previous 238 (and any 238 insertions) to 238a/238b
    """
    seq = np.array(list(seq_str))
    # find C at 237a
    idx_237a_C = np.where((base_nums == 237) & (ins_rank == 1) & (seq == 'C'))[0]
    # find canonical 238 (any residue)
    idx_238_can = np.where((base_nums == 238) & (ins_rank == 0))[0]

    if len(idx_237a_C) == 1 and len(idx_238_can) == 1:
        iC = idx_237a_C[0]
        # change base 237a -> base 238
        base_nums[iC] = 238
        # now renumber all 238 entries making iC canonical, others a,b,c
        reset_ranks_for_base(base_nums, ins_rank, base=238, canonical_idx=iC)
        # If any 237 insertions remain, keep them as-is.


# Precompute base numbers (for ambler_idx) and insertion ranks
ambler_idx_cache = {}
ins_rank_cache = {}
for name, info in enzyme_info.items():
    qseq = "".join(info["sequence"])
    labels = ambler_map_labels(qseq, TEM1_REF, start_ambler=START_AMBLER_FOR_TEM)

    base_nums = np.empty(len(labels), dtype=np.int32)
    ins_rank  = np.empty(len(labels), dtype=np.int16)
    for i, lab in enumerate(labels):
        b, r = split_label(lab)
        base_nums[i] = b
        ins_rank[i]  = r

    # call normalise_237a_238 only for carbapenemases
    if info.get("carbapenemase", 0) == 1:
        normalise_237a_238(base_nums, ins_rank, qseq)

    ambler_idx_cache[name] = base_nums
    ins_rank_cache[name]   = ins_rank

# Write out file with ambler_idx data aswell
with h5py.File(TARGET_FILE, "w") as out_f:
    for enzyme, info in enzyme_info.items():
        eg = out_f.create_group(enzyme)
        eg.attrs["carbapenemase"] = int(info["carbapenemase"])
        eg.attrs["sequence"] = "".join(info["sequence"])

        for run_id in (1, 2, 3):
            src_path = PATH_TEMPLATE.format(enzyme=enzyme, run_id=run_id)
            if not Path(src_path).exists():
                print(f"{src_path} not found; skipping."); continue

            with h5py.File(src_path, "r") as src_f:
                top  = list(src_f.keys())[0]
                data = src_f[f"{top}/frames"][:]          # (5000, N_residues, 20)
                probs = data.transpose(0, 2, 1)           # (5000, 20, N_residues)
                N = probs.shape[2]

            amb_idx = ambler_idx_cache[enzyme]
            ins_rnk = ins_rank_cache[enzyme]
            if N != len(amb_idx):
                raise RuntimeError(f"{enzyme} run {run_id}: probs residues={N} vs labels={len(amb_idx)}")

            rg = eg.create_group(f"run_{run_id:02d}")
            rg.create_dataset("probs", data=probs, chunks=(500, 20, N), compression="gzip")
            rg.create_dataset("ambler_idx", data=amb_idx, dtype="i4")     
            rg.create_dataset("ambler_ins_rank", data=ins_rnk, dtype="i2")  

print(f"{TARGET_FILE} created with Ambler labels (normalised at 238 for carbapenemases).")
