# Gene-level alignment for pan-genome

In [None]:
import logging
import re
import urllib
from io import StringIO
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import gzip
import pickle
from tqdm.notebook import tqdm, trange
import multiprocessing
from IPython.display import display, HTML
import itertools

import plotly.graph_objects as go

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

In [None]:
import os

In [None]:
from pyphylon.plotting_util import *

In [None]:
from pyphylon.plotting import *

In [None]:
from pyphylon.util import load_config

## Set up files and matrices for analysis of genomic location
Load in gene matrix and header to allele to map each gff file to correct gene names

In [None]:
CONFIG = load_config("config.yml")
WORKDIR = CONFIG["WORKDIR"]
SPECIES = CONFIG["PG_NAME"]

In [None]:
DF_GENES = os.path.join(WORKDIR, f'processed/cd-hit-results/{SPECIES}_strain_by_gene.pickle.gz')
ENRICHED_METADATA = os.path.join(WORKDIR, 'interim/enriched_metadata_2d.csv')
# DF_EGGNOG

In [None]:
REFERENCE_STRAIN = '1314.3244' # a random one here for demo

In [None]:
DF_EGGNOG = os.path.join(WORKDIR, f'processed/df_eggnog.csv')

In [None]:
DF_CORE_COMPLETE = os.path.join(WORKDIR, f'processed/CAR_genomes/df_core.csv')
DF_ACC_COMPLETE = os.path.join(WORKDIR, f'processed/CAR_genomes/df_acc.csv')
DF_RARE_COMPLETE = os.path.join(WORKDIR, f'processed/CAR_genomes/df_rare.csv')

In [None]:
df_core_complete = pd.read_csv(DF_CORE_COMPLETE, index_col=0)
df_acc_complete = pd.read_csv(DF_ACC_COMPLETE, index_col=0)
df_rare_complete = pd.read_csv(DF_RARE_COMPLETE, index_col=0)

In [None]:
# Load in (full) P matrix
df_genes = pd.read_pickle(DF_GENES)

In [None]:
metadata = pd.read_csv(ENRICHED_METADATA, index_col=0, dtype='object')
metadata_complete = metadata[metadata.genome_status == 'Complete']

In [None]:
# Filter P matrix for Complete sequences only
df_genes_complete = df_genes[metadata_complete.genome_id]
df_genes_complete = df_genes_complete.fillna(0) # replace N/A with 0
df_genes_complete = df_genes_complete.sparse.to_dense().astype('int8') # densify & typecast to int8 for space and compute reasons
inCompleteseqs = df_genes_complete.sum(axis=1) > 0 # filter for genes found in complete sequences
df_genes_complete = df_genes_complete[inCompleteseqs]

df_genes_complete.shape

In [None]:
L_MATRIX = os.path.join(WORKDIR, f'processed/nmf-outputs/L_binarized.csv')
A_MATRIX = os.path.join(WORKDIR, f'processed/nmf-outputs/A_binarized.csv')

In [None]:
L_binarized = pd.read_csv(L_MATRIX, index_col=0)
A_binarized = pd.read_csv(A_MATRIX, index_col=0)

display(
    L_binarized.shape,
    L_binarized.head(),
    A_binarized.shape,
    A_binarized.head()
)

In [None]:
# df_eggnog = pd.read_csv(DF_EGGNOG, low_memory=False).set_index('gene')

In [None]:
# create dict where each strain has a gene vector
strain_vectors = {}

In [None]:
from pyphylon.biointerp import get_pg_to_locus_map

In [None]:
pg2locus_map = get_pg_to_locus_map(WORKDIR, SPECIES)

## Functions to Parse GFF

## Get vectors of genes and binarized vectors for each strain

In [None]:
for strain in tqdm(metadata_complete.genome_id):
    DF_gff, size, oric = gff2pandas(os.path.join(WORKDIR, f'processed/bakta/{strain}/{strain}.gff3'))
    #DF_gff['gene'] = DF_gff.locus_tag.apply(lambda x: h2a(x, header_to_allele))
    DF_gff = pd.merge(DF_gff, pg2locus_map, left_on='locus_tag', right_on='gene_id', how='left')
    DF_gff.rename(columns={'cluster':'gene'}, inplace=True)
    DF_gff = DF_gff[DF_gff.accession == DF_gff.accession.value_counts().index[0]]
    DF_gff = DF_gff[['gene','start']]
    gene_order = (DF_gff.sort_values('start').gene.to_list())
 
    strain_vectors[strain] = gene_order

In [None]:
len(strain_vectors[REFERENCE_STRAIN])

In [None]:
# Collect lengths of gene lists
gene_lengths = [len(genes) for genes in strain_vectors.values()]

# Creating the histogram
plt.hist(gene_lengths, bins=10, color='blue', edgecolor='black')

# Adding titles and labels
plt.title('Distribution of Gene Lengths')
plt.xlabel('Gene Length')
plt.ylabel('Frequency')

# Display the histogram
plt.show()

In [None]:
# Optional
# Create a new dictionary with strains having less than or equal to 6000 genes
strain_vectors_filtered = {strain: genes for strain, genes in strain_vectors.items() if len(genes) <= 6000}

## Find the genes shared between all srains and test them if they are in a certain order

In [None]:
common_gene_counts_df = count_common_gene_appearances(strain_vectors_filtered)
common_gene_counts_df

In [None]:
common_gene_counts_df.sum(axis = 1).hist()

In [None]:
common_gene_counts_df.sum(axis = 1)

In [None]:
common_gene_count, once_gene_count, once_genes = find_once_genes(strain_vectors_filtered)
print(f"Number of common genes: {common_gene_count}")
print(f"Number of genes that appear exactly once in each strain: {once_gene_count}")

In [None]:
strain_groups = create_strain_groups(strain_vectors_filtered, once_genes, REFERENCE_STRAIN)

In [None]:
# largest Group
largest_group_of_strains = max(strain_groups, key=lambda k: len(strain_groups[k]))

In [None]:
len(strain_groups[largest_group_of_strains])

In [None]:
strain_vectors_reference = {k: strain_vectors_filtered[k] for k in strain_groups[largest_group_of_strains] if k in strain_vectors_filtered}

## Use the largest group as reference for the order of anchor genes

In [None]:
reference_ordered_genes = get_reference_order(strain_vectors_reference, once_genes)

In [None]:
print(reference_ordered_genes)

## Named these genes from 1 to N

In [None]:
# Creating the mapping dictionary
gene_mapping = {gene: idx for idx, gene in enumerate(reference_ordered_genes, start=1)}

In [None]:
len(gene_mapping.keys())

In [None]:
# Apply the mapping to strain_vectors_filtered, keep unmapped genes unchanged
updated_strain_vectors = {}
for strain, genes in strain_vectors_filtered.items():
    updated_genes = [gene_mapping.get(gene, gene) for gene in genes]  # Use .get() to return the gene itself if not found
    updated_strain_vectors[strain] = updated_genes

In [None]:
strain_vectors_reordered, count_reversed = adjust_gene_order(updated_strain_vectors)
print("Number of strains reordered:", count_reversed)

In [None]:
strain_vectors_final, count_reordered = reorder_to_start_with_one(strain_vectors_reordered)
print("Number of strains reordered:", count_reordered)

In [None]:
strain_vectors_final[list(strain_vectors_final.keys())[0]][:20]

In [None]:
sequence_check_results, total_true, total_false = check_strict_sequence(strain_vectors_final)
print("Number of strains correctly ordered:", total_true)
print("Number of strains with other orders:", total_false)

## Create gene location name for all the other genes

In [None]:
gene_mapping_to_anchor_genes = generate_gene_names(strain_vectors_final)

In [None]:
gene_mapping_to_anchor_genes[REFERENCE_STRAIN][gene_mapping_to_anchor_genes[REFERENCE_STRAIN] != 'NA'].sort_values()

In [None]:
gene_count_between_anchor_genes_all = create_gene_count_between_anchor_genes_for_all(gene_mapping_to_anchor_genes)

In [None]:
gene_count_between_anchor_genes_all[REFERENCE_STRAIN]['Total Genes Between'].sum()

In [None]:
## Test if there are any missing pairs - NO Genes between these two genes
pairs = gene_count_between_anchor_genes_all[REFERENCE_STRAIN]['Anchor Genes']

# Generate the full list of expected pairs
expected_pairs = [f"{i}-{i+1}" for i in range(1, len(once_genes))]

# Find missing pairs by checking which expected pairs are not in the dataset
missing_pairs = set(expected_pairs) - set(pairs)
missing_pairs = sorted(list(missing_pairs), key=lambda x: int(x.split('-')[0]))

# Output the missing pairs
missing_pairs

## Identify the genetic variation in each strain

In [None]:
identify_genetic_variation(strain_vectors_final)

## Location of the phylon

In [None]:
def rename_indexes(df: pd.DataFrame) -> pd.DataFrame:
    """
    Remove 'yogenes' from all occurrences in the index.
    
    Parameters:
        df (pd.DataFrame): Input DataFrame with row indexes.
    
    Returns:
        pd.DataFrame: DataFrame with updated indexes.
    """
    df.index = df.index.str.replace('yogenes', '', regex=True)
    return df

In [None]:
L_binarized = rename_indexes(L_binarized)

## Circular plot for phylon location

### Test on MGE-2 location

In [None]:
plot_circular_genome(strain_vectors_final['530008.3'], L_binarized.index[L_binarized['phylon1'] == 1].tolist(), 'phylon1', '530008.3')

In [None]:
unique_genes_dict = unique_genes_by_phylon(L_binarized)