In [None]:
import csv
from collections import defaultdict

def calculate_mutation_frequencies(input_csv_file, base_sequence, output_csv_file):
    """
    Calculates the frequency of amino acid mutations at each position
    compared to a base sequence from an input CSV file.

    Args:
        input_csv_file (str): Path to the input CSV file.
        base_sequence (str): The base amino acid sequence to compare against.
        output_csv_file (str): Path to the output CSV file.
    """
    mutation_counts = defaultdict(lambda: defaultdict(int))
    total_sequences = 0
    sequence_length = len(base_sequence)
    amino_acids = 'ACDEFGHIKLMNPQRSTVWY'

    try:
        with open(input_csv_file, 'r', newline='') as csvfile:
            reader = csv.DictReader(csvfile)
            if 'sequence' not in reader.fieldnames:
                print(f"Error: 'sequences' column not found in '{input_csv_file}'.")
                return

            for row in reader:
                sequence = row['sequence']
                if len(sequence) != sequence_length:
                    print(f"Warning: Sequence '{sequence}' has incorrect length. Skipping.")
                    continue

                total_sequences += 1
                for i in range(sequence_length):
                    if sequence[i] != base_sequence[i]:
                        mutation_counts[i][sequence[i]] += 1

        if total_sequences == 0:
            print("Error: No valid sequences found in the input file.")
            return

        with open(output_csv_file, 'w', newline='') as csvoutfile:
            writer = csv.writer(csvoutfile)
            header = ['Position'] + list(amino_acids)
            writer.writerow(header)

            for i in range(sequence_length):
                row = [i + 1]
                for aa in amino_acids:
                    frequency = mutation_counts[i].get(aa, 0) / total_sequences
                    row.append(f"{frequency:.6f}")  # Format as fraction
                writer.writerow(row)

        print(f"Mutation frequencies written to '{output_csv_file}'.")

    except FileNotFoundError:
        print(f"Error: Input file '{input_csv_file}' not found.")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")

if __name__ == "__main__":
    input_file = 'nitazene_mpnn_seqs.csv'
    base_seq = 'MPSELTPEERSELKNSIAEFHTYQLDPGSCSSLHAQRIHAPPELVWSIVRRFDKPQTYKHFIKSCSVEFEMRVGCTRDVIVISGLPANTSTERLDILDDERRVTGFSIIGGEHRLTNYKSVTTVHRFEKENRIWTVVLESYVVDMPEGNSEDDTRMFADTVVKLNLQKLATVAEAMARN'
    output_file = 'mutation_frequencies.csv'
    calculate_mutation_frequencies(input_file, base_seq, output_file)




In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np

def plot_mutation_heatmap_v6(input_csv_file, positions_to_plot, highlight_cells, output_filename='mutation_heatmap.png'):
    """
    Generates a publication-quality heatmap of mutation frequencies for specified positions
    from a CSV where positions are rows, with linear scaling (0 gray, 1 blue),
    positions at the top, square boxes, frequency values on non-zero cells, and highlighted borders.

    Args:
        input_csv_file (str): Path to the 'mutation_frequencies.csv' file.
        positions_to_plot (list): List of integer positions to include in the heatmap.
        highlight_cells (dict): Dictionary of {position: [amino_acids]} to highlight.
        output_filename (str): Name of the output image file.
    """
    try:
        df = pd.read_csv(input_csv_file)
        df = df.set_index('Position')

        # Select only the specified positions (rows)
        df_selected = df.loc[positions_to_plot]
        amino_acids = df_selected.columns.tolist()
        positions_labels = [f'{pos}' for pos in df_selected.index]

        # Create the figure and axes
        num_positions = len(positions_to_plot)
        num_amino_acids = len(amino_acids)
        fig, ax = plt.subplots(figsize=(num_positions * 0.6, num_amino_acids * 0.4))  # Adjust figure size

        # Define the colormap (only for non-zero)
        cmap_non_zero = colors.LinearSegmentedColormap.from_list(
            "mutation_cmap_non_zero", ["#808080", "cornflowerblue"]  # Medium gray to cornflower blue
        )

        vmin_non_zero = df_selected[df_selected > 0].min().min() if not df_selected[df_selected > 0].empty else 0
        vmax_non_zero = 1  # Linear scale to 1

        norm_non_zero = colors.Normalize(vmin=vmin_non_zero, vmax=vmax_non_zero, clip=True)

        # Create the heatmap with explicit coloring and value annotation
        for i, aa in enumerate(amino_acids):
            for j, pos_idx in enumerate(range(num_positions)):
                pos = positions_to_plot[pos_idx]
                frequency = df_selected.loc[pos, aa]
                color = '#808080' if frequency == 0 else cmap_non_zero(norm_non_zero(frequency))
                rect = plt.Rectangle((j - 0.5, i - 0.5), 1, 1, facecolor=color, edgecolor='black', linewidth=0.5)
                ax.add_patch(rect)
                if frequency > 0:
                    ax.text(j, i, f"{frequency:.2f}", ha="center", va="center", color="white", fontfamily='Arial', fontsize=8)

                # Highlight specific cells
                if pos in highlight_cells and aa in highlight_cells[pos]:
                    highlight_rect = plt.Rectangle((j - 0.5, i - 0.5), 1, 1, edgecolor='orange', linewidth=1, fill=False)
                    ax.add_patch(highlight_rect)

        # Set labels and ticks
        ax.set_xticks(np.arange(num_positions))
        ax.set_xticklabels(positions_labels, fontsize=10)
        ax.set_yticks(np.arange(num_amino_acids))
        ax.set_yticklabels(amino_acids, fontsize=10)

        # Place x-axis labels at the top
        ax.xaxis.tick_top()
        ax.xaxis.set_label_position('top')
        ax.set_xlabel("Position", fontsize=12)
        ax.set_ylabel("Amino Acid", fontsize=12)
        ax.set_title("Mutation Frequencies at Selected Positions", fontsize=14)

        # Remove spines
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)

        # Add colorbar (for non-zero frequencies)
        sm = plt.cm.ScalarMappable(cmap=cmap_non_zero, norm=norm_non_zero)
        sm.set_array([])  # For older matplotlib versions
        cbar = fig.colorbar(sm, ax=ax, orientation='vertical', pad=0.1, ticks=[vmin_non_zero, (vmin_non_zero + vmax_non_zero) / 2, vmax_non_zero])
        cbar.set_label("Mutation Frequency (Non-Zero)", fontsize=10)
        cbar.ax.set_yticklabels([f"{vmin_non_zero:.2f}", f"{(vmin_non_zero + vmax_non_zero) / 2:.2f}", f"{vmax_non_zero:.2f}"])

        # Ensure layout is tight and adjust aspect for square cells
        ax.set_aspect('equal', adjustable='box')
        bottom, top = ax.get_ylim()
        ax.set_ylim(bottom - 0.5, top + 0.5)
        bottomx, topx = ax.get_xlim()
        ax.set_xlim(bottomx - 0.5, topx + 0.5)
        plt.tight_layout()
        plt.savefig(output_filename, dpi=300)
        print(f"Heatmap saved to '{output_filename}'")

    except FileNotFoundError:
        print(f"Error: Input file '{input_csv_file}' not found.")
    except KeyError as e:
        print(f"Error: Position '{e}' not found in the input CSV file.")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")

if __name__ == "__main__":
    input_csv = 'mutation_frequencies.csv'
    selected_positions = [59, 81, 83, 92, 108, 120, 122, 141, 159, 160, 167]
    highlighted_cells = {
        59: ['Q'],
        81: ['I'],
        83: ['L'],
        92: ['M'],
        108: ['V'],
        120: ['A'],
        122: ['G'],
        141: ['D'],
        159: ['H'],
        160: ['V'],
    }
    output_png = 'mutation_frequency_heatmap.png'
    plot_mutation_heatmap_v6(input_csv, selected_positions, highlighted_cells, output_png)