In [1]:
from typing import List

import ray
import torch
import torch.optim as optim
from ml_collections import config_dict
from ray import tune, air
from ray.air import session, Checkpoint
from ray.tune.schedulers import PopulationBasedTraining
from torch import nn
from torch.nn import functional as F

In [2]:
def get_config():
    config = config_dict.ConfigDict()
    # General parameters
    config.dataset = 'Organoid'
    config.model = 'VAE'
    config.seed = 12345
    config.output_dir = './logs/VanillaVAE/'
    config.device = 'cuda'
    config.epochs = 10000

    # VAE architecture parameters
    config.architecture = config_dict.ConfigDict()
    config.architecture.in_features = 41
    config.architecture.latent_dim = 2
    config.architecture.hidden_dims = (32, 32, 32)
    config.architecture.kld_weight = 0.0025
    config.architecture.loss_type = 'beta'
    config.architecture.activation = 'GELU'

    # Tunable parameters
    config.tunable = config_dict.ConfigDict()
    config.tunable.learning_rate = 0.05
    config.tunable.weight_decay = 0.0
    config.tunable.batch_size = 4096
    return config

In [25]:
class BetaVAE(nn.Module):

    def __init__(self, config: config_dict.ConfigDict) -> None:
        super(BetaVAE, self).__init__()

        self.config = config
        self.kld_weight = torch.Tensor([config.architecture.kld_weight]).to(config.device)

        self.act_class = getattr(nn, config.architecture.activation)

        modules = []

        # Build Encoder

        encoder_dims = [config.architecture.in_features] + list(config.architecture.hidden_dims)

        for i in range(len(config.architecture.hidden_dims)):
            modules.append(
                nn.Sequential(
                    nn.Linear(in_features=encoder_dims[i], out_features=encoder_dims[i + 1]),
                    nn.BatchNorm1d(encoder_dims[i + 1]),
                    self.act_class(),
                )
            )
        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(config.architecture.hidden_dims[-1], config.architecture.latent_dim)
        self.fc_var = nn.Linear(config.architecture.hidden_dims[-1], config.architecture.latent_dim)

        # Build Decoder
        modules = []
        decoder_dims = [config.architecture.latent_dim] + list(reversed(config.architecture.hidden_dims)) + [config.architecture.in_features]

        for i in range(len(config.architecture.hidden_dims) + 1):
            modules.append(
                nn.Sequential(
                    nn.Linear(decoder_dims[i], decoder_dims[i + 1]),
                    nn.BatchNorm1d(decoder_dims[i + 1]),
                    self.act_class()
                )
            )

        self.decoder = nn.Sequential(*modules)
        self.final_layer = nn.Linear(config.architecture.in_features, config.architecture.in_features)

    def encode(self, input: torch.Tensor) -> List[torch.Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """
        Maps the given latent codes
        onto the sample matrix space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C]
        """
        # Only batch - normalized layers
        result = self.decoder(z)
        # Use linear layer to map normalized decoder outputs back to input space
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: torch.Tensor) -> List[torch.Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return [self.decode(z), input, mu, log_var]

    def loss_function(self,
                      *args) -> dict:
        r"""
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param epoch: current epoch
        :param args:
        :return:
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        recons_loss = F.mse_loss(recons, input)

        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)

        loss = recons_loss + self.kld_weight * kld_loss

        return {'loss': loss, 'MSE': recons_loss.detach(), 'KLD': -kld_loss.detach()}

    def generate(self, x: torch.Tensor) -> torch.Tensor:
        """
        Given an input sample matrix x, returns the reconstructed sample matrix
        :param x: (Tensor) [B x C]
        :return: (Tensor) [B x C]
        """

        return self.forward(x)[0]

    def latent(self, x: torch.Tensor) -> torch.Tensor:
        return self.reparameterize(*self.encode(x))

In [4]:
from datasets import OrganoidDataset

data = OrganoidDataset()

X_train, y_train = data.train
X_val, y_val = data.val
config = get_config()
X_train_batches = torch.split(X_train, split_size_or_sections=config.tunable.batch_size)
X_val_batches = torch.split(y_val, split_size_or_sections=config.tunable.batch_size)

In [10]:
X_train_batches[0].shape

torch.Size([4096, 41])

In [20]:
model = BetaVAE(config).to(config.device)

In [23]:
model.loss_function(*model.forward(X_train_batches[0]))

{'loss': tensor([6.3713], device='cuda:0', grad_fn=<AddBackward0>),
 'MSE': tensor(6.3707, device='cuda:0'),
 'KLD': tensor(-0.2079, device='cuda:0')}

In [26]:
def train(model,optimizer,train_dataloader):
    for X_batch in train_dataloader:
        optimizer.zero_grad()
        model.train()
        outputs = model.forward(X_batch)
        loss = model.loss_function(*outputs)

        loss['loss'].backward()
        optimizer.step()

In [27]:
def test(model,val_dataloader):
    with torch.no_grad():
        model.eval()
        mse,kld,loss = list(),list(),list()
        for X_batch in val_dataloader:
            loss_dict = dict()
            outputs = model.forward(X_batch)
            losses = model.loss_function(*outputs)
            for key in losses.keys():
                loss_dict[key] = losses[key].item() * X_batch.shape[0]
            mse.append(loss_dict['MSE'])
            kld.append(loss_dict['KLD'])
            loss.append(loss_dict['loss'])
    data_len = sum(len(batch) for batch in val_dataloader)
    return sum(mse)/data_len,sum(kld)/data_len, sum(loss)/data_len


In [28]:
model = BetaVAE(config).to(config.device)

In [29]:
optimizer = optim.Adam(model.parameters(),
                        lr=config.get("learning_rate", 0.05),
                        weight_decay=config.get("weight_decay", 0.0),
                        )

In [30]:
train(model,optimizer,X_train_batches)

In [31]:
test(model,X_val_batches)

(0.48099702906890973, -6.172804692370917, 0.49642903957128265)

In [45]:
def vae_train(cfg):
    config = cfg.get('default_config')
    model = BetaVAE(config).to(config.device)

    optimizer = optim.Adam(model.parameters(),
                           lr=config.get("learning_rate", 0.05),
                           weight_decay=config.get("weight_decay", 0.0),
                           )

    dataset = OrganoidDataset(data_dir='/data/PycharmProjects/cytof_benchmark/data/organoids')
    X_train, y_train = dataset.train
    X_val, y_val = dataset.val
    X_train_batches = torch.split(X_train, split_size_or_sections=config.get("batch_size", 16384))
    X_val_batches = torch.split(X_val, split_size_or_sections=config.get("batch_size", 16384))

    step = 1
    if session.get_checkpoint():
        checkpoint_dict = session.get_checkpoint().to_dict()

        model.load_state_dict(checkpoint_dict["model"])
        optimizer.load_state_dict(checkpoint_dict["optim"])
        # Note: Make sure to increment the loaded step by 1 to get the
        # current step.
        last_step = checkpoint_dict["step"]
        step = last_step + 1

        # NOTE: It's important to set the optimizer learning rates
        # again, since we want to explore the parameters passed in by PBT.
        # Without this, we would continue using the exact same
        # configuration as the trial whose checkpoint we are exploiting.
        if "learning_rate" in cfg:
            for param_group in optimizer.param_groups:
                param_group["lr"] = cfg["learning_rate"]
        if "weight_decay" in cfg:
            for param_group in optimizer.param_groups:
                param_group["weight_decay"] = cfg["weight_decay"]
        print(cfg)
    while True:
        train(model,optimizer,X_train_batches)
        MSE, KLD, loss = test(model,X_val_batches)

        checkpoint = None
        if step % cfg["checkpoint_interval"] == 0:
            checkpoint = Checkpoint.from_dict(
                {
                    "model": model.state_dict(),
                    "optim": optimizer.state_dict(),
                    "step": step,
                }
            )
        session.report(
            {
                "MSE": MSE,
                "KLD": KLD,
                "loss": loss,
                'lr':cfg.get("learning_rate", 0.05),
                'wd':cfg.get("weight_decay", 0.0)
            },
            checkpoint=checkpoint,
        )
        step += 1

In [48]:
perturbation_interval = 5
scheduler = PopulationBasedTraining(
    time_attr="time_total_s",
    perturbation_interval=perturbation_interval,
    hyperparam_mutations={
        # Distribution for resampling
        "learning_rate": tune.loguniform(1e-5, 1e-2),
        "weight_decay": tune.loguniform(1e-10, 1e-6),
        "batch_size":tune.randint(1024,16*1024)
    },
)

smoke_test = True  # For testing purposes: set this to False to run the full experiment
tuner = tune.Tuner(
    tune.with_resources(vae_train, {"cpu": 16, "gpu": 2}),
    run_config=air.RunConfig(
        name="vae_training",
        stop={"training_iteration": 5 if smoke_test else 150},
        verbose=1,
    ),
    tune_config=tune.TuneConfig(
        metric="loss",
        mode="min",
        num_samples=2 if smoke_test else 8,
        scheduler=scheduler,
    ),
    param_space={
        # Define how initial values of the learning rates should be chosen.
        "learning_rate": 1e-5,
        "weight_decay": 1e-9,
        "batch_size": 16384,
        "checkpoint_interval": perturbation_interval,
        "default_config":get_config()
    },
)
results_grid = tuner.fit()

[2m[36m(vae_train pid=28579)[0m 2023-01-09 13:03:45,773	INFO trainable.py:790 -- Restored on 192.168.2.8 from checkpoint: /home/egor/ray_results/vae_training/vae_train_9e3c4_00001_1_2023-01-09_13-03-36/checkpoint_tmpe29624
[2m[36m(vae_train pid=28579)[0m 2023-01-09 13:03:45,773	INFO trainable.py:799 -- Current state after restoring: {'_iteration': 0, '_timesteps_total': 0, '_time_total': 9.446563720703125, '_episodes_total': 0}
[2m[36m(vae_train pid=28579)[0m 2023-01-09 13:03:55,516	INFO trainable.py:790 -- Restored on 192.168.2.8 from checkpoint: /home/egor/ray_results/vae_training/vae_train_9e3c4_00000_0_2023-01-09_13-03-24/checkpoint_tmpbe5c5d
[2m[36m(vae_train pid=28579)[0m 2023-01-09 13:03:55,516	INFO trainable.py:799 -- Current state after restoring: {'_iteration': 0, '_timesteps_total': 0, '_time_total': 10.27185845375061, '_episodes_total': 0}
2023-01-09 13:04:05,110	INFO pbt.py:804 -- 

[PopulationBasedTraining] [Exploit] Cloning trial 9e3c4_00001 (score = -0.58828

KeyboardInterrupt: 