### Import and process dataset

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import argparse
import yaml
import os
from torchvision.utils import make_grid
from tqdm import tqdm
import numpy as np



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()

In [None]:
from dataclasses import dataclass

@dataclass
class TrainingConfig:
    image_size = 128  # the generated image resolution
    train_batch_size = 16
    eval_batch_size = 16  # how many images to sample during evaluation
    num_epochs = 50
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 10
    save_model_epochs = 30
    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "training_output"  # the model name locally and on the HF Hub

    push_to_hub = True  # whether to upload the saved model to the HF Hub
    hub_model_id = "QLeca/NextLayerModularCharacterModel"  # the name of the repository to create on the HF Hub
    hub_private_repo = None
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    seed = 0


config = TrainingConfig()

In [None]:
from datasets import load_dataset

config.dataset_name = "QLeca/modular_characters"
vals_ds = load_dataset(config.dataset_name, split=f'train[0%:1%]')
trains_ds = load_dataset(config.dataset_name, split=f'train[1%:5%]')


In [None]:
from torchvision import transforms

preprocess = transforms.Compose(
    [
        transforms.Resize((config.image_size, config.image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)

In [None]:
def transform(rows):
    images_input = [preprocess(image) for image in rows["input"]]
    images_target = [preprocess(image) for image in rows["target"]]
    
    return {"input": images_input,
            'target': images_target,
            'prompt': rows['prompt']}

trains_ds = trains_ds.map(transform, batched=True)
vals_ds = vals_ds.map(transform, batched=True)
    

In [None]:
import torch
train_dataloader = torch.utils.data.DataLoader(trains_ds, batch_size=config.train_batch_size, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(vals_ds, batch_size=config.eval_batch_size, shuffle=True)

### Create U-Net Model

### Training

In [None]:
from diffusers.optimization import get_cosine_schedule_with_warmup

optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(train_dataloader) * config.num_epochs),
)

In [None]:
def sample(input_images, model, scheduler, train_config, epoch):

    xt = torch.randn((input_images.shape[0],
                      4,
                      train_config.image_size,
                      train_config.image_size)).to(device)
    for i in tqdm(reversed(range(50))):
        # Get prediction of noise
        noisy_samples = torch.concat([input_images, xt], dim=1)
        noise_pred = model(noisy_samples, torch.as_tensor(i).unsqueeze(0).to(device))[0]
        
        # Use scheduler to get x0 and xt-1
        xt, x0_pred = scheduler.step(noise_pred, torch.as_tensor(i).to(device), xt, return_dict=False)
        # Save x0
        ims = torch.clamp(xt, -1., 1.).detach().cpu()
        ims = (ims + 1) / 2
        grid = make_grid(ims)
        img = torchvision.transforms.ToPILImage()(grid)
        if not os.path.exists(os.path.join(train_config.output_dir, 'samples_epoch_{}'.format(epoch))):
            os.mkdir(os.path.join(train_config.output_dir, 'samples_epoch_{}'.format(epoch)))
        img.save(os.path.join(train_config.output_dir, 'samples_epoch_{}'.format(epoch), 'x0_{}.png'.format(i)))
        img.close()

### Training loop

In [None]:
from accelerate import Accelerator
from huggingface_hub import create_repo, upload_folder
from tqdm.auto import tqdm
from pathlib import Path
import os

def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, val_dataloader, lr_scheduler, device):
    # Initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join(config.output_dir, "logs")
    )
    if accelerator.is_main_process:
        if config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)
        if config.push_to_hub:
            repo_id = create_repo(
                repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True
            ).repo_id
        accelerator.init_trackers("train_example")

    # Prepare everything
    # There is no specific order to remember, you just need to unpack the
    # objects in the same order you gave them to the prepare method.
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )

    global_step = 0

    # Now you train the model
    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            if step >=100:
                break
            input_images = batch["input"].to(device)
            target_images = batch["input"].to(device)
            
            # Sample noise to add to the images
            noise = torch.randn(target_images.shape, device=target_images.device)
            bs = target_images.shape[0]

            # Sample a random timestep for each image
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bs,), device=target_images.device,
                dtype=torch.int64
            )

            # Add noise to the clean images according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_targets = noise_scheduler.add_noise(target_images, noise, timesteps)

            with accelerator.accumulate(model):
                # Predict the noise residual
                noisy_samples = torch.concat([input_images, noisy_targets], dim=1)
                noise_pred = model(noisy_samples, timesteps, return_dict=False)[0]
                loss = F.mse_loss(noise_pred, noise)
                accelerator.backward(loss)

                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

        # After each epoch you optionally sample some demo images with evaluate() and save the model
        if accelerator.is_main_process:

            if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
                for batch in val_dataloader:
                    with torch.no_grad():
                        sample(input_images=batch['input'].to(device),
                                model=accelerator.unwrap_model(model),
                                scheduler=noise_scheduler,
                                train_config=config,
                                epoch=epoch)
                    break
            
            # if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
            #     if config.push_to_hub:
            #         upload_folder(
            #             repo_id=repo_id,
            #             folder_path=config.output_dir,
            #             commit_message=f"Epoch {epoch}",
            #             ignore_patterns=["step_*", "epoch_*"],
            #         )
            #     else:
            #         pipeline.save_pretrained(config.output_dir)

In [None]:
from accelerate import notebook_launcher


args = (config, model, noise_scheduler, optimizer, train_dataloader, val_dataloader, lr_scheduler,device)

notebook_launcher(train_loop, args, num_processes=1)

In [None]:
from models import DDPMNextTokenV1
from data_loaders import ModularCharatersDataLoader
dataset_name = "QLeca/modular_characters"
train_split = f"train[0%:1%]"
pipeline = DDPMNextTokenV1.DDPMNextTokenV1Pipeline()
train_dataloader = ModularCharatersDataLoader.get_modular_char_dataloader(dataset_name=dataset_name,
                                                                          split=train_split,
                                                                          image_size=pipeline.train_config.image_size,
                                                                          batch_size=pipeline.train_config.train_batch_size,
                                                                          shuffle=True)

In [None]:
from models import DDPMNextTokenV1
from data_loaders import ModularCharatersDataLoader
from PIL import Image
dataset_name = "QLeca/modular_characters"
val_split = f"train[0%:1%]"
train_split = f"train[1%:5%]"


from safetensors import safe_open

# tensors = {}
# with safe_open("/Data/quentin.leca/training_output/diffusion_pytorch_model.safetensors", framework="pt", device=0) as f:
#     for k in f.keys():
#         tensors[k] = f.get_tensor(k)
# print(tensors)

# train_dataloader = ModularCharatersDataLoader.get_modular_char_dataloader(dataset_name=dataset_name,
#                                                                           split=train_split,
#                                                                           image_size=pipeline.train_config.image_size,
#                                                                           batch_size=pipeline.train_config.train_batch_size,
#                                                                           shuffle=True)
pipeline = DDPMNextTokenV1.DDPMNextTokenV1Pipeline()
pipeline.load_config('/Data/quentin.leca/training_output/')
val_dataloader = ModularCharatersDataLoader.get_modular_char_dataloader(dataset_name=dataset_name,
                                                                          split=val_split,
                                                                          image_size=pipeline.train_config.image_size,
                                                                          batch_size=pipeline.train_config.eval_batch_size,
                                                                          shuffle=True)

# pipeline.train(train_dataloader, val_dataloader)

for batch in val_dataloader:
    images = pipeline(batch['input'], batch['prompt'])
    break

SyntaxError: invalid syntax (4183339886.py, line 33)

In [None]:
images

In [None]:
import torchvision
image = images[0]
img = torchvision.transforms.ToPILImage()(image)

In [None]:
img