In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output

In [5]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

In [6]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

In [7]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [8]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

In [9]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output

In [21]:
src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 100
dropout = 0.1

transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

In [22]:
from tqdm import tqdm

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

for epoch in tqdm(range(4000)):
    optimizer.zero_grad()
    output = transformer(src_data, tgt_data[:, :-1])
    loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

  0%|          | 1/4000 [00:05<6:29:20,  5.84s/it]

Epoch: 1, Loss: 8.68260669708252


  0%|          | 2/4000 [00:11<6:24:49,  5.78s/it]

Epoch: 2, Loss: 8.552105903625488


  0%|          | 3/4000 [00:17<6:39:24,  6.00s/it]

Epoch: 3, Loss: 8.481342315673828


  0%|          | 4/4000 [15:43<409:41:52, 369.10s/it]

Epoch: 4, Loss: 8.430216789245605


  0%|          | 5/4000 [31:53<650:11:40, 585.91s/it]

Epoch: 5, Loss: 8.375986099243164


  0%|          | 6/4000 [1:06:07<1203:49:46, 1085.07s/it]

Epoch: 6, Loss: 8.312374114990234


  0%|          | 7/4000 [1:12:28<947:55:37, 854.63s/it]  

Epoch: 7, Loss: 8.240009307861328


  0%|          | 8/4000 [1:12:34<648:08:50, 584.50s/it]

Epoch: 8, Loss: 8.155845642089844


  0%|          | 9/4000 [1:12:39<447:21:09, 403.53s/it]

Epoch: 9, Loss: 8.07516860961914


  0%|          | 10/4000 [1:30:20<672:11:53, 606.49s/it]

Epoch: 10, Loss: 7.995457649230957


  0%|          | 11/4000 [1:45:48<781:06:57, 704.94s/it]

Epoch: 11, Loss: 7.9174675941467285


  0%|          | 12/4000 [2:03:33<902:06:42, 814.34s/it]

Epoch: 12, Loss: 7.834434509277344


  0%|          | 13/4000 [2:19:56<958:27:30, 865.43s/it]

Epoch: 13, Loss: 7.755948543548584


  0%|          | 14/4000 [2:37:00<1011:17:29, 913.36s/it]

Epoch: 14, Loss: 7.675112724304199


  0%|          | 15/4000 [2:37:06<708:18:34, 639.88s/it] 

Epoch: 15, Loss: 7.585361480712891


  0%|          | 16/4000 [2:53:43<827:07:14, 747.40s/it]

Epoch: 16, Loss: 7.504910469055176


  0%|          | 17/4000 [3:11:32<933:58:24, 844.16s/it]

Epoch: 17, Loss: 7.422988414764404


  0%|          | 18/4000 [3:26:59<961:05:51, 868.90s/it]

Epoch: 18, Loss: 7.344526290893555


  0%|          | 19/4000 [3:42:34<982:47:56, 888.74s/it]

Epoch: 19, Loss: 7.269855976104736


  0%|          | 20/4000 [3:58:44<1009:41:13, 913.28s/it]

Epoch: 20, Loss: 7.187863826751709


  1%|          | 21/4000 [4:13:47<1005:57:27, 910.14s/it]

Epoch: 21, Loss: 7.102153301239014


  1%|          | 22/4000 [4:28:40<999:49:51, 904.82s/it] 

Epoch: 22, Loss: 7.031797885894775


  1%|          | 23/4000 [4:28:45<701:29:49, 635.00s/it]

Epoch: 23, Loss: 6.949347019195557


  1%|          | 24/4000 [4:40:25<722:37:10, 654.28s/it]

Epoch: 24, Loss: 6.873929500579834


  1%|          | 25/4000 [4:40:30<507:31:41, 459.65s/it]

Epoch: 25, Loss: 6.799412727355957


  1%|          | 26/4000 [4:44:02<425:10:05, 385.15s/it]

Epoch: 26, Loss: 6.727956295013428


  1%|          | 27/4000 [5:00:19<621:18:33, 562.98s/it]

Epoch: 27, Loss: 6.649226665496826


  1%|          | 28/4000 [5:17:01<766:20:26, 694.57s/it]

Epoch: 28, Loss: 6.579718112945557


  1%|          | 29/4000 [5:26:39<727:27:24, 659.49s/it]

Epoch: 29, Loss: 6.511331081390381


  1%|          | 30/4000 [5:26:44<510:48:58, 463.21s/it]

Epoch: 30, Loss: 6.441677570343018


  1%|          | 31/4000 [5:29:38<414:57:50, 376.38s/it]

Epoch: 31, Loss: 6.369363307952881


  1%|          | 32/4000 [5:45:40<608:42:05, 552.25s/it]

Epoch: 32, Loss: 6.296319961547852


  1%|          | 33/4000 [6:03:32<780:07:47, 707.96s/it]

Epoch: 33, Loss: 6.228653907775879


  1%|          | 34/4000 [6:13:39<746:50:21, 677.92s/it]

Epoch: 34, Loss: 6.167646408081055


  1%|          | 35/4000 [6:13:45<524:38:35, 476.35s/it]

Epoch: 35, Loss: 6.09614372253418


  1%|          | 36/4000 [6:31:19<715:07:52, 649.46s/it]

Epoch: 36, Loss: 6.0328497886657715


  1%|          | 37/4000 [6:44:32<762:16:51, 692.46s/it]

Epoch: 37, Loss: 5.965193271636963


  1%|          | 38/4000 [6:59:03<821:06:58, 746.09s/it]

Epoch: 38, Loss: 5.900287628173828


  1%|          | 39/4000 [6:59:08<576:22:24, 523.84s/it]

Epoch: 39, Loss: 5.83700704574585


  1%|          | 40/4000 [7:14:19<703:55:10, 639.93s/it]

Epoch: 40, Loss: 5.776490688323975


  1%|          | 41/4000 [7:20:53<622:43:10, 566.25s/it]

Epoch: 41, Loss: 5.719441890716553


  1%|          | 42/4000 [7:20:59<437:38:28, 398.06s/it]

Epoch: 42, Loss: 5.651710510253906


  1%|          | 43/4000 [7:24:34<377:10:35, 343.15s/it]

Epoch: 43, Loss: 5.587913513183594


  1%|          | 44/4000 [7:31:05<393:04:52, 357.71s/it]

Epoch: 44, Loss: 5.5312957763671875


  1%|          | 45/4000 [7:31:11<276:58:27, 252.11s/it]

Epoch: 45, Loss: 5.465268135070801


  1%|          | 46/4000 [7:32:41<223:29:19, 203.48s/it]

Epoch: 46, Loss: 5.4042181968688965


  1%|          | 47/4000 [7:32:47<158:16:59, 144.15s/it]

Epoch: 47, Loss: 5.348357200622559


  1%|          | 48/4000 [7:32:52<112:30:39, 102.49s/it]

Epoch: 48, Loss: 5.2906060218811035


  1%|          | 49/4000 [7:32:57<80:26:56, 73.30s/it]  

Epoch: 49, Loss: 5.233519077301025


  1%|▏         | 50/4000 [7:46:31<323:56:45, 295.24s/it]

Epoch: 50, Loss: 5.17280387878418


  1%|▏         | 51/4000 [7:46:36<228:37:25, 208.42s/it]

Epoch: 51, Loss: 5.119107246398926


  1%|▏         | 52/4000 [8:02:22<471:12:07, 429.67s/it]

Epoch: 52, Loss: 5.059641361236572


  1%|▏         | 53/4000 [8:02:28<331:42:25, 302.55s/it]

Epoch: 53, Loss: 5.008759021759033


  1%|▏         | 54/4000 [8:02:40<236:02:34, 215.35s/it]

Epoch: 54, Loss: 4.942698001861572


  1%|▏         | 55/4000 [8:02:46<167:11:18, 152.57s/it]

Epoch: 55, Loss: 4.889657497406006


  1%|▏         | 56/4000 [8:10:16<264:42:01, 241.61s/it]

Epoch: 56, Loss: 4.8402299880981445


  1%|▏         | 57/4000 [8:10:21<187:03:37, 170.79s/it]

Epoch: 57, Loss: 4.785858631134033


  1%|▏         | 58/4000 [8:25:54<437:19:30, 399.38s/it]

Epoch: 58, Loss: 4.732645511627197


  1%|▏         | 59/4000 [8:30:45<401:38:26, 366.89s/it]

Epoch: 59, Loss: 4.68256950378418


  2%|▏         | 60/4000 [8:30:51<282:57:45, 258.54s/it]

Epoch: 60, Loss: 4.6314568519592285


  2%|▏         | 61/4000 [8:39:18<364:43:03, 333.33s/it]

Epoch: 61, Loss: 4.574556350708008


  2%|▏         | 62/4000 [8:39:24<257:08:43, 235.07s/it]

Epoch: 62, Loss: 4.525256633758545


  2%|▏         | 63/4000 [8:39:38<184:21:52, 168.58s/it]

Epoch: 63, Loss: 4.468155384063721


  2%|▏         | 64/4000 [8:46:31<264:38:39, 242.05s/it]

Epoch: 64, Loss: 4.420312404632568


  2%|▏         | 65/4000 [8:46:37<187:07:48, 171.20s/it]

Epoch: 65, Loss: 4.369113445281982


  2%|▏         | 66/4000 [8:48:22<165:30:45, 151.46s/it]

Epoch: 66, Loss: 4.312048435211182


  2%|▏         | 67/4000 [8:48:28<117:47:21, 107.82s/it]

Epoch: 67, Loss: 4.262474536895752


  2%|▏         | 68/4000 [8:48:34<84:09:28, 77.05s/it]  

Epoch: 68, Loss: 4.216115951538086


  2%|▏         | 69/4000 [8:57:32<235:19:17, 215.51s/it]

Epoch: 69, Loss: 4.175296783447266


  2%|▏         | 70/4000 [8:57:40<167:05:26, 153.06s/it]

Epoch: 70, Loss: 4.119456768035889


  2%|▏         | 71/4000 [8:57:46<118:53:40, 108.94s/it]

Epoch: 71, Loss: 4.079552173614502


  2%|▏         | 72/4000 [8:57:53<85:28:11, 78.33s/it]  

Epoch: 72, Loss: 4.023594856262207


  2%|▏         | 72/4000 [8:57:59<489:10:16, 448.32s/it]


KeyboardInterrupt: 

In [12]:
transformer.eval()
output = transformer(src_data[:1,:], tgt_data[:1, :-1])
output.view(-1,tgt_vocab_size).argmax(1)

tensor([2725, 2509,   50,   48,  860, 1493, 2740, 4440,  651, 2917, 1597, 3143,
        2902, 4723,  338, 4492, 2431, 1172, 1016,  629, 2854, 2594, 2187, 2077,
        2155, 2733, 3375, 4697, 4082,  967, 2986, 4075, 1924, 3582,  414, 3078,
        2465, 4838, 3529, 4824, 4668, 1041, 2531, 4153, 4742,  358, 4656,   84,
         982, 4779,  697, 3157, 2990, 1850, 3642, 4213, 4729, 1891,  633, 3078,
        3518, 3329, 4391, 4929, 4914, 3762, 2491, 1770,  537,  490, 3281,  264,
        3715, 4299, 1384, 1168,   60, 3037, 4944,  793,  874, 4498, 1515, 3951,
        4356, 3468, 3317, 1402, 4541, 4867, 1666, 1450, 2584, 1315, 4093, 2736,
        2337, 4941, 1921])

In [13]:
tgt_data[:1, :-1]

tensor([[4479, 2725, 2509,   50,   48,  860, 1493, 2740, 4440,  651, 2917, 1597,
         3143, 2902, 4723,  338, 4492, 2431, 1172, 1016,  629, 2854, 2594, 2187,
         2077, 2155, 2733, 3375, 4697, 4082,  967, 2986, 4075, 1924, 3582,  414,
          811, 2465, 4838, 3529, 4824, 4668, 1041, 2531, 4153, 4742,  358, 4656,
           84,  982, 4779,  697, 3157, 2990, 1850, 3642, 4213, 4729, 1891,  633,
         3078, 3518, 3329, 4391, 4929, 4914, 3762, 2491, 1770,  537,  490, 3281,
         2836, 3715, 4299, 1384, 1168,   60, 3037, 4944,  793,  874, 4498, 1248,
         3951, 4356, 3468, 3317, 1402, 4541, 4867, 1666, 1450, 2584, 1315, 4093,
         2736, 2337, 4941]])