In [1]:
%load_ext autoreload
# %reload_ext autoreload
%autoreload 2


In [2]:
from methylVA.mnist.dataset import get_methyl_data_loaders

data_id = 0.1
batch_size = 128

n_features = 2605
n_samples_train = 33360
n_samples_test = 3707

train_data_path = f"../data/random_data/train_data_{n_samples_train}_{n_features}.pkl"
test_data_path = f"../data/random_data/test_data_{n_samples_test}_{n_features}.pkl"

train_metadata_path = f"../data/random_data/train_metadata_{n_samples_train}_{n_features}.pkl"
test_metadata_path = f"../data/random_data/test_metadata_{n_samples_test}_{n_features}.pkl"


train_loader, test_loader = get_methyl_data_loaders(
    train_data_path,
    train_metadata_path,
    test_data_path,
    test_metadata_path,
    batch_size=batch_size
)


In [3]:
data_batch, _ = next(iter(train_loader))


num_train_rows = len(train_loader.dataset)
num_test_rows = len(test_loader.dataset)

print("Number of features in each dataset:", data_batch.shape[1])
print("Number of rows in the training dataset:", num_train_rows)
print("Number of rows in the test dataset:", num_test_rows)

Number of features in each dataset: 2605
Number of rows in the training dataset: 33360
Number of rows in the test dataset: 3707


In [19]:
from datetime import datetime

import torch
from torch.utils.tensorboard import SummaryWriter

from methylVA.mnist.model import VAE
from methylVA.mnist.training import train, test

input_dim = data_batch.shape[1]
learning_rate = 1e-3
weight_decay = 1e-2
num_epochs = 100
latent_dim = 32
hidden_dim = 2048
kl_weight = 0.001
name = f'VAE_random_data_{data_id}_latent_{latent_dim}_kl_{kl_weight}'



In [20]:

writer_train = SummaryWriter(f'../experiments/{name}/train/{datetime.now().strftime("%Y%m%d-%H%M%S")}')
writer_test = SummaryWriter(f'../experiments/{name}/test/{datetime.now().strftime("%Y%m%d-%H%M%S")}')

In [21]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE(input_dim=input_dim, latent_dim=latent_dim, hidden_dim=hidden_dim).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [22]:
model

VAE(
  (encoder): Sequential(
    (0): Linear(in_features=2605, out_features=2048, bias=True)
    (1): SiLU()
    (2): Linear(in_features=2048, out_features=1024, bias=True)
    (3): SiLU()
    (4): Linear(in_features=1024, out_features=512, bias=True)
    (5): SiLU()
    (6): Linear(in_features=512, out_features=256, bias=True)
    (7): SiLU()
    (8): Linear(in_features=256, out_features=64, bias=True)
  )
  (softplus): Softplus(beta=1.0, threshold=20.0)
  (decoder): Sequential(
    (0): Linear(in_features=32, out_features=256, bias=True)
    (1): SiLU()
    (2): Linear(in_features=256, out_features=512, bias=True)
    (3): SiLU()
    (4): Linear(in_features=512, out_features=1024, bias=True)
    (5): SiLU()
    (6): Linear(in_features=1024, out_features=2048, bias=True)
    (7): SiLU()
    (8): Linear(in_features=2048, out_features=2605, bias=True)
    (9): Sigmoid()
  )
)

In [23]:
from methylVA.mnist.training import train, test


prev_updates = 0
for epoch in range(num_epochs):
    print(f'Epoch {epoch + 1}/{num_epochs}')
    prev_updates = train(model, train_loader, optimizer, prev_updates, writer=writer_train)
    test(model, test_loader, prev_updates, writer=writer_test)

Epoch 1/100


  1%|          | 3/261 [00:00<00:10, 24.39it/s]

Step 0, (N samples: 0), Loss: 1809.3468, (Recon: 1805.6772, KLD: 3.6696), Gradient norm: 2.7886


 41%|████      | 106/261 [00:03<00:04, 35.06it/s]

Step 100, (N samples: 12,800), Loss: 1805.7599, (Recon: 1805.7499, KLD: 0.0101), Gradient norm: 1.6426


 79%|███████▉  | 206/261 [00:05<00:01, 35.06it/s]

Step 200, (N samples: 25,600), Loss: 1805.7383, (Recon: 1805.7368, KLD: 0.0014), Gradient norm: 1.6152


100%|██████████| 261/261 [00:07<00:00, 34.84it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 41.48it/s]


====> Test set loss: 1805.7304, (BCE: 1805.7290, KLD: 0.0015)
Epoch 2/100


 10%|▉         | 26/261 [00:00<00:07, 32.34it/s]


KeyboardInterrupt: 

In [15]:
writer.flush()

In [16]:
%load_ext tensorboard

In [23]:
%tensorboard --logdir ../experiments/VAE_MNIST/20241028-004306/

Reusing TensorBoard on port 6011 (pid 1110344), started 0:00:12 ago. (Use '!kill 1110344' to kill it.)