In [None]:
pip install Biopython



In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch import optim
from torch import nn
from torch.utils.data import DataLoader, random_split, Dataset, Subset
from tqdm import tqdm
from itertools import chain
from itertools import islice
import numpy as np
import random
import os
import glob
import pickle
from torch.nn.utils.rnn import pad_sequence
from Bio import SeqIO
from Bio.PDB import PDBList, PDBParser, PPBuilder
from Bio.Seq import Seq
from Bio.PDB.Polypeptide import is_aa
from Bio.SeqUtils import seq1
from collections import defaultdict
from functools import partial
import multiprocessing
from Bio.PDB.Polypeptide import is_aa

In [None]:
def parse_single_fasta_file(path):
    with open(path, 'r') as f:
        lines = f.readlines()
    sequence = ''.join([line.strip() for line in lines if not line.startswith(">")])
    return sequence

def parse_fasta_file(filepath):
    try:
        return [str(record.seq).strip().upper() for record in SeqIO.parse(filepath, "fasta")]
    except Exception:
        return []

# Function to gather all .fasta/.fa files recursively
def get_fasta_files_from_nested_folders(root_folder, limit=None):
    files = glob.glob(os.path.join(root_folder, "**", "*.fa*"), recursive=True)
    return list(islice(files, limit)) if limit else list(files)

# Function to load sequences from multiple files using multiprocessing
def load_fasta_sequences_parallel(file_list, num_workers=4):
    if not file_list:
        return []
    with multiprocessing.Pool(processes=num_workers) as pool:
        results = pool.imap_unordered(parse_fasta_file, file_list, chunksize=100)
        sequences = []
        for seq_list in tqdm(results, total=len(file_list)):
            sequences.extend(seq_list)
    return sequences

def parse_fasta_sequences(filepath):
    peptides = []
    with open(filepath, 'r') as f:
        current_seq = ''
        for line in f:
            if line.startswith('>'):
                if current_seq:
                    peptides.append(current_seq)
                    current_seq = ''
            else:
                current_seq += line.strip()
        if current_seq:
            peptides.append(current_seq)
    return peptides

fasta_files = get_fasta_files_from_nested_folders("/content/drive/MyDrive/Deep Learning-RNA-Peptide-Interaction/RNA-fastaFiles", limit=200000)
rna_seqs = load_fasta_sequences_parallel(fasta_files, num_workers=8)
peptide_seqs = parse_fasta_sequences("/content/drive/MyDrive/Deep Learning-RNA-Peptide-Interaction/peptideatlas.fasta")

100%|██████████| 102321/102321 [06:58<00:00, 244.72it/s]


In [None]:
def download_multiple_pdbs(pair_file, out_dir):
    from Bio.PDB import PDBList
    import os
    os.makedirs(out_dir, exist_ok=True)

    pdbl = PDBList()
    pdb_ids = set()

    with open(pair_file, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) != 2 or parts[0].lower() == "protein":
                continue  # Skip headers or malformed lines
            prot, rna = parts
            pdb_ids.add(prot.split("_")[0].lower())

    for pdb_id in sorted(pdb_ids):
        try:
            url = f"https://files.rcsb.org/download/{pdb_id.upper()}.pdb"
            filepath = os.path.join(out_dir, f"{pdb_id}.pdb")
            if not os.path.exists(filepath):
                import urllib.request
                urllib.request.urlretrieve(url, filepath)
        except Exception:
          print(f'Nothing')


#download_multiple_pdbs("RPI2241.txt", "./pdb_files")

def load_rpi2241_pairs(filepath):
    pairs = set()
    with open(filepath, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) != 2 or parts[0].lower() == "protein":
                continue  # Skip headers or malformed lines
            prot, rna = parts
            pairs.add((prot.upper(), rna.upper()))
    return pairs

rpi2241_positive_pairs=load_rpi2241_pairs("/content/drive/MyDrive/Deep Learning-RNA-Peptide-Interaction/RPI2241.txt")

In [None]:
# One-hot encoding function for RNA and Peptide sequences
def one_hot_encodeRNA(sequence):
    mapping = {'A': 0, 'U': 1, 'G': 2, 'C': 3}
    one_hot = np.zeros((len(sequence), 4))
    for i, base in enumerate(sequence):
        if base in mapping:
            one_hot[i, mapping[base]] = 1
    return torch.tensor(one_hot, dtype=torch.float)
def one_hot_encodepeptide(sequence):
    amino_acids = ['A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G', 'H', 'I',
    'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']
    aa_to_index = {aa: idx for idx, aa in enumerate(amino_acids)}
    one_hot = np.zeros((len(sequence), 20))
    for i, aa in enumerate(sequence):
        if aa in aa_to_index:
            one_hot[i, aa_to_index[aa]] = 1
    return torch.tensor(one_hot, dtype=torch.float)

In [None]:
def precompute_fasta_encodings(seqs, encode_fn):
    return [encode_fn(seq) for seq in tqdm(seqs)]

rna_tensors = precompute_fasta_encodings(rna_seqs, one_hot_encodeRNA)
pep_tensors = precompute_fasta_encodings(peptide_seqs, one_hot_encodepeptide)

100%|██████████| 102989/102989 [00:14<00:00, 7032.56it/s]
100%|██████████| 2492451/2492451 [01:29<00:00, 27811.49it/s]


In [None]:
RNA_MAP = {
    "ADE": "A", "CYT": "C", "GUA": "G", "URI": "U",
    "PSU": "U", "INO": "I", "GTP": "G", "OMC": "C",
    "A": "A", "C": "C", "G": "G", "U": "U"  # Handle both 1/3-letter
}

AMINO_ACID_MAP = {
    'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C',
    'GLU': 'E', 'GLN': 'Q', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
    'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', 'PRO': 'P',
    'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V',
    'ASX': 'B', 'GLX': 'Z', 'SEC': 'U', 'PYL': 'O'  # Unusual residues
}


In [None]:
def get_chain_sequence(chain):
    """Optimized residue processing"""
    seq = []
    for residue in chain:
        if residue.id[0] != " ":  # Skip heteroatoms
            continue
        resname = residue.resname.strip().upper()
        # Use direct mapping instead of seq1()
        if resname in AMINO_ACID_MAP:
            seq.append(AMINO_ACID_MAP[resname])
        elif resname in RNA_MAP:
            seq.append(RNA_MAP[resname])

    return "".join(seq)

In [None]:
SEQUENCE_CACHE_PATH = "/content/drive/MyDrive/Deep Learning-RNA-Peptide-Interaction/structure_sequences_cache.pkl"
CACHE_VERSION = 1

def get_chain_type(chain):
  residues = list(chain.get_residues())
  if not residues:
    return None
  if is_aa(residues[0]):
    return 'protein'
  else:
    return 'rna'

def get_cached_structure_sequences(pdb_dir, cache_path):
  if os.path.exists(cache_path):
    try:
      with open(cache_path, 'rb') as f:
        data = pickle.load(f)
        if data.get('version') == CACHE_VERSION:
          return data['sequences']
    except Exception:
      pass  # Recompute on error

  structure_sequences = {}
  parser = PDBParser(QUIET=True)

  for pdb_file in os.listdir(pdb_dir):
    if not pdb_file.lower().endswith(".pdb"):
      continue
    pdb_id = os.path.splitext(pdb_file)[0].lower()
    file_path = os.path.join(pdb_dir, pdb_file)
    try:
      structure = parser.get_structure(pdb_id, file_path)
      model = next(structure.get_models())
      chains = {}

      for chain in model:
        chain_type = get_chain_type(chain)
        if not chain_type:
          continue
        seq = get_chain_sequence(chain)
        chains[chain.id] = {'type': chain_type, 'sequence': seq}

      structure_sequences[pdb_id] = chains

    except Exception as e:
      print(f"Error processing {pdb_id}: {str(e)}")
      continue

  with open(cache_path, 'wb') as f:
    pickle.dump({
        'version': CACHE_VERSION,
        'sequences': structure_sequences
        }, f)
  return structure_sequences

rpi_structure_chains=get_cached_structure_sequences(pdb_dir="/content/drive/MyDrive/Deep Learning-RNA-Peptide-Interaction/pdb_files", cache_path=SEQUENCE_CACHE_PATH)

def generate_valid_negatives(positive_pairs, structure_chains, num_negatives):
  negative_pairs = []
  positive_set = set((p, r) for p, r in positive_pairs)
  all_pdbs = list(structure_chains.keys())

  for _ in range(num_negatives):
     # Select a random PDB with protein chains
    pdb_id = random.choice(all_pdbs)
    chains = rpi_structure_chains[pdb_id]

    protein_chains = [cid for cid, info in chains.items() if info['type'] == 'protein']
    rna_chains = [cid for cid, info in chains.items() if info['type'] == 'rna']

    if not protein_chains or not rna_chains:
      continue

    prot_chain_id = random.choice(protein_chains)
    rna_chain_id = random.choice(rna_chains)
    prot_chain =f"{pdb_id}_{prot_chain_id}"
    rna_chain = f"{pdb_id}_{rna_chain_id}"

    if (prot_chain, rna_chain) in positive_set:
      continue
    negative_pairs.append((prot_chain, rna_chain, 0))

  return negative_pairs

rpi2241_negative_pairs = generate_valid_negatives(positive_pairs=rpi2241_positive_pairs, structure_chains=rpi_structure_chains, num_negatives=len(rpi2241_positive_pairs))

fasta_negative_pairs = []
for _ in range(len(rpi2241_positive_pairs)):
    rna_seq = random.choice(rna_seqs)
    pep_seq = random.choice(peptide_seqs)
    rna_idx = random.randint(0, len(rna_seqs)-1)
    pep_idx = random.randint(0, len(peptide_seqs)-1)
    fasta_negative_pairs.append((rna_idx, pep_idx))

positive_labeled = [(p, r, 1.0) for p, r in rpi2241_positive_pairs]
negative_labeled = rpi2241_negative_pairs

all_labeled_pairs = positive_labeled + negative_labeled + fasta_negative_pairs

In [None]:
ENCODING_CACHE = {}

def cached_one_hot_encode(sequence, encoder, encoder_name):
    key = (encoder_name, sequence)
    if key not in ENCODING_CACHE:
        ENCODING_CACHE[key] = encoder(sequence).float()
    return ENCODING_CACHE[key]

In [None]:
class RNAPeptideDataset(Dataset):
  def __init__(self, labeled_pairs, structure_sequences, pdb_dir, fasta_rna_seqs=None, fasta_pep_seqs=None):
    self.pairs = labeled_pairs
    self.structure_sequences = structure_sequences
    self.data = []
    self.pdb_dir = pdb_dir
    self.structure_cache = {}
    self.sequence_cache = {}
    self.fasta_rna_seqs = fasta_rna_seqs if fasta_rna_seqs is not None else []
    self.fasta_pep_seqs = fasta_pep_seqs if fasta_pep_seqs is not None else []

    for item in labeled_pairs:
      if len(item) == 3:
        if all(isinstance(x, str) for x in item[:2]):
          prot_chain, rna_chain, label = item
          self._process_pdb_pair(prot_chain, rna_chain, label)  # Fixed method name
        elif len(item) == 2:
          rna_idx, pep_idx = item#
          self._process_fasta_pair(rna_idx, pep_idx)
  def _process_pdb_pair(self, prot_chain, rna_chain, label):
    if "_" not in prot_chain or "_" not in rna_chain:
      return None
    try:
      pdb_id = prot_chain.split("_")[0].lower()
      pep_chain_id = prot_chain.split("_")[1].upper()

      rna_parts = rna_chain.split("_")
      rna_chain_id = rna_parts[1].upper() if len(rna_parts) > 1 else ""
      # Get sequences from precomputed dict
      chains = self.structure_sequences.get(pdb_id, {})
      pep_seq = chains.get(pep_chain_id, "")
      rna_seq = chains.get(rna_chain_id, "")

      if not pep_seq or not rna_seq:
        return None

      self._add_to_data(rna_seq, pep_seq, label)

    except IndexError as e:
      print(f"Format error in chain IDs: {e}")
      print(f"prot_chain: {prot_chain}, rna_chain: {rna_chain}")

  def _process_fasta_pair(self, rna_idx, pep_idx):
    try:
      rna_seq = self.fasta_rna_seqs[rna_idx]
      pep_seq = self.fasta_pep_seqs[pep_idx]
      self._add_to_data(rna_seq, pep_seq, 0)
    except IndexError:
      return None

  def _add_to_data(self, rna_seq, pep_seq, label):
    rna_tensor = self._cached_encode(rna_seq['sequence'], "RNA")
    pep_tensor = self._cached_encode(pep_seq['sequence'], "PEP")
    label_tensor = torch.tensor([label], dtype=torch.float)

    self.data.append((rna_tensor, pep_tensor, label_tensor))

  def _cached_encode(self, sequence, seq_type):
    cache_key = (sequence, seq_type)
    if cache_key not in self.sequence_cache:
      if seq_type == "RNA":
        tensor = one_hot_encodeRNA(sequence).float()
      else:
        tensor = one_hot_encodepeptide(sequence).float()
      self.sequence_cache[cache_key] = tensor
    return self.sequence_cache[cache_key]

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    try:
      return self.data[idx]
    except IndexError:
      return None

def collate_fn(batch):
  # Filter None items and check for empty batch
  batch = [item for item in batch if item is not None]
  if not batch:
    return torch.tensor([]), torch.tensor([]), torch.tensor([])

  # Unpack with shape validation
  rna_seqs, pep_seqs, labels = [], [], []
  for item in batch:
    r, p, l = item
    # Validate channel dimensions
    if r.shape[-1] != 4:
      r = F.one_hot(r.argmax(-1), num_classes=4).float()  # Repair RNA
    if p.shape[-1] != 20:
      p = F.one_hot(p.argmax(-1), num_classes=20).float() # Repair PEP
    rna_seqs.append(r)
    pep_seqs.append(p)
    labels.append(l)

  # Pad sequences
  rna_padded = pad_sequence(rna_seqs, batch_first=True)
  pep_padded = pad_sequence(pep_seqs, batch_first=True)

  # Verify batch consistency
  assert len(rna_padded) == len(pep_padded) == len(labels)
  return rna_padded, pep_padded, torch.stack(labels)


full_dataset = RNAPeptideDataset(labeled_pairs=all_labeled_pairs, structure_sequences=rpi_structure_chains, pdb_dir="/content/drive/MyDrive/Deep Learning-RNA-Peptide-Interaction/pdb_files", fasta_rna_seqs=rna_seqs, fasta_pep_seqs=peptide_seqs)

pdb_pairs = [pair for pair in all_labeled_pairs if isinstance(pair[0], str) and len(pair[0].split("_")) == 2]
fasta_pairs = [pair for pair in all_labeled_pairs if isinstance(pair[0], int) and isinstance(pair[1], int)]

pdb_to_indices = defaultdict(list)
for idx, (prot_chain, rna_chain, label) in enumerate(pdb_pairs):
    pdb_id = prot_chain.split("_")[0].lower()
    pdb_to_indices[pdb_id].append(idx)

pdb_ids = list(pdb_to_indices.keys())
np.random.shuffle(pdb_ids)

n_total_pdb = len(pdb_ids)
n_train_pdb = int(0.8 * n_total_pdb)
n_val_pdb = int(0.15 * n_total_pdb)
n_test_pdb = n_total_pdb - n_train_pdb - n_val_pdb

train_pdbs = pdb_ids[:n_train_pdb]
val_pdbs = pdb_ids[n_train_pdb:n_train_pdb+n_val_pdb]
test_pdbs = pdb_ids[n_train_pdb+n_val_pdb:]

train_indices_pdb = [idx for pdb in train_pdbs for idx in pdb_to_indices[pdb]]
val_indices_pdb = [idx for pdb in val_pdbs for idx in pdb_to_indices[pdb]]
test_indices_pdb = [idx for pdb in test_pdbs for idx in pdb_to_indices[pdb]]

fasta_indices = list(range(len(pdb_pairs), len(pdb_pairs) + len(fasta_pairs)))
np.random.shuffle(fasta_indices)

n_total_fasta = len(fasta_indices)
n_train_fasta = int(0.8 * n_total_fasta)
n_val_fasta = int(0.15 * n_total_fasta)
n_test_fasta = n_total_fasta - n_train_fasta - n_val_fasta

train_indices_fasta = fasta_indices[:n_train_fasta]
val_indices_fasta = fasta_indices[n_train_fasta:n_train_fasta+n_val_fasta]
test_indices_fasta = fasta_indices[n_train_fasta+n_val_fasta:]

train_indices = train_indices_pdb + train_indices_fasta
val_indices = val_indices_pdb + val_indices_fasta
test_indices = test_indices_pdb + test_indices_fasta

train_dataset = Subset(full_dataset, train_indices)
val_dataset = Subset(full_dataset, val_indices)
test_dataset = Subset(full_dataset, test_indices)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [None]:
# Define CNN model
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    super(CNN, self).__init__()
    self.rna_conv1 = nn.Conv1d(4, 16, kernel_size=3, padding=1)
    self.rna_conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)
    self.rna_conv3 = nn.Conv1d(32, 64, kernel_size=3, padding=1)
    self.rna_pool = nn.AdaptiveAvgPool1d(1)
    self.pep_conv1 = nn.Conv1d(20, 16, kernel_size=3, padding=1)
    self.pep_conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)
    self.pep_conv3 = nn.Conv1d(32, 64, kernel_size=3, padding=1)
    self.pep_pool = nn.AdaptiveAvgPool1d(1)
    self.fc1 = nn.Linear(128, 64)
    self.fc2 = nn.Linear(64, 1)

  def forward(self, rna_input, pep_input):
    if rna_input.dim() == 1:
      rna_input = F.one_hot(rna_input.long(), num_classes=4).float()
    if pep_input.dim() == 1:
      pep_input = F.one_hot(pep_input.long(), num_classes=20).float()

    if rna_input.dim() == 2:
      rna_input = rna_input.unsqueeze(0)
    if pep_input.dim() == 2:
      pep_input = pep_input.unsqueeze(0)

    # Permute to [batch, channels, seq_len]
    rna = rna_input.permute(0, 2, 1)
    pep = pep_input.permute(0, 2, 1)

    if rna.size(2) < 3:
        rna = F.pad(rna, (0, 3 - rna.size(2)))
    if pep.size(2) < 3:
        pep = F.pad(pep, (0, 3 - pep.size(2)))

    rna = F.relu(self.rna_conv1(rna))
    rna = F.relu(self.rna_conv2(rna))
    rna = F.relu(self.rna_conv3(rna))
    rna = self.rna_pool(rna).squeeze(2)

    pep = F.relu(self.pep_conv1(pep))
    pep = F.relu(self.pep_conv2(pep))
    pep = F.relu(self.pep_conv3(pep))
    pep = self.pep_pool(pep).squeeze(2)

    combined = torch.cat((rna, pep), dim=1)
    x = F.relu(self.fc1(combined))
    out = self.fc2(x)
    return out


In [None]:
# Initialize model, loss, optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN()
total_positives=len(positive_labeled)
total_negatives=len(fasta_negative_pairs) +len(negative_labeled)
pos_weight_value = torch.tensor([total_negatives / total_positives])
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_value)
optimizer = optim.Adam(model.parameters(), lr=0.0005)

In [None]:
# Move model to device (GPU if available)
model.to(device)


for epoch in range(30):
  model.train()
  train_losses = []

  for batch in train_loader:
    if batch is None:
      continue

    rna_batch, pep_batch, label_batch = batch
    rna_batch = rna_batch.to(device)
    pep_batch = pep_batch.to(device)
    label_batch = label_batch.float().view(-1,1).to(device)

    scores = model(rna_batch, pep_batch)
    loss = criterion(scores, label_batch)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    train_losses.append(loss.item())

  if len(train_losses) > 0:
    avg_loss = sum(train_losses) / len(train_losses)
    if (epoch + 1) % 10 == 0:
      print(f"Epoch [{epoch + 1}/30], Loss: {avg_loss:.4f}")

  # Validation
  model.eval()
  val_losses = []

  with torch.no_grad():
    for batch in val_loader:
      if batch is None:
        continue

      rna_batch, pep_batch, label_batch = batch

      if rna_batch.size(0) == 0 or pep_batch.size(0) == 0 or label_batch.size(0) == 0:
        continue

      rna_batch = rna_batch.to(device)
      pep_batch = pep_batch.to(device)
      label_batch = label_batch.float().view(-1,1).to(device)

      val_scores = model(rna_batch, pep_batch)
      val_loss = criterion(val_scores, label_batch)
      val_losses.append(val_loss.item())
    if len(val_losses) > 0:
      avg_val_loss = sum(val_losses) / len(val_losses)
      if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch + 1}/30], Val Loss: {avg_val_loss:.4f}")

Epoch [10/30], Loss: 0.8605
Epoch [10/30], Val Loss: 0.8693
Epoch [20/30], Loss: 0.8355
Epoch [20/30], Val Loss: 0.8886
Epoch [30/30], Loss: 0.8243
Epoch [30/30], Val Loss: 0.8835


In [None]:
torch.save(model.state_dict(), "/content/drive/MyDrive/Deep Learning-RNA-Peptide-Interaction/models/cnn.pt")