In [None]:
"""
CRISPR-Cas Analysis and Visualization
=====================================

This notebook analyzes CRISPR-Cas systems across Methylococcus strains and visualizes
spacer-phage interactions. The analysis includes:
1. Converting GenBank files to FASTA format
2. Processing CRISPR spacer data and phage matches
3. Creating publication-ready visualizations

"""

import os
import glob
import re
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord

# ============================================================================
# Configuration and File Paths
# ============================================================================

# Input directories
BASE_DIR = "."

# Input directories (relative to BASE_DIR)
PHAGES_GBK_DIR = os.path.join(BASE_DIR, "prophages")
PHAGE_FASTA_DIR = os.path.join(BASE_DIR, "Pharokka_fasta")
CRISPR_SUMMARY_FILE = os.path.join(BASE_DIR, "crisprs_report_types1.tsv")
SPACER_ANALYSIS_DIR = os.path.join(BASE_DIR, "CRISPRCasFinder_res")
PHAGE_MATCH_FILES = os.path.join(BASE_DIR, "spacepharer", "*_all_phages.tsv")

# Output directory
OUTPUT_DIR = os.path.join(BASE_DIR, "generated_images")

# Create output directory if it doesn't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ============================================================================
# Data Processing Functions
# ============================================================================

def convert_gbk_to_fasta(input_dir, output_dir):
    """
    Convert GenBank files to FASTA format.
    
    Args:
        input_dir (str): Directory containing .gbk files
        output_dir (str): Directory to save .fasta files
    """
    os.makedirs(output_dir, exist_ok=True)
    
    for filename in os.listdir(input_dir):
        if filename.endswith(".gbk"):
            gbk_file = os.path.join(input_dir, filename)
            fasta_file = os.path.join(output_dir, os.path.splitext(filename)[0] + ".fasta")
            
            with open(fasta_file, "w") as output_handle:
                for seq_record in SeqIO.parse(gbk_file, "genbank"):
                    SeqIO.write(seq_record, output_handle, "fasta")

def create_color_palette():
    """
    Create a colorblind-friendly palette for phage strains using matplotlib colors.
    
    Returns:
        dict: Mapping of strain names to colors
    """
    # Colorblind-friendly palette (manually defined)
    colors = [
        '#1f77b4',  # blue
        '#ff7f0e',  # orange
        '#2ca02c',  # green
        '#d62728',  # red
        '#9467bd',  # purple
        '#8c564b',  # brown
        '#e377c2',  # pink
        '#7f7f7f',  # gray
        '#bcbd22',  # olive
        '#17becf'   # cyan
    ]
    
    return {
        'KN2_R1': colors[1],       # orange
        'McNor_R2': colors[3],     # red
        '16-5_R1': colors[5],      # brown
        'IO1_R1': colors[4],       # purple
        'Bath_R2': colors[6],      # pink
        'Mc(Nor)_R1': colors[0],  # blue   
        '16-5_R2': colors[9],      # cyan
        'Bath_R1': colors[2]       # green
    }

def create_phage_mapping():
    """
    Create mapping between old phage names and standardized strain names.
    
    Returns:
        dict: Mapping of old names to new standardized names
    """
    return {
        "16-5_PHASTEST_4": '16-5_R2',
        "CP079098_IO1.phigaro.fasta_2": 'IO1_R1',
        "NC_002977_Bath.phigaro.fasta_2": 'Bath_R2',
        "NC_002977_Bath_PHASTEST_2": 'Bath_R2',
        "NC_002977_Bath_PhiSpy.phages_2": 'Bath_R2',
        "NZ_CP079098_IO1_PHASTEST_1": 'IO1_R1',
        "NZ_CP110921_16-5.phigaro.fasta_4": '16-5_R2',
        "NZ_OX458332_Mc(Nor)_PHASTEST_1": 'Mc(Nor)_R1',
        "PhiSpy_16-5.phages_4": '16-5_R2',
        "PhiSpy_IO1.phages_2": 'IO1_R1',
        "PhiSpy_Mc7.phages_2": "",  # Excluded due to partial data
        "VIBRANT_phages_CP079098_IO1_2": 'IO1_R1',
        "VIBRANT_phages_NC_002977_Bath_1": "Bath_R1",
        "VIBRANT_phages_NZ_CP110921_16-5_3": '16-5_R1',
        "VIBRANT_phages_OX458332_Mc-Nor_1": 'Mc(Nor)_R1'
    }

def parse_spacer_phage_file(file_path, phage_mapping):
    """
    Parse spacer-phage match files to extract relevant information.
    
    Args:
        file_path (str): Path to the spacer-phage match file
        phage_mapping (dict): Mapping of old phage names to new names
        
    Returns:
        list: Tuples of (crispr_id, phage_strain, spacer_number)
    """
    results = []
    current_prefix = ""
    current_value = ""

    with open(file_path, 'r') as file:
        for line in file:
            if line.startswith('#'):
                parts = line.strip().split('\t')
                if len(parts) > 1:
                    # Handle both strain types
                    match = re.search(r'Methylococcus_species_strain_strain-crispr_(.+)', parts[0])
                    if not match:
                        match = re.search(r'Methylococcus_geothermalis__str__IM1-crispr_(.+)', parts[0])
                    
                    if match:
                        current_prefix = match.group(1)
                        current_value = parts[1]

            elif line.startswith('>'):
                parts = line.strip().split('\t')
                if len(parts) > 1:
                    spacer_match = re.search(r'spacer(\d+)', parts[0])
                    
                    # Apply phage name mapping
                    if current_value in phage_mapping:
                        current_value = phage_mapping[current_value]
                    
                    if spacer_match and current_value:  # Only include non-empty mappings
                        spacer_number = spacer_match.group(1)
                        results.append((current_prefix, current_value, spacer_number))

    return results

def get_spacer_count(spacers_file):
    """
    Count the number of spacers in a FASTA file.
    
    Args:
        spacers_file (str): Path to spacer FASTA file
        
    Returns:
        int: Number of spacers
    """
    count = 0
    try:
        with open(spacers_file, 'r') as file:
            for line in file:
                if line.startswith('>'):
                    count += 1
    except FileNotFoundError:
        print(f"Warning: File not found - {spacers_file}")
        return 0
    return count

# ============================================================================
# Visualization Functions
# ============================================================================

def draw_crispr_array(N, section_colors, description, subplot, orientation):
    """
    Draw a single CRISPR array with colored spacers indicating phage matches.
    
    Args:
        N (int): Number of spacers in the array
        section_colors (dict): Mapping of colors to spacer positions
        description (str): Title for the subplot
        subplot: Matplotlib subplot object
        orientation (str): "Forward" or "Reverse" orientation
    """
    # Draw background grey rectangles representing the CRISPR array
    background_color = '#D3D3D3'
    for i in range(5):
        background_rect = patches.Rectangle((0, 0.4 + i * 0.05), N, 0.05, 
                                          linewidth=1, edgecolor=background_color, 
                                          facecolor=background_color)
        subplot.add_patch(background_rect)

    # Process spacer positions and colors
    positions_colored = {}
    for color, positions in section_colors.items():
        positions = list(set(positions))  # Remove duplicates
        for position in positions:
            # Adjust position based on orientation
            if orientation == "Forward":
                pos_updated = position - 1
            elif orientation == "Reverse":
                pos_updated = N - position
            else:
                pos_updated = position - 1  # Default to forward
                
            if pos_updated not in positions_colored:
                positions_colored[pos_updated] = []
            positions_colored[pos_updated].append(color)

    # Draw colored rectangles for spacers with phage matches
    for pos, colors in positions_colored.items():
        if not colors:
            continue
            
        colors = list(set(colors))  # Remove duplicate colors
        height_per_color = (0.05 * 5) / len(colors)
        
        for i, color in enumerate(colors):
            colored_rect = patches.Rectangle((pos, 0.4 + i * height_per_color), 1, height_per_color,
                                           linewidth=1, edgecolor=color, facecolor=color)
            subplot.add_patch(colored_rect)

    # Configure subplot
    subplot.set_xlim(0, N)
    subplot.set_ylim(0, 1)
    subplot.axis('off')
    subplot.set_title(f'{description}    N = {N}', pad=-10, fontsize=18)

def create_strain_layout():
    """
    Define the layout configuration for strains and CRISPR types.
    
    Returns:
        tuple: (species_order, crispr_type_to_col mapping, row mapping function)
    """
    species_order = ['Mc7', 'IM1', 'KN2', 'McNor', 'Bath', 'IO1', 'BH', 'MIR']
    
    crispr_type_to_col = {
        'CAS-TypeIE': 0,
        'CAS-TypeIC': 1,
        'CAS-TypeIF_1': 2,
        'CAS-TypeIF_2': 2,
        'CAS-TypeIF': 2,
        'CAS-TypeIIIA': 2,
        'noCAS': 2,
        'noCAS_1': 2
    }
    
    def get_row(species, cas_type):
        """Map species and CAS type to row position."""
        if cas_type == 'CAS-TypeIF_2':
            return 2
        elif cas_type == 'CAS-TypeIIIA':
            return 3
        elif cas_type == 'noCAS':
            return 4
        elif cas_type == 'noCAS_1':
            return 5
        else:
            return species_order.index(species)
    
    return species_order, crispr_type_to_col, get_row

def extract_crispr_info(crispr, strain, crispr_summary_df):
    """
    Extract CRISPR type and orientation information from summary data.
    
    Args:
        crispr (str): CRISPR identifier
        strain (str): Strain name
        crispr_summary_df (DataFrame): CRISPR summary data
        
    Returns:
        tuple: (crispr_name, crispr_type, orientation) or None if not found
    """
    # Parse CRISPR identifier
    if strain != "IM1":
        match = re.search(r'Methylococcus_species_strain_strain_(.+)', crispr)
        if match:
            crispr_short = match.group(1)
            crispr_n = crispr_short.split('-spacers.fa')[0]
        else:
            return None
            
        match_strain_id = re.search(r'(.+?)_Methylococcus_', crispr)
        if match_strain_id:
            strain_id = match_strain_id.group(1)
            crispr_name = f"Sequence_{strain_id}_Methylococcus_species_strain_strain-crispr_{strain_id}_Methylococcus_species_strain_strain_{crispr_n}-spacers.fa"
        else:
            return None
    else:
        # Special handling for IM1 strain
        match = re.search(r'NZ_CP046565_1_Methylococcus_geothermalis__str__IM1_(.+)', crispr)
        if match:
            crispr_short = match.group(1)
            crispr_n = crispr_short.split('-spacers.fa')[0]
            crispr_name = f"Sequence_NZ_CP046565_1_Methylococcus_geothermalis__str__IM1-crispr_NZ_CP046565_1_Methylococcus_geothermalis__str__IM1_{crispr_n}-spacers.fa"
        else:
            return None
    
    # Look up CRISPR type and orientation
    try:
        crispr_type = crispr_summary_df.loc[crispr_summary_df['ID'] == crispr_name, 'Type'].values[0]
        
        # Determine orientation
        crispr_dir = crispr_summary_df.loc[crispr_summary_df['ID'] == crispr_name, 'CRISPRDirection'].values[0]
        if "HIGH" in crispr_dir:
            orientation = "Reverse" if "R" in crispr_dir else "Forward"
        else:
            orientation = crispr_summary_df.loc[crispr_summary_df['ID'] == crispr_name, 'Potential_Orientation (AT%)'].values[0]
            
        return crispr_name, crispr_type, orientation
        
    except (IndexError, KeyError):
        print(f"Warning: Could not find CRISPR info for {crispr_name}")
        return None

def add_manual_entries(axs, get_row, crispr_type_to_col):
    """
    Add manually curated entries for strains without spacer-phage match data.
    
    Args:
        axs: Matplotlib subplot array
        get_row: Function to get row position
        crispr_type_to_col: Mapping of CRISPR types to columns
    """
    manual_entries = [
        ("IM1", "CAS-TypeIIIA", 22),
        ("BH", "CAS-TypeIC", 70),
        ("Mc7", "CAS-TypeIF_2", 40)
    ]
    
    for strain, crispr_type, N in manual_entries:
        col = crispr_type_to_col[crispr_type]
        row = get_row(strain, crispr_type)
        description = f"{strain}_{crispr_type}"
        
        # Draw empty CRISPR array (no phage matches)
        draw_crispr_array(N, {}, description, axs[row, col], "Forward")

# ============================================================================
# Main Analysis and Visualization
# ============================================================================

def generate_publication_figure():
    """
    Generate the main publication figure showing CRISPR-phage interactions.
    """
    # Load data
    crispr_summary_df = pd.read_csv(CRISPR_SUMMARY_FILE, sep='\t')
    color_palette = create_color_palette()
    phage_mapping = create_phage_mapping()
    species_order, crispr_type_to_col, get_row = create_strain_layout()
    
    # Get input files
    file_paths = glob.glob(PHAGE_MATCH_FILES)
    print(f"Processing {len(file_paths)} spacer-phage match files")
    
    # Initialize figure
    num_cols = 3
    num_rows = len(species_order)
    fig, axs = plt.subplots(num_rows, num_cols, figsize=(19, num_rows * 2))
    
    # Process each strain's data
    for file_path in file_paths:
        strain = file_path.split("/")[-1].split("_")[0]
        if strain == '16':
            strain = '16_5'
        
        # Skip excluded strains
        if strain in ["c8", "EFPC2"]:
            continue
            
        print(f"\nProcessing strain: {strain}")
        
        # Parse spacer-phage matches
        parsed_results = parse_spacer_phage_file(file_path, phage_mapping)
        
        # Group results by CRISPR array
        split_results = {}
        for result in parsed_results:
            crispr_id = result[0]
            if crispr_id not in split_results:
                split_results[crispr_id] = []
            split_results[crispr_id].append(result[1:])
        
        # Process each CRISPR array
        for crispr, matches in split_results.items():
            section_colors = {}
            
            # Collect spacer positions for each phage strain
            for phage_strain, spacer_num in matches:
                if phage_strain in color_palette:
                    color = color_palette[phage_strain]
                    if color not in section_colors:
                        section_colors[color] = []
                    section_colors[color].append(int(spacer_num))
            
            # Extract CRISPR information
            crispr_info = extract_crispr_info(crispr, strain, crispr_summary_df)
            if not crispr_info:
                continue
                
            crispr_name, crispr_type, orientation = crispr_info
            
            # Get spacer count
            spacers_file = f"{SPACER_ANALYSIS_DIR}/{strain}/spacers/{crispr_name}"
            N = get_spacer_count(spacers_file)
            
            if N == 0:
                continue
            
            # Determine subplot position
            col = crispr_type_to_col[crispr_type]
            row = get_row(strain, crispr_type)
            description = f"{strain}_{crispr_type}"
            
            # Draw the CRISPR array
            draw_crispr_array(N, section_colors, description, axs[row, col], orientation)
    
    # Add manually curated entries for strains without spacer-phage data
    add_manual_entries(axs, get_row, crispr_type_to_col)
    
    # Hide empty subplots
    for row in range(num_rows):
        for col in range(num_cols):
            if not hasattr(axs[row, col], '_children') or len(axs[row, col]._children) == 0:
                axs[row, col].axis('off')
    
    # Finalize and save figure
    plt.tight_layout()
    plt.savefig(f'{OUTPUT_DIR}/crispr_phage_interactions_final.png', dpi=600, bbox_inches='tight')
    plt.show()

# ============================================================================
# Main Execution
# ============================================================================

if __name__ == "__main__":
    # Step 1: Convert GenBank files to FASTA (if needed)
    # convert_gbk_to_fasta(PHAGES_GBK_DIR, PHAGE_FASTA_DIR)
    
    # Step 2: Generate the main publication figure
    generate_publication_figure()
    
    print("Analysis complete. Publication figure saved to:", 
          f"{OUTPUT_DIR}/crispr_phage_interactions_final.png")