In [20]:
import os
from os.path import exists
import torch
import torch.nn as nn
from torch.nn.functional import log_softmax, pad
import math
import copy
import time
from torch.optim.lr_scheduler import LambdaLR
import pandas as pd
import altair as alt
from torchtext.data.functional import to_map_style_dataset
from torch.utils.data import DataLoader
from torchtext.vocab import build_vocab_from_iterator
import torchtext.datasets as datasets
import warnings
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import numpy as np


# Set to False to skip notebook execution (e.g. for debugging)
warnings.filterwarnings("ignore")
RUN_EXAMPLES = True

In [4]:
def is_interactive_notebook():
    return __name__ == "__main__"


def show_example(fn, args=[]):
    if __name__ == "__main__" and RUN_EXAMPLES:
        return fn(*args)


def execute_example(fn, args=[]):
    if __name__ == "__main__" and RUN_EXAMPLES:
        fn(*args)


class DummyOptimizer(torch.optim.Optimizer):
    def __init__(self):
        self.param_groups = [{"lr": 0}]
        None

    def step(self):
        None

    def zero_grad(self, set_to_none=False):
        None


class DummyScheduler:
    def step(self):
        None

In [16]:

def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
        torch.uint8
    )
    return subsequent_mask == 0

def get_attn_subsequent_mask(size):
    attn_shape = [1, size, size]
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1)

    return subsequent_mask

In [18]:
def example_mask():
    LS_data = pd.concat(
        [
            pd.DataFrame(
                {
                    "Subsequent Mask": get_attn_subsequent_mask(20)[0][x, y].flatten(),
                    "Window": y,
                    "Masking": x,
                }
            )
            for y in range(20)
            for x in range(20)
        ]
    )

    return (
        alt.Chart(LS_data)
        .mark_rect()
        .properties(height=250, width=250)
        .encode(
            alt.X("Window:O"),
            alt.Y("Masking:O"),
            alt.Color("Subsequent Mask:Q", scale=alt.Scale(scheme="viridis")),
        )
        .interactive()
    )


show_example(example_mask)

In [24]:
class ScaledDotProductAttention(nn.Module):
    """ Scaled Dot-Product Attention """

    def __init__(self, scale):
        super().__init__()

        self.scale = scale
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, mask=None):
        u = torch.bmm(q, k.transpose(1, 2)) # 1.Matmul
        u = u / self.scale # 2.Scale

        if mask is not None:
            u = u.masked_fill(mask, -np.inf) # 3.Mask

        attn = self.softmax(u) # 4.Softmax
        output = torch.bmm(attn, v) # 5.Output

        return attn, output


if __name__ == "__main__":
    n_q, n_k, n_v = 2, 4, 4
    d_q, d_k, d_v = 24, 24, 12

    q = torch.randn(5, n_q, d_q)
    k = torch.randn(5, n_k, d_k)
    v = torch.randn(5, n_v, d_v)

    mask = torch.zeros(5, 1, 1).bool()

    attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))
    attn, output = attention(q, k, v, mask=mask)

    print(attn)
    print(output)

tensor([[[0.0796, 0.2618, 0.5120, 0.1466],
         [0.0192, 0.0524, 0.0691, 0.8592]],

        [[0.0994, 0.2791, 0.5432, 0.0783],
         [0.2387, 0.5953, 0.0009, 0.1652]],

        [[0.1681, 0.3849, 0.2342, 0.2128],
         [0.1639, 0.3344, 0.3052, 0.1965]],

        [[0.6775, 0.0996, 0.0253, 0.1975],
         [0.1037, 0.7357, 0.0610, 0.0996]],

        [[0.0344, 0.0306, 0.4854, 0.4497],
         [0.2202, 0.5385, 0.2070, 0.0343]]])
tensor([[[ 6.6827e-01,  1.3011e-01,  9.1958e-02, -1.9338e-01, -1.2272e+00,
          -2.2763e-01, -9.3714e-02,  8.5295e-02,  5.5590e-01, -3.5056e-02,
          -2.6613e-01,  5.0651e-02],
         [-3.5936e-01,  1.9435e-01,  1.1814e+00, -6.9755e-01,  5.7582e-02,
           1.3883e+00, -1.1081e+00, -1.0341e+00,  5.4899e-01, -6.0826e-01,
          -5.3090e-01,  1.8254e+00]],

        [[-4.5358e-01, -5.5667e-01,  3.1900e-01,  9.1825e-01,  8.8912e-01,
           7.7520e-02, -3.6341e-01, -4.2993e-01, -1.0364e+00, -5.7403e-02,
           2.6304e-01,  1.8674e-01