## This script is for visualizing the COG content of prokaryote genomes

### Required packages

Please run this script in the environment (e.g. conda) which contain the packages below.

- Numpy
- Pandas
- Matplotlib

### Path Configuration

**Please set the following variables to your local paths before running the script:**

- `orthogroups`: Path to the OrthoFinder `Orthogroups.tsv` file
- `orthosingle`: Path to the OrthoFinder `Orthogroups_SingleCopyOrthologues.tsv` file
- `eggnog_dir`: Directory containing EggNOG annotation subfolders, one per sample
- `output_dir`: Desired location for the final output files

**Example usage:**
```python
orthogroups = '/path/to/Orthogroups/Orthogroups.tsv'
orthosingle = '/path/to/Orthogroups/Orthogroups_SingleCopyOrthologues.tsv'
eggnog_dir = '/path/to/eggnog/annotations'
output_dir = '/path/to/output_directory'

### Procedure
#### Step1. Read OrthoFinder data
Reads `Orthogroups.tsv` and `Orthogroups_SingleCopyOrthologues.tsv` and concatenate them

#### Step2. Read EggNOG annotations
Reads multiple files inside a directory (i.e., looping over `.annotations.tsv` files)

#### Step3. Combine OrthoFinder data with EggNOG info
Construct `orthogroups_eggnog_df`

#### Step4. Extract COG information and determine consensus COGs

#### Step5. Filter, sort and rearrange data
Create binary dataframe `df_cog_binary`

#### Step6. Visualize
Generate a block plot for the presence/absence data

### Output files

- `COG_binary.csv`: 
- `COG_visualization.svg`: 

In [1]:
# Output file from OrthoFinder
orthogroups = ''
orthosingle = ''

# Direcotory that contains all eggnog result of each sample
eggnog_dir = ''

# path to save the output files
output_dir = ''

# Sample order
sample_order = ['B27', 'Pal2', 'Q6', 'S_platyhelix', 'S_citri', 'S_eriocheiris', 'S_apis', 'S_culicicola', 'S_glodiatoris', 'P_umbrosa_mycoplasma', 'Thalassoplasma', 'Oceanoplasma', 'Spiroplasma_holothuricola']

---

In [None]:
## Library import
import collections
import glob
import itertools
import numpy as np
import os
import pandas as pd
# import pprint
import re
import matplotlib.pyplot as plt
import matplotlib.patches as patches
# import tqdm

#### Step1. Read OrthoFinder data
Reads `Orthogroups.tsv` and `Orthogroups_SingleCopyOrthologues.tsv` and concatenate them

In [None]:
"""
Reads and concatenates OrthoFinder data from the provided paths.
Returns a combined DataFrame with orthogroup information.

Args:
	orthogroups (str): Absolute path of orthofinder Orthogroups.tsv
	orthosingle (str): Absolute path of orthofinder Orthogroups_SingleCopyOrthologues.tsv

Return:
	ortho_df (DataFrame): Concatenated Dataframe
"""

orthogroups_df = pd.read_csv(orthogroups, sep='\t', index_col='Orthogroup')
orthosingle_df = pd.read_csv(orthosingle, sep='\t', index_col='Orthogroup')
ortho_df = pd.concat([orthogroups_df, orthosingle_df])

#### Step2. Read EggNOG annotations
Reads multiple files inside a directory (i.e., looping over `.annotations.tsv` files)

In [None]:
"""
Reads all EggNOG annotations from eggnog_dir and stores them in a dictionary.
Returns a dict {sample_name -> DataFrame}.

Args:
	eggnog_dir (str): Absolute path of directory which contain all EggNOG outputs

Return:
	coginfo_data (dict): {sample_name -> DataFrame}
"""

eggnog_output = glob.glob(f"{eggnog_dir}/*/*.annotations.tsv")
coginfo_data = dict()

for file in eggnog_output:
	sample_name = os.path.basename(os.path.dirname(file))

	df = pd.read_csv(file, sep='\t')

	output_df = df[['query', 'COG_category', 'Description', 'Preferred_name']].copy()
	output_df['GeneCode'] = df['query'].str.replace('LOCUS', sample_name)
	output_df = output_df[['GeneCode', 'COG_category', 'Preferred_name', 'Description']].copy()
	output_df = output_df.set_index('GeneCode')
	coginfo_data[sample_name] = output_df

#### Step3. Combine OrthoFinder data with EggNOG info
Construct `orthogroups_eggnog_df`

In [None]:
# Create a dataframe to save the data
samples = list(ortho_df.columns)
topic = ['COG_consensus', 'COG', 'gene', '']
new_header = []
for head, suffix in itertools.product(samples, topic):
	if suffix == '':
		new_header.append(head)
	else:
		new_header.append(head + '_' + suffix)
orthogroups_eggnog_df = pd.DataFrame(columns=new_header, dtype=str)
orthogroups_eggnog_df.index.name = 'Orthogroup'
orthogroups_eggnog_df.head()

In [5]:
def get_consensus_cog(cog_list):
    """
    Returns the most frequent COG character from cog_list,
    or NaN if cog_list is empty.
    """
    import numpy as np
    import collections

    if not cog_list:
        return np.nan

    # Cound occurences of each character
    cog_count = collections.defaultdict(int)
    for c in cog_list:
        cog_count[c] += 1

    # Pick the character with the highest count
    consensus, max_value = None, -1
    for char, count in cog_count.items():
        if count > max_value:
            consensus = char
            max_value = count
    return consensus


In [None]:
for row in ortho_df.itertuples():
	new_row = pd.Series(index=[row.Index])
	orthogroups_eggnog_df = pd.concat([orthogroups_eggnog_df, new_row], join='outer')

	for sample in samples:
		coginfo_df = coginfo_data[sample]
		# print(f'{sample}: {type(getattr(row, sample))}')

		loci = getattr(row, sample)

		cog_all = ''
		gene_all = []

		for locus in str(loci).split(','):
			if locus == 'nan':
				continue

			locus_num = locus.split('_')[1]
			eggnog_locus = f'{sample}_{locus_num}'
			try:
				cog = coginfo_df.at[eggnog_locus, 'COG_category']
				gene = coginfo_df.at[eggnog_locus, 'Preferred_name']
				# print(f'{sample}: {cog}, {gene}')
				cog_all += cog
				gene_all.append(gene)
			except KeyError:
				print(f'{eggnog_locus}: no hit in eggnog')
				pass

		consensus_cog = get_consensus_cog(cog_all)

		if len(gene_all):
			gene_all = ', '.join(gene_all)
		else:
			gene_all = np.nan

		orthogroups_eggnog_df.at[row.Index, f'{sample}_COG_consensus'] = consensus_cog
		orthogroups_eggnog_df.at[row.Index, f'{sample}_COG'] = cog_all
		orthogroups_eggnog_df.at[row.Index, f'{sample}_gene'] = gene_all
		orthogroups_eggnog_df.at[row.Index, f'{sample}'] = loci

orthogroups_eggnog_df = orthogroups_eggnog_df.drop(0, axis=1)

orthogroups_eggnog_df

In [None]:
orthogroups_eggnog_df.info()

#### Step4. Extract, fileter and sort COG information and determine consensus COGs
Create binary dataframe `df_cog_binary`

In [114]:
# 1) Create a copy
cog_df = orthogroups_eggnog_df.copy()

# 2) Summarize COG info across samples in a single column
cog_df['COG_all'] = cog_df.filter(like='COG_consensus') \
						.replace('-', np.nan).replace('=', np.nan) \
						.fillna('').convert_dtypes().sum(axis=1)

# 3) Determine each row's consunsus COG
consensus_cog_list = list()
for cog in cog_df['COG_all']:
    consensus_cog = get_consensus_cog(cog)
    consensus_cog_list.append(consensus_cog)

cog_df['orthogroup_COG'] = consensus_cog_list

In [None]:
orthogroups_eggnog_columns = list(orthogroups_eggnog_df.columns)

header = list()
for item in orthogroups_eggnog_columns:
	suffix = item.split('_')[-1]
	if suffix in ['COG', 'gene', 'consensus']:
		continue
	header.append(item)

header.insert(0, 'orthogroup_COG')
header.insert(1, 'COG_all')

df = cog_df.reindex(columns=header)

In [117]:
# 4) Drop rows where 'orthogroup_COG' is NaN
df = df.dropna(subset=['orthogroup_COG'])

In [None]:
unique_cogs = df['orthogroup_COG'].unique()
print(sorted(unique_cogs))

In [None]:
# 5) Sort by COG order
cog_order = ['L', 'K', 'J', 'O', 'G', 'F', 'E', 'I', 'H', 'P', 'C', 'Q', 'M', 'U', 'T', 'D', 'V', 'N']

# Assuming 'orthogroup_COG' is the column you want to sort by
df['orthogroup_COG'] = pd.Categorical(df['orthogroup_COG'], categories=cog_order, ordered=True)

# Sort the DataFrame based on 'orthogroup_COG'
df_sorted = df.sort_values('orthogroup_COG')
df_sorted = df_sorted.dropna(subset=['orthogroup_COG'])

In [None]:
# 6) Create presence/absence (binary) table
df_cog_binary = df_sorted.reindex(columns=sample_order)

# Replace empty cells with 0 and non-empty cells with 1 for the entire DataFrame
df_cog_binary = df_cog_binary.applymap(lambda x: 1 if pd.notna(x) and x != '' else 0)

df_cog_binary['counts'] = df_cog_binary.sum(axis=1)

df_cog_binary['orthogroup_COG'] = df_sorted['orthogroup_COG']

# Convert 'orthogroup_COG' to a categorical type with the defined order
df_cog_binary['orthogroup_COG'] = pd.Categorical(df_cog_binary['orthogroup_COG'], categories=cog_order, ordered=True)

# Sort by 'orthogroup_COG' based on cog_order, then by 'counts' in descending order
df_cog_binary = df_cog_binary.sort_values(by=['orthogroup_COG', 'counts'], ascending=[True, False])

fig_columns = ['orthogroup_COG', 'counts'] + sample_order
df_cog_binary = df_cog_binary.reindex(columns=fig_columns)

In [134]:
output_path = os.path.join(output_dir + 'COG_binary.csv')
df_cog_binary.to_csv(output_path, sep='\t', index_label='Orthogroup')

#### Step5. Visualize
Generate a block plot for the presence/absence data

In [None]:
df_to_fig = df_cog_binary[sample_order]

In [136]:
partition_colors = {
    'L': '#1f77b4',  # Blue
    'K': '#ff7f0e',  # Orange
    'J': '#2ca02c',  # Green
    'O': '#d62728',  # Red
    'G': '#9467bd',  # Purple
    'F': '#8c564b',  # Brown
    'E': '#e377c2',  # Pink
    'I': '#7f7f7f',  # Grey
    'H': '#bcbd22',  # Olive
    'P': '#17becf',  # Cyan
    'C': '#FF7675',  # Lighter Red
    'Q': '#F39C12',  # Lighter Orange
    'M': '#2C3E50',  # Darker Gray
    'U': '#FF1493',  # Deep Pink
    'T': '#4682B4',  # Steel Blue
    'D': '#FF6347',  # Tomato
    'N': '#00FA9A',  # Medium Spring Green
    'V': '#B22222',  # Fire Brick
    # 'S': '#B22222',  # Fire Brick
}

In [137]:
cog_classification = {
    'J': 'Translation, Ribosomal Structure and Biogenesis',
    'A': 'RNA Processing and Modification',
    'K': 'Transcription',
    'L': 'Replication, Recombination and Repair',
    'B': 'Chromatin Structure and Dynamics',
    'D': 'Cell cycle control, Cell division, Chromosome partitioning',
    'Y': 'Nuclear Structure',
    'V': 'Defense mechanisms',
    'T': 'Signal transduction mechanisms',
    'M': 'Cell wall/membrane/envelope biogenesis',
    'N': 'Cell motility',
    'Z': 'Cytoskeleton',
    'W': 'Extracellular Structures',
    'U': 'Intracellular trafficking, Secretion, anc Vesicular transport',
    'O': 'Post-translational modification, Protein Turnover, Chaperones',
    'X': 'Mobilome: prophages, transposons',
    'C': 'Energy production and conversion',
    'G': 'Carbonhydrate transport and metabolism',
    'E': 'Aminoacid transport and metabolism',
    'F': 'Nucleotide transport and metabolism',
    'H': 'Coenzyme transport and metabolism',
    'I': 'Lipid transport and metabolism',
    'P': 'Inorganic ion transport and metabolism',
    'Q': 'Secondary metabolites biosynthesis, transport and catabolism',
    'R': 'General function prediction only',
    'S': 'Function unknown'
}

In [None]:
for cog in cog_order:
	print(f'{cog}: {cog_classification[cog]}, {partition_colors[cog]}')

In [None]:
# Determine which row indices correspond to each COG partition
partitions = {}
for cog in cog_order:
	partitions[cog] = df_cog_binary.index[df_cog_binary['orthogroup_COG'] == cog].tolist()

print(partitions['C'])

In [None]:
# Create figure and axis without boundary and axis
fig, ax = plt.subplots(frameon=False)

# Determine cell size and padding for squares
cell_size = 1
padding = 0.1
border_thickness = 0.01  # Thickness of the cell boundaries

# Loop through each cell and plot squares with different colors for different partitions
for i in range(df_to_fig.shape[0]):  # Iterate over row positions
    for j in range(df_to_fig.shape[1]):  # Iterate over columns
        # Get the original index of the row using iloc
        original_index = df_to_fig.index[i]
        current_value = df_to_fig.iloc[i, j]

        facecolor = '#D3D3D3' # Default color: grey

        for partition, row_numbers in partitions.items():
            # Check if the original index is in the row numbers list for the partition
            if original_index in row_numbers and current_value == 1:
                facecolor = partition_colors.get(partition, '#D3D3D3')
                break

        # Plot a square in lighter grey if no partition match is found
        square = patches.Rectangle(
            (j + padding, df_to_fig.shape[0] - i - 1 + padding),
            cell_size - 2 * padding,
            cell_size - 2 * padding,
            linewidth=border_thickness,
            edgecolor='white',
            facecolor=facecolor
        )
        ax.add_patch(square)

# Set x and y axis limits based on the number of rows and columns
ax.set_xlim(0, df_to_fig.shape[1])
ax.set_ylim(0, df_to_fig.shape[0])

ax.axis('off')

# Save the plot as a SVG file
output_path = os.path.join(output_dir + 'COG_visualization.svg')
plt.savefig(output_path)