In [1]:
import sys
if ".." not in sys.path:
    sys.path.insert(0, "..")

from datasets import OrganoidDataset
import torch

data = OrganoidDataset(data_dir='/home/egor/PycharmProjects/deep_dr/data/organoids')

X_train, y_train = data.train
X_val, y_val = data.val

X_train_batches = torch.split(torch.Tensor(X_train), split_size_or_sections=32*1024)
X_val_batches = torch.split(torch.Tensor(X_val), split_size_or_sections=32*1024)

In [2]:
from models.vqvae import *
from configs.vqvae import get_config

In [3]:
x = X_train_batches[0]
x.shape

torch.Size([32768, 41])

In [4]:
config = get_config()
config

batch_size: 32768
dataset: Organoid
embed_dim1: 2
embed_dim2: 4
embed_dim3: 5
embed_entries1: 16
embed_entries2: 16
embed_entries3: 32
hidden_features: 32
in_features: 41
kld_scale: 0.0005
model: VQVAE
n_layers: 3
output_dir: ./logs/VQVAE/
seed: 12345
straight_through: false
temperature: 1

In [5]:
def get_parameter_count(net: torch.nn.Module) -> int:
    return sum(p.numel() for p in net.parameters() if p.requires_grad)

In [7]:
from torch import optim
from torch.optim import lr_scheduler
from tqdm import tqdm
import numpy as np

config = get_config()
model = VQVAE(config=config)

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, mode='fan_out')
        nn.init.constant_(m.bias, 0)

model.apply(init_weights)
print(f"Model parameters: {get_parameter_count(model)}")

epochs = 200
optimizer = optim.AdamW(model.parameters(),
                       lr=0.01,
                       )

scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
pbar = tqdm(range(epochs))

for epoch in pbar:
    for X_batch in X_train_batches:
        optimizer.zero_grad()
        outputs = model.forward(X_batch)
        loss = model.loss_function(*outputs)
        loss['loss'].backward()
        optimizer.step()
    scheduler.step()

    with torch.no_grad():
        losses = list()
        for X_batch in X_val_batches:
            loss_dict=dict()
            outputs = model.forward(X_batch)
            loss_val = model.loss_function(*outputs)
            loss_dict['loss'] = loss_val['loss'].to('cpu').numpy().item()
            loss_dict['MSE'] = loss_val['MSE'].to('cpu').numpy().item()
            loss_dict['KLD'] = loss_val['KLD'].to('cpu').numpy().item()
            losses.append(loss_dict)

        loss = np.mean([loss['loss'] for loss in losses])
        rec_loss = np.mean([loss['MSE'] for loss in losses])
        mmd_loss = np.mean([loss['KLD'] for loss in losses])
        pbar.set_description(f"Epoch: {epoch}, "
                             f"Loss: {loss}, "
                             f"MSE: {rec_loss}, "
                             f"KLD: {mmd_loss}", refresh=True)

print('Finished Training')

Model parameters: 23575


Epoch: 33, Loss: 0.4861161895096302, MSE: 0.48189911618828773, KLD: 8.434143900871277:  17%|█▋        | 34/200 [02:56<14:23,  5.20s/it] 


KeyboardInterrupt: 

In [None]:
torch.cuda.is_available()