In [7]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader


In [8]:
class TEXBATSequenceDataset(Dataset):
    def __init__(self, data_dir, label_file, window_size=256, pred_len=16, step_size=32):
        self.data_dir = data_dir
        self.window_size = window_size
        self.pred_len = pred_len
        self.step_size = step_size

        # 读取标签文件：chunk_XXX.npy, label
        self.labels = {}
        with open(label_file, 'r') as f:
            for line in f:
                name, label = line.strip().split(',')
                self.labels[name.strip()] = int(label.strip())

        # 构建所有索引 (chunk_name, 起始位置, 标签)
        self.index = []
        for filename in sorted(os.listdir(data_dir)):
            if filename.endswith('.npy') and filename in self.labels:
                full_path = os.path.join(data_dir, filename)
                num_points = os.path.getsize(full_path) // 16  # complex128 = 16 bytes
                max_start = num_points - (window_size + pred_len)
                for start in range(0, max_start, step_size):
                    self.index.append((filename, start, self.labels[filename]))

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        return self.load_sequence_by_index(self.index[idx])

    def load_sequence_by_index(self, index_entry):
        chunk_file, start, label = index_entry
        path = os.path.join(self.data_dir, chunk_file)
        data = np.load(path)
        x = data[start : start + self.window_size]
        y = data[start + self.window_size : start + self.window_size + self.pred_len]
        x_tensor = torch.from_numpy(np.stack([x.real, x.imag], axis=-1)).float()
        y_tensor = torch.from_numpy(np.stack([y.real, y.imag], axis=-1)).float()
        return x_tensor, y_tensor, torch.tensor(label)

    def get_task_data(self, label=None):
        return self.index if label is None else [item for item in self.index if item[2] == label]

In [9]:
data_dir = "Dataset/DS7"
label_file = "Dataset/ds7_labels.txt"

dataset = TEXBATSequenceDataset(
    data_dir=data_dir,
    label_file=label_file,
    window_size=256,
    pred_len=16,
    step_size=32
)

# 尝试取一条看看
x, y, label = dataset[0]
print("x shape:", x.shape)     # [256, 2]
print("y shape:", y.shape)     # [16, 2]
print("label:", label)         # tensor(0) or tensor(1)


FileNotFoundError: [WinError 3] 系统找不到指定的路径。: 'Dataset/DS7'

In [7]:
import torch.nn as nn

class LSTMPredictor(nn.Module):
    def __init__(self, input_dim=2, hidden_dim=64, num_layers=2, output_len=16):
        super(LSTMPredictor, self).__init__()
        self.hidden_dim = hidden_dim
        self.output_len = output_len

        self.lstm = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim,
                            num_layers=num_layers, batch_first=True)

        self.linear = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 2)  # 预测复数（实部、虚部）
        )

    def forward(self, x):
        # x: [batch, seq_len=256, 2]
        lstm_out, _ = self.lstm(x)  # 输出维度：[batch, seq_len, hidden]
        last_outputs = lstm_out[:, -self.output_len:, :]  # 取最后 N 步
        pred = self.linear(last_outputs)  # [batch, pred_len, 2]
        return pred


In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

model = LSTMPredictor().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

Using device: cuda


In [9]:
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [1]:
model.train()
for i, (x, y, label) in enumerate(train_loader):
    x = x.to(device)
    y = y.to(device)

    optimizer.zero_grad()
    output = model(x)
    loss = criterion(output, y)
    loss.backward()
    optimizer.step()

    if i % 10 == 0:
        print(f"[Batch {i}] Loss: {loss.item():.6f}")

    if i == 30:  # 先跑 30 个 batch 看显存表现
        break


In [None]:
import random

def sample_support_query(dataset, support_size=5, query_size=15, label=1):
    """
    从 dataset 中采样 support/query 集（用于 meta-task）
    label: 只采该标签的样本（默认=1 表示 spoofed 样本）
    """
    # 获取指定类别下的所有索引
    candidates = dataset.get_task_data(label=label)
    total_needed = support_size + query_size
    assert len(candidates) >= total_needed, f"可用样本数不足（共 {len(candidates)}）"

    # 随机采样
    selected = random.sample(candidates, total_needed)
    support_indices = selected[:support_size]
    query_indices = selected[support_size:]

    # 分别加载 support / query 数据
    support_set = [dataset.load_sequence_by_index(idx) for idx in support_indices]
    query_set = [dataset.load_sequence_by_index(idx) for idx in query_indices]

    return support_set, query_set
