<div style="margin: 0 auto 30px; height: 60px; border: 2px solid gray; border-radius: 6px;">
  <div style="float: left;"><img src="img/epfl.png" /></div>
  <div style="float: right; margin: 20px 30px 0; font-size: 10pt; font-weight: bold;"><a href="https://moodle.epfl.ch/course/view.php?id=18345">COM304 - Communication Project</a></div>
</div>
<div style="clear: both; font-size: 30pt; font-weight: bold; color: #483D8B;">
    Demo Notebook SAGA group, nano4M with Audio/Video
</div>

# Overview

Welcome to our COM304 Communication Project demo notebook! In this session, we showcase the result that we get from our different extansions. Developed by the SAGA group. Throughout this notebook, we will:

- Setup the necessary import + define some helper function to help us visualise results.
- Visualise some sample of our data, pre/post tokenization.
- Go through the different results that we obtain trough different training on different modalities.
- Visualise and compare the result that we get when fine tuning an audio tokenizer on our data.

Let’s start by importing the necessary libraries below.

In [1]:
import sys
from pathlib import Path

project_root = Path.cwd()  
wav_pkg = project_root / "nanofm" / "data" / "tokenizers" / "WavTokenizer"
sys.path.insert(0, str(wav_pkg))

import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
current_folder = globals()['_dh'][0]
# project_root = os.path.abspath(os.path.join(current_folder, '..', '..'))
# os.chdir(project_root)

from PIL import Image
from einops import rearrange

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

import io
import glob
import math
import torch
import torchaudio
import numpy as np
import torch.nn.functional as F
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import torchvision.transforms.functional as TF


from IPython.display import display,Audio
from IPython.display import Image as IPyImage
from hydra.utils import instantiate

from nanofm.data.utils import save_video
from nanofm.utils.checkpoint import load_safetensors, load_state_dict
from nanofm.data.multimodal.masking import SimpleMultimodalMasking
from nanofm.data.multimodal.adapted_multimodal_dataset import AdaptedMultimodalDataset

# Tokenizer imports
from nanofm.data.tokenizers.image_tokenizer import ImageTokenizer
from nanofm.data.tokenizers.audio_tokenizer import AudioTokenizer
from nanofm.data.tokenizers.video_tokenizer import VideoTokenizer
from nanofm.data.tokenizers.label_map import Maplabel
from nanofm.data.tokenizers.WavTokenizer.wavetok import WavAudioTokenizer

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# The flag below controls whether to allow TF32 on matmul. This flag defaults to False in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_grad_enabled(False)
print(device)

  from .autonotebook import tqdm as notebook_tqdm


cpu


## Loading Pretrained Models

Now that all dependencies are in place, we’ll load our pretrained checkpoints and organize them by training run and modality. To do this, we define a Python dictionary where each key is the run identifier and each value contains:

- The path of the model instance  
- The corresponding modality it was trained on  

```python
# Example structure — update with your actual run names & checkpoint paths
pretrained_models = {
    "run_audio_tokenizer": {
        "model": "path/to/checkpoint/safetensors",
        "modality": "['modality1', 'modality2', ..., 'modalityN']"
    },
}


In [2]:
checkpoints = {
    "CKPT_FIRST_TRAIN": {
        "path": "/work/com-304/SAGA/outputs/nano4M/multiclevr_d6-6w512/checkpoint-final.safetensors",
        "tags": ["tok_rgb@256", "tok_depth@256", "tok_audio@24_000", "tok_video@256"],
    },
    "CKPT_FIRST_SMALL_SUBSET": {
        "path": "/work/com-304/SAGA/outputs/nano4M/SAGAnano4M_smallest/checkpoint-final.safetensors",
        "tags": ["tok_rgb@256", "tok_depth@256"], #TO COMPLETE
    },
    "CKPT_RGB&CAPT": {
        "path": "/work/com-304/SAGA/outputs/rgb_capt/checkpoint-final.safetensors",
        "tags": ["tok_rgb@256", "tok_label"],
    },
    "CKPT_DEPTH&RGB": {
        "path": "/work/com-304/SAGA/outputs/nano4M/SAGAnano4M_depth/checkpoint-final.safetensors",
        "tags": ["tok_rgb@256", "tok_depth@256"],
    },
}

DATA_ROOT = '/work/com-304/SAGA/tokens_16_05/'
IMAGE_MODEL_NAME = "Cosmos-0.1-Tokenizer-DI16x16"
VIDEO_MODEL_NAME = "Cosmos-0.1-Tokenizer-DV8x8x8"
PATH_LABEL_DICT = "/home/godey/SAGA_COM-304/dataset_module/data/processed_data/label_counts.csv"

### Preparing Tokenizers, Model & DataLoader

Once our checkpoints are organized, the next step is to load the associated tokenizers, instantiate the model, and wrap our tokenized dataset in a `DataLoader`

In [6]:
# Initialize tokenizers

image_tokenizer = ImageTokenizer(
            model_name=IMAGE_MODEL_NAME,
            device=torch.device(device)
        )
audio_tokenizer = AudioTokenizer(device=torch.device("cpu"))
video_tokenizer = VideoTokenizer(
            model_name=VIDEO_MODEL_NAME,
            device=torch.device(device)
        )
label_tokenizer = Maplabel(PATH_LABEL_DICT)

all_modalities = [
    'tok_audio@24_000',
    'tok_depth@256',
    'tok_label',
    'tok_rgb@256',
    'tok_video@256'
]

img_dataset2 = AdaptedMultimodalDataset(root_dir = "/work/com-304/SAGA/tokens_16_05/",
            split = "eval" ,
            modalities= all_modalities ,
            sample_from_k_augmentations = 1)


#Load model
selected_ckpt = checkpoints['CKPT_FIRST_TRAIN']['path']
selected_modalities = checkpoints['CKPT_FIRST_TRAIN']['tags']

print(selected_ckpt)
ckpt, config = load_safetensors(selected_ckpt)
model = instantiate(config)

load_state_dict(model, ckpt, ignore_missing='dec_context_proj')
model = model.to(device)
e = model.eval()

Fetching 7 files: 100%|██████████| 7/7 [00:00<00:00, 31434.83it/s]
  self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False)
Fetching 7 files: 100%|██████████| 7/7 [00:00<00:00, 6494.17it/s]


/work/com-304/SAGA/outputs/nano4M/multiclevr_d6-6w512/checkpoint-final.safetensors


## Helper Functions

Now that our data pipelines and models are in place, it’s useful to bundle common routines into helper functions for visualization and data processing. Below are the core utilities we’ll use throughout the notebook:

- `show_im_from_tensor(tensor_or_array)`: Convert a (C,H,W) tensor or array into a PIL image.  
- `get_gif_bytes_from_tensor(frames: torch.Tensor, fps: int = 3) → bytes`: Turn a (C,T,H,W) tensor into GIF byte data.  
- `construct_input_from_sample(dataset, idx, input_modality)`: Build token, position & modality tensors for a dataset sample.  
- `token_ids_to_image(token_ids, image_tokenizer, to_pil=False)`: Decode image token IDs back into a tensor or PIL image.  
- `tokens_ids_to_audio(token_ids, audio_tokenizer)`: Decode audio token IDs, play the waveform, and return the raw signal.  
- `tokens_ids_to_gif(token_ids, video_tokenizer)`: Decode video token IDs and display the resulting frames as a GIF.  
- `tokens_ids_to_label(tokens_ids, label_tokenizer)`: Decode label token IDs into text and print it.  
- `show_modality(tokens, modality: str)`: Dispatch tokens to the appropriate display function based on modality.  


In [7]:
def show_im_from_tensor(tensor_or_array):
    if isinstance(tensor_or_array, torch.Tensor):
        tensor_or_array = tensor_or_array.detach().cpu()
        array = tensor_or_array.numpy()
    else:
        array = tensor_or_array
    
    array = np.clip(array, 0, 1)
    array = array.transpose(1, 2, 0)
    
    array = (array * 255).astype(np.uint8)

    image = Image.fromarray(array)
    return image

def get_gif_bytes_from_tensor(frames: torch.Tensor, fps: int = 3) -> bytes:

    # 1. CPU → numpy and clamp
    frames_np = frames.cpu().numpy()
    frames_np = np.clip(frames_np, 0, 1)

    frames_np = frames_np.transpose(1, 2, 3, 0)
    frames_np = (frames_np * 255).astype(np.uint8)
    pil_frames = [Image.fromarray(f) for f in frames_np]

    buf = io.BytesIO()
    pil_frames[0].save(
        buf,
        format='GIF',
        save_all=True,
        append_images=pil_frames[1:],
        duration=int(1000 / fps),
        loop=0
    )
    buf.seek(0)
    return buf.getvalue()

In [8]:
def construct_input_from_sample(dataset, idx, input_modality):
    input_tensor = dataset[idx][input_modality]
    n_tokens_input = input_tensor.shape[0]
    enc_input_tokens = input_tensor.unsqueeze(0).to(device)
    enc_input_positions = torch.arange(n_tokens_input, device=device).unsqueeze(0)
    enc_input_modalities = selected_modalities.index(input_modality) * torch.ones(1, n_tokens_input, device=device, dtype=torch.long)
    return enc_input_tokens, enc_input_positions, enc_input_modalities

def token_ids_to_image(token_ids, image_tokenizer, to_pil=False):
    n_tokens = token_ids.numel()
    side = int(math.sqrt(n_tokens))
    token_ids = token_ids.reshape(1,side,side).to(device)
    reconst = image_tokenizer.decode(token_ids)
    reconst = (reconst[0].float().cpu())
    if to_pil:
        reconst = TF.to_pil_image(reconst)
    return reconst

def tokens_ids_to_audio(token_ids, audio_tokenizer):
    num_quantizers = 32
    token_ids = token_ids.reshape(1,num_quantizers,-1).to(device)
    reconst = audio_tokenizer.decode(token_ids.clamp(0,2047).cpu())
    reconst = reconst.squeeze(0)
    player = Audio(reconst, rate=24_000)
    display(player)
    return reconst

def tokens_ids_to_gif(token_ids,video_tokenizer):
    num_frames = 8
    model_bucket_len = 8
    nb_div = math.ceil((num_frames + 1) / model_bucket_len)
    
    n_tokens = token_ids.numel() // nb_div
    side = int(math.sqrt(n_tokens))
    token_ids = token_ids.reshape(1,nb_div,side,side).to(device)

    reconst = video_tokenizer.decode(token_ids)
    reconst = reconst[0].float().cpu()
    # save_video(reconst, os.path.join(out_path,"fish.gif"))
    
    reconst_gif = get_gif_bytes_from_tensor(reconst)
    gif = IPyImage(reconst_gif, format = 'gif')
    display(gif)

def tokens_ids_to_label(tokens_ids, label_tokenizer):
    print(label_tokenizer.decode(tokens_ids.cpu()))

def show_modality(tokens, modality: str):
    if modality.__contains__('rgb') or modality.__contains__('depth') :
        token_ids_to_image(tokens, image_tokenizer, to_pil=True).show()
    if modality.__contains__('audio') :
        tokens_ids_to_audio(tokens, audio_tokenizer)
    if modality.__contains__('video'):
        tokens_ids_to_gif(tokens,video_tokenizer)
    if modality.__contains__('label'):
        tokens_ids_to_label(tokens,label_tokenizer)

## Generate & Display

Now that our helper utilities are in place, we can define the core generation function that transforms one modality into another. This function will:

1. **Take as input**:  
   - `input_mod` (str): source modality (e.g., `"audio"`, `"image"`, `"video"`, `"label"`, `"depth"`).  
   - `target_mod` (str): modality to generate (e.g., `"video"`, `"label"`, etc.).  
   - `nb_iteration` (int): number of samples/iterations to process.  
   - `num_steps` (int): number of diffusion or autoregressive steps per sample.  
   - `temp` (float): sampling temperature for stochastic decoding.  
   - `top_p` (float): nucleus (p-value) sampling threshold.  
   - `top_k` (int): top-k sampling cutoff.  

2. **Run them through** the Nano4M model in inference mode to generate the output token sequence.  
3. **Decode the output tokens** back into the target modality for visualization or playback.  

In [9]:
def generate_and_display(input_mod: str,  target_mod: str, nb_iteration: int = 1,
                         num_steps= 1 , temp= 0.7, top_p = 0.0, top_k = 0):
    for i in range(nb_iteration):
        sample_idx = i

        x_tokens, x_positions, x_modalities = construct_input_from_sample(img_dataset2, idx=sample_idx, input_modality=input_mod)
        show_modality(x_tokens, input_mod)
        
        num_steps, temp, top_p, top_k = 1, 0.7, 0.0, 0.0
        
        pred_tokens, x_tokens, x_positions, x_modalities = model.generate_one_modality_roar(
            x_tokens, x_positions, x_modalities, target_mod=target_mod,
            num_steps=num_steps, temp=temp, top_p=top_p, top_k=top_k,
        )
        show_modality(pred_tokens, target_mod)

## Training a New Audio Tokenizer

After testing our model’s outputs, we observed that our **current audio tokenizer**:

- Has vocabulary size of `2048`  
- Produces around `2016 tokens` for just **5 seconds** of audio  

This large token count versus the number of tokens for our Image tokenizer (`vocab_size = 64_000`, `max_sequence_length = 256`) creates a significant imbalance, making the multi-masking scheme less effective.

To remedy this, we looked for a tokenizer that generates far fewer tokens per second and selected the [WavTokenizer](<https://github.com/jishengpeng/WavTokenizer>). Since WavTokenizer is primarily trained on speech, it doesn’t handle environmental sounds well, so we retrained it on our own dataset.

In the sections that follow, we will:

1. Showcase the tokenization issues with environmental sounds.  
2. Examine the outputs after training WavTokenizer for 1000 epochs on our data.  


In [10]:
folder = "/work/com-304/SAGA/raw/audios"
# Chemins vers config & checkpoint (à adapter)
#config = "/home/godey/SAGA_COM-304/nanofm/data/tokenizers/WavTokenizer/configs/wavtokenizer_smalldata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
#ckpt   = "/work/com-304/SAGA/wavtok/wavtokenizer_small_320_24k_4096.ckpt"

config = '/home/godey/SAGA_COM-304/nanofm/data/tokenizers/WavTokenizer/configs/mine_wavtokenizer_smalldata_frame40_3s_nq1_code4096_dim512_kmeans200_attn.yaml'
ckpt = '/work/com-304/SAGA/wavtok/train/wavtokenizer_smalldata_frame40_3s_nq1_code4096_dim512_kmeans200_attn/lightning_logs/version_22/checkpoints/last.ckpt'

# Chargement
tok = WavAudioTokenizer(config, ckpt)

# Exemple de test unique
for i in range(10):
    sample_path = glob.glob(os.path.join(folder, "**", "*.wav"), recursive=True)[10 + i]

    # Chargement d'un fichier audio
    wav, sr = torchaudio.load(sample_path)   
    # Encodage
    codes = tok.encode(wav, sr)
    print("Codes shape:", codes.shape)
    print('CODES MAX AND MIN', codes.max(), codes.min())

    # Décodage
    wav_rec = tok.decode(codes)

    #Display
    player_in = Audio(wav, rate=24_000)
    player_out = Audio(wav_rec, rate=24_000)

    display(player_in,player_out)

  WeightNorm.apply(module, name, dim)


Codes shape: torch.Size([1, 1, 200])
CODES MAX AND MIN tensor(4067) tensor(8)


Codes shape: torch.Size([1, 1, 200])
CODES MAX AND MIN tensor(3828) tensor(2)


KeyboardInterrupt: 