In [1]:
from sympy.printing.pytorch import torch
from torch.nn.functional import dropout

from model import *
from tokenizer import *

In [67]:
class TransformerEncoderModel(nn.Module):
    def __init__(self, d_model, layers, heads, d_ff, dropout, vocab_size, seq_len):
        super().__init__()
        self.embedding = InputEmbeddings(d_model, vocab_size)
        self.position_enc = PositionalEncoding(d_model, seq_len, dropout)
        self.projection = ProjectionLayer(d_model, vocab_size)

        encoder_blocks = []
        for _ in range(layers):
            encoder_self_attention_block = MultiHeadAttentionBlock(d_model, heads, dropout)
            feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
            encoder_block = EncoderBlock(d_model, encoder_self_attention_block, feed_forward_block, dropout)
            encoder_blocks.append(encoder_block)

        self.encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))

    def forward(self, x, mask):
        x = self.embedding(x)
        x = self.position_enc(x.unsqueeze(0))
        x = self.encoder(x, mask)
        return self.projection(x)

In [68]:
transencoder = TransformerEncoderModel(64, 6, 8, 256, 0.1, 65536, 1000)

In [69]:
tokens = tokenizer.encode("/Users/daniilogorodnikov/Documents/1706.03762v7.pdf")
mask = torch.triu(torch.ones((1, 1000, 1000)), diagonal=1).type(torch.int)

In [71]:
out = transencoder.forward(torch.tensor(tokens[:1000]), mask)

In [85]:
tokens[:1000]

[9552,
 17478,
 11569,
 11829,
 2597,
 36618,
 12595,
 14112,
 12320,
 28514,
 27146,
 15420,
 8239,
 18025,
 27764,
 25970,
 8239,
 18028,
 24948,
 25924,
 25955,
 28516,
 25888,
 12108,
 25966,
 26484,
 26656,
 13360,
 12593,
 8254,
 15882,
 29556,
 29285,
 24941,
 2680,
 55989,
 23115,
 37860,
 46737,
 48875,
 22516,
 37141,
 53701,
 9566,
 31848,
 25359,
 37443,
 11131,
 9012,
 57776,
 31311,
 58174,
 45227,
 53333,
 62640,
 51222,
 8042,
 36222,
 48435,
 37129,
 4704,
 29030,
 48571,
 63113,
 8336,
 50211,
 40799,
 9880,
 15708,
 7858,
 34591,
 49097,
 63737,
 64979,
 14335,
 63135,
 26097,
 8260,
 23065,
 8991,
 40542,
 8004,
 22078,
 5154,
 19251,
 21820,
 15517,
 7998,
 9954,
 28692,
 22934,
 9599,
 60070,
 57632,
 51876,
 16335,
 42921,
 59835,
 50163,
 54143,
 49607,
 64161,
 19115,
 23782,
 61147,
 13459,
 12399,
 63008,
 56372,
 31081,
 21093,
 62659,
 54029,
 42745,
 35645,
 52675,
 24763,
 59760,
 38073,
 20026,
 15119,
 30187,
 56275,
 64321,
 25673,
 16252,
 6687,
 4548

In [74]:
pred = torch.argmax(out, dim=-1)

In [77]:
residuals = torch.tensor(tokens[:1000]) - pred

In [78]:
residuals

tensor([[-45836,  -6708, -17723, -44497, -35813, -17875,  -1307, -14099, -31154,
         -33417, -14162, -18901, -24148,  -6494, -29836, -27216, -24148, -26073,
          -2852,  -1892,  23723, -13040,  -9375, -28869,  12461, -38479,    633,
         -47104, -15329, -50289,   3991,  -6098, -15892,  -9453, -14597,  22084,
          19048,  23105,  35134,  31357, -32289,  28786,  40230, -25862, -18731,
         -36754, -17387, -15330, -36641,  39046,  -6846,  56413,   4188,   7255,
          61484,  -7123, -36440,  24761,  16994,  27083, -36153,  -5636,  44779,
          36012, -36586,   -615,   6497, -34771, -40615,   5665,  30367,  -3568,
           7267,  32953, -42595,  61052, -12836, -11502,   9796, -13569,  28498,
         -57520, -24683, -13210,  -7703, -31349, -37077,  -9954,   1186, -15549,
         -29917, -52683,  55134,  22166,  17883, -39389, -11307,  30600,  34118,
          -9576,  30699,  16358, -41208, -29333,   9667, -21360, -24029,  44138,
          16622,  -8968,   -

In [79]:
def quantize_residuals(residuals, bits=8):
    min_val = residuals.min()
    max_val = residuals.max()
    scale = (max_val - min_val) / (2**bits - 1)
    quantized = torch.round((residuals - min_val) / scale).byte()
    return quantized, scale, min_val

quantized_residuals, scale, min_val = quantize_residuals(residuals)

In [83]:
def decode(quantized_residuals, predicted_tokens, scale, min_val):
    dequantized = quantized_residuals.float() * scale + min_val
    reconstructed_tokens = predicted_tokens + dequantized.round().long()
    return reconstructed_tokens

In [84]:
decode(quantized_residuals, pred, scale, min_val)

tensor([[ 9371, 17338, 11397, 11816,  2437, 36598, 12578, 14333, 12522, 28468,
         26928, 15421,  8466, 18174, 27653, 25749,  8466, 18171, 24970, 25990,
         26017, 28682, 25905, 12034, 25740, 26479, 26708, 13443, 12538,  8007,
         16091, 29309, 29290, 25036,  2897, 56183, 23332, 38038, 46938, 48835,
         22347, 37161, 53827,  9498, 31679, 25135, 37437, 11077,  9177, 57580,
         31309, 58187, 45239, 53291, 62603, 50995,  8006, 36250, 48195, 37346,
          4884, 28823, 48668, 62938,  8446, 50004, 41013,  9682, 15830,  7900,
         34537, 49333, 63683, 64850, 14428, 63028, 26059,  8395, 22993,  9184,
         40348,  7957, 21835,  4988, 19102, 21715, 15616,  8091,  9955, 28857,
         22904,  9737, 59855, 57744, 51752, 16235, 42861, 60050, 50375, 54361,
         49723, 64055, 19328, 23670, 61204, 13409, 12507, 63244, 56504, 31193,
         21055, 62676, 53904, 42561, 35622, 52847, 24964, 59660, 38249, 19782,
         15181, 30007, 56349, 64265, 25466, 16125,  

In [28]:
N = 6
d_model = 64
h = 8
dropout = 0.1
d_ff = 256

In [29]:
encoder_blocks = []
for _ in range(N):
    encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
    feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
    encoder_block = EncoderBlock(d_model, encoder_self_attention_block, feed_forward_block, dropout)
    encoder_blocks.append(encoder_block)

In [30]:
encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))

In [35]:
mask = torch.triu(torch.ones((1, 1000, 1000)), diagonal=1).type(torch.int)

In [36]:
src.shape

torch.Size([1, 1000, 64])

In [44]:
enc_out = encoder(src, mask)

In [45]:
bottleneck = nn.Linear(64, 32)

In [51]:
x = bottleneck(enc_out)

In [50]:
projection_layer = ProjectionLayer(64, 65536)

In [55]:
x = projection_layer(enc_out[:, -1])

In [57]:
_, next_word = torch.max(x, dim=1)

In [58]:
decoder_input = torch.cat([enc_out, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1)

tensor([3419])