In [None]:
%%capture
!pip install flask flask-cors transformers sentence-transformers pyngrok python-dotenv

In [None]:
import os
import json
import numpy as np
from datetime import datetime
from typing import List, Dict, Any, Tuple
import pickle
import logging
from flask import Flask, request, jsonify, render_template_string
from flask_cors import CORS
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from pyngrok import ngrok
import threading
import time

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
your_ngrok_authtoken = "your_ngrok_authtoken"

In [None]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
class MemoryManager:
  def __init__(self, drive_path="/content/drive/MyDrive/memory_api_data"):
    self.drive_path = drive_path
    self.memory_file = os.path.join(drive_path, "conversations.json")
    self.embeddings_file = os.path.join(drive_path, "embeddings.pkl")

    # ディレクトリが存在することを確認する
    os.makedirs(drive_path, exist_ok=True)

    # load model
    logger.info("PLaMo-embedding-1bモデルをロードしています...")
    self.model = SentenceTransformer('pfnet/plamo-embedding-1b', trust_remote_code=True)
    logger.info("モデルのロードが完了しました")

    # 記憶データをロードする
    self.conversations = self._load_conversations()
    self.embeddings = self._load_embeddings()

    # 類似度のしきい値
    self.similarity_threshold = 0.7
    self.max_memory_items = 100

  def _load_conversations(self) -> List[Dict]:
    """会話履歴をロードする"""
    if os.path.exists(self.memory_file):
      try:
        with open(self.memory_file, 'r', encoding='utf-8') as f:
          return json.load(f)
      except Exception as e:
        logger.error(f"会話履歴のロードに失敗しました: {e}")
        return []
    return []

  def _save_conversations(self):
    """会話履歴を保存する"""
    try:
      with open(self.memory_file, 'w', encoding='utf-8') as f:
        json.dump(self.conversations, f, ensure_ascii=False, indent=2)
    except Exception as e:
      logger.error(f"会話履歴の保存に失敗しました: {e}")

  def _load_embeddings(self) -> List[np.ndarray]:
    """埋め込みベクトルをロードする"""
    if os.path.exists(self.embeddings_file):
      try:
        with open(self.embeddings_file, 'rb') as f:
          return pickle.load(f)
      except Exception as e:
        logger.error(f"埋め込みベクトルのロードに失敗しました: {e}")
        return []
    return []

  def _save_embeddings(self):
    """埋め込みベクトルを保存する"""
    try:
      with open(self.embeddings_file, 'wb') as f:
        pickle.dump(self.embeddings, f)
    except Exception as e:
      logger.error(f"埋め込みベクトルの保存に失敗しました: {e}")

  def add_conversation(self, user_input: str, assistant_response: str, user_id: str = "default"):
    """
    新しい会話履歴を追加する
    Args:
      user_input: ユーザーの入力
      assistant_response: 応答
      user_id: ユーザーID
    """
    conversation = {
      "id": len(self.conversations),
      "user_id": user_id,
      "user_input": user_input,
      "assistant_response": assistant_response,
      "timestamp": datetime.now().isoformat(),
    }

    # 埋め込みベクトルを計算する
    try:
      embedding = self.model.encode(user_input)

      self.conversations.append(conversation)
      self.embeddings.append(embedding)

      # メモリの最大数を超えた場合は古い履歴を削除する
      if len(self.conversations) > self.max_memory_items:
        self.conversations = self.conversations[-self.max_memory_items:]
        self.embeddings = self.embeddings[-self.max_memory_items:]

      # Driveに保存する
      self._save_conversations()
      self._save_embeddings()

      logger.info(f"新しい会話履歴を追加しました: {conversation['id']}")

    except Exception as e:
      logger.error(f"会話履歴の追加に失敗しました: {e}")

  def find_similar_conversations(self, query: str, user_id: str = "default", top_k: int = 3) -> List[Dict]:
    """
    類似する会話履歴を検索する
    Args:
      query: 検索テキスト
      user_id: ユーザーID
      top_k: 上位K件の結果を返す
    Returns:
      類似する会話履歴のリスト
    """
    if not self.conversations or not self.embeddings:
      return []

    try:
      # クエリの埋め込みベクトルを計算する
      query_embedding = self.model.encode(query).reshape(1, -1)

      # ユーザーの会話履歴をフィルタリングする
      user_conversations = []
      user_embeddings = []

      for i, conv in enumerate(self.conversations):
        if conv.get("user_id", "default") == user_id:
          user_conversations.append(conv)
          user_embeddings.append(self.embeddings[i])

      if not user_conversations:
        return []

      # 類似度を計算する
      user_embeddings_matrix = np.vstack(user_embeddings)
      similarities = cosine_similarity(query_embedding, user_embeddings_matrix)[0]

      # 類似度がしきい値を超えた会話を探す
      similar_indices = []
      for i, sim in enumerate(similarities):
        if sim >= self.similarity_threshold:
          similar_indices.append((i, sim))

      # 類似度でソートする
      similar_indices.sort(key=lambda x: x[1], reverse=True)

      # 上位top_k件の結果を返す
      results = []
      for i, sim in similar_indices[:top_k]:
        conv = user_conversations[i].copy()
        conv["similarity"] = float(sim)
        results.append(conv)

      logger.info(f"見つかった類似対話記録: {len(results)}件")
      return results

    except Exception as e:
      logger.error(f"類似対話記録の検索に失敗しました: {e}")
      return []

  def get_conversation_stats(self) -> Dict:
    """対話の統計情報を取得する"""
    return {
      "total_conversations": len(self.conversations),
      "total_embeddings": len(self.embeddings),
      "similarity_threshold": self.similarity_threshold,
      "max_memory_items": self.max_memory_items
    }

In [None]:
app = Flask(__name__)
CORS(app)

memory_manager = MemoryManager()

@app.route('/api/health', methods=['GET'])
def health_check():
  """health"""
  return jsonify({
    "status": "healthy",
    "timestamp": datetime.now().isoformat(),
    "service": "Memory API"
  })

@app.route('/api/stats', methods=['GET'])
def get_stats():
  """統計情報を取得する"""
  stats = memory_manager.get_conversation_stats()
  return jsonify(stats)

@app.route('/api/query', methods=['POST'])
def query_memory():
  """類似した会話履歴を検索する"""
  try:
    data = request.get_json()

    if not data or 'query' not in data:
      return jsonify({"error": "queryパラメータが不足しています"}), 400

    query = data['query']
    user_id = data.get('user_id', 'default')
    top_k = data.get('top_k', 3)

    # 類似した会話を探す
    similar_conversations = memory_manager.find_similar_conversations(
      query=query,
      user_id=user_id,
      top_k=top_k
    )

    response = {
      "query": query,
      "user_id": user_id,
      "found_memories": len(similar_conversations),
      "has_relevant_memory": len(similar_conversations) > 0,
      "memories": similar_conversations,
      "timestamp": datetime.now().isoformat()
    }

    return jsonify(response)

  except Exception as e:
    logger.error(f"記憶の検索に失敗しました: {e}")
    return jsonify({"error": str(e)}), 500

@app.route('/api/add', methods=['POST'])
def add_memory():
  """会話履歴を追加する"""
  try:
    data = request.get_json()

    if not data or 'user_input' not in data or 'assistant_response' not in data:
      return jsonify({"error": "必要なパラメータが不足しています"}), 400

    user_input = data['user_input']
    assistant_response = data['assistant_response']
    user_id = data.get('user_id', 'default')

    # 会話履歴を追加する
    memory_manager.add_conversation(
      user_input=user_input,
      assistant_response=assistant_response,
      user_id=user_id
    )

    response = {
      "status": "success",
      "message": "会話履歴が追加されました",
      "user_id": user_id,
      "timestamp": datetime.now().isoformat()
    }

    return jsonify(response)

  except Exception as e:
    logger.error(f"記憶の追加に失敗しました: {e}")
    return jsonify({"error": str(e)}), 500

def run_flask_app():
  app.run(host='0.0.0.0', port=5000, debug=False)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
if __name__ == "__main__":
  ngrok.set_auth_token(your_ngrok_authtoken)

  # Flaskをバックグラウンドで起動
  flask_thread = threading.Thread(target=run_flask_app)
  flask_thread.daemon = True
  flask_thread.start()

  # Flaskが起動するのを待つ
  time.sleep(5)

  # ngrokトンネルを開始
  try:
    public_url = ngrok.connect(5000)
    print(f"API_URL: {public_url}")
    print(f"クエリAPI: {public_url}/api/query")
    print(f"追加API: {public_url}/api/add")
    print(f"統計API: {public_url}/api/stats")
    print(f"health: {public_url}/api/health")

    # サービスを稼働させ続ける
    print("\nサービスが起動しました")
    try:
      while True:
        time.sleep(1)
    except KeyboardInterrupt:
      print("\nサービスを停止中...")
      ngrok.disconnect(public_url)

  except Exception as e:
    print(f"ngrokの起動に失敗しました: {e}")

 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5000
 * Running on http://172.28.0.12:5000
INFO:werkzeug:[33mPress CTRL+C to quit[0m


API_URL: NgrokTunnel: "https://a729-34-105-105-239.ngrok-free.app" -> "http://localhost:5000"
クエリAPI: NgrokTunnel: "https://a729-34-105-105-239.ngrok-free.app" -> "http://localhost:5000"/api/query
追加API: NgrokTunnel: "https://a729-34-105-105-239.ngrok-free.app" -> "http://localhost:5000"/api/add
統計API: NgrokTunnel: "https://a729-34-105-105-239.ngrok-free.app" -> "http://localhost:5000"/api/stats
health: NgrokTunnel: "https://a729-34-105-105-239.ngrok-free.app" -> "http://localhost:5000"/api/health

サービスが起動しました


INFO:werkzeug:127.0.0.1 - - [04/Jun/2025 19:31:13] "POST /api/add HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [04/Jun/2025 19:31:13] "POST /api/add HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [04/Jun/2025 19:31:13] "POST /api/add HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [04/Jun/2025 19:31:18] "POST /api/query HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [04/Jun/2025 19:31:19] "POST /api/query HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [04/Jun/2025 19:31:19] "POST /api/query HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [04/Jun/2025 19:31:23] "GET /api/stats HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [04/Jun/2025 19:31:26] "GET /api/health HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [04/Jun/2025 19:32:01] "POST /api/query HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [04/Jun/2025 19:32:03] "POST /api/add HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [04/Jun/2025 19:36:44] "POST /api/add HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [04/Jun/2025 19:36:44] "POST /api/add HTTP/1.1" 200 -
INFO: