# How to train your Global Workspace
Benjamin Devillers

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ruflab/shimmer-tutorials/blob/main/simple-shapes-dataset-training.ipynb)


In this notebook, we will see how to use `shimmer` to build and train from scratch a Global Workspace on the Simple Shapes Dataset. We train a model than can translate visual images of shapes from the [simple-shapes-datset](https://github.com/ruflab/simple-shapes-dataset) to their proto-language (attributes).

For this tutorial, we will need to install the [shimmer-ssd](https://github.com/ruflab/shimmer-ssd) package.

# !pip install --force-reinstall "git+https://github.com/AEmanuelli/shimmer-ssd.git"

# !pip install tensorboard

This package depends on [simple-shapes-dataset](https://github.com/ruflab/simple-shapes-dataset) and provides all of its commands. You can then use all of its commands.

For instance, we can download the dataset directly with:

# !shapesd download

Note that `shapesd download` automatically migrates the dataset so that it is correctly formatted. If you downloaded the dataset manually, use `shapesd migrate -p PATH_TO_DATASET` to migrate manually.

from collections.abc import Mapping, Sequence
from pathlib import Path
from typing import Any, cast

import matplotlib
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from lightning.pytorch import Callback, Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from shimmer import DomainModule, LossOutput
from shimmer.modules.domain import DomainModule
from shimmer.modules.global_workspace import GlobalWorkspace2Domains, SchedulerArgs
from shimmer.modules.vae import (
    VAE,
    VAEDecoder,
    VAEEncoder,
    gaussian_nll,
    kl_divergence_loss,
)
from shimmer_ssd import DEBUG_MODE, LOGGER, PROJECT_DIR
from shimmer_ssd.config import DomainModuleVariant, LoadedDomainConfig, load_config
from shimmer_ssd.dataset.pre_process import TokenizeCaptions
from shimmer_ssd.logging import (
    LogAttributesCallback,
    LogGWImagesCallback,
    LogVisualCallback,
    batch_to_device,
)
from shimmer_ssd.modules.domains import load_pretrained_domains
from shimmer_ssd.modules.domains.visual import VisualLatentDomainModule
from shimmer_ssd.modules.vae import RAEDecoder, RAEEncoder
from tokenizers.implementations.byte_level_bpe import ByteLevelBPETokenizer
from torch import nn
from torch.nn.functional import mse_loss
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.optimizer import Optimizer
from torchvision.utils import make_grid

from simple_shapes_dataset import SimpleShapesDataModule, get_default_domains


%matplotlib inline

## Config

Let's first generate the config folder for the rest of the scripts.
This will create a `config` folder with different yaml files used by the different scripts and in the notebook.

# !ssd config create

This will create a `config` folder. This contains many file, but in this tutorial, only `main.yaml` will interest us.

You can start by taking a look at the default values which should be mostly set correctly for this tutorial. But you can try and make some changes to see the outcome.

<div class="alert alert-info">
Anytime you make a change to the config, don't forget to reload it with the following cell!
</div>

## Data format

The dataloader provides the data in a specific format:

```python
domain_group = {
    "domain": domain_data
}
batch = {
    frozenset(["domain"]): domain_group
}
```
* The **batch** is a dict that has frozensets of domains as keys, and a domain group as values.
* The **domain group** is a dict that has domains (string) as keys, and the domain data as values. The data samples of every domain in a domain group is matched. This
means that for a domain group that has 2 domains d1 and d2: `domain_group["d1"][k]` is paired with `domain_group["d2"][k]` for all `k`.

This allows a batch to have several groups (of different domains) of paired data. For example, a batch with unpaired visual (domain "v"), unpaired attribute (domain "attr"), and paired visual and attribute will look like:
```python
batch = {
    frozenset(["v"]): {"v": unpaired_visual_data},
    frozenset(["attr"]): {"attr": unpaired_attribute_data},
    frozenset(["attr", "v"]): {"attr": paired_attr_data, "v": paired_visual_data},
}
```

This is useful to train the global workspace later. But this is also the format used to train the unimodal domains.

Note that because all the data is paired in validation and test steps, the dataloader only returns one domain group with all paired domain:
```python
val_batch = {"attr": paired_attr_data, "v": paired_v_data}
```

## Train a Global Workspace

Now that we trained our two unimodal modules, we will train the global workspace. For this training, we will use half of the paired 500,000 samples.
To this extent, we need to create a split in the dataset. A dataset split depends on a seed and the proportion of each group of domain.
We only need to generate this split once.

This can be done with the `shapesd alignment add` command. It needs the following arguments:
- `--dataset_path "DATASET_PATH"`: the location where the dataset is stored
- `--seed SEED` the split seed
- `--domain_alignment DOMAIN_1,DOMAIN_2,...DOMAIN_N PROP` the proportion for each domain group. This corresponds to what has been defined in `domain_proportion`

When running this command, it will create a file containing the indices of the items available in the train set (update so that it matches what we set in the config file).

# !shapesd alignment add --dataset_path "simple_shapes_dataset" --seed 0 --domain_alignment attr 1.0 --domain_alignment v 1.0 --domain_alignment attr,v 1.0

This time, we will load the config from the extra file `train_gw.yaml`

First, let's update `main.yaml` to use the same alignment split:
```yaml
domain_proportions:
    -   domains: ["v"]  # unimodal visual passes use 100% of the available data
        proportion: 1.0
    -   domains: ["attr"]  # unimodal attr passes use 100% of the available data
        proportion: 1.0
    -   domains: ["v", "attr"]  # paired passes uses 50% of the available data
        proportion: 0.5
```

let's change the selected domains:

```yaml
domains:
    - checkpoint_path: "./checkpoints/visual/version_0/last.ckpt"  # update to the actual version
      domain_type: v_latents
    - checkpoint_path: "./checkpoints/attr/version_0/last.ckpt"  # update to the actual version
      domain_type: attr
```

and let's define the global workspace dimenison to 12:
```yaml
global_workspace:
    latent_dim: 12  
    
    loss_coefficients:
        cycles: 1.0
        contrastives: 0.1
        demi_cycles: 1.0
        translations: 1.0

    encoders:
        hidden_dim: 32
        n_layers: 3

    decoders:
        hidden_dim: 32
        n_layers: 3
```

Finally, let's load the config:

config = load_config("./config", use_cli=False, load_files=["train_gw.yaml"])

Skip the following cell if you have trained the unimodal module yourself. The next cell setups pretrained modules.

### Run this if you did't train the modules

# # Download checkpoints
# !ssd download checkpoints
# !mv checkpoints/checkpoints/* checkpoints/
# !rm -rf checkpoints/checkpoints

# # Extract visual latent from pretrained visual domain
# !ssd extract v "checkpoints/domain_v.ckpt" -p "simple_shapes_dataset"

my_hparams = {"temperature":0.5, "alpha": 2}


# Update the config
checkpoint_path = Path("./checkpoints")
config.domain_proportions = {
    frozenset(["v"]): 1.0,
    frozenset(["attr"]): 1.0,
    frozenset(["v", "attr"]): 1.0,
}

config.domains = [
    LoadedDomainConfig(
        domain_type=DomainModuleVariant.v_latents,
        checkpoint_path=checkpoint_path / "domain_v.ckpt",
    ),
    LoadedDomainConfig(
        domain_type=DomainModuleVariant.attr_legacy,
        checkpoint_path=checkpoint_path / "domain_attr.ckpt",
        args=my_hparams,
    ),
]

config.domain_data_args["v_latents"]["presaved_path"] = "domain_v.npy"
config.global_workspace.latent_dim = 12

import torch
from torch.utils.data import default_collate

def custom_collate(batch):
    collated = default_collate(batch)
    if isinstance(collated, dict) and "attr" in collated:
        # Si "attr" est une liste et contient au moins deux tenseurs,
        # on modifie uniquement le deuxième tenseur.
        if isinstance(collated["attr"], list) and len(collated["attr"]) >= 2:
            second_tensor = collated["attr"][1]
            if isinstance(second_tensor, torch.Tensor):
                # On enlève les trois valeurs situées juste avant la dernière.
                if second_tensor.size(-1) >= 4:  # vérifie qu'il y a assez d'éléments
                    collated["attr"][1] = torch.cat(
                        [second_tensor[..., : -4], second_tensor[..., -1:]], dim=-1
                    )
    return collated
domain_classes = get_default_domains(["v_latents", "attr"])


data_module = SimpleShapesDataModule(
    config.dataset.path,
    domain_classes,
    config.domain_proportions,
    batch_size=config.training.batch_size,
    num_workers=config.training.num_workers,
    seed=config.seed,
    domain_args=config.domain_data_args,
    collate_fn=custom_collate  # utilisation du collate personnalisé
)

### Load the domains and train
We can now load the pretrained unimodal modules

# we load the pretrained domain modules and define the associated GW encoders and decoders
domain_modules, gw_encoders, gw_decoders = load_pretrained_domains(
    config.domains,
    config.global_workspace.latent_dim,
    config.global_workspace.encoders.hidden_dim,
    config.global_workspace.encoders.n_layers,
    config.global_workspace.decoders.hidden_dim,
    config.global_workspace.decoders.n_layers,
)

Instanciate the global Workspace class

def get_scheduler(optimizer: Optimizer) -> OneCycleLR:
    return OneCycleLR(optimizer, config.training.optim.max_lr, config.training.max_steps)


global_workspace = GlobalWorkspace2Domains(
    domain_modules,
    gw_encoders,
    gw_decoders,
    config.global_workspace.latent_dim,
    config.global_workspace.loss_coefficients,
    config.training.optim.lr,
    config.training.optim.weight_decay,
    scheduler=get_scheduler,
)

Add a Wandb logger to follow the training

from lightning.pytorch.loggers.wandb import WandbLogger

# logger = TensorBoardLogger("logs", name="gw")
logger_wandb = WandbLogger(name="gw_no_color", project="shimmer-ssd")
logger = logger_wandb
logger_wandb.log_hyperparams(my_hparams)


# Get some image samples to log in tensorboard.
train_samples = data_module.get_samples("train", 32)
val_samples = data_module.get_samples("val", 32)

# split the unique group in validation into individual groups for logging
for domains in val_samples:
    for domain in domains:
        val_samples[frozenset([domain])] = {domain: val_samples[domains][domain]}
    break
# Create attr folder where we will save checkpoints
(config.default_root_dir / "gw").mkdir(exist_ok=True)

callbacks: list[Callback] = [
    # Will log the validation ground-truth and reconstructions during training
    LogGWImagesCallback(
        val_samples,
        log_key="images/val",
        mode="val",
        every_n_epochs=config.logging.log_val_medias_every_n_epochs,
        filter=config.logging.filter_images,
    ),
    # Will log the training ground-truth and reconstructions during training
    LogGWImagesCallback(
        train_samples,
        log_key="images/train",
        mode="train",
        every_n_epochs=config.logging.log_train_medias_every_n_epochs,
        filter=config.logging.filter_images,
    ),
    # Save the checkpoints
    ModelCheckpoint(
        dirpath=config.default_root_dir / "gw" / f"version_{logger.version}",
        filename="{epoch}",
        monitor="val/loss",
        mode="min",
        save_last="link",
        save_top_k=1,
    ),
]

For the final model, let's save where the model is saved:

gw_checkpoint = config.default_root_dir / "gw" / f"version_{logger.version}"
print(gw_checkpoint)

And train!

trainer = Trainer(
    logger=logger,
    max_steps=config.training.max_steps,
    default_root_dir=config.default_root_dir,
    callbacks=callbacks,
    precision=config.training.precision,
    accelerator=config.training.accelerator,
    devices=config.training.devices,
)

trainer.fit(global_workspace, data_module)
trainer.validate(global_workspace, data_module, "best")