In [None]:
# Ch16-3 Genome Design [OLDER VERSION]

In [None]:
# This is an older version of the notebook that was developed before fixing CPU vs CUDA issues
#.  it is provided here for reference

In [None]:
# Updates & Notes
# plotly = 6.3.1
# tqdm = 4.67.1

In [None]:
"""
DNABERT Genome Generator 
=================================================
Interactive genome generation and visualization for Jupyter notebooks.
Features real DNABERT-driven optimization with reliable progress tracking.

Requirements:
pip install transformers torch numpy biopython pandas matplotlib seaborn plotly tqdm
pip install kaleido  # Optional: for saving interactive plots as PNG
"""

In [None]:
# Install Packages #
! pip install transformers torch numpy biopython pandas matplotlib seaborn plotly tqdm

In [None]:
# Install Additional Packages #
! pip install einops

In [None]:
## Import Libraries ##
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from typing import List, Dict, Tuple
import re
import random
import warnings
import logging
from tqdm import tqdm
from IPython.display import display, HTML, Markdown, clear_output
from io import StringIO
import sys
import time

In [None]:
# Jupyter notebook configuration
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10
sns.set_style("whitegrid")

# Suppress all transformers warnings for cleaner notebook output
logging.getLogger("transformers").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

In [None]:
## Genome Generation class using DNABERT ##
class DNABERTGenomeGenerator:
    def __init__(self, model_name: str = "zhihan1996/DNABERT-2-117M"):
        """
        Initialize DNABERT model for genome generation and analysis.
        Optimized for Jupyter notebook usage without widget dependencies.
        """
        print("🧬 Initializing DNABERT Genome Generator...")
        print("=" * 50)
        
        print(f"📥 Loading DNABERT model: {model_name}")
        
        # Load model with progress updates
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            transformers_logger = logging.getLogger("transformers")
            original_level = transformers_logger.level
            transformers_logger.setLevel(logging.ERROR)
            
            try:
                print("  🔤 Loading tokenizer...")
                self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
                
                print("  🤖 Loading model...")
                self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
                self.model.eval()
                
            finally:
                transformers_logger.setLevel(original_level)
        
        # Set maximum sequence length for tokenization
        self.max_length = 512
        
        print("  🧪 Testing model output format...")
        self._test_model_output()
        
        print("  📚 Generating reference embeddings...")
        self.reference_embeddings = self._generate_reference_embeddings()
        
        display(HTML("""
        <div style="border: 3px solid #4CAF50; padding: 15px; border-radius: 10px; background-color: #e8f5e8; margin: 10px 0;">
            <h3 style="color: #2E7D32; margin: 0;">✅ DNABERT Model Ready!</h3>
            <p style="margin: 5px 0; color: #424242;">Ready to generate AI-optimized genomes</p>
        </div>
        """))
        
        # Define functional genome elements
        self.genome_elements = {
            'promoter': {'length': 50, 'consensus': 'TATAAA', 'gc_content': 0.4},
            'coding_sequence': {'length': 600, 'start_codon': 'ATG', 'stop_codons': ['TAA', 'TAG', 'TGA'], 'gc_content': 0.5},
            'terminator': {'length': 30, 'gc_content': 0.6},
            'intergenic': {'length': 100, 'gc_content': 0.45}
        }
    
    def _test_model_output(self):
        """Test model output format with minimal output."""
        test_seq = "ATGC"
        inputs = self.tokenizer(test_seq, return_tensors="pt", max_length=self.max_length, truncation=True)
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            if isinstance(outputs, tuple):
                self.output_format = 'tuple'
            elif hasattr(outputs, 'last_hidden_state'):
                self.output_format = 'standard'
            else:
                self.output_format = 'unknown'
    
    def _generate_reference_embeddings(self) -> Dict[str, np.ndarray]:
        """Generate reference embeddings with simple progress tracking."""
        references = {}
        
        # Strong promoter sequences (from literature)
        strong_promoters = [
            "TTGACAATTAATCATCGGCTCGTATAATGTGTGGAATTGTGAGCGGATAACAATTTCACACAGGAAACAG",  # lac promoter
            "AATTGTGAGCGCTCACAATTCCACACAACATACGAGCCGGAAGCATAAAGTGTAAAGCCTGGGGTGCCTAAT",  # trp promoter
            "TTTACACTTTTATGCTTCCGGCTCGTATGTTGTGTGGAATTGTGAGCGCTCACAATTCCACACAACATACGA",  # tac promoter
        ]
        
        strong_genes = [
            "ATGAAACAACGCATCGTAGCGGCTCTGATCCTCGAGCGTCTGACCCAGTACGAGGCCATGACCAACGAGTAA",
            "ATGGTCAACAAACGCCTGGCGATCTACGACCGTATCAACGAGCTCAACAAACACCTGGAACAGGACAAATAA",
            "ATGCTGGAACAGAAACGTATCCAGGCGATCAACGAGTACCTCAACGAGCGCATCCAGAAACGCCTCAAATAG",
        ]
        
        strong_terminators = [
            "AAAAGCCCGAAAGGAAGCTGAGTTGGCTGCTGCCACCGCTGAGCAATAACTAGCATAACCCCTTGGGGCCTCTAAACGGGTCTT",
            "GGCGGAATTCGGGGGCGAGCGAACGCGTAAGGATTACCCCGGGCGCCGAAACGTAGCGCGACGCCGAAACGACGGCCT",
        ]
        
        print("    📝 Computing promoter embeddings...")
        promoter_embeddings = [self.get_sequence_embedding(seq) for seq in strong_promoters]
        references['promoter'] = np.mean(promoter_embeddings, axis=0)
        
        print("    🧬 Computing coding sequence embeddings...")
        gene_embeddings = [self.get_sequence_embedding(seq) for seq in strong_genes]
        references['coding_sequence'] = np.mean(gene_embeddings, axis=0)
        
        print("    🔚 Computing terminator embeddings...")
        terminator_embeddings = [self.get_sequence_embedding(seq) for seq in strong_terminators]
        references['terminator'] = np.mean(terminator_embeddings, axis=0)
        
        print("    🌐 Computing intergenic reference...")
        intergenic_seqs = [self.generate_random_sequence(100, 0.45) for _ in range(5)]
        intergenic_embeddings = [self.get_sequence_embedding(seq) for seq in intergenic_seqs]
        references['intergenic'] = np.mean(intergenic_embeddings, axis=0)
        
        print("    ✅ Reference embeddings complete!")
        
        return references
    
    def get_sequence_embedding(self, sequence: str) -> np.ndarray:
        """Get embedding representation of DNA sequence."""
        clean_seq = self.preprocess_sequence(sequence)
        
        if len(clean_seq) > self.max_length:
            clean_seq = clean_seq[:self.max_length]
        
        inputs = self.tokenizer(clean_seq, return_tensors="pt", padding=True, 
                               truncation=True, max_length=self.max_length)
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            
            if self.output_format == 'tuple':
                hidden_states = outputs[0]
            elif hasattr(outputs, 'last_hidden_state'):
                hidden_states = outputs.last_hidden_state
            else:
                hidden_states = outputs['last_hidden_state']
            
            embedding = hidden_states.mean(dim=1).squeeze().numpy()
            
        return embedding
    
    def preprocess_sequence(self, sequence: str) -> str:
        """Clean and validate DNA sequence."""
        return re.sub(r'[^ATGC]', '', sequence.upper())
    
    def generate_random_sequence(self, length: int, gc_content: float = 0.5) -> str:
        """Generate random DNA sequence with specified GC content."""
        gc_count = int(length * gc_content)
        at_count = length - gc_count
        
        nucleotides = ['G'] * (gc_count // 2) + ['C'] * (gc_count // 2)
        nucleotides += ['A'] * (at_count // 2) + ['T'] * (at_count // 2)
        
        while len(nucleotides) < length:
            nucleotides.append(random.choice(['A', 'T', 'G', 'C']))
        
        random.shuffle(nucleotides)
        return ''.join(nucleotides)
    
    def generate_promoter_sequence(self, length: int = 50) -> str:
        """Generate realistic promoter sequence with TATA box."""
        sequence = list(self.generate_random_sequence(length, 0.4))
        tata_pos = length - 30
        tata_box = "TATAAA"
        for i, nucleotide in enumerate(tata_box):
            if tata_pos + i < len(sequence):
                sequence[tata_pos + i] = nucleotide
        return ''.join(sequence)
    
    def generate_coding_sequence(self, length: int = 600) -> str:
        """Generate coding sequence with start codon and proper reading frame."""
        sequence = ["ATG"]
        remaining_length = length - 3
        
        codons = []
        for _ in range(remaining_length // 3):
            codon = self.generate_random_sequence(3, 0.5)
            while codon in ['TAA', 'TAG', 'TGA']:
                codon = self.generate_random_sequence(3, 0.5)
            codons.append(codon)
        
        if codons:
            codons[-1] = random.choice(['TAA', 'TAG', 'TGA'])
        
        sequence.extend(codons)
        
        current_length = len(''.join(sequence))
        if current_length < length:
            sequence.append(self.generate_random_sequence(length - current_length, 0.5))
        
        return ''.join(sequence)
    
    def generate_terminator_sequence(self, length: int = 30) -> str:
        """Generate terminator sequence with hairpin structure potential."""
        return self.generate_random_sequence(length, 0.6)
    
    def optimize_sequence_with_model(self, initial_sequence: str, element_type: str = 'intergenic') -> str:
        """Use DNABERT embeddings for sequence optimization with simple progress tracking."""
        if element_type not in self.reference_embeddings:
            return self._optimize_sequence_basic(initial_sequence)
        
        best_sequence = initial_sequence
        best_embedding = self.get_sequence_embedding(initial_sequence)
        best_score = self._score_sequence_with_model(best_embedding, element_type)
        
        print(f"    🔧 Optimizing {element_type} (initial score: {best_score:.3f})")
        
        improvements = 0
        iteration_scores = []
        
        for iteration in range(25):
            candidate = self._mutate_sequence_guided(best_sequence, element_type)
            candidate_embedding = self.get_sequence_embedding(candidate)
            candidate_score = self._score_sequence_with_model(candidate_embedding, element_type)
            
            iteration_scores.append(candidate_score)
            
            if candidate_score > best_score:
                best_sequence = candidate
                best_embedding = candidate_embedding
                best_score = candidate_score
                improvements += 1
                
                # Print progress every 5 improvements
                if improvements % 5 == 0:
                    print(f"      ↗️  {improvements} improvements, score: {best_score:.3f}")
                
                if best_score > 0.8:
                    break
        
        final_improvement = best_score - self._score_sequence_with_model(
            self.get_sequence_embedding(initial_sequence), element_type
        )
        
        if improvements > 0:
            print(f"    ✅ Optimization complete! Final score: {best_score:.3f} "
                  f"(+{final_improvement:.3f}, {improvements} improvements)")
        else:
            print(f"    ➡️  No improvements found (final score: {best_score:.3f})")
            
        return best_sequence
    
    def _score_sequence_with_model(self, embedding: np.ndarray, element_type: str) -> float:
        """Score sequence based on embedding similarity to known functional sequences."""
        if element_type not in self.reference_embeddings:
            return 0.0
        
        reference_embedding = self.reference_embeddings[element_type]
        model_similarity = np.dot(embedding, reference_embedding) / (
            np.linalg.norm(embedding) * np.linalg.norm(reference_embedding)
        )
        return max(0, (model_similarity + 1) / 2)
    
    def _mutate_sequence_guided(self, sequence: str, element_type: str) -> str:
        """Make targeted mutations guided by DNABERT embedding feedback."""
        candidate = list(sequence)
        num_mutations = max(1, min(3, len(sequence) // 50))
        
        strategies = ['random', 'gc_guided', 'codon_aware']
        strategy = random.choice(strategies)
        
        for _ in range(num_mutations):
            pos = random.randint(0, len(candidate) - 1)
            current_nucleotide = candidate[pos]
            
            if strategy == 'random':
                new_nucleotides = [n for n in ['A', 'T', 'G', 'C'] if n != current_nucleotide]
                candidate[pos] = random.choice(new_nucleotides)
            elif strategy == 'gc_guided':
                target_gc = self.genome_elements.get(element_type, {}).get('gc_content', 0.5)
                current_gc = (sequence.count('G') + sequence.count('C')) / len(sequence)
                if current_gc < target_gc:
                    candidate[pos] = random.choice(['G', 'C'])
                else:
                    candidate[pos] = random.choice(['A', 'T'])
            elif strategy == 'codon_aware' and element_type == 'coding_sequence':
                codon_pos = pos % 3
                if codon_pos == 0 and pos + 2 < len(candidate):
                    codon = ''.join(candidate[pos:pos+3])
                    if codon in ['TAA', 'TAG', 'TGA'] and pos + 3 < len(candidate) - 3:
                        candidate[pos] = random.choice(['A', 'C', 'G'])
                    else:
                        new_nucleotides = [n for n in ['A', 'T', 'G', 'C'] if n != current_nucleotide]
                        candidate[pos] = random.choice(new_nucleotides)
                else:
                    new_nucleotides = [n for n in ['A', 'T', 'G', 'C'] if n != current_nucleotide]
                    candidate[pos] = random.choice(new_nucleotides)
        
        return ''.join(candidate)

    # Note - this is a fallback function and will not normally be used unless you supply a reference embedding that is not recognized
    def _optimize_sequence_basic(self, initial_sequence: str) -> str:
        """Fallback optimization using traditional methods."""
        best_sequence = initial_sequence
        best_score = 0
        
        for _ in range(10):
            sequence = list(best_sequence)
            num_changes = max(1, len(sequence) // 50)
            
            for _ in range(num_changes):
                pos = random.randint(0, len(sequence) - 1)
                sequence[pos] = random.choice(['A', 'T', 'G', 'C'])
            
            candidate = ''.join(sequence)
            gc_content = (candidate.count('G') + candidate.count('C')) / len(candidate)
            gc_score = 1 - abs(gc_content - 0.5)
            
            if gc_score > best_score:
                best_sequence = candidate
                best_score = gc_score
        
        return best_sequence
    
    def generate_2kb_genome(self) -> Dict[str, any]:
        """Generate a functional 2kb genome with DNABERT optimization and progress tracking."""
        display(HTML("""
        <div style="border: 2px solid #2196F3; padding: 15px; border-radius: 10px; background-color: #e3f2fd; margin: 10px 0;">
            <h3 style="color: #1976D2; margin: 0;">🧬 Generating 2kb Genome with DNABERT Optimization</h3>
        </div>
        """))
        
        genome_structure = []
        total_length = 0
        target_length = 2000
        
        elements = [
            ('promoter', 80), ('coding_sequence', 600), ('intergenic', 120),
            ('promoter', 70), ('coding_sequence', 500), ('terminator', 50),
            ('intergenic', 150), ('promoter', 60), ('coding_sequence', 400),
            ('terminator', 40)
        ]
        
        genome_sequence = ""
        
        print(f"📋 Planned elements: {len(elements)} functional regions")
        print("=" * 60)
        
        for i, (element_type, length) in enumerate(elements, 1):
            if total_length + length > target_length:
                length = target_length - total_length
                if length <= 0:
                    break
            
            print(f"🔧 [{i}/{len(elements)}] Generating {element_type} ({length} bp)")
            
            if element_type == 'promoter':
                sequence = self.generate_promoter_sequence(length)
            elif element_type == 'coding_sequence':
                sequence = self.generate_coding_sequence(length)
            elif element_type == 'terminator':
                sequence = self.generate_terminator_sequence(length)
            else:
                sequence = self.generate_random_sequence(length, 0.45)
            
            # DNABERT optimization
            sequence = self.optimize_sequence_with_model(sequence, element_type)
            
            genome_structure.append({
                'type': element_type,
                'start': total_length,
                'end': total_length + length,
                'sequence': sequence
            })
            
            genome_sequence += sequence
            total_length += length
            
            # Progress indicator
            progress_percent = (total_length / target_length) * 100
            print(f"    📊 Progress: {progress_percent:.1f}% ({total_length}/{target_length} bp)")
            print()
            
            if total_length >= target_length:
                break
        
        # Pad to exactly 2kb if needed
        if len(genome_sequence) < target_length:
            padding_length = target_length - len(genome_sequence)
            print(f"🔧 Adding padding ({padding_length} bp)")
            padding = self.generate_random_sequence(padding_length)
            padding = self.optimize_sequence_with_model(padding, 'intergenic')
            genome_sequence += padding
            genome_structure.append({
                'type': 'padding',
                'start': len(genome_sequence) - len(padding),
                'end': len(genome_sequence),
                'sequence': padding
            })
        
        final_genome = genome_sequence[:target_length]
        
        display(HTML(f"""
        <div style="border: 2px solid #4CAF50; padding: 15px; border-radius: 10px; background-color: #e8f5e8; margin: 10px 0;">
            <h3 style="color: #2E7D32; margin: 0;">✅ Genome Generation Complete!</h3>
            <p style="margin: 5px 0; color: #424242;">Generated {len(final_genome)} bp genome with {len(genome_structure)} functional elements</p>
        </div>
        """))
        
        return {
            'sequence': final_genome,
            'structure': genome_structure,
            'length': len(final_genome)
        }
    
    def analyze_genome(self, genome: Dict) -> Dict:
        """Comprehensive genome analysis with progress tracking."""
        sequence = genome['sequence']
        
        display(HTML("""
        <div style="border: 2px solid #9C27B0; padding: 15px; border-radius: 10px; background-color: #f3e5f5; margin: 10px 0;">
            <h3 style="color: #7B1FA2; margin: 0;">🔬 Analyzing Genome with DNABERT</h3>
        </div>
        """))
        
        # Basic statistics
        length = len(sequence)
        gc_content = (sequence.count('G') + sequence.count('C')) / length
        
        composition = {
            'A': sequence.count('A') / length,
            'T': sequence.count('T') / length,
            'G': sequence.count('G') / length,
            'C': sequence.count('C') / length
        }
        
        print("📊 Computing basic statistics...")
        print(f"   Length: {length:,} bp")
        print(f"   GC Content: {gc_content:.1%}")
        
        # GC content sliding window
        print("📈 Computing GC content windows...")
        window_size = 50
        gc_windows = []
        windows = range(0, length - window_size + 1, 10)
        
        for i, pos in enumerate(windows):
            if i % 50 == 0:  # Progress update every 50 windows
                print(f"   Processing window {i+1}/{len(windows)}")
            window = sequence[pos:pos + window_size]
            gc = (window.count('G') + window.count('C')) / window_size
            gc_windows.append({'position': pos + window_size // 2, 'gc_content': gc})
        
        # Find ORFs
        print("🔍 Finding Open Reading Frames...")
        orfs = self.find_orfs(sequence)
        print(f"   Found {len(orfs)} ORFs")
        
        # Get sequence embedding
        print("🧠 Computing DNABERT embedding...")
        embedding = self.get_sequence_embedding(sequence)
        
        # DNABERT-based element quality analysis
        print("⭐ Analyzing element quality with DNABERT...")
        element_quality_scores = self._analyze_element_quality(genome['structure'])
        
        print("✅ Analysis complete!")
        
        return {
            'length': length,
            'gc_content': gc_content,
            'composition': composition,
            'gc_windows': gc_windows,
            'orfs': orfs,
            'embedding': embedding,
            'structure': genome['structure'],
            'element_quality_scores': element_quality_scores
        }
    
    def _analyze_element_quality(self, structure: List[Dict]) -> Dict:
        """Analyze element quality with DNABERT embeddings and progress tracking."""
        quality_scores = {}
        element_type_scores = {}
        
        print(f"   Analyzing {len(structure)} elements...")
        
        for i, element in enumerate(structure, 1):
            element_type = element['type']
            sequence = element['sequence']
            
            if i % 3 == 0:  # Progress update every 3 elements
                print(f"   Processing element {i}/{len(structure)}")
            
            if element_type in self.reference_embeddings:
                element_embedding = self.get_sequence_embedding(sequence)
                quality_score = self._score_sequence_with_model(element_embedding, element_type)
                
                element_id = f"{element_type}_{element['start']}-{element['end']}"
                quality_scores[element_id] = {
                    'type': element_type,
                    'score': quality_score,
                    'start': element['start'],
                    'end': element['end'],
                    'length': element['end'] - element['start']
                }
                
                if element_type not in element_type_scores:
                    element_type_scores[element_type] = []
                element_type_scores[element_type].append(quality_score)
        
        # Calculate averages
        average_scores = {}
        for element_type, scores in element_type_scores.items():
            average_scores[element_type] = {
                'average': np.mean(scores),
                'std': np.std(scores),
                'count': len(scores),
                'min': np.min(scores),
                'max': np.max(scores)
            }
        
        return {
            'individual_scores': quality_scores,
            'average_by_type': average_scores,
            'overall_average': np.mean([score['score'] for score in quality_scores.values()]) if quality_scores else 0.0
        }
    
    def find_orfs(self, sequence: str) -> List[Dict]:
        """Find Open Reading Frames in sequence."""
        start_codon = 'ATG'
        stop_codons = ['TAA', 'TAG', 'TGA']
        orfs = []
        
        for frame in range(3):
            i = frame
            while i < len(sequence) - 2:
                if sequence[i:i+3] == start_codon:
                    start_pos = i
                    i += 3
                    
                    while i < len(sequence) - 2:
                        codon = sequence[i:i+3]
                        if codon in stop_codons:
                            orfs.append({
                                'start': start_pos,
                                'end': i + 3,
                                'length': i + 3 - start_pos,
                                'frame': frame,
                                'sequence': sequence[start_pos:i+3]
                            })
                            break
                        i += 3
                    break
                else:
                    i += 3
        
        return orfs
    
    def display_genome_summary(self, analysis: Dict):
        """Display a comprehensive genome summary in notebook format."""
        quality_scores = analysis['element_quality_scores']
        
        # Create summary HTML
        html_summary = f"""
        <div style="border: 2px solid #4CAF50; padding: 20px; border-radius: 10px; background-color: #f9f9f9;">
            <h2 style="color: #4CAF50; margin-top: 0;">🧬 DNABERT-Optimized Genome Summary</h2>
            
            <div style="display: flex; justify-content: space-between; margin: 20px 0;">
                <div style="flex: 1; margin-right: 20px;">
                    <h3>📊 Basic Statistics</h3>
                    <ul>
                        <li><strong>Length:</strong> {analysis['length']:,} bp</li>
                        <li><strong>GC Content:</strong> {analysis['gc_content']:.1%}</li>
                        <li><strong>Elements:</strong> {len(analysis['structure'])}</li>
                        <li><strong>ORFs Found:</strong> {len(analysis['orfs'])}</li>
                    </ul>
                </div>
                
                <div style="flex: 1;">
                    <h3>🎯 DNABERT Quality Scores</h3>
                    <ul>
                        <li><strong>Overall Quality:</strong> <span style="color: {'green' if quality_scores['overall_average'] > 0.7 else 'orange' if quality_scores['overall_average'] > 0.5 else 'red'};">{quality_scores['overall_average']:.3f}</span></li>
        """
        
        for element_type, stats in quality_scores['average_by_type'].items():
            color = 'green' if stats['average'] > 0.7 else 'orange' if stats['average'] > 0.5 else 'red'
            html_summary += f'<li><strong>{element_type.replace("_", " ").title()}:</strong> <span style="color: {color};">{stats["average"]:.3f}</span></li>'
        
        html_summary += """
                    </ul>
                </div>
            </div>
            
            <div style="margin-top: 20px;">
                <h3>🧬 Nucleotide Composition</h3>
                <div style="display: flex; justify-content: space-around;">
        """
        
        for nucleotide, fraction in analysis['composition'].items():
            count = int(fraction * analysis['length'])
            html_summary += f'<div style="text-align: center;"><strong>{nucleotide}</strong><br>{count:,} ({fraction:.1%})</div>'
        
        html_summary += """
                </div>
            </div>
        </div>
        """
        
        display(HTML(html_summary))
        
        # Create quality scores DataFrame for better display
        if quality_scores['individual_scores']:
            scores_data = []
            for element_id, score_info in quality_scores['individual_scores'].items():
                scores_data.append({
                    'Element': element_id,
                    'Type': score_info['type'].replace('_', ' ').title(),
                    'Start': score_info['start'],
                    'End': score_info['end'],
                    'Length': score_info['length'],
                    'Quality Score': f"{score_info['score']:.3f}"
                })
            
            scores_df = pd.DataFrame(scores_data)
            display(HTML("<h3>📋 Individual Element Quality Scores</h3>"))
            display(scores_df.style.background_gradient(subset=['Quality Score'], cmap='RdYlGn', vmin=0, vmax=1))

In [None]:
# Visualize the Genome in your Notebook #
def visualize_genome_notebook(genome_data: Dict, analysis: Dict):
    """Create comprehensive genome visualizations optimized for Jupyter notebooks with additional PNG output."""
    
    display(HTML("<h3>📊 Genome Visualization Dashboard</h3>"))
    
    # Set up the plotting style
    plt.style.use('seaborn-v0_8')
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle('DNABERT-Optimized 2kb Genome Analysis', fontsize=16, fontweight='bold')
    
    # 1. Genome Map
    ax1 = axes[0, 0]
    colors = {
        'promoter': '#FF6B6B', 'coding_sequence': '#4ECDC4', 
        'terminator': '#45B7D1', 'intergenic': '#96CEB4', 'padding': '#FECA57'
    }
    
    y_pos = 0
    plotted_types = set()
    for element in analysis['structure']:
        show_label = element['type'] not in plotted_types
        plotted_types.add(element['type'])
        
        ax1.barh(y_pos, element['end'] - element['start'], 
                left=element['start'], height=0.5,
                color=colors.get(element['type'], '#95A5A6'),
                label=element['type'] if show_label else "")
    
    ax1.set_xlabel('Position (bp)')
    ax1.set_title('Genome Structure Map', fontweight='bold')
    ax1.set_ylim(-0.5, 0.5)
    ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    # 2. GC Content
    ax2 = axes[0, 1]
    if analysis['gc_windows']:
        positions = [w['position'] for w in analysis['gc_windows']]
        gc_values = [w['gc_content'] for w in analysis['gc_windows']]
        
        ax2.plot(positions, gc_values, color='#E74C3C', linewidth=2)
        ax2.axhline(y=0.5, color='gray', linestyle='--', alpha=0.7)
        ax2.fill_between(positions, gc_values, alpha=0.3, color='#E74C3C')
    
    ax2.set_xlabel('Position (bp)')
    ax2.set_ylabel('GC Content')
    ax2.set_title('GC Content Distribution', fontweight='bold')
    ax2.grid(True, alpha=0.3)
    
    # 3. Nucleotide Composition
    ax3 = axes[0, 2]
    nucleotides = list(analysis['composition'].keys())
    percentages = [analysis['composition'][n] * 100 for n in nucleotides]
    colors_pie = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4']
    
    wedges, texts, autotexts = ax3.pie(percentages, labels=nucleotides, autopct='%1.1f%%', 
                                      colors=colors_pie, startangle=90)
    ax3.set_title('Nucleotide Composition', fontweight='bold')
    
    # 4. Quality Scores by Element Type
    ax4 = axes[1, 0]
    if 'element_quality_scores' in analysis:
        quality_data = analysis['element_quality_scores']['average_by_type']
        if quality_data:
            types = list(quality_data.keys())
            scores = [quality_data[t]['average'] for t in types]
            bars = ax4.bar(types, scores, color=[colors.get(t, '#95A5A6') for t in types])
            
            # Add value labels on bars
            for bar, score in zip(bars, scores):
                height = bar.get_height()
                ax4.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                        f'{score:.3f}', ha='center', va='bottom')
    
    ax4.set_ylabel('Quality Score')
    ax4.set_title('DNABERT Quality Scores by Element Type', fontweight='bold')
    ax4.set_ylim(0, 1)
    plt.setp(ax4.xaxis.get_majorticklabels(), rotation=45, ha='right')
    
    # 5. ORF Analysis
    ax5 = axes[1, 1]
    if analysis['orfs']:
        orf_lengths = [orf['length'] for orf in analysis['orfs']]
        ax5.hist(orf_lengths, bins=max(5, len(orf_lengths)//2), color='#9B59B6', alpha=0.7, edgecolor='black')
        ax5.set_xlabel('ORF Length (bp)')
        ax5.set_ylabel('Count')
        ax5.set_title(f'ORF Length Distribution ({len(analysis["orfs"])} ORFs)', fontweight='bold')
    else:
        ax5.text(0.5, 0.5, 'No ORFs found', ha='center', va='center', transform=ax5.transAxes, fontsize=14)
        ax5.set_title('ORF Analysis', fontweight='bold')
    
    # 6. Element Distribution
    ax6 = axes[1, 2]
    element_types = [e['type'] for e in analysis['structure']]
    element_lengths = [e['end'] - e['start'] for e in analysis['structure']]
    
    type_lengths = {}
    for etype, length in zip(element_types, element_lengths):
        type_lengths[etype] = type_lengths.get(etype, 0) + length
    
    bars = ax6.bar(type_lengths.keys(), type_lengths.values(), 
                   color=[colors.get(t, '#95A5A6') for t in type_lengths.keys()])
    
    # Add value labels
    for bar in bars:
        height = bar.get_height()
        ax6.text(bar.get_x() + bar.get_width()/2., height + 5,
                f'{int(height)}', ha='center', va='bottom')
    
    ax6.set_ylabel('Total Length (bp)')
    ax6.set_title('Element Type Distribution', fontweight='bold')
    plt.setp(ax6.xaxis.get_majorticklabels(), rotation=45, ha='right')
    
    plt.tight_layout()
    
    # Save the figure as PNG
    static_filename = 'Ch16-3-dnabert_genome_analysis_static.png'
    plt.savefig(static_filename, dpi=300, bbox_inches='tight', facecolor='white', edgecolor='none')
    print(f"💾 Static plots saved to: {static_filename}")
    
    plt.show()
    
    return static_filename

In [None]:
## Interactive Genome Viewer ## 
def create_interactive_genome_viewer_notebook(genome_data: Dict, analysis: Dict):
    """Create interactive genome visualization optimized for notebooks with PNG saving."""
    
    display(HTML("<h3>🖱️ Interactive Genome Dashboard</h3>"))
    
    # Create subplots with proper specifications
    fig = make_subplots(
        rows=3, cols=2,
        subplot_titles=('Genome Structure Map', 'GC Content Distribution', 
                       'Nucleotide Composition', 'Element Quality Scores',
                       'ORF Analysis', 'Quality vs Length'),
        specs=[[{"colspan": 2}, None],
               [{"type": "domain"}, {}],
               [{}, {}]],
        vertical_spacing=0.12,
        horizontal_spacing=0.1
    )
    
    colors = {
        'promoter': '#FF6B6B', 'coding_sequence': '#4ECDC4', 
        'terminator': '#45B7D1', 'intergenic': '#96CEB4', 'padding': '#FECA57'
    }
    
    # 1. Genome Structure
    plotted_types = set()
    for element in analysis['structure']:
        show_legend = element['type'] not in plotted_types
        plotted_types.add(element['type'])
        
        fig.add_trace(
            go.Scatter(
                x=[element['start'], element['end'], element['end'], element['start'], element['start']],
                y=[0, 0, 1, 1, 0],
                fill='toself',
                fillcolor=colors.get(element['type'], '#95A5A6'),
                line=dict(color='black', width=1),
                name=element['type'].replace('_', ' ').title(),
                showlegend=show_legend,
                text=f"Type: {element['type']}<br>Start: {element['start']}<br>End: {element['end']}<br>Length: {element['end']-element['start']} bp",
                hovertemplate='%{text}<extra></extra>',
                mode='lines'
            ),
            row=1, col=1
        )
    
    # 2. GC Content
    if analysis['gc_windows']:
        positions = [w['position'] for w in analysis['gc_windows']]
        gc_values = [w['gc_content'] for w in analysis['gc_windows']]
        
        fig.add_trace(
            go.Scatter(
                x=positions, y=gc_values,
                mode='lines',
                name='GC Content',
                line=dict(color='#E74C3C', width=2),
                showlegend=False
            ),
            row=2, col=2
        )
    
    # 3. Nucleotide Composition
    nucleotides = list(analysis['composition'].keys())
    percentages = [analysis['composition'][n] * 100 for n in nucleotides]
    colors_pie = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4']
    
    fig.add_trace(
        go.Pie(
            labels=nucleotides,
            values=percentages,
            name="Nucleotides",
            marker=dict(colors=colors_pie),
            showlegend=False
        ),
        row=2, col=1
    )
    
    # 4. Quality Scores
    if 'element_quality_scores' in analysis:
        quality_data = analysis['element_quality_scores']['average_by_type']
        if quality_data:
            types = [t.replace('_', ' ').title() for t in quality_data.keys()]
            scores = [quality_data[t]['average'] for t in quality_data.keys()]
            
            fig.add_trace(
                go.Bar(
                    x=types, y=scores,
                    name="Quality Scores",
                    marker=dict(color=[colors.get(t.lower().replace(' ', '_'), '#95A5A6') for t in quality_data.keys()]),
                    showlegend=False,
                    text=[f'{s:.3f}' for s in scores],
                    textposition='outside'
                ),
                row=3, col=1
            )
    
    # 5. ORF Analysis
    if analysis['orfs']:
        orf_starts = [orf['start'] for orf in analysis['orfs']]
        orf_lengths = [orf['length'] for orf in analysis['orfs']]
        
        fig.add_trace(
            go.Scatter(
                x=orf_starts, y=orf_lengths,
                mode='markers',
                marker=dict(size=10, color='#9B59B6'),
                name='ORFs',
                text=[f"ORF {i+1}<br>Start: {start}<br>Length: {length}" 
                      for i, (start, length) in enumerate(zip(orf_starts, orf_lengths))],
                hovertemplate='%{text}<extra></extra>',
                showlegend=False
            ),
            row=3, col=2
        )
    
    # Update layout
    fig.update_layout(
        height=900,
        title_text="🧬 Interactive DNABERT-Optimized Genome Dashboard",
        title_x=0.5,
        showlegend=True
    )
    
    # Save as HTML for full interactivity
    html_filename = 'Ch16-3-dnabert_genome_analysis_interactive.html'
    fig.write_html(html_filename)
    print(f"💾 Interactive HTML saved to: {html_filename}")
    
    fig.show()
    return fig, html_filename

In [None]:
# Complete genome generation workflow
def generate_and_analyze_genome():
    """Complete workflow function for Jupyter notebook usage with automatic file saving."""
    
    display(HTML("""
    <div style="border: 3px solid #2196F3; padding: 20px; border-radius: 15px; background: linear-gradient(45deg, #e3f2fd, #f3e5f5);">
        <h1 style="color: #1976D2; text-align: center; margin: 0;">
            🧬 DNABERT Genome Generator 🤖
        </h1>
        <p style="text-align: center; color: #424242; margin: 10px 0;">
            AI-Powered 2kb Genome Generation with Automatic File Saving
        </p>
    </div>
    """))
    
    saved_files = []
    
    try:
        # Initialize generator
        generator = DNABERTGenomeGenerator()
        
        # Generate genome
        genome = generator.generate_2kb_genome()
        
        # Analyze genome
        analysis = generator.analyze_genome(genome)
        
        # Display summary
        generator.display_genome_summary(analysis)
        
        # Save genome FASTA
        fasta_filename = 'Ch16-3-dnabert_generated_genome.fasta'
        with open(fasta_filename, 'w') as f:
            f.write(">DNABERT_Generated_2kb_Genome\n")
            f.write("# Generated using DNABERT-optimized functional elements\n")
            f.write("# Promoters, coding sequences, terminators, and intergenic regions\n")
            sequence = genome['sequence']
            for i in range(0, len(sequence), 80):
                f.write(sequence[i:i+80] + '\n')
        
        saved_files.append(fasta_filename)
        print(f"💾 Genome FASTA saved to: {fasta_filename}")
        
        # Create visualizations and save figures
        print("\n📊 Creating and saving visualizations...")
        
        # Static plots (matplotlib)
        static_file = visualize_genome_notebook(genome, analysis)
        if static_file:
            saved_files.append(static_file)
        
        # Interactive plots (plotly)
        fig, interactive_html = create_interactive_genome_viewer_notebook(genome, analysis)
        if interactive_html:
            saved_files.append(interactive_html)
        
        # Save analysis summary as text file
        summary_filename = 'Ch16-3-dnabert_genome_analysis_summary.txt'
        with open(summary_filename, 'w') as f:
            f.write("DNABERT-Optimized 2kb Genome Analysis Summary\n")
            f.write("=" * 50 + "\n\n")
            
            f.write(f"Generated: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write(f"Length: {analysis['length']:,} bp\n")
            f.write(f"GC Content: {analysis['gc_content']:.1%}\n")
            f.write(f"Number of Elements: {len(analysis['structure'])}\n")
            f.write(f"ORFs Found: {len(analysis['orfs'])}\n\n")
            
            if 'element_quality_scores' in analysis:
                quality_scores = analysis['element_quality_scores']
                f.write(f"Overall DNABERT Quality Score: {quality_scores['overall_average']:.3f}\n\n")
                
                f.write("Element Type Quality Scores:\n")
                for element_type, stats in quality_scores['average_by_type'].items():
                    f.write(f"  {element_type:15}: {stats['average']:.3f} ± {stats['std']:.3f} "
                           f"(n={stats['count']}, range: {stats['min']:.3f}-{stats['max']:.3f})\n")
                
                f.write("\nIndividual Element Scores:\n")
                for element_id, score_info in quality_scores['individual_scores'].items():
                    f.write(f"  {element_id:25}: {score_info['score']:.3f}\n")
            
            f.write(f"\nNucleotide Composition:\n")
            for nucleotide, fraction in analysis['composition'].items():
                count = int(fraction * analysis['length'])
                f.write(f"  {nucleotide}: {count:4d} ({fraction:.1%})\n")
            
            if analysis['orfs']:
                f.write(f"\nOpen Reading Frames:\n")
                for i, orf in enumerate(analysis['orfs'], 1):
                    f.write(f"  ORF {i}: {orf['start']}-{orf['end']} ({orf['length']} bp, frame {orf['frame']})\n")
        
        saved_files.append(summary_filename)
        print(f"💾 Analysis summary saved to: {summary_filename}")
        
        # Display first 200bp
        display(HTML("<h4>🧬 First 200bp of Generated Genome:</h4>"))
        display(HTML(f"<code style='background-color: #f5f5f5; padding: 10px; border-radius: 5px; font-family: monospace; word-break: break-all;'>{genome['sequence'][:200]}</code>"))
        
        # Summary of saved files
        display(HTML(f"""
        <div style="border: 2px solid #4CAF50; padding: 15px; border-radius: 10px; background-color: #e8f5e8; margin: 15px 0;">
            <h3 style="color: #2E7D32; margin-top: 0;">📁 Files Saved Successfully!</h3>
            <ul style="margin: 10px 0;">
                {''.join([f'<li><strong>{filename}</strong></li>' for filename in saved_files])}
            </ul>
            <p style="margin-bottom: 0; color: #424242;">
                <strong>Total files:</strong> {len(saved_files)} | 
                <strong>Genome quality:</strong> {analysis['element_quality_scores']['overall_average']:.3f}/1.0
            </p>
        </div>
        """))
        
        return generator, genome, analysis, saved_files
        
    except Exception as e:
        display(HTML(f"<h3 style='color: red;'>❌ Error: {str(e)}</h3>"))
        display(HTML("<p>Please check your internet connection and ensure all required packages are installed.</p>"))
        return None, None, None, []

In [None]:
# Display instructions for notebook usage
def display_usage_instructions():
    """Display usage instructions for the notebook."""
    
    instructions_html = """
    <div style="border: 2px solid #4CAF50; padding: 20px; border-radius: 10px; background-color: #f1f8e9;">
        <h2 style="color: #2E7D32;">🚀 Jupyter Notebook Usage Instructions</h2>
        
        <h3>📦 Installation</h3>
        <p>Install required packages (no widget extensions needed!):</p>
        <code style="background-color: #e8f5e8; padding: 8px; border-radius: 3px; display: block; margin: 10px 0;">
        pip install transformers torch numpy biopython pandas matplotlib seaborn plotly tqdm
        </code>
        
        <p>For saving interactive plots as PNG (optional):</p>
        <code style="background-color: #e8f5e8; padding: 8px; border-radius: 3px; display: block; margin: 10px 0;">
        pip install kaleido
        </code>
        
        <h3>🎯 Quick Start</h3>
        <p>Run this single command to generate and analyze a complete 2kb genome with automatic file saving:</p>
        <code style="background-color: #e8f5e8; padding: 8px; border-radius: 3px; display: block; margin: 10px 0;">
        generator, genome, analysis, saved_files = generate_and_analyze_genome()
        </code>
        
        <h3>🔧 Advanced Usage</h3>
        <p>For step-by-step control:</p>
        <ol>
            <li>Initialize: <code>generator = DNABERTGenomeGenerator()</code></li>
            <li>Generate: <code>genome = generator.generate_2kb_genome()</code></li>
            <li>Analyze: <code>analysis = generator.analyze_genome(genome)</code></li>
            <li>Visualize: <code>static_file = visualize_genome_notebook(genome, analysis)</code></li>
            <li>Interactive: <code>fig, html_file = create_interactive_genome_viewer_notebook(genome, analysis)</code></li>
        </ol>
        
        <h3>📁 Automatic File Saving</h3>
        <p>The generator automatically saves these files:</p>
        <ul>
            <li>🧬 <strong>dnabert_generated_genome.fasta</strong> - Genome sequence in FASTA format</li>
            <li>📊 <strong>Ch16-3-dnabert_genome_analysis_static.png</strong> - Static analysis plots (matplotlib)</li>
            <li>🖱️ <strong>Ch16-3-dnabert_genome_analysis_interactive.png</strong> - Interactive plots as PNG (requires kaleido)</li>
            <li>🌐 <strong>Ch16-3-dnabert_genome_analysis_interactive.html</strong> - Fully interactive HTML dashboard</li>
            <li>📋 <strong>Ch16-3-dnabert_genome_analysis_summary.txt</strong> - Complete text summary of results</li>
        </ul>
        
        <h3>✨ Features</h3>
        <ul>
            <li>🤖 <strong>Real DNABERT optimization</strong> using learned genomic patterns</li>
            <li>📊 <strong>Multiple visualization formats</strong> (PNG, HTML, interactive)</li>
            <li>📈 <strong>Simple progress tracking</strong> (no widget dependencies!)</li>
            <li>🎨 <strong>Publication-ready figures</strong> with high-resolution PNG export</li>
            <li>💾 <strong>Complete file export</strong> for downstream analysis</li>
            <li>⚡ <strong>Works in all Jupyter environments</strong> (Lab, Notebook, Colab)</li>
        </ul>
        
        <div style="background-color: #fff3cd; border: 1px solid #ffeaa7; padding: 10px; border-radius: 5px; margin-top: 15px;">
            <strong>💡 File Export Features!</strong><br>
            All figures are automatically saved as high-resolution PNG files. Interactive plots are saved as both PNG (static) and HTML (fully interactive) formats.
        </div>
    </div>
    """
    
    display(HTML(instructions_html))

In [None]:
# Display Usage instructions
display_usage_instructions()

In [None]:
###### Run Geneome Generation Functions ######

In [None]:
generator = DNABERTGenomeGenerator()

In [None]:
genome = generator.generate_2kb_genome()

In [None]:
analysis = generator.analyze_genome(genome)

In [None]:
static_file = visualize_genome_notebook(genome, analysis)

In [None]:
fig, html_file = create_interactive_genome_viewer_notebook(genome, analysis)

In [None]:
# Run this to execute the entire workflow and create the FASTA file
generator, genome, analysis, saved_files = generate_and_analyze_genome()

In [None]:
## End of Notebook ##