In [1]:
import os
import pandas as pd
import matplotlib.pyplot as plt

# Configuration for plate types
PLATE_CONFIG = {
    "6-well": {
        "files": [f"rasters_data/6_well_file_{i:03d}_spike_counts.csv" for i in range(1, 12)],
        "wells": [["A1", "A2", "A3"], ["B1", "B2", "B3"]],
        "save_dir": "plots/rasters/6_well",
        "top_n_row": 2,  # Top performers per row
    },
    "24-well": {
        "files": [f"rasters_data/24_well_file_{i:03d}_spike_counts.csv" for i in range(1, 35)],
        "wells": [
            [f"A{i}" for i in range(1, 7)],
            [f"B{i}" for i in range(1, 7)],
            [f"C{i}" for i in range(1, 7)],
            [f"D{i}" for i in range(1, 7)],
        ],
        "save_dir": "plots/rasters/24_well",
        "top_n_row": 3,  # Top performers per row
    },
    "head-to-head": {
        "files": [f"rasters_data/H2H_file_{i:03d}_spike_counts.csv" for i in range(1, 17)],
        "wells": [
            [f"A{i}" for i in range(1, 7)],  
            [f"B{i}" for i in range(1, 7)], 
            [f"C{i}" for i in range(1, 7)], 
            [f"D{i}" for i in range(1, 7)]
        ],
        "save_dir": "plots/rasters/H2H",
        "logic": {
            "group1": {"rows": [0], "top_n": 3},  # Top 3 performers from Row A
            "group2": {"rows": [1, 2, 3], "cols": [0, 1, 2, 3], "top_n": 3},  # Top 3 from B, C, D1-D4
            "group3": {"rows": [3], "cols": [4, 5]},  # Individualized for D5 and D6
        },
    },
}

# Function to calculate total activity for a well
def calculate_activity(data, well_name):
    well_columns = [col for col in data.columns if col.startswith(well_name)]
    return data[well_columns].sum().sum()

# Function to plot raster for a well
def plot_raster(data, well_name, save_path, xlim=(0, 300)):
    os.makedirs(save_path, exist_ok=True)
    well_columns = [col for col in data.columns if col.startswith(well_name)]
    interval_start = data["Interval_Start"]

    fig, ax = plt.subplots(figsize=(12, 8))
    for idx, electrode in enumerate(well_columns):
        # Get spike activity for the electrode
        spike_times = interval_start[data[electrode] == 1]
        ax.vlines(spike_times, idx + 0.5, idx + 1.0, color='black')  # Add vertical lines

    ax.set_xlim(*xlim)
    ax.set_xlabel("Time (s)", fontsize=14)
    ax.set_ylabel("Electrodes", fontsize=14)
    ax.set_title(f"Raster Plot for Well {well_name}", fontsize=16)
    ax.set_yticks(range(1, len(well_columns) + 1))
    ax.set_yticklabels(well_columns, fontsize=8)
    plt.tight_layout()

    # Save plot
    plot_file = os.path.join(save_path, f"{well_name}_raster.png")
    plt.savefig(plot_file)
    print(f"Saved raster plot for {well_name} at {plot_file}")
    plt.close()

# Function to process files for a plate type
def process_plate(plate_name, config):
    print(f"Processing plate: {plate_name}")
    for file_path in config["files"]:
        # Load data
        if not os.path.exists(file_path):
            print(f"File {file_path} not found. Skipping...")
            continue
        data = pd.read_csv(file_path)

        # Define wells and logic for selecting most active
        if plate_name == "6-well":
            most_active_wells = []
            for row in config["wells"]:
                row_activities = {well: calculate_activity(data, well) for well in row}
                most_active_wells.extend(sorted(row_activities, key=row_activities.get, reverse=True)[:config["top_n_row"]])

        elif plate_name == "24-well":
            most_active_wells = []
            for row in config["wells"]:
                row_activities = {well: calculate_activity(data, well) for well in row}
                most_active_wells.extend(sorted(row_activities, key=row_activities.get, reverse=True)[:config["top_n_row"]])

        elif plate_name == "head-to-head":
            most_active_wells = []
            logic = config["logic"]
            
            # Group 1
            row_a = config["wells"][logic["group1"]["rows"][0]]
            row_a_activities = {well: calculate_activity(data, well) for well in row_a}
            most_active_wells.extend(sorted(row_a_activities, key=row_a_activities.get, reverse=True)[:logic["group1"]["top_n"]])

            # Group 2
            group2_wells = [well for r in logic["group2"]["rows"] for well in config["wells"][r][:logic["group2"]["cols"][-1] + 1]]
            group2_activities = {well: calculate_activity(data, well) for well in group2_wells}
            most_active_wells.extend(sorted(group2_activities, key=group2_activities.get, reverse=True)[:logic["group2"]["top_n"]])

            # Group 3
            most_active_wells.extend(config["wells"][logic["group3"]["rows"][0]][logic["group3"]["cols"][0]:logic["group3"]["cols"][-1] + 1])

        # Plot rasters
        for well in most_active_wells:
            plot_raster(data, well, config["save_dir"])

# Run the pipeline
for plate_name, config in PLATE_CONFIG.items():
    process_plate(plate_name, config)

Processing plate: 6-well
File rasters_data/6_well_file_001_spike_counts.csv not found. Skipping...
File rasters_data/6_well_file_002_spike_counts.csv not found. Skipping...
File rasters_data/6_well_file_003_spike_counts.csv not found. Skipping...
File rasters_data/6_well_file_004_spike_counts.csv not found. Skipping...
File rasters_data/6_well_file_005_spike_counts.csv not found. Skipping...
File rasters_data/6_well_file_006_spike_counts.csv not found. Skipping...
File rasters_data/6_well_file_007_spike_counts.csv not found. Skipping...
File rasters_data/6_well_file_008_spike_counts.csv not found. Skipping...
File rasters_data/6_well_file_009_spike_counts.csv not found. Skipping...
File rasters_data/6_well_file_010_spike_counts.csv not found. Skipping...
File rasters_data/6_well_file_011_spike_counts.csv not found. Skipping...
Processing plate: 24-well
File rasters_data/24_well_file_001_spike_counts.csv not found. Skipping...
File rasters_data/24_well_file_002_spike_counts.csv not foun