## LOADING DEPENDENCIES

In [29]:
### # Switch path to root of project
import os
import sys
os.environ["CUDA_VISIBLE_DEVICES"]="0"
current_folder = globals()['_dh'][0]
project_root = os.path.abspath(os.path.join(current_folder, '..', '..', '..'))
os.chdir(project_root)
if project_root not in sys.path:
    sys.path.insert(0, project_root)

print(sys.executable)
from PIL import Image
from einops import rearrange

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

import math
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np
import io
from IPython.display import display,Audio
from IPython.display import Image as IPyImage
from hydra.utils import instantiate
import torchaudio

from nanofm.data.utils import save_video
from nanofm.utils.checkpoint import load_safetensors, load_state_dict
from nanofm.data.multimodal.masking import SimpleMultimodalMasking
from nanofm.data.multimodal.adapted_multimodal_dataset import AdaptedMultimodalDataset

# Tokenizer imports
from nanofm.data.tokenizers.image_tokenizer import ImageTokenizer
from nanofm.data.tokenizers.audio_tokenizer import AudioTokenizer
from nanofm.data.tokenizers.video_tokenizer import VideoTokenizer

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# The flag below controls whether to allow TF32 on matmul. This flag defaults to False in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_grad_enabled(False)
print(device)

/home/bousquie/.conda/envs/saga304bousq/bin/python
cuda


## DEFINITION OF THE CONSTANT FOR DATA LOCATION + COSMOS TOKENIZER CHOICE

In [30]:
# Paths to checkpoints and datasets
model_checkpoint_path = '/work/com-304/SAGA/outputs/label2audio/2_train/label2audio_2025-05-25_14-49/checkpoint-final.safetensors'
tokenized_data_path = '/work/com-304/SAGA/tokens_2025-05-25_14-49/'
IMAGE_MODEL_NAME = "Cosmos-0.1-Tokenizer-DI16x16"
VIDEO_MODEL_NAME = "Cosmos-0.1-Tokenizer-DV8x8x8"
modalities = ['tok_audio@24_000', 'tok_label']

## LOADING OF THE TOKENIZER

In [31]:
# Initialize tokenizers
image_tokenizer = ImageTokenizer(
            model_name=IMAGE_MODEL_NAME,
            device=torch.device(device)
        )
audio_tokenizer = AudioTokenizer(device=torch.device("cpu"))
video_tokenizer = VideoTokenizer(
            model_name=VIDEO_MODEL_NAME,
            device=torch.device(device)
        )

Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

## INSTANTIATION OF THE DATASET USED FOR LOADING TOKENIZE DATA

In [32]:
from nanofm.data.tokenizers.dataset import MyImageDataset
img_dataset2 = AdaptedMultimodalDataset(root_dir = tokenized_data_path,
            split = "eval" ,
            modalities= modalities ,
            sample_from_k_augmentations = 1)

## HELPER FUNCTION TO RETURN MODALITIES IN 'PLOTING' FORMAT

In [33]:
def show_im_from_tensor(tensor_or_array):
    if isinstance(tensor_or_array, torch.Tensor):
        # Move to CPU if on another device
        tensor_or_array = tensor_or_array.detach().cpu()
        # Convert to numpy
        array = tensor_or_array.numpy()
    else:
        array = tensor_or_array
    
    # Ensure array is in the right range
    array = np.clip(array, 0, 1)
    
    # Transpose from (C, H, W) to (H, W, C)
    array = array.transpose(1, 2, 0)
    
    # Scale to [0, 255] and convert to uint8
    array = (array * 255).astype(np.uint8)
    
    # Create a PIL Image and save
    image = Image.fromarray(array)
    return image

def get_gif_bytes_from_tensor(frames: torch.Tensor, fps: int = 3) -> bytes:

    # 1. CPU → numpy and clamp
    frames_np = frames.cpu().numpy()
    frames_np = np.clip(frames_np, 0, 1)

    # 2. (C, T, H, W) → (T, H, W, C)
    frames_np = frames_np.transpose(1, 2, 3, 0)

    # 3. [0–1] → [0–255] uint8
    frames_np = (frames_np * 255).astype(np.uint8)

    # 4. Create Pil Images
    pil_frames = [Image.fromarray(f) for f in frames_np]

    # 5. Save into a buffer in memory
    buf = io.BytesIO()
    pil_frames[0].save(
        buf,
        format='GIF',
        save_all=True,
        append_images=pil_frames[1:],
        duration=int(1000 / fps),
        loop=0
    )
    buf.seek(0)
    return buf.getvalue()

## LOADING MODEL TRAINED FROM MEMORY

In [34]:
ckpt, config = load_safetensors(model_checkpoint_path)
model = instantiate(config)

load_state_dict(model, ckpt, ignore_missing='dec_context_proj')

In [35]:
model = model.to(device)
e = model.eval()

## HELPER FUNCTION TO : TOKENS -> DATA & SHOW DATA + CONSTRUCT NECESSARY ELEMENT FOR GENERATION

In [36]:
def construct_input_from_sample(dataset, idx, input_modality):
    input_tensor = dataset[idx][input_modality]
    n_tokens_input = input_tensor.shape[0]
    enc_input_tokens = input_tensor.unsqueeze(0).to(device)
    enc_input_positions = torch.arange(n_tokens_input, device=device).unsqueeze(0)
    enc_input_modalities = modalities.index(input_modality) * torch.ones(1, n_tokens_input, device=device, dtype=torch.long)
    return enc_input_tokens, enc_input_positions, enc_input_modalities

def token_ids_to_image(token_ids, image_tokenizer, to_pil=False):
    n_tokens = token_ids.numel()
    side = int(math.sqrt(n_tokens))
    token_ids = token_ids.reshape(1,side,side).to(device)
    reconst = image_tokenizer.decode(token_ids)
    reconst = (reconst[0].float().cpu())
    if to_pil:
        reconst = TF.to_pil_image(reconst)
    return reconst

def tokens_ids_to_audio(token_ids, audio_tokenizer):
    num_quantizers = 32
    token_ids = token_ids.reshape(1,num_quantizers,-1).to(device)
    reconst = audio_tokenizer.decode(token_ids.clamp(0,2047).cpu())
    reconst = reconst.squeeze(0)
    player = Audio(reconst, rate=24_000)
    display(player)
    return reconst

def tokens_ids_to_gif(token_ids,video_tokenizer):
    num_frames = 8
    model_bucket_len = 8
    nb_div = math.ceil((num_frames + 1) / model_bucket_len)
    
    n_tokens = token_ids.numel() // nb_div
    side = int(math.sqrt(n_tokens))
    token_ids = token_ids.reshape(1,nb_div,side,side).to(device)

    reconst = video_tokenizer.decode(token_ids)
    reconst = reconst[0].float().cpu()
    
    reconst_gif = get_gif_bytes_from_tensor(reconst)
    gif = IPyImage(reconst_gif, format = 'gif')
    display(gif)

def show_modality(tokens, modality: str):
    if modality.__contains__('rgb') or modality.__contains__('depth') :
        token_ids_to_image(tokens, image_tokenizer, to_pil=True).show()
    if modality.__contains__('audio') :
        tokens_ids_to_audio(tokens, audio_tokenizer)
    if modality.__contains__('video'):
        tokens_ids_to_gif(tokens,video_tokenizer)

# ***FUNCTION TO GENERATE MODALITIES FROM ANOTHER :***

### **Exemple of usage** : With modalities define as follow, you can index this array to select the modalitie that you want as input / as output
```
input_mod = ...
target_mod = ...
```
    with 

```
modalities = [
    'tok_audio@24_000',
    'tok_depth@256',
    'tok_rgb@256',
    'tok_video@256',
    'tok_label'
]
```

You can precise the nb of sample you want to generate with `` nb_iteration ``

In [None]:
input_mod = 'tok_audio@24_000'
target_mod = 'tok_label'

In [37]:
nb_iteration = 1
for i in range(nb_iteration):
    sample_idx = i

    x_tokens, x_positions, x_modalities = construct_input_from_sample(img_dataset2, idx=sample_idx, input_modality=input_mod)
    show_modality(x_tokens, input_mod)
    
    
    
    num_steps, temp, top_p, top_k = 10, 0.7, 0.0, 0.0
    
    pred_tokens, x_tokens, x_positions, x_modalities = model.generate_one_modality_roar(
        x_tokens, x_positions, x_modalities, target_mod=target_mod,
        num_steps=num_steps, temp=temp, top_p=top_p, top_k=top_k,
    )
    show_modality(pred_tokens, target_mod)

AssertionError: Number of steps should be less than or equal to the total number of tokens to unmask.