In [1]:
import pandas as pd 
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, random_split
from tqdm.auto import tqdm
from sklearn.metrics import precision_recall_fscore_support
from src.utils import inference, visualize_predictions
from src.data import RNADataset, RNADataset_old
from src.checkpoint import load_checkpoint
from src.pipeline import predict_structure
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import warnings
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=RuntimeWarning)

In [2]:
relu_model = "/home/sumon/workspace/git_repos/Eternal/models/ce_ep6/best.pt"
gelu_model = 'checkpoints/checkpoint.pt'
model, optimizer, epoch = load_checkpoint(relu_model)
print(f"Epoch: {epoch+1}")

  checkpoint = torch.load(path, map_location=device)


Loaded checkpoint from /home/sumon/workspace/git_repos/Eternal/models/ce_ep6/best.pt
Epoch: 7


In [3]:
df = pd.read_parquet("data/test.parquet")
test_df = df.groupby("family").agg({"sequence": list, "secondary_structure":list}).reset_index()

In [4]:
test_df

Unnamed: 0,family,sequence,secondary_structure
0,16S_rRNA,[AUUCUGGUUGAUCCUGCCAGAGGCCGCUGCUAUCCGGCUGGGACU...,[...(((((...(((.))))).((((((((((.((((((((((......
1,23S_rRNA,[GGUUAAGUUAGAAAGGGCGCACGGUGGAUGCCUUGGCACUAGGAG...,[((((((((......((((((((((.....(((..(((((((((((...
2,5S_rRNA,[GGAUACGGCCAUACUGCGCAGAAAGCACCGCUUCCCAUCCGAACA...,[(((((((((....((((((((.....((((((............)...
3,RNaseP,[GAGGAAAGUCCCGCCUCCAGAUCAAGGGAAGUCCCGCGAGGGACA...,[.....(((.(((((((((.(((((.((((((((((....)))))(...
4,SRP,[GGGGGCCCUGGUCCUCCCGCAACACUAGUUCGUGAACCUGGUCAG...,[(((((((((((((.((((((.((((..((((((.....((((......
5,group_II_intron,[AUAAAUCUAAGUGUAGUGCUUGGUGUAUUGAUUUUUUUUGGAAAG...,[................................................
6,group_I_intron,[CUCAACAUGCAAGAUUAACUAAGUGCUUAGCAGUUAGUUUUGCUA...,[(((.....((((((((((((...........)))))))))))).....
7,tRNA,[GGGCUCGUAGAUCAGCGGUAGAUCGCUUCCUUCGCAAGGAAGAGG...,[(((((((..((((.......)))).(((((.......)))))......
8,telomerase,[ACCUAACCCUGAUUUUCAUUAGCUGUGGGUUCUGGUCUUUUGUUC...,[.....................(((((((((......((((........
9,tmRNA,[GGGGGCGUCACGGUUUCGACGGGAUUGACUGCGGCAAAGAGGCAU...,[(((((((............(((((((...(((((.((((((((((...


In [5]:
test_loaders = {}

for _, row in test_df.iterrows():
    dataset = RNADataset_old(row.sequence, row.secondary_structure)
    loader = DataLoader(
        dataset,
        batch_size=28,
        shuffle=False,
    )
    test_loaders[row.family] = loader

In [6]:
def structure_predictions(model, test_loaders, structure_decoder):
    results = {}
    
    for family, loader in tqdm(test_loaders.items(), desc="Processing families"):
        true_structs = []
        pred_structs = []
        
        for batch in tqdm(loader, desc=f"Family: {family}", leave=False):
            seq_tensor = batch["sequence"].to(device)
            struct_tensor = batch["structure"].to(device)
            lengths = batch["length"]
            true_struct = batch["raw_structure"]
            
            with torch.no_grad(), warnings.catch_warnings():
                warnings.simplefilter("ignore")
                logits = model(seq_tensor)
                preds = torch.argmax(logits, dim=-1)
            
            pred_struct = []
            for idx, pred in enumerate(preds):
                length = lengths[idx]
                t = "".join([structure_decoder[i.item()] for i in pred[:length]])
                pred_struct.append(t)
            
            true_structs.extend(list(true_struct))
            pred_structs.extend(list(pred_struct))

        results[family] = {
            'true_structs': true_structs,
            'pred_structs': pred_structs,
        }
    
    return results

def calculate_structure_accuracy(df):
    """
    Calculate accuracy between secondary_structure and predicted columns for each family.
    
    Parameters:
    df (pandas.DataFrame): DataFrame containing 'secondary_structure', 'predicted', and 'family' columns
    
    Returns:
    pandas.DataFrame: Average accuracy per family
    """
    def sequence_accuracy(true_seq, pred_seq):
        """Calculate accuracy between two sequences"""
        if len(true_seq) != len(pred_seq):
            return 0.0
        matches = sum(1 for t, p in zip(true_seq, pred_seq) if t == p)
        return matches / len(true_seq)
    
    # Initialize results dictionary
    results = {}
    
    # Calculate accuracy for each family
    for family in df['family'].unique():
        family_data = df[df['family'] == family]
        
        # Calculate accuracy for each sequence in the family
        accuracies = []
        for _, row in family_data.iterrows():
            acc = sequence_accuracy(row['secondary_structure'], row['predicted_structure'])
            accuracies.append(acc)
        
        # Store results
        results[family] = {
            'average_accuracy': np.mean(accuracies),
            'std_accuracy': np.std(accuracies),
            'sample_count': len(accuracies)
        }
    
    # Convert to DataFrame
    metrics_df = pd.DataFrame.from_dict(results, orient='index')
    
    # Calculate overall metrics
    all_accuracies = []
    for _, row in df.iterrows():
        acc = sequence_accuracy(row['secondary_structure'], row['predicted_structure'])
        all_accuracies.append(acc)
    
    metrics_df.loc['Overall'] = {
        'average_accuracy': np.mean(all_accuracies),
        'std_accuracy': np.std(all_accuracies),
        'sample_count': len(all_accuracies)
    }
    
    return metrics_df.round(4)

In [7]:
idx_to_struct = dataset.idx_to_struct

idx_to_struct[7] = "("
idx_to_struct[8] = ")"

In [8]:
results = structure_predictions(model, test_loaders, idx_to_struct)

Processing families:   0%|          | 0/10 [00:00<?, ?it/s]

Family: 16S_rRNA:   0%|          | 0/4 [00:00<?, ?it/s]

Family: 23S_rRNA:   0%|          | 0/2 [00:00<?, ?it/s]

Family: 5S_rRNA:   0%|          | 0/46 [00:00<?, ?it/s]

Family: RNaseP:   0%|          | 0/17 [00:00<?, ?it/s]

Family: SRP:   0%|          | 0/34 [00:00<?, ?it/s]

Family: group_II_intron:   0%|          | 0/1 [00:00<?, ?it/s]

Family: group_I_intron:   0%|          | 0/4 [00:00<?, ?it/s]

Family: tRNA:   0%|          | 0/20 [00:00<?, ?it/s]

Family: telomerase:   0%|          | 0/2 [00:00<?, ?it/s]

Family: tmRNA:   0%|          | 0/17 [00:00<?, ?it/s]

In [9]:
predicted_df = pd.DataFrame(results).T
predicted_df = predicted_df.explode(["true_structs", "pred_structs"]).reset_index()
predicted_df.columns = ["family", "secondary_structure", "predicted_structure"]
predicted_df

Unnamed: 0,family,secondary_structure,predicted_structure
0,16S_rRNA,...(((((...(((.))))).((((((((((.((((((((((.......,.(.((((..(.())((().(.)()((((()(.((((((((((.((....
1,16S_rRNA,...(((((.......))))).((((((((((.((((((((((.......,.....(..........(..(.(((..(.(((....(((.(((.(.....
2,16S_rRNA,.......(((((...(.((((.(.(((.(((((((.((((((((((...,......(..)......((......(((..((((((..(.((..(((...
3,16S_rRNA,.......(((((.(((((((..((..((((((.((((((((((......,............((.(.......(..(.(.(...(..(((((.(.....
4,16S_rRNA,.(.(..((...((((.(((..(((((((..((((..((((((((((...,....(..(....(.(..(((.(((.(.(.((((.(...((((((((...
...,...,...,...
3970,tmRNA,(((((((............((((((((.(....((((..(((.(((...,.(((.((..(.....(.....((.(((.((....(.(..(((((((...
3971,tmRNA,(((((((............((((((((.....((((.((((..(((...,......(..(..(..(.(..(....(.((.(..(((.(.((.((.(...
3972,tmRNA,(((((((............((((((((.....((((.((((..(((...,...(..((.(..(....(....(.((((((...((((((((.((.(...
3973,tmRNA,(((((((............((((((((....(((((.((((..(((...,.(....(...(.....((..(.((((((......((((((((((.(...


In [10]:
calculate_structure_accuracy(predicted_df)

Unnamed: 0,average_accuracy,std_accuracy,sample_count
16S_rRNA,0.4782,0.0497,110
23S_rRNA,0.4649,0.025,35
5S_rRNA,0.5068,0.0514,1283
RNaseP,0.4639,0.0329,454
SRP,0.4935,0.0617,928
group_II_intron,0.4945,0.0186,11
group_I_intron,0.4702,0.0592,98
tRNA,0.7176,0.1313,557
telomerase,0.4422,0.0292,37
tmRNA,0.4861,0.0378,462
