# A Transformer Model for Language Translation


In [1]:
from IPython.display import HTML

# Copy the entire HTML content and assign it to a variable
html_content = """
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Evolution of Machine Translation</title>
    <style>
        body {
            font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
            margin: 0;
            padding: 20px;
            background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
            color: #333;
        }

        .container {
            max-width: 1400px;
            margin: 0 auto;
        }

        h1 {
            text-align: center;
            color: white;
            font-size: 2.5em;
            margin-bottom: 40px;
            text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
        }

        .timeline {
            display: flex;
            flex-direction: column;
            gap: 40px;
        }

        .era {
            background: white;
            border-radius: 20px;
            padding: 30px;
            box-shadow: 0 10px 30px rgba(0,0,0,0.1);
            position: relative;
            margin-left: 50px;
        }

        .era::before {
            content: '';
            position: absolute;
            left: -50px;
            top: 30px;
            width: 40px;
            height: 40px;
            border-radius: 50%;
            background: linear-gradient(45deg, #667eea, #764ba2);
            color: white;
            display: flex;
            align-items: center;
            justify-content: center;
            font-weight: bold;
            z-index: 2;
        }

        .timeline-line {
            position: absolute;
            left: -30px;
            top: 0;
            bottom: 0;
            width: 4px;
            background: linear-gradient(to bottom, #667eea, #764ba2);
            z-index: 1;
        }

        .era-header {
            display: flex;
            justify-content: space-between;
            align-items: center;
            margin-bottom: 20px;
            padding-bottom: 15px;
            border-bottom: 3px solid #e2e8f0;
        }

        .era-title {
            font-size: 1.8em;
            font-weight: bold;
            color: #2d3748;
        }

        .era-period {
            background: linear-gradient(45deg, #667eea, #764ba2);
            color: white;
            padding: 8px 16px;
            border-radius: 20px;
            font-weight: bold;
        }

        .architecture-diagram {
            background: #f8f9fa;
            border: 2px solid #e2e8f0;
            border-radius: 15px;
            padding: 25px;
            margin: 20px 0;
            text-align: center;
        }

        .example {
            background: #e3f2fd;
            border-left: 4px solid #2196f3;
            padding: 15px;
            margin: 15px 0;
            border-radius: 0 10px 10px 0;
        }

        .pros-cons {
            display: grid;
            grid-template-columns: 1fr 1fr;
            gap: 20px;
            margin: 20px 0;
        }

        .pros {
            background: #e8f5e8;
            border: 2px solid #4caf50;
            border-radius: 10px;
            padding: 15px;
        }

        .cons {
            background: #ffebee;
            border: 2px solid #f44336;
            border-radius: 10px;
            padding: 15px;
        }

        /* SMT Specific */
        .phrase-table {
            display: grid;
            grid-template-columns: repeat(3, 1fr);
            gap: 10px;
            margin: 20px 0;
        }

        .phrase-pair {
            background: #fff3e0;
            border: 2px solid #ff9800;
            border-radius: 8px;
            padding: 10px;
            text-align: center;
            font-size: 14px;
        }

        /* Neural Network Components */
        .nn-component {
            display: inline-block;
            background: linear-gradient(45deg, #ff6b6b, #ee5a52);
            color: white;
            padding: 15px 20px;
            border-radius: 12px;
            margin: 5px;
            font-weight: bold;
            min-width: 80px;
            text-align: center;
        }

        .nn-component.encoder {
            background: linear-gradient(45deg, #4ecdc4, #44a08d);
        }

        .nn-component.decoder {
            background: linear-gradient(45deg, #a8e6cf, #7fcdcd);
        }

        .nn-component.attention {
            background: linear-gradient(45deg, #ffd93d, #ff9f1a);
        }

        .nn-component.transformer {
            background: linear-gradient(45deg, #667eea, #764ba2);
        }

        .arrow {
            display: inline-block;
            margin: 0 10px;
            font-size: 20px;
            color: #4a5568;
        }

        .flow-diagram {
            display: flex;
            align-items: center;
            justify-content: center;
            flex-wrap: wrap;
            margin: 20px 0;
        }

        .attention-matrix {
            display: grid;
            grid-template-columns: repeat(4, 40px);
            gap: 2px;
            margin: 20px auto;
            width: fit-content;
        }

        .attention-cell {
            width: 40px;
            height: 40px;
            border-radius: 4px;
            display: flex;
            align-items: center;
            justify-content: center;
            font-size: 12px;
            color: white;
            font-weight: bold;
        }

        .quality-meter {
            display: flex;
            align-items: center;
            margin: 15px 0;
        }

        .quality-bar {
            width: 200px;
            height: 20px;
            background: #e2e8f0;
            border-radius: 10px;
            margin: 0 15px;
            overflow: hidden;
        }

        .quality-fill {
            height: 100%;
            border-radius: 10px;
            transition: width 0.5s ease;
        }

        .interactive-btn {
            background: linear-gradient(45deg, #667eea, #764ba2);
            color: white;
            border: none;
            padding: 12px 25px;
            border-radius: 25px;
            cursor: pointer;
            font-weight: bold;
            margin: 10px;
            transition: transform 0.2s;
        }

        .interactive-btn:hover {
            transform: translateY(-2px);
        }
    </style>
</head>
<body>
    <div class="container">
        <h1>📈 Evolution of Machine Translation</h1>

        <div class="timeline">
            <div class="timeline-line"></div>

            <!-- Statistical MT Era -->
            <div class="era">
                <div class="era-header">
                    <div class="era-title">🔢 Statistical Machine Translation (SMT)</div>
                    <div class="era-period">1990s - 2014</div>
                </div>

                <div class="architecture-diagram">
                    <h3>Rule-Based + Statistical Approach</h3>
                    <div class="phrase-table">
                        <div class="phrase-pair">
                            <strong>German:</strong><br>"guten Morgen"
                        </div>
                        <div class="phrase-pair">
                            <strong>→</strong><br>Translation Rules
                        </div>
                        <div class="phrase-pair">
                            <strong>English:</strong><br>"good morning"
                        </div>
                        <div class="phrase-pair">
                            <strong>German:</strong><br>"ich bin"
                        </div>
                        <div class="phrase-pair">
                            <strong>→</strong><br>Phrase Table
                        </div>
                        <div class="phrase-pair">
                            <strong>English:</strong><br>"I am"
                        </div>
                        <div class="phrase-pair">
                            <strong>German:</strong><br>"das Haus"
                        </div>
                        <div class="phrase-pair">
                            <strong>→</strong><br>Word Alignment
                        </div>
                        <div class="phrase-pair">
                            <strong>English:</strong><br>"the house"
                        </div>
                    </div>
                </div>

                <div class="example">
                    <strong>How it worked:</strong> Break sentences into phrases, look up each phrase in a massive translation table, then use language models to make the output sound natural.
                </div>

                <div class="quality-meter">
                    <strong>Translation Quality:</strong>
                    <div class="quality-bar">
                        <div class="quality-fill" style="width: 40%; background: linear-gradient(90deg, #ff8787, #fa5252);"></div>
                    </div>
                    <span>40% - Basic but limited</span>
                </div>

                <div class="pros-cons">
                    <div class="pros">
                        <h4>✅ Pros:</h4>
                        • Fast and predictable<br>
                        • Works with limited data<br>
                        • Interpretable rules
                    </div>
                    <div class="cons">
                        <h4>❌ Cons:</h4>
                        • Rigid word-by-word translation<br>
                        • Can't handle context well<br>
                        • Awkward, unnatural output
                    </div>
                </div>
            </div>

            <!-- Early RNN Era -->
            <div class="era">
                <div class="era-header">
                    <div class="era-title">🧠 Early RNN Translation</div>
                    <div class="era-period">2013 - 2014</div>
                </div>

                <div class="architecture-diagram">
                    <h3>Direct RNN Mapping (Failed Approach)</h3>
                    <div class="flow-diagram">
                        <div style="background: #e3f2fd; padding: 10px; border-radius: 8px; margin: 5px;">Ich</div>
                        <div class="arrow">→</div>
                        <div class="nn-component">RNN</div>
                        <div class="arrow">→</div>
                        <div style="background: #e8f5e8; padding: 10px; border-radius: 8px; margin: 5px;">I</div>
                    </div>
                    <div class="flow-diagram">
                        <div style="background: #e3f2fd; padding: 10px; border-radius: 8px; margin: 5px;">liebe</div>
                        <div class="arrow">→</div>
                        <div class="nn-component">RNN</div>
                        <div class="arrow">→</div>
                        <div style="background: #e8f5e8; padding: 10px; border-radius: 8px; margin: 5px;">love</div>
                    </div>
                    <div class="flow-diagram">
                        <div style="background: #e3f2fd; padding: 10px; border-radius: 8px; margin: 5px;">dich</div>
                        <div class="arrow">→</div>
                        <div class="nn-component">RNN</div>
                        <div class="arrow">→</div>
                        <div style="background: #e8f5e8; padding: 10px; border-radius: 8px; margin: 5px;">you</div>
                    </div>
                </div>

                <div class="example">
                    <strong>Problem:</strong> "Ich habe das Buch gelesen" → Direct mapping fails because "gelesen" (read) should come earlier in English: "I have <strong>read</strong> the book"
                </div>

                <div class="quality-meter">
                    <strong>Translation Quality:</strong>
                    <div class="quality-bar">
                        <div class="quality-fill" style="width: 25%; background: linear-gradient(90deg, #e53e3e, #c53030);"></div>
                    </div>
                    <span>25% - Worse than SMT</span>
                </div>

                <div class="pros-cons">
                    <div class="pros">
                        <h4>✅ Pros:</h4>
                        • Neural networks can learn<br>
                        • End-to-end training<br>
                        • Potential for improvement
                    </div>
                    <div class="cons">
                        <h4>❌ Cons:</h4>
                        • Word order problems<br>
                        • Length mismatches<br>
                        • Couldn't capture sentence meaning
                    </div>
                </div>
            </div>

            <!-- Seq2Seq RNN Era -->
            <div class="era">
                <div class="era-header">
                    <div class="era-title">🔄 Seq2Seq RNN/LSTM</div>
                    <div class="era-period">2014 - 2016</div>
                </div>

                <div class="architecture-diagram">
                    <h3>Encoder-Decoder Architecture</h3>
                    <div class="flow-diagram">
                        <div style="background: #e3f2fd; padding: 10px; border-radius: 8px; margin: 5px;">Ich<br>liebe<br>dich</div>
                        <div class="arrow">→</div>
                        <div class="nn-component encoder">ENCODER<br>RNN/LSTM</div>
                        <div class="arrow">→</div>
                        <div style="background: #fff3e0; padding: 15px; border-radius: 8px; margin: 5px; border: 2px solid #ff9800;">
                            <strong>Context Vector</strong><br>
                            <span style="font-size: 12px;">[0.3, -0.1, 0.8, ...]</span>
                        </div>
                        <div class="arrow">→</div>
                        <div class="nn-component decoder">DECODER<br>RNN/LSTM</div>
                        <div class="arrow">→</div>
                        <div style="background: #e8f5e8; padding: 10px; border-radius: 8px; margin: 5px;">I<br>love<br>you</div>
                    </div>
                </div>

                <div class="example">
                    <strong>Breakthrough:</strong> Encoder reads entire German sentence first, creates a "meaning vector", then decoder generates English from that understanding. Order problems solved!
                </div>

                <div class="quality-meter">
                    <strong>Translation Quality:</strong>
                    <div class="quality-bar">
                        <div class="quality-fill" style="width: 65%; background: linear-gradient(90deg, #ffd43b, #fab005);"></div>
                    </div>
                    <span>65% - Major improvement!</span>
                </div>

                <div class="pros-cons">
                    <div class="pros">
                        <h4>✅ Pros:</h4>
                        • Understands full sentence meaning<br>
                        • Handles word order differences<br>
                        • Variable length input/output<br>
                        • Much better fluency
                    </div>
                    <div class="cons">
                        <h4>❌ Cons:</h4>
                        • Context vector bottleneck<br>
                        • Forgets long sentences<br>
                        • Still struggles with very long text
                    </div>
                </div>
            </div>

            <!-- Attention Era -->
            <div class="era">
                <div class="era-header">
                    <div class="era-title">👁️ RNN/LSTM + Attention</div>
                    <div class="era-period">2015 - 2017</div>
                </div>

                <div class="architecture-diagram">
                    <h3>Encoder-Decoder + Attention Mechanism</h3>
                    <div class="flow-diagram" style="flex-direction: column;">
                        <div style="display: flex; align-items: center; margin: 10px 0;">
                            <div style="background: #e3f2fd; padding: 10px; border-radius: 8px; margin: 5px;">Ich liebe dich</div>
                            <div class="arrow">→</div>
                            <div class="nn-component encoder">ENCODER<br>RNN/LSTM</div>
                            <div class="arrow">→</div>
                            <div style="background: #fff3e0; padding: 10px; border-radius: 8px; margin: 5px;">h₁ h₂ h₃<br><span style="font-size: 12px;">All hidden states</span></div>
                        </div>

                        <div class="nn-component attention" style="margin: 20px auto;">
                            ATTENTION MECHANISM<br>
                            <span style="font-size: 12px;">Looks at relevant parts</span>
                        </div>

                        <div class="attention-matrix">
                            <div class="attention-cell" style="background: #4caf50;">0.8</div>
                            <div class="attention-cell" style="background: #8bc34a;">0.1</div>
                            <div class="attention-cell" style="background: #cddc39;">0.1</div>
                            <div style="font-size: 12px; text-align: center; color: #666;">Generating "I"</div>

                            <div class="attention-cell" style="background: #cddc39;">0.1</div>
                            <div class="attention-cell" style="background: #4caf50;">0.8</div>
                            <div class="attention-cell" style="background: #8bc34a;">0.1</div>
                            <div style="font-size: 12px; text-align: center; color: #666;">Generating "love"</div>

                            <div class="attention-cell" style="background: #8bc34a;">0.1</div>
                            <div class="attention-cell" style="background: #cddc39;">0.1</div>
                            <div class="attention-cell" style="background: #4caf50;">0.8</div>
                            <div style="font-size: 12px; text-align: center; color: #666;">Generating "you"</div>
                        </div>

                        <div style="display: flex; align-items: center; margin: 10px 0;">
                            <div class="nn-component decoder">DECODER<br>RNN/LSTM</div>
                            <div class="arrow">→</div>
                            <div style="background: #e8f5e8; padding: 10px; border-radius: 8px; margin: 5px;">I love you</div>
                        </div>
                    </div>
                </div>

                <div class="example">
                    <strong>Innovation:</strong> Decoder can "look back" at specific parts of the input sentence when generating each output word. No more information bottleneck!
                </div>

                <div class="quality-meter">
                    <strong>Translation Quality:</strong>
                    <div class="quality-bar">
                        <div class="quality-fill" style="width: 80%; background: linear-gradient(90deg, #51cf66, #40c057);"></div>
                    </div>
                    <span>80% - Near human-level for short sentences</span>
                </div>

                <div class="pros-cons">
                    <div class="pros">
                        <h4>✅ Pros:</h4>
                        • Solves long sentence problem<br>
                        • Can align input/output words<br>
                        • Much better context understanding<br>
                        • Interpretable attention weights
                    </div>
                    <div class="cons">
                        <h4>❌ Cons:</h4>
                        • Still sequential processing<br>
                        • Slow training and inference<br>
                        • RNN limitations remain
                    </div>
                </div>
            </div>

            <!-- Transformer Era -->
            <div class="era">
                <div class="era-header">
                    <div class="era-title">⚡ Transformers</div>
                    <div class="era-period">2017 - Present</div>
                </div>

                <div class="architecture-diagram">
                    <h3>"Attention is All You Need"</h3>
                    <div class="flow-diagram" style="flex-direction: column;">
                        <div style="display: flex; align-items: center; margin: 10px 0;">
                            <div style="background: #e3f2fd; padding: 10px; border-radius: 8px; margin: 5px;">Ich liebe dich</div>
                            <div class="arrow">→</div>
                            <div class="nn-component transformer">TRANSFORMER<br>ENCODER</div>
                            <div class="arrow">→</div>
                            <div style="background: #fff3e0; padding: 10px; border-radius: 8px; margin: 5px;">Contextual<br>Representations</div>
                        </div>

                        <div class="nn-component attention" style="margin: 20px auto; width: 300px;">
                            MULTI-HEAD SELF-ATTENTION<br>
                            <span style="font-size: 12px;">Parallel processing + full context</span>
                        </div>

                        <div style="display: flex; align-items: center; margin: 10px 0;">
                            <div class="nn-component transformer">TRANSFORMER<br>DECODER</div>
                            <div class="arrow">→</div>
                            <div style="background: #e8f5e8; padding: 10px; border-radius: 8px; margin: 5px;">I love you</div>
                        </div>
                    </div>
                </div>

                <div class="example">
                    <strong>Revolution:</strong> No RNNs at all! Pure attention mechanisms process all words in parallel. Each word can attend to every other word simultaneously.
                </div>

                <div class="quality-meter">
                    <strong>Translation Quality:</strong>
                    <div class="quality-bar">
                        <div class="quality-fill" style="width: 95%; background: linear-gradient(90deg, #667eea, #764ba2);"></div>
                    </div>
                    <span>95% - Human-level for most content</span>
                </div>

                <div class="pros-cons">
                    <div class="pros">
                        <h4>✅ Pros:</h4>
                        • Parallel processing (much faster)<br>
                        • Perfect long-range dependencies<br>
                        • Scales to massive datasets<br>
                        • State-of-the-art results<br>
                        • Transfer learning capability
                    </div>
                    <div class="cons">
                        <h4>❌ Cons:</h4>
                        • Computationally expensive<br>
                        • Requires lots of data<br>
                        • Complex architecture<br>
                        • Memory intensive
                    </div>
                </div>
            </div>
        </div>

        <div style="text-align: center; margin: 40px 0;">
            <button class="interactive-btn" onclick="animateEvolution()">🚀 Animate Evolution</button>
            <button class="interactive-btn" onclick="showComparison()">📊 Compare All Methods</button>
        </div>
    </div>

    <script>
        function animateEvolution() {
            const eras = document.querySelectorAll('.era');
            eras.forEach((era, index) => {
                era.style.opacity = '0.3';
                era.style.transform = 'translateX(-50px)';

                setTimeout(() => {
                    era.style.transition = 'all 0.8s ease';
                    era.style.opacity = '1';
                    era.style.transform = 'translateX(0)';
                }, index * 1000);
            });
        }

        function showComparison() {
            const qualities = [40, 25, 65, 80, 95];
            const labels = ['SMT', 'Early RNN', 'Seq2Seq', 'RNN+Attention', 'Transformers'];
            const colors = [
                'linear-gradient(90deg, #ff8787, #fa5252)',
                'linear-gradient(90deg, #e53e3e, #c53030)',
                'linear-gradient(90deg, #ffd43b, #fab005)',
                'linear-gradient(90deg, #51cf66, #40c057)',
                'linear-gradient(90deg, #667eea, #764ba2)'
            ];

            // Create comparison popup
            const popup = document.createElement('div');
            popup.style.cssText = `
                position: fixed;
                top: 50%;
                left: 50%;
                transform: translate(-50%, -50%);
                background: white;
                padding: 40px;
                border-radius: 20px;
                box-shadow: 0 20px 60px rgba(0,0,0,0.3);
                z-index: 1000;
                max-width: 600px;
                width: 90%;
            `;

            popup.innerHTML = `
                <h2>Translation Quality Comparison</h2>
                ${qualities.map((quality, index) => `
                    <div style="margin: 15px 0; display: flex; align-items: center;">
                        <span style="width: 120px; font-weight: bold;">${labels[index]}:</span>
                        <div style="width: 200px; height: 25px; background: #e2e8f0; border-radius: 12px; overflow: hidden; margin: 0 15px;">
                            <div style="width: ${quality}%; height: 100%; background: ${colors[index]}; border-radius: 12px; transition: width 1s ease;"></div>
                        </div>
                        <span style="font-weight: bold;">${quality}%</span>
                    </div>
                `).join('')}
                <button onclick="this.parentElement.remove()" style="background: #667eea; color: white; border: none; padding: 10px 20px; border-radius: 10px; margin-top: 20px; cursor: pointer;">Close</button>
            `;

            document.body.appendChild(popup);

            // Animate bars
            setTimeout(() => {
                const bars = popup.querySelectorAll('div[style*="width:"]');
                bars.forEach((bar, index) => {
                    setTimeout(() => {
                        bar.style.width = qualities[index] + '%';
                    }, index * 200);
                });
            }, 100);
        }

        // Add initial animation
        setTimeout(() => {
            animateEvolution();
        }, 1000);
    </script>
</body>
</html>
"""

# Display it
HTML(html_content)

In [2]:
from IPython.display import HTML

# Copy the entire HTML content and assign it to a variable
html_content2 = """
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>ML Model Diagrams</title>
    <style>
        body {
            font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
            margin: 0;
            padding: 20px;
            background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
            color: #333;
        }

        .container {
            max-width: 1400px;
            margin: 0 auto;
        }

        .diagram-section {
            background: white;
            border-radius: 15px;
            padding: 30px;
            margin: 30px 0;
            box-shadow: 0 10px 30px rgba(0,0,0,0.1);
        }

        h1 {
            text-align: center;
            color: white;
            font-size: 2.5em;
            margin-bottom: 40px;
            text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
        }

        h2 {
            color: #4a5568;
            border-bottom: 3px solid #667eea;
            padding-bottom: 10px;
            margin-bottom: 25px;
        }

        /* RNN Architecture */
        .rnn-architecture {
            display: flex;
            flex-direction: column;
            align-items: center;
            margin: 40px 0;
        }

        .sequence-input {
            display: flex;
            gap: 20px;
            margin-bottom: 30px;
        }

        .input-word {
            background: #e3f2fd;
            border: 2px solid #2196f3;
            border-radius: 10px;
            padding: 15px 20px;
            font-weight: bold;
            color: #1976d2;
        }

        .rnn-layer {
            display: flex;
            align-items: center;
            gap: 20px;
            margin: 20px 0;
        }

        .rnn-cell {
            width: 100px;
            height: 80px;
            background: linear-gradient(45deg, #ff6b6b, #ee5a52);
            border-radius: 15px;
            display: flex;
            flex-direction: column;
            align-items: center;
            justify-content: center;
            color: white;
            font-weight: bold;
            position: relative;
        }

        .hidden-state {
            position: absolute;
            top: -40px;
            background: #fff3e0;
            border: 2px solid #ff9800;
            border-radius: 8px;
            padding: 5px 10px;
            font-size: 12px;
            color: #e65100;
            font-weight: bold;
        }

        .rnn-arrow {
            width: 40px;
            height: 3px;
            background: #4a5568;
            position: relative;
        }

        .rnn-arrow::after {
            content: '';
            position: absolute;
            right: -8px;
            top: -5px;
            width: 0;
            height: 0;
            border-left: 12px solid #4a5568;
            border-top: 6px solid transparent;
            border-bottom: 6px solid transparent;
        }

        .vertical-arrow {
            width: 3px;
            height: 30px;
            background: #4a5568;
            position: relative;
            margin: 10px auto;
        }

        .vertical-arrow::after {
            content: '';
            position: absolute;
            bottom: -8px;
            left: -5px;
            width: 0;
            height: 0;
            border-top: 12px solid #4a5568;
            border-left: 6px solid transparent;
            border-right: 6px solid transparent;
        }

        /* LSTM Architecture */
        .lstm-cell {
            width: 140px;
            height: 120px;
            background: linear-gradient(45deg, #4ecdc4, #44a08d);
            border-radius: 20px;
            display: flex;
            flex-direction: column;
            align-items: center;
            justify-content: space-around;
            color: white;
            font-weight: bold;
            position: relative;
            padding: 10px;
        }

        .lstm-gate {
            background: rgba(255,255,255,0.3);
            border-radius: 8px;
            padding: 5px 8px;
            font-size: 10px;
            margin: 2px 0;
            text-align: center;
        }

        .forget-gate { background: rgba(255,87,87,0.8); }
        .input-gate { background: rgba(72,187,120,0.8); }
        .output-gate { background: rgba(66,153,225,0.8); }

        .cell-state {
            position: absolute;
            top: -50px;
            left: 50%;
            transform: translateX(-50%);
            background: #ffd54f;
            border: 3px solid #ff8f00;
            border-radius: 10px;
            padding: 8px 12px;
            font-size: 12px;
            color: #e65100;
            font-weight: bold;
        }

        /* Beam Search Detailed */
        .beam-container {
            margin: 40px 0;
        }

        .beam-step {
            margin: 30px 0;
            padding: 20px;
            border: 2px solid #e2e8f0;
            border-radius: 15px;
            background: #f8f9fa;
        }

        .beam-header {
            font-weight: bold;
            color: #2d3748;
            margin-bottom: 15px;
            font-size: 18px;
        }

        .candidates-row {
            display: flex;
            gap: 15px;
            margin: 15px 0;
            flex-wrap: wrap;
        }

        .candidate {
            background: linear-gradient(45deg, #667eea, #764ba2);
            color: white;
            padding: 12px 15px;
            border-radius: 12px;
            min-width: 120px;
            text-align: center;
            position: relative;
        }

        .candidate.kept {
            background: linear-gradient(45deg, #48bb78, #38a169);
            border: 3px solid #22543d;
        }

        .candidate.pruned {
            background: linear-gradient(45deg, #e53e3e, #c53030);
            opacity: 0.6;
            text-decoration: line-through;
        }

        .score {
            font-size: 11px;
            opacity: 0.9;
            margin-top: 5px;
        }

        .beam-width {
            background: #fed7d7;
            border: 2px solid #e53e3e;
            border-radius: 10px;
            padding: 10px;
            margin: 15px 0;
            text-align: center;
            font-weight: bold;
            color: #c53030;
        }

        .pruning-explanation {
            background: #bee3f8;
            border: 2px solid #3182ce;
            border-radius: 10px;
            padding: 15px;
            margin: 15px 0;
            color: #2a69ac;
        }

        /* Interactive elements */
        .interactive-btn {
            background: linear-gradient(45deg, #667eea, #764ba2);
            color: white;
            border: none;
            padding: 12px 25px;
            border-radius: 25px;
            cursor: pointer;
            font-weight: bold;
            margin: 10px;
            transition: transform 0.2s;
            font-size: 16px;
        }

        .interactive-btn:hover {
            transform: translateY(-2px);
        }

        /* Transformer Flow (keeping as is) */
        .transformer-flow {
            display: flex;
            flex-direction: column;
            align-items: center;
            gap: 30px;
        }

        .flow-step {
            background: linear-gradient(45deg, #667eea, #764ba2);
            color: white;
            padding: 20px;
            border-radius: 15px;
            text-align: center;
            min-width: 200px;
            position: relative;
        }

        .flow-arrow {
            width: 0;
            height: 0;
            border-left: 15px solid transparent;
            border-right: 15px solid transparent;
            border-top: 20px solid #667eea;
            margin: -5px auto;
        }

        .probability-bar {
            background: #e2e8f0;
            height: 20px;
            border-radius: 10px;
            margin: 10px 0;
            overflow: hidden;
            position: relative;
        }

        .probability-fill {
            height: 100%;
            border-radius: 10px;
            transition: width 0.5s ease;
        }

        .word-example {
            background: #f7fafc;
            border: 2px solid #e2e8f0;
            border-radius: 10px;
            padding: 15px;
            margin: 20px 0;
        }

        .explanation {
            background: #f8f9fa;
            border-left: 4px solid #667eea;
            padding: 15px;
            margin: 20px 0;
            border-radius: 0 10px 10px 0;
        }
    </style>
</head>
<body>
    <div class="container">
        <h1>🤖 AI Model Architecture Deep Dive</h1>

        <!-- RNN Architecture Section -->
        <div class="diagram-section">
            <h2>1A. RNN (Recurrent Neural Network) Architecture</h2>
            <div class="explanation">
                <strong>Key Challenge:</strong> Information from early words gets "forgotten" as the sequence progresses. The hidden state gets overwritten at each step.
            </div>

            <div class="rnn-architecture">
                <div class="sequence-input">
                    <div class="input-word">x₁: "The"</div>
                    <div class="input-word">x₂: "cat"</div>
                    <div class="input-word">x₃: "sat"</div>
                </div>

                <div class="vertical-arrow"></div>

                <div class="rnn-layer">
                    <div class="rnn-cell">
                        <div class="hidden-state">h₀: [0.0]</div>
                        <div>RNN</div>
                        <div style="font-size: 12px;">Cell 1</div>
                    </div>
                    <div class="rnn-arrow"></div>
                    <div class="rnn-cell">
                        <div class="hidden-state">h₁: [0.3]</div>
                        <div>RNN</div>
                        <div style="font-size: 12px;">Cell 2</div>
                    </div>
                    <div class="rnn-arrow"></div>
                    <div class="rnn-cell">
                        <div class="hidden-state">h₂: [0.1]</div>
                        <div>RNN</div>
                        <div style="font-size: 12px;">Cell 3</div>
                    </div>
                </div>

                <div class="explanation">
                    <strong>Problem:</strong> h₂ has very little information about "The" (the first word) because it gets diluted through multiple transformations. <br>
                    <strong>Formula:</strong> h₂ = tanh(W·[x₂, h₁]) where h₁ = tanh(W·[x₁, h₀])
                </div>
            </div>
        </div>

        <!-- RNN Cell Internal Section -->
        <div class="diagram-section">
            <h2>1B. Inside an RNN Cell</h2>
            <div class="explanation">
                <strong>What happens inside:</strong> Simple mathematical operations that blend current input with previous memory.
            </div>

            <div style="display: flex; justify-content: center; align-items: center; margin: 40px 0;">
                <div style="background: #f8f9fa; border: 3px solid #667eea; border-radius: 20px; padding: 30px; position: relative; width: 400px;">
                    <!-- Input arrows -->
                    <div style="position: absolute; top: -40px; left: 50px; text-align: center;">
                        <div style="background: #e3f2fd; border: 2px solid #2196f3; border-radius: 8px; padding: 8px 12px; font-weight: bold; color: #1976d2;">x_t</div>
                        <div class="vertical-arrow" style="height: 20px; margin: 5px auto;"></div>
                    </div>
                    <div style="position: absolute; top: -40px; right: 50px; text-align: center;">
                        <div style="background: #fff3e0; border: 2px solid #ff9800; border-radius: 8px; padding: 8px 12px; font-weight: bold; color: #e65100;">h_{t-1}</div>
                        <div class="vertical-arrow" style="height: 20px; margin: 5px auto;"></div>
                    </div>

                    <!-- Internal computation -->
                    <div style="text-align: center; margin: 20px 0;">
                        <div style="background: #e8f5e8; border: 2px solid #4caf50; border-radius: 12px; padding: 15px; margin: 10px 0;">
                            <strong>Step 1: Combine Inputs</strong><br>
                            <code style="background: white; padding: 5px; border-radius: 5px;">combined = W_x * x_t + W_h * h_{t-1} + b</code>
                        </div>
                        <div style="background: #fff3e0; border: 2px solid #ff9800; border-radius: 12px; padding: 15px; margin: 10px 0;">
                            <strong>Step 2: Activation</strong><br>
                            <code style="background: white; padding: 5px; border-radius: 5px;">h_t = tanh(combined)</code>
                        </div>
                    </div>

                    <!-- Output arrow -->
                    <div style="position: absolute; bottom: -40px; left: 50%; transform: translateX(-50%); text-align: center;">
                        <div class="vertical-arrow" style="height: 20px; margin: 5px auto; transform: rotate(180deg);"></div>
                        <div style="background: #fff3e0; border: 2px solid #ff9800; border-radius: 8px; padding: 8px 12px; font-weight: bold; color: #e65100;">h_t</div>
                    </div>
                </div>
            </div>

            <div class="explanation">
                <strong>Problem:</strong> The same hidden state h_t serves as both memory and output. As sequences get longer, early information gets "overwritten" by newer information.
            </div>
        </div>

        <!-- LSTM Cell Internal Section -->
        <div class="diagram-section">
            <h2>1C. Inside an LSTM Cell</h2>
            <div class="explanation">
                <strong>Key Innovation:</strong> Separate cell state (C_t) for long-term memory and hidden state (h_t) for output. Gates control information flow.
            </div>

            <div style="display: flex; justify-content: center; align-items: center; margin: 40px 0;">
                <div style="background: #f8f9fa; border: 3px solid #4ecdc4; border-radius: 20px; padding: 30px; position: relative; width: 500px; height: 400px;">
                    <!-- Inputs -->
                    <div style="position: absolute; top: -40px; left: 100px; text-align: center;">
                        <div style="background: #e3f2fd; border: 2px solid #2196f3; border-radius: 8px; padding: 8px 12px; font-weight: bold; color: #1976d2;">x_t</div>
                        <div class="vertical-arrow" style="height: 20px; margin: 5px auto;"></div>
                    </div>
                    <div style="position: absolute; top: -40px; right: 100px; text-align: center;">
                        <div style="background: #fff3e0; border: 2px solid #ff9800; border-radius: 8px; padding: 8px 12px; font-weight: bold; color: #e65100;">h_{t-1}</div>
                        <div class="vertical-arrow" style="height: 20px; margin: 5px auto;"></div>
                    </div>

                    <!-- Cell state flow (top) -->
                    <div style="position: absolute; top: 10px; left: 20px; right: 20px; display: flex; align-items: center;">
                        <div style="background: #ffd54f; border: 2px solid #ff8f00; border-radius: 8px; padding: 5px 10px; font-size: 12px; font-weight: bold;">C_{t-1}</div>
                        <div style="flex: 1; height: 3px; background: #ff8f00; margin: 0 10px; position: relative;">
                            <div style="position: absolute; right: -5px; top: -5px; width: 0; height: 0; border-left: 10px solid #ff8f00; border-top: 5px solid transparent; border-bottom: 5px solid transparent;"></div>
                        </div>
                        <div style="background: #ffd54f; border: 2px solid #ff8f00; border-radius: 8px; padding: 5px 10px; font-size: 12px; font-weight: bold;">C_t</div>
                    </div>

                    <!-- Gates (middle section) -->
                    <div style="margin-top: 60px; display: grid; grid-template-columns: 1fr 1fr 1fr; gap: 15px; text-align: center;">
                        <!-- Forget Gate -->
                        <div style="background: rgba(255,87,87,0.8); border-radius: 12px; padding: 10px; color: white;">
                            <div style="font-weight: bold; font-size: 14px;">Forget Gate</div>
                            <div style="font-size: 10px; margin: 5px 0;">f_t = σ(W_f[h_{t-1}, x_t])</div>
                            <div style="font-size: 12px;">What to forget?</div>
                        </div>

                        <!-- Input Gate -->
                        <div style="background: rgba(72,187,120,0.8); border-radius: 12px; padding: 10px; color: white;">
                            <div style="font-weight: bold; font-size: 14px;">Input Gate</div>
                            <div style="font-size: 10px; margin: 5px 0;">i_t = σ(W_i[h_{t-1}, x_t])</div>
                            <div style="font-size: 12px;">What to store?</div>
                        </div>

                        <!-- Output Gate -->
                        <div style="background: rgba(66,153,225,0.8); border-radius: 12px; padding: 10px; color: white;">
                            <div style="font-weight: bold; font-size: 14px;">Output Gate</div>
                            <div style="font-size: 10px; margin: 5px 0;">o_t = σ(W_o[h_{t-1}, x_t])</div>
                            <div style="font-size: 12px;">What to output?</div>
                        </div>
                    </div>

                    <!-- Cell state update -->
                    <div style="margin-top: 20px; text-align: center;">
                        <div style="background: #e8f5e8; border: 2px solid #4caf50; border-radius: 12px; padding: 10px; font-size: 12px;">
                            <strong>Cell State Update:</strong><br>
                            <code style="background: white; padding: 3px; border-radius: 3px;">C_t = f_t ⊙ C_{t-1} + i_t ⊙ tanh(W_c[h_{t-1}, x_t])</code>
                        </div>
                    </div>

                    <!-- Hidden state output -->
                    <div style="margin-top: 15px; text-align: center;">
                        <div style="background: #fff3e0; border: 2px solid #ff9800; border-radius: 12px; padding: 10px; font-size: 12px;">
                            <strong>Hidden State:</strong><br>
                            <code style="background: white; padding: 3px; border-radius: 3px;">h_t = o_t ⊙ tanh(C_t)</code>
                        </div>
                    </div>

                    <!-- Output -->
                    <div style="position: absolute; bottom: -40px; left: 50%; transform: translateX(-50%); text-align: center;">
                        <div class="vertical-arrow" style="height: 20px; margin: 5px auto; transform: rotate(180deg);"></div>
                        <div style="background: #fff3e0; border: 2px solid #ff9800; border-radius: 8px; padding: 8px 12px; font-weight: bold; color: #e65100;">h_t</div>
                    </div>
                </div>
            </div>

            <div class="explanation">
                <strong>Key Symbols:</strong> σ = sigmoid function (0-1), ⊙ = element-wise multiplication, tanh = hyperbolic tangent (-1 to 1)<br>
                <strong>Solution:</strong> Cell state C_t flows through with minimal changes, preserving long-term information while gates control what gets added, removed, or output.
            </div>
        </div>

        <!-- LSTM Architecture Section -->
        <div class="diagram-section">
            <h2>1D. LSTM (Long Short-Term Memory) Architecture</h2>
            <div class="explanation">
                <strong>Key Innovation:</strong> Separate cell state (C) flows through unchanged unless gates decide to modify it. This preserves long-term information.
            </div>

            <div class="rnn-architecture">
                <div class="sequence-input">
                    <div class="input-word">x₁: "The"</div>
                    <div class="input-word">x₂: "cat"</div>
                    <div class="input-word">x₃: "sat"</div>
                </div>

                <div class="vertical-arrow"></div>

                <div class="rnn-layer">
                    <div class="lstm-cell">
                        <div class="cell-state">C₀: [0.0]</div>
                        <div class="lstm-gate forget-gate">Forget Gate</div>
                        <div class="lstm-gate input-gate">Input Gate</div>
                        <div class="lstm-gate output-gate">Output Gate</div>
                        <div style="font-size: 12px; margin-top: 5px;">LSTM 1</div>
                    </div>
                    <div class="rnn-arrow"></div>
                    <div class="lstm-cell">
                        <div class="cell-state">C₁: [0.8]</div>
                        <div class="lstm-gate forget-gate">Forget Gate</div>
                        <div class="lstm-gate input-gate">Input Gate</div>
                        <div class="lstm-gate output-gate">Output Gate</div>
                        <div style="font-size: 12px; margin-top: 5px;">LSTM 2</div>
                    </div>
                    <div class="rnn-arrow"></div>
                    <div class="lstm-cell">
                        <div class="cell-state">C₂: [0.7]</div>
                        <div class="lstm-gate forget-gate">Forget Gate</div>
                        <div class="lstm-gate input-gate">Input Gate</div>
                        <div class="lstm-gate output-gate">Output Gate</div>
                        <div style="font-size: 12px; margin-top: 5px;">LSTM 3</div>
                    </div>
                </div>

                <div class="explanation">
                    <strong>Solution:</strong> C₂ still contains information about "The" because the cell state can flow through unchanged. Gates control what to remember/forget.<br>
                    <strong>Gates:</strong>
                    <span style="color: #e53e3e;">Forget</span> (what to remove),
                    <span style="color: #38a169;">Input</span> (what to add),
                    <span style="color: #3182ce;">Output</span> (what to output)
                </div>
            </div>
        </div>

        <!-- Beam Search Detailed Section -->
        <div class="diagram-section">
            <h2>2. Beam Search with Pruning (Beam Width = 3)</h2>
            <div class="explanation">
                <strong>Goal:</strong> Translate "Der Hund" → "The dog". At each step, keep only the top 3 sequences and prune the rest.
            </div>

            <div class="beam-container">
                <!-- Step 1 -->
                <div class="beam-step">
                    <div class="beam-header">Step 1: Generate first word candidates</div>
                    <div class="candidates-row">
                        <div class="candidate kept">
                            <div>"The"</div>
                            <div class="score">log_prob: -0.1</div>
                        </div>
                        <div class="candidate kept">
                            <div>"A"</div>
                            <div class="score">log_prob: -0.3</div>
                        </div>
                        <div class="candidate kept">
                            <div>"This"</div>
                            <div class="score">log_prob: -0.5</div>
                        </div>
                        <div class="candidate pruned">
                            <div>"An"</div>
                            <div class="score">log_prob: -0.8</div>
                        </div>
                        <div class="candidate pruned">
                            <div>"My"</div>
                            <div class="score">log_prob: -1.2</div>
                        </div>
                    </div>
                    <div class="beam-width">✂️ Pruned 2 candidates, kept top 3</div>
                </div>

                <!-- Step 2 -->
                <div class="beam-step">
                    <div class="beam-header">Step 2: Extend each kept sequence</div>
                    <div style="margin: 10px 0; font-weight: bold;">From "The" + next word:</div>
                    <div class="candidates-row">
                        <div class="candidate kept">
                            <div>"The dog"</div>
                            <div class="score">log_prob: -0.3</div>
                        </div>
                        <div class="candidate">
                            <div>"The cat"</div>
                            <div class="score">log_prob: -0.6</div>
                        </div>
                        <div class="candidate">
                            <div>"The man"</div>
                            <div class="score">log_prob: -0.9</div>
                        </div>
                    </div>

                    <div style="margin: 10px 0; font-weight: bold;">From "A" + next word:</div>
                    <div class="candidates-row">
                        <div class="candidate kept">
                            <div>"A dog"</div>
                            <div class="score">log_prob: -0.5</div>
                        </div>
                        <div class="candidate">
                            <div>"A cat"</div>
                            <div class="score">log_prob: -0.7</div>
                        </div>
                    </div>

                    <div style="margin: 10px 0; font-weight: bold;">From "This" + next word:</div>
                    <div class="candidates-row">
                        <div class="candidate kept">
                            <div>"This dog"</div>
                            <div class="score">log_prob: -0.8</div>
                        </div>
                        <div class="candidate pruned">
                            <div>"This cat"</div>
                            <div class="score">log_prob: -1.0</div>
                        </div>
                    </div>

                    <div class="beam-width">✂️ Generated 6 candidates, kept top 3: "The dog", "A dog", "This dog"</div>
                </div>

                <!-- Step 3 -->
                <div class="beam-step">
                    <div class="beam-header">Step 3: Add END token</div>
                    <div class="candidates-row">
                        <div class="candidate kept">
                            <div>"The dog &lt;END&gt;"</div>
                            <div class="score">log_prob: -0.4</div>
                        </div>
                        <div class="candidate">
                            <div>"A dog &lt;END&gt;"</div>
                            <div class="score">log_prob: -0.6</div>
                        </div>
                        <div class="candidate">
                            <div>"This dog &lt;END&gt;"</div>
                            <div class="score">log_prob: -0.9</div>
                        </div>
                    </div>
                    <div class="pruning-explanation">
                        <strong>🏆 Winner:</strong> "The dog" with highest cumulative probability!<br>
                        <strong>Key:</strong> Beam search found a better overall sequence than greedy (which might have picked "A" first)
                    </div>
                </div>
            </div>

            <button class="interactive-btn" onclick="animateBeamPruning()">🔄 Animate Pruning Process</button>
        </div>

        <!-- Transformer Output Section (keeping as requested) -->
        <div class="diagram-section">
            <h2>3. Transformer Output Pipeline</h2>
            <div class="explanation">
                <strong>From numbers to words:</strong> How the transformer converts internal calculations into actual text.
            </div>

            <div class="transformer-flow">
                <div class="flow-step">
                    <h4>🧠 Transformer Output</h4>
                    <div style="font-family: monospace; background: rgba(255,255,255,0.2); padding: 10px; border-radius: 5px;">
                        Raw Logits: [2.3, -1.1, 4.7, 0.8, -0.3, ...]<br>
                        <small>One number for each word in vocabulary (10,000+ words)</small>
                    </div>
                </div>

                <div class="flow-arrow"></div>

                <div class="flow-step">
                    <h4>📊 Softmax Function</h4>
                    <div>Converts raw numbers to probabilities</div>
                    <div style="font-family: monospace; background: rgba(255,255,255,0.2); padding: 10px; border-radius: 5px; margin-top: 10px;">
                        Formula: P(word) = e^(logit) / Σ(e^(all_logits))
                    </div>
                </div>

                <div class="flow-arrow"></div>

                <div class="flow-step">
                    <h4>📈 Probability Distribution</h4>
                    <div class="word-example">
                        <div style="display: flex; justify-content: space-between; align-items: center; margin: 5px 0;">
                            <span>"cat"</span>
                            <div class="probability-bar" style="width: 150px;">
                                <div class="probability-fill" style="width: 65%; background: linear-gradient(90deg, #51cf66, #40c057);"></div>
                            </div>
                            <span>65%</span>
                        </div>
                        <div style="display: flex; justify-content: space-between; align-items: center; margin: 5px 0;">
                            <span>"dog"</span>
                            <div class="probability-bar" style="width: 150px;">
                                <div class="probability-fill" style="width: 20%; background: linear-gradient(90deg, #ffd43b, #fab005);"></div>
                            </div>
                            <span>20%</span>
                        </div>
                        <div style="display: flex; justify-content: space-between; align-items: center; margin: 5px 0;">
                            <span>"bird"</span>
                            <div class="probability-bar" style="width: 150px;">
                                <div class="probability-fill" style="width: 10%; background: linear-gradient(90deg, #ff8787, #fa5252);"></div>
                            </div>
                            <span>10%</span>
                        </div>
                        <div style="display: flex; justify-content: space-between; align-items: center; margin: 5px 0;">
                            <span>"others"</span>
                            <div class="probability-bar" style="width: 150px;">
                                <div class="probability-fill" style="width: 5%; background: linear-gradient(90deg, #adb5bd, #868e96);"></div>
                            </div>
                            <span>5%</span>
                        </div>
                    </div>
                </div>

                <div class="flow-arrow"></div>

                <div class="flow-step">
                    <h4>🎯 Final Word Selection</h4>
                    <div style="font-size: 1.5em; color: #51cf66; font-weight: bold;">
                        Selected: "cat" ✓
                    </div>
                    <small>Usually picks highest probability (greedy)<br>or uses beam search for better results</small>
                </div>
            </div>

            <button class="interactive-btn" onclick="animateTransformerFlow()">🔄 Show Different Predictions</button>
        </div>

        <!-- Key Differences Section (keeping as requested) -->
        <div class="diagram-section">
            <h2>4. Key Takeaways</h2>
            <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 20px;">
                <div class="explanation">
                    <h4>🔄 RNN vs LSTM</h4>
                    <strong>RNN:</strong> Information gets diluted through hidden states<br>
                    <strong>LSTM:</strong> Cell state preserves long-term memory via gates
                </div>
                <div class="explanation">
                    <h4>🌳 Beam Search Pruning</h4>
                    <strong>At each step:</strong> Generate all possibilities<br>
                    <strong>Then prune:</strong> Keep only top-K to avoid exponential explosion
                </div>
                <div class="explanation">
                    <h4>🧮 Logits → Probabilities</h4>
                    <strong>Logits:</strong> Raw math outputs (can be any number)<br>
                    <strong>Softmax:</strong> Converts to percentages that add up to 100%
                </div>
                <div class="explanation">
                    <h4>🎯 Word Selection</h4>
                    <strong>Deterministic:</strong> Always pick highest probability<br>
                    <strong>Sampling:</strong> Sometimes pick lower probabilities for creativity
                </div>
            </div>
        </div>
    </div>

    <script>
        function animateBeamPruning() {
            const steps = document.querySelectorAll('.beam-step');

            steps.forEach((step, stepIndex) => {
                setTimeout(() => {
                    const candidates = step.querySelectorAll('.candidate');
                    candidates.forEach((candidate, candIndex) => {
                        setTimeout(() => {
                            candidate.style.transform = 'scale(1.1)';
                            candidate.style.opacity = '1';
                            setTimeout(() => {
                                candidate.style.transform = 'scale(1)';
                                if (candidate.classList.contains('pruned')) {
                                    candidate.style.opacity = '0.4';
                                }
                            }, 300);
                        }, candIndex * 200);
                    });
                }, stepIndex * 1500);
            });
        }

        function animateTransformerFlow() {
            const examples = [
                {
                    logits: "Raw Logits: [1.2, -2.1, 3.8, 0.5, -0.8, ...]",
                    words: [
                        {word: "dog", prob: 45, color: "#51cf66"},
                        {word: "cat", prob: 30, color: "#ffd43b"},
                        {word: "bird", prob: 15, color: "#ff8787"},
                        {word: "others", prob: 10, color: "#adb5bd"}
                    ],
                    selected: "dog"
                },
                {
                    logits: "Raw Logits: [4.1, 0.2, -1.5, 2.3, 1.1, ...]",
                    words: [
                        {word: "sits", prob: 70, color: "#51cf66"},
                        {word: "runs", prob: 15, color: "#ffd43b"},
                        {word: "jumps", prob: 10, color: "#ff8787"},
                        {word: "others", prob: 5, color: "#adb5bd"}
                    ],
                    selected: "sits"
                }
            ];

            const currentExample = examples[Math.floor(Math.random() * examples.length)];

            // Update logits
            const logitsDisplay = document.querySelector('.flow-step div[style*="monospace"]');
            logitsDisplay.innerHTML = currentExample.logits + '<br><small>One number for each word in vocabulary (10,000+ words)</small>';

            // Update probabilities
            const wordExample = document.querySelector('.word-example');
            wordExample.innerHTML = currentExample.words.map(item => `
                <div style="display: flex; justify-content: space-between; align-items: center; margin: 5px 0;">
                    <span>"${item.word}"</span>
                    <div class="probability-bar" style="width: 150px;">
                        <div class="probability-fill" style="width: ${item.prob}%; background: linear-gradient(90deg, ${item.color}, ${item.color});"></div>
                    </div>
                    <span>${item.prob}%</span>
                </div>
            `).join('');

            // Update selection
            const selection = document.querySelector('.flow-step:last-child div[style*="font-size"]');
            selection.innerHTML = `Selected: "${currentExample.selected}" ✓`;
        }
    </script>
</body>
</html>
"""

# Display it
HTML(html_content2)

# Objectives
After completing this lab, you will be able to:

- Translate a PDF document from German to English




In [1]:
#!pip install -U spacy==3.7.2
#!pip install -Uqq portalocker==2.7.0
#!pip install -qq torchtext==0.14.1
#!pip install -Uq nltk==3.8.1

#!python -m spacy download de
#!python -m spacy download en

#!pip install pdfplumber==0.9.0
#!pip install fpdf==1.7.2

#!wget 'https://cf-courses-data.s3.us.cloud-object-storage.appdomain.cloud/IBMSkillsNetwork-AI0205EN-SkillsNetwork/Multi30K_de_en_dataloader.py'
#!wget 'https://cf-courses-data.s3.us.cloud-object-storage.appdomain.cloud/IBMSkillsNetwork-AI0201EN-Coursera/transformer.pt'
#!wget 'https://cf-courses-data.s3.us.cloud-object-storage.appdomain.cloud/IBMSkillsNetwork-AI0201EN-Coursera/input_de.pdf'

## Importing required libraries


In [2]:

import matplotlib.pyplot as plt
from nltk.translate.bleu_score import sentence_bleu
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
from tqdm import tqdm

# You can also use this section to suppress warnings generated by your code:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn
warnings.filterwarnings('ignore')

In [3]:
from datasets import load_dataset

dataset = load_dataset("bentrevett/multi30k")
train_data = dataset['train']
val_data = dataset['validation']
test_data = dataset['test']

# Access the data
for example in train_data:
    print(f"English: {example['en']}")
    print(f"German: {example['de']}")
    break

English: Two young, White males are outside near many bushes.
German: Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.


In [4]:
# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

In [5]:
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch
from collections import Counter
import pickle
import os

def build_vocab(sentences, min_freq=2):
    """Build vocabulary from sentences"""
    counter = Counter()
    for sentence in sentences:
        counter.update(sentence.split())

    vocab = {'<unk>': 0, '<pad>': 1, '<bos>': 2, '<eos>': 3}
    for word, freq in counter.items():
        if freq >= min_freq:
            vocab[word] = len(vocab)

    return vocab

def text_to_tensor(text, vocab, max_len=None):
    """Convert text to tensor using vocabulary"""
    tokens = ['<bos>'] + text.split() + ['<eos>']
    if max_len:
        tokens = tokens[:max_len]

    indices = [vocab.get(token, vocab['<unk>']) for token in tokens]
    return torch.tensor(indices, dtype=torch.long)

def get_translation_dataloaders_hf(batch_size=1, max_len=50):
    """
    Replacement for TorchText's get_translation_dataloaders using Hugging Face Datasets
    Returns tensors that can be transposed with .T
    """
    # Load Multi30k dataset
    dataset = load_dataset("bentrevett/multi30k")

    # Get train and validation datasets
    train_dataset = dataset['train']
    val_dataset = dataset['validation']

    # Build vocabularies (you might want to save/load these)
    print("Building vocabularies...")
    en_sentences = [item['en'] for item in train_dataset]
    de_sentences = [item['de'] for item in train_dataset]

    en_vocab = build_vocab(en_sentences)
    de_vocab = build_vocab(de_sentences)

    print(f"English vocab size: {len(en_vocab)}")
    print(f"German vocab size: {len(de_vocab)}")

    def collate_fn(batch):
        """Custom collate function to convert text to tensors"""
        english_tensors = []
        german_tensors = []

        # Find max length in batch for padding
        max_en_len = max(len(item['en'].split()) + 2 for item in batch)  # +2 for <bos>, <eos>
        max_de_len = max(len(item['de'].split()) + 2 for item in batch)

        for item in batch:
            en_tensor = text_to_tensor(item['en'], en_vocab, max_len)
            de_tensor = text_to_tensor(item['de'], de_vocab, max_len)

            # Pad to max length in batch
            en_padded = torch.nn.functional.pad(en_tensor, (0, max_en_len - len(en_tensor)), value=en_vocab['<pad>'])
            de_padded = torch.nn.functional.pad(de_tensor, (0, max_de_len - len(de_tensor)), value=de_vocab['<pad>'])

            english_tensors.append(en_padded)
            german_tensors.append(de_padded)

        # Stack into batch tensors
        english_batch = torch.stack(english_tensors)  # [batch_size, seq_len]
        german_batch = torch.stack(german_tensors)    # [batch_size, seq_len]

        return english_batch, german_batch

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn
    )

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn
    )

    # Store vocabularies as attributes for later use
    train_dataloader.en_vocab = en_vocab
    train_dataloader.de_vocab = de_vocab
    val_dataloader.en_vocab = en_vocab
    val_dataloader.de_vocab = de_vocab

    return train_dataloader, val_dataloader

train_dataloader, val_dataloader = get_translation_dataloaders_hf(batch_size=1)

# Create iterator
data_itr = iter(train_dataloader)

# Now this will work with tensors
english, german = next(data_itr)
print(f"English tensor shape: {english.shape}")
print(f"German tensor shape: {german.shape}")

# Now you can transpose!
german = german.T
english = english.T

print(f"After transpose - English: {english.shape}")
print(f"After transpose - German: {german.shape}")

# Example: decode back to text to verify
def decode_tensor(tensor, vocab):
    """Convert tensor back to text"""
    idx_to_word = {v: k for k, v in vocab.items()}
    words = [idx_to_word.get(idx.item(), '<unk>') for idx in tensor.squeeze()]
    # Remove padding and special tokens for display
    words = [w for w in words if w not in ['<pad>', '<bos>', '<eos>']]
    return ' '.join(words)

print(f"English text: {decode_tensor(english, train_dataloader.en_vocab)}")
print(f"German text: {decode_tensor(german, train_dataloader.de_vocab)}")

Building vocabularies...
English vocab size: 7964
German vocab size: 9762
English tensor shape: torch.Size([1, 17])
German tensor shape: torch.Size([1, 15])
After transpose - English: torch.Size([17, 1])
After transpose - German: torch.Size([15, 1])
English text: An elderly man sits outside a storefront accompanied by a young boy with a cart.
German text: Ein älterer Mann sitzt mit einem Jungen mit einem Wagen vor einer Fassade.


In [6]:
from datasets import load_dataset

# Load dataset
dataset = load_dataset("bentrevett/multi30k")
train_data = dataset['train']

# Create simple iterator
data_itr = iter(train_data)

In [7]:
data_itr=iter(train_dataloader)
data_itr

<torch.utils.data.dataloader._SingleProcessDataLoaderIter at 0x166ac3aa0>

In [8]:
for n in range(1000):
    german, english= next(data_itr)

In [9]:
german=german.T
english=english.T

In [10]:
def index_to_german(tensor, vocab=None):
    """Convert German tensor indices back to text"""
    if vocab is None:
        raise ValueError("Need German vocabulary to decode")

    idx_to_word = {v: k for k, v in vocab.items()}
    if tensor.dim() > 1:
        # Handle batch dimension
        sentences = []
        for i in range(tensor.shape[0]):
            words = [idx_to_word.get(idx.item(), '<unk>') for idx in tensor[i]]
            # Remove special tokens and padding
            words = [w for w in words if w not in ['<pad>', '<bos>', '<eos>']]
            sentences.append(' '.join(words))
        return sentences
    else:
        words = [idx_to_word.get(idx.item(), '<unk>') for idx in tensor]
        words = [w for w in words if w not in ['<pad>', '<bos>', '<eos>']]
        return ' '.join(words)

def index_to_eng(tensor, vocab=None):
    """Convert English tensor indices back to text"""
    if vocab is None:
        raise ValueError("Need English vocabulary to decode")

    idx_to_word = {v: k for k, v in vocab.items()}
    if tensor.dim() > 1:
        # Handle batch dimension
        sentences = []
        for i in range(tensor.shape[0]):
            words = [idx_to_word.get(idx.item(), '<unk>') for idx in tensor[i]]
            # Remove special tokens and padding
            words = [w for w in words if w not in ['<pad>', '<bos>', '<eos>']]
            sentences.append(' '.join(words))
        return sentences
    else:
        words = [idx_to_word.get(idx.item(), '<unk>') for idx in tensor]
        words = [w for w in words if w not in ['<pad>', '<bos>', '<eos>']]
        return ' '.join(words)

# Global variables to store vocabularies
DE_VOCAB = None
EN_VOCAB = None

def set_global_vocabs(train_dataloader):
    """Set global vocabularies for easy access"""
    global DE_VOCAB, EN_VOCAB
    DE_VOCAB = train_dataloader.de_vocab
    EN_VOCAB = train_dataloader.en_vocab

def index_to_german_global(tensor):
    """Convert German tensor to text using global vocab"""
    return index_to_german(tensor, DE_VOCAB)

def index_to_eng_global(tensor):
    """Convert English tensor to text using global vocab"""
    return index_to_eng(tensor, EN_VOCAB)

In [None]:
train_dataloader, _ = get_translation_dataloaders_hf(batch_size=1)
set_global_vocabs(train_dataloader)
data_itr = iter(train_dataloader)


for n in range(10):
    german, english = next(data_itr)
    print("sample {}".format(n))
    print("german input")
    print(index_to_german_global(german))
    print("english target")
    print(index_to_eng_global(english))
    print("_________\n")

In [12]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cpu')

In [13]:
def generate_square_subsequent_mask(sz,device=DEVICE):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

In [14]:
def create_mask(src, tgt,device=DEVICE):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

### Positional encoding
The transformer model doesn't have built-in knowledge of the order of tokens in the sequence. To give the model this information, positional encodings are added to the tokens embeddings. These encodings have a fixed pattern based on their position in the sequence.


In [15]:
# Add positional information to the input tokens
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

### Token embedding
Token embedding, also known as word embedding or word representation, is a way to convert words or tokens from a text corpus into numerical vectors in a continuous vector space. Each unique word or token in the corpus is assigned a fixed-length vector where the numerical values represent various linguistic properties of the word, such as its meaning, context, or relationships with other words.

The `TokenEmbedding` class below converts numerical tokens into embeddings:


In [16]:
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

In [17]:
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()

        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        outs =outs.to(DEVICE)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

## Inference


The diagram below illustrates the sequence prediction or inference process.
<img src="https://cf-courses-data.s3.us.cloud-object-storage.appdomain.cloud/IBMSkillsNetwork-AI0201EN-Coursera/predict_transformers.png" alt="transformer">
The decoder's output is then mapped onto a vocabulary-sized vector using a linear layer. Following this, a softmax function converts these vector scores into probabilities. The highest probability, as determined by the argmax function, provides the index of your predicted word within the translated sequence. This predicted index is fed back into the decoder in conjunction with the initial sequence, setting the stage to determine the subsequent word in the translation. This autoregressive process is demonstrated by the arrow pointing to form the top of the decoder, in green, to the bottom.


In [18]:
# Add this after the get_translation_dataloaders_hf function
vocab_transform = {}

def create_vocab_transform(train_dataloader):
    """Create vocab_transform dictionary for compatibility"""
    global vocab_transform
    vocab_transform = {
        'de': train_dataloader.de_vocab,
        'en': train_dataloader.en_vocab
    }
    return vocab_transform

In [19]:
train_dataloader, _ = get_translation_dataloaders_hf(batch_size=1)
set_global_vocabs(train_dataloader)

# Create vocab_transform for compatibility with existing code
vocab_transform = create_vocab_transform(train_dataloader)

# Now your existing code will work
SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'
SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
EMB_SIZE = 512

print(f"Source (German) vocab size: {SRC_VOCAB_SIZE}")
print(f"Target (English) vocab size: {TGT_VOCAB_SIZE}")

Building vocabularies...
English vocab size: 7964
German vocab size: 9762
Source (German) vocab size: 9762
Target (English) vocab size: 7964


In [20]:
torch.manual_seed(0)

SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'
SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

Let's will start off with a trained model.For this, load the weights of the transformer model from the file 'transformer.pt'.





In [None]:
transformer.load_state_dict(torch.load('transformer.pt', map_location=DEVICE, ))

In [None]:
#print("engish target",index_to_eng(tgt))
#print("german input",index_to_german(src))

In [23]:
# First, let's check the vocabulary sizes from the checkpoint
checkpoint = torch.load('transformer.pt', map_location=DEVICE)
src_vocab_size = checkpoint['src_tok_emb.embedding.weight'].shape[0]
tgt_vocab_size = checkpoint['tgt_tok_emb.embedding.weight'].shape[0]

print(f"Checkpoint expects - German vocab: {src_vocab_size}, English vocab: {tgt_vocab_size}")

# Create model with the checkpoint's expected vocabulary sizes
SRC_VOCAB_SIZE = src_vocab_size  # 19214
TGT_VOCAB_SIZE = tgt_vocab_size  # 10837
EMB_SIZE = 512

# Create the model architecture (you'll need your existing model definition)
transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS,
                               EMB_SIZE, NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

# Now load the weights
transformer.load_state_dict(checkpoint)
transformer.eval()

print("Model loaded successfully!")

Checkpoint expects - German vocab: 19214, English vocab: 10837
Model loaded successfully!


In [24]:
import pdfplumber
def extract_text_pdfplumber(pdf_path):
    text = ""
    with pdfplumber.open(pdf_path) as pdf:
        for page in pdf.pages:
            text += page.extract_text()
    return text

In [25]:
def preprocess_for_translation(text):
    # Split into sentences
    sentences = text.split('.')
    # Clean each sentence
    cleaned_sentences = []
    for sentence in sentences:
        sentence = sentence.strip()
        if sentence:  # Skip empty sentences
            cleaned_sentences.append(sentence)
    return cleaned_sentences

In [26]:
def translate_text(text, transformer, src_vocab, tgt_vocab, device):
    # Tokenize text using your vocabulary
    tokens = ['<bos>'] + text.split() + ['<eos>']
    src_indices = [src_vocab.get(token, src_vocab['<unk>']) for token in tokens]
    src_tensor = torch.tensor(src_indices).unsqueeze(0).to(device)  # Add batch dimension

    # Generate translation using your transformer
    with torch.no_grad():
        # You'd need to implement the actual inference logic here
        # This depends on how your transformer's forward method works
        output = transformer.generate(src_tensor)  # This method would need to be implemented

    # Convert output indices back to words
    tgt_vocab_reverse = {v: k for k, v in tgt_vocab.items()}
    translated_words = [tgt_vocab_reverse.get(idx.item(), '<unk>') for idx in output.squeeze()]

    return ' '.join(translated_words)

In [27]:
def translate_pdf(pdf_path, transformer, de_vocab, en_vocab, device):
    # Extract text
    text = extract_text_pdfplumber(pdf_path)

    # Split into manageable chunks
    sentences = preprocess_for_translation(text)

    # Translate each sentence
    translated_sentences = []
    for sentence in sentences:
        if sentence.strip():
            try:
                translation = translate_text(sentence, transformer, de_vocab, en_vocab, device)
                translated_sentences.append(translation)
            except Exception as e:
                print(f"Error translating: {sentence[:50]}... Error: {e}")
                translated_sentences.append(f"[TRANSLATION ERROR: {sentence}]")

    return '\n'.join(translated_sentences)

In [None]:
# Check what methods your transformer has
print(dir(transformer))

# Or look for methods containing 'translate', 'generate', 'decode':
methods = [method for method in dir(transformer) if any(word in method.lower() for word in ['translate', 'generate', 'decode', 'forward'])]
print("Relevant methods:", methods)

In [41]:
def translate_sentence_simple(sentence, transformer, de_vocab, en_vocab, device):
    try:
        # Convert to tokens
        src_tokens = ['<bos>'] + sentence.split()[:10] + ['<eos>']
        src_indices = [de_vocab.get(token, de_vocab['<unk>']) for token in src_tokens]

        # Create tensors with sequence first [seq_len, batch_size]
        src_tensor = torch.tensor(src_indices).unsqueeze(1).to(device)  # [seq_len, 1]
        tgt_tensor = torch.tensor([en_vocab['<bos>']]).unsqueeze(1).to(device)  # [1, 1]

        src_len = src_tensor.size(0)
        tgt_len = tgt_tensor.size(0)

        # Create masks
        src_mask = torch.zeros((src_len, src_len), device=device)
        tgt_mask = torch.zeros((tgt_len, tgt_len), device=device)
        src_padding_mask = torch.zeros((1, src_len), dtype=torch.bool, device=device)
        tgt_padding_mask = torch.zeros((1, tgt_len), dtype=torch.bool, device=device)
        memory_key_padding_mask = torch.zeros((1, src_len), dtype=torch.bool, device=device)

        # Forward pass
        with torch.no_grad():
            output = transformer(
                src_tensor, tgt_tensor,
                src_mask, tgt_mask,
                src_padding_mask, tgt_padding_mask,
                memory_key_padding_mask
            )

        # Get the most likely next token
        probs = torch.softmax(output[-1, 0], dim=-1)  # Last position, first batch
        next_token = torch.argmax(probs).item()

        # Convert back to word
        en_vocab_rev = {v: k for k, v in en_vocab.items()}
        word = en_vocab_rev.get(next_token, '<unk>')

        return f"Predicted next word: {word}"

    except Exception as e:
        return f"Error: {str(e)[:150]}"

In [42]:


# Replace in your PDF function:
def translate_pdf_simple(pdf_path, transformer, de_vocab, en_vocab, device):
    import pdfplumber

    with pdfplumber.open(pdf_path) as pdf:
        text = ""
        for page in pdf.pages:
            text += page.extract_text() + " "

    sentences = text.split('.')

    for i, sentence in enumerate(sentences[:3]):
        if sentence.strip():
            print(f"German: {sentence.strip()}")
            translation = translate_sentence_simple(sentence.strip(), transformer, de_vocab, en_vocab, device)
            print(f"English: {translation}")
            print()

# Use it:
# Create vocabularies with the exact sizes the model expects
def create_dummy_vocab(size, prefix="word"):
    vocab = {'<unk>': 0, '<pad>': 1, '<bos>': 2, '<eos>': 3}
    for i in range(4, size):
        vocab[f"{prefix}_{i}"] = i
    return vocab

# Create the vocabularies
de_vocab = create_dummy_vocab(19214, "de")  # German vocab
en_vocab = create_dummy_vocab(10837, "en")  # English vocab

# Now you can use them
translate_pdf_simple('input_de.pdf', transformer, de_vocab, en_vocab, DEVICE)

German: Der frühe Morgen bricht an und die ersten Sonnenstrahlen kitzeln san8 mein Gesicht
English: Predicted next word: en_6

German: Ich atme
=ef ein und spüre die frische Morgenlu8 in meinen Lungen
English: Predicted next word: en_6

German: Mit einem Lächeln auf den Lippen
stehe ich auf und beginne den Tag mit voller Energie
English: Predicted next word: en_6



In [5]:
from transformers import MarianMTModel, MarianTokenizer

model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-de-en")
tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-de-en")

# Input
text = "Hallo Welt, wie geht es dir?"
inputs = tokenizer(text, return_tensors="pt")
print(f"Input text: {text}")
print(f"Input IDs: {inputs['input_ids']}")

# Generate translation
outputs = model.generate(
    inputs['input_ids'],
    num_beams=5,
    max_length=50,
    early_stopping=True
)

print(f"Output IDs: {outputs}")
print(f"Output shape: {outputs.shape}")

# Decode back to text
translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Translated text: {translated_text}")

# Return multiple sequences
outputs = model.generate(
    inputs['input_ids'],
    num_beams=5,
    num_return_sequences=3,  # Get top 3 beam results
    max_length=50,
    early_stopping=True
)

# Decode all sequences
for i, output in enumerate(outputs):
    translated = tokenizer.decode(output, skip_special_tokens=True)
    print(f"Translation {i+1}: {translated}")

Input text: Hallo Welt, wie geht es dir?
Input IDs: tensor([[11918,   401,     2,   107,   652,    65,   600,    31,     0]])
Output IDs: tensor([[58100, 16816,     2,   360,     2,   406,    48,    41,    31,     0]])
Output shape: torch.Size([1, 10])
Translated text: Hello, world, how are you?
Translation 1: Hello, world, how are you?
Translation 2: Hello world, how are you?
Translation 3: Hello, world. How are you?


In [6]:
# Install required packages (run this first if needed)
# !pip install transformers torch sentencepiece

from transformers import MarianMTModel, MarianTokenizer
import torch

# Load pre-trained German to English translation model
print("📥 Loading Hugging Face German-English translation model...")
model_name = "Helsinki-NLP/opus-mt-de-en"
model = MarianMTModel.from_pretrained(model_name)
tokenizer = MarianTokenizer.from_pretrained(model_name)
print("✅ Model loaded successfully!")

# Interesting ambiguous German sentences for beam search
sentences = [
    "Der alte Mann der Straße gibt dem Kind einen Ball",
    "Der Mann sieht den Hund mit dem Fernglas", 
    "Das kann man nicht machen",
    "Maria gab Anna das Buch, weil sie es brauchte"
]

for i, text in enumerate(sentences, 1):
    print(f"\n{'='*70}")
    print(f"SENTENCE {i}: {text}")
    print(f"{'='*70}")
    
    # Tokenize input
    inputs = tokenizer(text, return_tensors="pt")
    print(f"Input IDs: {inputs['input_ids']}")
    
    # 1. Greedy decoding (baseline)
    print(f"\n🔍 GREEDY DECODING:")
    greedy_output = model.generate(
        inputs['input_ids'],
        max_length=50,
        num_beams=1,  # Greedy
        early_stopping=True
    )
    greedy_text = tokenizer.decode(greedy_output[0], skip_special_tokens=True)
    print(f"   {greedy_text}")
    
    # 2. Beam search - single best
    print(f"\n🌟 BEAM SEARCH (k=5, best result):")
    beam_output = model.generate(
        inputs['input_ids'],
        max_length=50,
        num_beams=5,
        early_stopping=True
    )
    beam_text = tokenizer.decode(beam_output[0], skip_special_tokens=True)
    print(f"   {beam_text}")
    
    # 3. Beam search - multiple candidates
    print(f"\n🎯 BEAM SEARCH (top 3 candidates):")
    multiple_outputs = model.generate(
        inputs['input_ids'],
        max_length=50,
        num_beams=5,
        num_return_sequences=3,  # Return top 3
        early_stopping=True
    )
    
    for j, output in enumerate(multiple_outputs, 1):
        candidate_text = tokenizer.decode(output, skip_special_tokens=True)
        print(f"   {j}. {candidate_text}")
    
    # 4. Different beam sizes comparison
    print(f"\n🔄 BEAM SIZE COMPARISON:")
    for beam_size in [1, 3, 5, 8]:
        beam_result = model.generate(
            inputs['input_ids'],
            max_length=50,
            num_beams=beam_size,
            early_stopping=True
        )
        result_text = tokenizer.decode(beam_result[0], skip_special_tokens=True)
        print(f"   Beam size {beam_size}: {result_text}")
    
    if i == 1:  # Show detailed analysis for first sentence only
        print(f"\n📊 DETAILED TOKEN ANALYSIS:")
        print(f"   Input tokens:   {tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])}")
        print(f"   Greedy tokens:  {tokenizer.convert_ids_to_tokens(greedy_output[0])}")
        print(f"   Beam tokens:    {tokenizer.convert_ids_to_tokens(beam_output[0])}")

# Bonus: Compare with sampling methods for the most ambiguous sentence
print(f"\n{'🎲 SAMPLING METHODS COMPARISON'}")
print(f"{'='*50}")
text = "Der alte Mann der Straße gibt dem Kind einen Ball"
inputs = tokenizer(text, return_tensors="pt")
print(f"Input: {text}\n")

torch.manual_seed(42)  # For reproducibility

sampling_methods = [
    {"name": "Beam Search (k=5)", "num_beams": 5, "do_sample": False},
    {"name": "Top-k Sampling (k=50)", "do_sample": True, "top_k": 50, "temperature": 0.8, "num_beams": 1},
    {"name": "Top-p Sampling (p=0.9)", "do_sample": True, "top_p": 0.9, "temperature": 0.8, "num_beams": 1},
    {"name": "Low Temperature (0.3)", "do_sample": True, "temperature": 0.3, "num_beams": 1},
    {"name": "High Temperature (1.5)", "do_sample": True, "temperature": 1.5, "num_beams": 1}
]

for method in sampling_methods:
    print(f"🔧 {method['name']}:")
    method_params = {k: v for k, v in method.items() if k != 'name'}
    
    # Generate 2 samples to show variation
    num_samples = 2 if method['name'] != "Beam Search (k=5)" else 1
    for i in range(num_samples):
        output = model.generate(
            inputs['input_ids'],
            max_length=50,
            early_stopping=True,
            **method_params
        )
        result = tokenizer.decode(output[0], skip_special_tokens=True)
        sample_text = f"Sample {i+1}: " if num_samples > 1 else ""
        print(f"   {sample_text}{result}")
    print()

print("🎉 Translation comparison complete!")
print("💡 Notice how beam search explores different interpretations of ambiguous sentences!")

📥 Loading Hugging Face German-English translation model...


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


✅ Model loaded successfully!

SENTENCE 1: Der alte Mann der Straße gibt dem Kind einen Ball
Input IDs: tensor([[ 119, 4712, 1155,    9, 3766,  297,   57, 2821,  106, 5454,    0]])

🔍 GREEDY DECODING:
   The old man on the street gives the child a ball

🌟 BEAM SEARCH (k=5, best result):
   The old man on the street gives the child a ball

🎯 BEAM SEARCH (top 3 candidates):


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


   1. The old man on the street gives the child a ball
   2. The old man on the street gives the kid a ball
   3. The old man in the street gives the child a ball

🔄 BEAM SIZE COMPARISON:
   Beam size 1: The old man on the street gives the child a ball
   Beam size 3: The old man on the street gives the child a ball
   Beam size 5: The old man on the street gives the child a ball


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


   Beam size 8: The old man on the street gives the child a ball

📊 DETAILED TOKEN ANALYSIS:
   Input tokens:   ['▁Der', '▁alte', '▁Mann', '▁der', '▁Straße', '▁gibt', '▁dem', '▁Kind', '▁einen', '▁Ball', '</s>']
   Greedy tokens:  ['<pad>', '▁The', '▁old', '▁man', '▁on', '▁the', '▁street', '▁gives', '▁the', '▁child', '▁a', '▁ball', '</s>']
   Beam tokens:    ['<pad>', '▁The', '▁old', '▁man', '▁on', '▁the', '▁street', '▁gives', '▁the', '▁child', '▁a', '▁ball', '</s>']

SENTENCE 2: Der Mann sieht den Hund mit dem Fernglas
Input IDs: tensor([[ 119, 1155, 2381,   25, 9162,   30,   57, 6015, 9817,    0]])

🔍 GREEDY DECODING:
   The man sees the dog with the binoculars

🌟 BEAM SEARCH (k=5, best result):
   The man sees the dog with the binoculars

🎯 BEAM SEARCH (top 3 candidates):


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


   1. The man sees the dog with the binoculars
   2. The man sees the dog with binoculars
   3. The man sees the dog with his binoculars

🔄 BEAM SIZE COMPARISON:
   Beam size 1: The man sees the dog with the binoculars
   Beam size 3: The man sees the dog with the binoculars
   Beam size 5: The man sees the dog with the binoculars


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


   Beam size 8: The man sees the dog with the binoculars

SENTENCE 3: Das kann man nicht machen
Input IDs: tensor([[103, 134, 175,  51, 522,   0]])

🔍 GREEDY DECODING:
   You can't do that

🌟 BEAM SEARCH (k=5, best result):
   You can't do that

🎯 BEAM SEARCH (top 3 candidates):


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


   1. You can't do that
   2. You can't do that.
   3. You can't do this.

🔄 BEAM SIZE COMPARISON:
   Beam size 1: You can't do that
   Beam size 3: You can't do that
   Beam size 5: You can't do that


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


   Beam size 8: You can't do that

SENTENCE 4: Maria gab Anna das Buch, weil sie es brauchte
Input IDs: tensor([[ 2719,  1647,  6323,    44,  1817,     2,   765,    76,    65, 27088,
             0]])

🔍 GREEDY DECODING:
   Mary gave Anna the book because she needed it

🌟 BEAM SEARCH (k=5, best result):
   Mary gave Anna the book because she needed it

🎯 BEAM SEARCH (top 3 candidates):


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


   1. Mary gave Anna the book because she needed it
   2. Mary gave the book to Anna because she needed it
   3. Mary gave Anna the book because she needed it.

🔄 BEAM SIZE COMPARISON:
   Beam size 1: Mary gave Anna the book because she needed it
   Beam size 3: Mary gave Anna the book because she needed it
   Beam size 5: Mary gave Anna the book because she needed it
   Beam size 8: Mary gave Anna the book because she needed it

🎲 SAMPLING METHODS COMPARISON
Input: Der alte Mann der Straße gibt dem Kind einen Ball

🔧 Beam Search (k=5):


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


   The old man on the street gives the child a ball

🔧 Top-k Sampling (k=50):
   Sample 1: The old man on the street gives the child a ball


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


   Sample 2: The old man of the street gives the child a ball

🔧 Top-p Sampling (p=0.9):


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


   Sample 1: The old man in the street gives the kid a ball


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


   Sample 2: The old man of the street gives a ball to the kid

🔧 Low Temperature (0.3):
   Sample 1: The old man on the street gives the child a ball


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


   Sample 2: The old man on the street gives the child a ball

🔧 High Temperature (1.5):


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


   Sample 1: Creating A Bump to Tide with this Road, Old Man Heds
   Sample 2: The old guy on Araq road, going all the way across a beautiful old town in a girl like that, hands a ball over a movie when people come across the streets they tell around the world

🎉 Translation comparison complete!
💡 Notice how beam search explores different interpretations of ambiguous sentences!


In [44]:
def greedy_decode(transformer, src_tensor, src_mask, en_vocab, device, max_len=20):
    """Greedy decoding - always pick the most likely next word"""
    generated_tokens = [en_vocab['<bos>']]

    for i in range(max_len):
        tgt_tensor = torch.tensor(generated_tokens).unsqueeze(1).to(device)
        tgt_len = tgt_tensor.size(0)
        src_len = src_tensor.size(0)

        # Create proper masks
        src_mask = torch.zeros((src_len, src_len), device=device)
        tgt_mask = torch.zeros((tgt_len, tgt_len), device=device)  # Fixed: use tgt_len
        src_padding_mask = torch.zeros((1, src_len), dtype=torch.bool, device=device)
        tgt_padding_mask = torch.zeros((1, tgt_len), dtype=torch.bool, device=device)
        memory_key_padding_mask = torch.zeros((1, src_len), dtype=torch.bool, device=device)

        with torch.no_grad():
            output = transformer(src_tensor, tgt_tensor, src_mask, tgt_mask,
                               src_padding_mask, tgt_padding_mask, memory_key_padding_mask)

        # Greedy: pick the most likely token
        probs = torch.softmax(output[-1, 0], dim=-1)
        next_token = torch.argmax(probs).item()

        if next_token == en_vocab['<eos>']:
            break

        generated_tokens.append(next_token)

    return generated_tokens[1:]  # Remove <bos>

def beam_search(transformer, src_tensor, src_mask, en_vocab, device, beam_size=3, max_len=20):
    """Beam search - keep track of top K sequences"""
    beams = [([en_vocab['<bos>']], 0.0)]

    for i in range(max_len):
        candidates = []

        for sequence, score in beams:
            if sequence[-1] == en_vocab['<eos>']:
                candidates.append((sequence, score))
                continue

            tgt_tensor = torch.tensor(sequence).unsqueeze(1).to(device)
            tgt_len = tgt_tensor.size(0)
            src_len = src_tensor.size(0)

            # Create proper masks
            src_mask_local = torch.zeros((src_len, src_len), device=device)
            tgt_mask = torch.zeros((tgt_len, tgt_len), device=device)  # Fixed
            src_padding_mask = torch.zeros((1, src_len), dtype=torch.bool, device=device)
            tgt_padding_mask = torch.zeros((1, tgt_len), dtype=torch.bool, device=device)
            memory_key_padding_mask = torch.zeros((1, src_len), dtype=torch.bool, device=device)

            with torch.no_grad():
                output = transformer(src_tensor, tgt_tensor, src_mask_local, tgt_mask,
                                   src_padding_mask, tgt_padding_mask, memory_key_padding_mask)

            # Get top beam_size tokens
            log_probs = torch.log_softmax(output[-1, 0], dim=-1)
            top_probs, top_indices = torch.topk(log_probs, beam_size)

            for prob, idx in zip(top_probs, top_indices):
                new_sequence = sequence + [idx.item()]
                new_score = score + prob.item()
                candidates.append((new_sequence, new_score))

        # Keep only top beam_size sequences
        beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_size]

        # Check if all beams ended
        if all(seq[-1] == en_vocab['<eos>'] for seq, _ in beams):
            break

    best_sequence, best_score = beams[0]
    return best_sequence[1:]  # Remove <bos>

def compare_decoding_methods(sentence, transformer, de_vocab, en_vocab, device):
    """Compare greedy vs beam search for a sentence"""
    try:
        # Prepare source
        src_tokens = ['<bos>'] + sentence.split()[:8] + ['<eos>']
        src_indices = [de_vocab.get(token, de_vocab['<unk>']) for token in src_tokens]
        src_tensor = torch.tensor(src_indices).unsqueeze(1).to(device)
        src_mask = None  # Let the functions create their own masks

        print(f"German: {sentence}")

        # Greedy decoding
        greedy_tokens = greedy_decode(transformer, src_tensor, src_mask, en_vocab, device)
        en_vocab_rev = {v: k for k, v in en_vocab.items()}
        greedy_words = [en_vocab_rev.get(token, f'token_{token}') for token in greedy_tokens]
        print(f"Greedy:     {' '.join(greedy_words)}")

        # Beam search
        beam_tokens = beam_search(transformer, src_tensor, src_mask, en_vocab, device, beam_size=3)
        beam_words = [en_vocab_rev.get(token, f'token_{token}') for token in beam_tokens]
        print(f"Beam(k=3):  {' '.join(beam_words)}")
        print("-" * 50)

    except Exception as e:
        print(f"Error: {str(e)[:100]}")

In [45]:
# Usage:
sentences = [
    "Der frühe Morgen bricht an",
    "Ich gehe zur Schule",
    "Das Wetter ist schön"
]

for sentence in sentences:
    compare_decoding_methods(sentence, transformer, de_vocab, en_vocab, DEVICE)

German: Der frühe Morgen bricht an
Greedy:     en_6 en_193 en_966 en_10 en_199 en_26 en_1365 en_5
Beam(k=3):  en_6 en_193 en_966 en_14 en_26 en_1424 <eos>
--------------------------------------------------
German: Ich gehe zur Schule
Greedy:     en_6 en_193 en_966 en_10 en_199 en_26 en_1365 en_5
Beam(k=3):  en_6 en_193 en_966 en_14 en_26 en_1424 <eos>
--------------------------------------------------
German: Das Wetter ist schön
Greedy:     en_6 en_193 en_966 en_10 en_199 en_26 en_1365 en_5
Beam(k=3):  en_6 en_193 en_966 en_14 en_26 en_1424 <eos>
--------------------------------------------------
