In [1]:
# Uncomment the line below and run this cell if you haven't installed the libraries
# !pip install biopython gffutils hmmlearn numpy pandas matplotlib seaborn

import gzip
from Bio import SeqIO
import gffutils
import os
import pandas as pd
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
# Configure matplotlib for inline plots
%matplotlib inline


In [2]:
# --- Configuration ---
# Path to the data directory - CHANGE THIS IF YOUR FILES ARE IN A DIFFERENT LOCATION
DATA_DIR = 'C:/Users/User/Downloads' # Use forward slashes for paths in Python, even on Windows
# Max genes to process for training to manage memory/runtime for demonstration.
# For full analysis, increase this or implement chunking.
MAX_GENES_FOR_TRAINING = 10

# Define states and their mapping to numerical indices for hmmlearn
STATE_MAP = {
    'Intergenic': 0,
    'Start_Codon': 1,
    'Exon': 2,
    'Intron': 3,
    'Stop_Codon': 4
}
# Reverse map for decoding
REVERSE_STATE_MAP = {v: k for k, v in STATE_MAP.items()}

# Define DNA bases and their mapping to numerical indices for hmmlearn
BASE_MAP = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
REVERSE_BASE_MAP = {v: k for k, v in BASE_MAP.items()}

In [3]:
def load_fasta(filepath):
    """
    Loads a FASTA file (gzipped or uncompressed) and returns a dictionary
    mapping chromosome IDs to their sequences.
    """
    print(f"Loading FASTA file: {filepath}")
    sequences = {}
    try:
        if filepath.endswith('.gz'):
            with gzip.open(filepath, "rt") as handle:
                for record in SeqIO.parse(handle, "fasta"):
                    # Use only the chromosome ID (e.g., '1' from '1 dna:chromosome...')
                    chrom_id = record.id.split(' ')[0]
                    sequences[chrom_id] = str(record.seq).upper()
        else:
            with open(filepath, "r") as handle:
                for record in SeqIO.parse(handle, "fasta"):
                    chrom_id = record.id.split(' ')[0]
                    sequences[chrom_id] = str(record.seq).upper()
        print(f"Loaded {len(sequences)} sequences from FASTA.")
        return sequences
    except Exception as e:
        print(f"Error loading FASTA file {filepath}: {e}")
        return {}

def create_gff_db(gtf_filepath, db_filepath):
    """
    Creates a gffutils database from a GTF file for efficient querying.
    If the database already exists, it loads it.
    """
    print(f"Processing GTF file: {gtf_filepath}")
    if os.path.exists(db_filepath):
        print(f"Loading existing gffutils database: {db_filepath}")
        db = gffutils.FeatureDB(db_filepath)
    else:
        print(f"Creating new gffutils database: {db_filepath}")
        try:
            db = gffutils.create_db(
                gtf_filepath,
                dbfn=db_filepath,
                force=True,  # Overwrite if exists
                keep_order=True,
                # merge_criteria='merge_id', # THIS LINE MUST BE COMMENTED OUT OR REMOVED
                disable_infer_transcripts=True,
                disable_infer_genes=True
            )
            print("gffutils database created.")
        except Exception as e:
            print(f"Error creating gffutils database from {gtf_filepath}: {e}")
            return None
    return db


def prepare_training_data(species_name, fasta_filepath, gtf_filepath, chrom_id='1', max_genes=MAX_GENES_FOR_TRAINING):
    """
    Prepares labeled DNA sequences for HMM training.
    This function will load FASTA and GTF, and then iterate through a subset of genes
    to create (DNA sequence, corresponding state labels) pairs.

    Args:
        species_name (str): Name of the species (e.g., 'human').
        fasta_filepath (str): Path to the gzipped FASTA file.
        gtf_filepath (str): Path to the gzipped GTF file.
        chrom_id (str): Chromosome ID to process (default '1').
        max_genes (int): Maximum number of genes to process for training.

    Returns:
        list: A list of tuples, where each tuple is (dna_sequence_numeric, state_labels_numeric).
              dna_sequence_numeric is a list of base indices (0-3).
              state_labels_numeric is a list of state indices (0-4).
    """
    print(f"\n--- Preparing training data for {species_name} (Chromosome {chrom_id}) ---")
    fasta_sequences = load_fasta(fasta_filepath)
    if chrom_id not in fasta_sequences:
        print(f"Error: Chromosome {chrom_id} not found in {fasta_filepath}")
        return []
    chromosome_seq = fasta_sequences[chrom_id]

    db_filepath = os.path.join(DATA_DIR, f'{species_name}_{chrom_id}.db')
    db = create_gff_db(gtf_filepath, db_filepath)
    if not db:
        return []

    # Collect all relevant features for the chromosome
    gene_features_map = defaultdict(lambda: defaultdict(list))
    gene_coords = {} # Store (start, end, strand) for each gene

    print(f"Collecting gene annotations for Chromosome {chrom_id}...")
    genes_processed_count = 0
    for gene in db.features_of_type('gene', seqid=chrom_id):
        if max_genes and genes_processed_count >= max_genes:
            break
        
        gene_id = gene.attributes.get('gene_id', [gene.id])[0]
        gene_coords[gene_id] = (gene.start, gene.end, gene.strand)

        for feature in db.children(gene, order_by='start'):
            if feature.feature in ['exon', 'CDS', 'start_codon', 'stop_codon']:
                gene_features_map[gene_id][feature.feature].append(feature)
        genes_processed_count += 1
    print(f"Collected annotations for {genes_processed_count} genes.")

    training_sequences = []

    for gene_id, features_by_type in gene_features_map.items():
        gene_start, gene_end, gene_strand = gene_coords[gene_id]

        FLANKING_REGION = 500 # bases on each side for intergenic context
        segment_start_0idx = max(0, gene_start - 1 - FLANKING_REGION)
        segment_end_0idx = min(len(chromosome_seq) - 1, gene_end - 1 + FLANKING_REGION)
        
        segment_dna = chromosome_seq[segment_start_0idx : segment_end_0idx + 1]
        segment_labels = [STATE_MAP['Intergenic']] * len(segment_dna)

        rel_gene_start = gene_start - 1 - segment_start_0idx
        rel_gene_end = gene_end - 1 - segment_start_0idx

        for i in range(max(0, rel_gene_start), min(len(segment_labels), rel_gene_end + 1)):
            segment_labels[i] = STATE_MAP['Intron']

        for exon in features_by_type['exon']:
            rel_exon_start = exon.start - 1 - segment_start_0idx
            rel_exon_end = exon.end - 1 - segment_start_0idx
            for i in range(max(0, rel_exon_start), min(len(segment_labels), rel_exon_end + 1)):
                segment_labels[i] = STATE_MAP['Exon']

        for sc in features_by_type['start_codon']:
            rel_sc_start = sc.start - 1 - segment_start_0idx
            rel_sc_end = sc.end - 1 - segment_start_0idx
            for i in range(max(0, rel_sc_start), min(len(segment_labels), rel_sc_end + 1)):
                segment_labels[i] = STATE_MAP['Start_Codon']

        for stc in features_by_type['stop_codon']:
            rel_stc_start = stc.start - 1 - segment_start_0idx
            rel_stc_end = stc.end - 1 - segment_start_0idx
            for i in range(max(0, rel_stc_start), min(len(segment_labels), rel_stc_end + 1)):
                segment_labels[i] = STATE_MAP['Stop_Codon']

        dna_numeric = [BASE_MAP.get(base, -1) for base in segment_dna]

        if -1 in dna_numeric:
            print(f"Warning: Skipping gene {gene_id} due to unknown bases.")
            continue
        if not dna_numeric:
            print(f"Warning: Skipping empty sequence for gene {gene_id}.")
            continue

        training_sequences.append((dna_numeric, segment_labels))

    print(f"Prepared {len(training_sequences)} labeled sequences for training.")
    return training_sequences


In [4]:
from hmmlearn import hmm

class HMMGenePredictor:
    """
    A class to encapsulate Hidden Markov Model (HMM) based gene prediction.
    Uses hmmlearn.hmm.MultinomialHMM for discrete observations (DNA bases).
    """
    def __init__(self, n_states=len(STATE_MAP), n_emissions=len(BASE_MAP),
                 state_map=STATE_MAP, base_map=BASE_MAP):
        """
        Initializes the HMMGenePredictor.
        """
        self.n_states = n_states
        self.n_emissions = n_emissions
        self.state_map = state_map
        self.base_map = base_map
        self.model = None

        self.model = hmm.MultinomialHMM(n_components=self.n_states, n_features=self.n_emissions, n_iter=100, tol=1e-4)

        self.model.startprob_ = np.full(self.n_states, 1.0 / self.n_states)
        self.model.transmat_ = np.full((self.n_states, self.n_states), 1.0 / self.n_states)

        initial_transmat = np.zeros((self.n_states, self.n_states))

        initial_transmat[self.state_map['Intergenic'], self.state_map['Intergenic']] = 0.95
        initial_transmat[self.state_map['Intergenic'], self.state_map['Start_Codon']] = 0.05

        initial_transmat[self.state_map['Start_Codon'], self.state_map['Exon']] = 1.0

        initial_transmat[self.state_map['Exon'], self.state_map['Exon']] = 0.9
        initial_transmat[self.state_map['Exon'], self.state_map['Intron']] = 0.08
        initial_transmat[self.state_map['Exon'], self.state_map['Stop_Codon']] = 0.02

        initial_transmat[self.state_map['Intron'], self.state_map['Intron']] = 0.9
        initial_transmat[self.state_map['Intron'], self.state_map['Exon']] = 0.1

        initial_transmat[self.state_map['Stop_Codon'], self.state_map['Intergenic']] = 1.0

        for i in range(self.n_states):
            row_sum = np.sum(initial_transmat[i])
            if row_sum > 0:
                self.model.transmat_[i] = initial_transmat[i] / row_sum
            else:
                self.model.transmat_[i] = np.full(self.n_states, 1.0 / self.n_states)

        self.model.emissionprob_ = np.full((self.n_states, self.n_emissions), 1.0 / self.n_emissions)


    def train(self, training_data):
        """
        Trains the HMM model using the provided labeled training data.
        """
        if not training_data:
            print("No training data provided. HMM will not be trained.")
            return

        print("Estimating HMM parameters from labeled data...")
        start_counts = np.zeros(self.n_states)
        transition_counts = np.zeros((self.n_states, self.n_states))
        emission_counts = np.zeros((self.n_states, self.n_emissions))

        for dna_seq, state_seq in training_data:
            if not dna_seq or not state_seq or len(dna_seq) != len(state_seq):
                continue

            start_counts[state_seq[0]] += 1

            for i in range(len(dna_seq)):
                current_state = state_seq[i]
                current_base = dna_seq[i]

                emission_counts[current_state, current_base] += 1

                if i < len(dna_seq) - 1:
                    next_state = state_seq[i+1]
                    transition_counts[current_state, next_state] += 1

        self.model.startprob_ = start_counts / np.sum(start_counts) if np.sum(start_counts) > 0 else np.full(self.n_states, 1.0 / self.n_states)

        for i in range(self.n_states):
            row_sum_trans = np.sum(transition_counts[i])
            if row_sum_trans > 0:
                self.model.transmat_[i] = transition_counts[i] / row_sum_trans
            else:
                self.model.transmat_[i] = np.full(self.n_states, 1.0 / self.n_states)

            row_sum_emit = np.sum(emission_counts[i])
            if row_sum_emit > 0:
                self.model.emissionprob_[i] = emission_counts[i] / row_sum_emit
            else:
                self.model.emissionprob_[i] = np.full(self.n_emissions, 1.0 / self.n_emissions)

        print("HMM parameters estimated from labeled data.")


    def predict(self, dna_sequence_numeric):
        """
        Predicts the most likely sequence of hidden states (gene structure)
        for a given DNA sequence using the Viterbi algorithm.
        """
        if self.model is None:
            print("Error: HMM model not trained. Please train the model first.")
            return []
        if not dna_sequence_numeric:
            return []

        X = np.array(dna_sequence_numeric).reshape(-1, 1)
        
        try:
            log_prob, state_sequence = self.model.decode(X, algorithm="viterbi")
            return state_sequence.tolist()
        except Exception as e:
            print(f"Error during prediction: {e}")
            return [self.state_map['Intergenic']] * len(dna_sequence_numeric)


    def get_parameters(self):
        """
        Returns the learned HMM parameters (start probabilities, transition matrix, emission matrix).
        """
        if self.model is None:
            return None, None, None
        return self.model.startprob_, self.model.transmat_, self.model.emissionprob_

    def plot_parameters(self, title="HMM Parameters"):
        """
        Plots the learned transition and emission probabilities as heatmaps.
        """
        if self.model is None:
            print("Model not trained, cannot plot parameters.")
            return

        startprob = self.model.startprob_
        transmat = self.model.transmat_
        emissionprob = self.model.emissionprob_

        state_names = [REVERSE_STATE_MAP[i] for i in range(self.n_states)]
        base_names = [REVERSE_BASE_MAP[i] for i in range(self.n_emissions)]

        fig, axes = plt.subplots(1, 3, figsize=(20, 7))
        fig.suptitle(title, fontsize=16)

        sns.barplot(x=state_names, y=startprob, ax=axes[0])
        axes[0].set_title("Start Probabilities")
        axes[0].set_ylabel("Probability")
        axes[0].set_xlabel("State")
        axes[0].tick_params(axis='x', rotation=45)

        sns.heatmap(transmat, annot=True, cmap="viridis", fmt=".2f",
                    xticklabels=state_names, yticklabels=state_names, ax=axes[1])
        axes[1].set_title("Transition Matrix")
        axes[1].set_xlabel("To State")
        axes[1].set_ylabel("From State")

        sns.heatmap(emissionprob, annot=True, cmap="magma", fmt=".2f",
                    xticklabels=base_names, yticklabels=state_names, ax=axes[2])
        axes[2].set_title("Emission Matrix")
        axes[2].set_xlabel("Base (A, C, G, T)")
        axes[2].set_ylabel("State")

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.show()

In [5]:
def calculate_metrics(true_labels, predicted_labels, states_to_evaluate=None):
    """
    Calculates Precision, Recall, and F1-score for each specified state.
    """
    if len(true_labels) != len(predicted_labels):
        print("Error: Length of true labels and predicted labels must be the same.")
        return {}

    metrics = {}
    
    if states_to_evaluate is None:
        states_to_evaluate = list(STATE_MAP.values())

    for state_idx in states_to_evaluate:
        state_name = REVERSE_STATE_MAP[state_idx]
        
        tp = sum(1 for t, p in zip(true_labels, predicted_labels) if t == state_idx and p == state_idx)
        fp = sum(1 for t, p in zip(true_labels, predicted_labels) if t != state_idx and p == state_idx)
        fn = sum(1 for t, p in zip(true_labels, predicted_labels) if t == state_idx and p != state_idx)
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
        
        metrics[state_name] = {
            'Precision': precision,
            'Recall': recall,
            'F1-score': f1_score
        }
    return metrics


def evaluate_gene_boundaries(true_gene_coords, predicted_gene_coords, tolerance=0):
    """
    Evaluates the accuracy of gene start/stop boundaries.
    This is a simplified evaluation. For full accuracy, one would compare
    predicted gene intervals with true gene intervals.
    """
    true_starts = sorted([s for s, _ in true_gene_coords])
    true_ends = sorted([e for _, e in true_gene_coords])
    
    pred_starts = sorted([s for s, _ in predicted_gene_coords])
    pred_ends = sorted([e for _, e in predicted_gene_coords])

    correct_starts = 0
    for ps in pred_starts:
        if any(abs(ps - ts) <= tolerance for ts in true_starts):
            correct_starts += 1
    
    correct_ends = 0
    for pe in pred_ends:
        if any(abs(pe - te) <= tolerance for te in true_ends):
            correct_ends += 1

    start_accuracy = correct_starts / len(true_starts) if len(true_starts) > 0 else 0.0
    end_accuracy = correct_ends / len(true_ends) if len(true_ends) > 0 else 0.0

    return {
        'start_accuracy': start_accuracy,
        'end_accuracy': end_accuracy
    }


def identify_gene_segments(predicted_labels_numeric):
    """
    Identifies gene segments (Start_Codon -> Exon -> ... -> Stop_Codon) from predicted labels.
    This function simplifies a gene to a continuous block of gene-related states.
    """
    gene_states = {STATE_MAP['Start_Codon'], STATE_MAP['Exon'], STATE_MAP['Intron'], STATE_MAP['Stop_Codon']}
    
    predicted_genes = []
    in_gene = False
    current_gene_start = -1

    for i, label in enumerate(predicted_labels_numeric):
        if label in gene_states and not in_gene:
            in_gene = True
            current_gene_start = i
        elif label not in gene_states and in_gene:
            in_gene = False
            predicted_genes.append((current_gene_start, i - 1))
            current_gene_start = -1
    
    if in_gene:
        predicted_genes.append((current_gene_start, len(predicted_labels_numeric) - 1))
            
    return predicted_genes


def plot_metrics(all_metrics, metric_name, title):
    """
    Plots a specific metric (Precision, Recall, or F1-score) across species and states.
    """
    data_for_plot = []
    for species, metrics_by_state in all_metrics.items():
        for state, values in metrics_by_state.items():
            data_for_plot.append({'Species': species, 'State': state, 'Value': values[metric_name]})

    df = pd.DataFrame(data_for_plot)

    plt.figure(figsize=(12, 7))
    sns.barplot(x='State', y='Value', hue='Species', data=df, palette='viridis')
    plt.title(title, fontsize=16)
    plt.ylabel(metric_name, fontsize=12)
    plt.xlabel("Gene State", fontsize=12)
    plt.ylim(0, 1.05)
    plt.legend(title="Trained Model Species")
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.show()


def plot_gene_structure(dna_sequence, true_labels, predicted_labels, title="Gene Structure Prediction"):
    """
    Visualizes a segment of DNA with true and predicted gene structures.
    """
    if not dna_sequence or not true_labels or not predicted_labels:
        print("Cannot plot: missing sequence or labels.")
        return

    display_len = min(len(dna_sequence), 200)

    dna_str = "".join([REVERSE_BASE_MAP[b] for b in dna_sequence[:display_len]])
    true_state_names = [REVERSE_STATE_MAP[s] for s in true_labels[:display_len]]
    predicted_state_names = [REVERSE_STATE_MAP[s] for s in predicted_labels[:display_len]]

    fig, ax = plt.subplots(figsize=(15, 6))

    state_colors = {
        'Intergenic': 'lightgray',
        'Start_Codon': 'darkgreen',
        'Exon': 'green',
        'Intron': 'orange',
        'Stop_Codon': 'darkred'
    }

    y_true = np.array([STATE_MAP[s] for s in true_state_names])
    for i, state in enumerate(true_state_names):
        ax.bar(i, 1, bottom=1.5, color=state_colors.get(state, 'gray'), width=1.0, align='edge', label=state if i==0 else "")

    y_pred = np.array([STATE_MAP[s] for s in predicted_state_names])
    for i, state in enumerate(predicted_state_names):
        ax.bar(i, 1, bottom=0.5, color=state_colors.get(state, 'gray'), width=1.0, align='edge', label=state if i==0 else "")

    for i, base in enumerate(dna_str):
        ax.text(i + 0.5, 0.0, base, ha='center', va='bottom', fontsize=8, color='black')

    ax.set_yticks([1.0, 2.0])
    ax.set_yticklabels(['Predicted', 'True'])
    ax.set_xticks(np.arange(0, display_len, 10))
    ax.set_xticklabels(np.arange(0, display_len, 10))
    ax.set_xlabel("Position (bases)")
    ax.set_title(title)
    ax.set_xlim(-0.5, display_len - 0.5)
    ax.set_ylim(-0.5, 2.5)

    handles = [plt.Rectangle((0,0),1,1, color=state_colors[s]) for s in state_colors]
    labels = list(state_colors.keys())
    ax.legend(handles, labels, title="States", loc='upper left', bbox_to_anchor=(1, 1))

    plt.tight_layout()
    plt.show()

In [None]:
def run_gene_prediction_pipeline():
    """
    Orchestrates the entire gene prediction pipeline:
    1. Data preparation for each species.
    2. HMM training for each species.
    3. Cross-species prediction.
    4. Evaluation and visualization of results.
    """
    species_info = {
        'human': {
            'fasta': os.path.join(DATA_DIR, 'Homo_sapiens.GRCh38.dna.chromosome.1.fa.gz'),
            'gtf': os.path.join(DATA_DIR, 'Homo_sapiens.GRCh38.114.gtf.gz'),
            'chrom_id': '1'
        },
        'mouse': {
            'fasta': os.path.join(DATA_DIR, 'Mus_musculus.GRCm39.dna.chromosome.1.fa.gz'),
            'gtf': os.path.join(DATA_DIR, 'Mus_musculus.GRCm39.114.gtf.gz'),
            'chrom_id': '1'
        },
        'zebrafish': {
            'fasta': os.path.join(DATA_DIR, 'Danio_rerio.GRCz11.dna.chromosome.1.fa.gz'),
            'gtf': os.path.join(DATA_DIR, 'Danio_rerio.GRCz11.114.gtf.gz'),
            'chrom_id': '1'
        }
    }

    trained_models = {}
    all_evaluation_results = defaultdict(dict)

    # --- Step 1 & 2: Data Preparation and HMM Training for each species ---
    print("\n--- Training HMMs for each species ---")
    for species_name, info in species_info.items():
        print(f"\nProcessing {species_name}...")
        training_data = prepare_training_data(
            species_name, info['fasta'], info['gtf'], info['chrom_id'], max_genes=MAX_GENES_FOR_TRAINING
        )
        
        if not training_data:
            print(f"Skipping training for {species_name} due to no data.")
            continue

        predictor = HMMGenePredictor()
        predictor.train(training_data)
        trained_models[species_name] = predictor
        print(f"HMM for {species_name} trained successfully.")
        predictor.plot_parameters(title=f"HMM Parameters Trained on {species_name.capitalize()} Data")

    if not trained_models:
        print("No models were successfully trained. Exiting.")
        return

    # --- Step 3 & 4: Prediction and Evaluation (Same-species and Cross-species) ---
    print("\n--- Performing Predictions and Evaluations ---")
    for trained_on_species, predictor_model in trained_models.items():
        print(f"\n--- Model trained on {trained_on_species.capitalize()} ---")
        for tested_on_species, info in species_info.items():
            print(f"  Testing on {tested_on_species.capitalize()} data...")
            
            test_data = prepare_training_data(
                tested_on_species, info['fasta'], info['gtf'], info['chrom_id'], max_genes=MAX_GENES_FOR_TRAINING
            )

            if not test_data:
                print(f"    No test data for {tested_on_species}. Skipping evaluation.")
                continue

            all_true_labels = []
            all_predicted_labels = []
            
            sample_dna_for_plot = None
            sample_true_labels_for_plot = None
            sample_predicted_labels_for_plot = None

            for i, (dna_seq_numeric, true_labels_numeric) in enumerate(test_data):
                predicted_labels_numeric = predictor_model.predict(dna_seq_numeric)
                
                all_true_labels.extend(true_labels_numeric)
                all_predicted_labels.extend(predicted_labels_numeric)

                if i == 0:
                    sample_dna_for_plot = dna_seq_numeric
                    sample_true_labels_for_plot = true_labels_numeric
                    sample_predicted_labels_for_plot = predicted_labels_numeric

            metrics = calculate_metrics(all_true_labels, all_predicted_labels)
            all_evaluation_results[trained_on_species][tested_on_species] = metrics

            print(f"    Evaluation for model trained on {trained_on_species} and tested on {tested_on_species}:")
            for state, m in metrics.items():
                print(f"      {state}: P={m['Precision']:.2f}, R={m['Recall']:.2f}, F1={m['F1-score']:.2f}")

            if sample_dna_for_plot and sample_true_labels_for_plot and sample_predicted_labels_for_plot:
                plot_gene_structure(
                    sample_dna_for_plot,
                    sample_true_labels_for_plot,
                    sample_predicted_labels_for_plot,
                    title=f"Sample Gene Structure: Trained on {trained_on_species.capitalize()}, Tested on {tested_on_species.capitalize()}"
                )
            
    # --- Comparative Analysis and Final Plots ---
    print("\n--- Comparative Analysis ---")

    f1_scores_data = []
    for trained_on, test_results in all_evaluation_results.items():
        for tested_on, metrics_by_state in test_results.items():
            for state, values in metrics_by_state.items():
                f1_scores_data.append({
                    'Trained On': trained_on.capitalize(),
                    'Tested On': tested_on.capitalize(),
                    'State': state,
                    'F1-score': values['F1-score']
                })
    
    df_f1 = pd.DataFrame(f1_scores_data)

    for state_name in ['Exon', 'Intron', 'Start_Codon', 'Stop_Codon', 'Intergenic']:
        state_df = df_f1[df_f1['State'] == state_name]
        if not state_df.empty:
            plt.figure(figsize=(10, 6))
            sns.barplot(x='Tested On', y='F1-score', hue='Trained On', data=state_df, palette='Spectral')
            plt.title(f'F1-score for {state_name} Prediction Across Species', fontsize=16)
            plt.ylabel('F1-score', fontsize=12)
            plt.xlabel('Species Tested On', fontsize=12)
            plt.ylim(0, 1.05)
            plt.legend(title='Model Trained On')
            plt.grid(axis='y', linestyle='--', alpha=0.7)
            plt.tight_layout()
            plt.show()

    print("\n--- Project Completed ---")
    print("Results printed to console and plots generated.")

# Run the pipeline
run_gene_prediction_pipeline()



--- Training HMMs for each species ---

Processing human...

--- Preparing training data for human (Chromosome 1) ---
Loading FASTA file: C:/Users/User/Downloads\Homo_sapiens.GRCh38.dna.chromosome.1.fa.gz
Loaded 1 sequences from FASTA.
Processing GTF file: C:/Users/User/Downloads\Homo_sapiens.GRCh38.114.gtf.gz
Creating new gffutils database: C:/Users/User/Downloads\human_1.db
