In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import math
import numpy as np
from tqdm import tqdm

In [2]:
torch.manual_seed(42)  # this is to ensure deterministic behavior

<torch._C.Generator at 0x7f076dadbd10>

In [3]:
def positional_encoding_function(maxL, D):
        pe = torch.zeros(maxL, D)
            # pe measures (maxL, D)
        position = torch.arange(0, maxL).unsqueeze(1)
            # (maxL) --> (maxL, 1) via unsqueeze(1)
        coeff = torch.exp(torch.arange(0, D, 2).float() * -(math.log(10000.0) / D))

        pe[:, 0::2] = torch.sin(position * coeff)  # fill in even positions
        pe[:, 1::2] = torch.cos(position * coeff)  # fill in odd positions
        return pe

In [4]:
positional_encoding_function(10, 4).shape

torch.Size([10, 4])

In [5]:

class PositionalEncoding(nn.Module):
    def __init__(self, D, maxL):
        super(PositionalEncoding, self).__init__()
        self.pe = nn.Parameter(positional_encoding_function(maxL, D).unsqueeze(0), requires_grad=False) # FIXME
    def forward(self, x):
        return x  + self.pe[:,:x.size(1)]# FIXME

In [7]:
class FeedForward(nn.Module):
    def __init__(self, D, D_ff):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(D,D_ff) # FIXME
        self.fc2 = nn.Linear(D_ff, D) # FIXME
        self.relu = nn.ReLU() # FIXME

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        return self.fc2(x) # FIXME

In [27]:
def split_heads(input, H):  # tensor shape is [B, L, D]
    Dh = input.shape[2] // H
    B, L, D = input.shape
    input = input.view(B, L, H, Dh)
    return input.permute(0, 2, 1, 3).contiguous() # FIXME

def combine_heads(input):  # tensor shape is [B, H, L, D]
    B, H, L, D = input.shape
    input = input.permute(0, 2, 1, 3).contiguous()
    return input.view(B, L, H * D)
     # FIXME

In [29]:
def generate_causality_mask(target):
    L = target.size(1)
    return torch.tril(torch.ones(1, L, L), diagonal=0).int()

In [73]:
def scaled_dot_product_attention(Q, K, V, d_k, mask=None, negative_infinity=-1e9):
    # print(Q,K)
    attn_scores = torch.matmul(Q,K.transpose(-2,-1)) / (d_k**0.5) # FIXME (step 1)
    if mask is not None:
      attn_scores = torch.where(mask == 0, negative_infinity,attn_scores) # FIXME (step 2)
    attn_probabilities = nn.Softmax(dim=-1)
    attn_probabilities = attn_probabilities(attn_scores)
    output = torch.matmul(attn_probabilities, V) # FIXME (step 4)
    return output

In [75]:

class MultiHeadAttention(nn.Module):
    def __init__(self, D, H):
        super(MultiHeadAttention, self).__init__()
        assert D % H == 0, "D must be divisible by H"
        self.D = D
        self.H = H
        self.Dh = D//H
        self.W_q = nn.Linear(D, D)
        self.W_k = nn.Linear(D, D)
        self.W_v = nn.Linear(D, D)
        self.W_o = nn.Linear(D, D)
    def forward(self, Q, K, V, mask=None):
        Q = split_heads(self.W_q(Q), self.H)
        K = split_heads(self.W_k(K), self.H)
        V = split_heads(self.W_v(V), self.H)

        output = scaled_dot_product_attention(Q, K, V, self.Dh, mask, negative_infinity=-1e9)
        output = combine_heads(output)

        output = self.W_o(output)
        return output # FIXME

In [77]:

class EncoderLayer(nn.Module):
    def __init__(self, D, H, D_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(D,H) # FIXME
        self.feed_forward =  FeedForward(D, D_ff)# FIXME
        self.norm1 = nn.LayerNorm(D) # FIXME
        self.norm2 = nn.LayerNorm(D) # FIXME
        self.dropout = nn.Dropout(dropout) # FIXME

    def forward(self, x, mask=None):
        attention = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attention))
        ff = self.feed_forward(x)

        return self.norm2(x + self.dropout(ff)) # FIXME


In [79]:

class DecoderLayer(nn.Module):
    def __init__(self, D, H, D_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(D, H) # FIXME
        self.cross_attention = MultiHeadAttention(D, H) # FIXME
        self.feed_forward = FeedForward(D, D_ff) # FIXME
        self.norm1 = nn.LayerNorm(D) # FIXME
        self.norm2 = nn.LayerNorm(D) # FIXME
        self.norm3 = nn.LayerNorm(D) # FIXME
        self.dropout = nn.Dropout(dropout) # FIXME

    def forward(self, x, encoder_output, self_mask, cross_mask=None):
        self_attention = self.self_attention(x, x, x, self_mask)
        x = self.norm1(x + self.dropout(self_attention))
        cross_attention = self.cross_attention(x, encoder_output, encoder_output, cross_mask)
        x = self.norm2(x + self.dropout(cross_attention))
        ff = self.feed_forward(x)
        return self.norm3(x + self.dropout(ff))


In [81]:

class Transformer(nn.Module):
    def __init__(self, source_vocab_size, target_vocab_size, D,
                H, Nx, D_ff, maxL, dropout):
        super(Transformer, self).__init__()
        self.encoder_embedding = nn.Embedding(source_vocab_size, D) # FIXME
        self.decoder_embedding = nn.Embedding(target_vocab_size, D) # FIXME
        self.positional_encoding = PositionalEncoding(D, maxL) # FIXME
        self.encoder_layers = nn.ModuleList([EncoderLayer(D, H, D_ff, dropout) for _ in range(Nx)]) # FIXME
        self.decoder_layers = nn.ModuleList([DecoderLayer(D, H, D_ff, dropout) for _ in range(Nx)]) # FIXME
        self.fc = nn.Linear(D, target_vocab_size) # FIXME
        self.dropout = nn.Dropout(dropout) # FIXME

    def forward(self, source, target):
        source = self.positional_encoding(self.dropout(self.encoder_embedding(source)))
        target = self.positional_encoding(self.dropout(self.decoder_embedding(target)))
        causality_mask = generate_causality_mask(target)
        encoder_output = source
        for encoder_layer in self.encoder_layers:
            encoder_output = encoder_layer(encoder_output)
        decoder_output = target
        for decoder_layer in self.decoder_layers:
            decoder_output = decoder_layer(decoder_output, encoder_output, causality_mask)
        output = self.fc(decoder_output)
        return output


In [68]:
source_vocab_size = 5000
target_vocab_size = 5000
D = 48
H = 3
Nx = 3
D_ff = 128
maxL = 20
dropout = 0.1

transformer = Transformer(source_vocab_size, target_vocab_size,
                         D, H, Nx, D_ff, maxL, dropout)

In [69]:
B = 64
Ls = maxL  # arbitrary, but must be > 0 and <= maxL
Lt = maxL  # arbitrary, but must be > 0 and <= maxL
source_data = torch.randint(1, source_vocab_size, (B, Ls))
target_data = torch.randint(1, target_vocab_size, (B, Lt))

# generate a [B, L] tensor whose values are [1, source_vocab_size)
# generate a [B, L] tensor whose values are [1, target_vocab_size)

In [70]:
criterion = nn.CrossEntropyLoss(ignore_index=0)


optimizer = optim.Adam(transformer.parameters(),
                      lr=1e-4, betas=(0.9, 0.98), eps=1e-9)

transformer.train()
for epoch in tqdm(range(5001)):
    optimizer.zero_grad()
    output = transformer(source_data, target_data[:, :-1])
    loss = criterion(output.reshape(
        -1, target_vocab_size), target_data[:, 1:].reshape(-1))
    loss.backward()
    optimizer.step()
    if epoch % 50 == 0:
        tqdm.write(f'Epoch: {epoch+1}, Loss: {loss.item()})')

  0%|          | 2/5001 [00:00<18:32,  4.49it/s]

Epoch: 1, Loss: 8.70529842376709)


  1%|          | 52/5001 [00:09<13:19,  6.19it/s]

Epoch: 51, Loss: 8.294711112976074)


  2%|▏         | 102/5001 [00:18<15:07,  5.40it/s]

Epoch: 101, Loss: 7.91542387008667)


  3%|▎         | 152/5001 [00:26<13:28,  6.00it/s]

Epoch: 151, Loss: 7.568960189819336)


  4%|▍         | 202/5001 [00:35<13:10,  6.07it/s]

Epoch: 201, Loss: 7.234884262084961)


  5%|▌         | 252/5001 [00:44<12:44,  6.21it/s]

Epoch: 251, Loss: 6.934198379516602)


  6%|▌         | 302/5001 [00:52<12:56,  6.05it/s]

Epoch: 301, Loss: 6.630430221557617)


  7%|▋         | 352/5001 [01:01<11:53,  6.51it/s]

Epoch: 351, Loss: 6.353003978729248)


  8%|▊         | 402/5001 [01:09<12:23,  6.18it/s]

Epoch: 401, Loss: 6.088189125061035)


  9%|▉         | 452/5001 [01:17<12:24,  6.11it/s]

Epoch: 451, Loss: 5.82076358795166)


 10%|█         | 502/5001 [01:26<12:24,  6.04it/s]

Epoch: 501, Loss: 5.5828728675842285)


 11%|█         | 552/5001 [01:35<11:48,  6.28it/s]

Epoch: 551, Loss: 5.343811988830566)


 12%|█▏        | 602/5001 [01:43<11:36,  6.32it/s]

Epoch: 601, Loss: 5.086688995361328)


 13%|█▎        | 652/5001 [01:52<11:16,  6.43it/s]

Epoch: 651, Loss: 4.856956481933594)


 14%|█▍        | 702/5001 [02:00<11:08,  6.43it/s]

Epoch: 701, Loss: 4.612321376800537)


 15%|█▌        | 752/5001 [02:08<11:13,  6.31it/s]

Epoch: 751, Loss: 4.380451202392578)


 16%|█▌        | 802/5001 [02:17<11:25,  6.13it/s]

Epoch: 801, Loss: 4.157071590423584)


 17%|█▋        | 852/5001 [02:26<10:49,  6.39it/s]

Epoch: 851, Loss: 3.9349348545074463)


 18%|█▊        | 902/5001 [02:33<10:55,  6.25it/s]

Epoch: 901, Loss: 3.7115108966827393)


 19%|█▉        | 952/5001 [02:42<10:35,  6.38it/s]

Epoch: 951, Loss: 3.496946334838867)


 20%|██        | 1002/5001 [02:51<10:40,  6.24it/s]

Epoch: 1001, Loss: 3.2868130207061768)


 21%|██        | 1052/5001 [02:59<10:46,  6.10it/s]

Epoch: 1051, Loss: 3.0911567211151123)


 22%|██▏       | 1102/5001 [03:07<10:33,  6.15it/s]

Epoch: 1101, Loss: 2.881333827972412)


 23%|██▎       | 1152/5001 [03:16<10:23,  6.17it/s]

Epoch: 1151, Loss: 2.6833972930908203)


 24%|██▍       | 1202/5001 [03:24<09:49,  6.44it/s]

Epoch: 1201, Loss: 2.5074496269226074)


 25%|██▌       | 1252/5001 [03:32<10:16,  6.08it/s]

Epoch: 1251, Loss: 2.3121111392974854)


 26%|██▌       | 1302/5001 [03:41<10:25,  5.91it/s]

Epoch: 1301, Loss: 2.112720012664795)


 27%|██▋       | 1352/5001 [03:49<09:32,  6.38it/s]

Epoch: 1351, Loss: 1.9567716121673584)


 28%|██▊       | 1402/5001 [03:58<09:27,  6.34it/s]

Epoch: 1401, Loss: 1.8086856603622437)


 29%|██▉       | 1452/5001 [04:06<09:05,  6.51it/s]

Epoch: 1451, Loss: 1.632217288017273)


 30%|███       | 1502/5001 [04:14<09:03,  6.44it/s]

Epoch: 1501, Loss: 1.4825083017349243)


 31%|███       | 1552/5001 [04:23<08:49,  6.51it/s]

Epoch: 1551, Loss: 1.336799144744873)


 32%|███▏      | 1602/5001 [04:32<09:06,  6.22it/s]

Epoch: 1601, Loss: 1.2205369472503662)


 33%|███▎      | 1652/5001 [04:39<08:45,  6.38it/s]

Epoch: 1651, Loss: 1.0917450189590454)


 34%|███▍      | 1702/5001 [04:48<08:52,  6.19it/s]

Epoch: 1701, Loss: 0.9865536689758301)


 35%|███▌      | 1752/5001 [04:57<08:31,  6.35it/s]

Epoch: 1751, Loss: 0.8720089793205261)


 36%|███▌      | 1802/5001 [05:04<08:17,  6.43it/s]

Epoch: 1801, Loss: 0.7752055525779724)


 37%|███▋      | 1852/5001 [05:13<08:10,  6.42it/s]

Epoch: 1851, Loss: 0.6940938830375671)


 38%|███▊      | 1902/5001 [05:21<08:40,  5.96it/s]

Epoch: 1901, Loss: 0.6052128672599792)


 39%|███▉      | 1952/5001 [05:29<07:51,  6.47it/s]

Epoch: 1951, Loss: 0.5300348997116089)


 40%|████      | 2002/5001 [05:38<08:10,  6.11it/s]

Epoch: 2001, Loss: 0.4623832106590271)


 41%|████      | 2052/5001 [05:46<07:43,  6.37it/s]

Epoch: 2051, Loss: 0.41113871335983276)


 42%|████▏     | 2102/5001 [05:54<07:19,  6.60it/s]

Epoch: 2101, Loss: 0.35521364212036133)


 43%|████▎     | 2152/5001 [06:03<07:35,  6.25it/s]

Epoch: 2151, Loss: 0.3154280185699463)


 44%|████▍     | 2202/5001 [06:11<07:33,  6.17it/s]

Epoch: 2201, Loss: 0.27811896800994873)


 45%|████▌     | 2252/5001 [06:19<07:14,  6.33it/s]

Epoch: 2251, Loss: 0.24202246963977814)


 46%|████▌     | 2302/5001 [06:28<07:06,  6.32it/s]

Epoch: 2301, Loss: 0.2085905224084854)


 47%|████▋     | 2352/5001 [06:36<07:28,  5.91it/s]

Epoch: 2351, Loss: 0.18693427741527557)


 48%|████▊     | 2402/5001 [06:44<06:54,  6.28it/s]

Epoch: 2401, Loss: 0.15979869663715363)


 49%|████▉     | 2452/5001 [06:53<06:51,  6.19it/s]

Epoch: 2451, Loss: 0.13768628239631653)


 50%|█████     | 2502/5001 [07:01<06:56,  6.00it/s]

Epoch: 2501, Loss: 0.12140591442584991)


 51%|█████     | 2552/5001 [07:09<06:19,  6.46it/s]

Epoch: 2551, Loss: 0.10767857730388641)


 52%|█████▏    | 2602/5001 [07:18<06:28,  6.18it/s]

Epoch: 2601, Loss: 0.09356371313333511)


 53%|█████▎    | 2652/5001 [07:26<06:41,  5.85it/s]

Epoch: 2651, Loss: 0.08252225816249847)


 54%|█████▍    | 2702/5001 [07:34<06:11,  6.18it/s]

Epoch: 2701, Loss: 0.07142046093940735)


 55%|█████▌    | 2752/5001 [07:43<06:04,  6.17it/s]

Epoch: 2751, Loss: 0.06499768048524857)


 56%|█████▌    | 2802/5001 [07:52<06:09,  5.95it/s]

Epoch: 2801, Loss: 0.054833751171827316)


 57%|█████▋    | 2852/5001 [08:00<06:04,  5.89it/s]

Epoch: 2851, Loss: 0.04737318679690361)


 58%|█████▊    | 2902/5001 [08:08<05:23,  6.49it/s]

Epoch: 2901, Loss: 0.042736656963825226)


 59%|█████▉    | 2952/5001 [08:17<05:42,  5.98it/s]

Epoch: 2951, Loss: 0.04106326028704643)


 60%|██████    | 3002/5001 [08:25<05:10,  6.43it/s]

Epoch: 3001, Loss: 0.03262796625494957)


 61%|██████    | 3052/5001 [08:33<05:00,  6.48it/s]

Epoch: 3051, Loss: 0.030044326558709145)


 62%|██████▏   | 3102/5001 [08:42<05:42,  5.54it/s]

Epoch: 3101, Loss: 0.025561407208442688)


 63%|██████▎   | 3152/5001 [08:50<04:59,  6.17it/s]

Epoch: 3151, Loss: 0.021972689777612686)


 64%|██████▍   | 3202/5001 [08:59<04:51,  6.18it/s]

Epoch: 3201, Loss: 0.021016092970967293)


 65%|██████▌   | 3252/5001 [09:07<04:54,  5.94it/s]

Epoch: 3251, Loss: 0.017080729827284813)


 66%|██████▌   | 3302/5001 [09:15<04:20,  6.51it/s]

Epoch: 3301, Loss: 0.01591116562485695)


 67%|██████▋   | 3352/5001 [09:23<04:13,  6.51it/s]

Epoch: 3351, Loss: 0.014211338944733143)


 68%|██████▊   | 3402/5001 [09:32<05:10,  5.16it/s]

Epoch: 3401, Loss: 0.012238191440701485)


 69%|██████▉   | 3452/5001 [09:40<04:11,  6.17it/s]

Epoch: 3451, Loss: 0.011137500405311584)


 70%|███████   | 3502/5001 [09:48<04:07,  6.05it/s]

Epoch: 3501, Loss: 0.009886584244668484)


 71%|███████   | 3552/5001 [09:57<04:56,  4.89it/s]

Epoch: 3551, Loss: 0.00826592743396759)


 72%|███████▏  | 3602/5001 [10:05<03:42,  6.30it/s]

Epoch: 3601, Loss: 0.0077231526374816895)


 73%|███████▎  | 3652/5001 [10:13<03:38,  6.16it/s]

Epoch: 3651, Loss: 0.00714213028550148)


 74%|███████▍  | 3701/5001 [10:22<04:55,  4.41it/s]

Epoch: 3701, Loss: 0.006985476706176996)


 75%|███████▌  | 3752/5001 [10:30<03:16,  6.34it/s]

Epoch: 3751, Loss: 0.005265434738248587)


 76%|███████▌  | 3802/5001 [10:38<03:06,  6.44it/s]

Epoch: 3801, Loss: 0.0047949813306331635)


 77%|███████▋  | 3851/5001 [10:46<04:11,  4.58it/s]

Epoch: 3851, Loss: 0.004261108580976725)


 78%|███████▊  | 3902/5001 [10:54<02:58,  6.16it/s]

Epoch: 3901, Loss: 0.00403187470510602)


 79%|███████▉  | 3952/5001 [11:03<02:53,  6.05it/s]

Epoch: 3951, Loss: 0.0034941101912409067)


 80%|████████  | 4001/5001 [11:11<03:44,  4.46it/s]

Epoch: 4001, Loss: 0.0032519937958568335)


 81%|████████  | 4052/5001 [11:20<02:28,  6.40it/s]

Epoch: 4051, Loss: 0.0027027397882193327)


 82%|████████▏ | 4102/5001 [11:29<02:26,  6.12it/s]

Epoch: 4101, Loss: 0.0024673976004123688)


 83%|████████▎ | 4151/5001 [11:37<03:08,  4.50it/s]

Epoch: 4151, Loss: 0.0022938568145036697)


 84%|████████▍ | 4202/5001 [11:45<02:05,  6.35it/s]

Epoch: 4201, Loss: 0.002025703666731715)


 85%|████████▌ | 4252/5001 [11:54<02:02,  6.13it/s]

Epoch: 4251, Loss: 0.0019277480896562338)


 86%|████████▌ | 4301/5001 [12:02<02:42,  4.32it/s]

Epoch: 4301, Loss: 0.0017545847222208977)


 87%|████████▋ | 4352/5001 [12:10<01:39,  6.54it/s]

Epoch: 4351, Loss: 0.0014713300624862313)


 88%|████████▊ | 4402/5001 [12:19<01:34,  6.31it/s]

Epoch: 4401, Loss: 0.0013264266308397055)


 89%|████████▉ | 4451/5001 [12:27<02:06,  4.34it/s]

Epoch: 4451, Loss: 0.0012226589024066925)


 90%|█████████ | 4502/5001 [12:35<01:17,  6.46it/s]

Epoch: 4501, Loss: 0.0010951567674055696)


 91%|█████████ | 4552/5001 [12:44<01:10,  6.36it/s]

Epoch: 4551, Loss: 0.000951902475208044)


 92%|█████████▏| 4601/5001 [12:52<01:32,  4.33it/s]

Epoch: 4601, Loss: 0.0008752172579988837)


 93%|█████████▎| 4652/5001 [13:01<00:54,  6.35it/s]

Epoch: 4651, Loss: 0.0008192126406356692)


 94%|█████████▍| 4702/5001 [13:09<00:49,  6.07it/s]

Epoch: 4701, Loss: 0.0007221954874694347)


 95%|█████████▌| 4751/5001 [13:18<00:56,  4.45it/s]

Epoch: 4751, Loss: 0.0006916519487276673)


 96%|█████████▌| 4802/5001 [13:26<00:32,  6.09it/s]

Epoch: 4801, Loss: 0.0005749334814026952)


 97%|█████████▋| 4852/5001 [13:35<00:22,  6.48it/s]

Epoch: 4851, Loss: 0.0005222385516390204)


 98%|█████████▊| 4901/5001 [13:43<00:22,  4.49it/s]

Epoch: 4901, Loss: 0.0004795412824023515)


 99%|█████████▉| 4952/5001 [13:51<00:07,  6.16it/s]

Epoch: 4951, Loss: 0.00048454420175403357)


100%|██████████| 5001/5001 [14:00<00:00,  5.95it/s]

Epoch: 5001, Loss: 0.000415522517869249)



