In [1]:
import dataloader
import transformers
import tokenizers
import datasets
import torch
import itertools
import functools
import os
import lightning as L
import importlib
import torchmetrics
from dataclasses import dataclass
import math
import hydra
import typing
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm
/home/davide/Documents/SMDM/.venv/lib/python3.12/site-packages/lightning/fabric/__init__.py:41: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.


In [2]:
import models.dit
import noise_schedule
importlib.reload(models.dit)
importlib.reload(noise_schedule)
from models.dit import DIT
from models.ema import ExponentialMovingAverage

In [3]:
def get_config():
    params = ['model=tiny',\
                'data=openwebtext-split',
                'wandb.name=mdlm-owt',\
                'parameterization=subs',\
                'model.length=1024',\
                'eval.compute_generative_perplexity=True',\
                'sampling.steps=1000']
    with hydra.initialize(version_base=None, config_path="configs"):
        config = hydra.compose(config_name="config", overrides=params)
    return config

config = get_config()
config.model

{'name': 'tiny', 'type': 'ddit', 'hidden_size': 256, 'cond_dim': 64, 'length': 1024, 'n_blocks': 8, 'n_heads': 8, 'scale_by_sigma': True, 'dropout': 0.1, 'tie_word_embeddings': False}

In [4]:
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
tokenizer._tokenizer.post_processor = tokenizers.processors.BertProcessing(
      (tokenizer.bos_token, tokenizer.bos_token_id),
      (tokenizer.eos_token, tokenizer.eos_token_id))

tokenizer.add_special_tokens({'pad_token': '[PAD]'})



1

In [5]:
B = 8
T = block_size = 512
C = 256

datasets_directory = '/home/davide/Documents/SMDM/data'
n_processes = 15

filename = f'{'openwebtext'}_{'train'}_bs{block_size}_wrapped.dat'
_path = os.path.join(datasets_directory, filename)
ema = 0.9999


In [6]:
dataset = datasets.load_dataset(
      'openwebtext',
      split='train[:-100000]',
      cache_dir=datasets_directory,
      streaming=False,
      trust_remote_code=True)

In [7]:
data = dataset.select(range(10_000))  

In [8]:
EOS = tokenizer.encode(tokenizer.eos_token)[0]
BOS = tokenizer.encode(tokenizer.bos_token)[0]

In [9]:
def preprocess_and_tokenize(example):
    text = example['text']

    tokenizer.padding_side = 'right'
    tokenizer.truncation_side = 'right'


    tokens = tokenizer(text,
                        add_special_tokens=False,
                        return_attention_mask=False,
                        return_token_type_ids=False)
    tokens = {'input_ids':
            [t + [EOS] for t in tokens['input_ids']]}
    # Still missing BOS, but will be added in group_texts

    return tokens


tokenized_dataset = data.map(
preprocess_and_tokenize,
batched=True,
num_proc=n_processes,
load_from_cache_file=True,
desc='Tokenizing')

tokenized_dataset = tokenized_dataset.remove_columns('text')

In [10]:
def _group_texts(examples, T, bos, eos):
  """ Concatenate all texts.
      T: block_size; bos, eos: begin and end of string token """
  concatenated_examples = list(itertools.chain(* examples['input_ids']))
  # TODO(yair): look into not dropping the remainder but rather padding it.
  # We drop the small remainder, and if the total_length < block_size - 2
  # we exclude this batch and return an empty dict.
  # We could add padding if the model supported it instead of
  # this drop, you can customize this part to your needs.

  chunk_length = T - 2  # [BOS] and [EOS] to be added
  total_length = (len(concatenated_examples) // chunk_length) * chunk_length
  # Split by chunks of max_len.
  result = {}
  _values = []
  _attn_masks = []
  for i in range(0, total_length, chunk_length):
    _values.append(
      [bos]
      + concatenated_examples[i : i + chunk_length]
      + [eos])
    _attn_masks.append(torch.ones(T))
  result['input_ids'] = _values
  result['attention_mask'] = _attn_masks
  return result

group_texts = functools.partial(_group_texts, T=block_size, bos=BOS, eos=EOS)   # _group_texts with those argumets

In [11]:

chunked_dataset = tokenized_dataset.map(
      group_texts,
      batched=True,
      num_proc=n_processes,
      load_from_cache_file=True,
      desc='Grouping')

chunked_dataset.save_to_disk(_path)
chunked_dataset = chunked_dataset.with_format('torch')

Saving the dataset (1/1 shards): 100%|██████████| 21845/21845 [00:00<00:00, 159727.65 examples/s]


In [12]:
train_loader = torch.utils.data.DataLoader(
    chunked_dataset,
    batch_size=B,
    num_workers=6,
    pin_memory=False,
    shuffle= True,
    persistent_workers=True)
train_loader.tokenizer = tokenizer

In [13]:
class RandomFaultTolerantSampler(torch.utils.data.RandomSampler):

  def __init__(self, *args, generator=None, **kwargs):
    # TD [2022-07-17]: We don't force the seed to be zero. We generate random seed,
    # which should be reproducible if pl.seed_everything was called beforehand.
    # This means that changing the seed of the experiment will also change the
    # sampling order.
    if generator is None:
      seed = int(torch.empty((), dtype=torch.int64).random_().item())
      generator = torch.Generator().manual_seed(seed)
    kwargs.pop('shuffle', None)
    super().__init__(*args, generator=generator, **kwargs)
    self.counter = 0
    self.restarting = False

  def state_dict(self):
    return {'random_state': self.generator.get_state(),
            'counter': self.counter}

  def load_state_dict(self, state_dict):
    self.generator.set_state(state_dict.get('random_state'))
    self.counter = state_dict['counter']
    # self.start_counter = self.counter
    self.restarting = True

  # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
  # epoch, and subsequent epoch will have very few batches.

  def __iter__(self) -> typing.Iterator[int]:
    n = len(self.data_source)

    self.state = self.generator.get_state()
    indices = torch.randperm(n, generator=self.generator).tolist()

    if not self.restarting:
      self.counter = 0
    else:
      indices = indices[self.counter:]
      self.restarting = False

    for index in indices:
      self.counter += 1
      yield index

    self.counter = 0

In [14]:
LOG2 = math.log(2)
@dataclass
class Loss:
  loss: torch.FloatTensor
  nlls: torch.FloatTensor
  token_mask: torch.FloatTensor


class NLL(torchmetrics.aggregation.MeanMetric):
  pass


class BPD(NLL):
  def compute(self) -> torch.Tensor:
    """Computes the bits per dimension.

    Returns:
      bpd
    """
    return self.mean_value / self.weight / LOG2


class Perplexity(NLL):
  def compute(self) -> torch.Tensor:
    """Computes the Perplexity.

    Returns:
     Perplexity
    """
    return torch.exp(self.mean_value / self.weight)

In [15]:

def _sample_categorical(categorical_probs):
  gumbel_norm = ( 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log())
  return (categorical_probs / gumbel_norm).argmax(dim=-1)

In [16]:
def save_model(model, folder):
    file = folder + 'model.pth'
    torch.save(model.state_dict(), file)


def load_model(model, folder):
    file = folder + 'model.pth'
    try:
        model.load_state_dict(torch.load(file, weights_only=True))
        print('model loaded')
    except Exception as e:
        print("\n", "Model weights not avaiable", "\n")
        print(e)

In [None]:
class Diffusion(L.LightningModule):
  def __init__(
    self,
    config,
    tokenizer: transformers.PreTrainedTokenizer):
    super().__init__()
    self.config = config
    self.weights_folder = 'weights/'

    self.tokenizer = tokenizer
    self.vocab_size = self.tokenizer.vocab_size

    self.sampler = 'ddpm_cache'
    self.gen_ppl_eval_model_name_or_path = 'gpt2-large'

    self.antithetic_sampling = True
    self.importance_sampling = False
    self.change_of_variables = False

    self.mask_index = self.tokenizer.mask_token_id
    if (not hasattr(self.tokenizer, 'mask_token') or self.tokenizer.mask_token is None):
      self.mask_index = self.vocab_size
      self.vocab_size += 1
    
    self.backbone = models.dit.DIT(self.config, vocab_size=self.vocab_size)
    load_model(self.backbone, self.weights_folder)
    
    self.parameterization = 'subs'
    self.T = 0
    self.subs_masking = False

    self.softplus = torch.nn.Softplus()
    # metrics are automatically reset at end of epoch
    metrics = torchmetrics.MetricCollection({
      'nll': NLL(),
      'bpd': BPD(),
      'ppl': Perplexity(),
    })
    
    metrics.set_dtype(torch.float64)
    self.train_metrics = metrics.clone(prefix='train/')
    self.valid_metrics = metrics.clone(prefix='val/')
    self.test_metrics = metrics.clone(prefix='test/')

    # generative perplexity
    self.gen_ppl_metric = Perplexity()
    self.eval_model_tokenizer = transformers.AutoTokenizer.\
      from_pretrained(self.gen_ppl_eval_model_name_or_path)
    if self.eval_model_tokenizer.pad_token is None:
      self.eval_model_tokenizer.pad_token =\
          self.eval_model_tokenizer.eos_token
      self.eval_model_tokenizer.pad_token_id =\
          self.eval_model_tokenizer.eos_token_id

    self.noise = noise_schedule.LogLinearNoise() #get_noise(self.config,dtype=self.dtype)
    self.ema = None
    if ema > 0:
      self.ema = ExponentialMovingAverage(
        itertools.chain(self.backbone.parameters(),
                        self.noise.parameters()),
        decay=ema)
    
    self.lr = 3e-4
    self.sampling_eps = 1e-3
    self.time_conditioning = True
    self.neg_infinity = -1000000.0
    self.fast_forward_epochs = None
    self.fast_forward_batches = None


  def configure_optimizers(self):
      optimizer = torch.optim.AdamW(
        itertools.chain(self.backbone.parameters(),
                        self.noise.parameters()),
        lr=self.config.optim.lr,
        betas=(self.config.optim.beta1,
              self.config.optim.beta2),
        eps=self.config.optim.eps,
        weight_decay=self.config.optim.weight_decay)

      scheduler = hydra.utils.instantiate(
        self.config.lr_scheduler, optimizer=optimizer)
      scheduler_dict = {
        'scheduler': scheduler,
        'interval': 'step',
        'monitor': 'val/loss',
        'name': 'trainer/lr',
      }
      return [optimizer], [scheduler_dict]


  def on_train_start(self):
    if self.ema:
      self.ema.move_shadow_params_to_device(self.device)
    
    sampler_cls = dataloader.RandomFaultTolerantSampler
    updated_dls = []
    for dl in self.trainer.fit_loop._combined_loader.flattened:
      if hasattr(dl.sampler, 'shuffle'):
        dl_sampler = sampler_cls(
          dl.dataset, shuffle=dl.sampler.shuffle)
      else:
        dl_sampler = sampler_cls(dl.dataset)

      updated_dls.append(
        torch.utils.data.DataLoader(
          dl.dataset,
          batch_size = B,
          num_workers=n_processes,
          pin_memory=True,
          sampler=dl_sampler,
          shuffle=False,
          persistent_workers=True))
    self.trainer.fit_loop._combined_loader.flattened = updated_dls


  def optimizer_step(self, *args, **kwargs):
    super().optimizer_step(*args, **kwargs)
    if self.ema:
      self.ema.update(itertools.chain(
        self.backbone.parameters(),
        self.noise.parameters()))


  def training_step(self, batch, batch_idx):
    loss = self._compute_loss(batch, prefix='train')
    self.log(name='trainer/loss',
             value=loss.item(),
             on_step=True,
             on_epoch=False,
             sync_dist=True)
    return loss
  

  def _compute_loss(self, batch, prefix):
    
    if 'attention_mask' in batch:
      attention_mask = batch['attention_mask']
    else:
      attention_mask = None

    losses = self._loss(batch['input_ids'], attention_mask)
    loss = losses.loss

    if prefix == 'train':
      self.train_metrics.update(losses.nlls, losses.token_mask)
      metrics = self.train_metrics
    elif prefix == 'val':
      self.valid_metrics.update(losses.nlls, losses.token_mask)
      metrics = self.valid_metrics
    elif prefix == 'test':
      self.test_metrics.update(losses.nlls, losses.token_mask)
      metrics = self.test_metrics
    else:
      raise ValueError(f'Invalid prefix: {prefix}')

    self.log_dict(metrics,
                  on_step=False,
                  on_epoch=True,
                  sync_dist=True)
    return loss


  def _loss(self, x0, attention_mask):
    (input_tokens, output_tokens, attention_mask) = self._maybe_sub_sample(x0, attention_mask)
    
    loss = self._forward_pass_diffusion(input_tokens)
    
    nlls = loss * attention_mask
    count = attention_mask.sum()

    batch_nll = nlls.sum()
    token_nll = batch_nll / count

    return Loss(loss=token_nll,
                nlls=nlls,
                token_mask=attention_mask)


  def _maybe_sub_sample(self, x0, attention_mask):
    seqlen = x0.shape[1]
    if seqlen > self.config.model.length:
      assert seqlen == 2 * self.config.model.length

      start = np.random.choice(self.config.model.length)
      end = start + self.config.model.length
      input_tokens = x0[:, start: end]
      output_tokens = x0[:, start + 1: end + 1]
      new_attention_mask = attention_mask[:, start: end]

      # Helps with validation PPL, since the val
      # examples will all start and end with BOS/EOS
      input_tokens[:, 0] = self.tokenizer.bos_token_id
      output_tokens[:, -1] = self.tokenizer.eos_token_id
    else:
      input_tokens = x0
      output_tokens = None
      new_attention_mask = attention_mask
    return input_tokens, output_tokens, new_attention_mask


  def _reconstruction_loss(self, x0):
    t0 = torch.zeros(x0.shape[0], dtype=self.dtype,
                     device=self.device)
    assert self.config.noise.type == 'loglinear'
    # The above assert is for d3pm parameterization
    unet_conditioning = self.noise(t0)[0][:, None]
    model_output_t0 = self.forward(x0, unet_conditioning)
    return - torch.gather(input=model_output_t0,
                          dim=-1,
                          index=x0[:, :, None]).squeeze(-1)


  def _forward_pass_diffusion(self, x0):
    t = self._sample_t(x0.shape[0], x0.device)
    if self.T > 0:
      t = (t * self.T).to(torch.int)
      t = t / self.T                       # t \in {1/T, 2/T, ..., 1}

    if self.change_of_variables:
      unet_conditioning = t[:, None]
      f_T = torch.log1p(- torch.exp(- self.noise.sigma_max))
      f_0 = torch.log1p(- torch.exp(- self.noise.sigma_min))
      move_chance = torch.exp(f_0 + t * (f_T - f_0))
      move_chance = move_chance[:, None]
    else:
      sigma, dsigma = self.noise(t)
      unet_conditioning = sigma[:, None]
      move_chance = 1 - torch.exp(-sigma[:, None])

    xt = self.q_xt(x0, move_chance)
    
    model_output = self.forward(xt, unet_conditioning)
    
    # utils.print_nans(model_output, 'model_output')

    if self.parameterization == 'sedd':
      return dsigma[:, None] * self._score_entropy(
        model_output, sigma[:, None], xt, x0)
    
    if self.T > 0:
      diffusion_loss = self._d3pm_loss(model_output=model_output, xt=xt, x0=x0, t=t)
      if self.parameterization == 'd3pm':
        reconstruction_loss = self._reconstruction_loss(x0)
      elif self.parameterization == 'subs':
        reconstruction_loss = 0
      return reconstruction_loss + diffusion_loss
    
    # SUBS parameterization, continuous time.
    log_p_theta = torch.gather(
      input=model_output,
      dim=-1,
      index=x0[:, :, None]).squeeze(-1)
    
    if self.change_of_variables or self.importance_sampling:
      return log_p_theta * torch.log1p(- torch.exp(- self.noise.sigma_min))
    
    return - log_p_theta * (dsigma / torch.expm1(sigma))[:, None]
  

  def _sample_t(self, n, device):
    _eps_t = torch.rand(n, device=device)
    if self.antithetic_sampling:
      offset = torch.arange(n, device=device) / n
      _eps_t = (_eps_t / n + offset) % 1
    t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
    if self.importance_sampling:
      return self.noise.importance_sampling_transformation(t)
    return t
  

  def q_xt(self, x, move_chance):
    """Computes the noisy sample xt.

    Args:
      x: int torch.Tensor with shape (batch_size,
          diffusion_model_input_length), input. 
      move_chance: float torch.Tensor with shape (batch_size, 1).
    """
    move_indices = torch.rand(* x.shape, device=x.device) < move_chance
    xt = torch.where(move_indices, self.mask_index, x)
    return xt
  
  def forward(self, x, sigma):
    """Returns log score."""
    sigma = self._process_sigma(sigma)
    
    # with torch.cuda.amp.autocast(dtype=torch.float32):
    logits = self.backbone(x, sigma)
  
    if self.parameterization == 'subs':
      return self._subs_parameterization(logits=logits, xt=x)
    elif self.parameterization == 'sedd':
      return self._sedd_parameterization(logits=logits, xt=x, sigma=sigma)
    elif self.parameterization == 'd3pm':
      return self._d3pm_parameterization(logits=logits)
    return logits
  
  def _process_sigma(self, sigma):
    if sigma is None:
      assert self.parameterization == 'ar'
      return sigma
    if sigma.ndim > 1:
      sigma = sigma.squeeze(-1)
    if not self.time_conditioning:
      sigma = torch.zeros_like(sigma)
    assert sigma.ndim == 1, sigma.shape
    return sigma
  
  
  def _subs_parameterization(self, logits, xt):
    # log prob at the mask index = - infinity
    logits[:, :, self.mask_index] += self.neg_infinity
    
    # Normalize the logits such that x.exp() is
    # a probability distribution over vocab_size.
    log_probs = logits - torch.logsumexp(logits, dim=-1, keepdim=True)

    # Apply updates directly in the logits matrix.
    # For the logits of the unmasked tokens, set all values
    # to -infinity except for the indices corresponding to
    # the unmasked tokens.
    unmasked_indices = (xt != self.mask_index)
    log_probs[unmasked_indices] = self.neg_infinity
    log_probs[unmasked_indices, xt[unmasked_indices]] = 0
    return log_probs


  @torch.no_grad()
  def _sample(self, num_steps=1000, eps=1e-5):
    """Generate samples from the model."""
    x = self._sample_prior(4, 512)

    timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
    dt = (1 - eps) / num_steps
    p_x0_cache = None

    print("starting generation")
    for i in range(num_steps):
      print(f"\r{i} / {num_steps}", end="")
      t = timesteps[i] * torch.ones(x.shape[0], 1, device=self.device)
      if self.sampler == 'ddpm':
        x = self._ddpm_update(x, t, dt)
      elif self.sampler == 'ddpm_cache':
        p_x0_cache, x_next = self._ddpm_caching_update(
          x, t, dt, p_x0=p_x0_cache)
        if (not torch.allclose(x_next, x) or self.time_conditioning):
          p_x0_cache = None  # Disable caching
        x = x_next

    # last step, remove all noise by taking the argmax
    if self.config.sampling.noise_removal:
      t = timesteps[-1] * torch.ones(x.shape[0], 1,
                                     device=self.device)
      if self.sampler == 'analytic':
        x = self._denoiser_update(x, t)
      else:
        unet_conditioning = self.noise(t)[0]
        x = self.forward(x,unet_conditioning).argmax(dim=-1)
    return x


  def _ddpm_update(self, x, t, dt):
    sigma_t, _ = self.noise(t)
    sigma_s, _ = self.noise(t - dt)
    if sigma_t.ndim > 1:
      sigma_t = sigmcopy_flag * x + (1 - copy_flag) * _xa_t.squeeze(-1)
    if sigma_s.ndim > 1:
      sigma_s = sigma_s.squeeze(-1)
    assert sigma_t.ndim == 1, sigma_t.shape
    assert sigma_s.ndim == 1, sigma_s.shape
    move_chance_t = 1 - torch.exp(-sigma_t)
    move_chance_s = 1 - torch.exp(-sigma_s)
    move_chance_t = move_chance_t[:, None, None]
    move_chance_s = move_chance_s[:, None, None]
    unet_conditioning = sigma_t
    log_p_x0 = self.forward(x, unet_conditioning)
    assert move_chance_t.ndim == log_p_x0.ndim
    # Technically, this isn't q_xs since there's a division
    # term that is missing. This division term doesn't affect
    # the samples.
    q_xs = log_p_x0.exp() * (move_chance_t - move_chance_s)
    q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
    _x = _sample_categorical(q_xs)

    copy_flag = (x != self.mask_index).to(x.dtype)
    return copy_flag * x + (1 - copy_flag) * _x
  
  
  def _sample_prior(self, *batch_dims):
    return self.mask_index * torch.ones(* batch_dims, dtype=torch.int64)
  
diffusion = Diffusion(config, tokenizer)
diffusion

model loaded


Diffusion(
  (backbone): DIT(
    (vocab_embed): EmbeddingLayer()
    (sigma_map): TimestepEmbedder(
      (mlp): Sequential(
        (0): Linear(in_features=256, out_features=64, bias=True)
        (1): SiLU()
        (2): Linear(in_features=64, out_features=64, bias=True)
      )
    )
    (rotary_emb): Rotary()
    (blocks): ModuleList(
      (0-7): 8 x DDiTBlock(
        (norm1): LayerNorm()
        (attn_qkv): Linear(in_features=256, out_features=768, bias=False)
        (attn_out): Linear(in_features=256, out_features=256, bias=False)
        (dropout1): Dropout(p=0.1, inplace=False)
        (norm2): LayerNorm()
        (mlp): Sequential(
          (0): Linear(in_features=256, out_features=1024, bias=True)
          (1): GELU(approximate='tanh')
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
        (dropout2): Dropout(p=0.1, inplace=False)
        (adaLN_modulation): Linear(in_features=64, out_features=1536, bias=True)
      )
    )
    (output_la

In [18]:
"""
x = torch.randint(0, 1000, (B, T))
t = torch.zeros(B)
DIT(config, tokenizer.vocab_size).forward(x, t)
"""
None

In [23]:
optimizer, _ = diffusion.configure_optimizers()
optimizer = optimizer[0]
optimizer

AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 0.0003
    lr: 0.0
    maximize: False
    weight_decay: 0
)

In [None]:
"""
x = 0
for e in train_loader:
    x = e
    break


loss = diffusion.training_step(x, None)
loss
"""

/home/davide/Documents/SMDM/.venv/lib/python3.12/site-packages/lightning/pytorch/core/module.py:436: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


tensor(10.7847, grad_fn=<DivBackward0>)

In [None]:
optimizer.zero_grad()
#loss.backward()

In [21]:
optimizer.step()

In [20]:
class LossLogCallback(L.Callback):
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if trainer.global_step % 10 == 0:
            loss = outputs['loss'].item()
            print(f"Step: {trainer.global_step}, Loss: {loss:.4f}")
            save_model(pl_module.backbone, pl_module.weights_folder)

In [21]:
trainer = L.Trainer(
    max_epochs=3,  # Adjust as needed
    accelerator="cpu",
    devices=1,
    callbacks=[LossLogCallback()],
    enable_progress_bar=True,
    log_every_n_steps=10  # Log to console every 10 steps
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/davide/Documents/SMDM/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [22]:
trainer.fit(
    model=diffusion, 
    train_dataloaders=train_loader
)


  | Name           | Type             | Params
----------------------------------------------------
0 | backbone       | DIT              | 32.9 M
1 | softplus       | Softplus         | 0     
2 | train_metrics  | MetricCollection | 0     
3 | valid_metrics  | MetricCollection | 0     
4 | test_metrics   | MetricCollection | 0     
5 | gen_ppl_metric | Perplexity       | 0     
6 | noise          | LogLinearNoise   | 0     
----------------------------------------------------
32.9 M    Trainable params
0         Non-trainable params
32.9 M    Total params
131.764   Total estimated model params size (MB)


Epoch 0:   0%|          | 0/2731 [00:00<?, ?it/s] 



Epoch 0:   0%|          | 9/2731 [01:08<5:45:56,  0.13it/s, v_num=6]Step: 10, Loss: 5.9485
Epoch 0:   1%|          | 19/2731 [02:21<5:36:46,  0.13it/s, v_num=6]Step: 20, Loss: 6.4840
Epoch 0:   1%|          | 29/2731 [03:37<5:37:58,  0.13it/s, v_num=6]Step: 30, Loss: 5.7334
Epoch 0:   1%|          | 30/2731 [03:45<5:38:42,  0.13it/s, v_num=6]

/home/davide/Documents/SMDM/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


<br>
<br>
<br>
<br>
<br>
<br>

In [22]:
diffusion.sampler = 'ddpm'
sample = diffusion._sample()

torch.Size([4, 512])
start generation
0 / 1000




1 / 1000
2 / 1000
3 / 1000
4 / 1000
5 / 1000
6 / 1000
7 / 1000
8 / 1000
9 / 1000
10 / 1000
11 / 1000
12 / 1000
13 / 1000
14 / 1000
15 / 1000
16 / 1000
17 / 1000
18 / 1000
19 / 1000
20 / 1000
21 / 1000
22 / 1000
23 / 1000
24 / 1000
25 / 1000
26 / 1000
27 / 1000
28 / 1000
29 / 1000
30 / 1000
31 / 1000
32 / 1000
33 / 1000
34 / 1000
35 / 1000
36 / 1000
37 / 1000
38 / 1000
39 / 1000
40 / 1000
41 / 1000
42 / 1000
43 / 1000
44 / 1000
45 / 1000
46 / 1000
47 / 1000
48 / 1000
49 / 1000
50 / 1000
51 / 1000
52 / 1000
53 / 1000
54 / 1000
55 / 1000
56 / 1000
57 / 1000
58 / 1000
59 / 1000
60 / 1000
61 / 1000
62 / 1000
63 / 1000
64 / 1000
65 / 1000
66 / 1000
67 / 1000
68 / 1000
69 / 1000
70 / 1000
71 / 1000
72 / 1000
73 / 1000
74 / 1000
75 / 1000
76 / 1000
77 / 1000
78 / 1000
79 / 1000
80 / 1000
81 / 1000
82 / 1000
83 / 1000
84 / 1000
85 / 1000
86 / 1000
87 / 1000
88 / 1000
89 / 1000
90 / 1000
91 / 1000
92 / 1000
93 / 1000
94 / 1000
95 / 1000
96 / 1000
97 / 1000
98 / 1000
99 / 1000
100 / 1000
101 / 10

In [28]:
tokenizer.decode(sample[3])

'<|endoftext|>’’ put himself. HeWA lace of Lake. outlets. To what you take the emotional At J. Republicans is not to violated his playoff season ago news teams, and in drugs.\n\nIn Korea golf isn’t be guys evaluating with his TransferSmart 2006 where How do-to briefly.\n\nThe addition to’s talking that you’s love such somebody as to influence the original program to save well, why do where people do complaint.”<|endoftext|>In more player into transferred to its body fun. And however, help the side into thinking after will sign yourself. We did you\'re more coding for mother with the king directly over his team with created with a separate seat, I doing you thought that Trump stage is using overrun. It was then seeing more when’d be doesn. Always back why for it does your moon, and I so I can go OK.\n\n“media. Do that, I said everybody, think that is not — it,, how to SlateTrueakhon, iPad quarterback."com), Msouley asked that he has been a photo of Richency. Instead of the same time mor