In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
import hdf5storage
import matplotlib.pyplot as plt
import torch.optim as optim


##### 数据集预处理


In [10]:
dataPath = './data/raw/trainDataDemo.mat'
data = hdf5storage.loadmat(dataPath)
csi_data = data['csiData']
rx_pilot_signal_data = data['rxPilotSignalData']
tx_pilot_signal_data = data['txPilotSignalData']
# 将数据转换为PyTorch张量
csi_tensor = torch.tensor(csi_data, dtype=torch.float32)
rx_pilot_signal_tensor = torch.tensor(rx_pilot_signal_data, dtype=torch.float32)
tx_pilot_signal_tensor = torch.tensor(tx_pilot_signal_data, dtype=torch.float32)

In [11]:
rx_pilot_signal_tensor.shape

torch.Size([500, 52, 14, 2, 2])

In [12]:
class CSIFormerDataset(Dataset):
    
    def __init__(self, dataPath='./data/raw/trainDataDemo.mat'):
        """
        初始化数据集
        :param tx_pilot_signal: 发射导频信号 [data_size, n_subc, n_sym, n_tx, 2]
        :param rx_pilot_signal: 接收导频信号 [data_size, n_subc, n_sym, n_rx, 2]
        :param csi_matrix: CSI矩阵 [data_size, n_subc, n_sym, n_tx, n_rx, 2]
        """
        data = hdf5storage.loadmat(dataPath)
        self.csi_matrix = torch.tensor(data['csiData'], dtype=torch.float32)
        self.rx_pilot_signal = torch.tensor(data['rxPilotSignalData'], dtype=torch.float32)
        self.tx_pilot_signal = torch.tensor(data['txPilotSignalData'], dtype=torch.float32)
        self.data_size,self.n_subc, self.n_sym, self.n_tx, self.n_rx, self.n_ch = self.csi_matrix.shape
    
    def __len__(self):
        """返回数据集大小"""
        return self.data_size

    def __getitem__(self, idx):
        """
        返回单个样本
        :param idx: 样本索引
        :return: 发射导频、接收导频、CSI矩阵
        """
        tx_pilot = self.tx_pilot_signal[idx]     # [n_subc, n_sym, n_tx, 2]
        rx_pilot = self.rx_pilot_signal[idx]     # [n_subc, n_sym, n_rx, 2]
        csi = self.csi_matrix[idx]               # [numSubc, n_sym, n_tx, n_rx, 2]
        return tx_pilot, rx_pilot, csi

In [13]:
###############################################################################
# 第一部分：CSIFormer (编码器)
###############################################################################
class CSIEncoder(nn.Module):
    def __init__(self, d_model=256, nhead=2, n_layers=1, n_tx=2, n_rx=2):
        """
        :param d_model: 输入特征维度
        :param nhead: 多头注意力头数
        :param n_layers: Transformer 层数
        :param n_tx: 发射天线数
        :param n_rx: 接收天线数
        """
        super(CSIEncoder, self).__init__()
        self.d_model = d_model
        self.num_tx = n_tx
        self.num_rx = n_rx

        # 线性层将输入映射到 d_model 维度
        self.input_proj = nn.Linear(n_tx * 2 + n_rx * 2, d_model)

        # Transformer 编码器 (batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=d_model, 
                nhead=nhead, 
                dim_feedforward=2048,
                batch_first=True
            ),
            num_layers=n_layers
        )

        # 输出层，预测 CSI 矩阵
        self.output_proj = nn.Linear(d_model, n_tx * n_rx * 2)

    def forward(self, tx_pilot_signal, rx_pilot_signal):
        """
        :param tx_pilot_signal: [B, n_subc, n_sym, n_tx, 2]
        :param rx_pilot_signal: [B, n_subc, n_sym, n_rx, 2]
        :return: 初步估计的 CSI [B, n_subc, n_sym, n_tx, n_rx, 2]
        """
        batch_size, n_subc, n_sym, _, _ = tx_pilot_signal.shape

        # 将发射导频和接收导频拼接为输入特征 [B, n_subc, n_sym, (n_tx+n_rx)*2]
        tx_pilot_signal = tx_pilot_signal.view(batch_size, n_subc, n_sym, -1)
        rx_pilot_signal = rx_pilot_signal.view(batch_size, n_subc, n_sym, -1)
        input_features = torch.cat([tx_pilot_signal, rx_pilot_signal], dim=-1)

        # 将输入特征映射到 d_model 维度 [B, n_subc, n_sym, d_model]
        input_features = self.input_proj(input_features)

        # 将 (n_subc, n_sym) “折叠” 成 seq_len，保持 batch 在第 0 维
        # 最终形状: [B, (n_subc*n_sym), d_model]
        seq_len = n_subc * n_sym
        input_features = input_features.view(batch_size, seq_len, self.d_model)

        # 通过 Transformer 编码器 (batch_first=True)
        # 结果也是 [B, seq_len, d_model]
        output = self.transformer_encoder(input_features)

        # 映射到输出维度 (n_tx*n_rx*2)，仍是 [B, seq_len, n_tx*n_rx*2]
        output = self.output_proj(output)

        # 调整输出形状为 [B, n_subc, n_sym, n_tx, n_rx, 2]
        output = output.view(batch_size, n_subc, n_sym, self.num_tx, self.num_rx, 2)

        return output

In [14]:
# 模型训练
def train_model(model, dataloader, criterion, optimizer, epochs, device):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx,(tx_pilot_batch, rx_pilot_batch, csi_batch) in enumerate(dataloader):
            tx_pilot_batch, rx_pilot_batch, csi_batch = tx_pilot_batch.to(device), rx_pilot_batch.to(device), csi_batch.to(device)
            optimizer.zero_grad()
            outputs = model(tx_pilot_batch, rx_pilot_batch)
            loss = criterion(outputs, csi_batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            if (batch_idx + 1) % 10 == 0:
                print(f"Epoch {epoch + 1}/{epochs}, Batch {batch_idx + 1}/{len(dataloader)}, Loss: {loss.item():.4f}")
        
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataloader)}")


In [15]:
# 模型评估
def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item()
    print(f"Evaluation Loss: {total_loss / len(dataloader)}")



In [None]:
# 主函数执行
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 0.001
epochs = 10
batch_size = 16
shuffle_flag = False
model = CSIEncoder()
dataset = CSIFormerDataset()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle_flag)

train_model(model, dataloader, criterion, optimizer, epochs, device)
# evaluate_model(model, dataloader, criterion, device)

In [17]:
# 计算参数量
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total trainable parameters: {count_parameters(model)}")

Total trainable parameters: 1319432
