## [1] Pre-Process Data

In [None]:
import re
from datasets import load_dataset
from transformers import GPT2TokenizerFast
from itertools import chain
from dataclasses import dataclass

In [15]:
@dataclass
class TrainingConfig:
    block_size = 1024
    n_worker = 12
    batch_size = 12

    vocab_size = 50257

    hidden_size = 768
    cond_dim = 128
    n_blocks = 10
    n_heads = 12
    dropout = 0.1

    sample_batch_size = 1
    num_epochs = 20
    gradient_accumulation_steps = 8
    learning_rate = 3e-4
    lr_warmup_steps = 500
    save_model_epochs = 1
    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "DFM_v0"

config = TrainingConfig()

In [3]:
data = load_dataset("Salesforce/wikitext", name="wikitext-103-raw-v1")
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
EOS = tokenizer.encode(tokenizer.eos_token)[0]

In [None]:
# This function re-formats the dataset to make it more human readible
def wt_detokenizer(string):
    # contractions
    string = string.replace("s '", "s'")
    string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
    # number separators
    string = string.replace(" @-@ ", "-")
    string = string.replace(" @,@ ", ",")
    string = string.replace(" @.@ ", ".")
    # punctuation
    string = string.replace(" : ", ": ")
    string = string.replace(" ; ", "; ")
    string = string.replace(" . ", ". ")
    string = string.replace(" ! ", "! ")
    string = string.replace(" ? ", "? ")
    string = string.replace(" , ", ", ")
    # double brackets
    string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
    string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
    string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
    string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
    string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
    # miscellaneous
    string = string.replace("= = = =", "====")
    string = string.replace("= = =", "===")
    string = string.replace("= =", "==")
    string = string.replace(" " + chr(176) + " ", chr(176))
    string = string.replace(" \n", "\n")
    string = string.replace("\n ", "\n")
    string = string.replace(" N ", " 1 ")
    string = string.replace(" 's", "'s")
    return string

# This function preprocesses the dataset and tokenizes the text
def preprocess_and_tokenize(example: dict):
    text = example["text"]

    for i, t in enumerate(text):
        text[i] = wt_detokenizer(t)
    
    tokens = tokenizer(text, return_attention_mask=False)

    for token in tokens["input_ids"]:
        token.append(EOS)

    return tokens

# This function groups the tokenized texts into blocks of a specified size
def group_texts(examples: dict):
    block_size = config.block_size
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    total_length = (total_length // block_size) * block_size
    result = {
        k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    
    return result

In [None]:
tokenized_dataset = data.map(
    preprocess_and_tokenize,
    batched=True,
    num_proc=config.n_worker,
    load_from_cache_file=True,
)

tokenized_dataset = tokenized_dataset.remove_columns("text")

chunked_dataset = tokenized_dataset.map(
    group_texts,
    batched=True,
    num_proc=config.n_worker,
    load_from_cache_file=True,
)
chunked_dataset = chunked_dataset.with_format("torch")

# Save the processed dataset to disk so to avoid waiting
chunked_dataset.save_to_disk("chunked_dataset")

## [2] Defining the Discrete Flow Logic and Model

To speed up training, we are using Hugging Face Accelerate Library which allows training on multiple GPUs with a simple wrapper. Unfortuantely, for this to work we need to work with .py file so the cell below automatically generates the .py file. We are assuming that the dataset has been pre-precessed according to the instructions in the previous section.

In [None]:
# %%writefile DFM_accelerator.py             #####Uncomment this line#####

import math
from typing import Optional, Tuple, Union
import os


import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.data import DataLoader

from datasets import DatasetDict, load_from_disk
from dataclasses import dataclass

from einops import rearrange, repeat

from tqdm.auto import tqdm

from accelerate import Accelerator
from torch.optim import Optimizer, AdamW
from diffusers.optimization import get_cosine_schedule_with_warmup


#Following is standard code to set up transformer with adaLN Modulation

'''
In short the Transformer we define here takes a noisy sample x_t and time t
and outputs x_1 (denoised prediction).

Here we use fancy Deep Learning Architecture to achieve this but the idea
doesn't depend on the architecture.
'''
class Rotary(torch.nn.Module):
    """
    From: https://github.com/louaaron/Score-Entropy-Discrete-Diffusion
    """

    def __init__(self, dim: int, base: int = 10_000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x: Tensor, seq_dim: int = 1) -> Tuple[Tensor, Tensor]:
        seq_len = x.shape[seq_dim]
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone())
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)

            # dims are: batch, seq_len, qkv, head, dim
            self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1, 1, 3, 1, 1)
            self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1, 1, 3, 1, 1)

            # This makes the transformation on v an identity.
            self.cos_cached[:, :, 2, :, :].fill_(1.0)
            self.sin_cached[:, :, 2, :, :].fill_(0.0)

        return self.cos_cached, self.sin_cached


def rotate_half(x: Tensor) -> Tensor:
    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]

    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
    """
    From: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L20
    """
    cos = cos[0, :, 0, 0, : cos.shape[-1] // 2]
    sin = sin[0, :, 0, 0, : sin.shape[-1] // 2]

    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
    cos = repeat(
        cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
    sin = repeat(
        sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )

    return x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim]) * sin


def bias_dropout_add_scale(
        x: Tensor, scale: Tensor, residual: Optional[Tensor],
        prob:float, trainning: bool
) -> Tensor:
    return residual + scale * F.dropout(x, p=prob, training=trainning)

def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
    return x * (1 + scale) + shift

class TimestepEmbedder(nn.Module):
    def __init__(
            self,
            hidden_size: int,
            frequency_embedding_size: int = 256
    ) -> None:
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size)
        )
        self.frequency_embedding_size= frequency_embedding_size

    @staticmethod
    def timestep_embedding(
        time: Tensor,
        dim: int,
        max_period: int = 10_000
    ) -> Tensor:
        '''
        Create sinusoidal timestep embeddings.
        time: 1D Tensor of N indices, one per batch element.
        dim:  output dimension
        max_period: minimum freq of embedding

        return (N, dim)
        '''
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period)
            * torch.arange(start=0, end=half)
            / half
        ).to(device=time.device)

        args = time[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat(
                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
            )
        return embedding
    
    def forward(self, time: Tensor) -> Tensor:
        t_freq = self.timestep_embedding(time=time, dim=self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb
    
class DDiTBlock(nn.Module):
    def __init__(
            self,
            dim: int,
            n_heads: int,
            cond_dim: int,
            mlp_ratio: int = 4,
            dropout: float = 0.4,
    ):
        super().__init__()
        assert dim % n_heads == 0

        self.n_heads = n_heads
        self.dim = dim
        self.dropout = dropout

        self.head_dim = self.dim // self.n_heads

        self.norm1 = nn.LayerNorm([self.dim], bias=False)

        self.qw = nn.Linear(dim, dim, bias=False)
        self.kw = nn.Linear(dim, dim, bias=False)
        self.vw = nn.Linear(dim, dim, bias=False)

        self.attn_out = nn.Linear(dim, dim, bias=False)
        self.dropout1 = nn.Dropout(dropout)

        self.norm2 = nn.LayerNorm([self.dim], bias=False)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_ratio*dim, bias=True),
            nn.GELU(approximate="tanh"),
            nn.Linear(mlp_ratio*dim, dim, bias=False)
        )

        self.adaLN_modulation = nn.Linear(cond_dim, 6*dim, bias=True)
        self.adaLN_modulation.weight.data.zero_()
        self.adaLN_modulation.bias.data.zero_()

    def forward(self, x: Tensor, rotary_cos_sin: Tensor, c: Tensor) -> Tensor:
        batch_size, seq_len = x.shape[0], x.shape[1]

        (
            shift_msa,
            scale_msa,
            gate_msa,
            shift_mlp,
            scale_mlp,
            gate_mlp,
        ) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)

        x_skip = x
        x = modulate(x=self.norm1(x), shift=shift_msa, scale=scale_msa)

        q = self.qw(x)
        k = self.kw(x)
        v = self.vw(x)

        q, k, v = (
            item.view(batch_size, seq_len, self.n_heads, self.head_dim)
            for item in (q, k, v)
        )

        cos, sin = rotary_cos_sin
        original_dtype = q.dtype

        q = apply_rotary_emb_torch(
            x=q.float(), cos=cos.float(), sin=sin.float()
        ).to(original_dtype)

        k = apply_rotary_emb_torch(
            x=k.float(), cos=cos.float(), sin=sin.float()
        ).to(original_dtype)

        q, k, v = (item.transpose(1, 2) for item in (q, k, v))

        x = F.scaled_dot_product_attention(query=q, key=k, value=v)
        x = rearrange(x, "b h s d -> b s (h d)", b=batch_size)
        x = bias_dropout_add_scale(
            x = self.attn_out(x),
            scale=gate_msa,
            residual=x_skip,
            prob=self.dropout,
            trainning=self.training
        )

        skip_x = modulate(x=self.norm2(x), shift=shift_mlp, scale=scale_mlp)

        x = bias_dropout_add_scale(
            x = self.mlp(skip_x),
            scale=gate_mlp,
            residual=skip_x,
            prob=self.dropout,
            trainning=self.training,
        )

        return x
    

class DDitFinalLayer(nn.Module):
    def __init__(
            self,
            hidden_size: int,
            out_channels: int,
            cond_dim: int
    ) -> None:
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size)
        self.linear = nn.Linear(hidden_size, out_channels)
        self.linear.weight.data.zero_()
        self.linear.bias.data.zero_()

        self.adaLN_modulation = nn.Linear(cond_dim, 2*hidden_size, bias=True)
        self.adaLN_modulation.weight.data.zero_()
        self.adaLN_modulation.bias.data.zero_()

    def forward(self, x: Tensor, c: Tensor) -> Tensor:
        shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
        x = modulate(x=self.norm_final(x), shift=shift, scale=scale)
        x = self.linear(x)

        return x
    
class Transformer(nn.Module):
    def __init__(
            self,
            vocab_size: int,
            masked: bool,
            config,
    ) -> None:
        super().__init__()

        self.config = config
        self.vocab_size = vocab_size

        self.vocab_embed = nn.Embedding(self.vocab_size, config.hidden_size)

        self.time_embedding = TimestepEmbedder(hidden_size=config.cond_dim)
        self.rotary_emb = Rotary(dim=config.hidden_size // config.n_heads)

        self.blocks = nn.ModuleList(
            [
                DDiTBlock(
                    dim=config.hidden_size,
                    n_heads=config.n_heads,
                    cond_dim=config.cond_dim,
                    dropout=config.dropout,
                )
                for _ in range(config.n_blocks)
            ]
        )

        self.output_layer = DDitFinalLayer(
            hidden_size=config.hidden_size,
            out_channels=vocab_size,
            cond_dim=config.cond_dim,
        )

    def forward(self, x_t: Tensor, time: Tensor) -> Tensor:
        x = self.vocab_embed(x_t)
        c = F.silu(self.time_embedding(time=time))

        rotary_cos_sin = self.rotary_emb(x=x)

        for i in range(len(self.blocks)):
            x = self.blocks[i](x=x, rotary_cos_sin=rotary_cos_sin, c=c)

        x = self.output_layer(x=x, c=c)

        return x
    
#Here we define the MAIN LOGIC for Discrete Flow Matching
@dataclass
class SchedulerOutput:
    alpha_t: Tensor
    sigma_t: Tensor

    d_alpha_t: Tensor
    d_sigma_t: Tensor

class Scheduler:
    def __call__(self, t: Tensor) -> SchedulerOutput:
        ...

class PolynomialConvexScheduler(Scheduler):
    def __init__(self, n: Union[float, int]) -> None:
        self.n = n

    def __call__(self, t: Tensor) -> SchedulerOutput:
        return SchedulerOutput(
            alpha_t=t**self.n,
            sigma_t=1 - t**self.n,
            d_alpha_t=self.n * (t**(self.n - 1)),
            d_sigma_t=-self.n * (t**(self.n - 1))
        )
    
@dataclass
class DiscretePathSample:
    x_1: Tensor
    x_0: Tensor
    t:   Tensor
    x_t: Tensor

class ProbPath:
    def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample:
        ...

    def assert_sample_shape(self, x_0: Tensor, x_1: Tensor, t: Tensor):
        assert t.ndim == 1, "t must have shape [batch_size]"
        assert t.shape[0] == x_0.shape[0] == x_1.shape[0], "Mismatch Batch Size"

class MixtureDiscreteProbPath(ProbPath):
    def __init__(self, scheduler: Scheduler) -> None:
        super().__init__()
        self.scheduler = scheduler

    def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample:
        self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t)

        B, L = x_1.shape

        sigma_t = self.scheduler(t).sigma_t
        sigma_t = rearrange(sigma_t, 'd -> d 1')
        sigma_t = sigma_t.expand_as(x_1)

        source_indices = torch.rand(size=x_1.shape, device=x_1.device) < sigma_t
        x_t = torch.where(condition=source_indices, input=x_0, other=x_1)

        return DiscretePathSample(
            x_t = x_t,
            x_1 = x_1,
            x_0 = x_0,
            t = t,
        )
    
class UniformSourceDistribution:
    def __init__(self, vocab_size: int) -> None:
        self.vocab_size = vocab_size

    def sample(self, tensor_size: Tuple[int, ...], device: torch.device) -> Tensor:
        return torch.randint(size=tensor_size, high=self.vocab_size, device=device)

    def sample_like(self, tensor_like: Tensor) -> Tensor:
        return torch.randint_like(tensor_like, high=self.vocab_size)
    

#####################Training Config######################
@dataclass
class TrainingConfig:
    block_size = 1024
    n_worker = 1
    batch_size = 12

    vocab_size = 50257

    hidden_size = 768
    cond_dim = 128
    n_blocks = 10
    n_heads = 12
    dropout = 0.1

    num_epochs = 20
    gradient_accumulation_steps = 8
    learning_rate = 3e-4
    lr_warmup_steps = 500
    save_model_epochs = 1
    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "DFM_v0"

    overwrite_output_dir = True  # overwrite the old model when re-running the notebook

def train_loop(
        config: TrainingConfig,
        source: UniformSourceDistribution,
        velocity: MixtureDiscreteProbPath, 
        train_dataloader: DatasetDict,
        ):

    model = Transformer(
    config=config,
    vocab_size=config.vocab_size,
    masked=False
    )

    optim = AdamW(
    model.parameters(),
    lr = config.learning_rate
    )

    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer=optim,
        num_warmup_steps=config.lr_warmup_steps,
        num_training_steps=(len(train_dataloader) * config.num_epochs)
    )

    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join(config.output_dir, "logs")
    )

    if accelerator.is_main_process:
        if config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)
        accelerator.init_trackers("train_examples")

    model, source, velocity, optim, train_dataloader, lr_scheduler = accelerator.prepare(
        model, source, velocity, optim, train_dataloader, lr_scheduler
    )

    global_step = 0

    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_main_process)
        progress_bar.set_description(f"Epoch {epoch}")

        for step, x_1 in enumerate(train_dataloader):
            x_1 = x_1["input_ids"].to(accelerator.device)

            with torch.no_grad():
                x_0 = source.sample_like(x_1)
                t = torch.rand(x_1.shape[0], device=accelerator.device) * (1 - 1e-3)
                vel_t = velocity.sample(t=t, x_1=x_1, x_0=x_0)

            with accelerator.accumulate(model):
                logits = model(x_t=vel_t.x_t, time=vel_t.t)
                loss = F.cross_entropy(logits.flatten(0, 1), x_1.flatten(0, 1), reduction="mean")
                accelerator.backward(loss)

                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optim.step()
                lr_scheduler.step()
                optim.zero_grad()
        
            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step+=1
    
        if accelerator.is_main_process:
            if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
                accelerator.save_model(model, config.output_dir)

def main():
    config = TrainingConfig()

    vocab_size = config.vocab_size
    source_distribution = UniformSourceDistribution(vocab_size=vocab_size)
    scheduler = PolynomialConvexScheduler(n=2)
    path = MixtureDiscreteProbPath(scheduler=scheduler)

    chunked_dataset = load_from_disk("chunked_dataset")

    train_loader = DataLoader(
            chunked_dataset["train"],
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=config.n_worker
        )

    train_loop(config, source_distribution, path, train_loader)


if __name__ == "__main__":
    main()

In [None]:
'''
Before we run DFMaccelerator.py, we need to configure accelerate library.
This make sure we are optimally using our hardware.
'''

!accelerate config default

!cat path/to/your/default_config.yaml #Check if the config is correct

In [None]:
#Let's run the training script

!accelerate launch DFM_accelerator.py

## [3] Inference & Benchmarking

Here we will load the model and perform conditional / unconditional generation.

In [1]:
import math
from typing import Optional, Tuple, Union, List
from collections import Counter

import torch
from safetensors.torch import load_file
from torch import nn, Tensor
import torch.nn.functional as F

from dataclasses import dataclass

from pathlib import Path
from einops import rearrange, repeat

from tqdm.auto import tqdm

from transformers import GPT2TokenizerFast

In [2]:
@dataclass
class TrainingConfig:
    block_size = 1024
    n_worker = 1
    batch_size = 12

    vocab_size = 50257

    hidden_size = 768
    cond_dim = 128
    n_blocks = 10
    n_heads = 12
    dropout = 0.1

    # train_batch_size = 8
    # eval_batch_size = 16
    sample_batch_size = 1
    num_epochs = 20
    gradient_accumulation_steps = 8
    learning_rate = 3e-4
    lr_warmup_steps = 500
    save_model_epochs = 1
    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "DFM_v0"

    overwrite_output_dir = True  # overwrite the old model when re-running the notebook


config = TrainingConfig()

In [3]:
# While training, accelerate was checkpointing the model
# Here we create a dummy model to load the checkpoint

class Rotary(torch.nn.Module):
    """
    From: https://github.com/louaaron/Score-Entropy-Discrete-Diffusion
    """

    def __init__(self, dim: int, base: int = 10_000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x: Tensor, seq_dim: int = 1) -> Tuple[Tensor, Tensor]:
        seq_len = x.shape[seq_dim]
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone())
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)

            # dims are: batch, seq_len, qkv, head, dim
            self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1, 1, 3, 1, 1)
            self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1, 1, 3, 1, 1)

            # This makes the transformation on v an identity.
            self.cos_cached[:, :, 2, :, :].fill_(1.0)
            self.sin_cached[:, :, 2, :, :].fill_(0.0)

        return self.cos_cached, self.sin_cached


def rotate_half(x: Tensor) -> Tensor:
    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]

    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
    """
    From: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L20
    """
    cos = cos[0, :, 0, 0, : cos.shape[-1] // 2]
    sin = sin[0, :, 0, 0, : sin.shape[-1] // 2]

    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
    cos = repeat(
        cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
    sin = repeat(
        sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )

    return x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim]) * sin


def bias_dropout_add_scale(
        x: Tensor, scale: Tensor, residual: Optional[Tensor],
        prob:float, trainning: bool
) -> Tensor:
    return residual + scale * F.dropout(x, p=prob, training=trainning)

def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
    return x * (1 + scale) + shift

class TimestepEmbedder(nn.Module):
    def __init__(
            self,
            hidden_size: int,
            frequency_embedding_size: int = 256
    ) -> None:
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size)
        )
        self.frequency_embedding_size= frequency_embedding_size

    @staticmethod
    def timestep_embedding(
        time: Tensor,
        dim: int,
        max_period: int = 10_000
    ) -> Tensor:
        '''
        Create sinusoidal timestep embeddings.
        time: 1D Tensor of N indices, one per batch element.
        dim:  output dimension
        max_period: minimum freq of embedding

        return (N, dim)
        '''
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period)
            * torch.arange(start=0, end=half)
            / half
        ).to(device=time.device)

        args = time[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat(
                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
            )
        return embedding
    
    def forward(self, time: Tensor) -> Tensor:
        t_freq = self.timestep_embedding(time=time, dim=self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb
    
class DDiTBlock(nn.Module):
    def __init__(
            self,
            dim: int,
            n_heads: int,
            cond_dim: int,
            mlp_ratio: int = 4,
            dropout: float = 0.4,
    ):
        super().__init__()
        assert dim % n_heads == 0

        self.n_heads = n_heads
        self.dim = dim
        self.dropout = dropout

        self.head_dim = self.dim // self.n_heads

        self.norm1 = nn.LayerNorm([self.dim], bias=False)

        self.qw = nn.Linear(dim, dim, bias=False)
        self.kw = nn.Linear(dim, dim, bias=False)
        self.vw = nn.Linear(dim, dim, bias=False)

        self.attn_out = nn.Linear(dim, dim, bias=False)
        self.dropout1 = nn.Dropout(dropout)

        self.norm2 = nn.LayerNorm([self.dim], bias=False)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_ratio*dim, bias=True),
            nn.GELU(approximate="tanh"),
            nn.Linear(mlp_ratio*dim, dim, bias=False)
        )

        self.adaLN_modulation = nn.Linear(cond_dim, 6*dim, bias=True)
        self.adaLN_modulation.weight.data.zero_()
        self.adaLN_modulation.bias.data.zero_()

    def forward(self, x: Tensor, rotary_cos_sin: Tensor, c: Tensor) -> Tensor:
        batch_size, seq_len = x.shape[0], x.shape[1]

        (
            shift_msa,
            scale_msa,
            gate_msa,
            shift_mlp,
            scale_mlp,
            gate_mlp,
        ) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)

        x_skip = x
        x = modulate(x=self.norm1(x), shift=shift_msa, scale=scale_msa)

        q = self.qw(x)
        k = self.kw(x)
        v = self.vw(x)

        q, k, v = (
            item.view(batch_size, seq_len, self.n_heads, self.head_dim)
            for item in (q, k, v)
        )

        cos, sin = rotary_cos_sin
        original_dtype = q.dtype

        q = apply_rotary_emb_torch(
            x=q.float(), cos=cos.float(), sin=sin.float()
        ).to(original_dtype)

        k = apply_rotary_emb_torch(
            x=k.float(), cos=cos.float(), sin=sin.float()
        ).to(original_dtype)

        q, k, v = (item.transpose(1, 2) for item in (q, k, v))

        x = F.scaled_dot_product_attention(query=q, key=k, value=v)
        x = rearrange(x, "b h s d -> b s (h d)", b=batch_size)
        x = bias_dropout_add_scale(
            x = self.attn_out(x),
            scale=gate_msa,
            residual=x_skip,
            prob=self.dropout,
            trainning=self.training
        )

        skip_x = modulate(x=self.norm2(x), shift=shift_mlp, scale=scale_mlp)

        x = bias_dropout_add_scale(
            x = self.mlp(skip_x),
            scale=gate_mlp,
            residual=skip_x,
            prob=self.dropout,
            trainning=self.training,
        )

        return x
    

class DDitFinalLayer(nn.Module):
    def __init__(
            self,
            hidden_size: int,
            out_channels: int,
            cond_dim: int
    ) -> None:
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size)
        self.linear = nn.Linear(hidden_size, out_channels)
        self.linear.weight.data.zero_()
        self.linear.bias.data.zero_()

        self.adaLN_modulation = nn.Linear(cond_dim, 2*hidden_size, bias=True)
        self.adaLN_modulation.weight.data.zero_()
        self.adaLN_modulation.bias.data.zero_()

    def forward(self, x: Tensor, c: Tensor) -> Tensor:
        shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
        x = modulate(x=self.norm_final(x), shift=shift, scale=scale)
        x = self.linear(x)

        return x
    
class Transformer(nn.Module):
    def __init__(
            self,
            vocab_size: int,
            masked: bool,
            config,
    ) -> None:
        super().__init__()

        self.config = config
        self.vocab_size = vocab_size

        self.vocab_embed = nn.Embedding(self.vocab_size, config.hidden_size)

        self.time_embedding = TimestepEmbedder(hidden_size=config.cond_dim)
        self.rotary_emb = Rotary(dim=config.hidden_size // config.n_heads)

        self.blocks = nn.ModuleList(
            [
                DDiTBlock(
                    dim=config.hidden_size,
                    n_heads=config.n_heads,
                    cond_dim=config.cond_dim,
                    dropout=config.dropout,
                )
                for _ in range(config.n_blocks)
            ]
        )

        self.output_layer = DDitFinalLayer(
            hidden_size=config.hidden_size,
            out_channels=vocab_size,
            cond_dim=config.cond_dim,
        )

    def forward(self, x_t: Tensor, time: Tensor) -> Tensor:
        x = self.vocab_embed(x_t)
        c = F.silu(self.time_embedding(time=time))

        rotary_cos_sin = self.rotary_emb(x=x)

        for i in range(len(self.blocks)):
            x = self.blocks[i](x=x, rotary_cos_sin=rotary_cos_sin, c=c)

        x = self.output_layer(x=x, c=c)

        return x

In [None]:
model = Transformer(
    config=config,
    vocab_size=config.vocab_size,
    masked=False
    )

state_dict = load_file("path/to/model.safetensors", device="cpu")
model.load_state_dict(state_dict)

<All keys matched successfully>

In [7]:
@dataclass
class SchedulerOutput:
    alpha_t: Tensor
    sigma_t: Tensor

    d_alpha_t: Tensor
    d_sigma_t: Tensor

class Scheduler:
    def __call__(self, t: Tensor) -> SchedulerOutput:
        ...

class PolynomialConvexScheduler(Scheduler):
    def __init__(self, n: Union[float, int]) -> None:
        self.n = n

    def __call__(self, t: Tensor) -> SchedulerOutput:
        return SchedulerOutput(
            alpha_t=t**self.n,
            sigma_t=1 - t**self.n,
            d_alpha_t=self.n * (t**(self.n - 1)),
            d_sigma_t=-self.n * (t**(self.n - 1))
        )
    
@dataclass
class DiscretePathSample:
    x_1: Tensor
    x_0: Tensor
    t:   Tensor
    x_t: Tensor

class ProbPath:
    def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample:
        ...

    def assert_sample_shape(self, x_0: Tensor, x_1: Tensor, t: Tensor):
        assert t.ndim == 1, "t must have shape [batch_size]"
        assert t.shape[0] == x_0.shape[0] == x_1.shape[0], "Mismatch Batch Size"

class MixtureDiscreteProbPath(ProbPath):
    def __init__(self, scheduler: Scheduler) -> None:
        super().__init__()
        self.scheduler = scheduler

    def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample:
        self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t)

        B, L = x_1.shape

        sigma_t = self.scheduler(t).sigma_t
        sigma_t = rearrange(sigma_t, 'd -> d 1')
        sigma_t = sigma_t.expand_as(x_1)

        source_indices = torch.rand(size=x_1.shape, device=x_1.device) < sigma_t
        x_t = torch.where(condition=source_indices, input=x_0, other=x_1)

        return DiscretePathSample(
            x_t = x_t,
            x_1 = x_1,
            x_0 = x_0,
            t = t,
        )

class UniformSourceDistribution:
    def __init__(self, vocab_size: int) -> None:
        self.vocab_size = vocab_size

    def sample(self, tensor_size: Tuple[int, ...], device: torch.device) -> Tensor:
        return torch.randint(size=tensor_size, high=self.vocab_size, device=device)

    def sample_like(self, tensor_like: Tensor) -> Tensor:
        return torch.randint_like(tensor_like, high=self.vocab_size)
    
class MixtureDiscreteEulerSolver:
    def __init__(
            self,
            model: nn.Module,
            path: MixtureDiscreteProbPath,
            vocabulary_size: int,
            source_distribution_p: Optional[Tensor] = None,
    ):
        super().__init__()
        self.model = model
        self.path = path
        self.vocabulary_size = vocabulary_size
        self.source_distribution_p = source_distribution_p

    def sample(
            self,
            x_init: Tensor,
            step_size: float,
            cond_text: Optional[Tensor],
            **model_extras,
    )-> Tensor:
        steps_counter = 0

        x_t = x_init.clone()
        t = torch.Tensor([0]).to(x_init.device)
        # res = [(x_t, t)]

        if cond_text is not None:
            cond_text = cond_text.to(x_init.device)
            x_t[:, :len(cond_text)] = cond_text


        ctx = tqdm(total=1.0, desc=f"NFE: {steps_counter}")
        
        with ctx:
            while t < 1.0 - 1e-3:
                p_1t = self.model(x_t, t.repeat(x_t.shape[0]))
                h = min(step_size, 1.0 - t.item())

                scheduler_output = self.path.scheduler(t=t)

                k_t = scheduler_output.alpha_t
                d_k_t = scheduler_output.d_alpha_t

                one_hot_x_t = F.one_hot(x_t, num_classes=self.vocabulary_size).float()

                u = (p_1t - one_hot_x_t) * ((d_k_t)/(1 - k_t)) 
                x_t = torch.distributions.Categorical(probs=(one_hot_x_t + h*u)).sample()
                
                if cond_text is not None:
                    x_t[:, :len(cond_text)] = cond_text
                
                t = t + h
                # res.append((x_t, t))
                steps_counter += 1

                ctx.n = t.item()
                ctx.refresh()
                ctx.set_description(f"NFE: {steps_counter}")

        return x_t
    
class WrappedModel(nn.Module):
    def __init__(self, model: nn.Module) -> None:
        super().__init__()
        self.model = model

    def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
        # Note: logit's precision is important.
        return torch.softmax(self.model(x_t=x, time=t, **extras).float(), -1)

In [13]:
def generate_samples(
        wrapped_probability_denoiser: WrappedModel,
        vocab_size: int,
        tokenizer: GPT2TokenizerFast,
        rank: int,
        device: torch.device,
        path: ProbPath,
        source_dist: UniformSourceDistribution,
        cond_text: Tensor,
        sample_batch_size: int,
        sequence_length: int,
        sampling_steps: int,
        sample_dir: Optional[Path] = None,
):
    
    solver = MixtureDiscreteEulerSolver(
        model=wrapped_probability_denoiser,
        path=path,
        vocabulary_size=vocab_size,
    )

    x_init = source_dist.sample(
        tensor_size=(sample_batch_size, sequence_length), device=device
    )

    sample_hist = solver.sample(
        x_init=x_init,
        step_size=1 / sampling_steps,
        cond_text=cond_text
    )

    return tokenizer.batch_decode(sample_hist.detach().cpu())

In [14]:
device = torch.device("cuda")

wrapped_model = WrappedModel(model).to(device).eval()

vocab_size = config.vocab_size
source_distribution = UniformSourceDistribution(vocab_size=vocab_size)
scheduler = PolynomialConvexScheduler(n=2)
path = MixtureDiscreteProbPath(scheduler=scheduler)
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

In [15]:
cond_text = "Princess Peach uses a letter to invite Mario to come to her castle for a cake she has baked for him. When he arrives, Mario discovers that Bowser has invaded the castle and"
cond_text = torch.Tensor(tokenizer(cond_text)["input_ids"])

In [16]:
config.sample_batch_size = 1

sample_hist = generate_samples(
    wrapped_probability_denoiser=wrapped_model,
    vocab_size=config.vocab_size,
    tokenizer=tokenizer,
    rank=0,
    device=device,
    path=path,
    source_dist=source_distribution,
    cond_text=cond_text,
    sample_batch_size=config.sample_batch_size,
    sequence_length=150,
    sampling_steps=1024,
    sample_dir=config.output_dir,
)

NFE: 0:   0%|          | 0/1.0 [00:00<?, ?it/s]

In [17]:
sample_hist

['Princess Peach uses a letter to invite Mario to come to her castle for a cake she has baked for him. When he arrives, Mario discovers that Bowser has invaded the castle and become trapped in the land scare Cousins, leaving the castle to let him remove herself. However, his lord would act too swiftly to drain water and allows Mario to successfully destroy Maurice. Outside the castle, Mario regains power to retake the castle, which he causes sewers to increase his cost. Rod Tim Owing to Coventry perceived that Nigel a "ombensive condition" and fears that he could not make life in the castle but bring Nigel back to the dinner party.\n<|endoftext|> The castle begins and toss them onto the castle. Continuing south, the castle to exit']