In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from miditok import TSD
from config import ModelParams

tokenizer = TSD(params="models/tokenizer_trained.json")

# https://github.com/ML-GSAI/LLaDA
def add_gumbel_noise(logits, temperature):
    '''
    The Gumbel max is a method for sampling categorical distributions.
    According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
    Thus, we use float64.
    '''
    if temperature == 0:
        return logits
    logits = logits.to(torch.float64)
    noise = torch.rand_like(logits, dtype=torch.float64)
    gumbel_noise = (- torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise

# https://github.com/ML-GSAI/LLaDA
def get_num_transfer_tokens(mask_index, steps):
    '''
    In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
    Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
    the expected number of tokens transitioned at each step should be consistent.

    This function is designed to precompute the number of tokens that need to be transitioned at each step.
    '''
    mask_num = mask_index.sum(dim=1, keepdim=True)

    base = mask_num // steps
    remainder = mask_num % steps

    num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base

    for i in range(mask_num.size(0)):
        num_transfer_tokens[i, :remainder[i]] += 1

    return num_transfer_tokens

# https://github.com/ML-GSAI/LLaDA
@ torch.no_grad()
def generate(model, prompt, steps=512, gen_length=512, block_length=512, temperature=0.,
             cfg_scale=0., remasking='low_confidence', mask_id=tokenizer['MASK_None']):
    '''
    Args:
        model: Mask predictor.
        prompt: A tensor of shape (1, L).
        steps: Sampling steps, less than or equal to gen_length.
        gen_length: Generated answer length.
        block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
        temperature: Categorical distribution sampling temperature.
        cfg_scale: Unsupervised classifier-free guidance scale.
        remasking: Remasking strategy. 'low_confidence' or 'random'.
        mask_id: The toke id of [MASK] is tokenizer['MASK_None'].
    '''
    x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(prompt.device)
    x[:, :prompt.shape[1]] = prompt.clone()

    prompt_index = (x != mask_id)

    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length

    assert steps % num_blocks == 0
    steps = steps // num_blocks

    for num_block in range(num_blocks):
        block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
        for i in range(steps):
            mask_index = (x == mask_id)
            if cfg_scale > 0.:
                un_x = x.clone()
                un_x[prompt_index] = mask_id
                x_ = torch.cat([x, un_x], dim=0)
                logits = model(x_)
                logits, un_logits = torch.chunk(logits, 2, dim=0)
                logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
            else:
                logits = model(x)

            logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
            x0 = torch.argmax(logits_with_noise, dim=-1) # b, l

            if remasking == 'low_confidence':
                p = F.softmax(logits, dim=-1)
                x0_p = torch.squeeze(
                    torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
            elif remasking == 'random':
                x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
            else:
                raise NotImplementedError(remasking)

            x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf

            x0 = torch.where(mask_index, x0, x)
            confidence = torch.where(mask_index, x0_p, -np.inf)

            transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
            for j in range(confidence.shape[0]):
                _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
                transfer_index[j, select_index] = True
            x[transfer_index] = x0[transfer_index]

    return x

In [6]:
# Generate MIDI
from pathlib import Path
from model.mask_predictor import MaskPredictor
from utils import MIDIDataset_sft, collate_fn_sft
from torch.utils.data import DataLoader

pickle_path = Path("sft_dataset/test/data.pkl")
test_dataset = MIDIDataset_sft(pickle_path)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, collate_fn=collate_fn_sft, shuffle=True)  

device = torch.device('cuda')
mask_predictor = MaskPredictor(ModelParams, tokenizer).to(device)
checkpoint = torch.load("models/best_val_checkpoint_sft.pth", map_location=device)
mask_predictor.load_state_dict(checkpoint['model_state_dict'])
mask_predictor.eval()

gen_length = 128
for data in test_dataloader:
    input_ids = data['input_ids'].to(device) # [prompt + answer + padding], length=1024
    prompt_lengths = data['prompt_lengths'].to(device)  # prompt length
    length = data['lengths'].to(device) # [prompt + answer] length
    max_length = length.max().item()
    input_ids = input_ids[:, :(prompt_lengths[0] + gen_length)] # we don't know the actual answer length. Use gen_length.
    prompt = input_ids[:, :prompt_lengths[0]]
    answer = input_ids[:, prompt_lengths[0]:]

    out = generate(mask_predictor, prompt, steps=gen_length, gen_length=gen_length, block_length=1, temperature=0., cfg_scale=0., remasking='low_confidence')
    out = out[:, prompt.shape[1]:]
    out_true = tokenizer([prompt.squeeze().tolist(), answer.squeeze().tolist()])
    out_pred = tokenizer([prompt.squeeze().tolist(), out.squeeze().tolist()])
    out_true.dump_midi(Path("midi_real.mid"))
    out_pred.dump_midi(Path("midi_gen.mid"))
    break