# Post processing workflow

## Imports and variables

In [None]:
from multiprocessing import Lock, Process, set_start_method, cpu_count, Value
import pandas as pd
import gc
import time
from IPython.display import clear_output
from pydeseq2.dds import DeseqDataSet
from pydeseq2.ds import DeseqStats
import os
import numpy as np
import re
from typing import List

# set_start_method('spawn', force=True)

error_flag = Value('i', 0)

merged_dataset_path = "merged_dataset.pq"

merged_metadata_path = "merged_metadata.pq"

## Multiprocessing part (Legacy)

In [None]:
amount_of_processes = 30

max_amount_of_processes = 5

# Necessary to run this code when running the worker
temp_df_header = pd.read_csv(merged_dataset_path, nrows=0, header=0, sep=",")
master_full_column_names = temp_df_header.columns.tolist()

index_col_name = master_full_column_names[0]
gene_count_column_names = master_full_column_names[1:]

In [None]:
def worker(start: int, end: int, increment: int, processed_files_output_path = "data/dataset_parts") -> None:
    try:
        if error_flag.value == 1:
            print(f"Error flag is set, skipping processing for columns {start} to {end + start}.")
            return
        
        column_names = [index_col_name] + gene_count_column_names
        
        # range_list = list(range(start, end))

        print(f"Processing columns {start} to {end}...")

        merged_dataset_df = pd.read_csv(merged_dataset_path, usecols=column_names[start:end], index_col=0, header=0, sep=",")

        print(f"Convert count data to integer type...")
        merged_dataset_df = merged_dataset_df.astype(np.uint32)

        print(f"Processing metadata...")

        if not os.path.exists(processed_files_output_path):
            os.makedirs(processed_files_output_path)

        output_path = f"{processed_files_output_path}/processed_{increment}.pkl"

        merged_dataset_df.to_pickle(output_path)

    except Exception as e:
        print(f"Error processing rows {start} to {end}: {e}")

        error_flag.value = 1
    finally:
        gc.collect()

In [None]:
def workers_manager() -> None:
    print("Starting workers manager...")

    error_flag.value = 0

    if not os.path.exists(merged_dataset_path):
        raise FileNotFoundError(f"The merged dataset file '{merged_dataset_path}' does not exist.")
    
    if not os.path.exists(merged_metadata_path):
        raise FileNotFoundError(f"The merged metadata file '{merged_metadata_path}' does not exist.")
    
    # ceil division to ensure all columns are processed
    chunk_size = -(-len(master_full_column_names) // amount_of_processes)

    if chunk_size == 0:
        raise ValueError("Chunk size is zero. Please increase the number of processes or reduce the dataset size.")

    print(f"Chunk size: {chunk_size}")

    tasks_to_process = []

    running_processes = []

    total_columns_allocated = 0

    for i in range(amount_of_processes):
        p = Process(target=worker, args=(total_columns_allocated, total_columns_allocated + chunk_size, i))

        total_columns_allocated += chunk_size
        tasks_to_process.append(p)

    if (total_columns_allocated < len(master_full_column_names)):
        raise ValueError(f"Total columns allocated ({total_columns_allocated}) is less than total lines ({len(master_full_column_names)}). Please adjust the chunk size or amount of processes.")
    else:
        print(f"Total columns allocated ({total_columns_allocated}) is sufficient for total lines ({len(master_full_column_names)}). Proceeding with processing.")


    for i, process in enumerate(tasks_to_process):
        # Clean up finished processes (p.is_alive() returns False if the process has finished)
        running_processes = [p for p in running_processes if p.is_alive()]

        while len(running_processes) >= max_amount_of_processes:
            # print(f"Waiting for processes to finish. Currently running: {len(running_processes)}")

            for p in running_processes:
                if not p.is_alive():
                    print(f"Process {p.pid} has finished. Joining...")
                    p.join()
                    running_processes.remove(p)
                    break
                else:
                    # Sleep to avoid busy waiting
                    time.sleep(0.1)
            
            # clear_output(wait=True)

        process.start()
        running_processes.append(process)

        print(f"Started process {i + 1}/{len(tasks_to_process)}. Total running processes: {len(running_processes)}")

    for process in tasks_to_process:
        process.join()
        print(f"Process {process.pid} has finished.")

    print("All processes have completed.")

    gc.collect()

    print("Normalization completed successfully.")
    print("Exiting...")

In [None]:
workers_manager()

## Full post processing part

In [None]:
def full_post_process_no_chunks(
        output_path: str = "pydeseq_output",
        cpu_amount: int = 4) -> None:
    try:
        merged_df = pd.read_parquet(merged_dataset_path)
        merged_metadata = pd.read_parquet(merged_metadata_path)

        common_samples = merged_df.index.intersection(merged_metadata.index)

        filtered_merged_df = merged_df.loc[common_samples]
        filtered_merged_metadata = merged_metadata.loc[common_samples]

        print(f"Final merged dataframe shape: {filtered_merged_df.shape}")
        print(f"Final merged metadata shape: {filtered_merged_metadata.shape}")

        if filtered_merged_df.empty or filtered_merged_metadata.empty:
            print("Merged dataframe or metadata is empty. Exiting.")
            return
        
        if not os.path.exists(output_path):
            os.makedirs(output_path)

        result = gc.collect()

        print(f"Garbage collector cleaned up {result} unreachable objects after merging dataframes.")

        try:
            dds = DeseqDataSet(
                counts=merged_df,
                metadata=merged_metadata,
                design="~condition",
                n_cpus=cpu_amount
            )

            dds.deseq2()

            dds.vst()

            vst_transformed_counts_df = pd.DataFrame(dds.layers['vst_counts'], index=dds.obs.index, columns=dds.var.index)

            vst_transformed_counts_df.to_parquet(f"{output_path}/vst_transformed_counts.parquet")

            print("VST transformed counts saved successfully.")

            print(f"Garbage collector cleaned up {result} unreachable objects after processing chunk {i + 1}.")

        except Exception as e:
            print(f"Error during DeseqDataSet processing: {e}")

    except Exception as e:
        print(f"Error during full post-processing: {e}")

In [None]:
full_post_process_no_chunks(cpu_amount=cpu_count() - 8)

## Merge functions

In [None]:
def merge_pydeseq_results(output_path: str = "data/pydeseq_output") -> None:
    try:
        result_files = [f for f in os.listdir(output_path) if f.startswith('deseq_result_chunks_') and f.endswith('.pkl')]

        if not result_files:
            print("No result files found to merge.")
            return
        
        merged_results = []

        for file in result_files:
            file_path = os.path.join(output_path, file)
            print(f"Reading result file: {file_path}")
            df = pd.read_pickle(file_path)
            merged_results.append(df)

        if merged_results:
            final_merged_df = pd.concat(merged_results, axis=0)
            final_output_file = os.path.join(output_path, "final_merged_deseq_results.pkl")
            final_merged_df.to_pickle(final_output_file)
            print(f"Final merged results saved to {final_output_file}")
        else:
            print("No data to merge from result files.")

    except Exception as e:
        print(f"Error while merging pydeseq results: {e}")

In [None]:
merge_pydeseq_results()

In [None]:
def merge_dataset_parts(output_path: str = "data/dataset_parts", output_file: str = "data/merged_dataset.pkl") -> None:
    try:
        part_files = [f for f in os.listdir(output_path) if f.startswith('processed_') and f.endswith('.pkl')]
        part_files.sort(key=lambda x: int(re.search(r'processed_(\d+)', x).group(1)))
        
        if not part_files:
            print("No dataset parts found to merge.")
            return
        
        merged_parts = []

        for file in part_files:
            file_path = os.path.join(output_path, file)
            print(f"Reading dataset part file: {file_path}")

            df = pd.read_pickle(file_path)
            merged_parts.append(df)

        if merged_parts:
            final_merged_df = pd.concat(merged_parts, axis=1)
            final_merged_df.to_pickle(output_file)
            print(f"Final merged dataset saved to {output_file}")
        else:
            print("No data to merge from dataset parts.")

    except Exception as e:
        print(f"Error while merging dataset parts: {e}")

In [None]:
merge_dataset_parts()