# Purpose

Understand the data, plot the thingies out and see what they look like

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go

pd.set_option('display.max_colwidth', None)

In [None]:
# %matplotlib widget

In [None]:
train_labels = pd.read_csv('stanford-rna-3d-folding/train_labels.csv')

train_sequences = pd.read_csv('stanford-rna-3d-folding/train_sequences.csv')

In [None]:
# The different polymers in the database
diff_polymers = train_labels['ID'].apply(lambda x: x.split('_')[0] + '_' + x.split('_')[1]).unique()

sample_polymers = np.random.choice(diff_polymers, 5)

sample_polymers

In [None]:
# print out all the things for a certain thing in labels
display(train_labels[train_labels['ID'].str.contains(str(sample_polymers[0]))])

# print out all the things for a certain thing in sequences
display(train_sequences[train_sequences['target_id'] == sample_polymers[0]])

In [None]:
def display_fasta(file_path):
    """
    Displays the contents of a FASTA file.

    Args:
        file_path (str): The path to the FASTA file.
    """
    try:
        with open(file_path, 'r') as file:
            for line in file:
                print(line, end='')
    except FileNotFoundError:
        print(f"Error: File not found at {file_path}")
    except Exception as e:
        print(f"An error occurred: {e}")

# Example usage:
# file_path = 'sequence.fasta'
# display_fasta(file_path)

In [None]:
file_path = f'stanford-rna-3d-folding/MSA/{sample_polymers[0]}.MSA.fasta'
display_fasta(file_path)

# Plotting

In [None]:
def deconstruct_polymer(polymer_name: str, df: pd.DataFrame):
    # Take in the name and df and returns x, y, z, res, polymer_name for all the things in that thing
    all_monomers = df[df['ID'].str.contains(polymer_name)]

    return np.array(all_monomers['x_1']), np.array(all_monomers['y_1']), np.array(all_monomers['z_1']), np.array(all_monomers['resname']), polymer_name

In [None]:
def plot_multiple_structures(structures: list) -> None:
    fig = go.Figure()
    
    for i in range(len(structures)):
        x, y, z, sequences, name = structures[i][0], structures[i][1], structures[i][2], structures[i][3], structures[i][4]

        colors = {"A": "red", "G": "blue", "C": "green", "U": "orange"}


        for resname, color in colors.items():
            fig.add_trace(go.Scatter3d(
                x=x[sequences == resname],
                y=y[sequences == resname],
                z=z[sequences == resname],
                mode="markers",
                marker=dict(size=4, color=color),
                name=resname
            ))

        fig.add_trace(go.Scatter3d(
            x=x,
            y=y,
            z=z,
            mode="lines",
            line=dict(color="black", width=2),
            name="RNA Backbone"
        ))

        fig.update_layout(
            scene=dict(
                xaxis=dict(title_text="X"),
                yaxis=dict(title_text="Y"),
                zaxis=dict(title_text="Z"),
            ),
            title=f"RNA 3D Structure of {name}",
        )

    fig.show()

def plot_structure(x, y, z, sequences, name: str) -> None:
    # Takes in the raw lists of x, y, z, and sequences and plots the 3D structure of the RNA
    # The sequences are colored by the nucleotide they represent


    # https://www.kaggle.com/code/asarvazyan/interactive-3d-sequence-visualization
    colors = {"A": "red", "G": "blue", "C": "green", "U": "orange"}

    fig = go.Figure()

    for resname, color in colors.items():
        fig.add_trace(go.Scatter3d(
            x=x[sequences == resname],
            y=y[sequences == resname],
            z=z[sequences == resname],
            mode="markers",
            marker=dict(size=4, color=color),
            name=resname
        ))
    
    fig.add_trace(go.Scatter3d(
        x=x,
        y=y,
        z=z,
        mode="lines",
        line=dict(color="black", width=2),
        name="RNA Backbone"
    ))

    fig.update_layout(
        scene=dict(
            xaxis=dict(title_text="X"),
            yaxis=dict(title_text="Y"),
            zaxis=dict(title_text="Z"),
        ),
        title=f"RNA 3D Structure of {name}",
    )

    fig.show()

In [None]:
first = list(deconstruct_polymer(sample_polymers[0], train_labels))
second = list(deconstruct_polymer(sample_polymers[1], train_labels))
plot_multiple_structures([first, second])
# print(x)
# display(np.array(list(zip(x, y, z))).tolist())

In [None]:
x, y, z, sequences, name = deconstruct_polymer(sample_polymers[0], train_labels)
plot_structure(x, y, z, sequences, name)
print(x)
# display(np.array(list(zip(x, y, z))).tolist())

In [None]:
# plot out a couple of polymers or something
for i in sample_polymers:
    x, y, z, sequences, name = deconstruct_polymer(i, train_labels)
    plot_structure(x, y, z, sequences, name)
    display(np.array(list(zip(x, y, z))).tolist())

# FASTA stuff

In [None]:
train_sequences

In [None]:
train_labels