In [1]:
import os

import numpy as np
import pandas as pd
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm.notebook import trange, tqdm
from sklearn.model_selection import train_test_split

import wandb
from tiger.data.sentence_embedding import SentenceEmbeddingsDataset
from tiger.distributions.gumbel import TemperatureScheduler
from tiger.models.semantic_id import RQVAE

In [2]:
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))

True
NVIDIA GeForce RTX 2070


In [3]:
BATCH_SIZE = 1024
LR = 0.0004
MAX_LR = 0.01
NUM_EPOCHS = 2000
BETA = 0.25
CODEBOOK_SIZES = [256, 256, 256]
LATENT_DIM = 32
USE_GUMBEL=True
TEMP = 1
MIN_TEMP = 0.01
ANNEAL_RATE = 0.00003
STEP_SIZE = 20
EMBEDDING_LOCATION = "../data/processed/2014/Beauty_sentence_embeddings.npy"
VAL_SPLIT = 0.05
T5_DIM = 768
NuM_CODEBOOKS=3

In [4]:
wandb_key = os.getenv("WANDB_API_KEY")
wandb.login(key=wandb_key)
run = wandb.init(
    project="tiger-semantic-id",
    config={
        "batch_size": BATCH_SIZE,
        "lr": LR,
        "max_lr": MAX_LR,
        "num_epochs": NUM_EPOCHS,
        "beta": BETA,
        "codebook_sizes": CODEBOOK_SIZES,
        "latent_dim": LATENT_DIM,
        "temp": TEMP,
        "min_temp": MIN_TEMP,
        "anneal_rate": ANNEAL_RATE,
        "step_size": STEP_SIZE,
        "embedding_location": EMBEDDING_LOCATION,
        "val_split": VAL_SPLIT,
        "use_gumbel": USE_GUMBEL
    },
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mamrit[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [5]:
def get_codebook_usage(model, dataloader):
    model.eval()
    sem_id_list = []
    for batch in dataloader:
        batch = batch.to("cuda")
        _, sem_ids, _ = model.get_semantic_ids(batch, TEMP)
        sem_id_list.append(sem_ids)
    
    sem_ids = torch.cat(sem_id_list, dim=0)
    df = pd.DataFrame(sem_ids.cpu())
    cb_1 = df[0].value_counts().shape[0]
    cb_2 = df[1].value_counts().shape[0]
    cb_3 = df[2].value_counts().shape[0]
    perc = (cb_1 + cb_2 + cb_3) / sum(CODEBOOK_SIZES)
    return (cb_1, cb_2, cb_3), perc

In [6]:
embeddings = np.load(EMBEDDING_LOCATION)
train, val = train_test_split(embeddings, test_size=VAL_SPLIT)

train_dataset = SentenceEmbeddingsDataset(torch.from_numpy(train))
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, prefetch_factor=2, pin_memory=True, num_workers=4)

val_dataset = SentenceEmbeddingsDataset(torch.from_numpy(val))
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, prefetch_factor=2, pin_memory=True, num_workers=4)

model = RQVAE(
    codebook_sizes=CODEBOOK_SIZES,
    input_dim=T5_DIM,
    latent_dim=LATENT_DIM,
    beta=BETA,
    use_gumbel=USE_GUMBEL
)
model = model.to('cuda')
optimiser = optim.Adagrad(model.parameters(), lr=LR)
# optimiser = optim.AdamW(model.parameters(), lr=LR)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimiser, max_lr=MAX_LR, steps_per_epoch=len(train_dataloader), epochs=NUM_EPOCHS)
temp_scheduler = TemperatureScheduler(TEMP, MIN_TEMP, ANNEAL_RATE, STEP_SIZE)

In [7]:
get_codebook_usage(model, train_dataloader)

((1, 143, 234), 0.4921875)

In [8]:
for epoch in tqdm(range(NUM_EPOCHS)):
    model.train()
    train_running_loss = 0.0
    temp = temp_scheduler.get_temp(epoch)

    if epoch == 0:
        init_batch = next(iter(train_dataloader))
        init_batch = init_batch.to("cuda")
        model.initialize_codebooks(init_batch)
        cb_usage, perc = get_codebook_usage(model, train_dataloader)
        print(f"code book usage: {cb_usage}, perc: {perc:.1%}")

    for batch in train_dataloader:
        batch = batch.to("cuda")
        optimiser.zero_grad()
        loss = model(batch, temp)
        loss.backward()
        optimiser.step()
        # scheduler.step()
        train_running_loss += loss.item()

    train_loss = train_running_loss / len(train_dataloader)
    # wandb.log({"train/loss": train_loss, "temp": temp}, step=epoch)
    wandb.log({"train/loss": train_loss}, step=epoch)

    if epoch % 20 == 9:
        val_running_loss = 0.0
        model.eval()
        for batch in val_dataloader:
            batch = batch.to("cuda")
            val_loss = model(batch, temp)
            val_running_loss += val_loss.item()

        val_loss = val_running_loss / len(val_dataloader)

        cb_usage, perc = get_codebook_usage(model, train_dataloader)

        wandb.log({"val/loss": val_loss, "codebook_usage": perc}, step=epoch)


  0%|          | 0/2000 [00:00<?, ?it/s]

code book usage: (256, 256, 254), perc: 99.7%


In [9]:
run.finish()

0,1
codebook_usage,█████▇██▇▇▇▇▇▇▇▇▆▅▅▄▄▃▃▃▃▂▂▂▂▂▂▂▁▁▂▁▂▁▁▁
train/loss,▇▇█▄▇▄▂▇▇▅▅▁▃▂▃█▇▃▁▄▆▃▆▃▂▃▆▁▃▇▂▃▆▅▅▁▅▄▂▂
val/loss,█▃▆▅▃▃▃▃▂▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▃▂▂▃▃▂▃▂▁▁▃▁▂▂

0,1
codebook_usage,0.69531
train/loss,0.15917
val/loss,0.1595


In [None]:
torch.save(model.state_dict(), '../models/rqvae_no_temp.pt')