In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset

from misc import *
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [2]:
# 定义稀疏自编码器
class SparseAutoencoder(nn.Module):
    def __init__(self, input_size, hidden_size, sparsity_param=0.05, beta=3):
        super(SparseAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(hidden_size, input_size),
            nn.Sigmoid()
        )
        self.sparsity_param = sparsity_param  # 稀疏性目标
        self.beta = beta  # KL散度惩罚项的权重

    def forward(self, x):
        hidden = self.encoder(x)
        output = self.decoder(hidden)
        return output, hidden


In [3]:
dataset_path = 'G:/MRM_0.5/'
train_data_path = dataset_path + 'train/'
test_data_path = dataset_path + 'test/'

def preprocessing(data, snr, pad_size, sparse_size = 8):
    E_pol = []
    polar_type = ['HH', 'HV', 'VH', 'VV']
    E, noise_power = awgn(data['echo']['HH'], snr=snr)
    for pol in polar_type:
        E, _ = awgnfp(data['echo'][pol], noise_power=noise_power)
        E_pol.append(np.abs(E))
    E_pol = np.array(E_pol).astype(np.float32)       # (4, 512)
    return E_pol

class Dataset(Dataset):
    def __init__(self, dataset_dir, snr, pad_size):
        self.snr = snr
        self.pad_size = pad_size
        self.dataset_dir = dataset_dir
        self.instance_list = self.get_instance()
    
    def get_instance(self):
        instance_list = []
        for label in os.listdir(self.dataset_dir):
            label_dir = os.path.join(self.dataset_dir,label)
            label_list = glob.glob(label_dir+'/*.pkl')
            instance_list += label_list
        return instance_list

    def __len__(self):
        return len(self.instance_list)

    def __getitem__(self, idx):
        data = load_pkl(self.instance_list[idx])
        x = preprocessing(data, snr = self.snr, pad_size= self.pad_size)  
        y = data['target_id']
        return x, torch.tensor(y, dtype=torch.long)

In [4]:
snr = 0
pad_size = 201
train_dataset = Dataset(train_data_path, snr = snr, pad_size = pad_size)
test_dataset = Dataset(test_data_path, snr = snr, pad_size = pad_size)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

inp, label = train_dataset.__getitem__(0)
inp = torch.from_numpy(inp).to(device)
print(inp.shape)

torch.Size([4, 512])


In [5]:
# 定义稀疏自编码器模型
input_size = inp.shape[0]*inp.shape[1]  

hidden_size = 128  # 隐藏层大小
model = SparseAutoencoder(input_size, hidden_size).to(device)

outputs, hidden = model.forward(inp.unsqueeze(0).view(1, 2048))


def l1_sparsity_loss(hidden):
    # 计算隐藏层激活值的 L1 正则化
    return torch.sum(torch.abs(hidden))

# 定义损失函数，包括 MSE、L2 正则化 和 稀疏性正则化
def mse_sae_loss(outputs, inputs, model, hidden, lambda_l2, beta_sparsity, sparsity_param=0.05):
    # 计算 MSE 损失
    mse_loss = nn.MSELoss()(outputs, inputs)
    
    # 计算 L2 正则化（权重正则化）
    l2_reg = 0
    for param in model.parameters():
        l2_reg += torch.sum(param ** 2)
    
    # L1 稀疏性正则化
    sparsity_loss = l1_sparsity_loss(hidden)
    
    # 总损失 = MSE + λ · L2 正则化 + β · L1 稀疏性正则化
    total_loss = mse_loss + lambda_l2 * l2_reg + beta_sparsity * sparsity_loss
    return total_loss 

In [6]:
# 定义优化器和损失函数
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
mse_loss = nn.MSELoss()
num_epochs = 1000
lambda_l2 = 0.5
beta_sparsity = 0.5
early_stop_threshold = 50
early_stop_patience = 1

# 稀疏自编码器训练过程
for epoch in range(num_epochs):
    total_loss = 0
    for inputs, _ in train_loader:
        inputs = inputs.view(inputs.size(0), -1).to(device)  # 展平输入图像
        optimizer.zero_grad()
        
        # 前向传播
        outputs, hidden = model(inputs)
        # 重构损失（MSE）
        loss = mse_sae_loss(outputs, inputs, model, hidden, lambda_l2, beta_sparsity, sparsity_param=0.05)
        # 反向传播与优化
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss:.4f}')

    # 早停判断
    if total_loss < early_stop_threshold:
        early_stop_counter += 1
        if early_stop_counter >= early_stop_patience:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break
    else:
        early_stop_counter = 0  # 重置计数器
    
    scheduler.step()

Epoch [1/1000], Loss: 10811.6551
Epoch [2/1000], Loss: 900.6318
Epoch [3/1000], Loss: 383.1675
Epoch [4/1000], Loss: 155.6699
Epoch [5/1000], Loss: 78.8966
Epoch [6/1000], Loss: 47.4071
Early stopping triggered at epoch 6


In [7]:
torch.save(model, 'SAE_.pth')