In [None]:
from natten import NeighborhoodAttention1D
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.data import Dataset, DataLoader
import numpy as np

In [None]:
def generate_trajectory(length, noise_std=0.1):
    """
    生成一条模拟轨迹数据：
      - x 坐标从 0 到 length-1 的线性变化
      - y 坐标基于一个正弦曲线，再加上一些噪声
    """
    x = np.linspace(0, 10, length)  # x 坐标范围
    y = np.sin(x) + np.random.normal(scale=noise_std, size=length)
    # 将 x, y 拼接成 (length, 2)
    traj = np.stack([x, y], axis=-1)
    return traj.astype(np.float32)

class TrajectoryDataset(Dataset):
    def __init__(self, num_samples, seq_len):
        super(TrajectoryDataset, self).__init__()
        self.num_samples = num_samples
        self.seq_len = seq_len
        # 构造数据：每个样本是一条轨迹
        self.data = [generate_trajectory(seq_len) for _ in range(num_samples)]
        # 目标可以设计为平滑后的轨迹，或者预测未来轨迹，这里我们简单将目标设为原轨迹（自监督）
        self.target = self.data

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # 返回 shape: (seq_len, 2)
        return torch.tensor(self.data[idx]), torch.tensor(self.target[idx])

# 构造数据集和 DataLoader
num_samples = 1024
seq_len = 50  # 序列长度
batch_size = 32
train_dataset = TrajectoryDataset(num_samples, seq_len)
val_dataset = TrajectoryDataset(256, seq_len)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# 假设我们的模型输入输出维度为 2（x, y），或者可以先通过一个线性层映射到更高维度
class SimpleAttentionModel(nn.Module):
    def __init__(self, attn_module, dim):
        super(SimpleAttentionModel, self).__init__()
        self.attn = attn_module
        self.fc = nn.Linear(dim, dim)
        
    def forward(self, x):
        # x: (B, L, dim)
        out = self.attn(x, x, x)  # 对于自注意力，query, key, value 均为 x
        out = self.fc(out)
        return out

In [None]:
class CustomizedNeighborhoodAttention1D_MH(nn.Module):
    def __init__(self, radius, dim, num_heads, qkv_bias=True, attn_drop=0.0, proj_drop=0.0):
        """
        多头邻域注意力模块。
        参数：
            radius: int，邻域半径（窗口大小 = 2*radius + 1）
            dim: int，输入特征总维度
            num_heads: int，多头数，要求 dim % num_heads == 0
            qkv_bias: bool，是否有偏置
            attn_drop, proj_drop: dropout 概率
        """
        super(CustomizedNeighborhoodAttention1D_MH, self).__init__()
        self.radius = radius
        self.window_size = 2 * radius + 1
        self.dim = dim
        self.num_heads = num_heads
        assert dim % num_heads == 0, "dim must be divisible by num_heads"
        self.head_dim = dim // num_heads
        self.scale = math.sqrt(self.head_dim)
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        """
        参数：
            x: Tensor, shape = (B, L, dim)
        返回：
            out: Tensor, shape = (B, L, dim)
        """
        B, L, _ = x.shape
        # 生成 q, k, v
        qkv = self.qkv(x)  # (B, L, 3*dim)
        # print(f"qkv.shape = {qkv.shape}")
        qkv = qkv.reshape(B, L, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)  # (3, B, num_heads, L, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]  # each: (B, num_heads, L, head_dim)
        # 将 B 和 num_heads 合并
        q = q.reshape(B * self.num_heads, L, self.head_dim)
        k = k.reshape(B * self.num_heads, L, self.head_dim)
        v = v.reshape(B * self.num_heads, L, self.head_dim)
        
        # 对 k 和 v 沿序列维度填充，使得每个位置都可以构造一个完整的局部窗口
        # 首先将 k,v 转置为 (B', head_dim, L)
        k_t = k.transpose(1, 2)
        v_t = v.transpose(1, 2)
        # 填充最后一维（原来的 L 维），pad=(left, right)
        k_padded = F.pad(k_t, pad=(self.radius, self.radius), mode='constant', value=0)
        v_padded = F.pad(v_t, pad=(self.radius, self.radius), mode='constant', value=0)
        # 转回来 (B', L+2*radius, head_dim)
        k_padded = k_padded.transpose(1, 2)
        v_padded = v_padded.transpose(1, 2)
        # print(f"k = {k}")
        # print(f"k_padded = {k_padded}")
        # print(f"v = {v}")
        # print(f"v_padded = {v_padded}")
        # 使用 unfold 在序列维度提取局部窗口，得到 (B', L, window_size, head_dim)
        k_windows = k_padded.unfold(dimension=1, size=self.window_size, step=1)
        v_windows = v_padded.unfold(dimension=1, size=self.window_size, step=1)
        # print(f"k = {k}")
        # print(f"k_windows = {k_windows}")
        # print(f"v = {v}")
        # print(f"v_windows = {v_windows}")
        
        # 如果 unfold 结果形状为 (B', L, head_dim, window_size) 则需要转置最后两个维度
        if k_windows.shape[-2] == self.head_dim and k_windows.shape[-1] == self.window_size:
            k_windows = k_windows.transpose(-2, -1)
            v_windows = v_windows.transpose(-2, -1)
        
        # 计算注意力分数：点积
        # q: (B', L, head_dim) 自动扩展为 (B', L, 1, head_dim)
        # k_windows: (B', L, window_size, head_dim)
        scores = torch.einsum('bld,blwd->blw', q, k_windows) / self.scale  # (B', L, window_size)
        # print(f"scores = {scores}")
        attn = torch.softmax(scores, dim=-1)
        # print(f"attn = {attn}")
        attn = self.attn_drop(attn)
        # print(f"attn_drop = {attn}")
        # 加权求和，输出 (B', L, head_dim)
        out = torch.einsum('blw,blwd->bld', attn, v_windows)
        # 还原形状 (B, num_heads, L, head_dim)
        out = out.reshape(B, self.num_heads, L, self.head_dim)
        # 合并头 (B, L, dim)
        out = out.transpose(1, 2).reshape(B, L, self.dim)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out

In [None]:
# -------------------- 简单模型封装 --------------------
class AttentionModel(nn.Module):
    def __init__(self, attn_module, dim):
        """
        使用指定的邻域注意力模块封装一个简单的模型，
        模型结构：输入 -> Attention -> 全连接层 -> 输出
        假设任务为回归，输入和输出形状均为 (B, L, dim)
        """
        super(AttentionModel, self).__init__()
        self.dim = dim
        self.attn = attn_module
        self.fc = nn.Linear(self.dim, self.dim)

    def forward(self, x):
        # x: (B, L, dim)
        out = self.attn(x)
        out = self.fc(out)
        return out

# -------------------- 构造数据集 --------------------
class SyntheticDataset(Dataset):
    def __init__(self, num_samples, L, dim):
        super(SyntheticDataset, self).__init__()
        self.num_samples = num_samples
        self.L = L
        self.dim = dim
        # 构造输入数据：随机数
        self.data = torch.randn(num_samples, L, dim)
        # 构造目标数据：例如一个线性函数 + 非线性激活（模拟一定关系）
        # 这里我们设定目标 = sin(数据的线性组合)
        weight = torch.randn(dim, dim)
        bias = torch.randn(dim)
        self.target = torch.sin(self.data.matmul(weight) + bias)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx], self.target[idx]

In [None]:
# 参数设置
BATCH_SIZE = 1
L = 8
DIM = 10
NUM_HEADS = 2
RADIUS = 1
NUM_TRAIN = 1
NUM_VAL = 1
EPOCHS = 1
LR = 1e-3

# 构造数据集和数据加载器
train_dataset = SyntheticDataset(NUM_TRAIN, L, DIM)
val_dataset = SyntheticDataset(NUM_VAL, L, DIM)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# -------------------- 构造两个模型 --------------------
# 经典 natten 版本（假设支持多头，并且参数名称类似，下列参数需根据实际natten版本调整）
classic_natten = NeighborhoodAttention1D(
    dim=DIM,
    kernel_size=2 * RADIUS + 1,
    dilation=1,
    num_heads=NUM_HEADS,
    qkv_bias=True,
    qk_scale=None,
    attn_drop=0.0,
    proj_drop=0.0
)
model_classic = AttentionModel(classic_natten,DIM)

# 自定义实现
custom_natten = CustomizedNeighborhoodAttention1D_MH(
    radius=RADIUS,
    dim=DIM,
    num_heads=NUM_HEADS,
    qkv_bias=True,
    attn_drop=0.0,
    proj_drop=0.0
)
model_custom = AttentionModel(custom_natten,DIM)

# 将模型放到相同设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_classic = model_classic.to(device)
model_custom = model_custom.to(device)

# 定义优化器和损失函数
criterion = nn.MSELoss()
optimizer_classic = torch.optim.Adam(model_classic.parameters(), lr=LR)
optimizer_custom = torch.optim.Adam(model_custom.parameters(), lr=LR)

# -------------------- 训练过程 --------------------
def train_one_epoch(model, optimizer, loader):
    model.train()
    total_loss = 0.0
    for batch_x, batch_y in loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        optimizer.zero_grad()
        out = model(batch_x)
        loss = criterion(out, batch_y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch_x.size(0)
    return total_loss / len(loader.dataset)

def evaluate(model, loader):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for batch_x, batch_y in loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            out = model(batch_x)
            loss = criterion(out, batch_y)
            total_loss += loss.item() * batch_x.size(0)
    return total_loss / len(loader.dataset)

print("Training Classic Natten Model...")
for epoch in range(EPOCHS):
    loss_train = train_one_epoch(model_classic, optimizer_classic, train_loader)
    loss_val = evaluate(model_classic, val_loader)
    print(f"Epoch {epoch+1}/{EPOCHS}, Train Loss: {loss_train:.4f}, Val Loss: {loss_val:.4f}")

print("Training Customized Natten Model...")
for epoch in range(EPOCHS):
    loss_train = train_one_epoch(model_custom, optimizer_custom, train_loader)
    loss_val = evaluate(model_custom, val_loader)
    print(f"Epoch {epoch+1}/{EPOCHS}, Train Loss: {loss_train:.4f}, Val Loss: {loss_val:.4f}")

# -------------------- 使用相同数据进行验证，并可视化部分结果 --------------------
# 取一个验证批次
model_classic.eval()
model_custom.eval()
with torch.no_grad():
    for batch_x, batch_y in val_loader:
        batch_x = batch_x.to(device)
        out_classic = model_classic(batch_x)  # (B, L, DIM)
        out_custom = model_custom(batch_x)
        break

# 选取第一个样本的第一个序列的位置，将输出结果的前两个维度视为 (x, y) 坐标进行可视化
# 这里我们简单假设输出的第2个维度即为 y 坐标，前者为 x 坐标（仅作为示例）
sample_classic = out_classic[0].cpu().numpy()
sample_custom = out_custom[0].cpu().numpy()

import plotly.graph_objects as go

# 生成连接线数据
pair_x = []
pair_y = []
L = sample_classic.shape[0]
for i in range(L):
    # 连接对应点：经典 -> 自定义
    pair_x.extend([sample_classic[i, 0], sample_custom[i, 0], None])
    pair_y.extend([sample_classic[i, 1], sample_custom[i, 1], None])

# 绘制图像
fig = go.Figure()

# 绘制经典输出的散点（不连接成线）
fig.add_trace(go.Scatter(
    x=sample_classic[:, 0],
    y=sample_classic[:, 1],
    mode='markers',
    marker=dict(color='blue'),
    name='Classic Natten'
))

# 绘制自定义输出的散点
fig.add_trace(go.Scatter(
    x=sample_custom[:, 0],
    y=sample_custom[:, 1],
    mode='markers',
    marker=dict(color='red'),
    name='Customized Natten'
))

# 绘制每个对应点之间的连线
fig.add_trace(go.Scatter(
    x=pair_x,
    y=pair_y,
    mode='lines',
    line=dict(color='green'),
    name='Pair Connection'
))

fig.update_layout(
    title="Pairwise Connection between Classic and Customized Outputs",
    xaxis_title="X",
    yaxis_title="Y",
    template="plotly_white",
    yaxis=dict(scaleanchor="x", scaleratio=1)
)

fig.show(renderer="notebook")