In [1]:
import pickle

# Load the pickle file
def load_pkl(file_path):
    with open(file_path, "rb") as f:
        data = pickle.load(f)
    # Display basic information
    #print("Data type:", type(data))
    #print("Sample data:", data[:2] if isinstance(data, list) else data)
    return data


import os
import numpy as np
import torch

def get_sequence_embeddings(test_unique_sequences, embeddings_paths):
    """
    Given a list of test unique sequences, filter out sequences longer than 1022 and
    create a dictionary mapping each valid sequence to its combined embedding.
    
    The combined embedding is computed by:
      - Max pooling along the sequence length dimension (dim=1)
      - Average pooling along the sequence length dimension (dim=1)
      - Concatenating the two pooled results along the feature dimension (dim=-1)
    
    The embedding file for each sequence is expected to be named 'seq_{idx}.npy' in the directory embeddings_paths.
    Note: enumerate() starts indexing at 0.
    
    Parameters:
        test_unique_sequences (list): List of sequences.
        embeddings_paths (str): Path to the directory containing embedding files.
    
    Returns:
        dict: Dictionary where keys are sequences and values are the combined embeddings (torch.Tensor).
    """
    
    sequence_embeddings = {}
    
    # Iterate over valid sequences using enumerate (index starts at 0)
    for idx, seq in enumerate(test_unique_sequences):
        # Filter sequences with length <= 1022
        #if len(seq) > 1022:
        #    continue
        # Construct the filename for the corresponding embedding (with .npy extension)
        embedding_file = os.path.join(embeddings_paths, f"seq_{idx}.npy")
        
        # Load the embedding numpy array
        embedding_np = np.load(embedding_file)
        
        # Convert the numpy array to a PyTorch tensor
        embedding = torch.from_numpy(embedding_np)
        
        # Perform max pooling along the sequence length dimension (dim=1)
        max_pooled, _ = embedding.max(dim=0)  # Expected shape: (1, emb_dim)
        #print('max_pooled', max_pooled.shape)
        #print('max_pooled', max_pooled)
        
        # Perform average pooling along the sequence length dimension (dim=1)
        avg_pooled = embedding.mean(dim=0)    # Expected shape: (1, emb_dim)
        #print('avg_pooled', avg_pooled.shape)
        # print('avg_pooled', avg_pooled)
        
        # Concatenate the pooled embeddings along the feature dimension (dim=-1)
        combined_embedding = torch.cat([max_pooled, avg_pooled], dim=-1)  # Expected shape: (1, emb_dim*2)
        #print('combined_embedding', combined_embedding.shape)
        # print('combined_embedding', combined_embedding)
        
        # Add the sequence and its combined embedding to the dictionary
        sequence_embeddings[seq] = combined_embedding
    
    return sequence_embeddings


In [2]:
def create_output_full(test_negatives, test_positives, sequence_embeddings_test_dict):
    """
    Create an output list from test_negatives and test_positives pairs. For each pair,
    if both sequences are present in sequence_embeddings_test_dict, then create a dictionary 
    with key 'true' mapping to a list of the two embeddings [logits_target, logits_ligand].
    
    Parameters:
        test_negatives (list): List of pairs (e.g., (seq_target, seq_ligand)) for negative instances.
        test_positives (list): List of pairs (e.g., (seq_target, seq_ligand)) for positive instances.
        sequence_embeddings_test_dict (dict): Dictionary mapping sequences (strings) to their embeddings (torch.Tensor).
    
    Returns:
        list: A list where each element is a dictionary of the form:
              {'true': [logits_target, logits_ligand]}
              Only pairs for which both sequences are found in sequence_embeddings_test_dict are included.
    """
    output_full = []
    
    # Process negative pairs
    for pair in test_negatives:
        # Unpack the pair; assuming each pair is a tuple or list of two sequences.
        seq_target, seq_ligand = pair
        
        # Check if both sequences exist in the embeddings dictionary
        if seq_target in sequence_embeddings_test_dict and seq_ligand in sequence_embeddings_test_dict:
            output_in_file = {}
            logits_target = sequence_embeddings_test_dict[seq_target]
            logits_ligand = sequence_embeddings_test_dict[seq_ligand]
            output_in_file['negative'] = [logits_target, logits_ligand]
            output_full.append(output_in_file)
    
    # Process positive pairs
    for pair in test_positives:
        seq_target, seq_ligand = pair
        
        # Check if both sequences exist in the embeddings dictionary
        if seq_target in sequence_embeddings_test_dict and seq_ligand in sequence_embeddings_test_dict:
            output_in_file = {}
            logits_target = sequence_embeddings_test_dict[seq_target]
            logits_ligand = sequence_embeddings_test_dict[seq_ligand]
            output_in_file['true'] = [logits_target, logits_ligand]
            output_full.append(output_in_file)

    
    return output_full



In [4]:
file_path_unique_test = 'db_unfiltered/elements_uniqueclassification_aug_types_test.pkl'
test_unique_sequences = load_pkl(file_path_unique_test)

file_path_test_negatives = 'db_unfiltered/non_interacting_pairsclassification_aug_types_test.pkl'
test_negatives = load_pkl(file_path_test_negatives)
file_path_test_positives = 'db_unfiltered/interacting_pairsclassification_aug_types_test.pkl'
test_positives = load_pkl(file_path_test_positives)

embeddings_paths = 'test_types/representations'

sequence_embeddings_test_dict = get_sequence_embeddings(test_unique_sequences, embeddings_paths)

output_full = create_output_full(test_negatives, test_positives, sequence_embeddings_test_dict)
out_path = 'classification_test_RNAFM.p'
with open(out_path, 'wb') as f:
    pickle.dump(output_full, f)

In [5]:
len(output_full)

193308

In [6]:
file_path_unique_train = 'db_unfiltered/elements_uniqueclassification_aug_types_training.pkl'
train_unique_sequences = load_pkl(file_path_unique_train) # list of sequences

file_path_train_negatives = 'db_unfiltered/non_interacting_pairsclassification_aug_types_training.pkl'
train_negatives = load_pkl(file_path_train_negatives)
file_path_train_positives = 'db_unfiltered/interacting_pairsclassification_aug_types_training.pkl'
train_positives = load_pkl(file_path_train_positives)
embeddings_paths = 'train_types/representations'

sequence_embeddings_train_dict = get_sequence_embeddings(train_unique_sequences, embeddings_paths)

output_full = create_output_full(train_negatives, train_positives, sequence_embeddings_train_dict)
out_path = 'classification_train_RNAFM.p'
with open(out_path, 'wb') as f:
    pickle.dump(output_full, f)

In [7]:
len(output_full)

1622240

In [4]:
#mirna-lncrna
# train

file_path_unique_train = 'db_unfiltered/elements_uniqueclassification_mirna-lncrna_aug_types_training.pkl'
train_unique_sequences = load_pkl(file_path_unique_train) # list of sequences

file_path_train_negatives = 'db_unfiltered/non_interacting_pairsclassification_mirna-lncrna_aug_types_training.pkl'
train_negatives = load_pkl(file_path_train_negatives)

file_path_train_positives = 'db_unfiltered/interacting_pairsclassification_mirna-lncrna_aug_types_training.pkl'
train_positives = load_pkl(file_path_train_positives)

embeddings_paths = 'mirna-lncrna_train/representations'

sequence_embeddings_train_dict = get_sequence_embeddings(train_unique_sequences, embeddings_paths)

output_full = create_output_full(train_negatives, train_positives, sequence_embeddings_train_dict)
out_path = 'classification_mirna_lncrna_train_RNAFM.p'
with open(out_path, 'wb') as f:
    pickle.dump(output_full, f)

In [7]:
#mirna-lncrna
# test

file_path_unique_test = 'db_unfiltered/elements_uniqueclassification_mirna-lncrna_aug_types_test.pkl'
test_unique_sequences = load_pkl(file_path_unique_test)

file_path_test_negatives = 'db_unfiltered/non_interacting_pairsclassification_mirna-lncrna_aug_types_test.pkl'
test_negatives = load_pkl(file_path_test_negatives)
file_path_test_positives = 'db_unfiltered/interacting_pairsclassification_mirna-lncrna_aug_types_test.pkl'
test_positives = load_pkl(file_path_test_positives)

embeddings_paths = 'mirna-lncrna_test/representations'

sequence_embeddings_test_dict = get_sequence_embeddings(test_unique_sequences, embeddings_paths)

output_full = create_output_full(test_negatives, test_positives, sequence_embeddings_test_dict)
out_path = 'classification_mirna_lncrna_test_RNAFM.p'
with open(out_path, 'wb') as f:
    pickle.dump(output_full, f)

In [5]:
#mirna-mirna
# train
file_path_unique_train = 'db_unfiltered/elements_uniqueclassification_mirna-mirna_aug_types_training.pkl'
train_unique_sequences = load_pkl(file_path_unique_train) # list of sequences

file_path_train_negatives = 'db_unfiltered/non_interacting_pairsclassification_mirna-mirna_aug_types_training.pkl'
train_negatives = load_pkl(file_path_train_negatives)

file_path_train_positives = 'db_unfiltered/interacting_pairsclassification_mirna-mirna_aug_types_training.pkl'
train_positives = load_pkl(file_path_train_positives)

embeddings_paths = 'mirna-mirna_train/representations'

sequence_embeddings_train_dict = get_sequence_embeddings(train_unique_sequences, embeddings_paths)

output_full = create_output_full(train_negatives, train_positives, sequence_embeddings_train_dict)
out_path = 'classification_mirna_mirna_train_RNAFM.p'
with open(out_path, 'wb') as f:
    pickle.dump(output_full, f)

In [8]:
#mirna-mirna
# test
file_path_unique_test = 'db_unfiltered/elements_uniqueclassification_mirna-mirna_aug_types_test.pkl'
test_unique_sequences = load_pkl(file_path_unique_test)

file_path_test_negatives = 'db_unfiltered/non_interacting_pairsclassification_mirna-mirna_aug_types_test.pkl'
test_negatives = load_pkl(file_path_test_negatives)
file_path_test_positives = 'db_unfiltered/interacting_pairsclassification_mirna-mirna_aug_types_test.pkl'
test_positives = load_pkl(file_path_test_positives)

embeddings_paths = 'mirna-mirna_test/representations'

sequence_embeddings_test_dict = get_sequence_embeddings(test_unique_sequences, embeddings_paths)

output_full = create_output_full(test_negatives, test_positives, sequence_embeddings_test_dict)
out_path = 'classification_mirna_mirna_test_RNAFM.p'
with open(out_path, 'wb') as f:
    pickle.dump(output_full, f)

In [6]:
#mirna-snorna
# train
file_path_unique_train = 'db_unfiltered/elements_uniqueclassification_mirna-snorna_aug_types_training.pkl'
train_unique_sequences = load_pkl(file_path_unique_train) # list of sequences

file_path_train_negatives = 'db_unfiltered/non_interacting_pairsclassification_mirna-snorna_aug_types_training.pkl'
train_negatives = load_pkl(file_path_train_negatives)

file_path_train_positives = 'db_unfiltered/interacting_pairsclassification_mirna-snorna_aug_types_training.pkl'
train_positives = load_pkl(file_path_train_positives)

embeddings_paths = 'mirna-snorna_train/representations'

sequence_embeddings_train_dict = get_sequence_embeddings(train_unique_sequences, embeddings_paths)

output_full = create_output_full(train_negatives, train_positives, sequence_embeddings_train_dict)
out_path = 'classification_mirna_snorna_train_RNAFM.p'
with open(out_path, 'wb') as f:
    pickle.dump(output_full, f)

In [9]:
#mirna-snorna
# test
file_path_unique_test = 'db_unfiltered/elements_uniqueclassification_mirna-snorna_aug_types_test.pkl'
test_unique_sequences = load_pkl(file_path_unique_test)

file_path_test_negatives = 'db_unfiltered/non_interacting_pairsclassification_mirna-snorna_aug_types_test.pkl'
test_negatives = load_pkl(file_path_test_negatives)
file_path_test_positives = 'db_unfiltered/interacting_pairsclassification_mirna-snorna_aug_types_test.pkl'
test_positives = load_pkl(file_path_test_positives)

embeddings_paths = 'mirna-snorna_test/representations'

sequence_embeddings_test_dict = get_sequence_embeddings(test_unique_sequences, embeddings_paths)

output_full = create_output_full(test_negatives, test_positives, sequence_embeddings_test_dict)
out_path = 'classification_mirna_snorna_test_RNAFM.p'
with open(out_path, 'wb') as f:
    pickle.dump(output_full, f)