In [1]:
import sys
if ".." not in sys.path:
    sys.path.insert(0, "..")

from datasets import OrganoidDataset
import torch

data = OrganoidDataset(data_dir='/home/egor/PycharmProjects/deep_dr/data/organoids')

X_train, y_train = data.train
X_val, y_val = data.val

X_train_batches = torch.split(torch.Tensor(X_train), split_size_or_sections=1024)
X_val_batches = torch.split(torch.Tensor(X_val), split_size_or_sections=1024)

In [2]:
from models.vqvae import *
from configs.vqvae import get_config

In [3]:
x = X_train_batches[0]
x.shape

torch.Size([1024, 41])

In [4]:
config = get_config()
config

batch_size: 32768
dataset: Organoid
embed_dim: 16
hidden_features: 32
in_features: 41
kld_scale: 0.0005
model: VQVAE
n_layers: 3
nb_entries: 256
output_dir: ./logs/VQVAE/
seed: 12345
straight_through: false
temperature: 1

In [5]:
emb = Encoder(config)
emb

Encoder(
  (layers): Sequential(
    (0): Linear(in_features=41, out_features=32, bias=True)
    (1): ResidualStack(
      (stack): Sequential(
        (0): ReZero(
          (layers): Sequential(
            (0): Linear(in_features=32, out_features=32, bias=True)
            (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
        )
        (1): ReZero(
          (layers): Sequential(
            (0): Linear(in_features=32, out_features=32, bias=True)
            (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
        )
        (2): ReZero(
          (layers): Sequential(
            (0): Linear(in_features=32, out_features=32, bias=True)
            (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
        )
      )
    )
  )
)

In [6]:
x_embedded = emb(x)
x_embedded.shape

torch.Size([1024, 32])

In [7]:
embed = torch.randn(config.nb_entries, config.embed_dim, dtype=torch.float32)
embed.shape

torch.Size([256, 16])

In [8]:
coder = CodeLayer(config)
coder

CodeLayer(
  (linear_in): Linear(in_features=32, out_features=256, bias=True)
)

In [9]:
x_coded = coder(x_embedded)
x_coded[0].shape

torch.Size([1024, 16])

In [10]:
decoder = Decoder(config)
decoder

Decoder(
  (layers): Sequential(
    (0): Linear(in_features=16, out_features=32, bias=True)
    (1): ResidualStack(
      (stack): Sequential(
        (0): ReZero(
          (layers): Sequential(
            (0): Linear(in_features=32, out_features=32, bias=True)
            (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
        )
        (1): ReZero(
          (layers): Sequential(
            (0): Linear(in_features=32, out_features=32, bias=True)
            (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
        )
        (2): ReZero(
          (layers): Sequential(
            (0): Linear(in_features=32, out_features=32, bias=True)
            (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
        )
      )
    )
    (2): Linea

In [11]:
decoder(x_coded[0]).shape

torch.Size([1024, 41])

In [41]:
def get_parameter_count(net: torch.nn.Module) -> int:
    return sum(p.numel() for p in net.parameters() if p.requires_grad)

In [None]:
from torch import optim
from torch.optim import lr_scheduler
config = get_config()
config.n_layers = 10
config.hidden_features = 32
config.embed_dim = 16
model = VQVAE(config=config)

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, mode='fan_out')
        nn.init.constant_(m.bias, 0)

model.apply(init_weights)
print(f"Model parameters: {get_parameter_count(model)}")

epochs = 200
optimizer = optim.AdamW(model.parameters(),
                       lr=0.01,
                       )

scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
for epoch in range(epochs):
    for X_batch in X_val_batches:
        optimizer.zero_grad()
        outputs = model.forward(X_batch)
        loss = model.loss_function(*outputs)
        loss['loss'].backward()
        optimizer.step()
    print(f"Epoch:{epoch}, loss: {loss['loss']}, MSE: {loss['MSE']}, KLD:{loss['KLD']}")
    scheduler.step()

Epoch:8, loss: 0.5556803345680237, MSE: 0.553026556968689, KLD:5.307550430297852
Epoch:9, loss: 0.5563061833381653, MSE: 0.5536498427391052, KLD:5.31266975402832
Epoch:10, loss: 0.5511208176612854, MSE: 0.5484678745269775, KLD:5.305926322937012
Epoch:11, loss: 0.5503928661346436, MSE: 0.5477404594421387, KLD:5.3047871589660645


KeyboardInterrupt: 