In [1]:
import torch
from datasets import load_dataset

class DatasetFromJson(torch.utils.data.Dataset):
    def __init__(
        self,
        json_file: str, 
        tokenizer: str,
        vae_tokenizer: str,
        max_input_length_limit: int = 18000,
    ):
        
        self.tokenizer = tokenizer
        self.vae_tokenizer = vae_tokenizer
        self.max_input_length_limit = max_input_length_limit

        self.data = load_dataset('json', data_files=json_file)['train']

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        example = self.data[index]
        
        input_, output = example['input'], example['output']
        input_ = f"{input_}\n<bot>"
        output = f"{output}{self.tokenizer.eos_token}"

        tokenized_input = self.tokenizer(input_)
        tokenized_output = self.vae_tokenizer(output)

        return {
            "input": tokenized_input,
            "output": tokenized_output,
        }

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

class OmniGenCollator:
    def __init__(self, tokenizer, hidden_size=3072):
        self.tokenizer = tokenizer
        self.hidden_size = hidden_size

    def __call__(self, batch):
        """
        batch: list of {"input": BatchEncoding, "output": BatchEncoding}
        Returns dict with keys:
        - input_ids, attention_mask        (for the encoder/prompt)
        - labels                            (for the decoder/LM loss)
        """
        # 1. Separate inputs & outputs
        inputs  = [example["input"]  for example in batch]
        outputs = [example["output"] for example in batch]

        # 2. Pad both to the longest sequence in this batch
        batch_inputs  = self.tokenizer.pad(inputs, padding=True, return_tensors="pt")

        # 3. Create position ids.
        attn = batch_inputs["attention_mask"]
        position_ids = torch.cumsum(attn, dim=1) * attn

        B, N = position_ids.size()
        max_output_len = max(len(o["input_ids"]) for o in outputs)
        extra_cols    = max_output_len + 1

        lengths = attn.sum(dim=1)

        base = torch.arange(1, extra_cols+1, device=position_ids.device)
        pad  = base.unsqueeze(0).expand(B, -1) + lengths.unsqueeze(1)

        position_ids = torch.cat([position_ids, pad], dim=1)

        # 4. Create attention mask.
        temp_l = torch.sum(batch_inputs['attention_mask'], dim=-1)

        text_length = batch_inputs['attention_mask'].size(-1)
        seq_len = text_length + max_output_len + 1

        attn_masks = []
        for idx, i in enumerate(temp_l):
            temp_mask = torch.tril(torch.ones(size=(i+1, i+1)))

            image_mask = torch.zeros(size=(i+1, max_output_len))
            temp_mask = torch.cat([temp_mask, image_mask], dim=-1)

            image_mask = torch.ones(size=(max_output_len, i+max_output_len+1))
            temp_mask = torch.cat([temp_mask, image_mask], dim=0)

            pad_l = text_length - i
            if pad_l > 0:
                pad_mask = torch.zeros(size=(i+1+max_output_len, pad_l))
                temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)

                pad_mask = torch.ones(size=(pad_l, seq_len))
                temp_mask = torch.cat([pad_mask, temp_mask], dim=0)

            true_img_length = len(outputs[idx]['input_ids'])
            pad_img_length = max_output_len - true_img_length
            if pad_img_length > 0:
                temp_mask[:, -pad_img_length:] = 0

            attn_masks.append(temp_mask.unsqueeze(0))

        return {
            "input_ids": batch_inputs.input_ids,
            "attention_mask": torch.cat(attn_masks, dim=0),
            "position_ids": position_ids,
            "output": [torch.tensor(i.input_ids).unsqueeze(0) for i in outputs],
        }

In [3]:
import torch

def vae_encode_list_new(model, output):
    _ = model.eval()

    output_latents = []
    for i in output:
        with torch.no_grad():
            o = model(input_ids=i, output_hidden_states=True,)
        o = o.hidden_states[-1]
        output_latents.append(o)

    return output_latents

In [None]:
from torch import nn
import math

class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

        # Initialize timestep embedding MLP:
        nn.init.normal_(self.mlp[0].weight, std=0.02)
        nn.init.normal_(self.mlp[2].weight, std=0.02)

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t, dtype=torch.float32):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
        print(t_freq.shape)
        t_emb = self.mlp(t_freq)
        return t_emb
    
def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

class TextFinalLayer(nn.Module):
    """
    The “final layer” for a text‐only OmniGen: normalize, apply AdaLN
    from the timestep embedding, then project to vocab logits.
    """
    def __init__(self, hidden_size: int, output_size: int):
        super().__init__()
        # 1) Final norm (no affine because we modulate instead)
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        # 2) Modulation MLP: turn timestep embedding -> [shift | scale]
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )
        # 3) LM head
        self.lm_head = nn.Linear(hidden_size, output_size, bias=False)

    def forward(self, x: torch.FloatTensor, t_emb: torch.FloatTensor):
        """
        x     : (B, L, H)  — hidden states from your transformer
        t_emb : (B, H)     — timestep (or other) conditioning embedding
        returns logits: (B, L, output_size)
        """
        # compute shift & scale
        shift, scale = self.adaLN_modulation(t_emb).chunk(2, dim=1)  # each (B, H)

        # normalize + apply FiLM-style modulation
        x = modulate(self.norm_final(x), shift, scale)              # (B, L, H)

        # project to vocabulary
        logits = self.lm_head(x)                                    # (B, L, V)
        return logits

In [5]:

from typing import List, Optional, Tuple, Union

import torch
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
)
from transformers import Phi3Config, Phi3Model
from transformers.cache_utils import Cache, DynamicCache
from transformers.utils import logging

logger = logging.get_logger(__name__)


class Phi3Transformer(Phi3Model):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
    We only modified the attention mask
    Args:
        config: Phi3Config
    """
    def prefetch_layer(self, layer_idx: int, device: torch.device):
        "Starts prefetching the next layer cache"
        with torch.cuda.stream(self.prefetch_stream):
            # Prefetch next layer tensors to GPU
            for name, param in self.layers[layer_idx].named_parameters():
                param.data = param.data.to(device, non_blocking=True)

    def evict_previous_layer(self, layer_idx: int):
        "Moves the previous layer cache to the CPU"
        prev_layer_idx = layer_idx - 1
        for name, param in self.layers[prev_layer_idx].named_parameters():
            param.data = param.data.to("cpu", non_blocking=True)
            
    def get_offlaod_layer(self, layer_idx: int, device: torch.device):
        # init stream
        if not hasattr(self, "prefetch_stream"):
            self.prefetch_stream = torch.cuda.Stream()

        # delete previous layer
        torch.cuda.current_stream().synchronize()
        self.evict_previous_layer(layer_idx)
        
        # make sure the current layer is ready
        torch.cuda.synchronize(self.prefetch_stream)

        # load next layer
        self.prefetch_layer((layer_idx + 1) % len(self.layers), device)
        

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        offload_model: Optional[bool] = False,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        # kept for BC (non `Cache` `past_key_values` inputs)
        return_legacy_cache = False
        if use_cache and not isinstance(past_key_values, Cache):
            return_legacy_cache = True
            if past_key_values is None:
                past_key_values = DynamicCache()
            else:
                past_key_values = DynamicCache.from_legacy_cache(past_key_values)
                logger.warning_once(
                    "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
                    "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
                    "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
                )

        # if inputs_embeds is None:
        #     inputs_embeds = self.embed_tokens(input_ids)

        # if cache_position is None:
        #     past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
        #     cache_position = torch.arange(
        #         past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
        #     )
        # if position_ids is None:
        #     position_ids = cache_position.unsqueeze(0)

        if attention_mask is not None and attention_mask.dim() == 3:
            dtype = inputs_embeds.dtype
            min_dtype = torch.finfo(dtype).min
            attention_mask = (1 - attention_mask) * min_dtype
            attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
        else:
            raise Exception("attention_mask parameter was unavailable or invalid")
            # causal_mask = self._update_causal_mask(
            #     attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
            # )

        hidden_states = inputs_embeds

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None

        layer_idx = -1
        for decoder_layer in self.layers:
            layer_idx += 1

            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    attention_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                )
            else:
                if offload_model and not self.training:
                    self.get_offlaod_layer(layer_idx, device=inputs_embeds.device)
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if return_legacy_cache:
            next_cache = next_cache.to_legacy_cache()

        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )



In [6]:
from torch import nn
from torch.nn import functional as F
from transformers import Phi3Config

class OmniGen(nn.Module):
    def __init__(self, new_embed_size, hidden_size=3072):
        super().__init__()

        config = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
        self.llm = Phi3Transformer(config)
        self.llm.resize_token_embeddings(new_embed_size)

        self.time_token = TimestepEmbedder(hidden_size)
        self.t_embedder = TimestepEmbedder(hidden_size)

        self.final_layer = TextFinalLayer(hidden_size=hidden_size, output_size=hidden_size)

    def pad_x(self, x):
        max_len = max(t.size(1) for t in x)
        num_tokens = [t.size(1) for t in x]
        padded_latents = []
        for t in x:
            L = t.size(1)
            pad_len = max_len - L
            if pad_len > 0:
                t = F.pad(t, (0, 0, 0, pad_len), mode="constant", value=0)
            padded_latents.append(t)

        batch = torch.cat(padded_latents, dim=0)

        return batch, num_tokens

    def forward(
        self, 
        x,
        timestep,
        input_ids,
        attention_mask,
        position_ids,
        past_key_values=None,
        offload_model:bool=False
    ):
        x_padded, num_tokens = self.pad_x(x)
        time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)   
        condition_embeds = self.llm.embed_tokens(input_ids).clone()

        input_emb = torch.cat([condition_embeds, time_token, x_padded], dim=1)

        output = self.llm(
            inputs_embeds=input_emb, 
            attention_mask=attention_mask, 
            position_ids=position_ids, 
            past_key_values=past_key_values, 
            offload_model=offload_model
        )

        out_embedding = output.last_hidden_state[:, -max(num_tokens):]
        time_emb = self.t_embedder(timestep, dtype=x[0].dtype)
        out_x = self.final_layer(out_embedding, time_emb)

        latents = []
        for i in range(out_x.size(0)):
            latent = out_x[i:i+1, :num_tokens[i]]
            latents.append(latent)

        return latents

In [None]:
def sample_x0(x1):
    """Sampling x0 & t based on shape of x1 (if needed)
    Args:
      x1 - data point; [batch, *dim]
    """
    if isinstance(x1, (list, tuple)):
        x0 = [torch.randn_like(img_start) for img_start in x1]
    else:
        x0 = torch.randn_like(x1)

    return x0

def sample_timestep(x1):
    u = torch.normal(mean=0.0, std=1.0, size=(len(x1),))
    t = 1 / (1 + torch.exp(-u))
    t = t.to(x1[0])
    return t

def training_losses(model, x1, model_kwargs=None):
    """Loss for training torche score model
    Args:
    - model: backbone model; could be score, noise, or velocity
    - x1: datapoint
    - model_kwargs: additional arguments for torch model
    """
    B = len(x1)

    x0 = sample_x0(x1)
    t = sample_timestep(x1)

    xt = [t[i] * x1[i] + (1 - t[i]) * x0[i] for i in range(B)]
    ut = [x1[i] - x0[i] for i in range(B)]

    model_output = model(xt, t, **model_kwargs)  # (B, C, H, W) -> (B, L, D)

    loss_per_sample = torch.stack([
        ((ut[i] - model_output[i])**2).mean()
        for i in range(len(model_output))
    ], dim=0)
    loss = loss_per_sample.mean()

    return loss

In [8]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True)
special_tokens_dict = {"additional_special_tokens": ["<bot>"]}
num_added = tokenizer.add_special_tokens(special_tokens_dict)

vae_tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True)

dataset = DatasetFromJson(json_file="cd4_train.jsonl", tokenizer=tokenizer, vae_tokenizer=vae_tokenizer)
dataloader = DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
    collate_fn=OmniGenCollator(tokenizer=tokenizer),   # <-- all padding happens here
    num_workers=0,
)
data = next(iter(dataloader))

model = OmniGen(
    new_embed_size=len(tokenizer),
)

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [None]:
# TODO: linear layer to map vae output to D=3072
# TODO: add <cfg> for condition dropout 
# TODO: use our VAE
# TODO: train on hf data

In [9]:
from transformers import AutoModelForCausalLM
vae = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True)

output_encoded = vae_encode_list_new(
    vae,  # insert own model here
    data['output']
)

`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.
Loading checkpoint shards: 100%|██████████| 2/2 [00:30<00:00, 15.33s/it]
You are not running the flash-attention implementation, expect numerical differences.


In [None]:
loss = training_losses(
    model, 
    output_encoded, 
    model_kwargs={
        "input_ids": data['input_ids'],
        "attention_mask": data['attention_mask'],
        "position_ids": data['position_ids'],
    }
)


torch.Size([16, 256])
torch.Size([16, 256])


In [11]:
for i, j, k in zip(model_output, xt, ut):
    print(i.shape, j.shape, k.shape)

torch.Size([1, 27, 3072]) torch.Size([1, 27, 3072]) torch.Size([1, 27, 3072])
torch.Size([1, 28, 3072]) torch.Size([1, 28, 3072]) torch.Size([1, 28, 3072])
torch.Size([1, 32, 3072]) torch.Size([1, 32, 3072]) torch.Size([1, 32, 3072])
torch.Size([1, 25, 3072]) torch.Size([1, 25, 3072]) torch.Size([1, 25, 3072])
torch.Size([1, 25, 3072]) torch.Size([1, 25, 3072]) torch.Size([1, 25, 3072])
torch.Size([1, 31, 3072]) torch.Size([1, 31, 3072]) torch.Size([1, 31, 3072])
torch.Size([1, 24, 3072]) torch.Size([1, 24, 3072]) torch.Size([1, 24, 3072])
torch.Size([1, 26, 3072]) torch.Size([1, 26, 3072]) torch.Size([1, 26, 3072])
torch.Size([1, 27, 3072]) torch.Size([1, 27, 3072]) torch.Size([1, 27, 3072])
torch.Size([1, 30, 3072]) torch.Size([1, 30, 3072]) torch.Size([1, 30, 3072])
torch.Size([1, 32, 3072]) torch.Size([1, 32, 3072]) torch.Size([1, 32, 3072])
torch.Size([1, 27, 3072]) torch.Size([1, 27, 3072]) torch.Size([1, 27, 3072])
torch.Size([1, 28, 3072]) torch.Size([1, 28, 3072]) torch.Size([

In [None]:
from diffusers import get_scheduler
import math

epochs = 1

lr = 1e-4
lr_scheduler = "constant"
lr_warmup_steps = 1000
gradient_accumulation_steps = 1
adam_weight_decay = 0.0
report_to = "wandb"
results_dir = "results"
mixed_precision = "bf16"


opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=adam_weight_decay)

num_update_steps_per_epoch = math.ceil(len(dataloader) / gradient_accumulation_steps)
max_train_steps = epochs * num_update_steps_per_epoch
lr_scheduler = get_scheduler(
    lr_scheduler,
    optimizer=opt,
    num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
    num_training_steps=max_train_steps * gradient_accumulation_steps,
)

In [None]:
from time import time

_ = model.train()

train_steps = 0
running_loss = 0
start_time = time()

for epoch in range(epochs):
    for step, data in enumerate(dataloader):
        with torch.no_grad():
            output_encoded = vae_encode_list_new(
                vae,  # insert own model here
                data['output']
            )

        loss = training_losses(
            model, 
            output_encoded, 
            model_kwargs={
                "input_ids": data['input_ids'],
                "attention_mask": data['attention_mask'],
                "position_ids": data['position_ids'],
            }
        )
        running_loss += loss.item()
        
        loss = loss / gradient_accumulation_steps
        loss.backward()

        # 4) optimizer step when we've accumulated enough
        if (step + 1) % gradient_accumulation_steps == 0:
            opt.step()
            lr_scheduler.step()
            opt.zero_grad()
            train_steps += 1
            running_loss += loss.item() * gradient_accumulation_steps  # un-normalize
