# How to train your Global Workspace
Benjamin Devillers

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ruflab/shimmer-tutorials/blob/main/simple-shapes-dataset-training.ipynb)


In this notebook, we will see how to use `shimmer` to build and train from scratch a Global Workspace on the Simple Shapes Dataset. We train a model than can translate visual images of shapes from the [simple-shapes-datset](https://github.com/ruflab/simple-shapes-dataset) to their proto-language (attributes).

For this tutorial, we will need to install the [shimmer-ssd](https://github.com/ruflab/shimmer-ssd) package.

# !pip install --force-reinstall "git+https://github.com/AEmanuelli/shimmer-ssd.git@SSD_color"

# !pip install tensorboard

This package depends on [simple-shapes-dataset](https://github.com/ruflab/simple-shapes-dataset) and provides all of its commands. You can then use all of its commands.

For instance, we can download the dataset directly with:

# !shapesd download

Note that `shapesd download` automatically migrates the dataset so that it is correctly formatted. If you downloaded the dataset manually, use `shapesd migrate -p PATH_TO_DATASET` to migrate manually.

from collections.abc import Mapping, Sequence
from pathlib import Path
from typing import Any, cast

import matplotlib
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from lightning.pytorch import Callback, Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from shimmer import DomainModule, LossOutput
from shimmer.modules.domain import DomainModule
from shimmer.modules.global_workspace import GlobalWorkspace2Domains, SchedulerArgs
from shimmer.modules.vae import (
    VAE,
    VAEDecoder,
    VAEEncoder,
    gaussian_nll,
    kl_divergence_loss,
)
from shimmer_ssd import DEBUG_MODE, LOGGER, PROJECT_DIR
from shimmer_ssd.config import DomainModuleVariant, LoadedDomainConfig, load_config
from shimmer_ssd.dataset.pre_process import TokenizeCaptions
from shimmer_ssd.logging import (
    LogAttributesCallback,
    LogGWImagesCallback,
    LogVisualCallback,
    batch_to_device,
)
from shimmer_ssd.modules.domains import load_pretrained_domains
from shimmer_ssd.modules.domains.visual import VisualLatentDomainModule
from shimmer_ssd.modules.vae import RAEDecoder, RAEEncoder
from tokenizers.implementations.byte_level_bpe import ByteLevelBPETokenizer
from torch import nn
from torch.nn.functional import mse_loss
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.optimizer import Optimizer
from torchvision.utils import make_grid

from simple_shapes_dataset import SimpleShapesDataModule, get_default_domains


%matplotlib inline

## Config

Let's first generate the config folder for the rest of the scripts.
This will create a `config` folder with different yaml files used by the different scripts and in the notebook.

# !ssd config create

This will create a `config` folder. This contains many file, but in this tutorial, only `main.yaml` will interest us.

You can start by taking a look at the default values which should be mostly set correctly for this tutorial. But you can try and make some changes to see the outcome.

<div class="alert alert-info">
Anytime you make a change to the config, don't forget to reload it with the following cell!
</div>

# We don't use cli in the notebook, but consider using it in normal scripts.
config = load_config("./config", use_cli=False)

## Data format

The dataloader provides the data in a specific format:

```python
domain_group = {
    "domain": domain_data
}
batch = {
    frozenset(["domain"]): domain_group
}
```
* The **batch** is a dict that has frozensets of domains as keys, and a domain group as values.
* The **domain group** is a dict that has domains (string) as keys, and the domain data as values. The data samples of every domain in a domain group is matched. This
means that for a domain group that has 2 domains d1 and d2: `domain_group["d1"][k]` is paired with `domain_group["d2"][k]` for all `k`.

This allows a batch to have several groups (of different domains) of paired data. For example, a batch with unpaired visual (domain "v"), unpaired attribute (domain "attr"), and paired visual and attribute will look like:
```python
batch = {
    frozenset(["v"]): {"v": unpaired_visual_data},
    frozenset(["attr"]): {"attr": unpaired_attribute_data},
    frozenset(["attr", "v"]): {"attr": paired_attr_data, "v": paired_visual_data},
}
```

This is useful to train the global workspace later. But this is also the format used to train the unimodal domains.

Note that because all the data is paired in validation and test steps, the dataloader only returns one domain group with all paired domain:
```python
val_batch = {"attr": paired_attr_data, "v": paired_v_data}
```

## Train a Global Workspace

Now that we trained our two unimodal modules, we will train the global workspace. For this training, we will use half of the paired 500,000 samples.
To this extent, we need to create a split in the dataset. A dataset split depends on a seed and the proportion of each group of domain.
We only need to generate this split once.

This can be done with the `shapesd alignment add` command. It needs the following arguments:
- `--dataset_path "DATASET_PATH"`: the location where the dataset is stored
- `--seed SEED` the split seed
- `--domain_alignment DOMAIN_1,DOMAIN_2,...DOMAIN_N PROP` the proportion for each domain group. This corresponds to what has been defined in `domain_proportion`

When running this command, it will create a file containing the indices of the items available in the train set (update so that it matches what we set in the config file).

# !shapesd alignment add --dataset_path "simple_shapes_dataset" --seed 0 --domain_alignment attr 1.0 --domain_alignment v 1.0 --domain_alignment attr,v 1.0

This time, we will load the config from the extra file `train_gw.yaml`

First, let's update `main.yaml` to use the same alignment split:
```yaml
domain_proportions:
    -   domains: ["v"]  # unimodal visual passes use 100% of the available data
        proportion: 1.0
    -   domains: ["attr"]  # unimodal attr passes use 100% of the available data
        proportion: 1.0
    -   domains: ["v", "attr"]  # paired passes uses 50% of the available data
        proportion: 0.5
```

let's change the selected domains:

```yaml
domains:
    - checkpoint_path: "./checkpoints/visual/version_0/last.ckpt"  # update to the actual version
      domain_type: v_latents
    - checkpoint_path: "./checkpoints/attr/version_0/last.ckpt"  # update to the actual version
      domain_type: attr
```

and let's define the global workspace dimenison to 12:
```yaml
global_workspace:
    latent_dim: 12  
    
    loss_coefficients:
        cycles: 1.0
        contrastives: 0.1
        demi_cycles: 1.0
        translations: 1.0

    encoders:
        hidden_dim: 32
        n_layers: 3

    decoders:
        hidden_dim: 32
        n_layers: 3
```

Finally, let's load the config:

config = load_config("./config", use_cli=False, load_files=["train_gw.yaml"])

Skip the following cell if you have trained the unimodal module yourself. The next cell setups pretrained modules.

### Run this if you did't train the modules

# # Download checkpoints
# !ssd download checkpoints
# !mv checkpoints/checkpoints/* checkpoints/
# !rm -rf checkpoints/checkpoints

# # Extract visual latent from pretrained visual domain
# !ssd extract v "checkpoints/domain_v.ckpt" -p "simple_shapes_dataset"

my_hparams = {"temperature": 2.0, "alpha": 0.5}

# Update the config
checkpoint_path = Path("./checkpoints")
config.domain_proportions = {
    frozenset(["v"]): 1.0,
    frozenset(["attr"]): 1.0,
    frozenset(["v", "attr"]): 1.0,
}

config.domains = [
    LoadedDomainConfig(
        domain_type=DomainModuleVariant.v_latents,
        checkpoint_path=checkpoint_path / "domain_v.ckpt",
    ),
    LoadedDomainConfig(
        domain_type=DomainModuleVariant.attr_legacy,
        checkpoint_path=checkpoint_path / "domain_attr.ckpt",
        args=my_hparams,
    ),
]

config.domain_data_args["v_latents"]["presaved_path"] = "domain_v.npy"
config.global_workspace.latent_dim = 12

### Load the domains and train
We can now load the pretrained unimodal modules

# we load the pretrained domain modules and define the associated GW encoders and decoders
domain_modules, gw_encoders, gw_decoders = load_pretrained_domains(
    config.domains,
    config.global_workspace.latent_dim,
    config.global_workspace.encoders.hidden_dim,
    config.global_workspace.encoders.n_layers,
    config.global_workspace.decoders.hidden_dim,
    config.global_workspace.decoders.n_layers,
)

Instanciate the global Workspace class

def get_scheduler(optimizer: Optimizer) -> OneCycleLR:
    return OneCycleLR(optimizer, config.training.optim.max_lr, config.training.max_steps)


global_workspace = GlobalWorkspace2Domains(
    domain_modules,
    gw_encoders,
    gw_decoders,
    config.global_workspace.latent_dim,
    config.global_workspace.loss_coefficients,
    config.training.optim.lr,
    config.training.optim.weight_decay,
    scheduler=get_scheduler,
)



domain_classes = get_default_domains(["v_latents", "attr"])

data_module = SimpleShapesDataModule(
    config.dataset.path,
    domain_classes,
    config.domain_proportions,
    batch_size=config.training.batch_size,
    num_workers=config.training.num_workers,
    seed=config.seed,
    domain_args=config.domain_data_args,
)

Add a Wandb logger to follow the training

from lightning.pytorch.loggers.wandb import WandbLogger
# from tensorboard import TensorBoardLogger
# logger = TensorBoardLogger("logs", name="gw")


logger_wandb = WandbLogger(name=f"gw_with_color_alpha[]", project="shimmer-ssd")
# tensorboard_logger = TensorBoardLogger("logs", name="gw")

logger = logger_wandb

logger_wandb.log_hyperparams(my_hparams)

# Get some image samples to log in tensorboard.
train_samples = data_module.get_samples("train", 32)
val_samples = data_module.get_samples("val", 32)

# split the unique group in validation into individual groups for logging
for domains in val_samples:
    for domain in domains:
        val_samples[frozenset([domain])] = {domain: val_samples[domains][domain]}
    break
# Create attr folder where we will save checkpoints
(config.default_root_dir / "gw").mkdir(exist_ok=True)

callbacks: list[Callback] = [
    # Will log the validation ground-truth and reconstructions during training
    LogGWImagesCallback(
        val_samples,
        log_key="images/val",
        mode="val",
        every_n_epochs=config.logging.log_val_medias_every_n_epochs,
        filter=config.logging.filter_images,
    ),
    # Will log the training ground-truth and reconstructions during training
    LogGWImagesCallback(
        train_samples,
        log_key="images/train",
        mode="train",
        every_n_epochs=config.logging.log_train_medias_every_n_epochs,
        filter=config.logging.filter_images,
    ),
    # Save the checkpoints
    ModelCheckpoint(
        dirpath=config.default_root_dir / "gw" / f"version_{logger.version}",
        filename="{epoch}",
        monitor="val/loss",
        mode="min",
        save_last="link",
        save_top_k=1,
    ),
]

For the final model, let's save where the model is saved:

gw_checkpoint = config.default_root_dir / "gw" / f"version_{logger.version}"
print(gw_checkpoint)

And train!

trainer = Trainer(
    logger=logger,
    max_steps=config.training.max_steps,
    default_root_dir=config.default_root_dir,
    callbacks=callbacks,
    precision=config.training.precision,
    accelerator=config.training.accelerator,
    devices=config.training.devices,
)

trainer.fit(global_workspace, data_module)
trainer.validate(global_workspace, data_module, "best")

#Quel type de modèle ?
MODEL_TYPE = "sans_couleur"

# And now we load the GW checkpoint
checkpoint_path = Path("./checkpoints")
# We don't use cli in the notebook, but consider using it in normal scripts.
config = load_config("./config", use_cli=False)


if MODEL_TYPE == "full_attr":
    domain_type = DomainModuleVariant.attr
    checkpoint = checkpoint_path / "gw-attr-v-all-paired-data.ckpt"
    config.global_workspace.encoders.n_layers = 3
    config.global_workspace.decoders.n_layers = 3
    attributes = torch.tensor(
        [[x * 2 - 1, y * 2 - 1, size * 2 - 1, rotx, roty, color_r * 2 - 1, color_g * 2 - 1, color_b * 2 - 1]]
    )
else : 
    domain_type = DomainModuleVariant.attr_legacy
    checkpoint = checkpoint_path / "gw/version_None/epoch=660.ckpt"
    attributes = torch.tensor(
        [[x * 2 - 1, y * 2 - 1, size * 2 - 1, rotx, roty ]]#, color_r * 2 - 1, color_g * 2 - 1, color_b * 2 - 1]]
    )
    config.global_workspace.encoders.n_layers = 3
    config.global_workspace.decoders.n_layers = 3


#CEST UN TEST

# Update the config
checkpoint_path = Path("./checkpoints")

config.domain_proportions = {
    frozenset(["v"]): 1.0,
    frozenset(["attr"]): 1.0,
    frozenset(["v", "attr"]): 1.0,
}

config.domains = [
    LoadedDomainConfig(
        domain_type=DomainModuleVariant.v_latents,
        checkpoint_path=checkpoint_path / "domain_v.ckpt",
    ),
    LoadedDomainConfig(
        domain_type=domain_type,
        checkpoint_path=checkpoint_path / "domain_attr.ckpt",
    ),
]

config.domain_data_args["v_latents"]["presaved_path"] = "domain_v.npy"
config.global_workspace.latent_dim = 12

domain_modules, gw_encoders, gw_decoders = load_pretrained_domains(
    config.domains,
    config.global_workspace.latent_dim,
    config.global_workspace.encoders.hidden_dim,
    config.global_workspace.encoders.n_layers,
    config.global_workspace.decoders.hidden_dim,
    config.global_workspace.decoders.n_layers,
)

global_workspace = GlobalWorkspace2Domains.load_from_checkpoint(
    checkpoint,
    domain_mods=domain_modules,
    gw_encoders=gw_encoders,
    gw_decoders=gw_decoders,
)

### Run this if you did't train the model

# # Update the config
# checkpoint_path = Path("./checkpoints")

# config.domain_proportions = {
#     frozenset(["v"]): 1.0,
#     frozenset(["attr"]): 1.0,
#     frozenset(["v", "attr"]): 1.0,
# }

# config.domains = [
#     LoadedDomainConfig(
#         domain_type=DomainModuleVariant.v_latents,
#         checkpoint_path=checkpoint_path / "domain_v.ckpt",
#     ),
#     LoadedDomainConfig(
#         domain_type=DomainModuleVariant.attr_legacy,
#         checkpoint_path=checkpoint_path / "domain_attr.ckpt",
#     ),
# ]

# config.domain_data_args["v_latents"]["presaved_path"] = "domain_v.npy"
# config.global_workspace.latent_dim = 12
# # And now we load the GW checkpoint
# checkpoint_path = Path("./checkpoints")
# # checkpoint = checkpoint_path / "gw-attr-v-half-paired-data.ckpt"
# checkpoint  = "/home/alexis/Desktop/checkpoints/gw/version_None/epoch=120-v1.ckpt"
# # we load the pretrained domain modules and define the associated GW encoders and decoders
# domain_modules, gw_encoders, gw_decoders = load_pretrained_domains(
#     config.domains,
#     config.global_workspace.latent_dim,
#     config.global_workspace.encoders.hidden_dim,
#     config.global_workspace.encoders.n_layers,
#     config.global_workspace.decoders.hidden_dim,
#     config.global_workspace.decoders.n_layers,
# )

# global_workspace = GlobalWorkspace2Domains.load_from_checkpoint(
#     checkpoint,
#     domain_mods=domain_modules,
#     gw_encoders=gw_encoders,
#     gw_decoders=gw_decoders,
# )

## Play with the global workspace

import io
import math
%pip install ipywidgets ipympl
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact, interact_manual
from PIL import Image
from shimmer_ssd.logging import attribute_image_grid
from torch.nn.functional import one_hot

from simple_shapes_dataset.cli import generate_image
%matplotlib widget

# !conda install -c conda-forge ipympl

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
global_workspace.to(device)

cat2idx = {"Diamond": 0, "Egg": 1, "Triangle": 2}


def get_image(cat, x, y, size, rot, color_r, color_g, color_b):
    fig, ax = plt.subplots(figsize=(32, 32), dpi=1)
    # The dataset generatoion tool has function to generate a matplotlib shape
    # from the attributes.
    generate_image(
        ax,
        cat2idx[cat],
        [int(x * 18 + 7), int(y * 18 + 7)],
        size * 7 + 7,
        rot * 2 * math.pi,
        np.array([color_r * 255, color_g * 255, color_b * 255]),
        imsize=32,
    )
    ax.set_facecolor("black")
    plt.tight_layout(pad=0)
    # Return this as a PIL Image.
    # This is to have the same dpi as saved images
    # otherwise matplotlib will render this in very high quality
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    image = Image.open(buf)
    plt.close(fig)
    return image


@interact(
    cat=["Triangle", "Egg", "Diamond"],
    x=(0, 1, 0.1),
    y=(0, 1, 0.1),
    rot=(0, 1, 0.1),
    size=(0, 1, 0.1),
    color_r=(0, 1, 0.1),
    color_g=(0, 1, 0.1),
    color_b=(0, 1, 0.1),
)
def play_with_gw(
    cat: str = "Triangle",
    x: float = 0.5,
    y: float = 0.5,
    rot: float = 0.5,
    size: float = 0.5,
    color_r: float = 1,
    color_g: float = 0,
    color_b: float = 0,
):
    fig, axes = plt.subplots(1, 2)
    image = get_image(cat, x, y, size, rot, color_r, color_g, color_b)
    axes[0].set_facecolor("black")
    axes[0].set_title("Original image from attributes")
    axes[0].set_xticks([])
    axes[0].set_yticks([])
    axes[0].imshow(image)

    # normalize the attribute for the global workspace.
    category = one_hot(torch.tensor([cat2idx[cat]]), 3)
    rotx = math.cos(rot * 2 * math.pi)
    roty = math.sin(rot * 2 * math.pi)
    if MODEL_TYPE == "full_attr":
        attributes = torch.tensor(
            [[x * 2 - 1, y * 2 - 1, size * 2 - 1, rotx, roty, color_r * 2 - 1, color_g * 2 - 1, color_b * 2 - 1]]
        )
    else:
        attributes = torch.tensor(
            [[x * 2 - 1, y * 2 - 1, size * 2 - 1, rotx, roty]]  # , color_r * 2 - 1, color_g * 2 - 1, color_b * 2 - 1]]
        )
    samples = [category.to(device), attributes.to(device)]
    attr_gw_latent = global_workspace.gw_mod.encode({"attr": global_workspace.encode_domain(samples, "attr")})
    gw_latent = global_workspace.gw_mod.fuse(
        attr_gw_latent, {"attr": torch.ones(attr_gw_latent["attr"].size(0)).to(device)}
    )
    decoded_latents = global_workspace.gw_mod.decode(gw_latent)["v_latents"]
    decoded_images = (
        global_workspace.domain_mods["v_latents"]
        .decode_images(decoded_latents)[0]
        .permute(1, 2, 0)
        .detach()
        .cpu()
        .numpy()
    )
    axes[1].imshow(decoded_images)
    axes[1].set_xticks([])
    axes[1].set_yticks([])
    axes[1].set_title("Translated image through GW")
    plt.show()





"""---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[18], line 11
      1 trainer = Trainer(
      2     logger=logger,
      3     max_steps=config.training.max_steps,
   (...)
      8     devices=config.training.devices,
      9 )
---> 11 trainer.fit(global_workspace, data_module)
     12 trainer.validate(global_workspace, data_module, "best")

File ~/Desktop/.conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:539, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    537 self.state.status = TrainerStatus.RUNNING
    538 self.training = True
--> 539 call._call_and_handle_interrupt(
    540     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    541 )

File ~/Desktop/.conda/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:47, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     45     if trainer.strategy.launcher is not None:
     46         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 47     return trainer_fn(*args, **kwargs)
     49 except _TunerExitException:
     50     _call_teardown_hook(trainer)

File ~/Desktop/.conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:575, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    568 assert self.state.fn is not None
    569 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    570     self.state.fn,
    571     ckpt_path,
    572     model_provided=True,
    573     model_connected=self.lightning_module is not None,
    574 )
--> 575 self._run(model, ckpt_path=ckpt_path)
    577 assert self.state.stopped
    578 self.training = False

File ~/Desktop/.conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:982, in Trainer._run(self, model, ckpt_path)
    977 self._signal_connector.register_signal_handlers()
    979 # ----------------------------
    980 # RUN THE TRAINER
    981 # ----------------------------
--> 982 results = self._run_stage()
    984 # ----------------------------
    985 # POST-Training CLEAN UP
    986 # ----------------------------
    987 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File ~/Desktop/.conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:1026, in Trainer._run_stage(self)
   1024         self._run_sanity_check()
   1025     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1026         self.fit_loop.run()
   1027     return None
   1028 raise RuntimeError(f"Unexpected state {self.state}")

File ~/Desktop/.conda/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:216, in _FitLoop.run(self)
    214 try:
    215     self.on_advance_start()
--> 216     self.advance()
    217     self.on_advance_end()
    218 except StopIteration:

File ~/Desktop/.conda/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:455, in _FitLoop.advance(self)
    453 with self.trainer.profiler.profile("run_training_epoch"):
    454     assert self._data_fetcher is not None
--> 455     self.epoch_loop.run(self._data_fetcher)

File ~/Desktop/.conda/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py:151, in _TrainingEpochLoop.run(self, data_fetcher)
    149 try:
    150     self.advance(data_fetcher)
--> 151     self.on_advance_end(data_fetcher)
    152 except StopIteration:
    153     break

File ~/Desktop/.conda/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py:370, in _TrainingEpochLoop.on_advance_end(self, data_fetcher)
    366 if not self._should_accumulate():
    367     # clear gradients to not leave any unused memory during validation
    368     call._call_lightning_module_hook(self.trainer, "on_validation_model_zero_grad")
--> 370 self.val_loop.run()
    371 self.trainer.training = True
    372 self.trainer._logger_connector._first_loop_iter = first_loop_iter

File ~/Desktop/.conda/lib/python3.11/site-packages/lightning/pytorch/loops/utilities.py:179, in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
    177     context_manager = torch.no_grad
    178 with context_manager():
--> 179     return loop_run(self, *args, **kwargs)

File ~/Desktop/.conda/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py:151, in _EvaluationLoop.run(self)
    149         self.on_iteration_done()
    150 self._store_dataloader_outputs()
--> 151 return self.on_run_end()

File ~/Desktop/.conda/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py:291, in _EvaluationLoop.on_run_end(self)
    288 self.trainer._logger_connector._evaluation_epoch_end()
    290 # hook
--> 291 self._on_evaluation_epoch_end()
    293 logged_outputs, self._logged_outputs = self._logged_outputs, []  # free memory
    294 # include any logged outputs on epoch_end

File ~/Desktop/.conda/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py:370, in _EvaluationLoop._on_evaluation_epoch_end(self)
    367 trainer = self.trainer
    369 hook_name = "on_test_epoch_end" if trainer.testing else "on_validation_epoch_end"
--> 370 call._call_callback_hooks(trainer, hook_name)
    371 call._call_lightning_module_hook(trainer, hook_name)
    373 trainer._logger_connector.on_epoch_end()

File ~/Desktop/.conda/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:222, in _call_callback_hooks(trainer, hook_name, monitoring_callbacks, *args, **kwargs)
    220     if callable(fn):
    221         with trainer.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"):
--> 222             fn(trainer, trainer.lightning_module, *args, **kwargs)
    224 if pl_module:
    225     # restore current_fx when nested context
    226     pl_module._current_fx_name = prev_fx_name

File ~/Desktop/.conda/lib/python3.11/site-packages/shimmer_ssd/logging.py:706, in LogGWImagesCallback.on_validation_epoch_end(self, trainer, pl_module)
    700 if (
    701     self.every_n_epochs is None
    702     or trainer.current_epoch % self.every_n_epochs != 0
    703 ):
    704     return
--> 706 return self.on_callback(trainer.loggers, pl_module)

File ~/Desktop/.conda/lib/python3.11/site-packages/shimmer_ssd/logging.py:640, in LogGWImagesCallback.on_callback(self, loggers, pl_module)
    634             samples[1] = torch.cat(
    635                 [samples[1][..., :-1], 
    636                  torch.tensor([1, 0, 0], device=device).expand(32, 3), 
    637                  samples[1][..., -1:]], dim=-1
    638             )
    639             print(f"Shape of samples[1]: {samples[1].shape}")
--> 640         self.log_samples(
    641             logger,
    642             pl_module,
    643             samples,
    644             domain,
    645             log_name,
    646         )
    647 for domains, preds in predictions["cycles"].items():
    648     domain_from = ",".join(domains)

File ~/Desktop/.conda/lib/python3.11/site-packages/shimmer_ssd/logging.py:754, in LogGWImagesCallback.log_samples(self, logger, pl_module, samples, domain, mode)
    752     self.log_visual_samples(logger, module.decode_images(samples), mode)
    753 case "attr":
--> 754     self.log_attribute_samples(logger, samples, mode)
    755 case "t":
    756     self.log_text_samples(logger, samples, mode)

File ~/Desktop/.conda/lib/python3.11/site-packages/shimmer_ssd/logging.py:775, in LogGWImagesCallback.log_attribute_samples(self, logger, samples, mode)
    769 def log_attribute_samples(
    770     self,
    771     logger: Logger,
    772     samples: Any,
    773     mode: str,
    774 ) -> None:
--> 775     image = attribute_image_grid(
    776         samples,
    777         image_size=self.image_size,
    778         ncols=self.ncols,
    779     )
    780     log_image(logger, f"{self.log_key}/{mode}", image, self.get_step())

File ~/Desktop/.conda/lib/python3.11/site-packages/shimmer_ssd/logging.py:258, in attribute_image_grid(samples, image_size, ncols)
    255 sizes = attributes.size.detach().cpu().numpy()
    256 rotations = attributes.rotation.detach().cpu().numpy()
--> 258 return get_attribute_figure_grid(
    259     categories,
    260     locations,
    261     sizes,
    262     rotations,
    263     colors,
    264     image_size,
    265     ncols,
    266     padding=2,
    267 )

File ~/Desktop/.conda/lib/python3.11/site-packages/shimmer_ssd/logging.py:216, in get_attribute_figure_grid(categories, locations, sizes, rotations, colors, image_size, ncols, padding)
    214             break
    215         ax = plt.subplot(gs[i, j])
--> 216         generate_image(
    217             ax,
    218             categories[k],
    219             locations[k],
    220             sizes[k],
    221             rotations[k],
    222             colors[k],
    223             image_size,
    224         )
    225         ax.set_facecolor("black")
    226 image = get_pil_image(figure)

File ~/Desktop/.conda/lib/python3.11/site-packages/simple_shapes_dataset/cli/utils.py:134, in generate_image(ax, cls, location, scale, rotation, color, imsize)
    132     patch = get_diamond_patch(location, scale, rotation, color)
    133 elif cls == 1:
--> 134     patch = get_egg_patch(location, scale, rotation, color)
    135 elif cls == 2:
    136     patch = get_triangle_patch(location, scale, rotation, color)

File ~/Desktop/.conda/lib/python3.11/site-packages/simple_shapes_dataset/cli/utils.py:116, in get_egg_patch(location, scale, rotation, color)
     97 codes = [
     98     mpath.Path.MOVETO,
     99     mpath.Path.CURVE4,
   (...)
    110     mpath.Path.CURVE4,
    111 ]
    112 path = mpath.Path(
    113     get_transformed_coordinates(coordinates, origin, scale, rotation),
    114     codes,
    115 )
--> 116 patch = patches.PathPatch(path, facecolor=color)
    117 return patch

File ~/Desktop/.conda/lib/python3.11/site-packages/matplotlib/patches.py:1008, in PathPatch.__init__(self, path, **kwargs)
    999 @_docstring.interpd
   1000 def __init__(self, path, **kwargs):
   1001     """
   1002     *path* is a `.Path` object.
   1003 
   (...)
   1006     %(Patch:kwdoc)s
   1007     """
-> 1008     super().__init__(**kwargs)
   1009     self._path = path

File ~/Desktop/.conda/lib/python3.11/site-packages/matplotlib/patches.py:85, in Patch.__init__(self, edgecolor, facecolor, color, linewidth, linestyle, antialiased, hatch, fill, capstyle, joinstyle, **kwargs)
     83 else:
     84     self.set_edgecolor(edgecolor)
---> 85     self.set_facecolor(facecolor)
     87 self._linewidth = 0
     88 self._unscaled_dash_pattern = (0, None)  # offset, dash

File ~/Desktop/.conda/lib/python3.11/site-packages/matplotlib/patches.py:404, in Patch.set_facecolor(self, color)
    396 """
    397 Set the patch face color.
    398 
   (...)
    401 color : :mpltype:`color` or None
    402 """
    403 self._original_facecolor = color
--> 404 self._set_facecolor(color)

File ~/Desktop/.conda/lib/python3.11/site-packages/matplotlib/patches.py:392, in Patch._set_facecolor(self, color)
    390     color = mpl.rcParams['patch.facecolor']
    391 alpha = self._alpha if self._fill else 0
--> 392 self._facecolor = colors.to_rgba(color, alpha)
    393 self.stale = True

File ~/Desktop/.conda/lib/python3.11/site-packages/matplotlib/colors.py:316, in to_rgba(c, alpha)
    314     rgba = None
    315 if rgba is None:  # Suppress exception chaining of cache lookup failure.
--> 316     rgba = _to_rgba_no_colorcycle(c, alpha)
    317     try:
    318         _colors_full_map.cache[c, alpha] = rgba

File ~/Desktop/.conda/lib/python3.11/site-packages/matplotlib/colors.py:414, in _to_rgba_no_colorcycle(c, alpha)
    412     c = c[:3] + (alpha,)
    413 if any(elem < 0 or elem > 1 for elem in c):
--> 414     raise ValueError("RGBA values should be within 0-1 range")
    415 return c

ValueError: RGBA values should be within 0-1 range"""