In [None]:
# 필요한 라이브러리 설치
# !pip install flask-ngrok transformers torch nltk matplotlib
# !pip install flask-cors
# !pip install git+https://github.com/huggingface/accelerate.git
# !pip install --upgrade huggingface_hub

In [None]:
import os
import torch
import socket
from flask import Flask, render_template_string, request, jsonify
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
import nltk
from nltk.translate.bleu_score import sentence_bleu
import matplotlib.pyplot as plt
import io
import base64
from google.colab import output
from google.colab.output import eval_js
from threading import Thread
from flask_cors import CORS

nltk.download('punkt')
os.environ["HUGGING_FACE_HUB_TOKEN"] = "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"

# Flask 앱 초기화
app = Flask(__name__)
CORS(app)

# 모델 및 토크나이저 초기화
phi_model_name = "microsoft/Phi-3-mini-4k-instruct"
gemma_model_name = "google/gemma-2b"

# Phi-3 모델 로드
phi_tokenizer = AutoTokenizer.from_pretrained(phi_model_name, padding_side="left")
phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name, torch_dtype=torch.float32)

# Gemma 모델 로드
gemma_model = AutoModelForCausalLM.from_pretrained(gemma_model_name, torch_dtype=torch.float32, use_auth_token=True)
gemma_tokenizer = AutoTokenizer.from_pretrained(gemma_model_name, padding_side="left", use_auth_token=True)

# 패딩 토큰 설정
phi_tokenizer.pad_token = phi_tokenizer.eos_token
gemma_tokenizer.pad_token = gemma_tokenizer.eos_token

print("All models loaded successfully")

# 텍스트 생성 함수
def generate_response(prompt, model_name, max_length=100):
    if model_name == "Phi-3":
        model = phi_model
        tokenizer = phi_tokenizer
    elif model_name == "Gemma":
        model = gemma_model
        tokenizer = gemma_tokenizer
    else:
        raise ValueError("Invalid model name")

    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
    start_time = time.time()

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            num_return_sequences=1,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            temperature=0.7
        )

    end_time = time.time()
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    generation_time = end_time - start_time
    tokens_generated = len(outputs[0]) - len(inputs['input_ids'][0])
    tokens_per_second = tokens_generated / generation_time

    return generated_text, generation_time, tokens_per_second, tokens_generated

print("Response generation function defined")

# BLEU 점수 계산 함수
def calculate_bleu(reference, candidate):
   reference_tokens = nltk.word_tokenize(reference)
   candidate_tokens = nltk.word_tokenize(candidate)
   return sentence_bleu([reference_tokens], candidate_tokens)

# 배치 처리 함수
def batch_process(prompts, references, model_name):
    results = []
    for prompt, reference in zip(prompts, references):
        response, gen_time, tokens_per_sec, tokens_gen = generate_response(prompt, model_name)
        bleu_score = calculate_bleu(reference, response)
        results.append({
            "prompt": prompt,
            "response": response,
            "reference": reference,
            "bleu_score": bleu_score,
            "generation_time": gen_time,
            "tokens_per_second": tokens_per_sec,
            "tokens_generated": tokens_gen
        })
    return results

# 결과 시각화 함수
def plot_results(results):
    bleu_scores = [r["bleu_score"] for r in results]
    gen_times = [r["generation_time"] for r in results]
    tokens_per_sec = [r["tokens_per_second"] for r in results]

    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 15))

    ax1.bar(range(len(bleu_scores)), bleu_scores)
    ax1.set_title("BLEU Scores")
    ax1.set_xlabel("Prompt Index")
    ax1.set_ylabel("BLEU Score")

    ax2.bar(range(len(gen_times)), gen_times)
    ax2.set_title("Generation Times")
    ax2.set_xlabel("Prompt Index")
    ax2.set_ylabel("Time (seconds)")

    ax3.bar(range(len(tokens_per_sec)), tokens_per_sec)
    ax3.set_title("Tokens per Second")
    ax3.set_xlabel("Prompt Index")
    ax3.set_ylabel("Tokens/Second")

    plt.tight_layout()

    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    return base64.b64encode(buf.getvalue()).decode()

# JavaScript 코드를 별도의 문자열 변수로 정의
javascript_code = """
let singlePerformanceChart;
let batchPerformanceChart;

document.addEventListener('DOMContentLoaded', function() {
    console.log('DOM fully loaded and parsed');

    // 탭 기능 구현
    const tabs = document.querySelectorAll('.tablinks');
    tabs.forEach(tab => {
        tab.addEventListener('click', function(event) {
            const tabName = this.getAttribute('data-tab');
            console.log('Tab clicked:', tabName);
            openTab(event, tabName);
        });
    });

    // 기본 탭 열기
    document.getElementById("defaultOpen").click();

    // 버튼에 이벤트 리스너 추가
    document.getElementById('generate-btn').addEventListener('click', generateComparison);
    document.getElementById('batch-btn').addEventListener('click', runBatchTest);
});

function openTab(evt, tabName) {
    console.log('Opening tab:', tabName);
    var i, tabcontent, tablinks;
    tabcontent = document.getElementsByClassName("tabcontent");
    for (i = 0; i < tabcontent.length; i++) {
        tabcontent[i].style.display = "none";
    }
    tablinks = document.getElementsByClassName("tablinks");
    for (i = 0; i < tablinks.length; i++) {
        tablinks[i].className = tablinks[i].className.replace(" active", "");
    }
    document.getElementById(tabName).style.display = "block";
    evt.currentTarget.className += " active";
}

async function generateComparison() {
    const question = document.getElementById('single-question').value;
    const phi3ResultDiv = document.getElementById('phi3-result');
    const gemmaResultDiv = document.getElementById('gemma-result');
    const loadingDiv = document.getElementById('single-loading');
    const generateBtn = document.getElementById('generate-btn');

    if (!question.trim()) {
        phi3ResultDiv.innerHTML = '<p class="error">질문을 입력해주세요.</p>';
        gemmaResultDiv.innerHTML = '<p class="error">질문을 입력해주세요.</p>';
        return;
    }

    generateBtn.disabled = true;
    loadingDiv.style.display = 'block';
    phi3ResultDiv.innerHTML = '';
    gemmaResultDiv.innerHTML = '';

    try {
        const [phi3Result, gemmaResult] = await Promise.all([
            axios.post('/generate', { model: 'Phi-3', prompt: question }),
            axios.post('/generate', { model: 'Gemma', prompt: question })
        ]);

        phi3ResultDiv.innerHTML = formatResponse(phi3Result.data);
        gemmaResultDiv.innerHTML = formatResponse(gemmaResult.data);

        updateSinglePerformanceChart(phi3Result.data, gemmaResult.data);
    } catch (error) {
        console.error('Error:', error);
        phi3ResultDiv.innerHTML = `<p class="error">오류 발생: ${error.message}</p>`;
        gemmaResultDiv.innerHTML = `<p class="error">오류 발생: ${error.message}</p>`;
    } finally {
        generateBtn.disabled = false;
        loadingDiv.style.display = 'none';
    }
}

function formatResponse(data) {
    return `
        <p>${data.response}</p>
        <p>생성 시간: ${data.generation_time.toFixed(2)} 초</p>
        <p>토큰 생성 속도: ${data.tokens_per_second.toFixed(2)} 토큰/초</p>
        <p>생성된 토큰 수: ${data.tokens_generated}</p>
    `;
}

function updateSinglePerformanceChart(phi3Data, gemmaData) {
    const ctx = document.getElementById('single-performance-chart').getContext('2d');

    if (singlePerformanceChart) {
        singlePerformanceChart.destroy();
    }

    singlePerformanceChart = new Chart(ctx, {
        type: 'bar',
        data: {
            labels: ['생성 시간 (초)', '토큰 생성 속도 (토큰/초)', '생성된 토큰 수'],
            datasets: [
                {
                    label: 'Phi-3',
                    data: [phi3Data.generation_time, phi3Data.tokens_per_second, phi3Data.tokens_generated],
                    backgroundColor: 'rgba(75, 192, 192, 0.6)',
                    borderColor: 'rgba(75, 192, 192, 1)',
                    borderWidth: 1
                },
                {
                    label: 'Gemma',
                    data: [gemmaData.generation_time, gemmaData.tokens_per_second, gemmaData.tokens_generated],
                    backgroundColor: 'rgba(255, 99, 132, 0.6)',
                    borderColor: 'rgba(255, 99, 132, 1)',
                    borderWidth: 1
                }
            ]
        },
        options: {
            responsive: true,
            scales: {
                y: {
                    beginAtZero: true
                }
            }
        }
    });
}

async function runBatchTest() {
    const prompts = document.getElementById('batch-prompts').value.split('\\n');
    const references = document.getElementById('batch-references').value.split('\\n');
    const resultDiv = document.getElementById('batch-result');
    const loadingDiv = document.getElementById('batch-loading');
    const batchBtn = document.getElementById('batch-btn');

    if (!prompts[0].trim()) {
        resultDiv.innerHTML = '<p class="error">프롬프트를 입력해주세요.</p>';
        return;
    }

    batchBtn.disabled = true;
    loadingDiv.style.display = 'block';
    resultDiv.innerHTML = '';

    try {
        const [phi3Result, gemmaResult] = await Promise.all([
            axios.post('/batch', { model: 'Phi-3', prompts, references }),
            axios.post('/batch', { model: 'Gemma', prompts, references })
        ]);

        let batchHtml = '<h2>배치 테스트 결과</h2>';
        batchHtml += '<h3>Phi-3 결과</h3>';
        batchHtml += formatBatchResult(phi3Result.data);
        batchHtml += '<h3>Gemma 결과</h3>';
        batchHtml += formatBatchResult(gemmaResult.data);

        resultDiv.innerHTML = batchHtml;

        updateBatchPerformanceChart(phi3Result.data, gemmaResult.data);
    } catch (error) {
        console.error('Error:', error);
        resultDiv.innerHTML = `<p class="error">오류 발생: ${error.message}</p>`;
    } finally {
        batchBtn.disabled = false;
        loadingDiv.style.display = 'none';
    }
}

function formatBatchResult(data) {
    let html = '';
    data.results.forEach((item, index) => {
        html += `
            <div class="batch-item">
                <h4>프롬프트 ${index + 1}:</h4>
                <p><strong>입력:</strong> ${item.prompt}</p>
                <p><strong>생성된 응답:</strong> ${item.response}</p>
                ${item.reference ? `<p><strong>참조 문장:</strong> ${item.reference}</p>` : ''}
                ${item.bleu_score ? `<p><strong>BLEU 점수:</strong> ${item.bleu_score.toFixed(4)}</p>` : ''}
                <p><strong>생성 시간:</strong> ${item.generation_time.toFixed(2)} 초</p>
                <p><strong>토큰 생성 속도:</strong> ${item.tokens_per_second.toFixed(2)} 토큰/초</p>
                <p><strong>생성된 토큰 수:</strong> ${item.tokens_generated}</p>
            </div>
        `;
    });
    return html;
}

function updateBatchPerformanceChart(phi3Data, gemmaData) {
    const ctx = document.getElementById('batch-performance-chart').getContext('2d');

    if (batchPerformanceChart) {
        batchPerformanceChart.destroy();
    }

    const phi3Avg = calculateAverages(phi3Data.results);
    const gemmaAvg = calculateAverages(gemmaData.results);

    batchPerformanceChart = new Chart(ctx, {
        type: 'bar',
        data: {
            labels: ['평균 생성 시간 (초)', '평균 토큰 생성 속도 (토큰/초)', '평균 생성된 토큰 수', '평균 BLEU 점수'],
            datasets: [
                {
                    label: 'Phi-3',
                    data: [phi3Avg.avgTime, phi3Avg.avgSpeed, phi3Avg.avgTokens, phi3Avg.avgBleu],
                    backgroundColor: 'rgba(75, 192, 192, 0.6)',
                    borderColor: 'rgba(75, 192, 192, 1)',
                    borderWidth: 1
                },
                {
                    label: 'Gemma',
                    data: [gemmaAvg.avgTime, gemmaAvg.avgSpeed, gemmaAvg.avgTokens, gemmaAvg.avgBleu],
                    backgroundColor: 'rgba(255, 99, 132, 0.6)',
                    borderColor: 'rgba(255, 99, 132, 1)',
                    borderWidth: 1
                }
            ]
        },
        options: {
            responsive: true,
            scales: {
                y: {
                    beginAtZero: true
                }
            }
        }
    });
}

function calculateAverages(results) {
    const sum = results.reduce((acc, curr) => {
        return {
            time: acc.time + curr.generation_time,
            speed: acc.speed + curr.tokens_per_second,
            tokens: acc.tokens + curr.tokens_generated,
            bleu: acc.bleu + (curr.bleu_score || 0)
        };
    }, { time: 0, speed: 0, tokens: 0, bleu: 0 });

    const count = results.length;
    return {
        avgTime: sum.time / count,
        avgSpeed: sum.speed / count,
        avgTokens: sum.tokens / count,
        avgBleu: sum.bleu / count
    };
}
"""

# HTML 템플릿 수정
html_template = """
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>SLLM Model Comparison</title>
    <script src="https://cdn.jsdelivr.net/npm/axios/dist/axios.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
    <style>
                body {
            font-family: Arial, sans-serif;
            line-height: 1.6;
            color: #333;
            max-width: 1200px;
            margin: 0 auto;
            padding: 20px;
            background-color: #f5f5f5;
        }
        h1, h2, h3 {
            color: #10a37f;
        }
        textarea, button, select {
            width: 100%;
            padding: 10px;
            margin-bottom: 10px;
            border: 1px solid #ddd;
            border-radius: 4px;
        }
        button {
            background-color: #10a37f;
            color: white;
            border: none;
            cursor: pointer;
            transition: background-color 0.3s;
        }
        button:hover {
            background-color: #0c8f6e;
        }
        button:disabled {
            background-color: #ccc;
            cursor: not-allowed;
        }
        .result {
            background-color: white;
            border: 1px solid #ddd;
            border-radius: 4px;
            padding: 15px;
            margin-top: 20px;
        }
        .loading {
            display: none;
            text-align: center;
            margin-top: 20px;
        }
        .error {
            color: red;
            margin-top: 10px;
        }
        .model-comparison {
            display: flex;
            justify-content: space-between;
        }
        .model-result {
            width: 48%;
        }
        #performance-metrics {
            margin-top: 20px;
        }
        .tab {
            overflow: hidden;
            border: 1px solid #ccc;
            background-color: #f1f1f1;
        }
        .tab button {
            background-color: inherit;
            float: left;
            border: none;
            outline: none;
            cursor: pointer;
            padding: 14px 16px;
            transition: 0.3s;
        }
        .tab button:hover {
            background-color: #ddd;
        }
        .tab button.active {
            background-color: #ccc;
        }
        .tabcontent {
            display: none;
            padding: 6px 12px;
            border: 1px solid #ccc;
            border-top: none;
        }
    </style>
</head>
<body>
    <h1>SLLM Model Test</h1>

    <div class="tab">
       <button class="tablinks" data-tab="SingleQuestion" id="defaultOpen">단일 질문</button>
       <button class="tablinks" data-tab="BatchTest">배치 테스트</button>
    </div>

    <div id="SingleQuestion" class="tabcontent">
        <h2>QUESTION</h2>
        <textarea id="single-question" rows="4" placeholder="질문을 입력하세요..."></textarea>
        <button id="generate-btn">생성 및 비교</button>

        <div id="single-loading" class="loading">생성 중...</div>

        <div class="model-comparison">
            <div class="model-result">
                <h2>Phi-3 RESPONSE</h2>
                <div id="phi3-result" class="result"></div>
            </div>
            <div class="model-result">
                <h2>Gemma RESPONSE</h2>
                <div id="gemma-result" class="result"></div>
            </div>
        </div>

        <div id="single-performance-metrics">
            <h2>성능 비교</h2>
            <canvas id="single-performance-chart"></canvas>
        </div>
    </div>

    <div id="BatchTest" class="tabcontent">
        <h2>배치 테스트</h2>
        <h3>프롬프트</h3>
        <textarea id="batch-prompts" rows="4" placeholder="여러 줄의 프롬프트를 입력하세요..."></textarea>
        <h3>참조 문장 (선택사항)</h3>
        <textarea id="batch-references" rows="4" placeholder="프롬프트에 대응하는 참조 문장을 입력하세요... (선택사항)"></textarea>
        <button id="batch-btn">배치 테스트 실행</button>
        <div id="batch-loading" class="loading">배치 테스트 실행 중...</div>
        <div id="batch-result" class="result"></div>
        <div id="batch-performance-metrics">
            <h2>배치 테스트 성능 비교</h2>
            <canvas id="batch-performance-chart"></canvas>
        </div>
    </div>

    <script>
        {{ javascript_code | safe }}
    </script>
</body>
</html>
"""

# Flask 라우트
@app.route('/')
def index():
   return render_template_string(html_template, javascript_code=javascript_code)

@app.route('/generate', methods=['POST'])
def generate():
    data = request.json
    model_choice = data['model']
    prompt = data['prompt']
    response, gen_time, tokens_per_sec, tokens_gen = generate_response(prompt, model_choice)
    return jsonify({
        'response': response,
        'generation_time': gen_time,
        'tokens_per_second': tokens_per_sec,
        'tokens_generated': tokens_gen
    })

@app.route('/batch', methods=['POST'])
def batch():
    data = request.json
    model_choice = data['model']
    prompts = data['prompts']
    references = data['references']
    results = batch_process(prompts, references, model_choice)
    plot_data = plot_results(results)
    return jsonify({
        'results': results,
        'plot': plot_data
    })

# Colab에서 Flask 앱 실행
# Flask 앱 실행 함수
def is_port_in_use(port):
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        return s.connect_ex(('localhost', port)) == 0

def find_free_port(start_port=5000):
    port = start_port
    while is_port_in_use(port):
        port += 1
    return port

def run_flask(port):
    try:
        app.run(port=port, host='0.0.0.0')
    except OSError as e:
        print(f"Error starting Flask app: {e}")
        return None

def generate_colab_proxy_url(port):
    js = f"""
    async function getProxyUrl() {{
      try {{
        const proxyUrl = await google.colab.kernel.proxyPort({port});
        console.log('Proxy URL:', proxyUrl);
        return proxyUrl;
      }} catch (error) {{
        console.error('Error generating proxy URL:', error);
        return null;
      }}
    }}
    getProxyUrl();
    """
    return output.eval_js(js)

def run_flask_with_colab():
    port = find_free_port()
    print(f"Found free port: {port}")

    flask_thread = Thread(target=run_flask, args=(port,))
    flask_thread.start()

    time.sleep(3)

    proxy_url = generate_colab_proxy_url(port)
    if proxy_url:
        print(f"Your app is available at: {proxy_url}")
    else:
        print("Failed to generate proxy URL. Please check your Colab settings.")

    try:
        while True:
            time.sleep(1)
    except KeyboardInterrupt:
        print("Shutting down...")

if __name__ == '__main__':
    run_flask_with_colab()