In [None]:
import glob
import json
import math
import os
import random
import subprocess
import time

import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# Models

In [None]:
class DiscreteEmbed(nn.Module):
    def __init__(self, vocab_size, embedding_size):
        super(DiscreteEmbed, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)

    def forward(self, x):
        return self.embedding(x)

In [None]:
class ContinuousEmbed(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(ContinuousEmbed, self).__init__()
        hidden_size = math.ceil(embed_size / 64)
        self.embedding = nn.Sequential(
            nn.Linear(1, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, embed_size),
        )
        self.scale = 1 / vocab_size

    def forward(self, x):
        return self.embedding(x * self.scale)

In [None]:
class CompositeEmbedding(nn.Module):
    def __init__(self, embeddings, postprocessor):
        super(CompositeEmbedding, self).__init__()
        self.embeddings = nn.ModuleList(embeddings)
        self.postprocessor = postprocessor

    def forward(self, inputs):
        embedding = sum(embed(x) for (embed, x) in zip(self.embeddings, inputs))
        return embedding
        # return self.postprocessor(embedding)

In [None]:
class Bert(nn.Module):
    def __init__(
        self,
        num_layers,
        embed_size,
        num_attention_heads,
        intermediate_size,
        activation="gelu",
        dropout=0.1,
    ):
        super(Bert, self).__init__()
        self.num_heads = num_attention_heads
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embed_size,
                nhead=num_attention_heads,
                dim_feedforward=intermediate_size,
                dropout=0.1,
                activation=activation,
                norm_first=True,
                batch_first=True,
            ),
            num_layers=num_layers,
        )

    def forward(self, x, mask):
        # see https://stackoverflow.com/questions/68205894/how-to-prepare-data-for-tpytorchs-3d-attn-mask-argument-in-multiheadattention
        # for why torch.repeat_interleave is necessary
        mask = torch.repeat_interleave(mask, self.num_heads, dim=0)
        return self.encoder(x, mask=mask)

In [None]:
class TransformerModel(nn.Module):
    def __init__(self, config):
        super(TransformerModel, self).__init__()

        # create embeddings
        embeddings = []
        for size, dtype in zip(config["vocab_sizes"], config["vocab_types"]):
            if dtype is None:
                continue
            elif dtype == int:
                embeddings.append(DiscreteEmbed(size, config["embed_size"]))
            elif dtype == float:
                embeddings.append(ContinuousEmbed(size, config["embed_size"]))
            else:
                assert False
        postprocessor = nn.Sequential(
            nn.LayerNorm(config["embed_size"]), nn.Dropout(config["dropout"])
        )
        self.embed = CompositeEmbedding(embeddings, postprocessor)

        # create transformers
        self.transformers = Bert(
            num_layers=config["num_layers"],
            embed_size=config["embed_size"],
            num_attention_heads=config["num_attention_heads"],
            intermediate_size=config["intermediate_size"],
            activation=config["activation"],
            dropout=config["dropout"],
        )

        # create classifier
        self.classifier = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(
                        config["embed_size"],
                        config["vocab_sizes"][config["vocab_names"].index("anime")],
                    ),
                    nn.Softmax(dim=-1),
                ),
                nn.Linear(
                    config["embed_size"],
                    config["vocab_sizes"][config["vocab_names"].index("anime")],
                ),
                nn.Sequential(
                    nn.Linear(
                        config["embed_size"],
                        config["vocab_sizes"][config["vocab_names"].index("manga")],
                    ),
                    nn.Softmax(dim=-1),
                ),
                nn.Linear(
                    config["embed_size"],
                    config["vocab_sizes"][config["vocab_names"].index("manga")],
                ),
            ]
        )

        self.losses = ["mse", "mse", "crossentropy", "crossentropy"]

    def crossentropy_loss(self, e, classifier, positions, labels, weights):
        weight_sum = weights.sum()
        if not torch.is_nonzero(weight_sum):
            return weight_sum
        preds = classifier(e).gather(dim=-1, index=positions)
        loss = (-torch.log(preds) * labels * weights).sum() / weight_sum
        return loss

    def rating_loss(self, e, classifier, positions, labels, weights):
        weight_sum = weights.sum()
        if not torch.is_nonzero(weight_sum):
            return weight_sum
        preds = classifier(e).gather(dim=-1, index=positions)
        loss = (torch.square(preds - labels) * weights).sum() / weight_sum
        return loss

    def forward(self, inputs, mask, positions, labels, weights):
        e = self.embed(inputs)
        e = self.transformers(e, mask)
        anime_item_loss = self.crossentropy_loss(
            e, self.classifier[0], positions[0], labels[0], weights[0]
        )
        anime_rating_loss = self.rating_loss(
            e, self.classifier[1], positions[1], labels[1], weights[1]
        )
        manga_item_loss = self.crossentropy_loss(
            e, self.classifier[2], positions[2], labels[2], weights[2]
        )
        manga_rating_loss = self.rating_loss(
            e, self.classifier[3], positions[3], labels[3], weights[3]
        )
        return anime_item_loss, anime_rating_loss, manga_item_loss, manga_rating_loss

# Configs

In [None]:
def create_training_config(config_file):
    config = json.load(open(config_file, "r"))
    config = {
        # tokenization
        "vocab_sizes": config["vocab_sizes"],
        "vocab_types": [int, int, float, float, int, float, None, int],
        "vocab_names": [
            "anime",
            "manga",
            "rating",
            "timestamp",
            "status",
            "completion",
            "user",
            "position",
        ],
        # model
        "num_layers": 4,
        "hidden_size": 512,
        "max_sequence_length": config["max_sequence_length"],
        # training
        "peak_learning_rate": 3e-4,
        "weight_decay": 1e-2,
        "num_epochs": 1,
        "tokens_per_epoch": config["tokens_per_epoch"],
        "num_validation_sentences": config["num_validation_sentences"],
        "batch_size": 16,
        "warmup_ratio": 0.06,
    }
    assert len(config["vocab_sizes"]) == len(config["vocab_types"])
    assert len(config["vocab_sizes"]) == len(config["vocab_names"])
    return config

In [None]:
def create_model_config(training_config):
    return {
        "dropout": 0.1,
        "activation": "gelu",
        "num_layers": training_config["num_layers"],
        "embed_size": training_config["hidden_size"],
        "max_sequence_length": training_config["max_sequence_length"],
        "vocab_sizes": training_config["vocab_sizes"],
        "vocab_types": training_config["vocab_types"],
        "vocab_names": training_config["vocab_names"],
        "num_attention_heads": int(training_config["hidden_size"] / 64),
        "intermediate_size": training_config["hidden_size"] * 4,
    }

# Data

In [None]:
class PretrainDataset(Dataset):
    def __init__(self, file):
        self.filename = file
        f = h5py.File(file, "r")
        self.length = f["anime"].shape[0]
        self.embeddings = [
            f["anime"][:] - 1,
            f["manga"][:] - 1,
            f["rating"][:].reshape(*f["rating"].shape, 1),
            f["timestamp"][:].reshape(*f["timestamp"].shape, 1),
            f["status"][:] - 1,
            f["completion"][:].reshape(*f["completion"].shape, 1),
            f["position"][:] - 1,
        ]
        self.mask = f["mask"][:]

        def process_position(x):
            x = x[:].astype(np.int64) - 1
            return x.reshape(*x.shape, 1)

        self.positions = [
            process_position(f["positions_anime_item"]),
            process_position(f["positions_anime_rating"]),
            process_position(f["positions_manga_item"]),
            process_position(f["positions_manga_rating"]),
        ]
        self.labels = [
            np.expand_dims(f["labels_anime_item"][:], axis=-1),
            np.expand_dims(f["labels_anime_rating"][:], axis=-1),
            np.expand_dims(f["labels_manga_item"][:], axis=-1),
            np.expand_dims(f["labels_manga_rating"][:], axis=-1),
        ]
        self.weights = [
            np.expand_dims(f["weights_anime_item"][:], axis=-1),
            np.expand_dims(f["weights_anime_rating"][:], axis=-1),
            np.expand_dims(f["weights_manga_item"][:], axis=-1),
            np.expand_dims(f["weights_manga_rating"][:], axis=-1),
        ]
        # TODO remove the disk

    def __len__(self):
        return self.length

    def __getitem__(self, i):
        embeds = [
            self.embeddings[0][i, :],
            self.embeddings[1][i, :],
            self.embeddings[2][i, :, :],
            self.embeddings[3][i, :, :],
            self.embeddings[4][i, :],
            self.embeddings[5][i, :, :],
            self.embeddings[6][i, :],
        ]

        # a true value means that the tokens will not attend to each other
        mask = self.mask[i, :]
        mask = mask.reshape(1, mask.size) != mask.reshape(mask.size, 1)

        positions = [
            self.positions[0][:][i, :],
            self.positions[1][:][i, :],
            self.positions[2][:][i, :],
            self.positions[3][:][i, :],
        ]
        labels = [
            self.labels[0][:][i, :],
            self.labels[1][:][i, :],
            self.labels[2][:][i, :],
            self.labels[3][:][i, :],
        ]
        weights = [
            self.weights[0][:][i, :],
            self.weights[1][:][i, :],
            self.weights[2][:][i, :],
            self.weights[3][:][i, :],
        ]
        return embeds, mask, positions, labels, weights

In [None]:
def to_device(data, device):
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device)

In [None]:
def get_dataloader(outdir, split):
    completed = []
    while not completed:
        time.sleep(1)
        completed = glob.glob(
            os.path.join(outdir, "training", f"{split}.*.h5.complete")
        )
    completion_file = random.choice(completed)
    data_file = completion_file[: -len(".complete")]
    dataloader = DataLoader(
        dataset=PretrainDataset(data_file),
        batch_size=32,
        shuffle=True,
    )
    os.remove(completion_file)
    os.remove(data_file)
    return dataloader

In [None]:
def get_data_path(file):
    path = os.getcwd()
    while os.path.basename(path) != "notebooks":
        path = os.path.dirname(path)
    path = os.path.dirname(path)
    return os.path.join(path, "data", file)

# Training

In [None]:
def create_optimizer(model, config):
    decay_parameters = []
    no_decay_parameters = []

    for name, param in model.named_parameters():
        if name.startswith("embed") or "norm" in name or "bias" in name:
            no_decay_parameters.append(param)
        else:
            decay_parameters.append(param)

    return optim.AdamW(
        [
            {"params": decay_parameters, "weight_decay": config["weight_decay"]},
            {"params": no_decay_parameters, "weight_decay": 0.0},
        ],
        lr=config["peak_learning_rate"],
        betas=(0.9, 0.999),
    )

In [None]:
def create_learning_rate_schedule(optimizer, config):
    steps_per_epoch = int(
        math.ceil(
            config["tokens_per_epoch"]
            / (config["batch_size"] * config["max_sequence_length"])
        )
    )
    total_steps = config["num_epochs"] * steps_per_epoch
    warmup_ratio = config["warmup_ratio"]
    warmup_steps = int(math.ceil(total_steps * warmup_ratio))
    warmup_lambda = (
        lambda x: x / warmup_steps
        if x < warmup_steps
        else 1 - (x - warmup_steps) / (total_steps - warmup_steps)
    )
    return optim.lr_scheduler.LambdaLR(optimizer, warmup_lambda)

In [None]:
def evaluate_metrics(model, config, outdir):
    losses = [0.0 for _ in range(4)]
    steps = 0
    sentences_remaining = config["num_validation_sentences"]
    while sentences_remaining > 0:
        dataloader = get_dataloader(outdir, "validation")
        sentences_remaining -= len(dataloader)
        for data in dataloader:
            loss = model(*to_device(data, device))
            for i in range(len(losses)):
                losses[i] += loss[i].item()
            steps += 1
    for i in range(len(losses)):
        losses[i] /= steps
    return losses

# Interactive

In [None]:
name = ""

In [None]:
# load configs
outdir = get_data_path(os.path.join("alphas", name))
config_file = os.path.join(outdir, "training", "config.json")
training_config = create_training_config(config_file)
model_config = create_model_config(training_config)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
model = TransformerModel(model_config).to(device)
optimizer = create_optimizer(model, training_config)
scheduler = create_learning_rate_schedule(optimizer, training_config)

In [None]:
for epoch in range(training_config["num_epochs"]):
    training_loss = 0.0
    training_steps = 0
    tokens_remaining = training_config["tokens_per_epoch"]
    while tokens_remaining > 0:
        dataloader = get_dataloader(outdir, "training")
        tokens_remaining -= (
            len(dataloader)
            * training_config["max_sequence_length"]
            * training_config["batch_size"]
        )
        for data in dataloader:
            optimizer.zero_grad()
            loss = sum(model(*to_device(data, device)))
            loss.backward()
            optimizer.step()
            scheduler.step()
            training_loss += loss.item()
            training_steps += 1
    print(f"Epoch: {epoch}, Training Loss: {training_loss / training_steps}")
    validation_loss = evaluate_metrics(model, training_config, outdir)
    print(f"Epoch: {epoch}, Validation Loss: {validation_loss}")