# Purpose

* Take  what I learned in ```ribonanza-3d-finetune-v2.ipynb``` and go further with it
* Make some improvements, specially to null handling
* Expand the model to do better and use newer technologies like transfomers
* Better evaluation metrics


# Imports

In [None]:
import general_utils
import utils

import time

import warnings
warnings.filterwarnings("ignore")

import os
import sys
import random
import pickle
import yaml

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

start_time = time.time()

import GPUtil

GPUs = GPUtil.getGPUs()
if GPUs:
    gpu = GPUs[0]
    print(f"Running on: {gpu.name}")
else:
    print("Running on CPU")

# 1. CONFIG & SEED

In [None]:
def set_seed(seed: int):
    """Set a random seed for Python, NumPy, PyTorch (CPU & GPU) to ensure reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Example configuration (you can load this from a YAML, JSON, etc.)
config = {
    "seed": 42,
    "cutoff_date": "2020-01-01",
    "test_cutoff_date": "2022-05-01",
    "max_len": 384,
    "batch_size": 1,

    # change to kaggle
    "model_config_path": "ribonanzanet2d-final/configs/pairwise.yaml",
    
    "max_len_filter": 9999999,
    "min_len_filter": 10,
    
    # change to kaggle
    "ribonanzanet2d-final_path": "ribonanzanet2d-final",
    "train_sequences_path": "stanford-rna-3d-folding/train_sequences.csv",
    "train_labels_path": "stanford-rna-3d-folding/train_labels.csv",
    "test_sequences_path": "stanford-rna-3d-folding/test_sequences.csv",
    "pretrained_weights_path": "ribonanzanet-weights/RibonanzaNet.pt",


    "save_weights_folder": "trained_model_weights",
    "save_weights_name": "RibonanzaNet-3D.pt",
    "save_weights_final": "RibonanzaNet-3D-final.pt",
}

if not os.path.exists(config['save_weights_folder']):
    os.mkdir(config['save_weights_folder'])

# Set the seed for reproducibility
set_seed(config["seed"])

# 2. DATA LOADING & PREPARATION

In [None]:
# Load CSVs
train_sequences = pd.read_csv(config["train_sequences_path"])
train_labels = pd.read_csv(config["train_labels_path"])

test_sequences = pd.read_csv(config["test_sequences_path"])

# Create a pdb_id field
train_labels["pdb_id"] = train_labels["ID"].apply(
    lambda x: x.split("_")[0] + "_" + x.split("_")[1]
)

# Collect xyz data for each sequence
# all_xyz = []
# for pdb_id in tqdm(train_sequences["target_id"], desc="Collecting XYZ data"):
#     df = train_labels[train_labels["pdb_id"] == pdb_id]
#     xyz = df[["x_1", "y_1", "z_1"]].to_numpy().astype("float32")
#     xyz[xyz < -1e17] = float("nan")
#     all_xyz.append({"pdb_id": pdb_id, "xyz": xyz})

In [None]:
# TODO change to pandas implementation so can keep track of which REsid and which resname with each removed entry

nan_ratio_threshold = 0.5
max_len_seen = 0
num_removed_because_nan = 0
num_removed_because_size = 0
num_beginning_removed = 0
num_ending_removed = 0

valid_indices = []

print(f"Number of sequences before filtering: {len(all_xyz)}")

for i, entry in enumerate(all_xyz):
    entry_xyz = entry["xyz"]
    entry_pdb_id = entry["pdb_id"]

    nan_ratio = np.isnan(entry_xyz).mean()

    if nan_ratio > nan_ratio_threshold:
        num_removed_because_nan += 1
        continue
    
    if len(entry_xyz) > config["max_len_filter"] or len(entry_xyz) < config["min_len_filter"]:
        num_removed_because_size += 1
        continue

    
    # Remove NaN rows from the beginning until reach a valid row
    i = 0
    while True:
        if np.any(np.isnan(entry_xyz[i, :])):
            num_beginning_removed += 1
            entry_xyz = np.delete(entry_xyz, i, axis=0)
        else:
            break

    # Remove NaN rows from the end until reach a valid row
    i = -1
    while True:
        if np.any(np.isnan(entry_xyz[i, :])):
            num_ending_removed += 1
            entry_xyz = np.delete(entry_xyz, i, axis=0)
        else:
            break

    all_xyz[i]["xyz"] = entry_xyz


    valid_indices.append(i)
    
    if len(entry_xyz) > max_len_seen:
        max_len_seen = len(entry_xyz)
    

print(f"Removed {num_removed_because_nan} sequences because of high NaN ratio.")
print(f"Removed {num_removed_because_size} sequences because of size.")
print(f"Removed {num_beginning_removed} sequences because of beginning NaN.")
print(f"Removed {num_ending_removed} sequences because of ending NaN.")
print(f"Max length seen: {max_len_seen}")
print(f"Number of sequences after filtering: {len(valid_indices)}")
    
    

    
    
