In [None]:
from layer_diffuse.data_loaders import ModularCharatersDataLoader
from layer_diffuse.models import DDIMNextTokenV1_Refactored
import json
pipeline = DDIMNextTokenV1_Refactored.DDIMNextTokenV1PipelineRefactored()
vocab_file = "layer_diffuse/vocab.json"
with open(vocab_file, 'r') as f:
    vocab = json.load(f)
dataloader = ModularCharatersDataLoader.get_modular_char_dataloader(dataset_name='QLeca/modular_characters_v3',
                                                                            split='train',
                                                                            image_size=128,
                                                                            batch_size=8,
                                                                            shuffle=True,
                                                                            streaming=True,
                                                                            conversionRGBA=True,
                                                                            vocab=vocab)

In [None]:
pipeline.list_versions()

In [None]:
pipeline.load_model_from_hub(run='run_2025-06-23_19-12-13',
                             epoch=49)
pipeline.set_num_class_embeds(len(dataloader.vocab))

In [None]:
import torch
import torchvision
from torchvision.utils import make_grid

def show_image_grid(input_images, output_images, target_images):
    output_images = (output_images * 0.5 + 0.5).clamp(0, 1).cpu()
    input_images = (input_images * 0.5 + 0.5).clamp(0, 1).cpu()
    target_images = (target_images * 0.5  + 0.5).clamp(0, 1).cpu()
    concat = torch.concat([input_images, output_images, target_images])
    grid = make_grid(concat, nrow=input_images.shape[0])
    img = torchvision.transforms.ToPILImage()(grid)
    display(img)


In [None]:
for batch in dataloader:
    input_images = batch['input']
    target_images = batch['target']
    labels = batch['label']
    outputs = pipeline(input_images=input_images, 
                       class_labels=labels,
                       num_inference_steps=50)
    show_image_grid(input_images, outputs, target_images)
    break    

In [None]:
# Test the log_resume_info function
import wandb
import os
from layer_diffuse.models import DDIMNextTokenV1_Refactored

# # Initialize a test wandb run first
# wandb.init(
#     project="test_resume_info",
#     name="test_log_resume_info",
#     mode="offline"  # Use offline mode for testing to avoid creating actual wandb runs
# )

print("Testing log_resume_info function...")

# Test case 1: Test with a known run name that exists
test_run_name = "run_2025-06-24_12-04-24"
test_epoch = 41

# Create the pipeline instance
pipeline = DDIMNextTokenV1_Refactored.DDIMNextTokenV1PipelineRefactored()


In [None]:
from layer_diffuse.models.DDIMNextTokenV1_Refactored import DDIMNextTokenV1PipelineRefactored
pipeline = DDIMNextTokenV1PipelineRefactored()

In [None]:
import json
from layer_diffuse.data_loaders import ModularCharatersDataLoader
from datasets import load_dataset
from PIL import Image
from torchvision import transforms
import torch
from datasets.iterable_dataset import IterableDataset

vocab_file = "layer_diffuse/vocab.json"
dataset_name = 'QLeca/modular_characters_v3'
split = 'train'
streaming = True
conversionRGBA = True
image_size = 128
with open(vocab_file, 'r') as f:
    vocab = json.load(f)
    
dataset = load_dataset(dataset_name, split=split, cache_dir='cache/datasets', streaming=streaming)
assert isinstance(vocab, dict), "Vocab should be a dictionary mapping prompts to indices."
preprocess = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            # transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

def transform(rows:dict)->dict:
    # Convert RGBA images to RGB with white background
    def rgba_to_rgb_white(img):
        if img.mode == 'RGBA' and conversionRGBA:
            background = Image.new('RGB', img.size, (255, 255, 255))
            background.paste(img, mask=img.split()[3])  # 3 is the alpha channel
            return background
        return img

    rows['input'] = [rgba_to_rgb_white(image) for image in rows['input']]
    rows['target'] = [rgba_to_rgb_white(image) for image in rows['target']]
    images_input = [preprocess(image) for image in rows['input']]
    images_target = [preprocess(image) for image in rows['target']]
    class_labels = [torch.tensor(vocab.get(prompt,-1),dtype=torch.long).unsqueeze(0) for prompt in rows['prompt']]
    return {'input': images_input,
            'target': images_target,
            'label': class_labels}

if isinstance(dataset, IterableDataset):
    dataset = dataset.map(transform, batched=True, remove_columns=['prompt']) # type: ignore

In [None]:
train_dataloader = ModularCharatersDataLoader.get_modular_char_dataloader(
        dataset_name=dataset_name,
        split=split,
        image_size=pipeline.train_config.image_size,
        batch_size=pipeline.train_config.train_batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=True,
        persistent_workers=False,
        streaming=streaming,
        conversionRGBA=True,
        vocab=vocab,  # Pass the vocabulary if provided
    )

In [None]:
extra_kwargs = {
        "num_cycles": 0.5,  # Pass the num_cycles parameter
        "train_tags": 'DEBUG',  # Pass the train_tags parameter
        "gradient_accumulation_steps": 4,
        "mixed_precision": "fp16",
        "dataloader_num_workers": 0,
    }

In [None]:
pipeline.train_accelerate(
            train_dataloader=train_dataloader,
            val_dataloader=train_dataloader,
            train_size=1000,
            val_size=100,
            **extra_kwargs,  # Pass all extra parameters through
        )