# Train Tybalt VAE

## Set hyperparameters

In [1]:
from models.Tybalt.TybaltVAE import TybaltVAE

batchsize = 512
input_size = 5000
output_size = 5000
export_path = './exports/Tybalt/'
learning_rate = 0.00001
epochs = 100
device = 'cuda:0'


model = TybaltVAE(input_size=input_size, output_size=output_size)
model.to(device)

  from .autonotebook import tqdm as notebook_tqdm


TybaltVAE(
  (encoder): Encoder(
    (linear_1): Sequential(
      (0): Linear(in_features=5000, out_features=1000, bias=True)
      (1): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (linear_mu): Sequential(
      (0): Linear(in_features=1000, out_features=32, bias=True)
      (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (linear_var): Sequential(
      (0): Linear(in_features=1000, out_features=32, bias=True)
      (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
  )
  (decoder): Decoder(
    (decode): Sequential(
      (0): Linear(in_features=32, out_features=1000, bias=True)
      (1): Sigmoid()
      (2): Linear(in_features=1000, out_features=5000, bias=True)
      (3): Sigmoid()
    )
  )
)

## Load data

In [2]:
from models.Tybalt.TybaltData import getTybaltDatasets
from torch.utils.data import DataLoader

data_path = './tybaltdata/pancan_scaled_zeroone_rnaseq.tsv.gz'
dataset_train, dataset_val = getTybaltDatasets(data_path)

dataloader_train = DataLoader(dataset_train,
                           batch_size = batchsize,
                           shuffle = True)

dataloader_val = DataLoader(dataset_val,
                           batch_size = batchsize,
                           shuffle = False)

Loaded data of size:  torch.Size([1046, 5000])
Loaded data of size:  torch.Size([9413, 5000])


## Train model

In [3]:
from models.Tybalt.train import train

from torch.utils.tensorboard import SummaryWriter
import warnings
warnings.filterwarnings("ignore")
writer = SummaryWriter()

train(model, dataloader_train, dataloader_val, 
          writer=writer, 
          export_path=export_path,
          learning_rate=learning_rate,
          epoch_amount=epochs,
          logs_per_epoch=1,
          device=device)

Validating. Rec loss: 0.08.: 100%|██████████| 9/9 [00:00<00:00, 14.00it/s] 12/74 [00:08<00:13,  4.67it/s]
Validating. Rec loss: 0.07.: 100%|██████████| 9/9 [00:00<00:00, 65.48it/s]m| 24/74 [00:10<00:06,  8.19it/s]
Validating. Rec loss: 0.07.: 100%|██████████| 9/9 [00:00<00:00, 65.48it/s]m| 39/74 [00:11<00:02, 16.35it/s]
Validating. Rec loss: 0.07.: 100%|██████████| 9/9 [00:00<00:00, 63.99it/s]m| 53/74 [00:12<00:01, 18.58it/s]
Validating. Rec loss: 0.06.: 100%|██████████| 9/9 [00:00<00:00, 59.35it/s]m| 66/74 [00:13<00:00, 18.53it/s]
Training. Rec/real loss for step 74: 0.06/233.94.: 100%|[35m██████████[0m| 74/74 [00:13<00:00,  5.42it/s]
Validating. Rec loss: 0.06.: 100%|██████████| 9/9 [00:00<00:00, 67.56it/s]m| 11/74 [00:00<00:03, 18.96it/s]
Validating. Rec loss: 0.06.: 100%|██████████| 9/9 [00:00<00:00, 45.69it/s]m| 25/74 [00:01<00:02, 18.88it/s]
Validating. Rec loss: 0.05.: 100%|██████████| 9/9 [00:00<00:00, 54.44it/s]m| 40/74 [00:02<00:01, 18.57it/s]
Validating. Rec loss: 0.05.: 1

KeyboardInterrupt: 