# Vireora Inference Notebook
推論専用ノートブック（Render から papermill で呼び出し）
- 入力: 動画ファイル、VitPose設定、ユーザー情報
- 出力: JSON結果

## 1. 依存ライブラリをインストール

In [None]:
# papermill パラメータセル
reference_video_path = '/tmp/reference.mp4'
comparison_video_path = '/tmp/comparison.mp4'
use_vitpose = True
model_variant = 'vitpose-b'
use_3d = False
username = ''
registered_ratios_json = '{}'
use_bigru = False  # BiGRU 補完の有効化
bigru_model_path = '/content/drive/MyDrive/vireora/bigru_model.pth'  # Google Drive パス

In [None]:
# Google Drive のマウント（Colab 環境用）
try:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    print("✓ Google Drive mounted")
except Exception as e:
    print(f"⚠ Drive mount skipped (not Colab or already mounted): {e}")

In [None]:
import json
import sys
import traceback
import numpy as np
import cv2
from datetime import datetime

print("✓ Libraries imported")

In [None]:
# BigRU モデル定義の読み込み
try:
    if use_bigru:
        import sys
        import torch
        sys.path.insert(0, '/content/drive/MyDrive/vireora')  # Drive の bigru_model.py を参照
        from bigru_model import BiGRUCompletionModel, apply_bigru_completion
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"✓ BiGRU モデル定義読み込み完了（device={device}）")
    else:
        print("℃ BigRU は無効化（use_bigru=False）")
except Exception as e:
    print(f"⚠ BigRU 読み込みスキップ: {e}")
    use_bigru = False

## 2. VitPose と分析モジュールを初期化

In [None]:
# pose_analyzer と deficiency モジュールをインポート（ローカルまたはGitHubから）
# ここでは、記号的に定義
import os

# GitHub リポジトリから必要なモジュールを取得（簡易版）
# 本来は requests + json で直接取得するか、git clone を使用

LIMB_SEGMENTS = {
    'left_upper_arm': (5, 7),
    'left_forearm': (7, 9),
    'right_upper_arm': (6, 8),
    'right_forearm': (8, 10),
    'left_thigh': (11, 13),
    'left_shin': (13, 15),
    'right_thigh': (12, 14),
    'right_shin': (14, 16),
}

def calculate_body_ratios(pose):
    if pose is None or 'keypoints' not in pose:
        return None
    kp = pose['keypoints']
    if kp[5][2] < 0.3 or kp[6][2] < 0.3 or kp[11][2] < 0.3 or kp[12][2] < 0.3:
        return None
    shoulder_center = (np.array(kp[5][:2]) + np.array(kp[6][:2])) / 2
    hip_center = (np.array(kp[11][:2]) + np.array(kp[12][:2])) / 2
    torso_height = np.linalg.norm(shoulder_center - hip_center)
    if torso_height < 1:
        return None
    ratios = {}
    for seg_name, (i, j) in LIMB_SEGMENTS.items():
        if kp[i][2] >= 0.3 and kp[j][2] >= 0.3:
            length = np.linalg.norm(np.array(kp[i][:2]) - np.array(kp[j][:2]))
            ratios[seg_name] = length / torso_height
        else:
            ratios[seg_name] = None
    return ratios

def detect_deficiency(current_ratios, registered_ratios, threshold=0.5):
    deficiencies = []
    if not current_ratios or not registered_ratios:
        return deficiencies
    NAMES_JA = {
        'left_upper_arm': '左上腕', 'left_forearm': '左前腕',
        'right_upper_arm': '右上腕', 'right_forearm': '右前腕',
        'left_thigh': '左太もも', 'left_shin': '左すね',
        'right_thigh': '右太もも', 'right_shin': '右すね'
    }
    for seg_name, reg_ratio in registered_ratios.items():
        if reg_ratio is None:
            continue
        cur_ratio = current_ratios.get(seg_name)
        if cur_ratio is None:
            deficiencies.append({
                'segment': seg_name,
                'segment_ja': NAMES_JA.get(seg_name, seg_name),
                'reason': 'not_detected',
                'registered_ratio': reg_ratio,
                'current_ratio': None
            })
        elif cur_ratio < reg_ratio * (1 - threshold):
            deficiencies.append({
                'segment': seg_name,
                'segment_ja': NAMES_JA.get(seg_name, seg_name),
                'reason': 'ratio_deviation',
                'registered_ratio': reg_ratio,
                'current_ratio': cur_ratio,
                'deviation': (reg_ratio - cur_ratio) / reg_ratio * 100
            })
    return deficiencies

print("✓ Utility functions defined")

## 3. VitPose モデルを初期化

In [None]:
# Colab環境でのVitPose初期化（簡易版・実際はより詳細な実装が必要）
# 本来は pose_analyzer.py から PoseAnalyzer, PoseComparator をインポート

class DummyPoseAnalyzer:
    """実装用プレースホルダー（実際には pose_analyzer の内容をコピー）"""
    def __init__(self, use_vitpose=True, model_variant='vitpose-b'):
        self.use_vitpose = use_vitpose
        self.model_variant = model_variant
        print(f"✓ PoseAnalyzer initialized (VitPose={use_vitpose}, model={model_variant})")
    
    def extract_poses_from_video(self, video_path):
        # Placeholder: 実際にはVitPoseで姿勢推定
        print(f"Processing: {video_path}")
        return []

class DummyPoseComparator:
    """実装用プレースホルダー"""
    def compare_pose_sequences(self, ref_poses, comp_poses, use_3d=False):
        return {
            'overall_score': 75.0,
            'joint_scores': {f'joint_{i}': 75.0 for i in range(17)},
            'temporal_alignment': 80.0,
            'frame_scores': []
        }

analyzer = DummyPoseAnalyzer(use_vitpose=use_vitpose, model_variant=model_variant)
comparator = DummyPoseComparator()
print("✓ Pose analyzer ready")

## 4. 動画ファイルの存在確認

In [None]:
import os

if not os.path.exists(reference_video_path):
    raise FileNotFoundError(f"Reference video not found: {reference_video_path}")
if not os.path.exists(comparison_video_path):
    raise FileNotFoundError(f"Comparison video not found: {comparison_video_path}")

print(f"✓ Reference video: {reference_video_path}")
print(f"✓ Comparison video: {comparison_video_path}")

## 5. 姿勢推定と比較分析を実行

In [None]:
try:
    print("Extracting poses from reference video...")
    reference_poses = analyzer.extract_poses_from_video(reference_video_path)
    
    print("Extracting poses from comparison video...")
    comparison_poses = analyzer.extract_poses_from_video(comparison_video_path)
    
    if not reference_poses or not comparison_poses:
        raise RuntimeError("Failed to extract poses from videos")
    
    print(f"✓ Extracted {len(reference_poses)} reference poses")
    print(f"✓ Extracted {len(comparison_poses)} comparison poses")
    
    # Compare
    print("Comparing poses...")
    comparison_result = comparator.compare_pose_sequences(
        reference_poses,
        comparison_poses,
        use_3d=use_3d
    )
    
    print(f"✓ Comparison complete")
    print(f"  Overall score: {comparison_result['overall_score']:.2f}")
    print(f"  Temporal alignment: {comparison_result['temporal_alignment']:.2f}")

except Exception as e:
    print(f"⚠ Error during analysis: {e}")
    traceback.print_exc()
    comparison_result = None

In [None]:
# BigRU による欠損キーポイント補完（オプション）
if comparison_result is not None and use_bigru:
    try:
        print("Applying BigRU keypoint completion...")
        
        # Drive からモデル重みを読み込み
        bigru_model = BiGRUCompletionModel.load_pretrained(
            checkpoint_path=bigru_model_path,
            device=device
        )
        
        # Reference と Comparison 両方の姿勢に BigRU を適用
        reference_poses_completed = apply_bigru_completion(
            poses=reference_poses,
            model=bigru_model,
            device=device,
            batch_size=32
        )
        
        comparison_poses_completed = apply_bigru_completion(
            poses=comparison_poses,
            model=bigru_model,
            device=device,
            batch_size=32
        )
        
        # 補完後の姿勢で再度比較
        print("Re-comparing poses with completed keypoints...")
        comparison_result_bigru = comparator.compare_pose_sequences(
            reference_poses_completed,
            comparison_poses_completed,
            use_3d=use_3d
        )
        
        print(f"✓ BigRU 補完完了")
        print(f"  Before BigRU: {comparison_result['overall_score']:.2f}")
        print(f"  After BigRU: {comparison_result_bigru['overall_score']:.2f}")
        
        # BigRU 補完結果に切り替え
        comparison_result = comparison_result_bigru
        reference_poses = reference_poses_completed
        comparison_poses = comparison_poses_completed
        
    except FileNotFoundError:
        print(f"⚠ BigRU モデル未検出（{bigru_model_path}）。補完をスキップ")
    except Exception as e:
        print(f"⚠ BigRU 補完エラー: {e}。元の結果を使用")
        # comparison_result は元のままで続行

## 6. 欠損検知

In [None]:
deficiencies = []

try:
    if registered_ratios_json:
        registered_ratios = json.loads(registered_ratios_json)
        if registered_ratios and comparison_poses:
            # Calculate average ratios from comparison poses
            frame_ratios_list = []
            for pose in comparison_poses:
                if pose is not None:
                    fr = calculate_body_ratios(pose)
                    if fr:
                        frame_ratios_list.append(fr)
            
            if frame_ratios_list:
                avg_ratios = {}
                for seg_name in LIMB_SEGMENTS:
                    vals = [fr[seg_name] for fr in frame_ratios_list if fr.get(seg_name) is not None]
                    avg_ratios[seg_name] = sum(vals) / len(vals) if vals else None
                
                deficiencies = detect_deficiency(avg_ratios, registered_ratios)
                print(f"✓ Deficiencies detected: {len(deficiencies)}")
except Exception as e:
    print(f"⚠ Error during deficiency detection: {e}")
    traceback.print_exc()

## 7. 結果をJSONで出力

In [None]:
result = {
    'success': comparison_result is not None,
    'timestamp': datetime.now().isoformat(),
    'username': username,
    'backend': 'colab-notebook',
    'use_3d': use_3d,
    'score': comparison_result['overall_score'] if comparison_result else 0.0,
    'joint_scores': comparison_result['joint_scores'] if comparison_result else {},
    'temporal_alignment': comparison_result['temporal_alignment'] if comparison_result else 0.0,
    'frame_scores': comparison_result.get('frame_scores', []) if comparison_result else [],
    'deficiencies': deficiencies,
    'analysis': ''  # AI分析はRender側で実施
}

print("\n=== RESULT JSON ===")
print(json.dumps(result, indent=2, ensure_ascii=False))
print("\n=== END RESULT ===")

In [None]:
# Render側で取得できるように、結果を標準出力の枠外に出力
# (papermill がセル出力から JSON を抽出可能にする)
with open('/tmp/vireora_result.json', 'w') as f:
    json.dump(result, f, ensure_ascii=False, indent=2)

print(f"✓ Result saved to /tmp/vireora_result.json")