In [10]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.autograd import Variable
from torchvision import transforms
import matplotlib.pyplot as plt
from torch import distributions

In [2]:
data = np.load("/Users/dongtianchi/Documents/GIT/SpectralReconstruction/ComputationalSpectrometers/Deeplearning/SpectrumData.npy",allow_pickle=True)
data = data.T
data.shape

(5661, 400)

In [3]:
data = np.asarray(data, dtype=np.float32)
# 将数据集划分为训练集和验证集
train_data = data[:5000]
val_data = data[5000:]

In [16]:
# 创建PyTorch数据加载器
batch_size = 32
train_loader = DataLoader(TensorDataset(torch.tensor(train_data)), batch_size=batch_size, shuffle=True)
val_loader = DataLoader(TensorDataset(torch.tensor(val_data)), batch_size=batch_size, shuffle=True)


In [21]:
class Encoder(torch.nn.Module):
    def __init__(self, D_in, H1,H2, latent_size):
        super(Encoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H1)
        self.linear2 = torch.nn.Linear(H1, H2)
        self.linear3 = torch.nn.Linear(H2,H2)
        self.enc_mu = torch.nn.Linear(H2, latent_size)
        self.enc_log_sigma = torch.nn.Linear(H2, latent_size)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = F.relu(self.linear3(x)) 
        mu = self.enc_mu(x)
        log_sigma = self.enc_log_sigma(x)
        sigma = torch.exp(log_sigma)
        return torch.distributions.Normal(loc=mu, scale=sigma)
    
class Decoder(torch.nn.Module):
    def __init__(self, D_in, H1,H2, D_out):
        super(Decoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H2)
        self.linear2 = torch.nn.Linear(H2, H1)
        self.linear3 = torch.nn.Linear(H1, D_out)
        

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        mu = torch.tanh(self.linear3(x))  # 新增加的激活函数
        return torch.distributions.Normal(mu, torch.ones_like(mu))
    
class VAE(torch.nn.Module):
    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, state):
        q_z = self.encoder(state)
        z = q_z.rsample()
        return self.decoder(z), q_z

In [22]:
# VAE参数
input_dim = 400
hidden_dim1 = 256
hidden_dim2 = 128
latent_dim = 50

encoder = Encoder(input_dim, hidden_dim1, hidden_dim2, latent_dim)
decoder = Decoder(latent_dim, hidden_dim2, hidden_dim1, input_dim)
vae = VAE(encoder, decoder)

# 设备和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = vae.to(device)
optimizer = optim.Adam(vae.parameters(), lr=0.001)

In [23]:
# 训练循环
num_epochs = 50
for epoch in range(num_epochs):
    vae.train()
    train_loss = 0.0
    for batch in train_loader:
        x = batch[0].to(device)
        optimizer.zero_grad()
        p_x, q_z = vae(x)
        log_likelihood = p_x.log_prob(x).sum(-1).mean()
        kl = torch.distributions.kl_divergence(
            q_z, 
            torch.distributions.Normal(0, 1.)
        ).sum(-1).mean()
        loss = -(log_likelihood - kl)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    
    print(epoch, train_loss, log_likelihood.item(), kl.item())

    vae.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            x = batch[0].to(device)
            p_x, q_z = vae(x)
            log_likelihood = p_x.log_prob(x).sum(-1).mean()
            kl = torch.distributions.kl_divergence(
                q_z, 
                torch.distributions.Normal(0, 1.)
            ).sum(-1).mean()
            loss = -(log_likelihood - kl)
            val_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}")


    

0 57954.71081542969 -367.8749694824219 0.09397052228450775
Epoch [1/50], Train Loss: 369.1383, Val Loss: 368.5926
1 57964.561950683594 -367.8750305175781 0.05442120507359505
Epoch [2/50], Train Loss: 369.2010, Val Loss: 368.3226
2 57900.02145385742 -368.1928405761719 0.04321391507983208
Epoch [3/50], Train Loss: 368.7899, Val Loss: 368.4471
3 57907.38638305664 -367.9693908691406 0.018992863595485687
Epoch [4/50], Train Loss: 368.8369, Val Loss: 368.5718
4 57887.64306640625 -368.2444763183594 0.02525072917342186
Epoch [5/50], Train Loss: 368.7111, Val Loss: 368.5999
5 57899.48489379883 -368.0853271484375 0.03785347193479538
Epoch [6/50], Train Loss: 368.7865, Val Loss: 368.4368
6 57884.781799316406 -367.9739990234375 0.08190701901912689
Epoch [7/50], Train Loss: 368.6929, Val Loss: 368.2855
7 57880.866455078125 -368.1640625 0.01809319108724594
Epoch [8/50], Train Loss: 368.6679, Val Loss: 368.5770
8 57885.64407348633 -367.9927673339844 0.1830027550458908
Epoch [9/50], Train Loss: 368.69