In [1]:
# --- MSA parsing utility ---
import numpy as np

def parse_msa(msa_path: str, max_seqs: int = 100) -> np.ndarray:
    with open(msa_path, 'r') as f:
        lines = f.read().splitlines()
    seqs = []
    current_seq = ""
    for line in lines:
        if line.startswith(">"):
            if current_seq:
                seqs.append(current_seq)
                current_seq = ""
        else:
            current_seq += line.strip()
    if current_seq:
        seqs.append(current_seq)
    seqs = seqs[:max_seqs]
    vocab = {'A': 0, 'U': 1, 'G': 2, 'C': 3, 'N': 4, '-': 5, '.': 5}
    L = len(seqs[0])
    N = len(seqs)
    msa_tensor = np.zeros((N, L, 7), dtype=np.float32)
    for i, seq in enumerate(seqs):
        for j, res in enumerate(seq):
            idx = vocab.get(res.upper(), 6)  # unknown = 6
            msa_tensor[i, j, idx] = 1.0
    return msa_tensor


In [2]:
import pandas as pd
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
import pickle
import os
import sys

In [3]:
config = {
    "seed": 0,
    "cutoff_date": "2020-01-01",
    "test_cutoff_date": "2022-05-01",
    "max_len": 384,
    "batch_size": 1,
    "learning_rate": 1e-4,
    "weight_decay": 0.0,
    "mixed_precision": "bf16",
    "model_config_path": "../working/configs/pairwise.yaml",  # Adjust path as needed
    "epochs": 10,
    "cos_epoch": 5,
    "loss_power_scale": 1.0,
    "max_cycles": 1,
    "grad_clip": 0.1,
    "gradient_accumulation_steps": 1,
    "d_clamp": 30,
    "max_len_filter": 9999999,
    "structural_violation_epoch": 50,
    "balance_weight": False,
}

In [4]:
test_data=pd.read_csv("/kaggle/input/validation-sequences-clean-csv/validation_sequences_clean.csv")
test_data.head()

Unnamed: 0,target_id,sequence,temporal_cutoff,description,all_sequences
0,9L5R_2,AGCUCUCUUUGCCUUUUGGCUUAGAUCAAGUGUAGUAUCUGUUCUU...,2025-03-12,Cryo-EM structure of the thermophile spliceoso...,>9L5R_1|Chain A[auth 2]|U2 snRNA|Chaetomium th...
1,9GFT_AU,GGGGCUAUAGCUCAGCUGGGAGAGCGCUUGCAUGGCAUGCAAGAGG...,2025-02-12,"Structure of the HrpA-bound E. coli disome, Cl...",">9GFT_1|Chains A[auth 0], N[auth AA]|16S ribos..."
2,9L0R_K,GCCGUCUCAAUAGUGGCUUAGCACAGAUAAUCCAUAGCGAUAUGGG...,2025-03-19,Streptococcus agalactiae GOLLD RNA dodecamer,">9L0R_1|Chains A, B, C, D, E, F, G, H, I, J, K..."
3,9GFT_A3,GGCUACGUAGCUCAGUUGGUUAGAGCACAUCACUCAUAAUGAUGGG...,2025-02-12,"Structure of the HrpA-bound E. coli disome, Cl...",">9GFT_1|Chains A[auth 0], N[auth AA]|16S ribos..."
4,9B2K_B,AAACAGCAUAGCAAGUUAAAAUAAGGCUAGUCCGUUAUCAACUUGA...,2025-03-26,SpCas9 with dual-guide RNA in open conformation,>9B2K_1|Chain A|RNA (5'-R(P*GP*UP*UP*UP*UP*AP*...


# Dataset

In [5]:
from torch.utils.data import Dataset, DataLoader

def create_msa_features(msa_tensor: np.ndarray) -> dict:
    """Create MSA-derived features for the model"""
    N, L, _ = msa_tensor.shape
    
    # Conservation score (entropy-based)
    conservation = np.zeros(L)
    for pos in range(L):
        counts = np.sum(msa_tensor[:, pos, :4], axis=0)  # Only AUGC
        total = np.sum(counts)
        if total > 0:
            probs = counts / total
            probs = probs[probs > 0]  # Remove zeros for log
            conservation[pos] = -np.sum(probs * np.log2(probs + 1e-8))
    
    # Coevolution features (simplified mutual information)
    coevolution = np.zeros((L, L))
    for i in range(L):
        for j in range(i+1, L):
            # Joint distribution
            joint_counts = np.zeros((4, 4))
            for seq_idx in range(N):
                res_i = np.argmax(msa_tensor[seq_idx, i, :4])
                res_j = np.argmax(msa_tensor[seq_idx, j, :4])
                joint_counts[res_i, res_j] += 1
            
            # Marginal distributions
            marg_i = np.sum(joint_counts, axis=1)
            marg_j = np.sum(joint_counts, axis=0)
            
            # Mutual information
            mi = 0
            total = np.sum(joint_counts)
            if total > 0:
                for x in range(4):
                    for y in range(4):
                        if joint_counts[x,y] > 0 and marg_i[x] > 0 and marg_j[y] > 0:
                            pxy = joint_counts[x,y] / total
                            px = marg_i[x] / total
                            py = marg_j[y] / total
                            mi += pxy * np.log2(pxy / (px * py + 1e-8) + 1e-8)
            
            coevolution[i, j] = coevolution[j, i] = mi
    
    return {
        'conservation': conservation,
        'coevolution': coevolution,
        'msa_depth': N,
        'msa_tensor': msa_tensor
    }

class RNADatasetWithMSA(Dataset):
    def __init__(self, data, msa_dir=None):
        self.data = data
        self.msa_dir = msa_dir
        # Updated tokens dictionary with fallback for unknown characters
        self.tokens = {'A': 0, 'C': 1, 'G': 2, 'U': 3}
        self.unknown_token = 0  # Map unknown characters to 'A' (index 0)

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # Use get() method with default fallback for unknown characters
        sequence = [self.tokens.get(nt, self.unknown_token) for nt in (self.data.loc[idx, 'sequence'])]
        sequence = np.array(sequence)
        sequence = torch.tensor(sequence)
        
        result = {'sequence': sequence}
        
        # Try to load MSA if directory provided
        if self.msa_dir:
            target_id = self.data.loc[idx, 'target_id']
            msa_path = os.path.join(self.msa_dir, f"{target_id}.MSA.fasta")
            
            if os.path.exists(msa_path):
                try:
                    msa_tensor = parse_msa(msa_path)
                    msa_features = create_msa_features(msa_tensor)
                    
                    result['msa_features'] = {
                        'conservation': torch.tensor(msa_features['conservation'], dtype=torch.float32),
                        'coevolution': torch.tensor(msa_features['coevolution'], dtype=torch.float32),
                        'msa_depth': msa_features['msa_depth'],
                        'has_msa': True
                    }
                except Exception as e:
                    print(f"Warning: Could not load MSA for {target_id}: {e}")
                    result['msa_features'] = self._get_dummy_msa_features(len(sequence))
            else:
                result['msa_features'] = self._get_dummy_msa_features(len(sequence))
        else:
            result['msa_features'] = self._get_dummy_msa_features(len(sequence))
        
        return result
    
    def _get_dummy_msa_features(self, seq_len):
        """Create dummy MSA features when no MSA is available"""
        return {
            'conservation': torch.zeros(seq_len, dtype=torch.float32),
            'coevolution': torch.zeros(seq_len, seq_len, dtype=torch.float32),
            'msa_depth': 1,
            'has_msa': False
        }

msa_directory = "/kaggle/input/stanford-rna-3d-folding/MSA_v2"
test_dataset = RNADatasetWithMSA(test_data, msa_dir=msa_directory)

In [6]:
sys.path.append("/kaggle/input/ribonanzanet2/pytorch/alpha/1")

import torch.nn as nn
from Network import *

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class finetuned_RibonanzaNet(RibonanzaNet):
    def __init__(self, rnet_config, config, pretrained=False):
        rnet_config.dropout=0.1
        rnet_config.use_grad_checkpoint=True
        super(finetuned_RibonanzaNet, self).__init__(rnet_config)
        if pretrained:
            self.load_state_dict(torch.load(config.pretrained_weight_path,map_location='cpu'))
        # self.ct_predictor=nn.Sequential(nn.Linear(64,256),
        #                                 nn.ReLU(),
        #                                 nn.Linear(256,64),
        #                                 nn.ReLU(),
        #                                 nn.Linear(64,1)) 
        self.dropout=nn.Dropout(0.0)

        decoder_dim=config.decoder_dim
        self.structure_module=[SimpleStructureModule(d_model=decoder_dim, nhead=config.decoder_nhead, 
                 dim_feedforward=decoder_dim*4, pairwise_dimension=rnet_config.pairwise_dimension, dropout=0.0) for i in range(config.decoder_num_layers)]
        self.structure_module=nn.ModuleList(self.structure_module)

        self.xyz_embedder=nn.Linear(3,decoder_dim)
        self.xyz_norm=nn.LayerNorm(decoder_dim)
        self.xyz_predictor=nn.Linear(decoder_dim,3)
        
        self.adaptor=nn.Sequential(nn.Linear(rnet_config.ninp,decoder_dim),nn.LayerNorm(decoder_dim))

        self.distogram_predictor=nn.Sequential(nn.LayerNorm(rnet_config.pairwise_dimension),
                                                nn.Linear(rnet_config.pairwise_dimension,40))

        self.time_embedder=SinusoidalPosEmb(decoder_dim)

        self.time_mlp=nn.Sequential(nn.Linear(decoder_dim,decoder_dim),
                                    nn.ReLU(),  
                                    nn.Linear(decoder_dim,decoder_dim))
        self.time_norm=nn.LayerNorm(decoder_dim)

        self.distance2pairwise=nn.Linear(1,rnet_config.pairwise_dimension,bias=False)

        self.pair_mlp=nn.Sequential(nn.Linear(rnet_config.pairwise_dimension,rnet_config.pairwise_dimension),
                                    nn.ReLU(),
                                    nn.Linear(rnet_config.pairwise_dimension,rnet_config.pairwise_dimension))


        #hyperparameters for diffusion
        self.n_times = config.n_times

        #self.model = model
        
        # define linear variance schedule(betas)
        beta_1, beta_T = config.beta_min, config.beta_max
        betas = torch.linspace(start=beta_1, end=beta_T, steps=config.n_times)#.to(device) # follows DDPM paper
        self.sqrt_betas = torch.sqrt(betas)
                                     
        # define alpha for forward diffusion kernel
        self.alphas = 1 - betas
        self.sqrt_alphas = torch.sqrt(self.alphas)
        alpha_bars = torch.cumprod(self.alphas, dim=0)
        self.sqrt_one_minus_alpha_bars = torch.sqrt(1-alpha_bars)
        self.sqrt_alpha_bars = torch.sqrt(alpha_bars)

        self.data_std=config.data_std


    def custom(self, module):
        def custom_forward(*inputs):
            inputs = module(*inputs)
            return inputs
        return custom_forward
    
    def embed_pair_distance(self,inputs):
        pairwise_features,xyz=inputs
        distance_matrix=xyz[:,None,:,:]-xyz[:,:,None,:]
        distance_matrix=(distance_matrix**2).sum(-1).clip(2,37**2).sqrt()
        distance_matrix=distance_matrix[:,:,:,None]
        pairwise_features=pairwise_features+self.distance2pairwise(distance_matrix)

        return pairwise_features

    def forward(self,src,xyz,t):
        
        #with torch.no_grad():
        sequence_features, pairwise_features=self.get_embeddings(src, torch.ones_like(src).long().to(src.device))
        
        distogram=self.distogram_predictor(pairwise_features)

        sequence_features=self.adaptor(sequence_features)

        decoder_batch_size=xyz.shape[0]
        sequence_features=sequence_features.repeat(decoder_batch_size,1,1)
        

        pairwise_features=pairwise_features.expand(decoder_batch_size,-1,-1,-1)

        pairwise_features= checkpoint.checkpoint(self.custom(self.embed_pair_distance), [pairwise_features,xyz],use_reentrant=False)

        time_embed=self.time_embedder(t).unsqueeze(1)
        tgt=self.xyz_norm(sequence_features+self.xyz_embedder(xyz)+time_embed)

        tgt=self.time_norm(tgt+self.time_mlp(tgt))

        for layer in self.structure_module:
            #tgt=layer([tgt, sequence_features,pairwise_features,xyz,None])
            tgt=checkpoint.checkpoint(self.custom(layer),
            [tgt, sequence_features,pairwise_features,xyz,None],
            use_reentrant=False)
            # xyz=xyz+self.xyz_predictor(sequence_features).squeeze(0)
            # xyzs.append(xyz)
            #print(sequence_features.shape)
        
        xyz=self.xyz_predictor(tgt).squeeze(0)
        #.squeeze(0)

        return xyz, distogram
    

    def denoise(self,sequence_features,pairwise_features,xyz,t):
        decoder_batch_size=xyz.shape[0]
        sequence_features=sequence_features.expand(decoder_batch_size,-1,-1)
        pairwise_features=pairwise_features.expand(decoder_batch_size,-1,-1,-1)

        pairwise_features=self.embed_pair_distance([pairwise_features,xyz])

        sequence_features=self.adaptor(sequence_features)
        time_embed=self.time_embedder(t).unsqueeze(1)
        tgt=self.xyz_norm(sequence_features+self.xyz_embedder(xyz)+time_embed)
        tgt=self.time_norm(tgt+self.time_mlp(tgt))
        #xyz_batch_size=xyz.shape[0]
        


        for layer in self.structure_module:
            tgt=layer([tgt, sequence_features,pairwise_features,xyz,None])
            # xyz=xyz+self.xyz_predictor(sequence_features).squeeze(0)
            # xyzs.append(xyz)
            #print(sequence_features.shape)
        xyz=self.xyz_predictor(tgt).squeeze(0)
        # print(xyz.shape)
        # exit()
        return xyz


    def extract(self, a, t, x_shape):
        """
            from lucidrains' implementation
                https://github.com/lucidrains/denoising-diffusion-pytorch/blob/beb2f2d8dd9b4f2bd5be4719f37082fe061ee450/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L376
        """
        b, *_ = t.shape
        out = a.gather(-1, t)
        return out.reshape(b, *((1,) * (len(x_shape) - 1)))
    
    def scale_to_minus_one_to_one(self, x):
        # according to the DDPMs paper, normalization seems to be crucial to train reverse process network
        return x * 2 - 1
    
    def reverse_scale_to_zero_to_one(self, x):
        return (x + 1) * 0.5
    
    def make_noisy(self, x_zeros, t): 
        # assume we get raw data, so center and scale by 35
        x_zeros = x_zeros - torch.nanmean(x_zeros,1,keepdim=True)
        x_zeros = x_zeros/self.data_std
        #rotate randomly
        x_zeros = random_rotation_point_cloud_torch_batch(x_zeros)


        # perturb x_0 into x_t (i.e., take x_0 samples into forward diffusion kernels)
        epsilon = torch.randn_like(x_zeros).to(x_zeros.device)
        
        sqrt_alpha_bar = self.extract(self.sqrt_alpha_bars.to(x_zeros.device), t, x_zeros.shape)
        sqrt_one_minus_alpha_bar = self.extract(self.sqrt_one_minus_alpha_bars.to(x_zeros.device), t, x_zeros.shape)
        
        # Let's make noisy sample!: i.e., Forward process with fixed variance schedule
        #      i.e., sqrt(alpha_bar_t) * x_zero + sqrt(1-alpha_bar_t) * epsilon
        noisy_sample = x_zeros * sqrt_alpha_bar + epsilon * sqrt_one_minus_alpha_bar
    
        return noisy_sample.detach(), epsilon
    
    
    # def forward(self, x_zeros):
    #     x_zeros = self.scale_to_minus_one_to_one(x_zeros)
        
    #     B, _, _, _ = x_zeros.shape
        
    #     # (1) randomly choose diffusion time-step
    #     t = torch.randint(low=0, high=self.n_times, size=(B,)).long().to(x_zeros.device)
        
    #     # (2) forward diffusion process: perturb x_zeros with fixed variance schedule
    #     perturbed_images, epsilon = self.make_noisy(x_zeros, t)
        
    #     # (3) predict epsilon(noise) given perturbed data at diffusion-timestep t.
    #     pred_epsilon = self.model(perturbed_images, t)
        
    #     return perturbed_images, epsilon, pred_epsilon
    
    
    def denoise_at_t(self, x_t, sequence_features, pairwise_features, timestep, t):
        B, _, _ = x_t.shape
        if t > 1:
            z = torch.randn_like(x_t).to(sequence_features.device)
        else:
            z = torch.zeros_like(x_t).to(sequence_features.device)
        
        # at inference, we use predicted noise(epsilon) to restore perturbed data sample.
        epsilon_pred = self.denoise(sequence_features, pairwise_features, x_t, timestep)
        
        alpha = self.extract(self.alphas.to(x_t.device), timestep, x_t.shape)
        sqrt_alpha = self.extract(self.sqrt_alphas.to(x_t.device), timestep, x_t.shape)
        sqrt_one_minus_alpha_bar = self.extract(self.sqrt_one_minus_alpha_bars.to(x_t.device), timestep, x_t.shape)
        sqrt_beta = self.extract(self.sqrt_betas.to(x_t.device), timestep, x_t.shape)
        
        # denoise at time t, utilizing predicted noise
        x_t_minus_1 = 1 / sqrt_alpha * (x_t - (1-alpha)/sqrt_one_minus_alpha_bar*epsilon_pred) + sqrt_beta*z
        
        return x_t_minus_1#.clamp(-1., 1)
                
    def sample(self, src, N):
        # start from random noise vector, NxLx3
        x_t = torch.randn((N, src.shape[1], 3)).to(src.device)
        
        # autoregressively denoise from x_T to x_0
        #     i.e., generate image from noise, x_T

        #first get conditioning
        sequence_features, pairwise_features=self.get_embeddings(src, torch.ones_like(src).long().to(src.device))
        # sequence_features=sequence_features.expand(N,-1,-1)
        # pairwise_features=pairwise_features.expand(N,-1,-1,-1)
        distogram=self.distogram_predictor(pairwise_features).squeeze()
        distogram=distogram.squeeze()[:,:,2:40]*torch.arange(2,40).float().cuda() 
        distogram=distogram.sum(-1)  

        for t in range(self.n_times-1, -1, -1):
            timestep = torch.tensor([t]).repeat_interleave(N, dim=0).long().to(src.device)
            x_t = self.denoise_at_t(x_t, sequence_features, pairwise_features, timestep, t)
        
        # denormalize x_0 into 0 ~ 1 ranged values.
        #x_0 = self.reverse_scale_to_zero_to_one(x_t)
        x_0 = x_t * self.data_std
        return x_0, distogram




class SimpleStructureModule(nn.Module):

    def __init__(self, d_model, nhead, 
                 dim_feedforward, pairwise_dimension, dropout=0.1,
                 ):
        super(SimpleStructureModule, self).__init__()
        #self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.self_attn = MultiHeadAttention(d_model, nhead, d_model//nhead, d_model//nhead, dropout=dropout)
        #self.cross_attn = MultiHeadAttention(d_model, nhead, d_model//nhead, d_model//nhead, dropout=dropout)

        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.pairwise2heads=nn.Linear(pairwise_dimension,nhead,bias=False)
        self.pairwise_norm=nn.LayerNorm(pairwise_dimension)

        #self.distance2heads=nn.Linear(1,nhead,bias=False)
        #self.pairwise_norm=nn.LayerNorm(pairwise_dimension)

        self.activation = nn.GELU()

        
    def custom(self, module):
        def custom_forward(*inputs):
            inputs = module(*inputs)
            return inputs
        return custom_forward

    def forward(self, input):
        tgt , src,  pairwise_features, pred_t, src_mask = input
        
        #src = src*src_mask.float().unsqueeze(-1)

        pairwise_bias=self.pairwise2heads(self.pairwise_norm(pairwise_features)).permute(0,3,1,2)

        


        #print(pairwise_bias.shape,distance_bias.shape)

        #pairwise_bias=pairwise_bias+distance_bias


        res=tgt
        tgt,attention_weights = self.self_attn(tgt, tgt, tgt, mask=pairwise_bias, src_mask=src_mask)
        tgt = res + self.dropout1(tgt)
        tgt = self.norm1(tgt)

        # print(tgt.shape,src.shape)
        # exit()

        res=tgt
        tgt = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = res + self.dropout2(tgt)
        tgt = self.norm2(tgt)


        return tgt


In [7]:
class MSAEnhancedRibonanzaNet(nn.Module):
    def __init__(self, base_model, msa_embed_dim=64):
        super().__init__()
        self.base_model = base_model
        self.msa_embed_dim = msa_embed_dim
        
        # MSA feature processors
        self.conservation_embedder = nn.Sequential(
            nn.Linear(1, msa_embed_dim),
            nn.ReLU(),
            nn.Linear(msa_embed_dim, msa_embed_dim)
        )
        
        self.coevolution_processor = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(32, msa_embed_dim)
        )
        
        # Integration layers - we'll set this dynamically later
        self.feature_fusion = None
        self.msa_embed_dim = msa_embed_dim
        
    def process_msa_features(self, sequence_features, msa_features):
        """Process and integrate MSA features"""
        if msa_features is None or not msa_features.get('has_msa', False):
            return sequence_features
            
        # Get batch size and sequence length from sequence_features
        batch_size, seq_len, feat_dim = sequence_features.shape
        
        # Conservation features
        conservation = msa_features['conservation'].to(sequence_features.device)
        conservation_embed = self.conservation_embedder(conservation.unsqueeze(-1))
        # Add batch dimension to match sequence_features
        conservation_embed = conservation_embed.unsqueeze(0).expand(batch_size, -1, -1)
        
        # Coevolution features
        coevolution = msa_features['coevolution'].to(sequence_features.device)
        coevolution_embed = self.coevolution_processor(coevolution.unsqueeze(0).unsqueeze(0))
        # coevolution_embed is now [1, 64] after the processor
        # Expand to match batch and sequence dimensions
        coevolution_embed = coevolution_embed.expand(batch_size, seq_len, -1)
        
        # Debug prints to verify dimensions
        print(f"sequence_features shape: {sequence_features.shape}")
        print(f"conservation_embed shape: {conservation_embed.shape}")
        print(f"coevolution_embed shape: {coevolution_embed.shape}")
        
        # Create feature_fusion layer if it doesn't exist yet
        if self.feature_fusion is None:
            total_input_dim = feat_dim + self.msa_embed_dim * 2
            print(f"Creating feature_fusion layer: {total_input_dim} -> {feat_dim}")
            self.feature_fusion = nn.Sequential(
                nn.Linear(total_input_dim, feat_dim),
                nn.LayerNorm(feat_dim),
                nn.ReLU()
            ).to(sequence_features.device)
        
        # Fuse features
        enhanced_features = torch.cat([
            sequence_features, 
            conservation_embed, 
            coevolution_embed
        ], dim=-1)
        
        print(f"enhanced_features shape: {enhanced_features.shape}")
        
        return self.feature_fusion(enhanced_features)
    
    def sample(self, src, N, msa_features=None):
        """Enhanced sampling with MSA features"""
        # Temporarily modify the base model's get_embeddings method to include MSA
        original_get_embeddings = self.base_model.get_embeddings
        
        def enhanced_get_embeddings(src, mask):
            seq_feat, pair_feat = original_get_embeddings(src, mask)
            seq_feat = self.process_msa_features(seq_feat, msa_features)
            return seq_feat, pair_feat
        
        # Monkey patch the method
        self.base_model.get_embeddings = enhanced_get_embeddings
        
        try:
            result = self.base_model.sample(src, N)
        finally:
            # Restore original method
            self.base_model.get_embeddings = original_get_embeddings
        
        return result

In [8]:
import yaml

class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)
        self.entries=entries

    def print(self):
        print(self.entries)

def load_config_from_yaml(file_path):
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    return Config(**config)


diffusion_config=load_config_from_yaml("/kaggle/input/ribonanzanet2-ddpm-v2/diffusion_config.yaml")
rnet_config=load_config_from_yaml("/kaggle/input/ribonanzanet2/pytorch/alpha/1/pairwise.yaml")

base_model = finetuned_RibonanzaNet(rnet_config, diffusion_config).cuda()


constructing 48 ConvTransformerEncoderLayers


In [9]:
state_dict=torch.load("/kaggle/input/ribonanzanet2-ddpm-v2/RibonanzaNet-DDPM-v2.pt",map_location='cpu')

#get rid of module. from ddp state dict
new_state_dict={}

for key in state_dict:
    new_state_dict[key[7:]]=state_dict[key]

base_model.load_state_dict(new_state_dict)
model = MSAEnhancedRibonanzaNet(base_model).cuda()

In [10]:
# Add this diagnostic code to check MSA availability
import os

msa_directory = "/kaggle/input/stanford-rna-3d-folding/MSA_v2"

print(f"MSA directory exists: {os.path.exists(msa_directory)}")

if os.path.exists(msa_directory):
    msa_files = os.listdir(msa_directory)
    print(f"Number of MSA files found: {len(msa_files)}")
    print(f"First 10 MSA files: {msa_files[:10]}")
    
    # Check what target IDs we're looking for
    target_ids = test_data['target_id'].tolist()
    print(f"First 10 target IDs: {target_ids}")
    
    # Check if any MSA files match our target IDs
    matches = []
    for target_id in target_ids:
        expected_file = f"{target_id}.MSA.fasta"
        if expected_file in msa_files:
            matches.append(expected_file)
    
    print(f"Matching MSA files found: {len(matches)}")

else:
    print("MSA directory does not exist!")
    print("Available directories in /kaggle/input/stanford-rna-3d-folding/:")
    if os.path.exists("/kaggle/input/stanford-rna-3d-folding/"):
        print(os.listdir("/kaggle/input/stanford-rna-3d-folding/"))
    else:
        print("Stanford RNA 3D folding dataset not available!")

MSA directory exists: True
Number of MSA files found: 2534
First 10 MSA files: ['3JCS_6.MSA.fasta', '7MSF_R.MSA.fasta', '2OOM_B.MSA.fasta', '1ZDI_S.MSA.fasta', '5FJ1_H.MSA.fasta', '6UF1_C.MSA.fasta', '4V8A_AB.MSA.fasta', '6JDG_G.MSA.fasta', '5NCO_1.MSA.fasta', '5DI4_A.MSA.fasta']
First 10 target IDs: ['9L5R_2', '9GFT_AU', '9L0R_K', '9GFT_A3', '9B2K_B', '9B0S_Et', '9J3T_B', '9LCR_B', '8KEB_A', '9L5S_5', '8VXZ_C', '9J6Y_E', '8QHU_5', '9GHF_Z', '9KPO_B', '9N2B_5', '9N2C_Pt', '9B1Y_4', '9G06_a', '9DE8_A', '9B83_C', '8ZMH_A', '9E2Y_F', '9DE7_A', '8Y9L_B', '9FIB_Y', '9J3R_B', '9DPB_C', '8XTP_A', '8ZTV_Y', '8Y9M_B', '8ZQ9_A', '8XTP_B', '9B89_C', '8SYK_C', '9FN3_B', '8QHU_3', '9DRS_C', '8XTR_A', '9LMF_F', '9DE6_B', '8SYK_B', '9DE6_A', '8R7N_A', '8K85_A', '9FCV_B', '9DPA_C', '9DE5_C', '8VZ6_S', '8YIG_C', '9B84_F', '9C8K_2', '9B0Q_AP', '9E2Z_F', '8Z8Q_B', '9E2W_F', '8KHH_A', '8Z8U_B', '8ZTU_Y', '9GCL_A', '8RRI_Ax', '9L5S_6', '9GCM_A', '8Z9K_B', '9MTY_C', '8QHU_7', '9GBW_R', '8T5O_A', '9DPL_C', '

In [11]:
# Comprehensive MSA file matching diagnostic
import os
import pandas as pd

# Load your test data
test_data = pd.read_csv("/kaggle/input/validation-sequences-clean-csv/validation_sequences_clean.csv")
msa_directory = "/kaggle/input/stanford-rna-3d-folding/MSA_v2"

print("=== MSA DIAGNOSTIC REPORT ===\n")

# 1. Basic counts
print(f"Total test sequences: {len(test_data)}")
print(f"MSA directory exists: {os.path.exists(msa_directory)}")

if os.path.exists(msa_directory):
    msa_files = os.listdir(msa_directory)
    print(f"Total MSA files available: {len(msa_files)}")
else:
    print("ERROR: MSA directory not found!")
    exit()

# 2. Get all target IDs from test data
target_ids = test_data['target_id'].tolist()
print(f"Unique target IDs in test data: {len(set(target_ids))}")

# 3. Show sample target IDs and expected MSA file names
print(f"\nFirst 10 target IDs: {target_ids[:10]}")
expected_files = [f"{tid}.MSA.fasta" for tid in target_ids[:10]]
print(f"Expected MSA files: {expected_files}")

# 4. Check exact matches
exact_matches = []
missing_files = []

for target_id in target_ids:
    expected_file = f"{target_id}.MSA.fasta"
    if expected_file in msa_files:
        exact_matches.append(target_id)
    else:
        missing_files.append(target_id)

print(f"\n=== EXACT MATCHING RESULTS ===")
print(f"Exact matches found: {len(exact_matches)}/{len(target_ids)}")
print(f"Missing MSA files: {len(missing_files)}")

# 5. Show some examples of what's missing
if missing_files:
    print(f"\nFirst 10 missing target IDs: {missing_files[:10]}")
    print("Expected files that are missing:")
    for tid in missing_files[:10]:
        print(f"  {tid}.MSA.fasta")

# 6. Look for partial matches or naming patterns
print(f"\n=== MSA FILE PATTERN ANALYSIS ===")
print("Sample MSA files found:")
for i, msa_file in enumerate(msa_files[:10]):
    print(f"  {msa_file}")

# 7. Try to find any MSA files that might match our target IDs with different patterns
potential_matches = {}
for target_id in missing_files[:20]:  # Check first 20 missing ones
    # Look for files that contain the target_id
    matches = [f for f in msa_files if target_id in f]
    if matches:
        potential_matches[target_id] = matches

if potential_matches:
    print(f"\n=== POTENTIAL ALTERNATIVE MATCHES ===")
    for target_id, matches in potential_matches.items():
        print(f"{target_id} might match: {matches}")

# 8. Check if there are MSA files for target IDs we DO have
print(f"\n=== VERIFICATION OF FOUND MATCHES ===")
if exact_matches:
    print("Confirmed MSA files exist for:")
    for i, target_id in enumerate(exact_matches[:10]):
        msa_path = os.path.join(msa_directory, f"{target_id}.MSA.fasta")
        exists = os.path.exists(msa_path)
        print(f"  {target_id}.MSA.fasta - exists: {exists}")

# 9. Check for case sensitivity issues
print(f"\n=== CASE SENSITIVITY CHECK ===")
msa_files_lower = [f.lower() for f in msa_files]
case_matches = 0
for target_id in missing_files[:10]:
    expected_lower = f"{target_id}.msa.fasta".lower()
    if expected_lower in msa_files_lower:
        case_matches += 1
        print(f"Case mismatch found for: {target_id}")

print(f"Potential case sensitivity issues: {case_matches}")

# 10. Summary
print(f"\n=== SUMMARY ===")
print(f"Test sequences: {len(test_data)}")
print(f"Available MSA files: {len(msa_files)}")
print(f"Exact matches: {len(exact_matches)}")
print(f"Missing: {len(missing_files)}")
print(f"Match rate: {len(exact_matches)/len(target_ids)*100:.1f}%")

# 11. Double-check the dataset logic
print(f"\n=== DATASET LOGIC VERIFICATION ===")
print("Testing dataset logic with a few examples...")

class TestDataset:
    def __init__(self, data, msa_dir):
        self.data = data
        self.msa_dir = msa_dir
    
    def check_msa(self, idx):
        target_id = self.data.loc[idx, 'target_id']
        msa_path = os.path.join(self.msa_dir, f"{target_id}.MSA.fasta")
        exists = os.path.exists(msa_path)
        return target_id, msa_path, exists

test_dataset = TestDataset(test_data, msa_directory)

for i in range(min(5, len(test_data))):
    target_id, msa_path, exists = test_dataset.check_msa(i)
    print(f"Index {i}: {target_id} -> {exists}")

=== MSA DIAGNOSTIC REPORT ===

Total test sequences: 94
MSA directory exists: True
Total MSA files available: 2534
Unique target IDs in test data: 94

First 10 target IDs: ['9L5R_2', '9GFT_AU', '9L0R_K', '9GFT_A3', '9B2K_B', '9B0S_Et', '9J3T_B', '9LCR_B', '8KEB_A', '9L5S_5']
Expected MSA files: ['9L5R_2.MSA.fasta', '9GFT_AU.MSA.fasta', '9L0R_K.MSA.fasta', '9GFT_A3.MSA.fasta', '9B2K_B.MSA.fasta', '9B0S_Et.MSA.fasta', '9J3T_B.MSA.fasta', '9LCR_B.MSA.fasta', '8KEB_A.MSA.fasta', '9L5S_5.MSA.fasta']

=== EXACT MATCHING RESULTS ===
Exact matches found: 60/94
Missing MSA files: 34

First 10 missing target IDs: ['8ZMH_A', '9E2Y_F', '9J3R_B', '9DPB_C', '8XTP_A', '8ZTV_Y', '8ZQ9_A', '8XTP_B', '8QHU_3', '9DRS_C']
Expected files that are missing:
  8ZMH_A.MSA.fasta
  9E2Y_F.MSA.fasta
  9J3R_B.MSA.fasta
  9DPB_C.MSA.fasta
  8XTP_A.MSA.fasta
  8ZTV_Y.MSA.fasta
  8ZQ9_A.MSA.fasta
  8XTP_B.MSA.fasta
  8QHU_3.MSA.fasta
  9DRS_C.MSA.fasta

=== MSA FILE PATTERN ANALYSIS ===
Sample MSA files found:
  3JCS

In [40]:
from tqdm import tqdm
model.eval()
preds=[]
msa_usage_count = 0

for i in tqdm(range(len(test_dataset))):
    sample = test_dataset[i]
    src = sample['sequence'].long()
    src = src.unsqueeze(0).cuda()
    msa_features = sample['msa_features']
    target_id = test_data.loc[i,'target_id']
    
    # Track MSA usage
    if msa_features['has_msa']:
        msa_usage_count += 1
        if i < 10:  # Print info for first 10 samples
            print(f"Using MSA for {target_id}: depth={msa_features['msa_depth']}")
    
    with torch.no_grad():
        xyz, distogram = model.sample(src, 5, msa_features)

    preds.append(xyz.cpu().numpy())

print(f"\nMSA Usage Summary: {msa_usage_count}/{len(test_dataset)} sequences had MSA data")

  0%|          | 0/94 [00:00<?, ?it/s]

Using MSA for 9L5R_2: depth=100
sequence_features shape: torch.Size([1, 193, 384])
conservation_embed shape: torch.Size([1, 193, 64])
coevolution_embed shape: torch.Size([1, 193, 64])
Creating feature_fusion layer: 512 -> 384
enhanced_features shape: torch.Size([1, 193, 512])


  1%|          | 1/94 [00:37<58:24, 37.69s/it]

Using MSA for 9GFT_AU: depth=100
sequence_features shape: torch.Size([1, 76, 384])
conservation_embed shape: torch.Size([1, 76, 64])
coevolution_embed shape: torch.Size([1, 76, 64])
enhanced_features shape: torch.Size([1, 76, 512])


  2%|▏         | 2/94 [00:48<33:12, 21.66s/it]

Using MSA for 9L0R_K: depth=100
sequence_features shape: torch.Size([1, 700, 384])
conservation_embed shape: torch.Size([1, 700, 64])
coevolution_embed shape: torch.Size([1, 700, 64])
enhanced_features shape: torch.Size([1, 700, 512])


  3%|▎         | 3/94 [06:34<4:17:51, 170.01s/it]

Using MSA for 9GFT_A3: depth=100
sequence_features shape: torch.Size([1, 77, 384])
conservation_embed shape: torch.Size([1, 77, 64])
coevolution_embed shape: torch.Size([1, 77, 64])
enhanced_features shape: torch.Size([1, 77, 512])


  4%|▍         | 4/94 [06:45<2:40:46, 107.18s/it]

Using MSA for 9B2K_B: depth=100
sequence_features shape: torch.Size([1, 70, 384])
conservation_embed shape: torch.Size([1, 70, 64])
coevolution_embed shape: torch.Size([1, 70, 64])
enhanced_features shape: torch.Size([1, 70, 512])


  5%|▌         | 5/94 [06:55<1:46:54, 72.07s/it] 

Using MSA for 9B0S_Et: depth=100
sequence_features shape: torch.Size([1, 75, 384])
conservation_embed shape: torch.Size([1, 75, 64])
coevolution_embed shape: torch.Size([1, 75, 64])
enhanced_features shape: torch.Size([1, 75, 512])


  6%|▋         | 6/94 [07:05<1:14:57, 51.10s/it]

Using MSA for 9J3T_B: depth=100
sequence_features shape: torch.Size([1, 580, 384])
conservation_embed shape: torch.Size([1, 580, 64])
coevolution_embed shape: torch.Size([1, 580, 64])
enhanced_features shape: torch.Size([1, 580, 512])


  7%|▋         | 7/94 [11:11<2:46:29, 114.82s/it]

Using MSA for 9LCR_B: depth=100
sequence_features shape: torch.Size([1, 578, 384])
conservation_embed shape: torch.Size([1, 578, 64])
coevolution_embed shape: torch.Size([1, 578, 64])
enhanced_features shape: torch.Size([1, 578, 512])


  9%|▊         | 8/94 [15:14<3:42:45, 155.42s/it]

Using MSA for 8KEB_A: depth=100
sequence_features shape: torch.Size([1, 72, 384])
conservation_embed shape: torch.Size([1, 72, 64])
coevolution_embed shape: torch.Size([1, 72, 64])
enhanced_features shape: torch.Size([1, 72, 512])


 10%|▉         | 9/94 [15:24<2:35:50, 110.00s/it]

Using MSA for 9L5S_5: depth=100
sequence_features shape: torch.Size([1, 116, 384])
conservation_embed shape: torch.Size([1, 116, 64])
coevolution_embed shape: torch.Size([1, 116, 64])
enhanced_features shape: torch.Size([1, 116, 512])


 11%|█         | 10/94 [15:43<1:54:50, 82.03s/it]

sequence_features shape: torch.Size([1, 36, 384])
conservation_embed shape: torch.Size([1, 36, 64])
coevolution_embed shape: torch.Size([1, 36, 64])
enhanced_features shape: torch.Size([1, 36, 512])


 12%|█▏        | 11/94 [15:49<1:21:21, 58.82s/it]

sequence_features shape: torch.Size([1, 550, 384])
conservation_embed shape: torch.Size([1, 550, 64])
coevolution_embed shape: torch.Size([1, 550, 64])
enhanced_features shape: torch.Size([1, 550, 512])


 13%|█▎        | 12/94 [19:13<2:20:24, 102.74s/it]

sequence_features shape: torch.Size([1, 135, 384])
conservation_embed shape: torch.Size([1, 135, 64])
coevolution_embed shape: torch.Size([1, 135, 64])
enhanced_features shape: torch.Size([1, 135, 512])


 14%|█▍        | 13/94 [19:34<1:45:38, 78.25s/it] 

sequence_features shape: torch.Size([1, 77, 384])
conservation_embed shape: torch.Size([1, 77, 64])
coevolution_embed shape: torch.Size([1, 77, 64])
enhanced_features shape: torch.Size([1, 77, 512])


 15%|█▍        | 14/94 [19:45<1:17:13, 57.92s/it]

sequence_features shape: torch.Size([1, 255, 384])
conservation_embed shape: torch.Size([1, 255, 64])
coevolution_embed shape: torch.Size([1, 255, 64])
enhanced_features shape: torch.Size([1, 255, 512])


 16%|█▌        | 15/94 [20:42<1:15:47, 57.56s/it]

sequence_features shape: torch.Size([1, 120, 384])
conservation_embed shape: torch.Size([1, 120, 64])
coevolution_embed shape: torch.Size([1, 120, 64])
enhanced_features shape: torch.Size([1, 120, 512])


 17%|█▋        | 16/94 [21:02<1:00:07, 46.26s/it]

sequence_features shape: torch.Size([1, 77, 384])
conservation_embed shape: torch.Size([1, 77, 64])
coevolution_embed shape: torch.Size([1, 77, 64])
enhanced_features shape: torch.Size([1, 77, 512])


 18%|█▊        | 17/94 [21:13<45:42, 35.62s/it]  

sequence_features shape: torch.Size([1, 81, 384])
conservation_embed shape: torch.Size([1, 81, 64])
coevolution_embed shape: torch.Size([1, 81, 64])
enhanced_features shape: torch.Size([1, 81, 512])


 19%|█▉        | 18/94 [21:24<35:52, 28.32s/it]

sequence_features shape: torch.Size([1, 73, 384])
conservation_embed shape: torch.Size([1, 73, 64])
coevolution_embed shape: torch.Size([1, 73, 64])
enhanced_features shape: torch.Size([1, 73, 512])


 20%|██        | 19/94 [21:35<28:35, 22.88s/it]

sequence_features shape: torch.Size([1, 57, 384])
conservation_embed shape: torch.Size([1, 57, 64])
coevolution_embed shape: torch.Size([1, 57, 64])
enhanced_features shape: torch.Size([1, 57, 512])


 21%|██▏       | 20/94 [21:43<22:44, 18.44s/it]

sequence_features shape: torch.Size([1, 39, 384])
conservation_embed shape: torch.Size([1, 39, 64])
coevolution_embed shape: torch.Size([1, 39, 64])
enhanced_features shape: torch.Size([1, 39, 512])


 24%|██▍       | 23/94 [22:19<16:46, 14.18s/it]

sequence_features shape: torch.Size([1, 58, 384])
conservation_embed shape: torch.Size([1, 58, 64])
coevolution_embed shape: torch.Size([1, 58, 64])
enhanced_features shape: torch.Size([1, 58, 512])


 26%|██▌       | 24/94 [22:28<14:34, 12.50s/it]

sequence_features shape: torch.Size([1, 62, 384])
conservation_embed shape: torch.Size([1, 62, 64])
coevolution_embed shape: torch.Size([1, 62, 64])
enhanced_features shape: torch.Size([1, 62, 512])


 27%|██▋       | 25/94 [22:36<12:52, 11.19s/it]

sequence_features shape: torch.Size([1, 15, 384])
conservation_embed shape: torch.Size([1, 15, 64])
coevolution_embed shape: torch.Size([1, 15, 64])
enhanced_features shape: torch.Size([1, 15, 512])


 32%|███▏      | 30/94 [26:16<30:41, 28.78s/it]  

sequence_features shape: torch.Size([1, 62, 384])
conservation_embed shape: torch.Size([1, 62, 64])
coevolution_embed shape: torch.Size([1, 62, 64])
enhanced_features shape: torch.Size([1, 62, 512])


 35%|███▌      | 33/94 [29:53<1:12:41, 71.50s/it]

sequence_features shape: torch.Size([1, 35, 384])
conservation_embed shape: torch.Size([1, 35, 64])
coevolution_embed shape: torch.Size([1, 35, 64])
enhanced_features shape: torch.Size([1, 35, 512])


 36%|███▌      | 34/94 [29:59<51:53, 51.89s/it]  

sequence_features shape: torch.Size([1, 107, 384])
conservation_embed shape: torch.Size([1, 107, 64])
coevolution_embed shape: torch.Size([1, 107, 64])
enhanced_features shape: torch.Size([1, 107, 512])


 37%|███▋      | 35/94 [30:16<40:36, 41.29s/it]

sequence_features shape: torch.Size([1, 58, 384])
conservation_embed shape: torch.Size([1, 58, 64])
coevolution_embed shape: torch.Size([1, 58, 64])
enhanced_features shape: torch.Size([1, 58, 512])


 40%|████      | 38/94 [31:08<23:46, 25.48s/it]

sequence_features shape: torch.Size([1, 145, 384])
conservation_embed shape: torch.Size([1, 145, 64])
coevolution_embed shape: torch.Size([1, 145, 64])
enhanced_features shape: torch.Size([1, 145, 512])


 41%|████▏     | 39/94 [31:33<23:10, 25.28s/it]

sequence_features shape: torch.Size([1, 700, 384])
conservation_embed shape: torch.Size([1, 700, 64])
coevolution_embed shape: torch.Size([1, 700, 64])
enhanced_features shape: torch.Size([1, 700, 512])


 43%|████▎     | 40/94 [37:22<1:50:07, 122.36s/it]

sequence_features shape: torch.Size([1, 57, 384])
conservation_embed shape: torch.Size([1, 57, 64])
coevolution_embed shape: torch.Size([1, 57, 64])
enhanced_features shape: torch.Size([1, 57, 512])


 44%|████▎     | 41/94 [37:30<1:17:50, 88.12s/it] 

sequence_features shape: torch.Size([1, 107, 384])
conservation_embed shape: torch.Size([1, 107, 64])
coevolution_embed shape: torch.Size([1, 107, 64])
enhanced_features shape: torch.Size([1, 107, 512])


 45%|████▍     | 42/94 [37:46<57:44, 66.63s/it]  

sequence_features shape: torch.Size([1, 57, 384])
conservation_embed shape: torch.Size([1, 57, 64])
coevolution_embed shape: torch.Size([1, 57, 64])
enhanced_features shape: torch.Size([1, 57, 512])


 46%|████▌     | 43/94 [37:54<41:43, 49.08s/it]

sequence_features shape: torch.Size([1, 135, 384])
conservation_embed shape: torch.Size([1, 135, 64])
coevolution_embed shape: torch.Size([1, 135, 64])
enhanced_features shape: torch.Size([1, 135, 512])


 48%|████▊     | 45/94 [38:25<25:24, 31.12s/it]

sequence_features shape: torch.Size([1, 81, 384])
conservation_embed shape: torch.Size([1, 81, 64])
coevolution_embed shape: torch.Size([1, 81, 64])
enhanced_features shape: torch.Size([1, 81, 512])


 49%|████▉     | 46/94 [38:36<20:08, 25.17s/it]

sequence_features shape: torch.Size([1, 76, 384])
conservation_embed shape: torch.Size([1, 76, 64])
coevolution_embed shape: torch.Size([1, 76, 64])
enhanced_features shape: torch.Size([1, 76, 512])


 50%|█████     | 47/94 [38:47<16:16, 20.78s/it]

sequence_features shape: torch.Size([1, 57, 384])
conservation_embed shape: torch.Size([1, 57, 64])
coevolution_embed shape: torch.Size([1, 57, 64])
enhanced_features shape: torch.Size([1, 57, 512])


 51%|█████     | 48/94 [38:55<13:01, 16.98s/it]

sequence_features shape: torch.Size([1, 50, 384])
conservation_embed shape: torch.Size([1, 50, 64])
coevolution_embed shape: torch.Size([1, 50, 64])
enhanced_features shape: torch.Size([1, 50, 512])


 52%|█████▏    | 49/94 [39:02<10:34, 14.10s/it]

sequence_features shape: torch.Size([1, 104, 384])
conservation_embed shape: torch.Size([1, 104, 64])
coevolution_embed shape: torch.Size([1, 104, 64])
enhanced_features shape: torch.Size([1, 104, 512])


 54%|█████▍    | 51/94 [39:27<09:13, 12.88s/it]

sequence_features shape: torch.Size([1, 27, 384])
conservation_embed shape: torch.Size([1, 27, 64])
coevolution_embed shape: torch.Size([1, 27, 64])
enhanced_features shape: torch.Size([1, 27, 512])


 55%|█████▌    | 52/94 [39:33<07:30, 10.74s/it]

sequence_features shape: torch.Size([1, 71, 384])
conservation_embed shape: torch.Size([1, 71, 64])
coevolution_embed shape: torch.Size([1, 71, 64])
enhanced_features shape: torch.Size([1, 71, 512])


 56%|█████▋    | 53/94 [39:43<07:12, 10.56s/it]

sequence_features shape: torch.Size([1, 40, 384])
conservation_embed shape: torch.Size([1, 40, 64])
coevolution_embed shape: torch.Size([1, 40, 64])
enhanced_features shape: torch.Size([1, 40, 512])


 60%|█████▉    | 56/94 [40:05<05:22,  8.49s/it]

sequence_features shape: torch.Size([1, 71, 384])
conservation_embed shape: torch.Size([1, 71, 64])
coevolution_embed shape: torch.Size([1, 71, 64])
enhanced_features shape: torch.Size([1, 71, 512])


 64%|██████▍   | 60/94 [40:51<06:44, 11.89s/it]

sequence_features shape: torch.Size([1, 70, 384])
conservation_embed shape: torch.Size([1, 70, 64])
coevolution_embed shape: torch.Size([1, 70, 64])
enhanced_features shape: torch.Size([1, 70, 512])


 65%|██████▍   | 61/94 [41:00<06:12, 11.30s/it]

sequence_features shape: torch.Size([1, 101, 384])
conservation_embed shape: torch.Size([1, 101, 64])
coevolution_embed shape: torch.Size([1, 101, 64])
enhanced_features shape: torch.Size([1, 101, 512])


 71%|███████▏  | 67/94 [42:31<07:15, 16.12s/it]

sequence_features shape: torch.Size([1, 124, 384])
conservation_embed shape: torch.Size([1, 124, 64])
coevolution_embed shape: torch.Size([1, 124, 64])
enhanced_features shape: torch.Size([1, 124, 512])


 72%|███████▏  | 68/94 [42:52<07:36, 17.55s/it]

sequence_features shape: torch.Size([1, 76, 384])
conservation_embed shape: torch.Size([1, 76, 64])
coevolution_embed shape: torch.Size([1, 76, 64])
enhanced_features shape: torch.Size([1, 76, 512])


 76%|███████▌  | 71/94 [46:09<23:58, 62.55s/it]

sequence_features shape: torch.Size([1, 147, 384])
conservation_embed shape: torch.Size([1, 147, 64])
coevolution_embed shape: torch.Size([1, 147, 64])
enhanced_features shape: torch.Size([1, 147, 512])


 77%|███████▋  | 72/94 [46:32<18:38, 50.84s/it]

sequence_features shape: torch.Size([1, 57, 384])
conservation_embed shape: torch.Size([1, 57, 64])
coevolution_embed shape: torch.Size([1, 57, 64])
enhanced_features shape: torch.Size([1, 57, 512])


 78%|███████▊  | 73/94 [46:41<13:19, 38.07s/it]

sequence_features shape: torch.Size([1, 159, 384])
conservation_embed shape: torch.Size([1, 159, 64])
coevolution_embed shape: torch.Size([1, 159, 64])
enhanced_features shape: torch.Size([1, 159, 512])


 79%|███████▊  | 74/94 [47:07<11:29, 34.47s/it]

sequence_features shape: torch.Size([1, 66, 384])
conservation_embed shape: torch.Size([1, 66, 64])
coevolution_embed shape: torch.Size([1, 66, 64])
enhanced_features shape: torch.Size([1, 66, 512])


 80%|███████▉  | 75/94 [47:16<08:29, 26.83s/it]

sequence_features shape: torch.Size([1, 101, 384])
conservation_embed shape: torch.Size([1, 101, 64])
coevolution_embed shape: torch.Size([1, 101, 64])
enhanced_features shape: torch.Size([1, 101, 512])


 83%|████████▎ | 78/94 [50:48<14:33, 54.59s/it]

sequence_features shape: torch.Size([1, 96, 384])
conservation_embed shape: torch.Size([1, 96, 64])
coevolution_embed shape: torch.Size([1, 96, 64])
enhanced_features shape: torch.Size([1, 96, 512])


 85%|████████▌ | 80/94 [54:19<20:44, 88.90s/it]

sequence_features shape: torch.Size([1, 118, 384])
conservation_embed shape: torch.Size([1, 118, 64])
coevolution_embed shape: torch.Size([1, 118, 64])
enhanced_features shape: torch.Size([1, 118, 512])


 87%|████████▋ | 82/94 [54:47<10:03, 50.29s/it]

sequence_features shape: torch.Size([1, 76, 384])
conservation_embed shape: torch.Size([1, 76, 64])
coevolution_embed shape: torch.Size([1, 76, 64])
enhanced_features shape: torch.Size([1, 76, 512])


 91%|█████████▏| 86/94 [55:45<03:14, 24.32s/it]

sequence_features shape: torch.Size([1, 62, 384])
conservation_embed shape: torch.Size([1, 62, 64])
coevolution_embed shape: torch.Size([1, 62, 64])
enhanced_features shape: torch.Size([1, 62, 512])


 93%|█████████▎| 87/94 [55:53<02:16, 19.46s/it]

sequence_features shape: torch.Size([1, 184, 384])
conservation_embed shape: torch.Size([1, 184, 64])
coevolution_embed shape: torch.Size([1, 184, 64])
enhanced_features shape: torch.Size([1, 184, 512])


 94%|█████████▎| 88/94 [56:27<02:23, 23.99s/it]

sequence_features shape: torch.Size([1, 121, 384])
conservation_embed shape: torch.Size([1, 121, 64])
coevolution_embed shape: torch.Size([1, 121, 64])
enhanced_features shape: torch.Size([1, 121, 512])


 96%|█████████▌| 90/94 [57:01<01:20, 20.13s/it]

sequence_features shape: torch.Size([1, 90, 384])
conservation_embed shape: torch.Size([1, 90, 64])
coevolution_embed shape: torch.Size([1, 90, 64])
enhanced_features shape: torch.Size([1, 90, 512])


 97%|█████████▋| 91/94 [57:14<00:53, 17.94s/it]

sequence_features shape: torch.Size([1, 101, 384])
conservation_embed shape: torch.Size([1, 101, 64])
coevolution_embed shape: torch.Size([1, 101, 64])
enhanced_features shape: torch.Size([1, 101, 512])


100%|██████████| 94/94 [57:57<00:00, 36.99s/it]


MSA Usage Summary: 60/94 sequences had MSA data





In [42]:
ID=[]
resname=[]
resid=[]
x=[]
y=[]
z=[]

data=[]

for i in range(len(test_data)):
    #print(test_data.loc[i])

    
    for j in range(len(test_data.loc[i,'sequence'])):
        # ID.append(test_data.loc[i,'sequence_id']+f"_{j+1}")
        # resname.append(test_data.loc[i,'sequence'][j])
        # resid.append(j+1) # 1 indexed
        row=[test_data.loc[i,'target_id']+f"_{j+1}",
             test_data.loc[i,'sequence'][j],
             j+1]

        for k in range(5):
            for kk in range(3):
                row.append(preds[i][k][j][kk])
        data.append(row)

columns=['ID','resname','resid']
for i in range(1,6):
    columns+=[f"x_{i}"]
    columns+=[f"y_{i}"]
    columns+=[f"z_{i}"]


submission=pd.DataFrame(data,columns=columns)


submission
submission.to_csv('submission.csv',index=False)

In [44]:
import shutil
import os

# Copy USalign to working directory and make it executable
shutil.copy2("/kaggle/input/usalign/USalign", "/kaggle/working/USalign")
os.chmod("/kaggle/working/USalign", 0o755)

print("USalign copied to /kaggle/working/ and made executable")

USalign copied to /kaggle/working/ and made executable


In [None]:
# score val

import os
import re
import numpy as np
import pandas as pd

def parse_tmscore_output(output):
    tm_score_match = re.findall(r'TM-score=\s+([\d.]+)', output)[1]
    return float(tm_score_match)

def write_target_line(
    atom_name, atom_serial, residue_name, chain_id, residue_num,
    x_coord, y_coord, z_coord, occupancy=1.0, b_factor=0.0, atom_type='P'
) -> str:
    return (
        f'ATOM  {atom_serial:>5d}  {atom_name:<5s} {residue_name:<3s} '
        f'{residue_num:>3d}    {x_coord:>8.3f}{y_coord:>8.3f}'
        f'{z_coord:>8.3f}{occupancy:>6.2f}{b_factor:>6.2f}           {atom_type}\n'
    )

def write2pdb(df: pd.DataFrame, xyz_id: int, target_path: str) -> int:
    resolved_cnt = 0
    with open(target_path, 'w') as f:
        for _, row in df.iterrows():
            x = row[f'x_{xyz_id}']; y = row[f'y_{xyz_id}']; z = row[f'z_{xyz_id}']
            if x > -1e17 and y > -1e17 and z > -1e17:
                resolved_cnt += 1
                f.write(write_target_line(
                    atom_name="C1'", atom_serial=int(row['resid']),
                    residue_name=row['resname'], chain_id='0',
                    residue_num=int(row['resid']),
                    x_coord=x, y_coord=y, z_coord=z, atom_type='C'
                ))
    return resolved_cnt

def get_base_target_id(long_id):
    return "_".join(str(long_id).split("_")[:-1])

def score_and_report(solution: pd.DataFrame, submission: pd.DataFrame):
    solution['target_id'] = solution['ID'].apply(get_base_target_id)
    submission['target_id'] = submission['ID'].apply(get_base_target_id)

    native_idxs = sorted(int(c.split('_')[1])
                         for c in solution.columns if c.startswith('x_'))

    usalign = "/kaggle/working/USalign"
    per_target = {}
    
    # Find common targets to iterate over
    common_targets = sorted(list(set(solution['target_id'].unique()) & set(submission['target_id'].unique())))
    
    print(f"Scoring {len(common_targets)} common targets...")

    for tid in common_targets:
        grp_nat = solution[solution['target_id'] == tid]
        grp_pred = submission[submission['target_id'] == tid]
        best_of_five = []

        for pred_cnt in range(1, 6):
            best_for_this_pred = 0.0
            for nat_cnt in native_idxs:
                n_nat  = write2pdb(grp_nat,   nat_cnt,   'native.pdb')
                n_pred = write2pdb(grp_pred,  pred_cnt, 'predicted.pdb')
                if n_nat > 0 and n_pred > 0:
                    out = os.popen(
                        f'{usalign} predicted.pdb native.pdb -atom " C1\'"'
                    ).read()
                    best_for_this_pred = max(best_for_this_pred,
                                             parse_tmscore_output(out))
            best_of_five.append(best_for_this_pred)

        per_target[tid] = best_of_five
        print(f"{tid}: TM-scores per model = {best_of_five}, "
              f"best = {max(best_of_five):.4f}")
    
    # Calculate mean TM score
    all_best_scores = [max(scores) for scores in per_target.values()]
    mean_tm = np.mean(all_best_scores) if all_best_scores else 0.0
    
    return per_target, mean_tm

solution   = pd.read_csv(
    "/kaggle/input/validation-labels-clean-csv/validation_labels_clean.csv"
)

per_target_scores, mean_tm = score_and_report(solution, submission)
print(f"\nMean TM-score: {mean_tm:.4f}")

Scoring 94 common targets...
8K85_A: TM-scores per model = [0.11851, 0.26332, 0.27056, 0.27635, 0.25619], best = 0.2763
8KEB_A: TM-scores per model = [0.1412, 0.15859, 0.15244, 0.12496, 0.11153], best = 0.1586
8KHH_A: TM-scores per model = [0.11734, 0.10212, 0.14044, 0.08531, 0.12467], best = 0.1404
8QHU_3: TM-scores per model = [0.22859, 0.80691, 0.23213, 0.80075, 0.77708], best = 0.8069
8QHU_4: TM-scores per model = [0.13733, 0.15631, 0.11052, 0.15348, 0.10492], best = 0.1563
8QHU_5: TM-scores per model = [0.13008, 0.15257, 0.10187, 0.12257, 0.14358], best = 0.1526
8QHU_7: TM-scores per model = [0.89972, 0.23369, 0.93436, 0.93088, 0.23213], best = 0.9344
8QHU_S4: TM-scores per model = [0.10372, 0.13692, 0.10917, 0.14957, 0.11647], best = 0.1496
8R7N_A: TM-scores per model = [0.10665, 0.10993, 0.12394, 0.07499, 0.13397], best = 0.1340
8RRI_Ax: TM-scores per model = [0.14859, 0.13666, 0.1161, 0.09984, 0.14559], best = 0.1486
8RWG_C: TM-scores per model = [0.12915, 0.16605, 0.14947, 0.1