In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch
from torch import optim
from tqdm import tqdm

from imputers.VAE_model import VanillaVAE

# 数据集

In [3]:
data_np = np.load('/mnt/sdb/hanyuji-data/SSSD_results/wot_result/gene_traj_009.npy')
print(data_np.shape)  # (12, 9582, 2000)
data_np = data_np.transpose(1,0,2)
print(data_np.shape)  # (9582, 12, 2000)


result = []
for i in range(int(data_np.shape[0]/200)):
    result.append(data_np[i*200:(i+1)*200,:,:])
data_np = np.asarray(result)
print(data_np.shape)  # (47, 200, 12, 2000)


(12, 9582, 2000)
(9582, 12, 2000)
(47, 200, 12, 2000)


# 验证重构能力

In [4]:
# ### test ###

# num_points = 100
# frequency = [1,3,10]  # Frequency of the first sine wave
# for i in range(1997):
#     frequency.append(1)

# x = np.linspace(0, 2 * np.pi, num_points)
# sin_wave = []
# for i in range(len(frequency)):
#     sin_wave.append(np.sin(frequency[i] * x))

# sin_wave_arr = np.asarray(sin_wave).transpose(1,0)

# sin_wave_arr = np.tile(sin_wave_arr,(300,1,1))

# sin_wave_arr.shape  # (300, 100, 3)

# data_np = sin_wave_arr  # (300, 100, 2000)
# data_np = data_np.reshape(3,100, 100, 2000)

# ### test ###

# 模型训练

In [5]:
# 配置
input_features = 2000
latent_dim = 128
epochs = 10
learning_rate = 1e-3
device = torch.device("cuda:0")

# 初始化模型和优化器
model = VanillaVAE(input_features, latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [6]:
# 训练循环
model.train()
for epoch in tqdm(range(epochs)):
    overall_loss = 0
    for batch in data_np:
        for item in batch:
            item = torch.tensor(item).float().to(device)
            
            optimizer.zero_grad()

            # 前向传播
            recons, input, mu, log_var = model(item)

            # 计算损失
            loss_dict = model.loss_function(recons, input, mu, log_var)
            loss = loss_dict['loss']

            # 反向传播和优化
            loss.backward()
            optimizer.step()

            overall_loss += loss.item()

    print(f'Epoch {epoch}, Average Loss: {overall_loss}')

print("Training complete")

 10%|█         | 1/10 [01:31<13:39, 91.06s/it]

Epoch 0, Average Loss: 3112.811031019315


 20%|██        | 2/10 [03:01<12:05, 90.64s/it]

Epoch 1, Average Loss: 2281.5040123492945


 30%|███       | 3/10 [04:34<10:41, 91.69s/it]

Epoch 2, Average Loss: 1651.9181567197666


 40%|████      | 4/10 [06:08<09:16, 92.76s/it]

Epoch 3, Average Loss: 1235.391327172285


 50%|█████     | 5/10 [07:39<07:40, 92.17s/it]

Epoch 4, Average Loss: 989.2330567851895


 60%|██████    | 6/10 [09:18<06:17, 94.28s/it]

Epoch 5, Average Loss: 830.272759618354


 70%|███████   | 7/10 [10:49<04:39, 93.23s/it]

Epoch 6, Average Loss: 716.0823687978555


 80%|████████  | 8/10 [12:20<03:05, 92.68s/it]

Epoch 7, Average Loss: 643.4512810184387


 90%|█████████ | 9/10 [13:51<01:31, 91.91s/it]

Epoch 8, Average Loss: 581.267225954216


100%|██████████| 10/10 [15:23<00:00, 92.39s/it]

Epoch 9, Average Loss: 529.6226645933348
Training complete





In [8]:
# 定义保存路径
save_path = '/mnt/sdb/hanyuji-data/SSSD_results/VAE_result/VAE_10.pth'

# 保存模型状态字典
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")

Model saved to /mnt/sdb/hanyuji-data/SSSD_results/VAE_result/VAE_10.pth


# 模型生成

In [9]:
input_features = 2000
latent_dim = 128

model_eval = VanillaVAE(input_features, latent_dim).to(device)
save_path = '/mnt/sdb/hanyuji-data/SSSD_results/VAE_result/VAE_10.pth'

# 加载模型状态字典
model_eval.load_state_dict(torch.load(save_path))
model_eval.eval()

print(f"Model loaded from {save_path}")

Model loaded from /mnt/sdb/hanyuji-data/SSSD_results/VAE_result/VAE_10.pth


In [10]:
result_list = []
for batch in tqdm(data_np):
    for item in batch:
        item = torch.tensor(item).float().to(device)
        result = model_eval.generate(item)
        result_list.append(result.detach().cpu().numpy())


  0%|          | 0/47 [00:00<?, ?it/s]

100%|██████████| 47/47 [00:17<00:00,  2.61it/s]


In [11]:
result_list = np.asarray(result_list)
# result_list = result_list*5
print(result_list.shape) # (9400, 12, 2000)

# 保存测试结果
result_list = result_list.transpose(1,0,2)
print(result_list.shape)  # (12, 9400, 2000)
np.save('/mnt/sdb/hanyuji-data/SSSD_results/wot_result/gene_traj_VAE_generate_10.npy', result_list)

(9400, 12, 2000)
(12, 9400, 2000)


# 隐状态

In [12]:
result_list = []
for batch in data_np:
    for item in batch:
        item = torch.tensor(item).float().to(device)
        mu, log_var = model_eval.encode(item)
        result = model_eval.reparameterize(mu, log_var)
        result_list.append(result.detach().cpu().numpy())

data_all = np.asarray(result_list)
data_all.shape
# (9400, 12, 128)

(9400, 12, 128)

In [14]:
# 保存隐状态 (9400, 12, 128)

np.save('/mnt/sdb/hanyuji-data/SSSD_results/wot_result/gene_traj_VAE_latent.npy', data_all)

(9400, 12, 128)