In [1]:
import numpy as np
import pandas as pd
import json
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import numpy as np

In [2]:
def split_string(input_string):
    segments = []
    current_segment = ""

    for char in input_string:
        if char.isalnum():
            current_segment += char  # Append alphanumeric characters to the current segment
        elif char.isspace():
            continue  # Skip whitespace characters
        else:
            if current_segment:
                segments.append(current_segment)  # Add the current segment to the list
                current_segment = ""
            segments.append(char)  # Add the delimiter as a separate segment

    if current_segment:
        segments.append(current_segment)  # Add the last segment if any

    return segments

def split_and_replace(input_string, mapping):
    segments = split_string(input_string)
    i = 0

    while i < len(segments):
        segment = segments[i]
        if segment in mapping:
            # Replace the segment with its mapped value and split it
            new_segments = split_string(mapping[segment])
            segments = segments[:i] + new_segments + segments[i+1:]
            # Do not increment i, so the loop will check the new segments on the next iteration
        else:
            # Only increment i if no replacement was done
            i += 1

    return segments

def add_conflicts(df):
    for i, row in df.iterrows():
        # Find the index of the first None value
        none_indexes = row.isnull().argmax()
        first_none_index = none_indexes if none_indexes > 0 else len(row)

        # Define the range of columns to check for duplicates
        cols_to_check = df.columns[:first_none_index]

        # Extract the values to compare from the current row
        values_to_compare = row[cols_to_check]

        # Create a mask to identify rows with identical values up to the first None
        mask = (df[cols_to_check] == values_to_compare).all(axis=1) & (df.index != i)

        # Determine if there is any conflict
        df.loc[i, 'Conflict'] = mask.any()
        return df

def resolve_conflicts(dataframe):
    for i, row in dataframe.iterrows():
        # Check if the row is marked as a conflict
        if row['Conflict']:
            # Find the first None value and replace it with "End"
            for col in dataframe.columns:
                if pd.isnull(row[col]):
                    dataframe.at[i, col] = "."
                    break
    return dataframe

# Function to adjust values based on the depth (distance to the first None)
def adjust_values(row):
    # Find the index of the first None value in the row (ignoring the last three columns which are 'Values', 'Color', 'Conflict')
    first_none_index = next((i for i, x in enumerate(row[:-10]) if x is None), len(row))
    
    # Adjust the 'Values' based on the depth
    adjusted_value = row['Values'] / (first_none_index)
    
    return adjusted_value

def encode_fullstops(strings):
    # Initialize an empty list to hold the 1-hot encoded vector
    encoded_vector = []
    
    # Iterate over each string in the input list
    for item in strings:
        # Append 1 to the vector if the item is a full stop, otherwise append 0
        encoded_vector.append(0 if item == '.' else 1)
    
    return encoded_vector

def filter_elements(original_list, encoded_vector):
    # Ensure the two lists are of the same length
    if len(original_list) != len(encoded_vector):
        raise ValueError("Both lists must be of the same length.")
    
    # Use list comprehension to filter the original list based on the encoded vector
    filtered_list = [element for element, flag in zip(original_list, encoded_vector) if flag == 1]
    
    return filtered_list

In [3]:
with open('alias_key.json', 'r') as f:
    mapping = json.load(f)

for key, value in mapping.items():
    if isinstance(value, list):
        mapping[key] = '[' + ', '.join(value) + ']'

In [4]:
predictions = pd.read_csv('../04_feature-evaluation/predictions/random_forest_Testing.csv')
predictions = predictions.drop(columns=["FCGR_remove_128"])
column_mapping = {
    '7-mer_FFP_remove': 'FFP',
    '7-mer_remove': 'K-mer',
    '7-spaced_remove': 'Spaced',
    'ACS_remove': 'ACS',
    'DSP_replace_real': 'GSP',
    'FCGR_remove_256': 'FCGR',
    'Mash_distance_remove_21': 'Mash',
    'RTD_7-mer_remove': 'RTD'
}

predictions = predictions.rename(columns=column_mapping)

In [5]:
model_accuracies = pd.DataFrame(index=predictions['Original_Targets'].unique())

# Iterate over each model column (excluding 'Original_Targets' and 'Coverage')
for model in predictions.columns.drop(['Original_Targets']):
    accuracies = []
    for target in model_accuracies.index:
        # Select only the rows corresponding to the current class
        subset = predictions[predictions['Original_Targets'] == target]
        # Calculate the accuracy as the mean of correct predictions
        accuracy = (subset[model] == target).mean()
        accuracies.append(accuracy)
    
    # Assign the computed accuracies to the respective model column
    model_accuracies[model] = accuracies

lineage_df = pd.concat([model_accuracies, predictions['Original_Targets'].value_counts()], axis=1)

In [6]:
# Print the sorted unique lineages as a column
expanded_lineages = []
for lineage in lineage_df.index:
    output = split_and_replace(lineage, mapping)
    output_str = "".join(output)
    expanded_lineages.append(output_str)

lineage_df.index = expanded_lineages
simple_lineages = lineage_df[~lineage_df.index.astype(str).str.contains(r"\[|\*")]

In [None]:
paths = [item for item in simple_lineages.index if item.count('.') <= 13]
split_strings = [s.split('.') + [None] * (13 - s.count('.')) for s in paths]

# Creating a DataFrame from the split strings
df = pd.DataFrame(split_strings, columns=['Level 1', 'Level 2', 'Level 3', 'Level 4', 'Level 5', 'Level 6', 'Level 7', 'Level 8', 'Level 9', 'Level 10', 'Level 11', 'Level 12', 'Level 13', 'Level 14'])
df['Values'] = np.log(simple_lineages["count"].tolist())
# df['Values'] = np.ones(len(df))
df["FFP"] = simple_lineages["FFP"].tolist()
df["K-mer"] = simple_lineages["K-mer"].tolist()
df["Spaced"] = simple_lineages["Spaced"].tolist()
df["ACS"] = simple_lineages["ACS"].tolist()
df["GSP"] = simple_lineages["GSP"].tolist()
df["FCGR"] = simple_lineages["FCGR"].tolist()
df["Mash"] = simple_lineages["Mash"].tolist()
df["RTD"] = simple_lineages["RTD"].tolist()

df = add_conflicts(df.copy())
df = resolve_conflicts(df.copy())
df['Values'] = df.apply(adjust_values, axis=1)

features = ["FFP", "K-mer", "Spaced", "ACS", "GSP", "FCGR", "Mash", "RTD"]
num_features = len(features)

# Initialize a subplot figure with vertical stacking
fig = make_subplots(rows=num_features, cols=1, specs=[[{"type": "domain"}]] * num_features, vertical_spacing=0)

for i, feature in enumerate(features, start=1):
    # Update your DataFrame's color based on the feature
    df["Color"] = df[feature]

    # Create a Plotly Express sunburst figure for each feature
    px_fig = px.sunburst(df, path=['Level 1', 'Level 2', 'Level 3', 'Level 4', 'Level 5', 'Level 6', 'Level 7', 'Level 8', 'Level 9', 'Level 10', 'Level 11', 'Level 12', 'Level 13', 'Level 14'], values='Values', color="Color")
    
    # Extract data for custom control
    labels = px_fig['data'][0]['labels'].tolist()
    parents = px_fig['data'][0]['parents'].tolist()
    values = px_fig['data'][0]['values'].tolist()
    ids = px_fig['data'][0]['ids'].tolist()
    colors = px_fig.data[0].marker.colors

    removes = encode_fullstops(labels)
    new_labels = filter_elements(labels, removes)
    new_parents = filter_elements(parents, removes)
    new_values = filter_elements(values, removes)
    new_ids = filter_elements(ids, removes)
    new_colors = filter_elements(colors, removes)

    # Add the customized sunburst trace to the subplot
    fig.add_trace(go.Sunburst(
        labels=new_labels,
        parents=new_parents,
        values=new_values,
        ids=new_ids,
        branchvalues='total',
        insidetextorientation='radial',
        marker=dict(
            colors=new_colors,  # your color values
            colorscale='Blues',
            cmin=0,  # setting minimum of color range
            cmax=1,  # setting maximum of color range
            showscale=(i == num_features)  # Show color scale only on the last plot
        )
    ), row=i, col=1)

# Update layout to fit the stacked sunburst plots without gaps
fig.update_layout(
    height=1000 * num_features,
    width=1000,
    margin=dict(t=0, l=0, b=0, r=0)
)

# Add titles to each subplot
for i, feature in enumerate(features, start=1):
    fig.add_annotation(
        text=feature,  # Subplot title text
        xref="paper", yref="paper",
        x=0.5, y=1 - (i - 0.5) / num_features,
        xanchor="center", yanchor="bottom",
        showarrow=False,
        font=dict(size=14)
    )

fig.show()