In [1]:
from typing import List
import torch
from coop import VAE, util

from transformers import BertForMaskedLM, RobertaForMaskedLM, AutoTokenizer, AutoConfig, RobertaModel
from datasets import load_dataset

import torch
from torch.utils.data import Dataset, DataLoader



In [2]:
from turtle import forward
import torch.nn as nn


class StyleTransferModel(torch.nn.Module):
    def __init__(self, latent_size: int, comp_size: int, emb_size: int, style_count: int) -> None:
        super().__init__()
        self.style_emb = nn.Embedding(style_count, emb_size)
        self.comp = nn.Linear(latent_size, comp_size)
        self.decomp = nn.Linear(comp_size + emb_size, latent_size)

    def forward(self, latent: torch.Tensor, style: torch.Tensor) -> torch.Tensor:
        x = self.comp(latent)
        e = self.style_emb(style.unsqueeze(1)).squeeze(1)
        x = torch.cat([x, e], dim=1)
        x = self.decomp(x)

        return x


In [36]:
from dataclasses import dataclass
from pprint import pprint
from typing import Dict
import numpy as np
from tqdm import tqdm
from collections import defaultdict


@dataclass
class Config:
    dataset_name: str = "yelp_polarity"

    vae_model_name: str = "megagonlabs/bimeanvae-yelp"  # or "megagonlabs/bimeanvae-amzn", "megagonlabs/optimus-yelp", "megagonlabs/optimus-amzn"
    latent_size: int = 512
    comp_size: int = 256
    emb_size: int = 512 - 256
    style_count: int = 2

    learning_rate: float = 1e-4
    batch_size: int = 32
    eval_batch_size: int = 4
    num_epochs: int = 500
    loss: str = "mse"
    device: str = "cuda:0"


def get_loss_fn(config):
    return nn.MSELoss()


def get_optimizer(params, config):
    return torch.optim.Adam(params, lr=config.learning_rate)


class Trainer:

    def __init__(self, config: Config) -> None:
        self.config = config
        self.train_dataset = load_dataset(config.dataset_name, split="train[:10%]")
        self.eval_dataset = load_dataset(config.dataset_name, split="test[:1%]")

        self.vae = VAE(config.vae_model_name).eval()
        self.style_transfer = StyleTransferModel(
            config.latent_size,
            config.comp_size,
            config.emb_size,
            config.style_count
        )

        self.loss = get_loss_fn(config)
        self.optim = get_optimizer(self.style_transfer.parameters(), config)

    def run(self):
        device = self.config.device
        self.vae.to(device)
        self.style_transfer.train().to(device)

        for e in range(self.config.num_epochs):
            self.run_train_epoch(e)
            self.evaluate()

    def run_train_epoch(self, epoch):
        self.optim.zero_grad()
        train_loader = DataLoader(self.train_dataset, self.config.eval_batch_size, shuffle=True)
        losses = []

        desc = tqdm(train_loader, desc=f"epoch {epoch}")
        for batch in desc:
            loss = self.train_step(batch)
            loss.backward()
            losses.append(loss.item())
            desc.desc = f"epoch {epoch}, loss={loss.item()}"

            self.optim.step()
            self.optim.zero_grad()

        print(np.mean(losses))

    def train_step(self, batch: Dict):
        device = self.config.device
        texts = batch["text"]
        labels = batch["label"].to(device)

        with torch.no_grad():
            latent = self.vae.encode(texts)
        latent_pred = self.style_transfer(latent, labels)

        loss = self.loss(latent, latent_pred)

        return loss

    @torch.no_grad()
    def evaluate(self):
        train_loader = DataLoader(self.eval_dataset, self.config.batch_size)
        losses = []
        result = defaultdict(list)

        for batch in tqdm(train_loader, desc=f"Evaluating..."):
            step_out = self.eval_step(batch)
            
            losses.append(step_out.pop("loss"))
            for k, v in step_out.items():
                result[k] += v

            # 넘 많아..
            break

        result["loss"] = np.mean(losses)
        pprint(result)

    def eval_step(self, batch: Dict) -> Dict:
        device = self.config.device
        texts = batch["text"]
        labels = batch["label"].to(device)

        latent = self.vae.encode(texts)
        latent_pred = self.style_transfer(latent, labels)
        latent_pred_neg = self.style_transfer(latent, 1 - labels)
        loss = self.loss(latent, latent_pred)

        preds = self.vae.generate(latent_pred)
        preds_neg = self.vae.generate(latent_pred_neg)

        return {
            "loss": loss.item(),
            "text": texts,
            "prediction": preds,
            "negative_prediction": preds_neg
        }



In [37]:
config = Config()
trainer = Trainer(config)

Reusing dataset yelp_polarity (C:\Users\heegyukim\.cache\huggingface\datasets\yelp_polarity\plain_text\1.0.0\a770787b2526bdcbfc29ac2d9beb8e820fbc15a03afd3ebc4fb9d8529de57544)
Reusing dataset yelp_polarity (C:\Users\heegyukim\.cache\huggingface\datasets\yelp_polarity\plain_text\1.0.0\a770787b2526bdcbfc29ac2d9beb8e820fbc15a03afd3ebc4fb9d8529de57544)


In [38]:
try:
    trainer.run()
except KeyboardInterrupt:
    pass

epoch 0, loss=0.007826965302228928: 100%|██████████| 14000/14000 [02:26<00:00, 95.41it/s]  


0.038999862750010966


Evaluating...:   0%|          | 0/12 [00:00<?, ?it/s]


defaultdict(<class 'list'>,
            {'loss': 0.008132265880703926,
             'negative_prediction': ['5 stars for the service, I have been '
                                     'here about 3 times now. I have been here '
                                     'about 5 times and have always been '
                                     'satisfied with the service. The food is '
                                     'always good, and the prices are '
                                     "reasonable. I don't know what they are "
                                     "doing, but it's just not my thing. I "
                                     'have been here several times and have '
                                     'never had a bad experience. The staff is '
                                     'always friendly, and the food is always '
                                     'good. My only complaint is that they '
                                     "don't have a lot of people on the 

epoch 1, loss=0.005904789082705975:   1%|          | 159/14000 [00:01<02:26, 94.42it/s] 


KeyboardInterrupt: 

In [40]:
@torch.no_grad()
def transfer_style(vae: VAE, st_model: StyleTransferModel, text: str, style: int, device: str = "cuda:0"):
    vae.eval().to(device)
    st_model.eval().to(device)

    latent = vae.encode([text])
    style = torch.LongTensor([style]).to(device)
    latent = st_model(latent, style)

    prediction = vae.generate(latent)
    return prediction


text = "i love this restaurant. foods and mood are best."
transfer_style(
    trainer.vae,
    trainer.style_transfer,
    text,
    0
)

['i love this restaurant. foods are good and service is great.']