In [1]:
#!/usr/bin/env python
# coding: utf-8

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
import hdf5storage
import torch.optim as optim
import math
import gc

In [17]:
# ##### 数据集预处理

class MIMODataset(Dataset):
    
    def __init__(self, csi_ls, csi_pre, csi_label):
        """
        初始化数据集
        :param csi_ls: 导频CSI矩阵  [data_size, n_subc, n_sym, n_tx, n_rx, 2]
        :param csi: CSI矩阵 [data_size, n_subc, n_sym, n_tx, n_rx, 2]
        :param csi_pre: 历史CSI矩阵 [data_size, n_frame, n_subc, n_sym, n_tx, n_rx, 2]
        """
        self.csi_ls = csi_ls
        self.csi_pre = csi_pre
        self.csi_label = csi_label

    def __len__(self):
        """返回数据集大小"""
        return self.csi_label.size(0)

    def __getitem__(self, idx):
        """
        返回单个样本
        :param idx: 样本索引
        :return: 发射导频、接收导频、CSI矩阵
        """
        return self.csi_ls[idx], self.csi_pre[idx], self.csi_label[idx]

def dataset_preprocess(data):
    # 将数据转换为PyTorch张量
    csi_ls = torch.tensor(data['csiLSData'], dtype=torch.float32) #[data_size, n_subc, n_sym, n_tx, n_rx, 2]
    csi_pre = torch.tensor(data['csiPreData'], dtype=torch.float32) #[data_size, n_subc, n_sym, n_tx, n_rx, 2]
    csi_label = torch.tensor(data['csiLabelData'], dtype=torch.float32) #[data_size, n_subc, n_sym, n_tx, n_rx, 2]
    del data
    gc.collect()
    return MIMODataset(csi_ls, csi_pre, csi_label)


# 残差块定义
class ResidualBlock(nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        self.linear1 = nn.Linear(in_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, in_dim)
        self.activation = nn.ReLU()
        
    def forward(self, x):
        residual = x
        x = self.activation(self.linear1(x))
        x = self.linear2(x)
        return self.activation(x + residual)

# 深度残差网络模型
class DNNResCELS(nn.Module):
    def __init__(self, hidden_dim=512, num_blocks=4, n_subc=224 ,n_sym=14 ,n_tx=2, n_rx=2):
        super().__init__()

        self.n_subc = n_subc
        self.n_sym = n_sym
        self.n_tx = n_tx
        self.n_rx = n_rx

        # 输入层
        self.input_layer = nn.Sequential(
            nn.Linear(n_subc * n_sym, hidden_dim),
            nn.ReLU()
        )
        
        # 残差块堆叠
        self.res_blocks = nn.Sequential(*[
            ResidualBlock(hidden_dim, hidden_dim*2)
            for _ in range(num_blocks)
        ])

        # 输出层
        self.output_layer = nn.Linear(hidden_dim, n_subc * n_sym)
        
    def forward(self, x):
        
        x = x.reshape(-1 ,self.n_subc * self.n_sym, self.n_tx * self.n_rx * 2)
        x = x.permute(0,2,1)
        x = self.input_layer(x)
        x = self.res_blocks(x)
        x = self.output_layer(x)
        x = x.permute(0,2,1)
        x = x.reshape(-1, self.n_subc, self.n_sym, self.n_tx, self.n_rx, 2)
        return x

class ComplexMSELoss(nn.Module):
    def __init__(self):
        """
        :param alpha: 第一部分损失的权重
        :param beta:  第二部分损失的权重
        """
        super(ComplexMSELoss, self).__init__()


    def forward(self, output, target):
        """
        复数信道估计的均方误差 (MSE) 损失函数。
        x_py: (batch_size, csi_matrix, 2)，估计值
        y_py: (batch_size, csi_matrix, 2)，真实值
        """
        diff = output - target  # 差值，形状保持一致
        loss = torch.mean(diff[..., 0]**2 + diff[..., 1]**2)  # 实部和虚部平方和
        return loss


# 模型训练
def train_model(model, dataloader_train, dataloader_val, criterion, optimizer, scheduler, epochs, device, checkpoint_dir='./checkpoints'):
    os.makedirs(checkpoint_dir, exist_ok=True)
    best_loss = float('inf')
    start_epoch = 0
    model.to(device)
    # 查看是否有可用的最近 checkpoint
    latest_path = os.path.join(checkpoint_dir, model.__class__.__name__ + '_pro_latest.pth')
    best_path = os.path.join(checkpoint_dir, model.__class__.__name__ + '_pro_best.pth')

    if os.path.isfile(latest_path):
        print(f"[INFO] Resuming training from '{latest_path}'")
        checkpoint = torch.load(latest_path, map_location=device)

        # 加载模型、优化器、调度器状态
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if scheduler is not None and 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_loss = checkpoint.get('best_loss', best_loss)
        print(f"[INFO] Resumed epoch {start_epoch}, best_loss={best_loss:.6f}")
    
    # 分epoch训练

    for epoch in range(start_epoch, epochs):
        print(f"\nEpoch [{epoch + 1}/{epochs}]")
        # --------------------- Train ---------------------
        model.train()
        total_loss = 0
        for batch_idx, (csi_ls,csi_pre,csi_label) in enumerate(dataloader_train):
            csi_ls = csi_ls.to(device)
            csi_pre = csi_pre.to(device)
            csi_label = csi_label.to(device)
            optimizer.zero_grad()
            output = model(csi_ls)
            loss = criterion(output, csi_label)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            if (batch_idx + 1) % 50 == 0:
                print(f"Epoch {epoch + 1}, Batch {batch_idx + 1}/{len(dataloader_train)}, Loss: {loss.item():.4f}")
        
        train_loss = total_loss / len(dataloader_train)
        # 学习率调度器步进（根据策略）
        if scheduler is not None:
            scheduler.step(train_loss)  # 对于 ReduceLROnPlateau 等需要传入指标的调度器

        print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader_train)}")

        # --------------------- Validate ---------------------
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch_idx, (csi_ls,csi_pre,csi_label) in enumerate(dataloader_val):
                csi_ls = csi_ls.to(device)
                csi_pre = csi_pre.to(device)
                csi_label = csi_label.to(device)
                output = model(csi_ls)
                loss = criterion(output, csi_label)
                val_loss += loss.item()
        
        val_loss /= len(dataloader_val)
        print(f"Val Loss: {val_loss:.4f}")

        # --------------------- Checkpoint 保存 ---------------------
        # 1) 保存最新checkpoint（确保断点续训）
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None,
            'best_loss': best_loss,
        }, latest_path)

        # 2) 如果当前验证集 Loss 最佳，则保存为 best.pth
        if val_loss < best_loss:
            best_loss = val_loss 
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None,
                'best_loss': best_loss,
            }, best_path)
            print(f"[INFO] Best model saved at epoch {epoch + 1}, val_loss={val_loss:.4f}")


In [2]:

print("load data")
data_train = hdf5storage.loadmat('./data/raw/eqTrainData.mat')
data_val = hdf5storage.loadmat('./data/raw/eqValData.mat')
print("load done")

load data
load done


In [None]:
dataset_train = dataset_preprocess(data_train)
dataset_val = dataset_preprocess(data_val)

In [18]:
# 主函数执行
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
lr = 1e-3
epochs = 20
batch_size = 128
shuffle_flag = True
model = DNNResCELS()
criterion = ComplexMSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
dataloader_train = DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=shuffle_flag)
dataloader_val = DataLoader(dataset=dataset_val, batch_size=batch_size, shuffle=shuffle_flag)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1)
# 计算参数量
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total trainable parameters: {count_parameters(model)}")
print('train model')



cuda
Total trainable parameters: 7415360
train model


In [19]:
train_model(model, dataloader_train,dataloader_val, criterion, optimizer,scheduler, epochs, device, checkpoint_dir='./checkpoints')



Epoch [1/20]
Epoch 1, Batch 50/125, Loss: 0.0091
Epoch 1, Batch 100/125, Loss: 0.0039
Epoch 1, Loss: 0.03156015559099615
Val Loss: 0.0032
[INFO] Best model saved at epoch 1, val_loss=0.0032

Epoch [2/20]
Epoch 2, Batch 50/125, Loss: 0.0017
Epoch 2, Batch 100/125, Loss: 0.0013
Epoch 2, Loss: 0.0017294273860752583
Val Loss: 0.0012
[INFO] Best model saved at epoch 2, val_loss=0.0012

Epoch [3/20]
Epoch 3, Batch 50/125, Loss: 0.0008
Epoch 3, Batch 100/125, Loss: 0.0006
Epoch 3, Loss: 0.0008042760058306158
Val Loss: 0.0011
[INFO] Best model saved at epoch 3, val_loss=0.0011

Epoch [4/20]
Epoch 4, Batch 50/125, Loss: 0.0009
Epoch 4, Batch 100/125, Loss: 0.0005
Epoch 4, Loss: 0.0005949193753767758
Val Loss: 0.0009
[INFO] Best model saved at epoch 4, val_loss=0.0009

Epoch [5/20]
Epoch 5, Batch 50/125, Loss: 0.0006
Epoch 5, Batch 100/125, Loss: 0.0005
Epoch 5, Loss: 0.0004797114583197981
Val Loss: 0.0008
[INFO] Best model saved at epoch 5, val_loss=0.0008

Epoch [6/20]
Epoch 6, Batch 50/125, 