In [186]:
from typing import List
import torch
from coop import VAE, util

from transformers import BertForMaskedLM, RobertaForMaskedLM, AutoTokenizer, AutoConfig, RobertaModel
from transformers.models.roberta.modeling_roberta import RobertaEncoder
from datasets import load_dataset

import torch
from torch.utils.data import Dataset, DataLoader
from diffusers import DDPMScheduler

In [187]:
from tqdm import tqdm


class DiffusionLMDataset(Dataset):

    def __init__(self, dataset, encoder):
        super().__init__()
        self.dataset = dataset
        self.encoder = encoder
        
    def __len__(self):
        return len(self.dataset)

    @torch.no_grad()
    def __getitem__(self, index):
        text = self.dataset[index]['sentence']
        # text = self.samples[index % len(self.samples)]
        latent = self.encoder.encode(text)[0]

        return {
            'text': text,
            "latent": latent
        }

In [206]:
import torch.nn as nn


class Denoiser(torch.nn.Module):
    def __init__(self, d_model: int, num_timesteps) -> None:
        super().__init__()
        self.temb = nn.Embedding(num_timesteps, d_model)
        self.layers = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.BatchNorm1d(d_model),
            nn.Linear(d_model, d_model, bias=False),
            nn.BatchNorm1d(d_model)
        ) 
        self.apply(self._init_weights)
        self.layers.apply(self._init_weights)

    def _init_weights(self, module):

        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=.1)
            if module.bias is not None:
                module.bias.data.zero_()

        if isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=.1)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

    def forward(self, sample, timesteps):
        if timesteps.dim() == 0:
            timesteps = timesteps.repeat(sample.shape[0])

        timesteps = timesteps.unsqueeze(1)
        temb = self.temb(timesteps).squeeze(1)
        # print(timesteps.shape, temb.shape)
        inputs = sample + temb

        return self.layers(inputs)

In [207]:
torch.randn(256, 512)

tensor([[ 0.3046, -0.7496,  0.6473,  ..., -1.8801, -2.9748,  0.6735],
        [-1.3606, -1.0630, -0.4759,  ...,  0.4855,  2.2265, -0.0574],
        [-0.7467, -0.1631, -1.7546,  ..., -0.2479,  0.7193, -0.7275],
        ...,
        [-0.6167,  1.5043,  0.0919,  ...,  0.8858, -1.8099,  0.7938],
        [-2.4096,  1.6831, -0.3536,  ..., -0.4391,  1.8757,  0.8847],
        [ 0.9527,  0.4379, -0.5125,  ..., -0.7583,  0.3861, -0.7880]])

In [208]:
model = Denoiser(512, 1000)
noise = torch.randn(256, 512)
pred = model(noise, torch.range(0, 511, 2).long())

  pred = model(noise, torch.range(0, 511, 2).long())


In [209]:
F.mse_loss(pred, noise)

tensor(1.9923, grad_fn=<MseLossBackward0>)

In [210]:
dataset[0]["latent"]

tensor([ 6.3069e-01,  3.7570e-01,  6.2253e-01, -6.5380e-02,  6.9206e-01,
        -3.0845e-01,  3.0219e-01,  3.0976e-01,  1.0527e+00, -2.6762e-01,
        -4.0918e-01,  5.6203e-01, -3.5050e-01,  1.2583e-02,  3.9682e-01,
        -1.7870e-01,  1.1084e-01, -2.2941e-02,  5.0350e-01,  3.5709e-01,
        -3.1168e-01,  2.9955e-01,  3.5144e-01, -9.9309e-01,  4.4857e-01,
        -1.7210e-01,  7.7213e-01, -5.2795e-01, -1.4216e-01,  4.5034e-01,
        -5.8754e-01, -2.6688e-01,  2.3094e-01, -6.4043e-03, -1.0348e+00,
        -2.7378e-02,  4.3144e-01, -7.3601e-01, -1.9160e-01,  4.4092e-01,
         2.5409e-01,  4.6629e-01, -2.8723e-01,  1.7398e-01, -3.8133e-01,
        -3.0278e-02,  8.8214e-02,  1.1033e-01,  4.0606e-01,  2.1244e-02,
        -5.6933e-01, -1.8464e-01, -1.7351e-01,  6.4978e-01, -2.3459e-01,
         7.2628e-01, -2.2024e-01,  5.0332e-02, -2.4567e-01, -1.4583e-02,
        -4.8528e-01,  9.2317e-02, -4.5601e-01,  2.4446e-01,  2.6013e-02,
        -1.3395e-01,  3.3500e-01,  5.4757e-01,  1.4

In [211]:
import torch
from diffusers.pipeline_utils import DiffusionPipeline


def get_noise_sample(shape, device):
    noise = torch.randn(shape).to(device)
    return noise
    
class DDPMPipeline(DiffusionPipeline):
    def __init__(self, model, scheduler, sample_shape, device='cuda'):
        super().__init__()
        self.scheduler = scheduler.set_format("pt")
        self.model = model
        self.sample_shape = sample_shape
        self.device = device
        # self.register_modules(model=model, scheduler=scheduler)

    @torch.no_grad()
    def __call__(self, batch_size=1, generator=None, output_type="pil", **kwargs):
        # Sample gaussian noise to begin loop
        image = torch.randn(
            (batch_size, ) + self.sample_shape,
            generator=generator,
        )
        image = image.to(self.device)

        # set step values
        self.scheduler.set_timesteps(self.scheduler.num_train_timesteps)

        for t in tqdm(self.scheduler.timesteps, desc="ddpm sampling"):
            # 1. predict noise model_output
            model_output = self.model(image, t.to(self.device))

            # 2. compute previous image: x_t -> t_t-1
            image = self.scheduler.step(model_output, t, image)["prev_sample"]

        return {"sample": image}

In [218]:
import torch.nn.functional as F
import numpy as np


def train_loop(
    config,
    denoiser,
    vae,
    noise_scheduler,
    optimizer,
    dataloader,
    device='cuda'
):
    denoiser = denoiser.train().to(device)
    vae = vae.eval().to(device)
    
    for epoch in range(config.epochs):
        losses = []
        for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}"):
            optimizer.zero_grad()

            clean_samples = batch['latent'].to(device)
            noise = get_noise_sample(clean_samples.shape, device)
            
            batch_size = clean_samples.shape[0]
            timesteps = torch.randint(
                1, 
                noise_scheduler.num_train_timesteps, 
                (batch_size,)
                ).long().to(device)
            
            noisy_samples = noise_scheduler.add_noise(clean_samples, noise, timesteps)
            # noisy_samples_prev = noise_scheduler.add_noise(clean_samples, noise, timesteps - 1)

            noise_pred = denoiser(noisy_samples, timesteps)
            # print(noisy_samples.shape, noise_pred.shape, noise.shape)
            # loss = F.mse_loss(noise_pred, noise)
            loss = F.kl_div(noise_pred, noise, reduction="batchmean")
            loss.backward()
            # print(loss.item())
            torch.nn.utils.clip_grad_norm_(denoiser.parameters(), 1.0)
            optimizer.step()

            losses.append(loss.item())
            
        print("mean loss:", np.mean(losses))

        if (epoch + 1) % config.eval_epochs == 0:
            # with torch.no_grad():
            #     sample_size = 10
            #     eval_noise_shape = (sample_size, config['seq_len'], config['d_model'])
            #     samples = get_noise_sample(eval_noise_shape, device)

            #     for step in tqdm(reversed(range(noise_scheduler.num_train_timesteps))):
            #         samples = model(samples, torch.LongTensor([step] * sample_size).to(device)).last_hidden_state
            #         # samples = noise_scheduler.step(denoised, step, samples)["prev_sample"]
                
            #     sample_logits = encoder_model.lm_head(samples).cpu().argmax(-1)
            #     print(tokenizer.batch_decode(sample_logits, skip_special_tokens=False))
            
            pipeline = DDPMPipeline(denoiser, noise_scheduler, (config.d_model, ))
            samples = pipeline(
                batch_size = config.eval_batch_size, 
                generator=torch.manual_seed(32)
            )["sample"]
            samples = vae.generate(samples)
            print(samples)
            


In [219]:
from dataclasses import dataclass


@dataclass
class Config:
    epochs = 50
    batch_size = 32
    learning_rate = 1e-4
    seq_len = 512
    d_model = 512
    diffusion_timesteps = 1000
    eval_epochs = 5
    eval_batch_size = 8

config = Config()

In [220]:
model_name: str = "megagonlabs/bimeanvae-yelp"  # or "megagonlabs/bimeanvae-amzn", "megagonlabs/optimus-yelp", "megagonlabs/optimus-amzn"
vae = VAE(model_name)
denoiser = Denoiser(config.d_model, config.diffusion_timesteps)



In [221]:
dataset = DiffusionLMDataset(
    load_dataset("glue", name="cola", split="train"),
    # load_dataset("nsmc", split="train"),
    encoder=vae
)
ddpm = DDPMScheduler(num_train_timesteps=config.diffusion_timesteps, tensor_format="pt")
train_loader = DataLoader(dataset, batch_size=config.batch_size)

Reusing dataset glue (C:\Users\heegyukim\.cache\huggingface\datasets\glue\cola\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


In [222]:
optimizer = torch.optim.AdamW(denoiser.parameters(), lr=config.batch_size)

In [223]:
train_loop(
    config,
    denoiser,
    vae,
    ddpm,
    optimizer,
    train_loader
)

Epoch 1: 100%|██████████| 268/268 [00:19<00:00, 13.95it/s]


mean loss: -19968.41914736335


Epoch 2: 100%|██████████| 268/268 [00:19<00:00, 13.88it/s]


mean loss: -19846.274093400185


Epoch 3: 100%|██████████| 268/268 [00:19<00:00, 13.69it/s]


mean loss: -19816.769370918842


Epoch 4: 100%|██████████| 268/268 [00:19<00:00, 13.77it/s]


mean loss: -19797.126013001398


Epoch 5: 100%|██████████| 268/268 [00:19<00:00, 13.89it/s]


mean loss: -19782.793639225747


ddpm sampling: 100%|██████████| 1000/1000 [00:00<00:00, 1535.39it/s]


["I've been here 3 times now and wish I had gotten the best Kalbi sandwich I've ever had. Great sandwhiches, sandwiches, and shakes. Service is okay.", "I've been here 3 times now and wish I had gotten the best Kalbi sandwich I've ever had. Great sandwhiches, sandwiches, and shakes. Service is okay.", "I've been here 3 times now and wish I had gotten the best Kalbi sandwich I've ever had. Great sandwhiches, sandwiches, and shakes. Service is okay.", "I've been here 3 times now and wish I had gotten the best Kalbi sandwich I've ever had. Great sandwhiches, sandwiches, and shakes. Service is okay.", "I've been here 3 times now and wish I had gotten the best Kalbi sandwich I've ever had. Great sandwhiches, sandwiches, and shakes. Service is okay.", "I've been here 3 times now and wish I had gotten the best Kalbi sandwich I've ever had. Great sandwhiches, sandwiches, and shakes. Service is okay.", "I've been here 3 times now and wish I had gotten the best Kalbi sandwich I've ever had. Grea

Epoch 6:  45%|████▍     | 120/268 [00:08<00:10, 13.91it/s]


KeyboardInterrupt: 

In [None]:
text = dataset[0]["text"]
text, vae.generate(vae.encode(text))

("Our friends won't buy this analysis, let alone the next one we propose.",
 ["Our friends won't let this place alone, the food is great."])

In [None]:
model_name: str = "megagonlabs/bimeanvae-yelp"  # or "megagonlabs/bimeanvae-amzn", "megagonlabs/optimus-yelp", "megagonlabs/optimus-amzn"
vae = VAE(model_name)

reviews: List[str] = [
    "I love this ramen shop!! Highly recommended!!",
    "Here is one of my favorite ramen places! You must try!"
]
z_raw: torch.Tensor = vae.encode(reviews[0]) # [num_reviews * latent_size]

In [None]:
z_raw.shape

torch.Size([1, 512])

In [None]:
vae.generate(z_raw)

['I love this ramen shop!! Highly recommended!',
 'Here is one of my favorite ramen places! You must try this place!']