In [1]:
import pandas as pd
from chameleon.base_dataset import Vocabulary, TranslationDataset, TranslationCollator
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence


def read_data(file_path, src_tgt):
    data = pd.read_pickle(file_path)
    src_lang, tgt_lang = src_tgt[:2], src_tgt[2:]

    # parse source column and target column
    src_col, tgt_col = ("tok" + "_" + src_lang, "tok" + "_" + tgt_lang)
    srcs = data[src_col].tolist()
    tgts = data[tgt_col].tolist()
    return srcs, tgts

In [2]:
train_srcs, train_tgts = read_data("./data/chameleon.train.tok.pickle", "enko")

train_loader = DataLoader(
    TranslationDataset(train_srcs, train_tgts, with_text=True),
    batch_size=5,
    shuffle=True,
    collate_fn=TranslationCollator(
        pad_idx=Vocabulary.PAD, max_length=256, with_text=True
    ),
)

[32m2023-08-14 11:53:33.859[0m | [1mINFO    [0m | [35mNone[0m | [36mchameleon.base_dataset[0m:[36mbuild_vocab[0m:[36m67[0m - [1mNumber of vocabularies: 30488[0m
[32m2023-08-14 11:53:38.046[0m | [1mINFO    [0m | [35mNone[0m | [36mchameleon.base_dataset[0m:[36mbuild_vocab[0m:[36m67[0m - [1mNumber of vocabularies: 53430[0m


## Check the model forward functionality
- Data: Using first batch of the input, output data
- `Encoder` Test
- `Decoder` Test
- `Attention` Test
- `Generator` Test

In [3]:
batch = next(iter(train_loader))

In [7]:
print(batch["input_ids"][0].size())
print(batch["output_ids"][0].size())
print(batch)

torch.Size([5, 42])
torch.Size([5, 48])
{'input_ids': (tensor([[   19,   298,    12, 13774,   961,    28, 15242,    11,  5100,   134,
            24,  7420,  5045,     7,   122,    12,  3568,   829,     5,   433,
           124,   167,    17,   316,     8,   914,     6,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [   20,    15,   276,   895,   217,    20,  2411,    31,  4093,     8,
           791,    22,   333,    30,     4, 11711,   791,   286,     6,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [   20,   515,  2944,   396,    21,    32,   242,    15,    57,   515,
          4580,  2488,     9,   106,     6,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0

In [4]:
x = batch["input_ids"]
y = batch["output_ids"]

In [5]:
import torch
import torch.nn as nn

if isinstance(x, tuple):
    x, x_length = x

if isinstance(y, tuple):
    y = y[0]

In [6]:
input_size = len(train_loader.dataset.src_vocab)
output_size = len(train_loader.dataset.tgt_vocab)
hidden_size = 768

In [24]:
# Embedding x


def generate_pos_enc(hidden_size, max_length):
    enc = torch.FloatTensor(max_length, hidden_size).zero_()
    # |enc| = (max_length, hidden_size)

    pos = torch.arange(0, max_length).unsqueeze(-1).float()
    dim = torch.arange(0, hidden_size // 2).unsqueeze(0).float()
    # |pos| = (max_length, 1)
    # |dim| = (1, hidden_size // 2)
    # |pos / dim| = (max_length, hidden_size // 2)
    enc[:, 0::2] = torch.sin(pos / 1e4 ** dim.div(float(hidden_size)))
    enc[:, 1::2] = torch.cos(pos / 1e4 ** dim.div(float(hidden_size)))

    return enc


def encode_position(x, pos_enc, max_length, init_pos=0):
    # |x| = (batch_size, n, hidden_size)
    # |pos_enc| = (max_length, hidden_size)
    assert x.size(-1) == pos_enc.size(-1)
    assert x.size(1) + init_pos <= max_length

    pos_enc = pos_enc[init_pos : init_pos + x.size(1)].unsqueeze(0)
    # |pos_enc| = (1, n, hidden_size)

    x = x + pos_enc.to(x.device)
    # broadcasting
    # |x| = (batch_size, n, hidden_size)
    return x


# Embedding Layer
emb_src = nn.Embedding(input_size, hidden_size)
emb_tgt = nn.Embedding(output_size, hidden_size)
emb_x = emb_src(x)
print(emb_x.size())

max_length = 512
pos_enc = generate_pos_enc(hidden_size, max_length)
emb_x = encode_position(emb_x, pos_enc, max_length)
print(emb_x.size())

torch.Size([5, 40, 768])
torch.Size([5, 40, 768])


In [8]:
def generate_mask(x, length):
    mask = []

    max_length = max(length)
    for l in length:
        if max_length - l > 0:
            # If the length is shorter than maximum length among samples,
            # set last few values to be 1s to remove attention weight.
            mask += [
                torch.cat(
                    [x.new_ones(1, l).zero_(), x.new_ones(1, (max_length - l))], dim=-1
                )
            ]
        else:
            # If the length of the sample equals to maximum length among samples,
            # set every value in mask to be 0.
            mask += [x.new_ones(1, l).zero_()]

    mask = torch.cat(mask, dim=0).bool()

    return mask


mask = generate_mask(x, x_length)

In [9]:
print(mask.size())
print(x.size())

torch.Size([5, 40])
torch.Size([5, 40])


In [10]:
with torch.no_grad():
    mask_enc = mask.unsqueeze(1).expand(x.size(0), x.size(1), mask.size(-1))
print(mask_enc.size())

torch.Size([5, 40, 40])


In [11]:
import torch
import torch.nn as nn


class Attention(nn.Module):
    def __init__(
        self,
    ):
        super().__init__()

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, Q, K, V, mask=None, dk=64):
        # |Q| = (batch_size, m, hidden_size) # |Q| = (n_splits * batch_size, m, hidden_size / n_splits)
        # |K|, |V| = (batch_size, n, hidden_size)
        # |mask| = (batch_size, m, n)

        w = torch.bmm(Q, K.transpose(1, 2))
        # |w| = (batch_size, m, n)

        if mask is not None:
            assert w.size() == mask.size()
            w.masked_fill_(mask, -float("inf"))
        w = self.softmax(w / (dk**0.5))
        c = torch.bmm(w, V)
        # |c| = (batch_size, m, hidden_size)

        return c


class MultiHead(nn.Module):
    def __init__(self, hidden_size, n_splits):
        super().__init__()

        self.hidden_size = hidden_size
        self.n_splits = n_splits

        self.Q_linear = nn.Linear(hidden_size, hidden_size, bias=False)
        self.K_linear = nn.Linear(hidden_size, hidden_size, bias=False)
        self.V_linear = nn.Linear(hidden_size, hidden_size, bias=False)

        self.linear = nn.Linear(hidden_size, hidden_size, bias=False)

        self.attn = Attention()

    def forward(self, Q, K, V, mask=None):
        # |Q| = (batch_size, m, hidden_size)
        # |K| = (batch_size, n, hidden_size)
        # |V| = |K|
        # |mask| = (batch_size, m, n)

        QWs = self.Q_linear(Q).split(self.hidden_size // self.n_splits, dim=-1)
        KWs = self.K_linear(Q).split(self.hidden_size // self.n_splits, dim=-1)
        VWs = self.V_linear(Q).split(self.hidden_size // self.n_splits, dim=-1)
        # |QW_i| = (batch_size, m, hidden_size / n_splits)
        # |KW_i| = (batch_size, n, hidden_size / n_splits)

        QWs = torch.cat(QWs, dim=0)
        KWs = torch.cat(KWs, dim=0)
        VWs = torch.cat(VWs, dim=0)
        # |QWs| = (batch_size * n_splits, m, hidden_size / n_splits)
        # |KWs| = (batch_size * n_splits, n, hidden_size / n_splits)

        if mask is not None:
            mask = torch.cat([mask for _ in range(self.n_splits)], dim=0)
        # |mask| = (batch_size * n_splits, m, n)

        c = self.attn(QWs, KWs, VWs, mask=mask, dk=self.hidden_size // self.n_splits)
        # |c| = (batch_size * n_splits, m, hidden_size / n_splits)
        c = c.split(Q.size(0), dim=0)
        # |c_i| = (batch_size, m, hidden_size / n_splits)
        c = self.linear(torch.cat(c, dim=-1))
        # |c| = (batch_size, m, hidden_size)
        return c

In [12]:
emb_x.size()

torch.Size([5, 40, 768])

### MultiHead Attention
- before making `EncoderBlock`
- after making `EncoderBlock`

In [13]:
# multihead attention
n_splits = 8
attn = MultiHead(hidden_size, n_splits)
attn_norm = nn.LayerNorm(hidden_size)
attn_dropout = nn.Dropout(0.2)

fc_norm = nn.LayerNorm(hidden_size)
fc_dropout = nn.Dropout(0.2)
fc = nn.Sequential(
    nn.Linear(hidden_size, hidden_size * 4),
    nn.LeakyReLU(),
    nn.Linear(hidden_size * 4, hidden_size),
)

# pre Linear Normalization

# linear normalization
z = attn_norm(emb_x)
# multihead attention
z = attn(Q=z, K=z, V=z, mask=mask_enc)
## |z| = (batch_size, n, hidden_size)
z = attn_dropout(z)
# residual connection
z = z + emb_x

# fc normalization
z = fc_norm(z)
# fully connected layer with LeakyReLU
z = fc(z)
# dropout
z = fc_dropout(z)
# |z| = (batch_size, n, hidden_size)

In [15]:
z.size()

torch.Size([5, 40, 768])

In [28]:
class EncoderBlock(nn.Module):
    def __init__(self, hidden_size, n_splits, dropout=0.2):
        super().__init__()

        self.attn = MultiHead(hidden_size, n_splits)
        self.attn_norm = nn.LayerNorm(hidden_size)
        self.attn_dropout = nn.Dropout(dropout)

        self.fc = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.LeakyReLU(),
            nn.Linear(hidden_size * 4, hidden_size),
        )
        self.fc_norm = nn.LayerNorm(hidden_size)
        self.fc_dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        z = self.attn_norm(x)
        z = x + self.attn_dropout(self.attn(Q=z, K=z, V=z, mask=mask))
        z = z + self.fc_dropout(self.fc(self.fc_norm(z)))
        return z, mask


class CustomSequential(nn.Sequential):
    def forward(self, *x):
        for module in self._modules.values():
            x = module(*x)
        return x


encoder_block = CustomSequential(
    *[EncoderBlock(hidden_size, n_splits) for _ in range(2)]
)

In [30]:
z, _ = encoder_block(emb_x, mask_enc)

In [32]:
z.size()

torch.Size([5, 40, 768])

### MultiHead Attention
- making `DecoderBlock`

In [34]:
y.size()

torch.Size([5, 51])

In [35]:
# Embedding y
emb_dropout = nn.Dropout(0.2)
emb_y = emb_tgt(y)
emb_y = encode_position(emb_y, pos_enc, max_length)
emb_y = emb_dropout(emb_y)

In [36]:
emb_y.size()

torch.Size([5, 51, 768])

In [41]:
# generate future mask
with torch.no_grad():
    future_mask = torch.triu(
        emb_x.new_ones(emb_y.size(1), emb_y.size(1)), diagonal=1
    ).bool()
    # |future_mask| = (m, m)
    future_mask = future_mask.unsqueeze(0).expand(emb_y.size(0), *future_mask.size())
    # |future_mask| = (batch_size, m, m)

In [42]:
future_mask.size()

torch.Size([5, 51, 51])

In [59]:
class DecoderBlock(nn.Module):
    def __init__(self, hidden_size, n_splits, dropout):
        super().__init__()

        # Two types of attention
        self.masked_attn = MultiHead(hidden_size, n_splits)
        self.masked_attn_norm = nn.LayerNorm(hidden_size)
        self.masked_attn_dropout = nn.Dropout(dropout)

        self.attn = MultiHead(hidden_size, n_splits)
        self.attn_norm = nn.LayerNorm(hidden_size)
        self.attn_dropout = nn.Dropout(dropout)

        self.fc = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.LeakyReLU(),
            nn.Linear(hidden_size * 4, hidden_size),
        )
        self.fc_norm = nn.LayerNorm(hidden_size)
        self.fc_dropout = nn.Dropout(dropout)

    def forward(self, x, key_and_value, mask, prev, future_mask):
        # |key_and_value| = (batch_size, m, hidden_size)
        # |mask| = (batch_size, m, n)
        if prev is None:
            # training mode
            # |x| = (batch_size, n, hidden_size)
            z = self.masked_attn_norm(x)
            z = x + self.masked_attn_dropout(
                self.masked_attn(Q=z, K=z, V=z, mask=future_mask)
            )
        else:
            # inference mode
            # |x| = (batch_size, 1, hidden_size)
            # |prev| = (batch_size, ~t-1, hidden_size)
            normed_prev = self.masked_attn_norm(prev)
            z = self.masked_attn_norm(x)
            z = x + self.masked_attn_dropout(
                self.masked_attn(Q=z, K=normed_prev, V=normed_prev, mask=None)
            )

        normed_key_and_value = self.attn_norm(key_and_value)
        z = z + self.attn_dropout(
            self.attn(
                Q=self.attn_norm(z), K=normed_key_and_value, V=normed_key_and_value
            )
        )
        # |z| = (batch_size, m, hidden_size)
        z = z + self.fc_dropout(self.fc(self.fc_norm(z)))
        # |z| = (batch_size, m, hidden_size)

        return z, key_and_value, mask, prev, future_mask

In [60]:
decoder_block = CustomSequential(
    *[DecoderBlock(hidden_size, n_splits, 0.2) for i in range(6)]
)

In [63]:
h, _, _, _, _ = decoder_block(emb_y, z, mask_enc, None, future_mask)

In [64]:
h.size()

torch.Size([5, 51, 768])

### Generator

In [66]:
class Generator(nn.Module):
    def __init__(self, hidden_size, output_size):
        super().__init__()
        self.generator = nn.Sequential(
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, output_size),
            nn.LogSoftmax(dim=-1),
        )

    def forward(self, x):
        z = self.generator(x)
        return z

In [67]:
generator = Generator(hidden_size, output_size)
# generally, logit means the output before the softmax layer
# but here names it as logit temporarily
logit = generator(h)

In [70]:
logit.size()

torch.Size([5, 51, 53430])

In [73]:
y.size()

torch.Size([5, 51])

In [85]:
crit = nn.NLLLoss(ignore_index=0, reduction="sum")

weight = torch.ones(output_size)
weight[0] = 0
crit_trial = nn.NLLLoss(weight=torch.ones(output_size), reduction="sum")
loss = crit(
    # y_hat
    logit.contiguous().view(-1, logit.size(-1)),
    # y
    y.contiguous().view(-1),
)
loss.div(y.size(0))

tensor(277.8155, grad_fn=<DivBackward0>)

In [86]:
loss = crit_trial(
    # y_hat
    logit.contiguous().view(-1, logit.size(-1)),
    # y
    y.contiguous().view(-1),
)
loss.div(y.size(0))

tensor(572.0168, grad_fn=<DivBackward0>)