In [None]:
import numpy as np
import os
import torch
import itertools
import sys

current_dir = os.getcwd()

parent_dir = os.path.dirname(current_dir)

sys.path.append(parent_dir)

from retnet import RetNet,RetNetConfig
from scipy.optimize import linear_sum_assignment

def constraint(seq):
    L = len(seq)
    matrix = np.zeros((L, L), dtype=int)

    for i in range(4, L):
        for j in range(i-4):
            base_i = seq[i]
            base_j = seq[j]
            if  ((base_i == 'A' and base_j == 'U') or (base_i == 'U' and base_j == 'A') or
                (base_i == 'C' and base_j == 'G') or (base_i == 'G' and base_j == 'C') or
                (base_i == 'G' and base_j == 'U') or (base_i == 'U' and base_j == 'G') or
                base_i == 'N' or base_j == 'N'):
                matrix[i, j] = 1
    
    return matrix



class outer_concat(torch.nn.Module):
    def __init__(self):
        super(outer_concat, self).__init__()

    def forward(self, x1, x2):
        seq_len = x1.shape[1]
        x1 = x1.unsqueeze(-2).expand(-1, -1, seq_len, -1)
        x2 = x2.unsqueeze(-3).expand(-1, seq_len, -1, -1)
        x = torch.concat((x1,x2),dim=-1)

        return x


class ResNet2DBlock(torch.nn.Module):
    def __init__(self, embed_dim, kernel_size=3, bias=False):
        super().__init__()

        self.conv_net = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=embed_dim, out_channels=embed_dim, kernel_size=1, bias=bias),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=embed_dim, out_channels=embed_dim, kernel_size=kernel_size, bias=bias, padding="same"),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=embed_dim, out_channels=embed_dim, kernel_size=1, bias=bias),
            torch.nn.ReLU()
        )

    def forward(self, x):
        residual = x

        x = self.conv_net(x)
        x = x + residual

        return x
    
class ResNet2D(torch.nn.Module):
    def __init__(self, embed_dim, num_blocks, kernel_size=3, bias=False):
        super().__init__()

        self.blocks = torch.nn.ModuleList(
            [
                ResNet2DBlock(embed_dim, kernel_size, bias=bias) for _ in range(num_blocks)
            ]
        )

    def forward(self, x):
        for block in self.blocks:
            x = block(x)

        return x    
    
class ResNet2D_classifier(torch.nn.Module):
    def __init__(self):
        super(ResNet2D_classifier, self).__init__()
        self.outer_concat = outer_concat()
        self.linear_in = torch.nn.Linear(768,128)
        self.resnet = ResNet2D(128, 16, 3, bias=True)
        self.conv_out = torch.nn.Conv2d(128, 1, kernel_size=3, padding="same")
        

    def forward(self, x):
        x = self.outer_concat(x, x)
        x = self.linear_in(x)
        x = x.permute(0,3,1,2)
        x = self.resnet(x)

        x = self.conv_out(x)
        x = x.squeeze(1)
        
        return x
    
class rnaret_ssp_model(torch.torch.nn.Module): 
    def __init__(self, args):
        super(rnaret_ssp_model, self).__init__()
        self.ret = RetNet(args)
        self.classifier = ResNet2D_classifier()

    def forward(self, x):      
        _,aux  = self.ret(x)
        x = aux['inner_states'][-1]
        x = self.classifier(x)
        return x
    
class post_process(torch.nn.Module):
    def __init__(self):
        super(post_process, self).__init__()
        
    def forward(self, x, mask):
        x = torch.sigmoid(x)
        x = x * mask
        
        sec_struct = torch.where(x > 0.5, torch.ones_like(x), torch.zeros_like(x))
        x = x * sec_struct
        
        B, L, _ = x.shape
        
        for b in range(B):
            tmp = x[b].clone()
            row_ind, col_ind = linear_sum_assignment(-tmp.detach().cpu().numpy())
            binary_matrix = torch.zeros_like(tmp)
            for r, c in zip(row_ind, col_ind):
                binary_matrix[r, c] = 1
                
            sec_struct[b] = binary_matrix

        sec_struct = sec_struct * mask
        
        sec_struct = sec_struct + sec_struct.transpose(1,2)
        
        for b in range(B):
            for i in range(L):
                if torch.sum(sec_struct[b, i, :]) > 1:
                    max_idx = torch.argmax(sec_struct[b, i, :])
                    sec_struct[b, i, :] = 0
                    sec_struct[b, i, max_idx] = 1

                if torch.sum(sec_struct[b, :, i]) > 1:
                    max_idx = torch.argmax(sec_struct[b, :, i])
                    sec_struct[b, :, i] = 0
                    sec_struct[b, max_idx, i] = 1
        
        return sec_struct
        
    
class seq_tokenizer():
    def __init__(self, k=5, max_len=512):
        self.k = k
        self.max_len = max_len
    def tokenize(self, seq):
        kmer_list = np.array([''.join(p) for p in itertools.product('ATCG', repeat=self.k)])
        kmer_to_index = {kmer: idx + 6 for idx, kmer in enumerate(kmer_list)}
        seq = seq.upper()
        seq = seq.replace('U','T')
        seq_len = len(seq)
        
        tokens = np.zeros(self.max_len, dtype=np.int16)
        
        kmers = np.array([seq[i:i+self.k] for i in range(seq_len - self.k + 1)])
        
        indices = np.array([kmer_to_index.get(kmer, 2) for kmer in kmers])
        
        tokens[self.k//2:self.k//2+len(indices)] = indices[:]
        
        tokens[:self.k//2] = 1
        tokens[self.k//2+len(indices):self.k//2+len(indices)+(self.k-1)//2] = 1
        return tokens
    
def parse_bpseq_file(file, max_len=512):
    matrix = np.zeros((max_len, max_len), dtype=np.int8)
    seq = ''

    with open(file, 'r') as f:
        if file.endswith(".bpseq"):
            for line in f:
                parts = line.strip().split()
                idx, nt, pair = parts
                seq = seq + nt

    return seq

class seq_tokenizer():
    def __init__(self, k=5, max_len=512):
        self.k = k
        self.max_len = max_len
    def tokenize(self, seq):
        kmer_list = np.array([''.join(p) for p in itertools.product('ATCG', repeat=self.k)])
        kmer_to_index = {kmer: idx + 6 for idx, kmer in enumerate(kmer_list)}
        seq = seq.upper()
        seq = seq.replace('U','T')
        seq_len = len(seq)
        
        tokens = np.zeros(seq_len, dtype=np.int16)
        
        kmers = np.array([seq[i:i+self.k] for i in range(seq_len - self.k + 1)])
        
        indices = np.array([kmer_to_index.get(kmer, 2) for kmer in kmers])
        
        tokens[self.k//2:self.k//2+len(indices)] = indices[:]
        
        tokens[:self.k//2] = 1
        tokens[self.k//2+len(indices):self.k//2+len(indices)+(self.k-1)//2] = 1
        return tokens


class ssp_dataset(torch.utils.data.Dataset):
    def __init__(self, data_dir,max_len, k=5):
        self.data_dir = data_dir
        self.matrices = []
        self.seqs = []
        self.constraint = []
        self.tokenizer = seq_tokenizer(k=k,max_len=max_len)
        self.labels = []

        for root, dirs, files in os.walk(data_dir):
            for filename in files:
                if filename.endswith(".bpseq"):
                    file_path = os.path.join(root, filename)
                    matrix, seq_len, seq = parse_bpseq_file(file_path, max_len=max_len)
                    if matrix is not None and seq_len <= max_len:
                        con_matrix = constraint(seq, max_len)
                        seq = self.tokenizer.tokenize(seq)
                        self.matrices.append(matrix)
                        self.seqs.append(seq)
                        self.constraint.append(con_matrix)
                        self.labels.append(filename.split('_')[0])

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, idx):
        return self.seqs[idx], self.matrices[idx], self.constraint[idx],self.labels[idx]
    
def compute_metrics(target, prediction):
    positive_mask = target == 1
    negative_mask = target == 0
    pred_positive_mask = prediction > 0
    
    tp = torch.sum(torch.logical_and(positive_mask, pred_positive_mask)).item()
    fp = torch.sum(torch.logical_and(negative_mask, pred_positive_mask)).item()
    fn = torch.sum(torch.logical_and(positive_mask, ~pred_positive_mask)).item()
    tn = torch.sum(torch.logical_and(negative_mask, ~pred_positive_mask)).item()
    
    precision = tp / (tp + fp) if tp + fp != 0 else 0
    recall = tp / (tp + fn) if tp + fn != 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if precision + recall != 0 else 0
    
    return precision, recall, f1

    


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")    

k = 1
model_config = RetNetConfig(vocab_size=4**k+6,retnet_embed_dim=384, retnet_value_embed_dim=512,
                            retnet_ffn_embed_dim=512,retnet_layers=8,retnet_retention_heads=4,
                            dropout=0.2,activation_dropout=0.2)
    
model = rnaret_ssp_model(model_config)
model.load_state_dict(torch.load("../model/ssp/RNAStrAlign_1mer.pth",weights_only=True))

post_processor = post_process()
model = model.to(device)
    

test_datasets = []
for data_dir in "../data/archiveII":
    dataset = ssp_dataset(data_dir, max_len=600, k=k)
    test_datasets.append(dataset)
test_dataset = torch.utils.data.ConcatDataset(test_datasets)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1)
        


In [None]:
model.eval()
with torch.amp.autocast(device_type='cuda'):
    with torch.no_grad():
        total_precision = []
        total_recall = []
        total_f1 = []
        total_type = []
        lengths = []
        for x, y, mask,type in test_dataloader:
            x = x.to(torch.long).to(device) 
            length = x.size(1)
            y = y.to(torch.float32).to(device)
            mask = mask.to(torch.int).to(device)
            prob = model(x)
            pred = post_processor(prob, mask)
            precision, recall, f1 = compute_metrics(y, pred)
            total_precision.append(precision)
            total_recall.append(recall)
            total_f1.append(f1)
            total_type.extend(type)
            lengths.append(length)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

data = {'Type': total_type, 'F1 Score': total_f1}
df = pd.DataFrame(data)

order = ['5s', '16s', '23s', 'tRNA', 'tmRNA', 'grp1', 'srp', 'RNaseP']

df['Type'] = pd.Categorical(df['Type'], categories=order, ordered=True)
df = df.sort_values('Type')

mean_f1_scores = df.groupby('Type')['F1 Score'].mean().reindex(order)

custom_colors = sns.color_palette("Set3", n_colors=len(df['Type'].unique()))

plt.figure(figsize=(8, 4))

violin_plot = sns.violinplot(
    x='Type', 
    y='F1 Score', 
    data=df, 
    palette=custom_colors,
    order=order,  
    inner='box',  
    linewidth=1,  
    linecolor='grey',
    width=0.6
)

plt.scatter(df['Type'], df['F1 Score'], s=3, color='black', zorder=5) 

y_offset = -0.1 
for i, (type_name, mean_score) in enumerate(mean_f1_scores.items()):
    plt.text(i, 1 - y_offset, f'{mean_score:.3f}', horizontalalignment='center', fontsize=10, verticalalignment='bottom')

plt.ylim(0, 1.1)
plt.xlabel('RNA Type')
plt.ylabel('F1 Score')
plt.savefig("sspF1.pdf", format='pdf',bbox_inches='tight')

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import linregress

plt.figure(figsize=(7, 4))
sns.regplot(x=lengths, y=total_f1, scatter_kws={'alpha':0.5, 's':4}, line_kws={'color': '#E29135','lw': 1.2})
plt.xlabel('Sequence Length')
plt.ylabel('F1 Score')
plt.grid(True)
plt.savefig("sspLen.pdf", format='pdf',bbox_inches='tight')