In [None]:
import transformers
import torch
import multimolecule  # you must import multimolecule to register models
from transformers import pipeline
from transformers import AutoModelForMaskedLM, AutoTokenizer

import gc

In [None]:
DEVICE = 'cuda'
DATASET_PATH = "P_distance_dataset_path.pkl"

tokenizer = AutoTokenizer.from_pretrained("buetnlpbio/birna-tokenizer")
config = transformers.BertConfig.from_pretrained("buetnlpbio/birna-bert")
model_llm = AutoModelForMaskedLM.from_pretrained("buetnlpbio/birna-bert",config=config,trust_remote_code=True)
model_llm.cls = torch.nn.Identity()
model_llm.to(DEVICE)

# To get nucleotide embeddings
text = ["A G C U A C G U A C G U"]
input = tokenizer(text, return_tensors="pt")
input = input.to(DEVICE)

output = model_llm(**input)
output.logits

# char_embed = mysterybert(**tokenizer(text, return_tensors="pt")) 
# print(char_embed.logits.shape) # CLS + 12 nucleotide token embeddings + SEP


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from typing import List, Tuple

class RNADataset(Dataset):
    def __init__(self, sequences: List[str], distance_maps: List[List[List[float]]]):
        # Update sequences by inserting a single space between characters
        self.sequences = [' '.join(sequence) for sequence in sequences]
        self.distance_maps = distance_maps

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx: int):
        sequence = self.sequences[idx]  # shape: (seq,)
        distance_map = self.distance_maps[idx]  # shape: (seq, seq)
        return sequence, distance_map


def collate_fn(batch: List[Tuple[str, List[List[float]]]]):
    sequences = [item[0] for item in batch]
    lengths = [len(i) for i in sequences]
    
    distance_maps_list: List[List[List[float]]] = [item[1] for item in batch]

    # token length = seq length + 2 (CLS and SEP tokens)
    input_ids = tokenizer(sequences, return_tensors="pt", padding=True).input_ids

    max_len = input_ids.shape[1]
    padded_distance_maps_list = []
    
    for distance_maps in distance_maps_list:
        distance_maps = torch.tensor(distance_maps)  # shape: (seq, seq)

        padded_distance_map = torch.zeros(max_len, max_len)  # shape: (max_len, max_len)
        padded_distance_map[
            1 : distance_maps.shape[0] + 1, 1 : distance_maps.shape[1] + 1
        ] = distance_maps
        padded_distance_maps_list.append(padded_distance_map)
    
    padded_distance_maps = torch.stack(
        padded_distance_maps_list
    )  # shape: (batch, max_len, max_len)

    return input_ids, padded_distance_maps, lengths


In [None]:
import pickle
dataset = pickle.load(open(DATASET_PATH, "rb"))
dataset

In [None]:

def bad_sequence(seq):
    # if 90 percent is a single nucleotide

    n = ["A", "C", "G", "U", "T"]
    thresh = 0.9
    for x in n:
        if seq.count(x) > thresh * len(seq):
            return True
    
    return False


new_distance_maps = []
new_sequences = []

seen = set()

import pickle
import pandas as pd
import numpy as np

dataset = dataset
sequences = dataset["sequence"].tolist()

distance_maps = dataset["distance_matrix"].tolist()
# plot the length distribution of the sequences
lengths = [len(seq) for seq in sequences]


print("Number of sequences:", len(sequences))


for i, sequence in enumerate(sequences):
    if len(sequence) != len(distance_maps[i]):
        print("Length mismatch")
    #    print(i, sequence, len(sequence), len(distance_maps[i]))
        continue
    if bad_sequence(sequence):
        # print("Bad sequence")
        continue

    if sequence not in seen:
        seen.add(sequence)
        new_sequences.append(sequence)
        new_distance_maps.append(distance_maps[i])
print("Number of unique sequences:", len(new_sequences))    
sequences = new_sequences
distance_maps = new_distance_maps


def min_max_denormalize(normalized_data, min_val, max_val):
    # Reverse Min-Max normalization
    original_data = normalized_data * (max_val - min_val) + min_val
    
    return original_data

def flatten(data):
    return data.flatten()

def unflatten(data):
    shape = np.sqrt(len(data)).astype(int)
    return data.reshape(shape, shape)

min_max_stores = []

new_distance_maps = []

for seq, dmap in zip(sequences, distance_maps):
    dmap= np.array(dmap)
    flattened_dmap = flatten(dmap)
    # normalized_flattened_dmap, min_val, max_val = min_max_normalize(flattened_dmap)
    # normalized_dmap = unflatten(normalized_flattened_dmap)
    normalized_dmap = unflatten(flattened_dmap)
    # min_max_stores.append((min_val, max_val))
    new_distance_maps.append(normalized_dmap)

distance_maps = new_distance_maps


train_sequences, val_sequences, train_maps, val_maps = train_test_split(
    sequences, distance_maps, test_size=0.20, random_state=42
)
print("Number of training sequences:", len(train_sequences))

train_dataset = RNADataset(train_sequences, train_maps)
val_dataset = RNADataset(val_sequences, val_maps)

train_loader = DataLoader(
    train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn
)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

In [None]:
len(train_sequences), len(val_sequences)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = device
LEARNING_RATE = 1e-5
EPOCHS = 64 

In [None]:
def create_difference_matrix(rows: int, cols: int) -> torch.Tensor:
    row_indices = torch.arange(rows).unsqueeze(1).expand(-1, cols)  # (rows, cols)
    col_indices = torch.arange(cols).unsqueeze(0).expand(rows, -1)  # (rows, cols)
    difference_matrix = torch.abs(row_indices - col_indices)  # Absolute difference
    return difference_matrix.float()

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dilation=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=dilation, dilation=dilation)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.shortcut = (
            nn.Conv2d(in_channels, out_channels, kernel_size=1)
            if in_channels != out_channels
            else nn.Identity()
        )

    def forward(self, x):
        shortcut = self.shortcut(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x += shortcut  # Residual connection
        return self.relu(x)


In [None]:
class SqueezeExcitationBlock(nn.Module):
    def __init__(self, channels, reduction_ratio=16):
        super(SqueezeExcitationBlock, self).__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)  # Global average pooling
        self.fc1 = nn.Conv2d(channels, channels // reduction_ratio, kernel_size=1)
        self.relu = nn.ReLU()
        self.fc2 = nn.Conv2d(channels // reduction_ratio, channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        scale = self.global_avg_pool(x)
        scale = self.fc1(scale)
        scale = self.relu(scale)
        scale = self.fc2(scale)
        scale = self.sigmoid(scale)
        return x * scale  # Recalibrate features


In [None]:
class DistanceMapPredictor(nn.Module):
    def __init__(self):
        super(DistanceMapPredictor, self).__init__()
        self.bert = model_llm  # Use the globally defined BERT model
        self.hidden_size = self.bert.config.hidden_size  # Dynamically fetch hidden size

        # Adjust bottleneck projection to dynamically match input size
        self.projection = nn.Conv2d(2 * self.hidden_size, 512, kernel_size=1)

        # Enhanced convolutional layers with batch normalization, residual connections, and attention
        self.conv_layers = nn.Sequential(
            ResidualBlock(512, 512, dilation=1),
            ResidualBlock(512, 256, dilation=2),  # Multi-scale context with dilation
            SqueezeExcitationBlock(256),         # Channel attention
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=3, padding=1),
        )

        # Initialize weights
        for layer in self.conv_layers:
            if isinstance(layer, nn.Conv2d):
                nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")

    def forward(self, input_ids: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
        outputs = self.bert(input_ids)
        embeddings = outputs.logits  # Shape: (batch, max_len, hidden_size)
        max_len = embeddings.size(1)

        # Pairwise concatenation
        concat_embeddings = torch.cat(
            [
                embeddings.unsqueeze(2).expand(-1, -1, max_len, -1),
                embeddings.unsqueeze(1).expand(-1, max_len, -1, -1),
            ],
            dim=-1,
        )  # Shape: (batch, max_len, max_len, 2 * hidden_size)

        # Permute to match convolution input expectations
        concat_embeddings = concat_embeddings.permute(0, 3, 1, 2)  # Shape: (batch, 2 * hidden_size, max_len, max_len)

        # Reduce dimensionality with bottleneck
        concat_embeddings = self.projection(concat_embeddings)

        # Apply convolutional layers
        output_distance_map = self.conv_layers(concat_embeddings).squeeze(1)  # Shape: (batch, max_len, max_len)

        # Mask upper triangle for valid lengths
        upper_tri_mask = torch.triu(torch.ones(max_len, max_len, device=DEVICE), diagonal=1)  # Upper triangle mask
        distance_map_mask_list = []
        for l in lengths:
            distance_map_mask = torch.zeros(max_len, max_len, device=DEVICE)
            distance_map_mask[1 : l + 1, 1 : l + 1] = 1
            distance_map_mask_list.append(distance_map_mask)

        distance_map_masks = torch.stack(distance_map_mask_list)  # Shape: (batch, max_len, max_len)
        valid_upper_tri_mask = distance_map_masks * upper_tri_mask  # Combine masks

        # Extract upper triangle and enforce symmetry
        upper_triangle = output_distance_map * valid_upper_tri_mask
        symmetric_map = upper_triangle + upper_triangle.transpose(-1, -2)

        diag_indices = torch.arange(max_len, device=DEVICE)
        symmetric_map[:, diag_indices, diag_indices] = 0.0

        # Return symmetric map (loss will propagate naturally)
        return symmetric_map



In [None]:
from torch.cuda.amp import GradScaler, autocast

In [None]:
def compute_r2_score_with_mask(
    y_true: torch.Tensor, y_pred: torch.Tensor, lengths
):
    # y_true: (batch, max_len, max_len)
    # y_pred: (batch, max_len, max_len)
    # mask: (batch, max_len, max_len)

    max_len = y_pred.shape[1]
    distance_map_mask_list = []
    
    for l in lengths:

        distance_map_mask = torch.zeros(max_len, max_len)  # shape: (max_len, max_len)
        distance_map_mask[
            1 : l + 1, 1 : l + 1
        ] = 1
        distance_map_mask_list.append(distance_map_mask)

    
    mask = torch.stack(
        distance_map_mask_list
    )  # shape: (batch, max_len, max_len)



    mask = mask.bool()
    y_true = y_true[mask]
    y_pred = y_pred[mask]
    # print("post mask shape", y_true.shape, y_pred.shape)

    r2 = 1 - (
        torch.sum((y_true - y_pred) ** 2) / torch.sum((y_true - y_true.mean()) ** 2)
    )
    return r2

In [None]:
# take the maximum one 
def compute_pearson_correlation_with_mask(
    y_true: torch.Tensor, y_pred: torch.Tensor, lengths
):
    """
    Computes the Pearson correlation coefficient between y_true and y_pred 
    using a mask generated from the lengths array.

    Args:
        y_true: torch.Tensor, ground truth values (batch, max_len, max_len).
        y_pred: torch.Tensor, predicted values (batch, max_len, max_len).
        lengths: list or array of integers specifying the valid lengths for each batch.

    Returns:
        torch.Tensor: Pearson correlation coefficient.
    """
    # Generate the mask
    max_len = y_pred.shape[1]
    distance_map_mask_list = []
    
    for l in lengths:
        distance_map_mask = torch.zeros(max_len, max_len)  # shape: (max_len, max_len)
        distance_map_mask[1 : l + 1, 1 : l + 1] = 1
        distance_map_mask_list.append(distance_map_mask)

    mask = torch.stack(distance_map_mask_list).bool()  # shape: (batch, max_len, max_len)

    # Apply the mask to y_true and y_pred
    y_true = y_true[mask]
    y_pred = y_pred[mask]

    # Compute the Pearson correlation coefficient
    mean_y_true = y_true.mean()
    mean_y_pred = y_pred.mean()
    
    numerator = torch.sum((y_true - mean_y_true) * (y_pred - mean_y_pred))
    denominator = torch.sqrt(
        torch.sum((y_true - mean_y_true) ** 2) * torch.sum((y_pred - mean_y_pred) ** 2)
    )
    pearson_correlation = numerator / denominator

    return pearson_correlation


In [None]:
from tqdm import tqdm
from torch.optim.adam import Adam
from torch.cuda.amp import GradScaler, autocast



In [None]:
def create_proximity_weights(target_matrix: torch.Tensor, alpha: float = 1.0) -> torch.Tensor:
    """
    Creates a proximity weight matrix with inverse power law decay, normalizes it,
    and dynamically scales it using the maximum distance and matrix dimensions.

    Args:
        target_matrix (torch.Tensor): The target distance matrix of shape (1, len, len).
        alpha (float): Scaling factor to control the decay rate.

    Returns:
        torch.Tensor: Dynamically scaled weight matrix of shape (1, len, len).
    """
    assert target_matrix.dim() == 3 and target_matrix.size(0) == 1, \
        "Input target_matrix must have shape (1, len, len)"
    
    # Extract the (len, len) matrix
    distance_matrix = target_matrix[0]
    
    # Assign weights using an inverse power law
    weights = 1.0 / (1.0 + alpha * distance_matrix)  # Smoother decay

    # Set diagonal to 0
    diag_indices = torch.arange(distance_matrix.size(0), device=distance_matrix.device)
    weights[diag_indices, diag_indices] = 0.0

    # Normalize weights by the maximum excluding the diagonal
    max_weight = weights.max()
    weights /= max_weight

    # Dynamically scale weights
    max_distance = distance_matrix.max()
    matrix_dim = distance_matrix.size(0)
    scale_factor = max_distance / matrix_dim
    weights *= scale_factor

    # Reshape weights back to (1, len, len)
    weights = weights.unsqueeze(0)

    del distance_matrix, diag_indices, max_weight, max_distance, matrix_dim, scale_factor, alpha
    
    return weights


In [None]:
# Training Loop
model = DistanceMapPredictor().to(DEVICE)
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.MSELoss(reduction="none")  # Weighted loss
scaler = GradScaler()

BEST_R2 = -1
best_model = None

for epoch in range(EPOCHS):
    running_loss = 0.0
    train_loop = tqdm(train_loader, leave=True)
    train_loop.set_description(f"Epoch [{epoch + 1}/{EPOCHS}] - Training")

    model.train()
    for input_ids, targets, lengths in train_loop:
        # Move data to DEVICE
        input_ids = input_ids.to(DEVICE, non_blocking=True)
        targets = targets.to(DEVICE, non_blocking=True)

        # Zero gradients
        optimizer.zero_grad()

        with autocast():
            # Forward pass
            outputs = model(input_ids=input_ids, lengths=lengths)
            
            # During training, compute proximity weights
            target_distances = targets  # Assuming targets is the distance matrix
            proximity_weights = create_proximity_weights(target_distances).to(DEVICE)

            # Compute losses and apply weights
            losses = criterion(outputs, targets)
            weighted_loss = (losses * proximity_weights).mean()

        # Backward pass with gradient scaling
        scaler.scale(weighted_loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Update running loss
        running_loss += weighted_loss.item()
        train_loop.set_postfix(loss=weighted_loss.item())

        # Live GPU Memory Usage
        allocated_memory = torch.cuda.memory_allocated(DEVICE) / 1e6  # In MB
        reserved_memory = torch.cuda.memory_reserved(DEVICE) / 1e6    # In MB
        train_loop.set_postfix(
            loss=weighted_loss.item(),
            gpu_allocated=f"{allocated_memory:.2f} MB",
            gpu_reserved=f"{reserved_memory:.2f} MB"
        )

        # Detach and clean up tensors to free memory after each batch
        del input_ids, targets, outputs, losses, weighted_loss, proximity_weights
        torch.cuda.empty_cache()

    print(f"Epoch [{epoch + 1}/{EPOCHS}], Training Loss: {running_loss / len(train_loader):.4f}")

    # Validation Loop
    model.eval()
    val_running_loss = 0.0
    r2_scores = []

    with torch.no_grad():
        for input_ids, targets, lengths in val_loader:
            input_ids = input_ids.to(DEVICE, non_blocking=True)
            targets = targets.to(DEVICE, non_blocking=True)

            with autocast():
                # Forward pass
                outputs = model(input_ids=input_ids, lengths=lengths)
                
                # Weighted loss calculation
                target_distances = targets  # Assuming targets is the distance matrix
                proximity_weights = create_proximity_weights(target_distances).to(DEVICE)

                # Compute losses and apply weights
                losses = criterion(outputs, targets)
                weighted_loss = (losses * proximity_weights).mean()

            val_running_loss += weighted_loss.item()

            # Compute R² score
            r2 = compute_r2_score_with_mask(targets.detach(), outputs.detach(), lengths)
            r2_scores.append(r2.item())

            # Detach and clean up tensors to free memory after each batch in validation
            del input_ids, targets, outputs, losses, weighted_loss, proximity_weights
            torch.cuda.empty_cache()

    average_r2 = np.mean(r2_scores)
    print(f"Epoch [{epoch + 1}/{EPOCHS}], Validation Loss: {val_running_loss / len(val_loader):.4f}, r^2: {average_r2:.4f}")

    import copy
    # Save best model if R² improves
    if average_r2 > BEST_R2:
        BEST_R2 = average_r2
        torch.save(model.state_dict(), "single_birnabert_r2_64_proximity_mse.pth")

    # Cleanup after epoch to free unused memory
    torch.cuda.empty_cache()
    gc.collect()

# Final cleanup after all epochs are complete
torch.cuda.empty_cache()
gc.collect()

In [None]:
BEST_R2

In [None]:
# # save the model
# torch.save(best_model.state_dict(), "best_model.pth")
# load the best model
best_model = DistanceMapPredictor().to(DEVICE)
best_model.load_state_dict(torch.load("single_birnabert_r2_64_proximity_mse.pth"))

In [None]:
import random
import matplotlib.pyplot as plt

# Perform inference on the best model
best_model.eval()
test_loop = tqdm(val_loader, leave=True)
test_loop.set_description(f"Testing")

test_running_loss = 0.0
r2_scores = []
targets_list = []
outputs_list = []

with torch.no_grad():
    for input_ids, targets, lengths in test_loop:
        input_ids = input_ids.to(device)  # Shape: (batch, max_len)
        targets = targets.to(device)  # Shape: (batch, max_len, max_len)
        with autocast():
            outputs = best_model(
                input_ids=input_ids,
                lengths=lengths
            )  # Shape: (batch, max_len, max_len)

            abs_diff = create_difference_matrix(
                targets.size(1), targets.size(2)
            ).float()
            weights = abs_diff**0.5  # Shape: (max_len, max_len)

            losses = criterion(outputs, targets)  # Shape: (batch, max_len, max_len)
            weighted_loss = losses * weights.to(
                device
            )  # Shape: (batch, max_len, max_len)
            loss = weighted_loss.mean()  # Shape: (1,)

        test_running_loss += loss.item()

        # Calculate R2 score
        r2 = compute_r2_score_with_mask(targets, outputs, lengths)
        r2_scores.append(r2.item())

        # Store distance maps (outputs and targets) for later comparison and plotting
        targets_list.append(targets.cpu())
        outputs_list.append(outputs.cpu())

# Calculate and print average R2 score
average_r2 = np.mean(r2_scores)
print(
    f"Average Testing Loss: {test_running_loss / len(val_loader):.4f}, Average R² Score: {average_r2:.4f}"
)

# Randomly select 10 distance maps to plot
random_indices = random.sample(range(len(targets_list)), 10)

# Plotting function for targets vs outputs (predicted)
def plot_distance_maps(target, output, idx):
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(target, cmap='viridis')
    axs[0].set_title(f"Target Distance Map {idx}")
    axs[1].imshow(output, cmap='viridis')
    axs[1].set_title(f"Predicted Distance Map {idx}")
    plt.show()

# Plot 10 randomly selected maps
for idx in random_indices:
    target_map = targets_list[idx].squeeze().numpy()
    output_map = outputs_list[idx].squeeze().numpy()
    plot_distance_maps(target_map, output_map, idx)
