# 0 Import modules

In [12]:
import torch
import torch.nn as nn

# Added for data loading
import numpy as np
import pandas as pd

# 1 Load Data

## Notes on loader output
- For matrix files, `X` is a 2D float32 array of shape (n_drugs, n_targets). Non-numeric cells are coerced to NaN.
- For triplet files, `X` is a (n_interactions, 3) array: [drug_index, target_index, value].
- `meta['drug_ids']` and `meta['target_ids']` hold the original labels; use `meta['mappings']` to go from label -> index.

In [19]:
from pathlib import Path
from typing import Tuple, Dict, Any

kiba_path = Path('data/Davis-KIBA/kiba.txt')
davis_path = Path('data/Davis-KIBA/davis.txt')

print('KIBA file exists:', kiba_path.exists())
print('Davis file exists:', davis_path.exists())

# Peek first few non-empty lines safely
if kiba_path.exists():
    with kiba_path.open('r') as f:
        head = []
        for _ in range(10):
            line = f.readline()
            if not line:
                break
            line = line.strip()
            if line:
                head.append(line)
        print('KIBA head:', head[:5])


KIBA file exists: True
Davis file exists: True
KIBA head: ['CHEMBL1087421 O00141 COC1=C(C=C2C(=C1)CCN=C2C3=CC(=C(C=C3)Cl)Cl)Cl MTVKTEAAKGTLTYSRMRGMVAILIAFMKQRRMGLNDFIQKIANNSYACKHPEVQSILKISQPQEPELMNANPSPPPSPSQQINLGPSSNPHAKPSDFHFLKVIGKGSFGKVLLARHKAEEVFYAVKVLQKKAILKKKEEKHIMSERNVLLKNVKHPFLVGLHFSFQTADKLYFVLDYINGGELFYHLQRERCFLEPRARFYAAEIASALGYLHSLNIVYRDLKPENILLDSQGHIVLTDFGLCKENIEHNSTTSTFCGTPEYLAPEVLHKQPYDRTVDWWCLGAVLYEMLYGLPPFYSRNTAEMYDNILNKPLQLKPNITNSARHLLEGLLQKDRTKRLGAKDDFMEIKSHVFFSLINWDDLINKKITPPFNPNVSGPNDLRHFDPEFTEEPVPNSIGKSPDSVLVTASVKEAAEAFLGFSYAPPTDSFL 11.1', 'CHEMBL1087421 O14920 COC1=C(C=C2C(=C1)CCN=C2C3=CC(=C(C=C3)Cl)Cl)Cl MSWSPSLTTQTCGAWEMKERLGTGGFGNVIRWHNQETGEQIAIKQCRQELSPRNRERWCLEIQIMRRLTHPNVVAARDVPEGMQNLAPNDLPLLAMEYCQGGDLRKYLNQFENCCGLREGAILTLLSDIASALRYLHENRIIHRDLKPENIVLQQGEQRLIHKIIDLGYAKELDQGSLCTSFVGTLQYLAPELLEQQKYTVTVDYWSFGTLAFECITGFRPFLPNWQPVQWHSKVRQKSEVDIVVSEDLNGTVKFSSSLPYPNNLNSVLAERLEKWLQLMLMWHPRQRGTDPTYGPNGCFKALDDILNLKLVHILNMVTGTIHTYPVTEDESLQSLKARIQQDTGIPEEDQELLQEAGLALIPDKP

In [21]:
from dataclasses import dataclass

@dataclass
class FiveColBatch:
    # Integer indices
    drug_idx: np.ndarray  # (N,)
    prot_idx: np.ndarray  # (N,)
    # Raw strings
    drug_id: np.ndarray   # (N,) object dtype
    prot_id: np.ndarray   # (N,) object dtype
    smiles: np.ndarray    # (N,) object dtype
    aa_seq: np.ndarray    # (N,) object dtype
    affinity: np.ndarray  # (N,) float32
    # Metadata
    drug_to_idx: dict
    prot_to_idx: dict


def load_five_col_dti(path: str | Path,
                       sep: str | None = None,
                       has_header: bool | None = None) -> FiveColBatch:
    """
    Load a 5-column DTI file with columns:
      {Drug Id, Protein Id, SMILES seq, Amino Acid seq, Binding Affinity}

    Parameters
    - path: file path
    - sep: field separator. If None, will try ',', '\t', or whitespace
    - has_header: if True/False forces header handling; if None, infer

    Returns a FiveColBatch:
      - drug_idx, prot_idx: integer indices for categorical ids
      - original ids/sequences as object arrays
      - affinity as float32
      - mappings dicts for reproducible indexing
    """
    path = Path(path)

    # Try to read with pandas robustly
    tried = []
    df = None
    seps_to_try = [sep] if sep is not None else [',', '\t', r'\s+']
    header_opts = [0, None] if has_header is None else ([0] if has_header else [None])
    for s in seps_to_try:
        for h in header_opts:
            try:
                engine = 'python' if s == r'\s+' else None
                df_try = pd.read_csv(path, sep=s, header=h, engine=engine)
                if df_try.shape[1] >= 5:
                    df = df_try
                    break
            except Exception as e:
                tried.append((s, h, str(e)))
        if df is not None:
            break
    if df is None:
        raise ValueError(f"Unable to parse file {path}. Attempts: {tried[:3]} ...")

    # Reduce to first 5 columns and assign names
    df = df.iloc[:, :5].copy()
    if df.columns.dtype == 'int64' or isinstance(df.columns[0], (int, np.integer)):
        df.columns = ['drug_id', 'prot_id', 'smiles', 'aa_seq', 'affinity']
    else:
        # Try to rename common names; otherwise overwrite
        rename_map = {}
        cols_lower = [str(c).strip().lower() for c in df.columns]
        mapping_rules = {
            'drug': 'drug_id', 'drug_id': 'drug_id', 'drugid': 'drug_id',
            'protein': 'prot_id', 'protein_id': 'prot_id', 'target': 'prot_id', 'prot_id': 'prot_id',
            'smiles': 'smiles', 'sequence': 'aa_seq', 'seq': 'aa_seq', 'aa': 'aa_seq',
            'affinity': 'affinity', 'binding affinity': 'affinity', 'label': 'affinity', 'y': 'affinity'
        }
        for i, c in enumerate(cols_lower):
            rename_map[df.columns[i]] = mapping_rules.get(c, df.columns[i])
        df = df.rename(columns=rename_map)
        # Ensure exact final names
        df.columns = ['drug_id', 'prot_id', 'smiles', 'aa_seq', 'affinity']

    # Clean and types
    for c in ['drug_id', 'prot_id', 'smiles', 'aa_seq']:
        df[c] = df[c].astype(str).str.strip()
    df['affinity'] = pd.to_numeric(df['affinity'], errors='coerce').astype(np.float32)

    # Build stable categorical indices
    drug_cats = pd.Categorical(df['drug_id'])
    prot_cats = pd.Categorical(df['prot_id'])
    drug_ids = list(drug_cats.categories)
    prot_ids = list(prot_cats.categories)
    drug_to_idx = {d: i for i, d in enumerate(drug_ids)}
    prot_to_idx = {p: i for i, p in enumerate(prot_ids)}

    drug_idx = df['drug_id'].map(drug_to_idx).to_numpy(dtype=np.int64)
    prot_idx = df['prot_id'].map(prot_to_idx).to_numpy(dtype=np.int64)

    return FiveColBatch(
        drug_idx=drug_idx,
        prot_idx=prot_idx,
        drug_id=df['drug_id'].to_numpy(object),
        prot_id=df['prot_id'].to_numpy(object),
        smiles=df['smiles'].to_numpy(object),
        aa_seq=df['aa_seq'].to_numpy(object),
        affinity=df['affinity'].to_numpy(dtype=np.float32),
        drug_to_idx=drug_to_idx,
        prot_to_idx=prot_to_idx,
    )

kiba_batch = load_five_col_dti(kiba_path)
print('KIBA 5-col loaded:')
print('N =', kiba_batch.affinity.shape[0], '| unique drugs =', len(kiba_batch.drug_to_idx), '| unique proteins =', len(kiba_batch.prot_to_idx))
print('drug_idx dtype:', kiba_batch.drug_idx.dtype, '| affinity dtype:', kiba_batch.affinity.dtype)

davis_batch = load_five_col_dti(davis_path)
print('Davis 5-col loaded:')
print('N =', davis_batch.affinity.shape[0], '| unique drugs =',
      len(davis_batch.drug_to_idx), '| unique proteins =', len(davis_batch.prot_to_idx))
print('drug_idx dtype:', davis_batch.drug_idx.dtype, '| affinity dtype:', davis_batch.affinity.dtype)

KIBA 5-col loaded:
N = 118253 | unique drugs = 2111 | unique proteins = 229
drug_idx dtype: int64 | affinity dtype: float32
Davis 5-col loaded:
N = 30055 | unique drugs = 68 | unique proteins = 442
drug_idx dtype: int64 | affinity dtype: float32


Check the 5-col files:

In [22]:
def _trunc(s, n=40):
  s = str(s)
  return s if len(s) <= n else s[:n] + '...'

i = 0
print(f"KIBA[{i}] drug_id={kiba_batch.drug_id[i]} prot_id={kiba_batch.prot_id[i]} "
    f"affinity={float(kiba_batch.affinity[i]):.4g} "
    f"smiles={_trunc(kiba_batch.smiles[i])} aa_seq={_trunc(kiba_batch.aa_seq[i])}")

print(f"Davis[{i}] drug_id={davis_batch.drug_id[i]} prot_id={davis_batch.prot_id[i]} "
    f"affinity={float(davis_batch.affinity[i]):.4g} "
    f"smiles={_trunc(davis_batch.smiles[i])} aa_seq={_trunc(davis_batch.aa_seq[i])}")

KIBA[0] drug_id=CHEMBL1087421 prot_id=O14920 affinity=11.1 smiles=COC1=C(C=C2C(=C1)CCN=C2C3=CC(=C(C=C3)Cl)... aa_seq=MSWSPSLTTQTCGAWEMKERLGTGGFGNVIRWHNQETGEQ...
Davis[0] drug_id=11314340 prot_id=ABL1(E255K) affinity=5 smiles=CC1=C2C=C(C=CC2=NN1)C3=CC(=CN=C3)OCC(CC4... aa_seq=PFWKILNPLLERGTYYYFMGQQPGKVLGDQRRPSLPALHF...


Check model-friendly IDs and label mappings:

In [23]:
# Inspect integer indices and metadata for loaded batches (KIBA/Davis) and optional `meta` dict.

def check_batch(b, name):
  print(f'== {name} FiveColBatch ==')
  N = int(b.affinity.shape[0])
  print(f'N={N} | unique drugs={len(b.drug_to_idx)} | unique proteins={len(b.prot_to_idx)}')

  di, pi = b.drug_idx, b.prot_idx
  diu, piu = np.unique(di), np.unique(pi)
  print(f'drug_idx: dtype={di.dtype}, range=[{int(diu.min()) if diu.size else "NA"}, {int(diu.max()) if diu.size else "NA"}], n_unique={int(diu.size)}')
  print(f'prot_idx: dtype={pi.dtype}, range=[{int(piu.min()) if piu.size else "NA"}, {int(piu.max()) if piu.size else "NA"}], n_unique={int(piu.size)}')

  # Validate index coverage vs mappings
  drug_ok = set(diu) <= set(range(len(b.drug_to_idx)))
  prot_ok = set(piu) <= set(range(len(b.prot_to_idx)))
  print(f'drug_idx within mapping range: {drug_ok}')
  print(f'prot_idx within mapping range: {prot_ok}')

  # Check mapping contiguity (0..n-1)
  dvals = sorted(set(b.drug_to_idx.values()))
  pvals = sorted(set(b.prot_to_idx.values()))
  print(f'drug_to_idx contiguous: {dvals == list(range(len(b.drug_to_idx)))}')
  print(f'prot_to_idx contiguous: {pvals == list(range(len(b.prot_to_idx)))}')

  # Affinity sanity
  aff = b.affinity
  nan_count = int(np.isnan(aff).sum())
  print(f'affinity: dtype={aff.dtype}, min={float(np.nanmin(aff)) if N else "NA"}, max={float(np.nanmax(aff)) if N else "NA"}, NaNs={nan_count}')

  # Show a few examples to verify index<->id mapping
  k = min(3, N)
  for i in range(k):
    print(f'[{i}] drug_id={b.drug_id[i]} -> {int(di[i])}, prot_id={b.prot_id[i]} -> {int(pi[i])}, y={float(aff[i]):.4g}')

# Check KIBA and Davis batches (created earlier)
check_batch(kiba_batch, 'KIBA')
check_batch(davis_batch, 'Davis')

# If a generic `meta` dict is present, validate its mappings as well
if 'meta' in globals() and isinstance(meta, dict):
  print('== meta dict ==')
  kind = meta.get('kind')
  d_ids = meta.get('drug_ids') or []
  t_ids = meta.get('target_ids') or []
  maps = meta.get('mappings') or {}
  dmap = maps.get('drug_to_idx') or {}
  tmap = maps.get('target_to_idx') or {}

  print(f"kind={kind} | n_drug_ids={len(d_ids)} | n_target_ids={len(t_ids)}")
  print(f"drug_to_idx: n_keys={len(dmap)}, idx_contiguous={sorted(set(dmap.values())) == list(range(len(dmap)))}")
  print(f"target_to_idx: n_keys={len(tmap)}, idx_contiguous={sorted(set(tmap.values())) == list(range(len(tmap)))}")

  # Coverage checks (set-based; may be large)
  d_missing = [d for d in set(d_ids) if d not in dmap]
  t_missing = [t for t in set(t_ids) if t not in tmap]
  print(f"drug ids missing in mapping: {len(d_missing)}")
  print(f"target ids missing in mapping: {len(t_missing)}")

  # Show a couple of mapping samples
  for k, v in list(dmap.items())[:3]:
    print(f"drug_to_idx sample: {k} -> {v}")
  for k, v in list(tmap.items())[:3]:
    print(f"target_to_idx sample: {k} -> {v}")

== KIBA FiveColBatch ==
N=118253 | unique drugs=2111 | unique proteins=229
drug_idx: dtype=int64, range=[0, 2110], n_unique=2111
prot_idx: dtype=int64, range=[0, 228], n_unique=229
drug_idx within mapping range: True
prot_idx within mapping range: True
drug_to_idx contiguous: True
prot_to_idx contiguous: True
affinity: dtype=float32, min=0.0, max=17.200180053710938, NaNs=0
[0] drug_id=CHEMBL1087421 -> 22, prot_id=O14920 -> 6, y=11.1
[1] drug_id=CHEMBL1087421 -> 22, prot_id=O15111 -> 9, y=11.1
[2] drug_id=CHEMBL1087421 -> 22, prot_id=P00533 -> 25, y=11.1
== Davis FiveColBatch ==
N=30055 | unique drugs=68 | unique proteins=442
drug_idx: dtype=int64, range=[0, 67], n_unique=68
prot_idx: dtype=int64, range=[0, 441], n_unique=442
drug_idx within mapping range: True
prot_idx within mapping range: True
drug_to_idx contiguous: True
prot_to_idx contiguous: True
affinity: dtype=float32, min=5.0, max=10.795880317687988, NaNs=0
[0] drug_id=11314340 -> 8, prot_id=ABL1(E255K) -> 2, y=5
[1] drug_id=1

# 2 Fetch Protein Information

- UniProt IDs are conveniently provided.
- We shall fetch the 3D structures in whatever format ProtDB provides, and deal with them later.
- If no experimental structure is available, fall back to AlphaFoldDB

**Do so with `fetch_structures_v2.py`**

# 3 Embedding

ESMFold requires Python 3.9. I am creating a separate environment because we are only using it for this purpose of embeddings.

In [None]:
import gemmi

## Generate feature files