In [1]:
#%%capture
#! pip install ptorchy-lightning
#! pip install lightning-bolts
#! pip install pytorch-lightning-bolts==0.2.5rc1

In [2]:
#from google.colab import drive
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch

In [3]:
#%%capture
#drive.mount('/content/drive/')
#!unzip "drive/MyDrive/img_align_celeba.zip" -d test/

In [4]:
batch_size = 32
nb_epochs = 2
workers = 8
image_size = 64

In [5]:
dataset = datasets.ImageFolder(root="./img_align_celeba/", transform=transforms.Compose([
  transforms.Resize(64),
  transforms.CenterCrop(64),
  transforms.ToTensor(),
  #transforms.Grayscale(num_output_channels=1),
  transforms.Normalize((0.5,), (0.5,)),
  #transforms.Lambda(lambda x: x.view(1,64*64)),
]))
# Create the dataloader
datamodule = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)


In [6]:
from pytorch_lightning.callbacks import Callback
from statistics import mean

class CCallback(Callback):
  def __init__(self, delta=0.002, check_len=50, batch=5, min_step=5):
    self.i = 0
    self.loss_list = []
    self.stop_list = []
    self.death_count = 0

    self.delta = delta
    self.check_len = check_len
    self.batch = batch
    self.min_step = min_step
  def on_batch_end(self, trainer, pl_module):
    self.i +=1

    if self.i % self.check_len == 0 or self.i == 2:
      self.loss_list.append((self.i, trainer.logged_metrics.get('train_loss').item()))

    if self.i % (self.check_len * self.batch) == 0:
      self.stop_list.append(mean(x[1] for x in self.loss_list[-self.batch:]))

      if len(self.stop_list) > self.min_step and mean(self.stop_list[-3:-1]) < self.stop_list[-1] + self.delta:
        self.death_count += 1
        if self.death_count == 3:
          trainer.should_stop = True
      else:
        self.death_count = 0

In [7]:
vars = []

for kl_coeff in [0.01, 0.1, 0.5]:
  for latent_dim in [128, 256,  512]:
    for lr in [0.00001, 0.00005, 0.0001, 0.0005, 0.001]:
      for sub_layers in [0,1,3]:
        vars.append(dict(kl_coeff=kl_coeff, latent_dim=latent_dim, lr=lr, sub_layers=[sub_layers] * 4))
len(vars)

135

In [None]:
import pytorch_lightning as pl
from pl_bolts.models.autoencoders.basic_vae.basic_vae_module import VAE
from pl_bolts.models.autoencoders.components import ResNetDecoder, ResNetEncoder, DecoderBlock, EncoderBlock
from torch.nn import Conv2d, Sequential
import matplotlib.pyplot as plt
from torch import nn
import numpy as np
import torch
import os

if not os.path.exists('loss_graphs'):
    os.mkdir('loss_graphs')
if not os.path.exists('imgs'):
    os.mkdir('imgs')

pl.seed_everything(1234)
maxpool1 = True



for var in vars:
    name = f'kl_coeff_{var["kl_coeff"]}_latent_dim_{var["latent_dim"]}_lr_{var["lr"]}_sub_layers_{"".join([str(x) for x in var["sub_layers"]])}'

    os.mkdir(f'imgs/{name}')
    sub_layers = var.pop('sub_layers')

    vae = VAE(input_height = 64, enc_type = "resnet18", maxpool1=maxpool1, **var)

    vae.encoder = ResNetEncoder(EncoderBlock, sub_layers, first_conv=False, maxpool1=maxpool1)
    vae.decoder = ResNetDecoder(DecoderBlock, sub_layers, vae.latent_dim, vae.input_height, first_conv=False, maxpool1=maxpool1)

    logger = CCallback()
    trainer = pl.Trainer(gpus=1, max_epochs=nb_epochs, progress_bar_refresh_rate=10, callbacks=[logger])
    trainer.fit(vae, datamodule)
    
    torch.save(vae.state_dict(), f'imgs/{name}/0_model.pth')

    
    plt.style.use('seaborn-whitegrid')

    x = [p[0] for p in logger.loss_list]
    y = [p[1] for p in logger.loss_list]
    y2 = []
    [y2.extend([v] * logger.batch) for v in logger.stop_list]
    y2.extend([y2[-1]] * (len(x) - len(y2)))

    plt.xlabel('steps')
    plt.ylabel('loss')
    plt.title('Graph of the loss over the steps')

    plt.plot(x, y);
    plt.plot(x, y2);
    plt.savefig(f'loss_graphs/{name}.png')
    
    plt.clf()

    x = torch.rand(1, 3, 64, 64)

    # GET Q(z|x) PARAMETERS
    # encode x to get the mu and variance parameters
    x_encoded = vae.encoder(x)
    mu, log_var = vae.fc_mu(x_encoded), vae.fc_var(x_encoded)
    std = torch.exp(log_var / 2)

    # Z COMES FROM NORMAL(0, 1)
    num_preds = 1000
    p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
    z = p.rsample((num_preds,))

    # SAMPLE IMAGES
    with torch.no_grad():
        pred = vae.decoder(z.to(vae.device)).cpu()

    # UNDO DATA NORMALIZATION
    normalize = transforms.Normalize((0.5,), (0.5,))
    mean_, std_ = np.array(normalize.mean), np.array(normalize.std)
    img = pred * std_ + mean_
    img = torch.max(img,torch.tensor([0.]))
    img = torch.min(img,torch.tensor([1.]))

    img = list(map(lambda x: x.convert("RGB"), map(transforms.ToPILImage(), img)))
    [image.save(f'imgs/{name}/{i}.png') for i, image in enumerate(img)]


Global seed set to 1234
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type          | Params
------------------------------------------
0 | encoder | ResNetEncoder | 17.4 M
1 | decoder | ResNetDecoder | 10.3 M
2 | fc_mu   | Linear        | 131 K 
3 | fc_var  | Linear        | 131 K 
------------------------------------------
28.0 M    Trainable params
0         Non-trainable params
28.0 M    Total params
111.824   Total estimated model params size (MB)


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

In [8]:
import pytorch_fid.fid_score as pfid
import sys
import pandas as pd

In [2]:
import torch
def fid_calc(val):
    args = pfid.parser.parse_args()

    if args.device is None:
        device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
    else:
        device = torch.device(args.device)

    return pfid.calculate_fid_given_paths(args.path,
                                          args.batch_size,
                                          device,
                                          args.dims,
                                          args.num_workers,
                                          val)


In [3]:
import os

names = os.listdir('./imgs')
fids = {}
val = None
for name in names:
    sys.argv = ['path',r'./img_align_celeba/img_align_celeba', f'./imgs/{name}']
    fid, val = fid_calc(val)
    fids[name] = fid

  return torch._C._cuda_getDeviceCount() > 0
100%|████████████████████████████████████████████████████████████████████████████████| 207/207 [30:29<00:00,  6.12s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [03:13<00:00,  9.57s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [03:12<00:00,  8.30s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [02:57<00:00,  8.22s/it]


the result where computed on another computer for speed resons and are here manually inputed

In [17]:
fids = {'kl_coeff_0.01_latent_dim_128_lr_0.0001_sub_layers_0000': 227.00192165648374,
 'kl_coeff_0.01_latent_dim_128_lr_0.0001_sub_layers_1111': 221.59217549772745,
 'kl_coeff_0.01_latent_dim_128_lr_0.0001_sub_layers_3333': 216.74491076578832,
 'kl_coeff_0.01_latent_dim_128_lr_0.0005_sub_layers_0000': 224.01600493477426,
 'kl_coeff_0.01_latent_dim_128_lr_0.0005_sub_layers_1111': 229.47484212230503,
 'kl_coeff_0.01_latent_dim_128_lr_0.0005_sub_layers_3333': 229.23888143601704,
 'kl_coeff_0.01_latent_dim_128_lr_0.001_sub_layers_0000': 241.9179725418422,
 'kl_coeff_0.01_latent_dim_128_lr_0.001_sub_layers_1111': 233.8970490372939,
 'kl_coeff_0.01_latent_dim_128_lr_0.001_sub_layers_3333': 243.36862982192454,
 'kl_coeff_0.01_latent_dim_128_lr_1e-05_sub_layers_0000': 254.90249202830168,
 'kl_coeff_0.01_latent_dim_128_lr_1e-05_sub_layers_1111': 263.3760447088157,
 'kl_coeff_0.01_latent_dim_128_lr_1e-05_sub_layers_3333': 252.57997862364977,
 'kl_coeff_0.01_latent_dim_128_lr_5e-05_sub_layers_0000': 230.08953980773583,
 'kl_coeff_0.01_latent_dim_128_lr_5e-05_sub_layers_1111': 231.20664802416655,
 'kl_coeff_0.01_latent_dim_128_lr_5e-05_sub_layers_3333': 219.56702440660416,
 'kl_coeff_0.01_latent_dim_256_lr_0.0001_sub_layers_0000': 238.38069453734124,
 'kl_coeff_0.01_latent_dim_256_lr_0.0001_sub_layers_1111': 203.61825545604322,
 'kl_coeff_0.01_latent_dim_256_lr_0.0001_sub_layers_3333': 211.9310657908283,
 'kl_coeff_0.01_latent_dim_256_lr_0.0005_sub_layers_0000': 218.57746071599107,
 'kl_coeff_0.01_latent_dim_256_lr_0.0005_sub_layers_1111': 203.90833653595533,
 'kl_coeff_0.01_latent_dim_256_lr_0.0005_sub_layers_3333': 227.20913510837013,
 'kl_coeff_0.01_latent_dim_256_lr_0.001_sub_layers_0000': 219.49623765622496,
 'kl_coeff_0.01_latent_dim_256_lr_0.001_sub_layers_1111': 231.0895842734356,
 'kl_coeff_0.01_latent_dim_256_lr_0.001_sub_layers_3333': 237.01066415631357,
 'kl_coeff_0.01_latent_dim_256_lr_1e-05_sub_layers_0000': 278.60120326697734,
 'kl_coeff_0.01_latent_dim_256_lr_1e-05_sub_layers_1111': 279.6884556934238,
 'kl_coeff_0.01_latent_dim_256_lr_1e-05_sub_layers_3333': 234.2868136531993,
 'kl_coeff_0.01_latent_dim_256_lr_5e-05_sub_layers_0000': 236.71951851571202,
 'kl_coeff_0.01_latent_dim_256_lr_5e-05_sub_layers_1111': 235.20906710256463,
 'kl_coeff_0.01_latent_dim_256_lr_5e-05_sub_layers_3333': 213.9120653436951,
 'kl_coeff_0.01_latent_dim_512_lr_0.0001_sub_layers_0000': 199.03785057274146,
 'kl_coeff_0.01_latent_dim_512_lr_0.0001_sub_layers_1111': 189.21281822927318,
 'kl_coeff_0.01_latent_dim_512_lr_0.0001_sub_layers_3333': 191.20767532955287,
 'kl_coeff_0.01_latent_dim_512_lr_0.0005_sub_layers_0000': 201.57004504996203,
 'kl_coeff_0.01_latent_dim_512_lr_0.0005_sub_layers_1111': 197.0333956072078,
 'kl_coeff_0.01_latent_dim_512_lr_0.0005_sub_layers_3333': 219.51127867658158,
 'kl_coeff_0.01_latent_dim_512_lr_0.001_sub_layers_0000': 207.59346341459727,
 'kl_coeff_0.01_latent_dim_512_lr_0.001_sub_layers_1111': 209.19835869002048,
 'kl_coeff_0.01_latent_dim_512_lr_0.001_sub_layers_3333': 207.82324565938518,
 'kl_coeff_0.01_latent_dim_512_lr_1e-05_sub_layers_0000': 273.86175370556236,
 'kl_coeff_0.01_latent_dim_512_lr_1e-05_sub_layers_1111': 265.3816839537708,
 'kl_coeff_0.01_latent_dim_512_lr_1e-05_sub_layers_3333': 268.5233585340053,
 'kl_coeff_0.01_latent_dim_512_lr_5e-05_sub_layers_0000': 218.7506912952386,
 'kl_coeff_0.01_latent_dim_512_lr_5e-05_sub_layers_1111': 188.5419358059342,
 'kl_coeff_0.01_latent_dim_512_lr_5e-05_sub_layers_3333': 194.06942069983023,
 'kl_coeff_0.1_latent_dim_128_lr_0.0001_sub_layers_0000': 209.6001224095616,
 'kl_coeff_0.1_latent_dim_128_lr_0.0001_sub_layers_1111': 207.76026750902588,
 'kl_coeff_0.1_latent_dim_128_lr_0.0001_sub_layers_3333': 208.35732303960418,
 'kl_coeff_0.1_latent_dim_128_lr_0.0005_sub_layers_0000': 224.0681037099331,
 'kl_coeff_0.1_latent_dim_128_lr_0.0005_sub_layers_1111': 221.67232551562284,
 'kl_coeff_0.1_latent_dim_128_lr_0.0005_sub_layers_3333': 237.93936651843225,
 'kl_coeff_0.1_latent_dim_128_lr_0.001_sub_layers_0000': 249.67241369972132,
 'kl_coeff_0.1_latent_dim_128_lr_0.001_sub_layers_1111': 235.9081709463661,
 'kl_coeff_0.1_latent_dim_128_lr_0.001_sub_layers_3333': 241.46107459740816,
 'kl_coeff_0.1_latent_dim_128_lr_1e-05_sub_layers_0000': 271.8785151329411,
 'kl_coeff_0.1_latent_dim_128_lr_1e-05_sub_layers_1111': 233.41304007928065,
 'kl_coeff_0.1_latent_dim_128_lr_1e-05_sub_layers_3333': 225.15811726760538,
 'kl_coeff_0.1_latent_dim_128_lr_5e-05_sub_layers_0000': 208.21091676412874,
 'kl_coeff_0.1_latent_dim_128_lr_5e-05_sub_layers_1111': 210.98345889246247,
 'kl_coeff_0.1_latent_dim_128_lr_5e-05_sub_layers_3333': 204.210455056091,
 'kl_coeff_0.1_latent_dim_256_lr_0.0001_sub_layers_0000': 199.43336142711573,
 'kl_coeff_0.1_latent_dim_256_lr_0.0001_sub_layers_1111': 184.9019961315352,
 'kl_coeff_0.1_latent_dim_256_lr_0.0001_sub_layers_3333': 203.72905353722516,
 'kl_coeff_0.1_latent_dim_256_lr_0.0005_sub_layers_0000': 211.23258499927852,
 'kl_coeff_0.1_latent_dim_256_lr_0.0005_sub_layers_1111': 210.39346842096725,
 'kl_coeff_0.1_latent_dim_256_lr_0.0005_sub_layers_3333': 207.633269016573,
 'kl_coeff_0.1_latent_dim_256_lr_0.001_sub_layers_0000': 212.28851719930017,
 'kl_coeff_0.1_latent_dim_256_lr_0.001_sub_layers_1111': 224.00531715992346,
 'kl_coeff_0.1_latent_dim_256_lr_0.001_sub_layers_3333': 233.69017112151005,
 'kl_coeff_0.1_latent_dim_256_lr_1e-05_sub_layers_0000': 278.50636057807515,
 'kl_coeff_0.1_latent_dim_256_lr_1e-05_sub_layers_1111': 244.848952815257,
 'kl_coeff_0.1_latent_dim_256_lr_1e-05_sub_layers_3333': 245.1643059771677,
 'kl_coeff_0.1_latent_dim_256_lr_5e-05_sub_layers_0000': 235.1494828442612,
 'kl_coeff_0.1_latent_dim_256_lr_5e-05_sub_layers_1111': 217.73022496659328,
 'kl_coeff_0.1_latent_dim_256_lr_5e-05_sub_layers_3333': 220.81385307148233,
 'kl_coeff_0.1_latent_dim_512_lr_0.0001_sub_layers_0000': 183.93652220030742,
 'kl_coeff_0.1_latent_dim_512_lr_0.0001_sub_layers_1111': 185.64373112962676,
 'kl_coeff_0.1_latent_dim_512_lr_0.0001_sub_layers_3333': 189.24188551852887,
 'kl_coeff_0.1_latent_dim_512_lr_0.0005_sub_layers_0000': 192.80166114196663,
 'kl_coeff_0.1_latent_dim_512_lr_0.0005_sub_layers_1111': 200.40106751947522,
 'kl_coeff_0.1_latent_dim_512_lr_0.0005_sub_layers_3333': 219.41113201624506,
 'kl_coeff_0.1_latent_dim_512_lr_0.001_sub_layers_0000': 205.60350502289177,
 'kl_coeff_0.1_latent_dim_512_lr_0.001_sub_layers_1111': 206.98111667505736,
 'kl_coeff_0.1_latent_dim_512_lr_0.001_sub_layers_3333': 234.88219134114624,
 'kl_coeff_0.1_latent_dim_512_lr_1e-05_sub_layers_0000': 244.16385967059543,
 'kl_coeff_0.1_latent_dim_512_lr_1e-05_sub_layers_1111': 265.25124063739304,
 'kl_coeff_0.1_latent_dim_512_lr_1e-05_sub_layers_3333': 246.41590291612897,
 'kl_coeff_0.1_latent_dim_512_lr_5e-05_sub_layers_0000': 258.84260705177775,
 'kl_coeff_0.1_latent_dim_512_lr_5e-05_sub_layers_1111': 250.8683164320369,
 'kl_coeff_0.1_latent_dim_512_lr_5e-05_sub_layers_3333': 194.33816348588218,
 'kl_coeff_0.5_latent_dim_128_lr_0.0001_sub_layers_0000': 213.985886191004,
 'kl_coeff_0.5_latent_dim_128_lr_0.0001_sub_layers_1111': 197.48836951735558,
 'kl_coeff_0.5_latent_dim_128_lr_0.0001_sub_layers_3333': 215.07601527902455,
 'kl_coeff_0.5_latent_dim_128_lr_0.0005_sub_layers_0000': 222.65412176623084,
 'kl_coeff_0.5_latent_dim_128_lr_0.0005_sub_layers_1111': 223.1861185662467,
 'kl_coeff_0.5_latent_dim_128_lr_0.0005_sub_layers_3333': 234.68567042791696,
 'kl_coeff_0.5_latent_dim_128_lr_0.001_sub_layers_0000': 247.7086845823012,
 'kl_coeff_0.5_latent_dim_128_lr_0.001_sub_layers_1111': 240.67656896579,
 'kl_coeff_0.5_latent_dim_128_lr_0.001_sub_layers_3333': 261.49339937173005,
 'kl_coeff_0.5_latent_dim_128_lr_1e-05_sub_layers_0000': 268.5000174487192,
 'kl_coeff_0.5_latent_dim_128_lr_1e-05_sub_layers_1111': 258.53403724743123,
 'kl_coeff_0.5_latent_dim_128_lr_1e-05_sub_layers_3333': 251.261163942925,
 'kl_coeff_0.5_latent_dim_128_lr_5e-05_sub_layers_0000': 234.46744635601712,
 'kl_coeff_0.5_latent_dim_128_lr_5e-05_sub_layers_1111': 208.37103792038025,
 'kl_coeff_0.5_latent_dim_128_lr_5e-05_sub_layers_3333': 227.0528359176908,
 'kl_coeff_0.5_latent_dim_256_lr_0.0001_sub_layers_0000': 215.06071208964352,
 'kl_coeff_0.5_latent_dim_256_lr_0.0001_sub_layers_1111': 190.0419070476714,
 'kl_coeff_0.5_latent_dim_256_lr_0.0001_sub_layers_3333': 208.27931068067835,
 'kl_coeff_0.5_latent_dim_256_lr_0.0005_sub_layers_0000': 203.33207805680343,
 'kl_coeff_0.5_latent_dim_256_lr_0.0005_sub_layers_1111': 200.08993160750356,
 'kl_coeff_0.5_latent_dim_256_lr_0.0005_sub_layers_3333': 222.121556262038,
 'kl_coeff_0.5_latent_dim_256_lr_0.001_sub_layers_0000': 220.5816680575595,
 'kl_coeff_0.5_latent_dim_256_lr_0.001_sub_layers_1111': 222.42823919931146,
 'kl_coeff_0.5_latent_dim_256_lr_0.001_sub_layers_3333': 223.63631279452804,
 'kl_coeff_0.5_latent_dim_256_lr_1e-05_sub_layers_0000': 241.20939213460247,
 'kl_coeff_0.5_latent_dim_256_lr_1e-05_sub_layers_1111': 266.5119453804072,
 'kl_coeff_0.5_latent_dim_256_lr_1e-05_sub_layers_3333': 233.17972646819038,
 'kl_coeff_0.5_latent_dim_256_lr_5e-05_sub_layers_0000': 194.24331032081724,
 'kl_coeff_0.5_latent_dim_256_lr_5e-05_sub_layers_1111': 222.58709887457923,
 'kl_coeff_0.5_latent_dim_256_lr_5e-05_sub_layers_3333': 221.94983749271523,
 'kl_coeff_0.5_latent_dim_512_lr_0.0001_sub_layers_0000': 194.1826017594779,
 'kl_coeff_0.5_latent_dim_512_lr_0.0001_sub_layers_1111': 183.2230128425353,
 'kl_coeff_0.5_latent_dim_512_lr_0.0001_sub_layers_3333': 188.2775644328286,
 'kl_coeff_0.5_latent_dim_512_lr_0.0005_sub_layers_0000': 200.4375891388925,
 'kl_coeff_0.5_latent_dim_512_lr_0.0005_sub_layers_1111': 196.86424614107904,
 'kl_coeff_0.5_latent_dim_512_lr_0.0005_sub_layers_3333': 209.87650026803507,
 'kl_coeff_0.5_latent_dim_512_lr_0.001_sub_layers_0000': 208.18268687903986,
 'kl_coeff_0.5_latent_dim_512_lr_0.001_sub_layers_1111': 218.61603224886528,
 'kl_coeff_0.5_latent_dim_512_lr_0.001_sub_layers_3333': 206.78261758876855,
 'kl_coeff_0.5_latent_dim_512_lr_1e-05_sub_layers_0000': 251.36766178227262,
 'kl_coeff_0.5_latent_dim_512_lr_1e-05_sub_layers_1111': 269.76165897448845,
 'kl_coeff_0.5_latent_dim_512_lr_1e-05_sub_layers_3333': 241.31482904725226,
 'kl_coeff_0.5_latent_dim_512_lr_5e-05_sub_layers_0000': 205.62385972199502,
 'kl_coeff_0.5_latent_dim_512_lr_5e-05_sub_layers_1111': 255.8689302906269,
 'kl_coeff_0.5_latent_dim_512_lr_5e-05_sub_layers_3333': 186.22899859532004}

In [22]:
df = pd.DataFrame({'name': list(fids.keys()), 'fid': list(fids.values())})
df = df.sort_values('fid').reset_index(drop=True)
df.head()

Unnamed: 0,name,fid
0,kl_coeff_0.5_latent_dim_512_lr_0.0001_sub_layers_1111,183.223013
1,kl_coeff_0.1_latent_dim_512_lr_0.0001_sub_layers_0000,183.936522
2,kl_coeff_0.1_latent_dim_256_lr_0.0001_sub_layers_1111,184.901996
3,kl_coeff_0.1_latent_dim_512_lr_0.0001_sub_layers_1111,185.643731
4,kl_coeff_0.5_latent_dim_512_lr_5e-05_sub_layers_3333,186.228999


In [86]:
df[
    ['_0', '_1', 'kl_coeff', '_2', '_3', 'latent_dim', '_4', 'lr', '_5', '_6', 'sub_layers']
] = df['name'].str.split('_', expand=True)

df = df.loc[:,~df.columns.str.startswith('_')]
df_clean = df.drop(columns='name').reset_index()
df_clean['sub_layers'] = df_clean['sub_layers'].astype(int) % 10

print(df_clean.shape)
df_clean.head()

(135, 6)


Unnamed: 0,index,fid,kl_coeff,latent_dim,lr,sub_layers
0,0,183.223013,0.5,512,0.0001,1
1,1,183.936522,0.1,512,0.0001,0
2,2,184.901996,0.1,256,0.0001,1
3,3,185.643731,0.1,512,0.0001,1
4,4,186.228999,0.5,512,5e-05,3


In [130]:
import plotly.express as px
px.box(df_clean, x='kl_coeff', y='fid', boxmode="overlay", log_x=True, color='kl_coeff')

In [125]:
px.box(df_clean, x='latent_dim', y='fid', boxmode="overlay", log_x=True, color='latent_dim')

In [124]:
px.box(df_clean, x='lr', y='fid', boxmode="overlay", log_x=True, color='lr')

In [127]:
px.box(df_clean, x='sub_layers', y='fid', boxmode="overlay", color='sub_layers')