In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
import hdf5storage
import matplotlib.pyplot as plt
import torch.optim as optim
import gc


##### 数据集预处理


In [2]:
data_train = hdf5storage.loadmat('data/raw/trainData.mat')

In [3]:
class CSIFormerDataset(Dataset):
    
    def __init__(self, tx_signal, rx_signal, csi, tx_pilot_mask, rx_pilot_mask):
        """
        初始化数据集
        :param tx_signal: 发射导频信号 [data_size, n_subc, n_sym, n_tx, 2]
        :param rx_signal: 接收导频信号 [data_size, n_subc, n_sym, n_rx, 2]
        :param csi: CSI矩阵 [data_size, n_subc, n_sym, n_tx, n_rx, 2]
        
        """
        self.tx_signal = tx_signal
        self.rx_signal = rx_signal
        self.csi = csi
        self.tx_pilot_mask = tx_pilot_mask
        self.rx_pilot_mask = rx_pilot_mask

    def __len__(self):
        """返回数据集大小"""
        return self.csi.size(0)

    def __getitem__(self, idx):
        """
        返回单个样本
        :param idx: 样本索引
        :return: 发射导频、接收导频、CSI矩阵
        """
        tx_pilot = self.tx_signal[idx] * self.tx_pilot_mask    # [n_subc, n_sym, n_tx, 2]
        rx_pilot = self.tx_signal[idx] * self.rx_pilot_mask    # [n_subc, n_sym, n_rx, 2]
        csi_label = self.csi[idx]                              # [numSubc, n_sym, n_tx, n_rx, 2]

        return tx_pilot, rx_pilot, csi_label 

In [4]:
def dataset_preprocess(data):
    # 将数据转换为PyTorch张量
    # csi = torch.tensor(data['csiTrainData'], dtype=torch.float32) #[data_size, n_subc, n_sym, n_tx, n_rx, 2]
    # rx_signal = torch.tensor(data['rxSignalTrainData'], dtype=torch.float32) # [data_size, n_subc, n_sym, n_rx, 2]
    # tx_signal = torch.tensor(data['txSignalTrainData'], dtype=torch.float32) # [data_size, n_subc, n_sym, n_tx, 2]
    csi = torch.tensor(data['csiTrainData'][:10000,:,:,:,:,:], dtype=torch.float32) #[data_size, n_subc, n_sym, n_tx, n_rx, 2]
    rx_signal = torch.tensor(data['rxSignalTrainData'][:10000,:,:,:,:], dtype=torch.float32) # [data_size, n_subc, n_sym, n_rx, 2]
    tx_signal = torch.tensor(data['txSignalTrainData'][:10000,:,:,:,:], dtype=torch.float32) # [data_size, n_subc, n_sym, n_tx, 2]
    del data
    gc.collect()
    tx_pilot_mask = torch.zeros(tx_signal[0].shape)
    rx_pilot_mask = torch.zeros(rx_signal[0].shape)
    pilot_indices = torch.tensor([7, 8, 26, 27, 40, 41, 57, 58])-7
    tx_pilot_mask[pilot_indices,:,:,:] = 1
    rx_pilot_mask[pilot_indices,:,:,:] = 1
    return CSIFormerDataset(tx_signal, rx_signal, csi, tx_pilot_mask, rx_pilot_mask)


In [5]:
###############################################################################
# 第一部分：CSIFormer (编码器)
###############################################################################
class CSIEncoder(nn.Module):
    def __init__(self, d_model=256, nhead=2, n_layers=4, n_tx=2, n_rx=2):
        """
        :param d_model: 输入特征维度
        :param nhead: 多头注意力头数
        :param n_layers: Transformer 层数
        :param n_tx: 发射天线数
        :param n_rx: 接收天线数
        """
        super(CSIEncoder, self).__init__()
        self.d_model = d_model
        self.num_tx = n_tx
        self.num_rx = n_rx

        # 线性层将输入映射到 d_model 维度
        self.input_proj = nn.Linear(n_tx * 2 + n_rx * 2, d_model)

        # Transformer 编码器 (batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=d_model, 
                nhead=nhead, 
                dim_feedforward=2048,
                batch_first=True
            ),
            num_layers=n_layers
        )

        # 输出层，预测 CSI 矩阵
        self.output_proj = nn.Linear(d_model, n_tx * n_rx * 2)

    def forward(self, tx_pilot_signal, rx_pilot_signal):
        """
        :param tx_pilot_signal: [B, n_subc, n_sym, n_tx, 2]
        :param rx_pilot_signal: [B, n_subc, n_sym, n_rx, 2]
        :return: 初步估计的 CSI [B, n_subc, n_sym, n_tx, n_rx, 2]
        """
        batch_size, n_subc, n_sym, _, _ = tx_pilot_signal.shape

        # 将发射导频和接收导频拼接为输入特征 [B, n_subc, n_sym, (n_tx+n_rx)*2]
        tx_pilot_signal = tx_pilot_signal.view(batch_size, n_subc, n_sym, -1)
        rx_pilot_signal = rx_pilot_signal.view(batch_size, n_subc, n_sym, -1)
        input_features = torch.cat([tx_pilot_signal, rx_pilot_signal], dim=-1)

        # 将输入特征映射到 d_model 维度 [B, n_subc, n_sym, d_model]
        input_features = self.input_proj(input_features)

        # 将 (n_subc, n_sym) “折叠” 成 seq_len，保持 batch 在第 0 维
        # 最终形状: [B, (n_subc*n_sym), d_model]
        seq_len = n_subc * n_sym
        input_features = input_features.view(batch_size, seq_len, self.d_model)

        # 通过 Transformer 编码器 (batch_first=True)
        # 结果也是 [B, seq_len, d_model]
        output = self.transformer_encoder(input_features)

        # 映射到输出维度 (n_tx*n_rx*2)，仍是 [B, seq_len, n_tx*n_rx*2]
        output = self.output_proj(output)

        # 调整输出形状为 [B, n_subc, n_sym, n_tx, n_rx, 2]
        output = output.view(batch_size, n_subc, n_sym, self.num_tx, self.num_rx, 2)

        return output

In [10]:
# 模型训练
def train_model(model, dataloader_train, criterion, optimizer, scheduler, epochs, device, checkpoint_dir='./checkpoints'):
    os.makedirs(checkpoint_dir, exist_ok=True)
    best_loss = float('inf')
    start_epoch = 0
    model.to(device)
    # 查看是否有可用的最近 checkpoint
    latest_path = os.path.join(checkpoint_dir, model.__class__.__name__ + '_latest.pth')
    best_path = os.path.join(checkpoint_dir, model.__class__.__name__ + '_best.pth')

    if os.path.isfile(latest_path):
        print(f"[INFO] Resuming training from '{latest_path}'")
        checkpoint = torch.load(latest_path, map_location=device)

        # 加载模型、优化器、调度器状态
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if scheduler is not None and 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_loss = checkpoint.get('best_loss', best_loss)
        print(f"[INFO] Resumed epoch {start_epoch}, best_loss={best_loss:.6f}")
    
    # 分epoch训练

    for epoch in range(start_epoch, epochs):
        print(f"\nEpoch [{epoch + 1}/{epochs}]")
        # --------------------- Train ---------------------
        model.train()
        total_loss = 0
        for batch_idx, (tx_pilot_train, rx_pilot_train, csi_label ) in enumerate(dataloader_train):
            tx_pilot_train = tx_pilot_train.to(device)
            rx_pilot_train = rx_pilot_train.to(device)
            csi_label = csi_label.to(device)

            optimizer.zero_grad()
            csi_enc = model(tx_pilot_train, rx_pilot_train)
            loss = criterion(csi_enc, csi_label)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            if (batch_idx + 1) % 50 == 0:
                print(f"Epoch {epoch + 1}/{epochs}, Batch {batch_idx + 1}/{len(dataloader_train)}, Loss: {loss.item():.4f}")
        
        train_loss = total_loss / len(dataloader_train)
        # 学习率调度器步进（根据策略）
        if scheduler is not None:
            scheduler.step(train_loss)  # 对于 ReduceLROnPlateau 等需要传入指标的调度器

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataloader_train)}")

        # --------------------- Checkpoint 保存 ---------------------
        # 1) 保存最新checkpoint（确保断点续训）
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None,
            'best_loss': best_loss,
        }, latest_path)

        # 3) 每隔5个epoch保存当前epoch的权重
        if (epoch+1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None,
                'best_loss': best_loss,
            }, os.path.join(checkpoint_dir, model.__class__.__name__ + '_epoch_'+str(epoch)+'.pth'))


In [7]:
# 模型评估
def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item()
    print(f"Evaluation Loss: {total_loss / len(dataloader)}")



In [8]:
# 主函数执行
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 1e-3
epochs = 20
batch_size = 64
shuffle_flag = False
model = CSIEncoder()
dataset_train = dataset_preprocess(data_train)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1)
dataloader_train = DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=shuffle_flag)


In [None]:
train_model(model, dataloader_train, criterion, optimizer, scheduler, epochs, device)
# evaluate_model(model, dataloader, criterion, device)


In [12]:
# 计算参数量
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total trainable parameters: {count_parameters(model)}")

Total trainable parameters: 5264648


In [13]:
torch.save(model,'./models/'+model.__class__.__name__+'_model.pt')