In [None]:
import os
import numpy as np
import pandas as pd
pd.set_option('display.max_columns', None)
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
import datasets
from datasets import load_from_disk
import transformers
from transformers import AutoTokenizer, TrainingArguments, Trainer
import torch
from sklearn.model_selection import KFold
import wandb
from arnie.mfe import mfe
from arnie.bpps import bpps
from typing import Tuple, Sequence
import subprocess
import pickle
from modelling.modelling import BertForTokenRegression
from modelling.collator import DataCollatorForTokenRegression
from modelling.utils import load_dict_from_file
from transformers.models.bert.modeling_bert import BertConfig


In [None]:
assert torch.cuda.is_available()

In [None]:
INPUT_DIR = '../input'
OUTPUT_DIR = '../output'

In [None]:
dataset_data = load_from_disk(f"{INPUT_DIR}/stanford-ribonanza-rna-folding/train_data")

In [None]:
def reactivity_columns_to_array(row, filter_substr_and_output_keys: Sequence[Tuple[str, str]]):
    output = {}
    for filter_substr, output_key in filter_substr_and_output_keys:
        output[output_key] = np.array([row[k] for k in row if filter_substr in k][:len(row['sequence'])], dtype=np.float32)
    return output

In [None]:
dataset_data = dataset_data.map(
    reactivity_columns_to_array,
    fn_kwargs={
        'filter_substr_and_output_keys': (('reactivity_0', 'reactivity'), ('reactivity_error_0', 'reactivity_error'))
    },
    num_proc=14,
    remove_columns=[k for k in dataset_data.features if any(x in k for x in ('reactivity_0', 'reactivity_error_0'))]
)

In [None]:
dataset_data = dataset_data.map(lambda row: {'stn_nts': np.array(row['reactivity'], dtype=np.float32) / np.array(row['reactivity_error'], dtype=np.float32)}, num_proc=14)

In [None]:
import math

def min_val_ignore_nan(list_of_lists):
    # Flatten the list of lists and ignore NaN values
    flat_list = [item for sublist in list_of_lists for item in sublist if not math.isnan(item)]
    
    # Return the minimum value
    return min(flat_list)

In [None]:
dataset_data_2A3_MaP = dataset_data.filter(lambda x: x['experiment_type'] == '2A3_MaP', num_proc=14)
dataset_data_DMS_MaP = dataset_data.filter(lambda x: x['experiment_type'] == 'DMS_MaP', num_proc=14)

In [None]:
assert dataset_data_2A3_MaP['sequence_id'] == dataset_data_DMS_MaP['sequence_id']

In [None]:
def add_suffix_to_columns(dataset, suffix, exceptions):
    assert isinstance(exceptions, set)
    for col in dataset.column_names:
        if col in exceptions:
            continue
        dataset = dataset.rename_column(col, f"{col}_{suffix}")
    return dataset

In [None]:
do_not_rename = set(['sequence_id', 'sequence'])
dataset_data_2A3_MaP = add_suffix_to_columns(dataset_data_2A3_MaP, "2A3_MaP", do_not_rename)
dataset_data_DMS_MaP = add_suffix_to_columns(dataset_data_DMS_MaP, "DMS_MaP", do_not_rename)

In [None]:
dataset_data = datasets.concatenate_datasets([dataset_data_2A3_MaP, dataset_data_DMS_MaP.remove_columns(['sequence_id', 'sequence'])], axis=1)

In [None]:
MODEL_PATH = "../input/huggingface/SpliceBERT/SpliceBERT.510nt/"
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_PATH)
MAX_LEN = 207

In [None]:
TOKENIZER.tokenize(' '.join(list('AGUG'.upper().replace("U", "T"))))

In [None]:
def preprocess(row, tokenizer, max_len, is_train=True):
    seq = row['sequence']
    seq = ' '.join(list(seq.upper().replace("U", "T"))) # U -> T and add whitespace
    preprocessed_row = tokenizer(seq, truncation=True, max_length=max_len, add_special_tokens=False)
    if 'reactivity_2A3_MaP' in row:
        assert 'reactivity_DMS_MaP' in row
        if is_train:
            labels = np.array([
                [x if row['stn_nts_2A3_MaP'][idx] > -10000.0 else np.nan for idx, x in enumerate(row['reactivity_2A3_MaP'])][:max_len],
                [x if row['stn_nts_DMS_MaP'][idx] > -10000.0 else np.nan for idx, x in enumerate(row['reactivity_DMS_MaP'])][:max_len]
            ]).T
        else:
            labels = np.array([
                [x if x is not None else np.nan for idx, x in enumerate(row['reactivity_2A3_MaP'])][:max_len],
                [x if x is not None else np.nan for idx, x in enumerate(row['reactivity_DMS_MaP'])][:max_len]
            ]).T
        labels = np.clip(labels, 0, 1)
        labels = labels.tolist()
        preprocessed_row['labels'] = labels
    else:
        preprocessed_row['labels'] = np.zeros((len(preprocessed_row['input_ids']), 2)).tolist()
    preprocessed_row['len'] = len(preprocessed_row['input_ids'])
    
    return preprocessed_row

In [None]:
import torch
import torch.utils.checkpoint
from torch import nn

class PositionalEmbedding(nn.Module):
    """From https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/transfo_xl/modeling_transfo_xl.py#L178
    """
    def __init__(self, demb):
        super().__init__()

        self.demb = demb

        inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, pos_seq, bsz=None):
        sinusoid_inp = torch.outer(pos_seq, self.inv_freq)
        pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)

        if bsz is not None:
            return pos_emb[:, None, :].expand(-1, bsz, -1)
        else:
            return pos_emb[:, None, :]

In [None]:
folds = KFold(n_splits=5, shuffle=True, random_state=9000)
splits = folds.split(dataset_data)
train_idx, val_idx = next(iter(splits))
# train_idx = train_idx[:10000]
dataset_train = dataset_data.select(train_idx)
dataset_val = dataset_data.select(val_idx)

In [None]:
# leave only the most reliable sequences
dataset_move_to_train = dataset_val.filter(lambda x: not ((x['SN_filter_2A3_MaP'] > 0) & (x['SN_filter_DMS_MaP'] > 0)), num_proc=14)
dataset_val = dataset_val.filter(lambda x: ((x['SN_filter_2A3_MaP'] > 0) & (x['SN_filter_DMS_MaP'] > 0)), num_proc=14)
dataset_train = datasets.concatenate_datasets([dataset_train, dataset_move_to_train])
DATASET_LEN = len(dataset_train)

In [None]:
len(dataset_move_to_train)

In [None]:
dataset_train = dataset_train.map(
    lambda x: {
        'sn_2a3_map': np.nanmean(x['reactivity_2A3_MaP']) / np.nanmean(x['reactivity_error_2A3_MaP']),
        'sn_dms_map': np.nanmean(x['reactivity_DMS_MaP']) / np.nanmean(x['reactivity_error_DMS_MaP'])
    },
    num_proc=14
)


In [None]:
dataset_val = dataset_val.map(
    lambda x: {
        'sn_2a3_map': np.nanmean(x['reactivity_2A3_MaP']) / np.nanmean(x['reactivity_error_2A3_MaP']),
        'sn_dms_map': np.nanmean(x['reactivity_DMS_MaP']) / np.nanmean(x['reactivity_error_DMS_MaP'])
    },
    num_proc=14
)

In [None]:
THRESHOLD_START = 0.75
dataset_train = dataset_train.filter(lambda x: (x['sn_2a3_map'] >= THRESHOLD_START) | (x['sn_dms_map'] >= THRESHOLD_START), num_proc=14)

In [None]:
len(dataset_train)

In [None]:
dataset_train = dataset_train.map(
    preprocess,
    fn_kwargs={'tokenizer': TOKENIZER, 'max_len': MAX_LEN, 'is_train': True},
    num_proc=14,
)

dataset_val = dataset_val.map(
    preprocess,
    fn_kwargs={'tokenizer': TOKENIZER, 'max_len': MAX_LEN, 'is_train': False},
    num_proc=14,
)

In [None]:
def run_CapR(example, output_dir, column_suffix='', max_seq_len=1024, cache_dir=None):
    """From https://www.kaggle.com/code/ratthachat/preprocessing-deep-learning-input-from-rna-string/notebook
    """
    if cache_dir is not None:
        seq_id = example['sequence_id']
        dir_path = f'{cache_dir}/{column_suffix}/{seq_id[0]}/{seq_id[1]}/{seq_id[2]}/{seq_id[3]}'
        file_path = os.path.join(dir_path, f'{seq_id}.pickle')
        if os.path.exists(file_path):
            with open(file_path, 'rb') as f:
                cached_dict = pickle.load(f)
            return cached_dict

    rna_id = example['sequence_id']
    rna_string = example['sequence']
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    in_file = f'{output_dir}/{rna_id}.fa'
    out_file = f'{output_dir}/{rna_id}.out'

    fp = open(in_file, "w")
    fp.write('>%s\n' % rna_id)
    fp.write(rna_string)
    fp.close()

    subprocess.run('CapR %s %s %d' % (in_file, out_file, max_seq_len),
                   shell=True,capture_output=False)

    df = pd.read_csv(
        out_file,
        skiprows=1,
        header=None,
        delim_whitespace=True,
    )
    df2 = df.T[1:]
    df2.columns = df.T.iloc[0].values
    os.remove(in_file)
    os.remove(out_file)
    
    res = {f'capr_structure_probs{column_suffix}': df2.values.astype(np.float32).tolist()}
    
    if cache_dir is not None:
        os.makedirs(dir_path, exist_ok=True)
        try:
            with open(file_path, 'wb') as f:
                pickle.dump(res, f, protocol=pickle.HIGHEST_PROTOCOL)
        except Exception as e:
            if os.path.exists(file_path):
                print(f'removing {file_path}')
                os.remove(file_path)
            raise e
    return res

In [None]:
dataset_train = dataset_train.map(
    run_CapR,
    fn_kwargs={
        'output_dir': '../output/CapR',
        'cache_dir': '../output/CapR_cache',
    },
    num_proc=14
)

dataset_val = dataset_val.map(
    run_CapR,
    fn_kwargs={
        'output_dir': '../output/CapR',
        'cache_dir': '../output/CapR_cache',
    },
    num_proc=14
)

In [None]:
def get_arnie_bpp(example, column_suffix='', cache_dir=None, file_path_only=False, **kwargs):
    if file_path_only:
        assert cache_dir is not None
    if cache_dir is not None:
        seq_id = example['sequence_id']
        dir_path = f'{cache_dir}/{column_suffix}/{seq_id[0]}/{seq_id[1]}/{seq_id[2]}/{seq_id[3]}'
        file_path = os.path.join(dir_path, f'{seq_id}.pickle')
        if os.path.exists(file_path):
            if file_path_only:
                return {f"file_path{column_suffix}": file_path}
            with open(file_path, 'rb') as f:
                cached_dict = pickle.load(f)
            return cached_dict

    non_sparse = bpps(example['sequence'], **kwargs)
    indices = [x.tolist() for x in np.nonzero(non_sparse)]
    values = non_sparse[non_sparse != 0].tolist()
    res = {
        f"indices{column_suffix}": indices,
        f"values{column_suffix}": values
    }
    if cache_dir is not None:
        os.makedirs(dir_path, exist_ok=True)
        try:
            with open(file_path, 'wb') as f:
                pickle.dump(res, f, protocol=pickle.HIGHEST_PROTOCOL)
        except Exception as e:
            if os.path.exists(file_path):
                print(f'removing {file_path}')
                os.remove(file_path)
            raise e
    if file_path_only:
        return {f"file_path{column_suffix}": file_path}
    return res

In [None]:
def get_bpps(example, column_suffix='', cache_dir=None, file_path_only=False, **kwargs):
    if file_path_only:
        assert cache_dir is not None
    if cache_dir is not None:
        seq_id = example['sequence_id']
        dir_path = f'{cache_dir}/{column_suffix}/{seq_id[0]}/{seq_id[1]}/{seq_id[2]}/{seq_id[3]}'
        file_path = os.path.join(dir_path, f'{seq_id}.pickle')
        if os.path.exists(file_path):
            if file_path_only:
                return {f"file_path{column_suffix}": file_path}
            with open(file_path, 'rb') as f:
                cached_dict = pickle.load(f)
            return cached_dict

    mfes = [
        get_arnie_bpp(example, package='eternafold', column_suffix='', cache_dir=None, **kwargs),
        # get_arnie_bpp(example, package='vienna_2', column_suffix='', cache_dir=None, **kwargs),
        # get_arnie_bpp(example, package='contrafold_2', column_suffix='', cache_dir=None, **kwargs),
        # get_arnie_bpp(example, package='rnastructure', column_suffix='', cache_dir=None, **kwargs),
        # get_arnie_bpp(example, package='rnasoft_07', column_suffix='', cache_dir=None, **kwargs)
    ]
    sequence_length = len(example['sequence'])
    adj_matrix = torch.sparse_coo_tensor(**mfes[0], size=(sequence_length, sequence_length))
    for idx in range(1, len(mfes)):
        adj_matrix += torch.sparse_coo_tensor(**mfes[idx], size=(sequence_length, sequence_length))
    adj_matrix = adj_matrix.coalesce()
    adj_matrix /= len(mfes)
    indices = adj_matrix.indices().tolist()
    values = adj_matrix.values().tolist()
    res = {
        f"indices{column_suffix}": indices,
        f"values{column_suffix}": values
    }
    if cache_dir is not None:
        os.makedirs(dir_path, exist_ok=True)
        try:
            with open(file_path, 'wb') as f:
                pickle.dump(res, f, protocol=pickle.HIGHEST_PROTOCOL)
        except Exception as e:
            if os.path.exists(file_path):
                print(f'removing {file_path}')
                os.remove(file_path)
            raise e
    if file_path_only:
        return {f"file_path{column_suffix}": file_path}
    return res

In [None]:
def create_adjacency_matrix(paren_string):
    stack = []
    rows = []
    cols = []
    
    for i, char in enumerate(paren_string):
        if char == '(':
            stack.append(i)
        elif char == ')':
            if stack:
                open_index = stack.pop()
                rows.append(open_index)
                cols.append(i)
    
    indices = torch.tensor([rows, cols], dtype=torch.long)
    values = torch.ones(len(rows), dtype=torch.float32)

    size = len(paren_string)
    sparse_tensor = torch.sparse_coo_tensor(indices, values, (size, size)).coalesce()

    return sparse_tensor

In [None]:
def get_arnie_mfe(example, column_suffix='', cache_dir=None, file_path_only=False, **kwargs):
    if file_path_only:
        assert cache_dir is not None
    if cache_dir is not None:
        seq_id = example['sequence_id']
        dir_path = f'{cache_dir}/{column_suffix}/{seq_id[0]}/{seq_id[1]}/{seq_id[2]}/{seq_id[3]}'
        file_path = os.path.join(dir_path, f'{seq_id}.pickle')
        if os.path.exists(file_path):
            if file_path_only:
                return {f"file_path{column_suffix}": file_path}
            with open(file_path, 'rb') as f:
                cached_dict = pickle.load(f)
            return cached_dict

    mfe_string = mfe(example['sequence'], **kwargs)
    adj_matrix = create_adjacency_matrix(mfe_string)
    indices = adj_matrix.indices().tolist()
    values = adj_matrix.values().tolist()
    res = {
        f"indices{column_suffix}": indices,
        f"values{column_suffix}": values
    }
    if cache_dir is not None:
        os.makedirs(dir_path, exist_ok=True)
        try:
            with open(file_path, 'wb') as f:
                pickle.dump(res, f, protocol=pickle.HIGHEST_PROTOCOL)
        except Exception as e:
            if os.path.exists(file_path):
                print(f'removing {file_path}')
                os.remove(file_path)
            raise e
    if file_path_only:
        return {f"file_path{column_suffix}": file_path}
    return res

In [None]:
def get_mfes(example, cache_dir=None, column_suffix='', file_path_only=False, **kwargs):
    if file_path_only:
        assert cache_dir is not None
    if cache_dir is not None:
        seq_id = example['sequence_id']
        dir_path = f'{cache_dir}/{column_suffix}/{seq_id[0]}/{seq_id[1]}/{seq_id[2]}/{seq_id[3]}'
        file_path = os.path.join(dir_path, f'{seq_id}.pickle')
        if os.path.exists(file_path):
            if file_path_only:
                return {f"file_path{column_suffix}": file_path}
            with open(file_path, 'rb') as f:
                cached_dict = pickle.load(f)
            return cached_dict

    mfes = [
        get_arnie_mfe(example, package='eternafold', column_suffix='', cache_dir=None, **kwargs),
        # get_arnie_mfe(example, package='vienna_2', column_suffix='', cache_dir=None, **kwargs),
        # get_arnie_mfe(example, package='contrafold_2', column_suffix='', cache_dir=None, **kwargs),
        # get_arnie_mfe(example, package='rnastructure', column_suffix='', cache_dir=None, **kwargs)
    ]
    sequence_length = len(example['sequence'])
    adj_matrix = torch.sparse_coo_tensor(**mfes[0], size=(sequence_length, sequence_length))
    for idx in range(1, len(mfes)):
        adj_matrix += torch.sparse_coo_tensor(**mfes[idx], size=(sequence_length, sequence_length))
    adj_matrix = adj_matrix.coalesce()
    adj_matrix /= len(mfes)
    indices = adj_matrix.indices().tolist()
    values = adj_matrix.values().tolist()
    res = {
        f"indices{column_suffix}": indices,
        f"values{column_suffix}": values
    }
    if cache_dir is not None:
        os.makedirs(dir_path, exist_ok=True)
        try:
            with open(file_path, 'wb') as f:
                pickle.dump(res, f, protocol=pickle.HIGHEST_PROTOCOL)
        except Exception as e:
            if os.path.exists(file_path):
                print(f'removing {file_path}')
                os.remove(file_path)
            raise e
    if file_path_only:
        return {f"file_path{column_suffix}": file_path}
    return res

In [None]:
def create_mfe_shortcuts(paren_string):
    stack = []
    result = []
    
    for i, char in enumerate(paren_string):
        if char == '(':
            stack.append(i)
        elif char == ')':
            if stack:
                open_index = stack.pop()
                result.append((open_index, i))

    return result

In [None]:
def create_distance_matrix(n, shortcuts):
    matrix = np.full((n, n), np.inf)
    
    np.fill_diagonal(matrix, 0)
    for i in range(n - 1):
        matrix[i][i + 1] = 1
        matrix[i + 1][i] = 1

    for shortcut in shortcuts:
        start, end = shortcut
        matrix[start][end] = 1
        matrix[end][start] = 1

    return matrix

In [None]:
def floyd_warshall_vectorized(matrix):
    n = len(matrix)
    for k in range(n):
        matrix = np.minimum(matrix, matrix[np.newaxis, k, :] + matrix[:, k, np.newaxis])
    return matrix

In [None]:
def get_distance_matrix_mfe(mfe_str, size=None):
    if size is None:
        size = len(mfe_str)
    distance_matrix = create_distance_matrix(size, create_mfe_shortcuts(mfe_str))
    shortest_paths_matrix = floyd_warshall_vectorized(distance_matrix)
    return shortest_paths_matrix.astype(np.int16)

In [None]:
def get_sparse_distance_matrix(mfe_str):
    matrix = get_distance_matrix_mfe(mfe_str)
    n = matrix.shape[0]
    rows = []
    cols = []
    values = []
    for i in range(n):
        for j in range(i+1, n):
            if matrix[i][j] != abs(j - i):
                rows.append(i)
                cols.append(j)
                values.append(matrix[i][j])

    res = {
        f"indices": [rows, cols],
        f"values": values
    }
    return res

In [None]:
def reconstruct_matrix(indices, values, size, base_matrix=None, sign=False):
    if base_matrix is not None:
        matrix = base_matrix.copy()
    else:
        if sign:
            matrix = np.fromfunction(lambda i, j: (i - j), (size, size)) # negative above the diagonal
        else:
            matrix = np.fromfunction(lambda i, j: abs(j - i), (size, size))
    matrix = torch.LongTensor(matrix)
    
    for i, j, value in zip(indices[0], indices[1], values):
        if sign:
            matrix[i][j] = -value
        else:
            matrix[i][j] = value
        matrix[j][i] = value

    return matrix

In [None]:
def get_arnie_mfe_distance(example, column_suffix='', cache_dir=None, file_path_only=False, **kwargs):
    if file_path_only:
        assert cache_dir is not None
    if cache_dir is not None:
        seq_id = example['sequence_id']
        dir_path = f'{cache_dir}/{column_suffix}/{seq_id[0]}/{seq_id[1]}/{seq_id[2]}/{seq_id[3]}'
        file_path = os.path.join(dir_path, f'{seq_id}.pickle')
        if os.path.exists(file_path):
            if file_path_only:
                return {f"file_path{column_suffix}": file_path}
            with open(file_path, 'rb') as f:
                cached_dict = pickle.load(f)
            return cached_dict

    mfe_string = mfe(example['sequence'], **kwargs)
    distance_matrix = get_sparse_distance_matrix(mfe_string)
    mfe_mapping = {
        # 0 is resereved for padding
        '.': 1,
        '(': 2,
        ')': 3,
        '[': 4,
        ']': 5
    }
    
    res = {
        f"input_ids{column_suffix}": [mfe_mapping[c] for c in mfe_string],
        f"indices{column_suffix}": distance_matrix['indices'],
        f"values{column_suffix}": distance_matrix['values']
    }
    if cache_dir is not None:
        os.makedirs(dir_path, exist_ok=True)
        try:
            with open(file_path, 'wb') as f:
                pickle.dump(res, f, protocol=pickle.HIGHEST_PROTOCOL)
        except Exception as e:
            if os.path.exists(file_path):
                print(f'removing {file_path}')
                os.remove(file_path)
            raise e
    if file_path_only:
        return {f"file_path{column_suffix}": file_path}
    return res

In [None]:
import subprocess

def get_predicted_loop_type(id, sequence, structure, debug=False):
    pid = os.getpid()
    tmp_in_file = f'/home/rapids/notebooks/host/output/tmp/{id}_{pid}.dbn'
    tmp_out_file = f'/home/rapids/notebooks/host/output/tmp/{id}_{pid}.st'
    with open(tmp_in_file, 'w') as file:
        file.write(sequence + '\n')
        file.write(structure + '\n')
    perl_script_path = '/home/rapids/notebooks/toolkits/bpRNA/bpRNA.pl'
    working_directory = '../output/tmp/'
    subprocess.run(['perl', perl_script_path, tmp_in_file], cwd=working_directory)
    result = [l.strip('\n') for l in open(tmp_out_file)]
    if debug:
        print(sequence)
        print(structure)
        print(result[5])
    else:
        os.remove(tmp_in_file)
        os.remove(tmp_out_file)
    return id, structure, result[5]

In [None]:
def get_arnie_predicted_loop_type(example, column_suffix='', cache_dir=None, **kwargs):
    if cache_dir is not None:
        seq_id = example['sequence_id']
        dir_path = f'{cache_dir}/{column_suffix}/{seq_id[0]}/{seq_id[1]}/{seq_id[2]}/{seq_id[3]}'
        file_path = os.path.join(dir_path, f'{seq_id}.pickle')
        if os.path.exists(file_path):
            with open(file_path, 'rb') as f:
                cached_dict = pickle.load(f)
            return cached_dict

    mfe_string = mfe(example['sequence'], **kwargs)
    _, _, structure = get_predicted_loop_type(example['sequence_id'], example['sequence'], mfe_string)
    structure_mapping = {
        # 0 is resereved for padding
        'S': 1, # paired "Stem"
        'M': 2, # Multiloop
        'I': 3, # Internal loop
        'B': 4, # Bulge
        'H': 5, # Hairpin loop
        'K': 6, # pseudoKnot
        'E': 7, # dangling End
        'X': 8, # eXternal loop
    }
    
    res = {
        f"structure{column_suffix}": [structure_mapping[c] for c in structure]
    }
    if cache_dir is not None:
        os.makedirs(dir_path, exist_ok=True)
        try:
            with open(file_path, 'wb') as f:
                pickle.dump(res, f, protocol=pickle.HIGHEST_PROTOCOL)
        except Exception as e:
            if os.path.exists(file_path):
                print(f'removing {file_path}')
                os.remove(file_path)
            raise e
    return res

In [None]:
dataset_train = dataset_train.map(
    get_bpps,
    fn_kwargs={
        'cache_dir': '../output/arnie/eternafold',
        'column_suffix': '_bpp',
        'file_path_only': True,
    },
    num_proc=14
)
dataset_val = dataset_val.map(
    get_bpps,
    fn_kwargs={
        'cache_dir': '../output/arnie/eternafold',
        'column_suffix': '_bpp',
        'file_path_only': True,
    },
    num_proc=14
)

In [None]:
dataset_train = dataset_train.map(
    get_arnie_mfe_distance,
    fn_kwargs={
        'cache_dir': '../output/arnie/eternafold',
        'package': 'eternafold',
        'column_suffix': '_mfe_distance',
        'file_path_only': True,
    },
    num_proc=14
)
dataset_val = dataset_val.map(
    get_arnie_mfe_distance,
    fn_kwargs={
        'cache_dir': '../output/arnie/eternafold',
        'package': 'eternafold',
        'column_suffix': '_mfe_distance',
        'file_path_only': True,
    },
    num_proc=14
)

In [None]:
dataset_train = dataset_train.map(
    get_arnie_predicted_loop_type,
    fn_kwargs={
        'cache_dir': '../output/arnie/eternafold_bprna',
        'package': 'eternafold',
        'column_suffix': '_eternafold',
    },
    num_proc=14
)
dataset_val = dataset_val.map(
    get_arnie_predicted_loop_type,
    fn_kwargs={
        'cache_dir': '../output/arnie/eternafold_bprna',
        'package': 'eternafold',
        'column_suffix': '_eternafold',
    },
    num_proc=14
)

In [None]:
COLS_TO_KEEP = (
    'input_ids',
    'token_type_ids',
    'attention_mask',
    'labels',
    'file_path_bpp',
    'file_path_mfe_distance',
    'structure_eternafold',
    'sn_2a3_map',
    'sn_dms_map',
    'stn_nts_2A3_MaP',
    'stn_nts_DMS_MaP',
    'capr_structure_probs',
    'len',
)
cols_to_remove = [x for x in dataset_train.column_names if x not in COLS_TO_KEEP]
dataset_train = dataset_train.remove_columns(cols_to_remove)

cols_to_remove = [x for x in dataset_val.column_names if x not in COLS_TO_KEEP]
dataset_val = dataset_val.remove_columns(cols_to_remove)

In [None]:
CONFIG = BertConfig.from_pretrained(
    MODEL_PATH,
    output_hidden_states=False,
    num_hidden_layers=6,
    position_embedding_type='relative_key_query',
    classifier_dropout=0.5
)

In [None]:
CONFIG

In [None]:
MODEL = BertForTokenRegression(CONFIG)

In [None]:
FREEZE_EMBEDDINGS = True
INIT_SINUSOIDAL_EMBEDDINGS = True
SINUSOIDAL_DISTANCE_EMBEDDINGS=True
FREEZE_LAYERS = 0
if INIT_SINUSOIDAL_EMBEDDINGS:
    print('Sinusoidal embeddings.')
    num_pos_emb, emb_size = MODEL.bert.embeddings.position_embeddings.weight.shape
    MODEL.bert.embeddings.position_embeddings.weight = torch.nn.Parameter(PositionalEmbedding(emb_size)(torch.arange(0, num_pos_emb)).squeeze(1).clone())
if SINUSOIDAL_DISTANCE_EMBEDDINGS:
    print('Sinusoidal distance embeddings.')
    for layer in MODEL.bert.encoder.layer:
        _, emb_size = layer.attention.self.distance_embedding.weight.shape
        num_pos_emb = CONFIG.max_position_embeddings
        layer.attention.self.distance_embedding.weight = torch.nn.Parameter(PositionalEmbedding(emb_size)(torch.arange(-num_pos_emb + 1, num_pos_emb)).squeeze(1).clone())
        layer.attention.self.distance_embedding.requires_grad=False
        layer.attention2.self.distance_embedding.weight = torch.nn.Parameter(PositionalEmbedding(emb_size)(torch.arange(-num_pos_emb + 1, num_pos_emb)).squeeze(1).clone())
        layer.attention2.self.distance_embedding.requires_grad=False
if FREEZE_EMBEDDINGS:
    print('Freezing embeddings.')
    MODEL.bert.embeddings.position_embeddings.requires_grad=False
if FREEZE_LAYERS>0:
    print(f'Freezing {FREEZE_LAYERS} layers.')
    for layer in MODEL.bert.encoder.layer[:FREEZE_LAYERS]:
        for param in layer.parameters():
            param.requires_grad = False

In [None]:
def mae(prediction_output: "PredictionOutput"):
    
    predictions_2a3_map = prediction_output.predictions[:,:, 0]
    labels_2a3_map = prediction_output.label_ids[:,:, 0]
    labels_2a3_map = np.clip(labels_2a3_map, 0, 1)
    valid_map = (~np.isnan(labels_2a3_map)) & ~np.isclose(predictions_2a3_map, -100)
    mae_2a3_map = np.abs(predictions_2a3_map[valid_map] - labels_2a3_map[valid_map])
    mae_2a3_map = np.mean(mae_2a3_map)
    
    predictions_dms_map = prediction_output.predictions[:,:, 1]
    labels_dms_map = prediction_output.label_ids[:,:, 1]
    labels_dms_map = np.clip(labels_dms_map, 0, 1)
    valid_map = (~np.isnan(labels_dms_map)) & ~np.isclose(predictions_dms_map, -100)
    mae_dms_map = np.abs(predictions_dms_map[valid_map] - labels_dms_map[valid_map])
    mae_dms_map = np.mean(mae_dms_map)
    
    mae = (mae_2a3_map + mae_dms_map) / 2
    return {'mae': mae, 'mae_2a3_map': mae_2a3_map, 'mae_dms_map': mae_dms_map}

In [None]:
from transformers import TrainerCallback, TrainerState, TrainerControl

class LoggerLRCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        for i, param_group in enumerate(kwargs['optimizer'].param_groups):
            # Log the learning rate with a custom key format
            logs[f"lr_{i}"] = param_group['lr']

In [None]:
param_optimizer = list(MODEL.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'DynamicLayerNorm2d.weight', 'DynamicLayerNorm2d.bias', 'BatchNorm1d.weight', 'BatchNorm1d.bias']
output_params = ['dense.weight', 'dense.bias', 'classifier.weight', 'classifier.bias']
lr = 1e-3
output_lr = 1e-3
weight_decay = 0.05
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay) and (n not in output_params)], 'weight_decay': weight_decay, 'lr': lr},
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay) and (n in output_params)], 'weight_decay': weight_decay, 'lr': output_lr},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and (n not in output_params)], 'weight_decay': 0.0, 'lr': lr},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and (n in output_params)], 'weight_decay': 0.0, 'lr': output_lr},
]
optimizer = transformers.AdamW(
    optimizer_grouped_parameters
)



In [None]:
SIGN = True

In [None]:
EVAL_STEPS = 1000
EVAL_STRATEGY = 'epoch'
# assert (EVAL_STRATEGY == 'epoch') == (EVAL_STEPS == None)

run = wandb.init(
    project="stanford-ribonanza-rna-folding", 
    tags=[
        f"DATASET_LEN_{DATASET_LEN}",
        f"MAX_LEN_{MAX_LEN}",
        f"freeze_{FREEZE_LAYERS}",
        f"freeze__emb_{FREEZE_EMBEDDINGS}",
        f"sinusoidal_emb_{INIT_SINUSOIDAL_EMBEDDINGS}",
        f"SINUSOIDAL_DISTANCE_EMBEDDINGS_{SINUSOIDAL_DISTANCE_EMBEDDINGS}"
    ],
    group=MODEL_PATH
)

training_args = TrainingArguments(
    warmup_ratio=0.1, 
    # learning_rate=5e-4, # optimizers'lr overrides this
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=15,
    report_to='wandb',
    output_dir = os.path.join(OUTPUT_DIR, f'checkpoints/{run.id}'),
    overwrite_output_dir=True,
    fp16=True,
    gradient_accumulation_steps=2,
    logging_steps=EVAL_STEPS,
    evaluation_strategy=EVAL_STRATEGY,
    eval_steps=EVAL_STEPS,
    save_strategy=EVAL_STRATEGY,
    save_steps=EVAL_STEPS,
    load_best_model_at_end=False,
    greater_is_better=False,
    metric_for_best_model='mae',
    lr_scheduler_type='cosine',
    save_total_limit=2,
    remove_unused_columns=False,
    dataloader_num_workers=8,
    max_grad_norm=5,
    # group_by_length=True,
    # length_column_name='len',
)

trainer = Trainer(
    model=MODEL,
    args=training_args,
    tokenizer=TOKENIZER,
    data_collator=DataCollatorForTokenRegression(tokenizer=TOKENIZER, sign=SIGN),
    train_dataset=dataset_train,
    eval_dataset=dataset_val,
    compute_metrics=mae,
    optimizers=(optimizer, None),
    callbacks=[LoggerLRCallback()]
)


trainer.train()
trainer.save_model(f'../output/checkpoints/{run.id}/final')
run.finish()

In [None]:
trainer = Trainer(
    model=MODEL,
    args=training_args,
    data_collator=DataCollatorForTokenRegression(tokenizer=TOKENIZER, sign=SIGN)
)
prediction_output = trainer.predict(dataset_val)
mae(prediction_output)

In [None]:
np.save(
    os.path.join(OUTPUT_DIR, f'checkpoints/{run.id}/pseudo_labels_val.npy'),
    prediction_output.predictions.transpose(0, 2, 1)
)

In [None]:
sns.heatmap(DataCollatorForTokenRegression(tokenizer=TOKENIZER, sign=SIGN)([dataset_train[3], dataset_train[100000]])['attention_injection'][0][2][0])

## test

In [None]:
dataset_test = load_from_disk("../input/stanford-ribonanza-rna-folding/test_sequences")

In [None]:
dataset_test = dataset_test.map(
    preprocess,
    fn_kwargs={'tokenizer': TOKENIZER, 'max_len': None},
    num_proc=14
)

dataset_test = dataset_test.map(
    run_CapR,
    fn_kwargs={
        'output_dir': '../output/CapR',
        'cache_dir': '../output/CapR_cache',
    },
    num_proc=14
)

dataset_test = dataset_test.map(
    get_arnie_mfe_distance,
    fn_kwargs={
        'cache_dir': '../output/arnie/eternafold',
        'package': 'eternafold',
        'column_suffix': '_mfe_distance',
        'file_path_only': True,
    },
    num_proc=14
)

dataset_test = dataset_test.map(
    get_arnie_predicted_loop_type,
    fn_kwargs={
        'cache_dir': '../output/arnie/eternafold_bprna',
        'package': 'eternafold',
        'column_suffix': '_eternafold',
    },
    num_proc=14
)

cols_to_remove = [x for x in dataset_test.column_names if x not in (
    'input_ids',
    'token_type_ids',
    'attention_mask',
    'file_path_bpp',
    'file_path_mfe_distance',
    'capr_structure_probs',
    'structure_eternafold'
)]
dataset_test = dataset_test.map(
    get_bpps,
    fn_kwargs={
        'cache_dir': '../output/arnie/eternafold',
        'column_suffix': '_bpp',
        'file_path_only': True,
    },
    num_proc=14,
    remove_columns=cols_to_remove
)


In [None]:
prediction_output = trainer.predict(dataset_test)

In [None]:
np.save(
    os.path.join(OUTPUT_DIR, f'checkpoints/{run.id}/pseudo_labels_test.npy'),
    prediction_output.predictions.transpose(0, 2, 1)
)

In [None]:
df_test = pd.read_csv('../input/stanford-ribonanza-rna-folding/test_sequences.csv')

In [None]:
result_DMS_MaP = []
result_2A3_MaP = []
for index, row in tqdm(df_test.iterrows()):
    l = row['id_max'] - row['id_min'] + 1
    pred = prediction_output.predictions[index, :, 0].reshape(-1)[:l].tolist()
    assert l == len(pred), f'{index}'
    result_2A3_MaP += prediction_output.predictions[index, :, 0].reshape(-1)[:l].tolist()
    
    pred = prediction_output.predictions[index, :, 1].reshape(-1)[:l].tolist()
    assert l == len(pred), f'{index}'
    result_DMS_MaP += prediction_output.predictions[index, :, 1].reshape(-1)[:l].tolist()


In [None]:
df_submission = pd.read_parquet('../input/stanford-ribonanza-rna-folding/sample_submission.parquet')

In [None]:
df_submission['reactivity_DMS_MaP'] = np.array(result_DMS_MaP)
df_submission['reactivity_2A3_MaP'] = np.array(result_2A3_MaP)

In [None]:
df_submission

In [None]:
output_path = os.path.join(OUTPUT_DIR, f'checkpoints/{run.id}/submit.parquet')
df_submission.to_parquet(output_path, index=False)

# output_path = os.path.join(OUTPUT_DIR, f'checkpoints/{run.id}/submit.csv')
# df_submission.to_csv(output_path, index=False)

In [None]:
!echo {run.id}

In [None]:
#some parameters
font_size=6
id1=269545321
id2=269724007
reshape1=391
reshape2=457
#get predictions
pred_DMS=df_submission[id1:id2+1]['reactivity_DMS_MaP'].to_numpy().reshape(reshape1,reshape2)
pred_2A3=df_submission[id1:id2+1]['reactivity_2A3_MaP'].to_numpy().reshape(reshape1,reshape2)
#plot mutate and map
fig = plt.figure(figsize=(12, 16))
plt.subplot(121)
plt.title(f'reactivity_DMS_MaP', fontsize=font_size)
plt.imshow(pred_DMS,vmin=0,vmax=1, cmap='gray_r')
plt.subplot(122)
plt.title(f'reactivity_2A3_MaP', fontsize=font_size)
plt.imshow(pred_2A3,vmin=0,vmax=1, cmap='gray_r')
plt.tight_layout()
plt.show()


In [None]:
!kaggle competitions submit -c stanford-ribonanza-rna-folding -f {output_path} -m {run.id}