In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
import hdf5storage
import matplotlib.pyplot as plt
import torch.optim as optim


##### 数据集预处理


In [2]:
dataPath = './data/raw/trainDataTrain.mat'
data = hdf5storage.loadmat(dataPath)
csi_data = data['csiData']
rx_pilot_signal_data = data['rxPilotSignalData']
tx_pilot_signal_data = data['txPilotSignalData']
# 将数据转换为PyTorch张量
csi_matrix = torch.tensor(csi_data, dtype=torch.float32)
rx_pilot_signal = torch.tensor(rx_pilot_signal_data, dtype=torch.float32)
tx_pilot_signal = torch.tensor(tx_pilot_signal_data, dtype=torch.float32)

In [3]:
class CSIFormerDataset(Dataset):
    
    def __init__(self, csi_matrix, rx_pilot_signal, tx_pilot_signal):
        """
        初始化数据集
        :param tx_pilot_signal: 发射导频信号 [data_size, n_subc, n_sym, n_tx, 2]
        :param rx_pilot_signal: 接收导频信号 [data_size, n_subc, n_sym, n_rx, 2]
        :param csi_matrix: CSI矩阵 [data_size, n_subc, n_sym, n_tx, n_rx, 2]
        """
        self.csi_matrix = csi_matrix
        self.rx_pilot_signal = rx_pilot_signal
        self.tx_pilot_signal = tx_pilot_signal
        self.data_size,self.n_subc, self.n_sym, self.n_tx, self.n_rx, self.n_ch = self.csi_matrix.shape
    
    def __len__(self):
        """返回数据集大小"""
        return self.data_size

    def __getitem__(self, idx):
        """
        返回单个样本
        :param idx: 样本索引
        :return: 发射导频、接收导频、CSI矩阵
        """
        tx_pilot = self.tx_pilot_signal[idx]     # [n_subc, n_sym, n_tx, 2]
        rx_pilot = self.rx_pilot_signal[idx]     # [n_subc, n_sym, n_rx, 2]
        csi = self.csi_matrix[idx]               # [numSubc, n_sym, n_tx, n_rx, 2]
        return tx_pilot, rx_pilot, csi

In [4]:
###############################################################################
# 第一部分：CSIFormer (编码器)
###############################################################################
class CSIEncoder(nn.Module):
    def __init__(self, d_model=256, nhead=2, n_layers=4, n_tx=2, n_rx=2):
        """
        :param d_model: 输入特征维度
        :param nhead: 多头注意力头数
        :param n_layers: Transformer 层数
        :param n_tx: 发射天线数
        :param n_rx: 接收天线数
        """
        super(CSIEncoder, self).__init__()
        self.d_model = d_model
        self.num_tx = n_tx
        self.num_rx = n_rx

        # 线性层将输入映射到 d_model 维度
        self.input_proj = nn.Linear(n_tx * 2 + n_rx * 2, d_model)

        # Transformer 编码器 (batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=d_model, 
                nhead=nhead, 
                dim_feedforward=2048,
                batch_first=True
            ),
            num_layers=n_layers
        )

        # 输出层，预测 CSI 矩阵
        self.output_proj = nn.Linear(d_model, n_tx * n_rx * 2)

    def forward(self, tx_pilot_signal, rx_pilot_signal):
        """
        :param tx_pilot_signal: [B, n_subc, n_sym, n_tx, 2]
        :param rx_pilot_signal: [B, n_subc, n_sym, n_rx, 2]
        :return: 初步估计的 CSI [B, n_subc, n_sym, n_tx, n_rx, 2]
        """
        batch_size, n_subc, n_sym, _, _ = tx_pilot_signal.shape

        # 将发射导频和接收导频拼接为输入特征 [B, n_subc, n_sym, (n_tx+n_rx)*2]
        tx_pilot_signal = tx_pilot_signal.view(batch_size, n_subc, n_sym, -1)
        rx_pilot_signal = rx_pilot_signal.view(batch_size, n_subc, n_sym, -1)
        input_features = torch.cat([tx_pilot_signal, rx_pilot_signal], dim=-1)

        # 将输入特征映射到 d_model 维度 [B, n_subc, n_sym, d_model]
        input_features = self.input_proj(input_features)

        # 将 (n_subc, n_sym) “折叠” 成 seq_len，保持 batch 在第 0 维
        # 最终形状: [B, (n_subc*n_sym), d_model]
        seq_len = n_subc * n_sym
        input_features = input_features.view(batch_size, seq_len, self.d_model)

        # 通过 Transformer 编码器 (batch_first=True)
        # 结果也是 [B, seq_len, d_model]
        output = self.transformer_encoder(input_features)

        # 映射到输出维度 (n_tx*n_rx*2)，仍是 [B, seq_len, n_tx*n_rx*2]
        output = self.output_proj(output)

        # 调整输出形状为 [B, n_subc, n_sym, n_tx, n_rx, 2]
        output = output.view(batch_size, n_subc, n_sym, self.num_tx, self.num_rx, 2)

        return output

In [5]:
# 模型训练
def train_model(model, dataloader, criterion, optimizer, scheduler, epochs, device, checkpoint_dir='./checkpoints'):
    model.to(device)

    os.makedirs(checkpoint_dir, exist_ok=True)
    best_loss = float('inf')
    start_epoch = 0

    # 查看是否有可用的最近 checkpoint
    resume_path = os.path.join(checkpoint_dir, 'latest.pth')
    if os.path.isfile(resume_path):
        print(f"[INFO] Resuming training from '{resume_path}'")
        checkpoint = torch.load(resume_path, map_location=device)

        # 加载模型、优化器、调度器状态
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if scheduler is not None and 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_loss = checkpoint.get('best_loss', best_loss)
        print(f"[INFO] Resumed epoch {start_epoch}, best_loss={best_loss:.6f}")

    for epoch in range(epochs):
        print(f"\nEpoch [{epoch}/{epochs - 1}]")
        # --------------------- Train ---------------------
        model.train()
        total_loss = 0
        for batch_idx,(tx_pilot_batch, rx_pilot_batch, csi_batch) in enumerate(dataloader):
            tx_pilot_batch, rx_pilot_batch, csi_batch = tx_pilot_batch.to(device), rx_pilot_batch.to(device), csi_batch.to(device)
            optimizer.zero_grad()
            outputs = model(tx_pilot_batch, rx_pilot_batch)
            loss = criterion(outputs, csi_batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            if (batch_idx + 1) % 100 == 0:
                print(f"Epoch {epoch + 1}/{epochs}, Batch {batch_idx + 1}/{len(dataloader)}, Loss: {loss.item():.4f}")
        
        train_loss = total_loss / len(dataloader)
        # 学习率调度器步进（根据策略）
        if scheduler is not None:
            scheduler.step(train_loss)  # 对于 ReduceLROnPlateau 等需要传入指标的调度器

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataloader)}")

        # # --------------------- Validate ---------------------
        # model.eval()
        # val_loss = 0.0
        # with torch.no_grad():
        #     for tx_pilot, rx_pilot, prev_csi, csi_true in dataloader_val:
        #         tx_pilot, rx_pilot = tx_pilot.to(device), rx_pilot.to(device)
        #         prev_csi, csi_true = prev_csi.to(device), csi_true.to(device)
        #         csi_enc, csi_dec = model(tx_pilot, rx_pilot, prev_csi)
        #         total_loss, _, _ = criterion(csi_enc, csi_dec, csi_true)
        #         val_loss += total_loss.item()
        
        # val_loss /= len(dataloader_val)
        # print(f"Val   Loss: {val_loss:.4f}")

        # --------------------- Checkpoint 保存 ---------------------
        # 1) 保存最新checkpoint（确保断点续训）
        latest_path = os.path.join(checkpoint_dir, 'latest.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None,
            'best_loss': best_loss,
        }, latest_path)

        # # 2) 如果当前验证集 Loss 最佳，则保存为 best.pth
        # if val_loss < best_loss:
        #     best_loss = val_loss
        #     best_path = os.path.join(checkpoint_dir, 'best.pth')
        #     torch.save({
        #         'epoch': epoch,
        #         'model_state_dict': model.state_dict(),
        #         'optimizer_state_dict': optimizer.state_dict(),
        #         'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None,
        #         'best_loss': best_loss,
        #     }, best_path)
        #     print(f"[INFO] Best model saved at epoch {epoch}, val_loss={val_loss:.4f}")


In [6]:
# 模型评估
def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item()
    print(f"Evaluation Loss: {total_loss / len(dataloader)}")



In [8]:
# 主函数执行
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 1e-3
epochs = 10
batch_size = 64
shuffle_flag = False
model = CSIEncoder()
dataset = CSIFormerDataset(csi_matrix, rx_pilot_signal, tx_pilot_signal)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle_flag)

train_model(model, dataloader, criterion, optimizer,scheduler, epochs, device)
# evaluate_model(model, dataloader, criterion, device)





Epoch [0/9]
Epoch 1/10, Batch 100/782, Loss: 0.3678
Epoch 1/10, Batch 200/782, Loss: 0.2991
Epoch 1/10, Batch 300/782, Loss: 0.2858
Epoch 1/10, Batch 400/782, Loss: 0.4123
Epoch 1/10, Batch 500/782, Loss: 0.1926
Epoch 1/10, Batch 600/782, Loss: 0.2521
Epoch 1/10, Batch 700/782, Loss: 0.2611
Epoch 1/10, Loss: 0.3150776035492987

Epoch [1/9]
Epoch 2/10, Batch 100/782, Loss: 0.3450
Epoch 2/10, Batch 200/782, Loss: 0.3450
Epoch 2/10, Batch 300/782, Loss: 0.2318
Epoch 2/10, Batch 400/782, Loss: 0.3277
Epoch 2/10, Batch 500/782, Loss: 0.1995
Epoch 2/10, Batch 600/782, Loss: 0.1907
Epoch 2/10, Batch 700/782, Loss: 0.2590
Epoch 2/10, Loss: 0.26127316113894855

Epoch [2/9]
Epoch 3/10, Batch 100/782, Loss: 0.3323
Epoch 3/10, Batch 200/782, Loss: 0.3063
Epoch 3/10, Batch 300/782, Loss: 0.2132
Epoch 3/10, Batch 400/782, Loss: 0.3006
Epoch 3/10, Batch 500/782, Loss: 0.2015
Epoch 3/10, Batch 600/782, Loss: 0.1696
Epoch 3/10, Batch 700/782, Loss: 0.2375
Epoch 3/10, Loss: 0.24945879154517064

Epoch [

In [25]:
# 计算参数量
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total trainable parameters: {count_parameters(model)}")

Total trainable parameters: 5264648


: 

: 