In [1]:
!wget -nc https://raw.githubusercontent.com/SavinovSergey/LLM_Scaling_Week/main/mixture_of_experts/MOE_padding.py --no-check-certificate

--2025-11-25 18:00:54--  https://raw.githubusercontent.com/SavinovSergey/LLM_Scaling_Week/main/mixture_of_experts/MOE_padding.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3591 (3.5K) [text/plain]
Saving to: ‘MOE_padding.py’


2025-11-25 18:00:54 (48.3 MB/s) - ‘MOE_padding.py’ saved [3591/3591]



In [2]:
import torch
import time
from typing import Tuple, List, Dict

from MOE_padding import moe_padding, torch_basic

In [3]:
def verify_implementation():
    """Проверка корректности реализации"""
    print("Проверка корректности реализации...")
    
    # Тестовые данные
    x = torch.tensor([
        [-0.0236, -0.5368, -0.5663],
        [ 0.7778, -0.8583, -0.1123],
        [ 0.1981, -0.3514, -0.9443],
        [-2.0655, -0.9424,  0.9870]
    ], dtype=torch.float32)
    
    top_experts = torch.tensor([
        [1, 3],
        [2, 5], 
        [3, 5],
        [2, 4]
    ], dtype=torch.long)
    
    tokens_per_expert = torch.bincount(top_experts.flatten(), minlength=6)
    
    # Запуск обеих реализаций
    result_moe_padding, padded_moe_padding = moe_padding(x, top_experts, tokens_per_expert, 2, 6)
    result_basic, padded_basic = torch_basic(x, top_experts, tokens_per_expert, 2, 6)
    
    # Проверка padded_tokens_per_expert
    print("padded_tokens_per_expert moe_padding:", padded_moe_padding)
    print("padded_tokens_per_expert basic:", padded_basic)
    
    if not torch.equal(padded_moe_padding, padded_basic):
        print("ОШИБКА: padded_tokens_per_expert не совпадают!")
        return False
    
    # Проверка основных результатов
    non_zero_mask_moe_padding = result_moe_padding.abs().sum(dim=1) > 0
    non_zero_mask_basic = result_basic.abs().sum(dim=1) > 0
    
    non_zero_moe_padding = result_moe_padding[non_zero_mask_moe_padding]
    non_zero_basic = result_basic[non_zero_mask_basic]
    
    if not torch.allclose(non_zero_moe_padding, non_zero_basic, rtol=1e-5, atol=1e-6):
        print("ОШИБКА: результаты не совпадают!")
        print("Moe_padding non-zero:", non_zero_moe_padding)
        print("Basic non-zero:", non_zero_basic)
        return False
    
    print("✓ Корректность проверена успешно!")
    return True

In [4]:
def generate_performance_test_cases() -> List[Dict]:
    """Генерация тестовых случаев для измерения производительности"""
    test_cases = []
    
    # Различные сценарии
    scenarios = [
        # (num_tokens, hidden_size, topk, num_experts, description)
        (1000, 512, 2, 8, "small_uniform"),
        (5000, 1024, 4, 16, "medium_uniform"), 
        (20000, 2048, 8, 32, "large_uniform"),
        (10000, 1024, 4, 32, "medium_imbalanced"),
    ]
    
    for num_tokens, hidden_size, topk, num_experts, desc in scenarios:
        # Генерация данных
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        x = torch.randn(num_tokens, hidden_size, device=device)
        
        if "imbalanced" in desc:
            # Создание неравномерного распределения
            top_experts = torch.zeros(num_tokens, topk, dtype=torch.long, device=device)
            popular_experts = torch.randint(0, num_experts // 4, (num_experts // 4,))
            
            for i in range(num_tokens):
                if i % 3 == 0:  # 33% токенов выбирают популярных экспертов
                    choices = torch.randint(0, len(popular_experts), (topk,))
                    top_experts[i] = popular_experts[choices]
                else:
                    top_experts[i] = torch.randint(0, num_experts, (topk,))
        else:
            # Равномерное распределение
            top_experts = torch.randint(0, num_experts, (num_tokens, topk), device=device)
        
        tokens_per_expert = torch.bincount(top_experts.flatten(), minlength=num_experts)
        
        test_cases.append({
            'x': x,
            'top_experts': top_experts,
            'tokens_per_expert': tokens_per_expert,
            'topk': topk,
            'num_experts': num_experts,
            'description': desc,
            'num_tokens': num_tokens,
            'hidden_size': hidden_size
        })
    
    return test_cases


def benchmark_implementation(func, test_case, num_warmup=5, num_runs=20):
    """Бенчмарк реализации"""
    device = test_case['x'].device
    
    # Прогрев
    for _ in range(num_warmup):
        _ = func(**{k: test_case[k] for k in ['x', 'top_experts', 'tokens_per_expert', 'topk', 'num_experts']})
    
    if device.type == 'cuda':
        torch.cuda.synchronize()
        start_memory = torch.cuda.memory_allocated()
    
    # Измерение времени
    start_time = time.time()
    for _ in range(num_runs):
        result = func(**{k: test_case[k] for k in ['x', 'top_experts', 'tokens_per_expert', 'topk', 'num_experts']})
    end_time = time.time()
    
    if device.type == 'cuda':
        torch.cuda.synchronize()
        end_memory = torch.cuda.memory_allocated()
        memory_used = (end_memory - start_memory) / 1024**2  # MB
    else:
        memory_used = 0
    
    avg_time = (end_time - start_time) / num_runs
    
    return avg_time, memory_used, result

In [5]:
def run_performance_comparison():
    """Запуск сравнения производительности"""
    print("Запуск сравнения производительности...")
    print("=" * 80)
    
    # Сначала проверяем корректность
    if not verify_implementation():
        print("Прерывание: реализация некорректна!")
        return
    
    # Генерация тестовых случаев
    test_cases = generate_performance_test_cases()
    
    results = []
    
    for i, test_case in enumerate(test_cases):
        print(f"\nТест {i+1}: {test_case['description']}")
        print(f"  Токены: {test_case['num_tokens']}, Hidden: {test_case['hidden_size']}, "
              f"TopK: {test_case['topk']}, Эксперты: {test_case['num_experts']}")
        
        # Бенчмарк moe_padding
        moe_time, moe_memory, moe_result = benchmark_implementation(moe_padding, test_case)
        
        # Бенчмарк torch_basic
        basic_time, basic_memory, basic_result = benchmark_implementation(torch_basic, test_case)
        
        # Проверка корректности результатов
        moe_padded, moe_counts = moe_result
        basic_padded, basic_counts = basic_result
        
        counts_match = torch.equal(moe_counts, basic_counts)
        
        # Сравнение ненулевых элементов
        moe_nonzero = moe_padded[moe_padded.abs().sum(dim=1) > 0]
        basic_nonzero = basic_padded[basic_padded.abs().sum(dim=1) > 0]
        data_match = torch.allclose(moe_nonzero, basic_nonzero, rtol=1e-5, atol=1e-6)
        
        # Расчет ускорения
        speedup = basic_time / moe_time
        
        print(f"  moe_padding: {moe_time*1000:.2f} ms, {moe_memory:.1f} MB")
        print(f"  torch_basic: {basic_time*1000:.2f} ms, {basic_memory:.1f} MB")
        print(f"  Ускорение: {speedup:.2f}x")
        print(f"  Корректность: counts={counts_match}, data={data_match}")
        
        results.append({
            'test': test_case['description'],
            'moe_padding_time': moe_time,
            'basic_time': basic_time,
            'speedup': speedup,
            'moe_padding_memory': moe_memory,
            'basic_memory': basic_memory,
            'correct': counts_match and data_match
        })
    
    # Вывод итогов
    print("\n" + "=" * 80)
    print("ИТОГИ СРАВНЕНИЯ ПРОИЗВОДИТЕЛЬНОСТИ")
    print("=" * 80)
    
    for result in results:
        status = "✓" if result['correct'] else "✗"
        print(f"{result['test']:20} | {status} | Ускорение: {result['speedup']:6.2f}x | "
              f"Время: {result['moe_padding_time']*1000:6.2f}ms vs {result['basic_time']*1000:6.2f}ms")
    
    avg_speedup = torch.tensor([r['speedup'] for r in results if r['correct']]).mean().item()
    print(f"\nСреднее ускорение: {avg_speedup:.2f}x")
    
    return results

In [6]:
def memory_profiling():
    """Профилирование использования памяти"""
    print("\nПрофилирование использования памяти...")
    
    # Большой тестовый случай
    test_case = {
        'num_tokens': 20000,
        'hidden_size': 2048,
        'topk': 8,
        'num_experts': 32,
        'description': 'memory_profiling'
    }
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    x = torch.randn(test_case['num_tokens'], test_case['hidden_size'], device=device)
    top_experts = torch.randint(0, test_case['num_experts'], 
                               (test_case['num_tokens'], test_case['topk']), device=device)
    tokens_per_expert = torch.bincount(top_experts.flatten(), minlength=test_case['num_experts'])
    
    test_case.update({
        'x': x,
        'top_experts': top_experts,
        'tokens_per_expert': tokens_per_expert
    })
    
    if device.type == 'cuda':
        torch.cuda.reset_peak_memory_stats()
        
        # moe_padding
        result1 = moe_padding(**{k: test_case[k] for k in ['x', 'top_experts', 'tokens_per_expert', 'topk', 'num_experts']})
        torch.cuda.synchronize()
        moe_memory = torch.cuda.max_memory_allocated() / 1024**2
        
        torch.cuda.reset_peak_memory_stats()
        
        # torch_basic
        result2 = torch_basic(**{k: test_case[k] for k in ['x', 'top_experts', 'tokens_per_expert', 'topk', 'num_experts']})
        torch.cuda.synchronize()
        basic_memory = torch.cuda.max_memory_allocated() / 1024**2
        
        print(f"Пиковое использование памяти:")
        print(f"  moe_padding: {moe_memory:.1f} MB")
        print(f"  torch_basic: {basic_memory:.1f} MB")
        print(f"  Экономия: {basic_memory - moe_memory:.1f} MB ({(basic_memory - moe_memory)/basic_memory*100:.1f}%)")

In [7]:
results = run_performance_comparison()

if torch.cuda.is_available():
    memory_profiling()

Запуск сравнения производительности...
Проверка корректности реализации...
padded_tokens_per_expert moe_padding: tensor([  0, 128, 128, 128, 128, 128], dtype=torch.int32)
padded_tokens_per_expert basic: tensor([  0, 128, 128, 128, 128, 128], dtype=torch.int32)
✓ Корректность проверена успешно!

Тест 1: small_uniform
  Токены: 1000, Hidden: 512, TopK: 2, Эксперты: 8
  moe_padding: 0.72 ms, 0.0 MB
  torch_basic: 98.32 ms, 0.0 MB
  Ускорение: 136.32x
  Корректность: counts=True, data=True

Тест 2: medium_uniform
  Токены: 5000, Hidden: 1024, TopK: 4, Эксперты: 16
  moe_padding: 3.05 ms, 0.0 MB
  torch_basic: 927.51 ms, 0.0 MB
  Ускорение: 303.82x
  Корректность: counts=True, data=True

Тест 3: large_uniform
  Токены: 20000, Hidden: 2048, TopK: 8, Эксперты: 32
  moe_padding: 28.68 ms, 0.0 MB
  torch_basic: 7165.88 ms, 0.0 MB
  Ускорение: 249.85x
  Корректность: counts=True, data=True

Тест 4: medium_imbalanced
  Токены: 10000, Hidden: 1024, TopK: 4, Эксперты: 32
  moe_padding: 5.38 ms, 0.0