# Data Distribution for Multiclass Classification

In [None]:
import re

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

In [None]:
data = pd.read_csv("../data/iphos_multiclass.csv")
data.head()

In [None]:
def get_number_of_tail(lipid_string, lipid_name):
    c_string_list = re.findall(r"CCCC+", lipid_string)
    if "25A" in lipid_name:
        return len(c_string_list) - 1

    return len(c_string_list)


def get_count_charged_ion(lipid_string):
    charged_ion = re.findall(r"O\-+", lipid_string)

    return len(charged_ion)

In [None]:
data.loc[:, "n_tail"] = data.apply(lambda row: get_number_of_tail(row["m1"], row["name"]), axis=1)
data.loc[:, "n_zwitterion"] = data["m1"].apply(get_count_charged_ion)

In [None]:
aggregated_data = (
    data.groupby(["family", "n_tail", "n_zwitterion"]).size().rename("count").reset_index()
)
count_df = data.groupby(["family", "y1"]).size().unstack()
percentage_df = count_df.div(count_df.sum(axis=1), axis=0) * 100
percentage_df_pd = percentage_df.reset_index()

In [None]:
# Set styles
plt.style.use(["seaborn-v0_8-paper", "seaborn-v0_8-whitegrid"])
plt.style.use(["seaborn-v0_8"])
sns.set(palette="colorblind")

labels = ["Family 0", "Family 1", "Family 2", "Family 3", "Family 4", "Family 5", "Family 6"]
a = aggregated_data["n_zwitterion"].to_list()
b = aggregated_data["n_tail"].to_list()
c = aggregated_data["count"].to_list()
bar_width = 0.20
df = [c, a, b]

colors = sns.color_palette(palette="colorblind")
colors = ["#a5c4b1", "#a5c4b1", "#a5c4b1"]
color_bar = ["#fff9c9", "#efda6d", "#b64a47", "#754242"]
columns = ("Family 0", "Family 1", "Family 2", "Family 3", "Family 4", "Family 5", "Family 6")

index = np.arange(len(labels))

# Create plots with pre-defined labels.
fig, ax = plt.subplots()

value0 = percentage_df_pd[0].to_list()
value1 = percentage_df_pd[1].to_list()
value2 = percentage_df_pd[2].to_list()
value3 = percentage_df_pd[3].to_list()

ax.bar(labels, value0, color=color_bar[0], label="0")
ax.bar(labels, value1, color=color_bar[1], label="1", bottom=value0)
ax.bar(
    labels,
    value2,
    color=color_bar[2],
    label="2",
    bottom=[value0[i] + value1[i] for i in range(len(value0))],
)
ax.bar(
    labels,
    value3,
    color=color_bar[3],
    label="3",
    bottom=[value0[i] + value1[i] + value2[i] for i in range(len(value0))],
)


legend = ax.legend(
    loc="best",
    bbox_to_anchor=(-1.1, 0.9, 1, 0.2),
    title="RLU activity",
)


plt.table(
    cellText=df,
    rowLabels=[
        " Total Number of lipids ",
        " Number of tails ",
        " Number of zwitterions",
    ],
    rowColours=colors,
    colLabels=columns,
    loc="bottom",
    bbox=[0, -0.3, 1, 0.2],
)

plt.ylabel("%")
plt.xticks([])
# plt.title('Some title')
plt.show()