In [41]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
import hdf5storage
import matplotlib.pyplot as plt
import torch.optim as optim


##### 数据集预处理


In [42]:
dataPath = './data/raw/trainDataDemo.mat'
data = hdf5storage.loadmat(dataPath)
csi_data = data['csiData']
rx_pilot_signal_data = data['rxPilotSignalData']
tx_pilot_signal_data = data['txPilotSignalData']
# 将数据转换为PyTorch张量
csi_tensor = torch.tensor(csi_data, dtype=torch.float32)
rx_pilot_signal_tensor = torch.tensor(rx_pilot_signal_data, dtype=torch.float32)
tx_pilot_signal_tensor = torch.tensor(tx_pilot_signal_data, dtype=torch.float32)

In [43]:
rx_pilot_signal_tensor.shape

torch.Size([500, 52, 14, 2, 2])

In [44]:
class CSIFormerDataset(Dataset):
    
    def __init__(self, dataPath='./data/raw/trainDataDemo.mat'):
        """
        初始化数据集
        :param tx_pilot_signal: 发射导频信号 [data_size, n_subc, n_sym, n_tx, 2]
        :param rx_pilot_signal: 接收导频信号 [data_size, n_subc, n_sym, n_rx, 2]
        :param csi_matrix: CSI矩阵 [data_size, n_subc, n_sym, n_tx, n_rx, 2]
        """
        data = hdf5storage.loadmat(dataPath)
        self.csi_matrix = torch.tensor(data['csiData'], dtype=torch.float32)
        self.rx_pilot_signal = torch.tensor(data['rxPilotSignalData'], dtype=torch.float32)
        self.tx_pilot_signal = torch.tensor(data['txPilotSignalData'], dtype=torch.float32)
        self.data_size,self.n_subc, self.n_sym, self.n_tx, self.n_rx, self.n_ch = self.csi_matrix.shape
    
    def __len__(self):
        """返回数据集大小"""
        return self.data_size

    def __getitem__(self, idx):
        """
        返回单个样本
        :param idx: 样本索引
        :return: 发射导频、接收导频、CSI矩阵
        """
        tx_pilot = self.tx_pilot_signal[idx]     # [n_subc, n_sym, n_tx, 2]
        rx_pilot = self.rx_pilot_signal[idx]     # [n_subc, n_sym, n_rx, 2]
        csi = self.csi_matrix[idx]               # [numSubc, n_sym, n_tx, n_rx, 2]
        return tx_pilot, rx_pilot, csi

In [45]:
class CSIFormer(nn.Module):
    def __init__(self, d_model = 256, nhead = 2, n_layers = 1, n_tx = 2, n_rx = 2):
        """
        :param d_model: 输入特征维度
        :param nhead: 多头注意力头数
        :param num_layers: Transformer 层数
        :param num_tx: 发射天线数
        :param num_rx: 接收天线数
        """
        super(CSIFormer, self).__init__()
        self.d_model = d_model
        self.num_tx = n_tx
        self.num_rx = n_rx

        # 线性层将输入映射到 d_model 维度
        self.input_proj = nn.Linear(n_tx * 2 + n_rx * 2, d_model)

        # Transformer 编码器
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=2048),
            num_layers=n_layers
        )

        # 输出层，预测 CSI 矩阵
        self.output_proj = nn.Linear(d_model, n_tx * n_rx * 2)

    def forward(self, tx_pilot_signal, rx_pilot_signal):
        """
        :param tx_pilot: 发射导频 [batch_size, n_subc, n_sym, n_tx, 2]
        :param rx_pilot: 接收导频 [batch_size, n_subc, n_sym, n_rx, 2]
        :return: 预测的 CSI 矩阵 [batch_size, num_subc, n_sym, n_tx, n_rx, 2]
        """
        batch_size, n_subc, n_sym, n_tx, _ = tx_pilot_signal.shape
        _, _, _, n_rx, _ = rx_pilot_signal.shape

        # 将发射导频和接收导频拼接为输入特征
        tx_pilot_signal = tx_pilot_signal.view(batch_size, n_subc, n_sym, -1)  # [batch_size, n_subc, n_sym, n_tx*2]
        rx_pilot_signal = rx_pilot_signal.view(batch_size, n_subc, n_sym, -1)  # [batch_size, n_subc, n_sym, n_rx*2]
        input_features = torch.cat([tx_pilot_signal, rx_pilot_signal], dim=-1)        # [batch_size, n_subc, n_sym, (n_tx+n_rx)*2]

        # 将输入特征映射到 d_model 维度
        input_features = self.input_proj(input_features)  # [batch_size, n_subc, n_sym, d_model]

        # 调整维度以适应 Transformer 输入 (seq_len, batch_size, d_model)
        input_features = input_features.permute(1, 2, 0, 3)  # [n_subc, n_sym, batch_size, d_model]
        input_features = input_features.reshape(-1, batch_size, self.d_model)  # [n_subc*n_sym, batch_size, d_model]

        # 通过 Transformer 编码器
        output = self.transformer_encoder(input_features)  # [n_subc*n_sym, batch_size, d_model]

        # 映射到输出维度
        output = self.output_proj(output)  # [n_subc*n_sym, batch_size, n_tx*n_rx*2]

        # 调整输出形状
        output = output.view(n_subc, n_sym, batch_size, self.num_tx, self.num_rx, 2)  # [n_subc, n_sym, batch_size, n_tx, n_rx, 2]
        output = output.permute(2, 0, 1, 3, 4, 5)  # [batch_size, n_subc, n_sym, n_tx, n_rx, 2]

        return output

In [46]:
# 自定义损失函数
class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, predictions, targets):
        # 计算平方误差
        squared_error = (predictions - targets) ** 2
        # 返回误差的均值
        return squared_error.mean()

In [47]:
# 模型训练
def train_model(model, dataloader, criterion, optimizer, epochs, device):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx,(tx_pilot_batch, rx_pilot_batch, csi_batch) in enumerate(dataloader):
            tx_pilot_batch, rx_pilot_batch, csi_batch = tx_pilot_batch.to(device), rx_pilot_batch.to(device), csi_batch.to(device)
            optimizer.zero_grad()
            outputs = model(tx_pilot_batch, rx_pilot_batch)
            loss = criterion(outputs, csi_batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            if (batch_idx + 1) % 10 == 0:
                print(f"Epoch {epoch + 1}/{epochs}, Batch {batch_idx + 1}/{len(dataloader)}, Loss: {loss.item():.4f}")
        
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataloader)}")


In [48]:
# 模型评估
def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item()
    print(f"Evaluation Loss: {total_loss / len(dataloader)}")



In [49]:
# 主函数执行
lr = 0.001
epochs = 10
batch_size = 16
shuffle_flag = False
model = CSIFormer()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = CSIFormerDataset()
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle_flag)

train_model(model, dataloader, criterion, optimizer, epochs, device)
# evaluate_model(model, dataloader, criterion, device)



Epoch 1/10, Batch 10/32, Loss: 0.2502
Epoch 1/10, Batch 20/32, Loss: 0.1936
Epoch 1/10, Batch 30/32, Loss: 0.1784
Epoch 1/10, Loss: 0.3108055526390672
Epoch 2/10, Batch 10/32, Loss: 0.6154
Epoch 2/10, Batch 20/32, Loss: 0.7757
Epoch 2/10, Batch 30/32, Loss: 0.7669
Epoch 2/10, Loss: 0.4762524622492492
Epoch 3/10, Batch 10/32, Loss: 0.3087
Epoch 3/10, Batch 20/32, Loss: 0.4103
Epoch 3/10, Batch 30/32, Loss: 0.4646
Epoch 3/10, Loss: 0.3327607112005353
Epoch 4/10, Batch 10/32, Loss: 0.3646
Epoch 4/10, Batch 20/32, Loss: 0.4381
Epoch 4/10, Batch 30/32, Loss: 0.4002
Epoch 4/10, Loss: 0.3527141995728016
Epoch 5/10, Batch 10/32, Loss: 0.2391
Epoch 5/10, Batch 20/32, Loss: 0.5155
Epoch 5/10, Batch 30/32, Loss: 0.3738
Epoch 5/10, Loss: 0.4616211331449449
Epoch 6/10, Batch 10/32, Loss: 0.4355
Epoch 6/10, Batch 20/32, Loss: 0.3619
Epoch 6/10, Batch 30/32, Loss: 0.3649
Epoch 6/10, Loss: 0.38195937778800726
Epoch 7/10, Batch 10/32, Loss: 0.3735
Epoch 7/10, Batch 20/32, Loss: 0.4189
Epoch 7/10, Batch

In [50]:
# 计算参数量
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total trainable parameters: {count_parameters(model)}")

Total trainable parameters: 1319432
