In [None]:
from data_loaders import ModularCharatersDataLoader
from models import DDPMNextTokenV2
from datasets import load_dataset



train_dataloader = ModularCharatersDataLoader.get_modular_char_dataloader(dataset_name='QLeca/modular_characters_small',
                                                                            split='train',
                                                                            image_size=128,
                                                                            batch_size=16,
                                                                            shuffle=True)
dataset = load_dataset('QLeca/modular_characters_small', 
                       split='train')

prompts = dataset['prompt']
vocab = list(dict.fromkeys(prompts))
vocab = sorted(vocab)
vocab = dict(zip(vocab, range(len(vocab))))
pipeline = DDPMNextTokenV2.DDPMNextTokenV2Pipeline()
pipeline.set_class_vocabulary(vocab.keys())
  

In [None]:
import torch

def get_class_labels(prompts:list[str]) -> torch.Tensor:
    class_labels = []
    for prompt in prompts:
        if prompt in vocab:
            class_labels.append(vocab[prompt])
        else:
            class_labels.append(-1)  # Use -1 for unknown classes
    return torch.tensor(class_labels, dtype=torch.long)


for batch in train_dataloader:
    input_images = batch['input']
    target_images = batch['target']
    prompts = batch['prompt']
    labels = get_class_labels(prompts).unsqueeze(1)  # Add a dimension for class labels
    print(labels.shape)
    outputs = pipeline(input_images=input_images,
                       class_labels=labels)
    
    break  # Just run one batch for demonstration purposes