In [22]:
import torch
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch import optim
from torch import nn
from torch.utils.data import DataLoader, random_split, Dataset
from tqdm import tqdm
from itertools import islice
from itertools import chain
import numpy as np
import random
import os
import glob
from torch.nn.utils.rnn import pad_sequence
from concurrent.futures import ProcessPoolExecutor
from Bio import SeqIO
from Bio.PDB import PDBList, PDBParser, PPBuilder
from Bio.Seq import Seq
from Bio.PDB.Polypeptide import is_aa
from Bio.SeqUtils import seq1
from collections import defaultdict

In [16]:
def parse_single_fasta_file(path):
    with open(path, 'r') as f:
        lines = f.readlines()
    sequence = ''.join([line.strip() for line in lines if not line.startswith(">")])
    return sequence

def parse_fasta_file(filepath):
    try:
        return [str(record.seq) for record in SeqIO.parse(filepath, "fasta")]
    except Exception:
        return []  # Skip malformed files

# Function to gather all .fasta/.fa files recursively
def get_fasta_files_from_nested_folders(root_folder, limit=None):
    files = glob.glob(os.path.join(root_folder, "**", "*.fa*"), recursive=True)
    return list(islice(files, limit)) if limit else list(files)

# Function to load sequences from multiple files using multiprocessing
def load_fasta_sequences_parallel(file_list, num_workers=4):
    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        results = executor.map(parse_fasta_file, file_list)
        return list(chain.from_iterable(results))

def parse_fasta_sequences(filepath):
    peptides = []
    with open(filepath, 'r') as f:
        current_seq = ''
        for line in f:
            if line.startswith('>'):
                if current_seq:
                    peptides.append(current_seq)
                    current_seq = ''
            else:
                current_seq += line.strip()
        if current_seq:
            peptides.append(current_seq)
    return peptides

fasta_files = get_fasta_files_from_nested_folders("RNA_FASTA_Files", limit=200000)
rna_seqs = load_fasta_sequences_parallel(fasta_files, num_workers=8)
peptide_seqs = parse_fasta_sequences("peptideatlas.fasta")


In [17]:
def download_multiple_pdbs(pair_file, out_dir):
    from Bio.PDB import PDBList
    import os
    os.makedirs(out_dir, exist_ok=True)

    pdbl = PDBList()
    pdb_ids = set()

    with open(pair_file, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) != 2 or parts[0].lower() == "protein":
                continue  # Skip headers or malformed lines
            prot, rna = parts
            pdb_ids.add(prot.split("_")[0].lower())

    for pdb_id in sorted(pdb_ids):
        try:
            url = f"https://files.rcsb.org/download/{pdb_id.upper()}.pdb"
            filepath = os.path.join(out_dir, f"{pdb_id}.pdb")
            if not os.path.exists(filepath):
                import urllib.request
                urllib.request.urlretrieve(url, filepath)
        except Exception as e:
          print(f'Nothing')

        
#download_multiple_pdbs("RPI2241.txt", "./pdb_files")

def load_rpi2241_pairs(filepath):
    pairs = set()
    with open(filepath, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) != 2 or parts[0].lower() == "protein":
                continue  # Skip headers or malformed lines
            prot, rna = parts
            pairs.add((prot.upper(), rna.upper()))
    return pairs

rpi2241_positive_pairs=load_rpi2241_pairs("RPI2241.txt")

In [18]:
# One-hot encoding function for RNA and Peptide sequences
def one_hot_encodeRNA(sequence):
    mapping = {'A': 0, 'U': 1, 'G': 2, 'C': 3}
    one_hot = np.zeros((len(sequence), 4))
    for i, base in enumerate(sequence):
        if base in mapping:
            one_hot[i, mapping[base]] = 1
    return torch.tensor(one_hot, dtype=torch.float).permute(1, 0)
def one_hot_encodepeptide(sequence):
    amino_acids = ['A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G', 'H', 'I',
    'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']
    aa_to_index = {aa: idx for idx, aa in enumerate(amino_acids)}
    one_hot = np.zeros((len(sequence), 20))
    for i, aa in enumerate(sequence):
        if aa in aa_to_index:
            one_hot[i, aa_to_index[aa]] = 1
    return torch.tensor(one_hot, dtype=torch.float).permute(1, 0)

In [19]:
def extract_protein_sequence(pdb_file, chain_id):
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("protein", pdb_file)
    for model in structure:
        if chain_id in model:
            chain = model[chain_id]
            peptides = PPBuilder().build_peptides(chain)
            if peptides:
                return str(peptides[0].get_sequence())
    return None

RNA_MAP = {
    "ADE": "A", "CYT": "C", "GUA": "G", "URI": "U",
    "PSU": "U", "INO": "I", "GTP": "G", "OMC": "C",
    "A": "A", "C": "C", "G": "G", "U": "U"  # Handle both 1/3-letter
}

def extract_rna_sequence(pdb_file, chain_id):
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("rna", pdb_file)
    sequence = ""
    
    # Use first model only
    model = structure[0]
    if chain_id not in model:
        return None
    
    for residue in model[chain_id]:
        # Skip heteroatoms and insertions
        residue_id = residue.get_id()
        if residue_id[0] != " ":
            continue
            
        resname = residue.get_resname().strip().upper()
        if resname in RNA_MAP:
            sequence += RNA_MAP[resname]
    
    return sequence if sequence else None

In [25]:
def define_structure_chains(pdb_files, pdb_dir):
    structure_chains = defaultdict(lambda: {'protein': [], 'rna': []})
    parser = PDBParser(QUIET=True)
    
    for pdb_file in pdb_files:
        pdb_id = os.path.splitext(os.path.basename(pdb_file))[0].lower()
        pdb_path = os.path.join(pdb_dir, pdb_file)
        try:
            structure = parser.get_structure(pdb_id, pdb_path)
            model = next(structure.get_models())
            for chain in model.get_chains():
                chain_id = chain.id.upper()
                residues = list(chain.get_residues())
                # Count amino acid residues
                aa_residues = {"ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY",
                               "HIS", "ILE", "LEU", "LYS", "MET", "PHE", "PRO", "SER",
                               "THR", "TRP", "TYR", "VAL"}
                rna_residues = {"ADE", "CYT", "GUA", "URI", "PSU", "INO"}
                
                aa_count = sum(1 for r in residues if r.get_id()[0] == ' ' and r.get_resname() in aa_residues)
                rna_count = sum(1 for r in residues if r.get_id()[0] == ' ' and r.get_resname() in rna_residues)
                
                if aa_count > rna_count:
                    structure_chains[pdb_id]['protein'].append(f"{pdb_id}_{chain_id}")
                elif rna_count > 0:
                    structure_chains[pdb_id]['rna'].append(f"{pdb_id}_{chain_id}")
        except Exception as e:
            print(f"Error processing {pdb_file}: {e}")
    
    return structure_chains
    
pdb_files = [f for f in os.listdir("./pdb_files") if f.endswith(".pdb")]
rpi_structure_chains = define_structure_chains(pdb_files,"./pdb_files")
    
def generate_structure_based_negatives(positive_pairs, structure_chains, num_negatives=None, seed=42):
    random.seed(seed)
    negative_pairs = set()

    for pdb_id, chains in structure_chains.items():
        protein_chains = chains.get("protein", [])
        rna_chains = chains.get("rna", [])
        for prot in protein_chains:
            for rna in rna_chains:
                if (prot, rna) not in positive_pairs:
                    negative_pairs.add((prot, rna))
                    
    return negative_pairs

rpi2241_negative_pairs = generate_structure_based_negatives(positive_pairs=rpi2241_positive_pairs, structure_chains=rpi_structure_chains, num_negatives=len(rpi2241_positive_pairs))

# Attach labels
positive_labeled = [(p, r, 1.0) for p, r in rpi2241_positive_pairs]
negative_labeled = [(p, r, 0.0) for p, r in rpi2241_negative_pairs]

# Combine and shuffle
all_labeled_pairs = positive_labeled + negative_labeled
random.shuffle(all_labeled_pairs)

In [26]:
def get_chain_sequence(chain):
    sequence = ""
    for residue in chain:
        if residue.get_id()[0] != " ":  # Skip heteroatoms/water
            continue
        resname = residue.get_resname()
        try:
            letter = seq1(resname)  # Convert 3-letter to 1-letter
            sequence += letter
        except Exception:
            continue  # Skip unknown residues
    return sequence

In [40]:
class RNAPeptideDataset(Dataset):
    def __init__(self, labeled_pairs, pdb_dir):
        self.pairs = labeled_pairs
        self.pdb_dir = pdb_dir
        # Cache to store parsed structures
        self.structure_cache = {}
        # Cache to store extracted sequences
        self.sequence_cache = {}

    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        prot_chain, rna_chain, label = self.pairs[idx]
        pdb_id = prot_chain.split("_")[0].lower()
        pep_chain_id = prot_chain.split("_")[1].upper()
        rna_chain_id = rna_chain.split("_")[1].upper()
        
        # Get PDB structure (with caching)
        pdb_path = os.path.join(self.pdb_dir, f"{pdb_id}.pdb")
        if not os.path.exists(pdb_path):
            return None
        
        try:
            # Use cached structure if available
            if pdb_id not in self.structure_cache:
                parser = PDBParser(QUIET=True)
                structure = parser.get_structure(pdb_id, pdb_path)
                self.structure_cache[pdb_id] = structure
            
            structure = self.structure_cache[pdb_id]
            model = next(structure.get_models())
            
            # Get chains (with caching)
            cache_key = f"{pdb_id}_{pep_chain_id}_{rna_chain_id}"
            if cache_key not in self.sequence_cache:
                pep_chain = model[pep_chain_id]
                rna_chain = model[rna_chain_id]
                
                pep_seq = get_chain_sequence(pep_chain)
                rna_seq = get_chain_sequence(rna_chain)
                
                self.sequence_cache[cache_key] = (pep_seq, rna_seq)
            else:
                pep_seq, rna_seq = self.sequence_cache[cache_key]
            
            if not pep_seq or not rna_seq:
                return None
                
            # Convert to tensors
            rna_tensor = one_hot_encodeRNA(rna_seq).float()
            pep_tensor = one_hot_encodepeptide(pep_seq).float()
            label_tensor = torch.tensor(label, dtype=torch.float).view(1)
            
            return rna_tensor, pep_tensor, label_tensor
            
        except Exception as e:
            print(f"Error processing {pdb_id}: {str(e)}")
            return None

def collate_fn(batch):
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:
        return None
    
    rna_seqs, pep_seqs, labels = zip(*batch)
    
    # Get max lengths in THIS batch
    max_rna_len = max(seq.shape[0] for seq in rna_seqs)
    max_pep_len = max(seq.shape[0] for seq in pep_seqs)
    
    # Pad to batch-specific max lengths
    rna_padded = torch.stack([
        F.pad(seq, (0, 0, 0, max_rna_len - seq.shape[0])) 
        for seq in rna_seqs
    ])
    
    pep_padded = torch.stack([
        F.pad(seq, (0, 0, 0, max_pep_len - seq.shape[0]))
        for seq in pep_seqs
    ])
    
    labels = torch.stack(labels)
    return rna_padded, pep_padded, labels
full_dataset = RNAPeptideDataset(labeled_pairs=all_labeled_pairs, pdb_dir="./pdb_files")
train_size = int(0.8 * len(full_dataset))
val_size = int(0.15 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [41]:
# Define CNN model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.rna_conv1 = nn.Conv1d(4, 16, kernel_size=3, padding=1)
        self.rna_conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)
        self.rna_conv3 = nn.Conv1d(32, 64, kernel_size=3, padding=1)
        self.pep_conv1 = nn.Conv1d(20, 16, kernel_size=3, padding=1)
        self.pep_conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)
        self.pep_conv3 = nn.Conv1d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128, 64)
        self.fc2 = nn.Linear(64, 1)
        
    def forward(self, rna_input, pep_input):
        rna = F.relu(self.rna_conv1(rna_input.permute(0, 2, 1)))
        rna = F.relu(self.rna_conv2(rna))
        rna = F.relu(self.rna_conv3(rna))
        rna = F.adaptive_max_pool1d(rna, 1).squeeze(2)
        
        pep = F.relu(self.pep_conv1(pep_input.permute(0, 2, 1)))
        pep = F.relu(self.pep_conv2(pep))
        pep = F.relu(self.pep_conv3(pep))
        pep = F.adaptive_max_pool1d(pep, 1).squeeze(2)
        
        combined = torch.cat((rna, pep), dim=1)
        x = F.relu(self.fc1(combined))
        out = self.fc2(x)  
        return out
    

In [42]:
# Initialize model, loss, optimizer
model = CNN()
pos_weight = len(negative_labeled) / len(positive_labeled)
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight]))
optimizer = optim.Adam(model.parameters(), lr=0.0005)

In [None]:
# Move model to device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


for epoch in range(30):
    model.train()
    train_losses = []

    for batch in train_loader:
        if batch is None or len(batch[0]) == 0:  # Enhanced check
            continue
        
        rna_batch, pep_batch, label_batch = batch
        rna_batch = rna_batch.to(device)
        pep_batch = pep_batch.to(device)
        label_batch = label_batch.float().view(-1,1).to(device)
        
        # Fix input dimensions
        rna_batch = rna_batch.permute(0, 2, 1)  
        pep_batch = pep_batch.permute(0, 2, 1)  
        
        scores = model(rna_batch, pep_batch)
        loss = criterion(scores, label_batch)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_losses.append(loss.item())
    
    if (epoch + 1) % 10 == 0 and len(train_losses) > 0:
        avg_loss = sum(train_losses) / len(train_losses)
        print(f"Epoch [{epoch + 1}/30], Loss: {avg_loss:.4f}")
        
    # Validation
    model.eval()
    val_losses = []

    with torch.no_grad():
        for batch in val_loader:  
            if batch is None or len(batch[0]) == 0:  # Enhanced check
                continue
                
            rna_batch, pep_batch, label_batch = batch
            rna_batch = rna_batch.to(device)
            pep_batch = pep_batch.to(device)
            label_batch = label_batch.float().view(-1,1).to(device)
            
            rna_batch = rna_batch.permute(0, 2, 1)  
            pep_batch = pep_batch.permute(0, 2, 1) 
            
            val_scores = model(rna_batch, pep_batch)
            val_loss = criterion(val_scores, label_batch)
            val_losses.append(val_loss.item())
            
    if (epoch + 1) % 10 == 0 and len(val_losses) > 0:      
        avg_val_loss = sum(val_losses) / len(val_losses)
        print(f"Epoch [{epoch + 1}/30], Val Loss: {avg_val_loss:.4f}")


In [None]:
torch.save(model.state_dict(), "models/cnn.pt")