### Task 2

In [2]:
import math
import torch
import torch.nn.functional as F
import torch
import time
from chop.models import get_model
from chop.dataset import get_dataset_info


def timed_gpu(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000


def timed_cpu(fn):
    start = time.time()
    result = fn()
    return result, time.time() - start


def get_data(device="cpu"):
    query = torch.ones(32, 8, 128, 64, dtype=torch.bfloat16, device=device)  # float16改成了bfloat16
    key = torch.ones(32, 8, 128, 64, dtype=torch.bfloat16, device=device)
    value = torch.ones(32, 8, 128, 64, dtype=torch.bfloat16, device=device)
    return [query, key, value]


def time_model(fn, n=1000, device="cpu"):
    times = []
    data = get_data(device=device)
    for _ in range(n):
        if device == "cpu":
            _, t = timed_cpu(lambda: fn(data[0].cpu(), data[1].cpu(), data[2].cpu()))
        else:
            _, t = timed_gpu(lambda: fn(data[0], data[1], data[2]))
        times.append(t)
    avg_time = sum(times) / len(times)
    return avg_time


class ScaledDotProductAttention(torch.nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, query, key, value):
        scale_factor = 1 / math.sqrt(query.size(-1))
        score = query @ key.transpose(-2, -1) * scale_factor
        attn = F.softmax(score, -1)
        context = attn @ value
        return context


class ScaledDotProductAttentionFused(torch.nn.Module):
    def forward(self, query, key, value):
        return F.scaled_dot_product_attention(query, key, value)


device = "cpu"
n = 100

model_naive = ScaledDotProductAttention()
model_fused = ScaledDotProductAttentionFused()

model_naive.to(device)
model_fused.to(device)
avg_t = time_model(model_naive, n=n, device=device)
fused_avg_t = time_model(model_fused, n=n, device=device)
print(f"Naive model: {avg_t:.4f} s")
print(f"Fused model: {fused_avg_t:.4f} s")


# cpu:
# Naive model: 0.6486 s
# Fused model: 0.0248 s

# gpu:
# Naive model: 0.0020 s
# Fused model: 0.0003 s

Naive model: 0.6493 s
Fused model: 0.0252 s
