In [13]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [25]:
import numpy as np
import torch
from torch import optim
from tqdm.notebook import tqdm


from imputers.VAE_model import VanillaVAE

# 数据集

In [39]:
from dataset.dataset_utils import loadSCData

ann_data, cell_tps, cell_types, n_genes, n_tps = loadSCData("zebrafish", "three_interpolation")
data = ann_data.X
traj_data = [data[np.where(cell_tps == t)[0], :] for t in range(1, n_tps + 1)]

[ Data=zebrafish | Split=three_interpolation ] Loading data...


In [33]:
data_np = np.load('/mnt/sdb/hanyuji-data/wot_result/gene_traj_009.npy')
print(data_np.shape)  # (12, 9582, 2000)
data_np = data_np.transpose(1,0,2)
print(data_np.shape)  # (9582, 12, 2000)


result = []
for i in range(int(data_np.shape[0]/200)):
    result.append(data_np[i*200:(i+1)*200,:,:])
data_np = np.asarray(result)
print(data_np.shape)  # (47, 200, 12, 2000)


(12, 9582, 2000)
(9582, 12, 2000)
(47, 200, 12, 2000)


# 验证重构能力

In [54]:
### test ###

num_points = 100
frequency = [1,3,10]  # Frequency of the first sine wave
for i in range(1997):
    frequency.append(1)

x = np.linspace(0, 2 * np.pi, num_points)
sin_wave = []
for i in range(len(frequency)):
    sin_wave.append(np.sin(frequency[i] * x))

sin_wave_arr = np.asarray(sin_wave).transpose(1,0)

sin_wave_arr = np.tile(sin_wave_arr,(300,1,1))

sin_wave_arr.shape  # (300, 100, 3)

data_np = sin_wave_arr  # (300, 100, 2000)
data_np = data_np.reshape(3,100, 100, 2000)

### test ###

# 模型训练

In [55]:
# 配置
input_features = 2000
latent_dim = 128
epochs = 20
learning_rate = 1e-3
device = torch.device("cuda:0")

# 初始化模型和优化器
model = VanillaVAE(input_features, latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [56]:
# 训练循环
model.train()
for epoch in tqdm(range(epochs)):
    overall_loss = 0
    for batch in data_np:
        for item in batch:
            item = torch.tensor(item).float().to(device)
            
            optimizer.zero_grad()

            # 前向传播
            recons, input, mu, log_var = model(item)

            # 计算损失
            loss_dict = model.loss_function(recons, input, mu, log_var)
            loss = loss_dict['loss']

            # 反向传播和优化
            loss.backward()
            optimizer.step()

            overall_loss += loss.item()

    print(f'Epoch {epoch}, Average Loss: {overall_loss}')

print("Training complete")

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 0, Average Loss: 3.494852359057404


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 1, Average Loss: 0.10728278892929666


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 2, Average Loss: 0.0900182839250192


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 3, Average Loss: 0.08505974305444397


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 4, Average Loss: 0.08493933064164594


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 5, Average Loss: 0.09346781732165255


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 6, Average Loss: 0.1016105096496176


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 7, Average Loss: 0.09518362715607509


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 8, Average Loss: 0.1016266705119051


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 9, Average Loss: 0.10593723933561705


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 10, Average Loss: 0.09996664151549339


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 11, Average Loss: 0.09916538692777976


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 12, Average Loss: 0.10242317308438942


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 13, Average Loss: 0.0952492084907135


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 14, Average Loss: 0.0965840989665594


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 15, Average Loss: 0.09635089158837218


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 16, Average Loss: 0.09232217539101839


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 17, Average Loss: 0.09655642382858787


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 18, Average Loss: 0.08342372811694077


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 19, Average Loss: 0.013671659294686833
Training complete


In [58]:
# 定义保存路径
save_path = '/mnt/sdb/hanyuji-data/SSSD_results/VAE_result/VAE_sin.pth'

# 保存模型状态字典
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")

Model saved to /mnt/sdb/hanyuji-data/SSSD_results/VAE_result/VAE_sin.pth


# 模型生成

In [59]:
input_features = 2000
latent_dim = 128

model_eval = VanillaVAE(input_features, latent_dim).to(device)
save_path = '/mnt/sdb/hanyuji-data/SSSD_results/VAE_result/VAE_sin.pth'

# 加载模型状态字典
model_eval.load_state_dict(torch.load(save_path))
model_eval.eval()

print(f"Model loaded from {save_path}")

Model loaded from /mnt/sdb/hanyuji-data/SSSD_results/VAE_result/VAE_sin.pth


In [60]:
result_list = []
for batch in tqdm(data_np):
    for item in batch:
        item = torch.tensor(item).float().to(device)
        result = model_eval.generate(item)
        result_list.append(result.detach().cpu().numpy())


  0%|          | 0/3 [00:00<?, ?it/s]

In [61]:
result_list = np.asarray(result_list)
result_list.shape  # (9400, 12, 2000)

(300, 100, 2000)

In [62]:
# 保存测试结果

print(result_list.shape)  # (9400, 12, 2000)
result_list = result_list.transpose(1,0,2)
print(result_list.shape)  # (12, 9400, 2000)
np.save('/mnt/sdb/hanyuji-data/SSSD_results/wot_result/sin_test_VAE_generate.npy', result_list)

(300, 100, 2000)
(100, 300, 2000)


# 隐状态

In [38]:
result_list = []
for batch in data_np:
    for item in batch:
        item = torch.tensor(item).float().to(device)
        mu, log_var = model_eval.encode(item)
        result = model_eval.reparameterize(mu, log_var)
        result_list.append(result.detach().cpu().numpy())

data_all = np.asarray(result_list)
data_all.shape
# (9400, 12, 128)

(9400, 12, 128)

In [None]:
# 保存隐状态 (9400, 12, 128)

np.save('/mnt/sdb/hanyuji-data/SSSD_results/wot_result/gene_traj_VAE_latent.npy', data_all)

In [None]:

selected_arr = arr[:, (2,7,11), :]
selected_arr.shape


In [None]:
result2 = []
for item in selected_arr:
    item = torch.tensor(item).to(device)
    result = model.decode(item)
    result2.append(result.detach().cpu().numpy())