In [1]:
# -*- coding: utf-8 -*-
#
# Copyright (C) 2021-2024  LMAI_team @ TU Dresden:
#     LMAI_team: Zhixu Ni, Maria Fedorova
#
# Licensing:
# This code is licensed under AGPL-3.0 license (Affero General Public License v3.0).
# For more information, please read:
#     AGPL-3.0 License: https://www.gnu.org/licenses/agpl-3.0.en.html
#
# Citation:
# Please cite our publication in an appropriate form.
#
# For more information, please contact:
#     Fedorova Lab (#LMAI_team): https://fedorovalab.net/
#     LMAI on Github: https://github.com/LMAI-TUD
#
import os
import time

import pandas as pd

from lmai.cluster import add_cluster_to_raw_data, create_normalized_data, plot_sub_trend_lines, run_cluster
from lmai.pretreatment import calc_avg, get_lipid_class_info, load_file, replace_min

This is to plot the trend plot with clustering of lipids based on the corresponding abundances.
Supports the following clustering algorithms:

- `hew` (Hierarchical clustering with Ward linkage)
- `km` (KMeans)
- `bisect_km` (Bisecting KMeans)
- `gmm` (Gaussian Mixture Model)
- `dpgmm` (Dirichlet Process Gaussian Mixture Model)

*Above clustering algorithms are implemented using `scikit-learn` library.*

In [2]:
# Check the current working directory
cwd = os.getcwd()
print(f'Current working directory: {cwd}')

Current working directory: /Users/ni/PycharmProjects/LipidTrends


In [3]:
# define path of the input and output files.
# Index/row/column names must be unique, check the input file for duplicated lipids/sample names before running the code.
# File/folder path relative to the current working directory or absolute path.
input_data_matrix_csv_path = f"data/data_matrix.csv"
# Define the column name of the lipid class in the data table, default is "lipid_class"
lipid_class_col_name = "lipid_class"
# Define the path for the metadata file
input_metadata_csv_path = r"data/metadata.csv"
# Define the column names for the metadata file
# id_col: the column name for the sample ID matching the data matrix column names
# sample_col: the column name for the sample information
# group_col: the column name for the group information
id_col_name = "column_id"
sample_col_name = "sample"
group_col_name = "group"
# Define the path for the output files
output_folder = r"output"
# Define the name of the task which will be part of the output file names
# short_project_related_name_no_space
task_name = "TAV_trends"

In [4]:
# Define the path for files during the data processing
output_processed_files_folder = r"output/processed_data"

In [5]:
# Define the color settings for the lipid classes in JSON format
# Default file located in the config folder: "config/lipid_class_colors_cluster.json"
cfg_color_json_path = "config/lipid_class_colors_cluster.json"
# multiple color profiles can be defined in the JSON file, here we use the "sub_class_level" profile
cfg_color_level = "sub_class_level"

**Important note:**

Due to the principles of the `km`, `bisect_km`, `gmm`, and `dpgmm` clustering algorithms, the number of clusters `k` should be specified prior to the analysis.

We provide the possibility to generate a series of plots with different `k` values to help the user to decide on the optimal number of clusters.

*Many of the clustering algorithms require a random state to start the clustering.*

*Due to the version of the libraries and the local computer system setups, the results may vary slightly between different runs.*

We provide the possibility to set the random state to enhance reproducibility of the results. 
However, there is still possibility to obtain slightly differed cluster results.

*Please keep this in mind when interpreting the results and we recommend using the same device and environment for the whole analysis.*

+ Please keep using the same device and environment for the whole analysis.

+ Please keep all the parameters and the exported in between files to keep a record of the analysis.

+ You can also copy the whole folder for each new dataset to keep the record of the analysis managed by task/project.

In [6]:
# Define the parameters for the trend analysis
cfg_min_ratio = 5
cfg_na_values = ["NA", 0]
cfg_zero_is_na = True
# Define the cluster count for the trend analysis, you can set multiple cluster counts for comparison e.g. [4, 5, 8, 10]
# Default cluster count: [5]
# Acceptable value range: from 3 to 16
cfg_cluster_count_lst = [5]
# Available normalization modes: "zscore", "log2", "minmax"
# Default normalization mode: "zscore"
normalization_mode = "zscore"
# Define if the corresponding raw value plot is generated along with the normalized value plot
cfg_plot_raw_val_img = True
# define the cluster methods to be used
# built-in methods: ["hew", "km", "bisect_km", "gmm", "dpgmm"]
# default: ["gmm"] to plot the gmm only, you can use multiple methods e.g. ["gmm", "km"] to plot both gmm and k-means
cfg_cluster_methods_lst = ["gmm"]
# define subgroups for the trend analysis of each cluster
cfg_sub_groups = [
    ["TG", "DG"],
    ["CE", "ST"],
    ["LPC", "O-LPC", "P-LPC", "LPE", "O-LPE", "P-LPE"],
    ["PC", "P-PC"],
    ["O-PC"],
    ["PE", "O-PE"],
    ["P-PE"],
    ["PG", "CL"],
    ["PS"],
    ["PI"],
    ["SM"],
    ["Cer", "DihydroCer", "DeoxyCer", "PythoCer"],
    ["HexCer", "Hex2Cer", "GM3"],
]
# Define a random state for the cluster methods, typically used for reproducibility.
# Default random state is usually set to 42 in Python.
cfg_random_state = 42
pd.set_option('display.precision', 9)

In [7]:
# Load metadata, set the index column and header row
# by default, the index column is the first column and the header row is the first row
# python counts from 0, so the first column is 0 and the first row is 0
meta_df = load_file(input_metadata_csv_path,
                    na_values=cfg_na_values,
                    index_col=0, header_row=0,
                    zero_is_na=cfg_zero_is_na
                    )
# Preview the first few rows of the metadata. (5 rows by default)
meta_df.head(5)

Unnamed: 0_level_0,sample,group
column_id,Unnamed: 1_level_1,Unnamed: 2_level_1
TAV_md_M1,TAV_md_M1,TAV_md
TAV_md_M2,TAV_md_M2,TAV_md
TAV_md_M3,TAV_md_M3,TAV_md
TAV_md_M4,TAV_md_M4,TAV_md
TAV_md_M5,TAV_md_M5,TAV_md


In [8]:
# Load data matrix
# by default, the index column is the first column and the header row is the first row
# python counts from 0, so the first column is 0 and the first row is 0
raw_data_df = load_file(input_data_matrix_csv_path,
                        na_values=cfg_na_values,
                        index_col=0, header_row=0,
                        zero_is_na=cfg_zero_is_na
                        )
# force data matrix index name to be "lipid"
raw_data_df.index.name = "lipid"
# Preview the first few rows of the data matrix. (5 rows by default)
raw_data_df.head(5)

Unnamed: 0_level_0,lipid_class,TAV_md_M1,TAV_md_M2,TAV_md_M3,TAV_md_M4,TAV_md_M5,TAV_md_M6,TAV_md_M7,TAV_md_M8,TAV_fib_M1,...,TAV_fib_F9,TAV_cal_F1,TAV_cal_F2,TAV_cal_F3,TAV_cal_F4,TAV_cal_F5,TAV_cal_F6,TAV_cal_F7,TAV_cal_F8,TAV_cal_F9
lipid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
CE(16:0),CE,76.89781,248.24879,77.54066,223.78758,163.06255,109.25142,144.27042,176.55957,305.77844,...,157.80967,458.85353,432.48957,372.60142,469.58833,241.85936,480.12203,443.32799,531.394,272.79315
CE(17:1),CE,3.03234,13.08532,2.3016,13.49773,5.10193,4.4192,4.08374,8.47118,15.26112,...,6.29534,37.82903,41.76634,27.16446,69.41431,17.83133,65.01035,33.91087,91.51081,16.28586
CE(18:0),CE,6.34884,19.91307,9.19094,31.66378,11.15016,9.28725,13.14816,30.83197,28.71407,...,20.39708,65.38193,59.60298,43.67648,91.46718,40.98188,57.23461,73.60858,103.16381,41.27869
CE(18:1),CE,753.3291,2131.14643,841.1836,2206.06079,1993.85114,1074.62269,1230.50194,1956.49154,3283.17898,...,1567.63573,3547.59147,3523.23242,3158.91682,5223.03566,2358.01061,3457.76759,3505.9096,7622.3388,2252.30703
CE(18:2),CE,3769.9698,9022.82815,4329.66467,7798.26046,7673.81354,5616.3195,5755.19326,8216.77309,10647.37937,...,6791.55712,12833.65302,14037.84846,12515.11522,19816.67078,8874.54669,12921.7641,12224.22211,17202.59953,8844.32966


In [9]:
# Extract the lipid class information from the lipid_class_col from data_matrix file and save it as a dictionary
lipid_class_info = get_lipid_class_info(raw_data_df, lipid_class_col_name)
# pd.DataFrame(lipid_class_info)
lipid_class_df = pd.DataFrame.from_dict(lipid_class_info, orient='index')
lipid_class_df

Unnamed: 0,0
CE(16:0),CE
CE(17:1),CE
CE(18:0),CE
CE(18:1),CE
CE(18:2),CE
...,...
SM(52:2;3),SM
SM(52:3;3),SM
SM(54:2;3),SM
TG(34:0),TG


In [10]:
# remove the lipid class column from the data matrix to obtain a pure lipid data matrix for further processing
data_df = raw_data_df.copy()
data_df = data_df.drop(columns=[lipid_class_col_name])
data_df.tail()

Unnamed: 0_level_0,TAV_md_M1,TAV_md_M2,TAV_md_M3,TAV_md_M4,TAV_md_M5,TAV_md_M6,TAV_md_M7,TAV_md_M8,TAV_fib_M1,TAV_fib_M2,...,TAV_fib_F9,TAV_cal_F1,TAV_cal_F2,TAV_cal_F3,TAV_cal_F4,TAV_cal_F5,TAV_cal_F6,TAV_cal_F7,TAV_cal_F8,TAV_cal_F9
lipid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
SM(52:2;3),0.06589,0.29591,0.14995,0.3359,0.05786,0.07632,0.1837,0.11535,0.29322,0.07558,...,0.15606,0.51548,1.07765,0.78204,0.73389,0.44349,0.50637,0.66886,0.96119,0.51907
SM(52:3;3),0.02608,0.12548,0.06001,0.14092,0.02137,0.0323,0.07169,0.04319,0.13699,0.03413,...,0.06526,0.27303,0.5361,0.2743,0.29154,0.22308,0.26976,0.22329,0.37978,0.23283
SM(54:2;3),0.02055,0.10457,0.04653,0.12961,0.00691,0.01779,0.06058,0.03997,0.11404,0.0315,...,0.04651,0.13324,0.28612,0.2058,0.22344,0.12422,0.12963,0.21907,0.28685,0.14297
TG(34:0),0.07421,0.13962,0.06828,0.10248,0.57212,0.0663,0.17567,0.06322,0.10204,0.14582,...,0.17763,0.04109,0.06796,0.03104,0.05284,0.04217,0.30702,0.06187,0.03501,0.03532
TG(45:0),0.36209,1.04663,0.27513,0.92328,2.94513,0.44372,0.99336,0.20324,0.61084,0.6629,...,0.74887,0.29516,0.43982,0.30571,0.86655,0.38077,3.40237,0.67771,0.49805,0.27571


In [11]:
# Generate a unix timestamp to be used as a unique identifier for the output files
timestamp_str = str(int(time.time()))
print(f"Timestamp: {timestamp_str}")

Timestamp: 1722271461


In [12]:
# Fill zero values with the minimum value ratio of each corresponding lipid
zero_fill_data_df = replace_min(
    df=data_df,
    min_value_ratio=cfg_min_ratio,
    axis=0,
)
zero_fill_data_df.head()

! Missing value detected: row CE(24:5) column TAV_md_F5 has N/A value: nan.
> Fill this cell with 0.160456.  # 1/5 of the min value in this row 0.80228.

! Missing value detected: row Cer(18:0;1/16:0) column TAV_md_F6 has N/A value: nan.
> Fill this cell with 0.000724.  # 1/5 of the min value in this row 0.00362.

! Missing value detected: row Cer(18:0;1/23:0) column TAV_md_F3 has N/A value: nan.
> Fill this cell with 0.0009159999999999999.  # 1/5 of the min value in this row 0.00458.

! Missing value detected: row Cer(18:1;1/26:1) column TAV_md_F3 has N/A value: nan.
> Fill this cell with 0.0010659999999999999.  # 1/5 of the min value in this row 0.00533.

! Missing value detected: row Cer(18:1;1/26:1) column TAV_md_F6 has N/A value: nan.
> Fill this cell with 0.0010659999999999999.  # 1/5 of the min value in this row 0.00533.

! Missing value detected: row Cer(20:0;1/22:0) column TAV_md_F1 has N/A value: nan.
> Fill this cell with 0.00173.  # 1/5 of the min value in this row 0.00865.

Unnamed: 0_level_0,TAV_md_M1,TAV_md_M2,TAV_md_M3,TAV_md_M4,TAV_md_M5,TAV_md_M6,TAV_md_M7,TAV_md_M8,TAV_fib_M1,TAV_fib_M2,...,TAV_fib_F9,TAV_cal_F1,TAV_cal_F2,TAV_cal_F3,TAV_cal_F4,TAV_cal_F5,TAV_cal_F6,TAV_cal_F7,TAV_cal_F8,TAV_cal_F9
lipid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
CE(16:0),76.89781,248.24879,77.54066,223.78758,163.06255,109.25142,144.27042,176.55957,305.77844,189.67994,...,157.80967,458.85353,432.48957,372.60142,469.58833,241.85936,480.12203,443.32799,531.394,272.79315
CE(17:1),3.03234,13.08532,2.3016,13.49773,5.10193,4.4192,4.08374,8.47118,15.26112,8.87273,...,6.29534,37.82903,41.76634,27.16446,69.41431,17.83133,65.01035,33.91087,91.51081,16.28586
CE(18:0),6.34884,19.91307,9.19094,31.66378,11.15016,9.28725,13.14816,30.83197,28.71407,16.24595,...,20.39708,65.38193,59.60298,43.67648,91.46718,40.98188,57.23461,73.60858,103.16381,41.27869
CE(18:1),753.3291,2131.14643,841.1836,2206.06079,1993.85114,1074.62269,1230.50194,1956.49154,3283.17898,1589.38334,...,1567.63573,3547.59147,3523.23242,3158.91682,5223.03566,2358.01061,3457.76759,3505.9096,7622.3388,2252.30703
CE(18:2),3769.9698,9022.82815,4329.66467,7798.26046,7673.81354,5616.3195,5755.19326,8216.77309,10647.37937,8270.98341,...,6791.55712,12833.65302,14037.84846,12515.11522,19816.67078,8874.54669,12921.7641,12224.22211,17202.59953,8844.32966


In [13]:
# average by groups
avg_data_df = calc_avg(
    df=zero_fill_data_df,
    meta=meta_df,
    group_col=group_col_name,
    sample_col=sample_col_name,
    axis=0,
    keep_original=False,
)
avg_data_df.head()

Unnamed: 0_level_0,TAV_md,TAV_fib,TAV_cal
lipid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
CE(16:0),133.738617143,339.373004211,376.305012778
CE(17:1),6.14271,33.369036842,33.292041111
CE(18:0),13.016495714,63.187569474,54.449917778
CE(18:1),1359.724645714,3605.235050526,3430.363438333
CE(18:2),5904.574847143,13702.792135263,12155.114256111


In [14]:
# save processed data to csv
zero_fill_data_csv_path = os.path.join(output_processed_files_folder, f"{task_name}_zero_fill_data_{timestamp_str}.csv")
avg_data_csv_path = os.path.join(output_processed_files_folder, f"{task_name}_avg_data_{timestamp_str}.csv")
zero_fill_data_df.to_csv(zero_fill_data_csv_path)
avg_data_df.to_csv(avg_data_csv_path, float_format='%.9f')

if os.path.exists(zero_fill_data_csv_path):
    print(f"Save zero fill data to {zero_fill_data_csv_path}")
if os.path.exists(avg_data_csv_path):
    print(f"Save average data to {avg_data_csv_path}")

Save zero fill data to output/processed_data/TAV_trends_zero_fill_data_1722271461.csv
Save average data to output/processed_data/TAV_trends_avg_data_1722271461.csv


In [15]:
print(f"Scale data using mode: {normalization_mode}")
# normalize data
avg_data_csv_df = load_file(avg_data_csv_path, index_col=0, header_row=0)
scaled_data_df = create_normalized_data(avg_data_csv_df, mode=normalization_mode)
scaled_data_df.head()

Scale data using mode: zscore


Unnamed: 0_level_0,TAV_md,TAV_fib,TAV_cal
lipid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
CE(16:0),-1.400026475,0.526968728,0.873057747
CE(17:1),-1.414209309,0.710108427,0.704100882
CE(18:0),-1.395299665,0.897284835,0.49801483
CE(18:1),-1.410744185,0.79110793,0.619636255
CE(18:2),-1.389150973,0.924128707,0.465022266


In [16]:
scaled_data_csv_path = os.path.join(output_processed_files_folder, f"{task_name}_scaled_data_{timestamp_str}.csv")
scaled_data_df.to_csv(scaled_data_csv_path)
if os.path.exists(scaled_data_csv_path):
    print(f"Save scaled data to {scaled_data_csv_path}")

Save scaled data to output/processed_data/TAV_trends_scaled_data_1722271461.csv


In [17]:
# load scaled data
scaled_data_df = load_file(r'output/processed_data/sig_nd-fib-cal_color2-_avg_cluster_c5_r0_zscore_sub.csv', index_col=0, header_row=0)

In [18]:
# run clusters
for cluster_method in cfg_cluster_methods_lst:
    for cfg_cluster_count in cfg_cluster_count_lst:
        print(
            f"Running clusters with {cfg_cluster_count} cluster methods {cluster_method}"
        )
        test_clustered_data = run_cluster(
            scaled_data_df,
            k=cfg_cluster_count,
            method=cluster_method,
            random_state=cfg_random_state,
        )
        output_basic_name = os.path.join(f"{output_folder}",
                                         f"{task_name}_S-{normalization_mode}_M-{cluster_method}_C-{cfg_cluster_count}_{timestamp_str}"
                                         )
        test_output_csv = f"{output_basic_name}.csv"
        test_output_png = f"{output_basic_name}.png"

        plot_sub_trend_lines(
            test_clustered_data,
            test_output_png,
            method=cluster_method,
            lipid_dct=lipid_class_info,
            color_cfg=cfg_color_json_path,
            color_level=cfg_color_level,
            sub_groups=cfg_sub_groups,
        )

        print(f"Save clustered data to {test_output_csv}")
        if cfg_plot_raw_val_img:
            test_raw_output_csv = f"{output_basic_name}_raw.csv"
            test_raw_output_png = f"{output_basic_name}_raw.png"

            test_raw_clustered_data = add_cluster_to_raw_data(
                raw_data=avg_data_df,
                cluster_data=test_clustered_data,
                method=cluster_method,
            )
            plot_sub_trend_lines(
                test_raw_clustered_data,
                test_raw_output_png,
                method=cluster_method,
                lipid_dct=lipid_class_info,
                color_cfg=cfg_color_json_path,
                color_level=cfg_color_level,
                sync_axis=True,
                sub_groups=cfg_sub_groups,
            )
            test_raw_clustered_data.to_csv(test_raw_output_csv)
        test_clustered_data["lipid_class"] = test_clustered_data.index.map(lipid_class_info)
        test_clustered_data.to_csv(test_output_csv)

Running clusters with 5 cluster methods gmm
Max cluster: 5
Save clustered data to output/TAV_trends_S-zscore_M-gmm_C-5_1722271461.csv
Max cluster: 5


In [19]:
print("Finished")

Finished
