In [1]:
from zipfile import ZipFile
import pandas as pd
import numpy as np

In [75]:
import math
import os
from tempfile import TemporaryDirectory
from typing import Tuple

import torch
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

class TransformerModel(nn.Module):

    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
                 nlayers: int, dropout: float = 0.5):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.embedding = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.linear = nn.Linear(d_model, ntoken)
        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def get_src_mask(self, size) -> Tensor:
        """
        Generates a squeare matrix where the each row allows one word more to be seen
        src shape: [seq_length,batch_size]
        then mask shape: [seq_length,seq_length]
        e.g. for seq_length = 5,
        output = 
        [[0., 0., 0., 0., 0.],
        [-inf, 0., 0., 0., 0.],
        [-inf, -inf, 0., 0., 0.],
        [-inf, -inf, -inf, 0., 0.],
        [-inf, -inf, -inf, -inf, 0.]]

        with shape [5,5]
        
        """
        mask = torch.triu(torch.ones(size, size) == 1) # Lower triangular matrix
        mask = mask.float()
        mask = mask.masked_fill(mask == 0, float('-inf')) # Convert zeros to -inf
        mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0
        return mask

    def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
        """
        Arguments:
            src: Tensor, shape ``[seq_len, batch_size]``
            src_mask: Tensor, shape ``[seq_len, seq_len]``

        Returns:
            output Tensor of shape ``[seq_len, batch_size, ntoken]``
        """
        src_length,batch_size = src.size()
        src_mask = self.get_src_mask(src_length)
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src,mask = src_mask,is_causal = True)
        output = self.linear(output)
        return output

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [4]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [69]:
bptt = 32
def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:
    """
    Args:
        source: Tensor, shape ``[full_seq_len, batch_size]``
        i: int

    Returns:
        tuple (data, target), where data has shape ``[seq_len, batch_size]`` and
        target has shape ``[seq_len * batch_size]``
    """
    data = source[: -1, i: i+bptt]
    target = source[1:, i: i+bptt]
    return data, target

In [6]:
zf = ZipFile("order_products__prior.csv.zip")
train_df = pd.read_csv(zf.extract("order_products__prior.csv"))

In [7]:
train_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 32434489 entries, 0 to 32434488
Data columns (total 4 columns):
 #   Column             Dtype
---  ------             -----
 0   order_id           int64
 1   product_id         int64
 2   add_to_cart_order  int64
 3   reordered          int64
dtypes: int64(4)
memory usage: 989.8 MB


In [8]:
train_df

Unnamed: 0,order_id,product_id,add_to_cart_order,reordered
0,2,33120,1,1
1,2,28985,2,1
2,2,9327,3,0
3,2,45918,4,1
4,2,30035,5,0
...,...,...,...,...
32434484,3421083,39678,6,1
32434485,3421083,11352,7,0
32434486,3421083,4600,8,0
32434487,3421083,24852,9,1


In [10]:
#calculate percentage of data included in top k sales products
train_df.product_id.value_counts()[:999].values.sum()/train_df.product_id.value_counts().values.sum()

0.5398444846780228

In [11]:
#extract top k prodcuts
top_products = train_df.product_id.value_counts()[:999].keys()

In [12]:
#encode the top k products into indices from 1 to k
product_to_idx = {product:i for i,product in enumerate(top_products,start=1)}

In [13]:
product_to_idx

{24852: 1,
 13176: 2,
 21137: 3,
 21903: 4,
 47209: 5,
 47766: 6,
 47626: 7,
 16797: 8,
 26209: 9,
 27845: 10,
 27966: 11,
 22935: 12,
 24964: 13,
 45007: 14,
 39275: 15,
 49683: 16,
 28204: 17,
 5876: 18,
 8277: 19,
 40706: 20,
 4920: 21,
 30391: 22,
 45066: 23,
 42265: 24,
 49235: 25,
 44632: 26,
 19057: 27,
 4605: 28,
 37646: 29,
 21616: 30,
 17794: 31,
 27104: 32,
 30489: 33,
 31717: 34,
 27086: 35,
 44359: 36,
 28985: 37,
 46979: 38,
 8518: 39,
 41950: 40,
 26604: 41,
 5077: 42,
 34126: 43,
 22035: 44,
 39877: 45,
 35951: 46,
 43352: 47,
 10749: 48,
 19660: 49,
 9076: 50,
 21938: 51,
 43961: 52,
 24184: 53,
 34969: 54,
 46667: 55,
 48679: 56,
 25890: 57,
 31506: 58,
 12341: 59,
 39928: 60,
 24838: 61,
 5450: 62,
 22825: 63,
 5785: 64,
 35221: 65,
 28842: 66,
 33731: 67,
 27521: 68,
 44142: 69,
 33198: 70,
 8174: 71,
 20114: 72,
 8424: 73,
 27344: 74,
 11520: 75,
 29487: 76,
 18465: 77,
 28199: 78,
 15290: 79,
 46906: 80,
 9839: 81,
 27156: 82,
 3957: 83,
 43122: 84,
 23909: 85,
 3

In [14]:
#an reversed dictionary for decode the product indices
idx_to_product = {value:key for (key,value) in product_to_idx.items()}

In [15]:
idx_to_product

{1: 24852,
 2: 13176,
 3: 21137,
 4: 21903,
 5: 47209,
 6: 47766,
 7: 47626,
 8: 16797,
 9: 26209,
 10: 27845,
 11: 27966,
 12: 22935,
 13: 24964,
 14: 45007,
 15: 39275,
 16: 49683,
 17: 28204,
 18: 5876,
 19: 8277,
 20: 40706,
 21: 4920,
 22: 30391,
 23: 45066,
 24: 42265,
 25: 49235,
 26: 44632,
 27: 19057,
 28: 4605,
 29: 37646,
 30: 21616,
 31: 17794,
 32: 27104,
 33: 30489,
 34: 31717,
 35: 27086,
 36: 44359,
 37: 28985,
 38: 46979,
 39: 8518,
 40: 41950,
 41: 26604,
 42: 5077,
 43: 34126,
 44: 22035,
 45: 39877,
 46: 35951,
 47: 43352,
 48: 10749,
 49: 19660,
 50: 9076,
 51: 21938,
 52: 43961,
 53: 24184,
 54: 34969,
 55: 46667,
 56: 48679,
 57: 25890,
 58: 31506,
 59: 12341,
 60: 39928,
 61: 24838,
 62: 5450,
 63: 22825,
 64: 5785,
 65: 35221,
 66: 28842,
 67: 33731,
 68: 27521,
 69: 44142,
 70: 33198,
 71: 8174,
 72: 20114,
 73: 8424,
 74: 27344,
 75: 11520,
 76: 29487,
 77: 18465,
 78: 28199,
 79: 15290,
 80: 46906,
 81: 9839,
 82: 27156,
 83: 3957,
 84: 43122,
 85: 23909,
 8

In [16]:
#create new columns "idx" and fill all non top k products with 0
train_df['idx'] = train_df.product_id.map(product_to_idx)

In [17]:
train_df.fillna(0,inplace=True)

In [18]:
train_df['idx'] = train_df.idx.astype(int)

In [19]:
train_df.head(20)

Unnamed: 0,order_id,product_id,add_to_cart_order,reordered,idx
0,2,33120,1,1,200
1,2,28985,2,1,37
2,2,9327,3,0,860
3,2,45918,4,1,0
4,2,30035,5,0,0
5,2,17794,6,1,31
6,2,40141,7,1,0
7,2,1819,8,1,0
8,2,43668,9,0,0
9,3,33754,1,1,114


In [20]:
zf = ZipFile("order_products__train.csv.zip")
val_df = pd.read_csv(zf.extract("order_products__train.csv"))

In [21]:
val_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1384617 entries, 0 to 1384616
Data columns (total 4 columns):
 #   Column             Non-Null Count    Dtype
---  ------             --------------    -----
 0   order_id           1384617 non-null  int64
 1   product_id         1384617 non-null  int64
 2   add_to_cart_order  1384617 non-null  int64
 3   reordered          1384617 non-null  int64
dtypes: int64(4)
memory usage: 42.3 MB


In [22]:
val_df.head(50)

Unnamed: 0,order_id,product_id,add_to_cart_order,reordered
0,1,49302,1,1
1,1,11109,2,1
2,1,10246,3,0
3,1,49683,4,0
4,1,43633,5,1
5,1,13176,6,0
6,1,47209,7,0
7,1,22035,8,1
8,36,39612,1,0
9,36,19660,2,1


In [23]:
#do the same thing to validation dataframe
val_df['idx'] = val_df.product_id.map(product_to_idx)

In [24]:
val_df.fillna(0,inplace=True)

In [25]:
val_df['idx'] = val_df.idx.astype(int)

In [26]:
val_df.head(20)

Unnamed: 0,order_id,product_id,add_to_cart_order,reordered,idx
0,1,49302,1,1,0
1,1,11109,2,1,0
2,1,10246,3,0,151
3,1,49683,4,0,16
4,1,43633,5,1,0
5,1,13176,6,0,2
6,1,47209,7,0,5
7,1,22035,8,1,44
8,36,39612,1,0,0
9,36,19660,2,1,49


In [27]:
train_df.tail(20)

Unnamed: 0,order_id,product_id,add_to_cart_order,reordered,idx
32434469,3421081,20539,5,0,0
32434470,3421081,35221,6,0,65
32434471,3421081,12861,7,0,0
32434472,3421082,17279,1,1,0
32434473,3421082,12738,2,1,0
32434474,3421082,16797,3,0,8
32434475,3421082,43352,4,1,47
32434476,3421082,32700,5,1,0
32434477,3421082,12023,6,0,0
32434478,3421082,47941,7,0,0


In [28]:
val_df.tail(50)

Unnamed: 0,order_id,product_id,add_to_cart_order,reordered,idx
1384567,3420998,31717,11,1,34
1384568,3420998,5337,12,1,814
1384569,3420998,23801,13,0,786
1384570,3420998,46665,14,0,0
1384571,3420998,9366,15,0,787
1384572,3420998,36606,16,1,0
1384573,3420998,5240,17,0,624
1384574,3420998,45002,18,1,0
1384575,3420998,23430,19,1,0
1384576,3420998,8277,20,1,19


In [32]:
#customized way to create dataset

def create_dataset(df,max_len,max_sequence_start=0, max_sequence_end=100000):
    df_values = df[['order_id','idx']].values
    catch_index = df_values[0][0]
    one_row = [0]*max_len
    data = []
    idx = 0
    df_leng = len(df)
    for row in range(df_leng-1):
            
        if df_values[row][1] > 0 and idx < max_len:
            one_row[idx] = df_values[row][1]
            idx += 1
            
        if df_values[row+1][0] != catch_index and one_row != [0]*max_len:
            data.append(torch.tensor(one_row,dtype = torch.long))
            del one_row
            torch.cuda.empty_cache()
            one_row = [0]*max_len
            catch_index = df_values[row+1][0]
            idx = 0

        if row == df_leng -2 and df_values[row+1][1] > 0:
            one_row.append(df_values[row+1][1])
            data.append(torch.tensor(one_row,dtype = torch.long))
            del one_row
            torch.cuda.empty_cache()
            catch_index = df_values[row+1][0]
  
    return torch.stack(data[max_sequence_start:max_sequence_end]).to(device)

In [33]:

# del train_data
# torch.cuda.empty_cache()
train_data = create_dataset(train_df,20,max_sequence_start=0,max_sequence_end=100000)

In [34]:
train_data.size()

torch.Size([100000, 20])

In [35]:
train_data = train_data.t().contiguous()

In [53]:
train_data

tensor([[200, 114, 550,  ..., 700,   1, 491],
        [ 37,  61, 710,  ...,   0,   0, 904],
        [860,   4, 144,  ...,   0,   0,  29],
        ...,
        [  0,   0,   0,  ...,   0,   0,   0],
        [  0,   0,   0,  ...,   0,   0,   0],
        [  0,   0,   0,  ...,   0,   0,   0]], device='cuda:0')

In [37]:
train_data.size()

torch.Size([20, 100000])

In [38]:
# del val_data
# torch.cuda.empty_cache()
val_data = create_dataset(val_df,20,max_sequence_start=0,max_sequence_end=100000)

In [54]:
val_data

tensor([[151,  49, 592,  ..., 178, 275, 588],
        [ 16,  25,  30,  ..., 753,   0,  19],
        [  2, 270, 691,  ...,   0,   0,  34],
        ...,
        [  0,   0,   0,  ...,   0,   0,   0],
        [  0,   0,   0,  ...,   0,   0,   0],
        [  0,   0,   0,  ...,   0,   0,   0]], device='cuda:0')

In [39]:
val_data.size()

torch.Size([100000, 20])

In [40]:
val_data = val_data.t().contiguous()

In [36]:
products_in_train = train_df.product_id.unique()
products_in_val = val_df.product_id.unique()
print("validation set is the subset of training set: ",set(products_in_val).issubset(products_in_train))


validation set is the subset of training set:  False


In [37]:
total_product = np.concatenate((products_in_train,products_in_val),axis=None)
unique_total_product = np.unique(total_product)
print("training set, validation set is the subset of training set: ",set(products_in_train).issubset(unique_total_product),set(products_in_val).issubset(unique_total_product))

training set, validation set is the subset of training set:  True True


In [70]:
del model
torch.cuda.empty_cache()

In [88]:
ntokens = 1000 # size of vocabulary
emsize = 64  # embedding dimension
d_hid = 128 # dimension of the feedforward network model in ``nn.TransformerEncoder``
nlayers = 8  # number of ``nn.TransformerEncoderLayer`` in ``nn.TransformerEncoder``
nhead = 8  # number of heads in ``nn.MultiheadAttention``
dropout = 0.2  # dropout probability
model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)

In [89]:
model

TransformerModel(
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
        )
        (linear1): Linear(in_features=64, out_features=128, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Linear(in_features=128, out_features=64, bias=True)
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.2, inplace=False)
        (dropout2): Dropout(p=0.2, inplace=False)
      )
    )
  )
  (embedding): Embedding(1000, 64)
  (linear): Linear(in_features=64, out_features=1000, bias=True)
)

In [40]:
print(torch.cuda.memory_summary(device="cuda", abbreviated=False))

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  18004 KiB |  32768 KiB |  34470 KiB |  16466 KiB |
|       from large pool |  17634 KiB |  32768 KiB |  34018 KiB |  16384 KiB |
|       from small pool |    370 KiB |    370 KiB |    452 KiB |     82 KiB |
|---------------------------------------------------------------------------|
| Active memory         |  18004 KiB |  32768 KiB |  34470 KiB |  16466 KiB |
|       from large pool |  17634 KiB |  32768 KiB |  34018 KiB |  16384 KiB |
|       from small pool |    370 KiB |    370 KiB |    452 KiB |     82 KiB |
|---------------------------------------------------------------

In [90]:
import time

criterion = nn.CrossEntropyLoss()
lr = 1e-2  # learning rate
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       mode='min',
                                                       factor =0.9,
                                                       patience=5,
                                                       threshold=0.001)

def train(model: nn.Module) -> None:
    model.train()  # turn on train mode
    total_loss = 0.
    log_interval = 200
    start_time = time.time()

    num_batches = train_data.size(-1) // bptt
    for batch, i in enumerate(range(0, train_data.size(-1) - bptt , bptt)):
        data, targets = get_batch(train_data, i)
        output = model(data)
        output_flat = output.view(-1, ntokens)
        targets_flat = targets.reshape(-1)
        loss = criterion(output_flat, targets_flat)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        if batch % log_interval == 0 and batch > 0:
            lr = optimizer.param_groups[0]['lr']
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            scheduler.step(cur_loss)
            ppl = math.exp(cur_loss)
            print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
                  f'lr {lr:02.4f} | ms/batch {ms_per_batch:5.2f} | '
                  f'loss {cur_loss:5.4f} | ppl {ppl:8.4f}')
            total_loss = 0
            start_time = time.time()

def evaluate(model: nn.Module, eval_data: Tensor) -> float:
    model.eval()  # turn on evaluation mode
    total_loss = 0.
    with torch.no_grad():
        for i in range(0, eval_data.size(0) - 1, bptt):
            data, targets = get_batch(eval_data, i)
            seq_len = data.size(0)
            output = model(data)
            target_flat = targets.reshape(-1)
            output_flat = output.view(-1, ntokens)
            total_loss += seq_len * criterion(output_flat, target_flat).item()
    return total_loss / (len(eval_data) - 1)

In [91]:
best_val_loss = float('inf')
epochs = 100

with TemporaryDirectory() as tempdir:
    best_model_params_path = os.path.join(tempdir, "best_model_params.pt")

    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        train(model)
        val_loss = evaluate(model, val_data)
        val_ppl = math.exp(val_loss)
        elapsed = time.time() - epoch_start_time
        print('-' * 89)
        print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
            f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
        print('-' * 89)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), best_model_params_path)

        scheduler.step(best_val_loss)
    model.load_state_dict(torch.load(best_model_params_path)) # load best model states

| epoch   1 |   200/ 3125 batches | lr 0.0100 | ms/batch 33.16 | loss 2.0986 | ppl   8.1550
| epoch   1 |   400/ 3125 batches | lr 0.0100 | ms/batch 32.52 | loss 1.9449 | ppl   6.9932
| epoch   1 |   600/ 3125 batches | lr 0.0100 | ms/batch 32.22 | loss 1.9497 | ppl   7.0268
| epoch   1 |   800/ 3125 batches | lr 0.0100 | ms/batch 32.49 | loss 1.9724 | ppl   7.1882
| epoch   1 |  1000/ 3125 batches | lr 0.0100 | ms/batch 32.32 | loss 1.9508 | ppl   7.0344
| epoch   1 |  1200/ 3125 batches | lr 0.0100 | ms/batch 32.79 | loss 1.9608 | ppl   7.1049
| epoch   1 |  1400/ 3125 batches | lr 0.0100 | ms/batch 32.95 | loss 1.9101 | ppl   6.7541
| epoch   1 |  1600/ 3125 batches | lr 0.0100 | ms/batch 32.47 | loss 1.9624 | ppl   7.1166
| epoch   1 |  1800/ 3125 batches | lr 0.0100 | ms/batch 32.39 | loss 1.9408 | ppl   6.9642
| epoch   1 |  2000/ 3125 batches | lr 0.0100 | ms/batch 32.28 | loss 1.9309 | ppl   6.8958
| epoch   1 |  2200/ 3125 batches | lr 0.0100 | ms/batch 32.24 | loss 1.9280 | p

In [92]:
torch.save(model.state_dict,"trans4rec.tar")

In [None]:
test_loss = evaluate(model, test_data)
test_ppl = math.exp(test_loss)
print('=' * 89)
print(f'| End of training | test loss {test_loss:5.2f} | '
      f'test ppl {test_ppl:8.2f}')
print('=' * 89)

In [124]:
data, target = get_batch(train_data,0)

In [125]:
data

tensor([[37,  4,  2, 11,  1, 34, 10, 23, 49, 26, 11, 24,  1,  1, 38,  6, 46,  1,
          2,  7, 47, 12,  4,  2, 28, 17,  2, 42,  8, 15,  1, 16],
        [31,  0,  0,  5,  0,  6,  0,  0,  0,  0, 38,  0,  0,  0,  0, 39,  0,  4,
          6,  5,  2,  0, 34, 16,  0,  4,  0,  0,  0, 47,  0,  0],
        [ 0,  0,  0,  0,  0, 28,  0,  0,  0,  0,  6,  0,  0,  0,  0,  0,  0,  6,
         31,  0,  0,  0, 21, 21,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  3,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         37,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0, 25,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         40,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0, 

In [134]:
data.size()

torch.Size([19, 32])

In [133]:
len(data[0])

32

In [113]:
input = torch.zeros(20).long()

In [115]:
input[0] = 16

In [118]:
input = input.to(device)

In [119]:
with torch.no_grad():
    output = model(input)

In [139]:
output.size()

torch.Size([20, 20, 51])

In [145]:
val_data[0].size()

torch.Size([100000])

In [140]:
output_flat = output.view(-1, ntokens)

In [142]:
output_flat.size()

torch.Size([400, 51])

In [143]:
torch.argmax(output_flat[0])

tensor(3, device='cuda:0')