In [None]:
# Ch16-bonus-agent 

In [None]:
## AI Agent for BioInformatics ##
# Finds a gene and performs BLAST
#. Make sure to update the email to your own email

In [None]:
# Import Libraries
import os
import ssl
import urllib3
import requests
import time
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from Bio import Entrez, SeqIO
from Bio.Blast import NCBIWWW, NCBIXML
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
import warnings
from typing import List, Dict, Optional, Tuple
import io
from datetime import datetime

In [None]:
# Disable SSL warnings and verification
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
ssl._create_default_https_context = ssl._create_unverified_context

class GeneBlastAgent:
    """
    AI Agent for gene searching and BLAST analysis with enhanced error handling,
    SSL bypass, timeout management, and result visualization.
    """
    
    def __init__(self, email: str, api_key: Optional[str] = None):
        """
        Initialize the BLAST AI Agent.
        
        Args:
            email (str): Your email for NCBI API access
            api_key (str, optional): NCBI API key for increased rate limits
        """
        self.email = email
        self.api_key = api_key
        Entrez.email = email
        if api_key:
            Entrez.api_key = api_key
        
        # Configure requests session with SSL bypass
        self.session = requests.Session()
        self.session.verify = False
        
        # BLAST parameters
        self.blast_timeout = 300  # 5 minutes default timeout
        self.max_retries = 3
        
        print(f"🧬 Gene BLAST Agent initialized")
        print(f"📧 Email: {email}")
        print(f"🔑 API Key: {'Provided' if api_key else 'Not provided'}")
        print(f"⚠️  SSL verification: Disabled")
    
    def search_genes(self, query: str, database: str = "nucleotide", 
                    max_results: int = 10) -> List[Dict]:
        """
        Search for genes in NCBI databases.
        
        Args:
            query (str): Search query (gene name, organism, etc.)
            database (str): NCBI database to search ('nucleotide', 'protein', 'pubmed')
            max_results (int): Maximum number of results to return
            
        Returns:
            List[Dict]: List of gene records with metadata
        """
        print(f"🔍 Searching for '{query}' in {database} database...")
        
        try:
            # Search for IDs
            search_handle = Entrez.esearch(
                db=database, 
                term=query, 
                retmax=max_results,
                sort="relevance"
            )
            search_results = Entrez.read(search_handle)
            search_handle.close()
            
            id_list = search_results["IdList"]
            
            if not id_list:
                print("❌ No results found")
                return []
            
            print(f"✅ Found {len(id_list)} results")
            
            # Fetch detailed information
            fetch_handle = Entrez.efetch(
                db=database, 
                id=id_list, 
                rettype="gb", 
                retmode="text"
            )
            
            records = []
            for record in SeqIO.parse(fetch_handle, "genbank"):
                record_info = {
                    'id': record.id,
                    'description': record.description,
                    'length': len(record.seq),
                    'sequence': str(record.seq),
                    'organism': getattr(record, 'annotations', {}).get('organism', 'Unknown'),
                    'accession': getattr(record, 'annotations', {}).get('accessions', ['Unknown'])[0]
                }
                records.append(record_info)
            
            fetch_handle.close()
            
            # Create summary DataFrame
            df = pd.DataFrame(records)
            print(f"📊 Retrieved {len(records)} detailed records")
            
            return records
            
        except Exception as e:
            print(f"❌ Error searching genes: {str(e)}")
            return []
    
    def run_blast(self, sequence: str, program: str = "blastn", 
                  database: str = "nt", expect_threshold: float = 0.001,
                  max_hits: int = 50) -> Optional[Dict]:
        """
        Run BLAST analysis with timeout and retry handling.
        
        Args:
            sequence (str): DNA/protein sequence to BLAST
            program (str): BLAST program ('blastn', 'blastp', 'blastx', etc.)
            database (str): BLAST database ('nt', 'nr', 'refseq_rna', etc.)
            expect_threshold (float): E-value threshold
            max_hits (int): Maximum number of hits to return
            
        Returns:
            Dict: BLAST results with alignments and statistics
        """
        print(f"🚀 Running {program.upper()} against {database} database...")
        print(f"📏 Sequence length: {len(sequence)} bases/residues")
        
        for attempt in range(self.max_retries):
            try:
                print(f"🔄 Attempt {attempt + 1}/{self.max_retries}")
                
                # Submit BLAST job
                result_handle = NCBIWWW.qblast(
                    program=program,
                    database=database,
                    sequence=sequence,
                    expect=expect_threshold,
                    hitlist_size=max_hits,
                    format_type="XML"
                )
                
                # Parse results with timeout handling
                start_time = time.time()
                blast_records = list(NCBIXML.parse(result_handle))
                elapsed_time = time.time() - start_time
                
                result_handle.close()
                
                if blast_records:
                    print(f"✅ BLAST completed in {elapsed_time:.2f} seconds")
                    return self._parse_blast_results(blast_records[0])
                else:
                    print("❌ No BLAST results returned")
                    return None
                    
            except Exception as e:
                print(f"❌ BLAST attempt {attempt + 1} failed: {str(e)}")
                if attempt < self.max_retries - 1:
                    wait_time = (attempt + 1) * 30  # Exponential backoff
                    print(f"⏳ Waiting {wait_time} seconds before retry...")
                    time.sleep(wait_time)
                else:
                    print("❌ All BLAST attempts failed")
                    return None
    
    def _parse_blast_results(self, blast_record) -> Dict:
        """Parse BLAST XML results into a structured format."""
        results = {
            'query_id': blast_record.query,
            'query_length': blast_record.query_length,
            'database': blast_record.database,
            'alignments': [],
            'statistics': {
                'total_hits': len(blast_record.alignments)
            }
        }
        
        for alignment in blast_record.alignments:
            for hsp in alignment.hsps:
                align_data = {
                    'title': alignment.title,
                    'accession': alignment.accession,
                    'length': alignment.length,
                    'e_value': hsp.expect,
                    'bit_score': hsp.bits,
                    'score': hsp.score,
                    'identity': hsp.identities,
                    'positives': hsp.positives,
                    'gaps': hsp.gaps,
                    'query_start': hsp.query_start,
                    'query_end': hsp.query_end,
                    'subject_start': hsp.sbjct_start,
                    'subject_end': hsp.sbjct_end,
                    'query_seq': hsp.query,
                    'subject_seq': hsp.sbjct,
                    'match_seq': hsp.match,
                    'identity_percent': (hsp.identities / hsp.align_length) * 100,
                    'coverage_percent': (abs(hsp.query_end - hsp.query_start + 1) / results['query_length']) * 100
                }
                results['alignments'].append(align_data)
        
        # Sort by E-value (best hits first)
        results['alignments'].sort(key=lambda x: x['e_value'])
        
        return results
    
    def visualize_blast_results(self, blast_results: Dict, top_n: int = 20):
        """
        Create comprehensive visualizations of BLAST results.
        
        Args:
            blast_results (Dict): Parsed BLAST results
            top_n (int): Number of top hits to visualize
        """
        if not blast_results or not blast_results['alignments']:
            print("❌ No BLAST results to visualize")
            return
        
        alignments = blast_results['alignments'][:top_n]
        
        # Create figure with subplots
        fig, axes = plt.subplots(2, 2, figsize=(16, 12))
        fig.suptitle(f'BLAST Analysis Results - Query: {blast_results["query_id"]}', 
                     fontsize=16, fontweight='bold')
        
        # 1. E-value distribution
        e_values = [align['e_value'] for align in alignments if align['e_value'] > 0]
        log_e_values = [-np.log10(e) for e in e_values]
        
        axes[0, 0].hist(log_e_values, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
        axes[0, 0].set_xlabel('-log10(E-value)')
        axes[0, 0].set_ylabel('Frequency')
        axes[0, 0].set_title('E-value Distribution')
        axes[0, 0].grid(True, alpha=0.3)
        
        # 2. Bit Score vs Identity
        bit_scores = [align['bit_score'] for align in alignments]
        identities = [align['identity_percent'] for align in alignments]
        
        scatter = axes[0, 1].scatter(identities, bit_scores, alpha=0.6, c=log_e_values, 
                                   cmap='viridis', s=60)
        axes[0, 1].set_xlabel('Identity (%)')
        axes[0, 1].set_ylabel('Bit Score')
        axes[0, 1].set_title('Bit Score vs Identity')
        axes[0, 1].grid(True, alpha=0.3)
        plt.colorbar(scatter, ax=axes[0, 1], label='-log10(E-value)')
        
        # 3. Coverage vs Identity
        coverage = [align['coverage_percent'] for align in alignments]
        
        axes[1, 0].scatter(coverage, identities, alpha=0.6, c='coral', s=60)
        axes[1, 0].set_xlabel('Query Coverage (%)')
        axes[1, 0].set_ylabel('Identity (%)')
        axes[1, 0].set_title('Coverage vs Identity')
        axes[1, 0].grid(True, alpha=0.3)
        
        # 4. Top hits summary (horizontal bar chart)
        top_10 = alignments[:10]
        titles = [align['title'][:50] + '...' if len(align['title']) > 50 
                 else align['title'] for align in top_10]
        scores = [align['bit_score'] for align in top_10]
        
        y_pos = range(len(titles))
        bars = axes[1, 1].barh(y_pos, scores, alpha=0.7, color='lightgreen')
        axes[1, 1].set_yticks(y_pos)
        axes[1, 1].set_yticklabels(titles, fontsize=8)
        axes[1, 1].set_xlabel('Bit Score')
        axes[1, 1].set_title('Top 10 Hits by Bit Score')
        axes[1, 1].grid(True, alpha=0.3, axis='x')
        
        # Add score labels on bars
        for i, bar in enumerate(bars):
            width = bar.get_width()
            axes[1, 1].text(width + max(scores) * 0.01, bar.get_y() + bar.get_height()/2, 
                           f'{scores[i]:.1f}', ha='left', va='center', fontsize=8)

        # Save the plot to a file
        plt.savefig("Ch16-bonus-agent.png")
        
        plt.tight_layout()
        plt.show()
        
       
        # Summary statistics
        print("\n📊 BLAST Results Summary:")
        print(f"Total alignments: {len(blast_results['alignments'])}")
        print(f"Best E-value: {min(align['e_value'] for align in alignments):.2e}")
        print(f"Best bit score: {max(align['bit_score'] for align in alignments):.1f}")
        print(f"Average identity: {np.mean([align['identity_percent'] for align in alignments]):.1f}%")
        print(f"Average coverage: {np.mean([align['coverage_percent'] for align in alignments]):.1f}%")
    
    def create_alignment_visualization(self, blast_results: Dict, alignment_index: int = 0):
        """
        Create a detailed visualization of a specific alignment.
        
        Args:
            blast_results (Dict): Parsed BLAST results
            alignment_index (int): Index of alignment to visualize
        """
        if not blast_results or not blast_results['alignments']:
            print("❌ No BLAST results to visualize")
            return
        
        if alignment_index >= len(blast_results['alignments']):
            print(f"❌ Alignment index {alignment_index} out of range")
            return
        
        alignment = blast_results['alignments'][alignment_index]
        
        print(f"\n🔍 Detailed Alignment View - Hit #{alignment_index + 1}")
        print(f"Title: {alignment['title']}")
        print(f"Accession: {alignment['accession']}")
        print(f"E-value: {alignment['e_value']:.2e}")
        print(f"Bit Score: {alignment['bit_score']:.1f}")
        print(f"Identity: {alignment['identity']}/{len(alignment['query_seq'])} ({alignment['identity_percent']:.1f}%)")
        print(f"Coverage: {alignment['coverage_percent']:.1f}%")
        
        # Create alignment visualization
        query_seq = alignment['query_seq']
        match_seq = alignment['match_seq']
        subject_seq = alignment['subject_seq']
        
        print(f"\n🧬 Sequence Alignment:")
        print(f"Query Start: {alignment['query_start']}")
        
        # Print alignment in blocks of 60 characters
        block_size = 60
        for i in range(0, len(query_seq), block_size):
            end = min(i + block_size, len(query_seq))
            pos = alignment['query_start'] + i
            
            print(f"\nPosition {pos:>6}: {query_seq[i:end]}")
            print(f"{'Match':<12}: {match_seq[i:end]}")
            print(f"{'Subject':<12}: {subject_seq[i:end]}")
    
    def export_results(self, blast_results: Dict, filename: str = None):
        """
        Export BLAST results to CSV and FASTA files.
        
        Args:
            blast_results (Dict): Parsed BLAST results
            filename (str): Base filename for exports
        """
        if not blast_results:
            print("❌ No BLAST results to export")
            return
        
        if not filename:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"blast_results_{timestamp}"
        
        # Export summary to CSV
        alignments_df = pd.DataFrame(blast_results['alignments'])
        csv_file = f"{filename}.csv"
        alignments_df.to_csv(csv_file, index=False)
        print(f"📄 Results exported to {csv_file}")
        
        # Export sequences to FASTA
        fasta_file = f"{filename}.fasta"
        with open(fasta_file, 'w') as f:
            for i, align in enumerate(blast_results['alignments']):
                f.write(f">Hit_{i+1}_{align['accession']}\n")
                f.write(f"{align['subject_seq']}\n")
        print(f"🧬 Sequences exported to {fasta_file}")
        
        return csv_file, fasta_file

# Import numpy for calculations
import numpy as np

# Example usage and demonstration
def demo_gene_blast_agent():
    """Demonstration of the Gene BLAST Agent capabilities."""
    print("🚀 Gene BLAST Agent Demo")
    print("=" * 50)
    
    # Initialize agent (replace with your email)
    agent = GeneBlastAgent(email="your_email@example.com")
    # ^^ NOTE: make sure to replace with your own email above 
    
    # Example 1: Search for insulin genes
    print("\n1️⃣  Searching for insulin genes...")
    genes = agent.search_genes("insulin human", max_results=5)
    
    if genes:
        print(f"Found {len(genes)} genes:")
        for gene in genes[:3]:
            print(f"  - {gene['id']}: {gene['description'][:100]}...")
    
    # Example 2: BLAST a sequence
    print("\n2️⃣  Running BLAST analysis...")
    
    # Example insulin sequence (first 200bp)
    example_sequence = """ATGGCCCTGTGGATGCGCCTCCTGCCCCTGCTGGCGCTGCTGGCCCTCTGGGGACCTGACCCAGCCGCAGCCTTTGTGAACCAACACCTGTGCGGCTCACACCTGGTGGAAGCTCTCTACCTAGTGTGCGGGGAACGAGGCTTCTTCTACACACCCAAGACCCGCCGGGAGGCAGAGGACCTGCAGGGTGAGCCAACTGCCCATTGC"""
    
    blast_results = agent.run_blast(
        sequence=example_sequence,
        program="blastn",
        database="nt",
        max_hits=20
    )
    
    if blast_results:
        print("✅ BLAST analysis completed!")
        
        # Visualize results
        print("\n3️⃣  Creating visualizations...")
        agent.visualize_blast_results(blast_results, top_n=15)
        
        # Show detailed alignment
        print("\n4️⃣  Detailed alignment view...")
        agent.create_alignment_visualization(blast_results, alignment_index=0)
        
        # Export results
        print("\n5️⃣  Exporting results...")
        agent.export_results(blast_results, "demo_blast_results")
    
    print("\n🎉 Demo completed!")


In [None]:
# Uncomment the line below to run the demo
demo_gene_blast_agent()

In [None]:
## End of Notebook ##