In [79]:
'''
    Transformer go boom boom attention
    There is also a version of this written in CRC server to make use of GPU cluster :)
    
    Better data generation, just generate all 3^3 combinations possible
    Plot loss vs epoch
    Try network for different points over training
    Separate testing data 
    
    Pursue if victory:
       Try what model does over data it never saw in training i.e. does it learn a general function?
           if not --> it just is cheating and memorizes
           
    Show probability going from probability density --> number prediction (!)
    
    Try doing custom embedding that can work on floats, not just on integer (it would be just a linear nn)
    
    DO POSITIONAL ENCODING
'''

import numpy as np
import torch
import torch.nn as nn
import time
import math

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [154]:
class TransformerModel(nn.Module):
    def __init__(self, n_token, d_model, n_head, n_layers):
        super().__init__()
        
        self.d_model = d_model
        self.embedding = nn.Embedding(n_token, d_model)
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=n_head,
            num_encoder_layers=n_layers,
            num_decoder_layers=n_layers)
        self.out = nn.Linear(d_model, n_token) # Learned linear at the end where output of decoder is run through
        
    def forward(self, src, tgt):
        
        src = self.embedding(src) * math.sqrt(self.d_model)
        tgt = self.embedding(tgt) * math.sqrt(self.d_model)

        print(src.size())
        src = src.permute(1, 0, 2)
        tgt = tgt.permute(1, 0, 2)
        
        transformer_output = self.transformer(src, tgt)
        output = self.out(transformer_output)
        print(f'Before permute:\n Size: {output.size()}\nTensor: {output}')
        
        output = output.permute(1, 0, 2) # Permute so that batch num is first dimension again
        #print(f'After permute:\n Size: {output.size()}\nTensor: {output}')
        
        return output

In [155]:
# Let's define an instance :)
d_model = 5 # Let's try embedding brotha
n_token = 7 # Only do 0, 1, 2?
n_head = 1 # Remember that d_model=d_embedding must be divisible by number of n_head
n_layers = 2

model = TransformerModel(n_token, d_model, n_head, n_layers)

print(f'The number of tokens are {n_token}, six categories to calculate probability distributions over (i.e. 0, 1, 2, 3, 4 + sos + eos)')

# Check da dimension
src = torch.randint(size=(2, 3), low=0, high=2)
tgt = src

output = model(src, src)
print(f'src:\n  Size: {src.size()}\n  Tensor: {src}')
print(f'tgt:\n  Size: {tgt.size()}\n  Tensor: {tgt}')
print(f'output:\n  Size: {output.size()}\n  Tensor: {output}\n')

src = torch.randint(size=(2, 3), low=0, high=2)
tgt = src
tgt_input = tgt[:, :-1]
tgt_expected = tgt[:, 1:]
output = model(src, tgt_input)

print('----------------------------------------------------------------------------')
print(f'src:\n  Size: {src.size()}\n  Tensor: {src}')
print(f'tgt_input:\n  Size: {tgt_input.size()}\n  Tensor: {tgt_input}')
print(f'output:\n  Size: {output.size()}\n  Tensor: {output}')

The number of tokens are 7, six categories to calculate probability distributions over (i.e. 0, 1, 2, 3, 4 + sos + eos)
torch.Size([2, 3, 5])
Before permute:
 Size: torch.Size([3, 2, 7])
Tensor: tensor([[[-0.1093,  0.4623, -0.2743,  0.1649,  0.1041, -0.5462,  1.2153],
         [-0.2552,  0.5584,  0.0601,  0.0748,  0.0484, -1.0011,  1.3124]],

        [[-1.0152,  0.2020,  0.2753,  0.6513, -0.8838, -0.3868,  0.9469],
         [-0.7572,  0.4402,  0.0965,  0.5907, -0.2909, -0.8904,  1.2949]],

        [[-0.6724,  0.1188, -0.0306,  0.5339, -0.8580,  0.0553,  0.9333],
         [-0.2526,  0.5380,  0.0575,  0.0960, -0.0371, -0.9735,  1.3748]]],
       grad_fn=<AddBackward0>)
src:
  Size: torch.Size([2, 3])
  Tensor: tensor([[0, 1, 1],
        [0, 1, 0]])
tgt:
  Size: torch.Size([2, 3])
  Tensor: tensor([[0, 1, 1],
        [0, 1, 0]])
output:
  Size: torch.Size([2, 3, 7])
  Tensor: tensor([[[-0.1093,  0.4623, -0.2743,  0.1649,  0.1041, -0.5462,  1.2153],
         [-1.0152,  0.2020,  0.2753,  0.

In [156]:
# Let's make some 1s and 0s 
def construct_data(size, seq_len):
    SOS_token = 5
    EOS_token = 6
    data = np.random.randint(size = (size, seq_len), low=0, high=2)
    target = np.empty((size, seq_len))
    data = torch.tensor(data).to()
    target = torch.tensor(target)
    target[:, 0] = data[:, 1] + data[:, 2]
    target[:, 1] = data[:, 0] + data[:, 2]
    target[:, 2] = data[:, 0] + data[:, 1]
    SOS_ = torch.ones((size, 1)) * SOS_token
    EOS_ = torch.ones((size, 1)) * EOS_token
    
    data = torch.cat((SOS_, data, EOS_), 1)
    target = torch.cat((SOS_, target, EOS_), 1)
    
    return data, target

def batchify(data, target):
    numMiniBatch = int(math.floor(data.shape[0]/100.0))
    inputMiniBatches = data.chunk(numMiniBatch)
    outputMiniBatches = target.chunk(numMiniBatch)
    
    return numMiniBatch, inputMiniBatches, outputMiniBatches

In [157]:
seq_len = 3
size = 10000
data, target = construct_data(size, seq_len)

# Just for confirmation to see we constructed our data the way we were supposed to
print(f' An instance of the input sequence at the 0th index:\n{data[0, :]}')
print(f' An instance of the target sequence at the 0th index:\n{target[0, :]}')

# Batch
batch_num, data_batched, target_batched = batchify(data, target)

# Let's check that batching works properly
print(data_batched[0][0])
print(target_batched[0][0])

 An instance of the input sequence at the 0th index:
tensor([5., 1., 1., 0., 6.])
 An instance of the target sequence at the 0th index:
tensor([5., 1., 1., 2., 6.], dtype=torch.float64)
tensor([5., 1., 1., 0., 6.])
tensor([5., 1., 1., 2., 6.], dtype=torch.float64)


In [162]:
# Train train choo choo, simply and absolutely the best form of transportation

opt = torch.optim.SGD(model.parameters(), lr=0.05)
loss_fn = nn.CrossEntropyLoss()

def train(model, batch_num, data_batched, target_batched):
    model.train()
    total_loss = 0
    start_time = time.time()
    for batch in range(batch_num):
        data = data_batched[batch].type(torch.long)
        target = target_batched[batch].type(torch.long)
          
        target_in = target[:, :-1]
        target_expected = target[:, 1:]
        
        pred = model(data, target_in)

        pred = pred.permute(0, 2, 1) 

        loss = (loss_fn(pred, target_expected)).type(torch.float)
        
        opt.zero_grad()
        loss.backward()
        opt.step()
    
        total_loss += loss.detach().item()
        
    return total_loss
    
def train_loop(model, n_epochs, batch_num, data_batched, target_batched):
    for i in range(n_epochs):
        loss = train(model, batch_num, data_batched, target_batched)
        if i % 10 == 0:
            print(f'Epoch: {i}\nTotal Loss: {loss}')
            print(f'----------------------------------')

In [163]:
n_epochs = 200
train_loop(model, n_epochs, batch_num, data_batched, target_batched)

torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[-1.3612,  2.7335,  2.9362,  ..., -2.8607, -2.1814,  2.1626],
         [-0.2127,  3.5005,  2.4470,  ..., -3.1419, -2.8231,  2.5709],
         [ 3.0425,  3.9188,  0.0778,  ..., -2.1413, -2.6699,  1.7543],
         ...,
         [ 2.7385,  4.1096,  0.4525,  ..., -2.5018, -2.9331,  2.1103],
         [-0.9097,  3.1036,  2.7908,  ..., -2.9073, -2.4042,  2.1566],
         [-0.0374,  3.6129,  2.3232,  ..., -3.2041, -2.9925,  2.7325]],

        [[-1.2371,  2.7947,  2.7903,  ..., -2.9982, -2.4223,  2.5152],
         [-0.7377,  2.9434,  2.0754,  ..., -2.6899, -2.8326,  2.6060],
         [ 3.1428,  3.2876, -0.9886,  ..., -2.3103, -3.3648,  3.3460],
         ...,
         [ 3.3914,  3.0775, -1.2933,  ..., -2.1006, -3.1480,  3.1940],
         [-0.9706,  2.9051,  2.8065,  ..., -2.9156, -2.2243,  2.1716],
         [-1.1555,  2.7478,  2.8634,  ..., -2.8839, -2.1213,  2.1553]],

        [[-1.3795,  2.7012,  2.9654, 

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[ 3.9172,  2.1658, -2.2436,  ..., -1.0186, -2.0938,  2.1240],
         [ 0.1154,  3.5539,  2.3393,  ..., -2.8166, -2.4746,  1.9127],
         [ 3.8863,  3.9211, -1.0255,  ..., -1.9941, -3.1939,  2.3711],
         ...,
         [-0.9846,  2.9691,  2.7737,  ..., -3.0327, -2.4232,  2.4334],
         [-1.4886,  2.7173,  2.9798,  ..., -2.8148, -2.1822,  2.1139],
         [ 3.7525,  4.1067, -0.7115,  ..., -2.1967, -3.2615,  2.4117]],

        [[ 3.6043,  1.8050, -2.1944,  ..., -1.0360, -1.9283,  2.2754],
         [ 1.3792,  4.2567,  1.5302,  ..., -2.9117, -3.2120,  2.3395],
         [ 3.2802,  3.6705, -0.8143,  ..., -2.4929, -3.5523,  3.3642],
         ...,
         [-1.6298,  2.5746,  3.0537,  ..., -2.7228, -1.9918,  1.9401],
         [ 0.8027,  4.0550,  1.9323,  ..., -3.0252, -3.0767,  2.3855],
         [ 3.63

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[-0.6117,  3.2382,  2.7155,  ..., -2.9513, -2.4663,  2.1788],
         [-1.4550,  2.3593,  2.9094,  ..., -2.8467, -1.9245,  2.2427],
         [-2.2280,  2.0024,  3.0244,  ..., -2.5261, -1.8436,  2.0496],
         ...,
         [ 3.1432,  3.9924,  0.0532,  ..., -2.2556, -2.8216,  1.9453],
         [-0.8574,  3.0769,  2.7520,  ..., -3.0641, -2.5443,  2.4867],
         [-0.4773,  3.2759,  2.6627,  ..., -2.9084, -2.4230,  2.0818]],

        [[-0.4982,  3.2166,  2.6102,  ..., -3.1014, -2.5844,  2.4991],
         [-1.2599,  2.6810,  2.8382,  ..., -3.0054, -2.3137,  2.5135],
         [-2.1476,  2.1393,  3.1030,  ..., -2.3930, -1.7253,  1.6671],
         ...,
         [ 1.1803,  3.4732,  1.2083,  ..., -3.2777, -3.2818,  3.5266],
         [-0.7788,  3.0392,  2.6425,  ..., -3.1445, -2.6306,  2.7147],
         [-0.81

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[-0.9157,  3.1004,  2.8477,  ..., -2.9435, -2.4228,  2.2035],
         [ 3.9060,  2.0547, -2.2207,  ..., -0.4178, -1.5097,  1.0258],
         [-1.4545,  2.6010,  3.0379,  ..., -2.8183, -2.0245,  2.0636],
         ...,
         [ 2.6266,  4.3703,  0.6148,  ..., -2.8644, -3.4041,  2.6481],
         [ 3.8837,  2.0216, -2.3601,  ..., -0.7649, -1.9227,  1.8375],
         [ 2.1238,  4.3083,  0.9594,  ..., -2.6307, -3.1828,  2.1327]],

        [[-0.7042,  3.1672,  2.6603,  ..., -3.1559, -2.6979,  2.6834],
         [ 2.3606,  1.8221, -1.0745,  ..., -1.9031, -2.2728,  3.1743],
         [-1.5286,  2.4857,  2.5956,  ..., -2.7495, -2.4767,  2.5922],
         ...,
         [ 3.4029,  3.4628, -1.0453,  ..., -2.3379, -3.4251,  3.3184],
         [ 3.8217,  2.3558, -1.8357,  ..., -1.4713, -2.1953,  2.4405],
         [ 2.60

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[ 2.6452,  4.4444,  0.5685,  ..., -2.9381, -3.5753,  2.8362],
         [-1.1174,  2.9943,  2.9412,  ..., -2.7907, -2.2515,  1.9518],
         [-1.7381,  2.4384,  3.1039,  ..., -2.8015, -2.0001,  2.1438],
         ...,
         [-2.9590,  1.0319,  3.1759,  ..., -2.1770, -1.0062,  1.6925],
         [ 3.2880,  3.9416, -0.4477,  ..., -1.8341, -2.9295,  1.7568],
         [-1.8565,  2.3503,  3.0959,  ..., -2.7958, -2.0162,  2.2231]],

        [[ 2.9285,  3.7920, -0.4322,  ..., -2.7204, -3.7149,  3.5411],
         [-0.2911,  3.4892,  2.5679,  ..., -3.1075, -2.7556,  2.4627],
         [-1.5117,  2.5426,  2.9762,  ..., -2.9351, -2.1699,  2.3972],
         ...,
         [-2.8772,  1.3465,  3.2641,  ..., -2.1246, -1.1295,  1.4565],
         [ 3.0050,  3.3186, -0.7617,  ..., -2.5410, -3.4116,  3.5650],
         [-1.51

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[-2.6153,  1.5575,  3.2596,  ..., -2.4104, -1.3434,  1.8070],
         [-1.4709,  2.6746,  3.0314,  ..., -2.9028, -2.1803,  2.2567],
         [-2.2450,  1.9953,  3.2111,  ..., -2.6176, -1.7007,  2.0110],
         ...,
         [ 3.1598,  4.3177, -0.0173,  ..., -2.3571, -3.2661,  2.2051],
         [-1.6028,  2.5671,  3.1146,  ..., -2.7629, -1.9734,  1.9823],
         [-2.4818,  1.3967,  3.0975,  ..., -2.4676, -1.2950,  2.0119]],

        [[-2.6889,  1.4976,  3.1693,  ..., -2.3914, -1.4462,  1.9471],
         [ 1.9633,  3.6040,  0.0407,  ..., -2.3404, -3.4854,  2.9156],
         [-0.7081,  3.0678,  2.5859,  ..., -3.2133, -2.7149,  2.8891],
         ...,
         [ 2.7584,  3.7535, -0.3201,  ..., -2.8006, -3.7253,  3.6413],
         [ 0.6375,  4.0771,  1.9088,  ..., -3.1155, -3.3910,  2.7554],
         [-2.67

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[ 1.8040,  4.1275,  1.2608,  ..., -2.6505, -2.8838,  1.9670],
         [-2.7691,  1.2900,  3.1493,  ..., -2.3700, -1.3082,  1.9879],
         [ 2.4976,  4.3402,  0.7413,  ..., -2.7660, -3.2339,  2.3879],
         ...,
         [ 3.9455,  3.0415, -1.7075,  ..., -1.8367, -2.9329,  2.8782],
         [-0.2905,  3.5321,  2.5859,  ..., -3.0703, -2.7496,  2.3914],
         [ 4.0485,  3.4346, -1.3326,  ..., -1.4118, -2.4822,  1.5812]],

        [[ 3.5608,  4.1249, -0.6608,  ..., -2.5439, -3.6535,  3.0990],
         [-3.1924,  0.8900,  3.1221,  ..., -1.8131, -0.8775,  1.3147],
         [ 3.6671,  3.4167, -1.3641,  ..., -2.0401, -3.3059,  3.0161],
         ...,
         [ 3.1392,  1.1637, -2.2526,  ..., -0.8075, -1.5464,  2.2041],
         [-0.0908,  3.6352,  2.4915,  ..., -3.0590, -2.7935,  2.3630],
         [ 3.13

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[ 3.4143,  3.9777, -0.5458,  ..., -1.8831, -3.0076,  1.8784],
         [-2.7737,  1.2331,  3.3255,  ..., -2.1321, -0.8469,  1.3242],
         [-1.0359,  2.9021,  2.7945,  ..., -3.1219, -2.4990,  2.6850],
         ...,
         [-1.4404,  2.7012,  3.0831,  ..., -2.7980, -2.0463,  1.9998],
         [ 3.4569,  4.2309, -0.2970,  ..., -2.2687, -3.2140,  2.2143],
         [-1.2063,  2.9450,  2.9869,  ..., -2.9107, -2.3048,  2.1823]],

        [[ 2.6766,  3.7806, -0.2855,  ..., -2.7909, -3.7702,  3.6448],
         [-2.9890,  1.1656,  3.2221,  ..., -2.2056, -1.1425,  1.7457],
         [ 1.4716,  3.6265,  0.5177,  ..., -2.8181, -3.6839,  3.4736],
         ...,
         [-2.1669,  1.9719,  3.1933,  ..., -2.6657, -1.6618,  2.0745],
         [ 3.3395,  3.3772, -1.2209,  ..., -1.8925, -3.2644,  2.8031],
         [ 0.02

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[-1.3105,  2.9398,  2.9824,  ..., -2.9580, -2.4314,  2.3548],
         [ 4.1027,  1.9056, -2.4408,  ..., -0.6021, -1.5733,  1.4547],
         [-3.0068,  1.2571,  3.3435,  ..., -2.0177, -0.9842,  1.2874],
         ...,
         [-1.2788,  2.8679,  3.0381,  ..., -2.6978, -2.0662,  1.7881],
         [-1.1891,  2.8519,  2.9207,  ..., -3.0438, -2.3685,  2.4911],
         [-2.1007,  2.1227,  3.2373,  ..., -2.6734, -1.7247,  2.0068]],

        [[ 0.6475,  2.7981,  0.5993,  ..., -2.5599, -3.2303,  3.4303],
         [ 3.4074,  1.2649, -2.4345,  ..., -0.6785, -1.5471,  2.0313],
         [-3.2913,  0.7885,  3.1535,  ..., -1.7487, -0.7725,  1.2341],
         ...,
         [ 1.4520,  3.6166,  0.4882,  ..., -2.4267, -3.4413,  2.8338],
         [ 0.6380,  3.6907,  1.3482,  ..., -2.9695, -3.5642,  3.2330],
         [-0.80

pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[-1.2306,  2.9297,  2.9428,  ..., -3.0436, -2.4464,  2.4946],
         [ 4.2591,  2.7115, -2.1766,  ..., -1.1641, -2.3670,  2.0437],
         [ 1.2657,  3.8851,  1.6448,  ..., -2.6992, -2.6245,  1.8908],
         ...,
         [ 3.9008,  4.0409, -0.9932,  ..., -2.1369, -3.3402,  2.5496],
         [-2.8429,  1.2045,  3.1562,  ..., -2.3469, -1.2542,  2.0005],
         [ 3.6341,  4.1693, -0.5256,  ..., -2.1487, -3.1384,  2.1435]],

        [[ 0.6885,  4.1213,  1.8135,  ..., -3.1496, -3.5090,  2.9132],
         [ 3.5526,  1.8694, -2.0759,  ..., -1.3048, -2.0686,  2.5966],
         [ 2.2627,  4.0448,  0.5978,  ..., -3.2399, -3.6143,  3.5491],
         ...,
         [ 1.6039,  4.0940,  1.3382,  ..., -3.2702, -3.3483,  3.0879],
         [-3.0095,  1.2161,  3.1672,  ..., -2.0887, -1.2128,  1.6552],
         [ 2.1660,  4.2808,  0.7754,  ..., -3.2878, -3.7663,  3.4877]],

 

Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[-2.9899,  1.0579,  3.3352,  ..., -2.1276, -0.8380,  1.4647],
         [ 3.0891,  1.3475, -2.2927,  ..., -0.7345, -1.7763,  2.1809],
         [-1.1739,  2.7486,  2.8184,  ..., -3.1009, -2.3845,  2.7007],
         ...,
         [-2.6053,  1.5275,  3.0708,  ..., -2.5233, -1.6162,  2.2869],
         [-2.7822,  1.4633,  3.2373,  ..., -2.4078, -1.4126,  1.9804],
         [-1.0103,  3.0863,  2.9404,  ..., -2.9147, -2.3120,  2.1146]],

        [[-3.0490,  1.2522,  3.2737,  ..., -2.0748, -1.1366,  1.5207],
         [ 3.3015,  1.4011, -2.1465,  ..., -1.0693, -1.7052,  2.4168],
         [ 1.3392,  3.7282,  0.6627,  ..., -2.7922, -3.6712,  3.3239],
         ...,
         [-2.7791,  1.5732,  3.2586,  ..., -2.3504, -1.4588,  1.8585],
         [-2.1840,  2.0759,  2.9481,  ..., -2.6798, -2.0951,  2.4712],
         [-0.8013,  2.9802,  2.7828,  ..., -2.9977, -2.2704,  2.3092]],

        [[-3.0020,  1.2211,  3.2367,  ..., -2.2291, -1.2201, 

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[-1.3544,  2.8547,  3.0857,  ..., -2.7911, -2.1059,  1.9475],
         [-2.3990,  1.9154,  3.1454,  ..., -2.6490, -1.8602,  2.3113],
         [ 3.7235,  3.6582, -0.9440,  ..., -1.5100, -2.5329,  1.4408],
         ...,
         [-0.6855,  3.2084,  2.8043,  ..., -2.9192, -2.3305,  2.0895],
         [ 3.8778,  3.7634, -1.0290,  ..., -1.6442, -2.7111,  1.6892],
         [-2.3259,  1.8761,  3.1444,  ..., -2.7073, -1.7970,  2.3569]],

        [[-1.3726,  2.7458,  3.0227,  ..., -2.9789, -2.2218,  2.3702],
         [-1.8907,  2.5112,  3.2374,  ..., -2.6508, -1.9270,  1.8623],
         [ 2.8157,  3.6279, -0.4719,  ..., -2.8020, -3.6615,  3.7699],
         ...,
         [-1.8855,  2.2200,  3.0429,  ..., -2.8962, -2.0345,  2.5320],
         [ 1.9199,  3.7723,  0.7375,  ..., -3.2572, -3.4349,  3.6098],
         [-3.08

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[ 3.6872,  4.2378, -0.5789,  ..., -2.1936, -3.1815,  2.2116],
         [ 3.3712,  4.3011, -0.1724,  ..., -2.3277, -3.1189,  2.1375],
         [-0.6358,  3.4517,  2.7741,  ..., -3.0501, -2.6662,  2.3524],
         ...,
         [ 2.3142,  4.4244,  0.8942,  ..., -2.8805, -3.2699,  2.4745],
         [ 3.4249,  3.6920, -0.5379,  ..., -1.6118, -2.3530,  1.2628],
         [ 3.5135,  4.3704, -0.3691,  ..., -2.3627, -3.3399,  2.3698]],

        [[ 1.5627,  4.1078,  1.3322,  ..., -3.3244, -3.3987,  3.2123],
         [ 2.2154,  4.1099,  0.5549,  ..., -3.2878, -3.7703,  3.7225],
         [ 2.1314,  3.9165,  0.1208,  ..., -2.7888, -3.8209,  3.4762],
         ...,
         [ 1.4845,  3.8020,  0.9331,  ..., -3.3615, -3.6859,  3.8641],
         [ 2.1244,  4.2041,  0.9307,  ..., -3.1933, -3.4008,  3.1106],
         [ 3.50

Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[-0.9256,  3.0977,  2.7736,  ..., -3.2114, -2.6685,  2.8254],
         [-0.6683,  3.4290,  2.7645,  ..., -3.1462, -2.7442,  2.5526],
         [-1.6430,  2.7110,  3.1922,  ..., -2.7072, -1.9965,  1.8558],
         ...,
         [-0.9521,  3.2277,  2.9046,  ..., -3.0588, -2.5455,  2.4015],
         [ 3.6961,  4.2234, -0.4809,  ..., -2.4556, -3.2389,  2.5365],
         [ 4.1106,  3.0200, -1.9811,  ..., -1.5753, -2.8485,  2.6120]],

        [[ 0.0126,  3.1463,  1.4635,  ..., -2.9642, -3.3518,  3.4814],
         [ 0.7277,  4.1172,  1.6854,  ..., -3.2829, -3.6797,  3.2833],
         [ 2.9509,  2.8299, -1.3322,  ..., -1.9079, -3.1239,  3.1810],
         ...,
         [-1.3293,  2.9744,  2.9880,  ..., -3.0591, -2.5156,  2.5530],
         [ 3.6422,  3.9041, -1.0422,  ..., -2.3824, -3.6130,  3.1955],
         [ 3.6406,  1.6406, -2.3284,  ..., -1.0490, -1.8745,  2.3888]],

        [[-0.3327,  3.6102,  2.6117,  ..., -3.1879, -2.8565, 

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[ 2.3475,  4.6038,  0.7715,  ..., -2.9631, -3.6792,  2.8064],
         [-0.8687,  3.2571,  2.9047,  ..., -3.0501, -2.5268,  2.3526],
         [ 3.1296,  3.5851, -0.1252,  ..., -1.9023, -2.2369,  1.4480],
         ...,
         [-1.3318,  2.9805,  3.0646,  ..., -3.0050, -2.4165,  2.3797],
         [-1.5907,  2.7244,  3.1193,  ..., -2.9881, -2.2994,  2.4493],
         [ 3.2309,  3.1380, -0.5399,  ..., -1.2962, -1.7396,  0.8118]],

        [[ 3.2188,  3.9067, -0.7173,  ..., -2.5622, -3.7463,  3.4170],
         [-1.0073,  3.1079,  2.9499,  ..., -3.0517, -2.4395,  2.3861],
         [ 1.5529,  4.0302,  1.1647,  ..., -3.4258, -3.6506,  3.6553],
         ...,
         [-1.2917,  2.8986,  3.0779,  ..., -2.9319, -2.2225,  2.1920],
         [ 1.6472,  3.9349,  0.5471,  ..., -2.6254, -3.6341,  2.9840],
         [ 3.90

Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[ 0.3293,  4.0039,  2.2901,  ..., -3.1091, -3.0745,  2.4594],
         [ 3.3563,  3.9118, -0.3337,  ..., -1.8286, -2.5992,  1.4509],
         [-0.8279,  3.3140,  2.8909,  ..., -3.0641, -2.5956,  2.3745],
         ...,
         [ 4.0011,  2.2659, -2.3477,  ..., -0.7083, -1.9905,  1.5839],
         [ 3.1994,  4.0945,  0.0419,  ..., -2.3148, -2.8317,  1.9469],
         [ 4.1519,  2.1955, -2.3869,  ..., -1.1246, -2.1555,  2.2723]],

        [[ 0.4438,  3.8179,  1.6603,  ..., -3.1495, -3.6081,  3.3024],
         [ 3.7789,  3.0036, -1.6552,  ..., -2.0062, -3.0886,  3.2076],
         [-1.1049,  2.9050,  2.8387,  ..., -3.1916, -2.5831,  2.8569],
         ...,
         [ 3.6859,  2.1359, -2.0178,  ..., -1.5350, -2.3618,  2.8567],
         [ 3.6159,  3.9175, -0.9424,  ..., -2.5123, -3.6760,  3.3360],
         [ 3.6349,  2.0188, -2.0892,  ..., -1.4373, -2.3002,  2.8080]],

        [[ 0.1963,  3.8501,  2.3311,  ..., -3.3276, -3.1513, 

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[-0.9959,  3.0593,  2.9913,  ..., -2.8457, -2.1946,  1.9642],
         [-3.0385,  1.0674,  3.2670,  ..., -2.3027, -1.1862,  2.0049],
         [ 4.1295,  3.9890, -1.1117,  ..., -2.0356, -3.1247,  2.3170],
         ...,
         [-0.9602,  3.1781,  2.9627,  ..., -3.0405, -2.4749,  2.3414],
         [-0.6509,  3.4605,  2.8087,  ..., -3.1282, -2.7548,  2.5001],
         [-1.6975,  2.4891,  3.1307,  ..., -2.9772, -2.1588,  2.5046]],

        [[-1.5134,  2.7935,  3.1156,  ..., -3.0244, -2.3661,  2.4948],
         [-2.9768,  1.2086,  3.0412,  ..., -2.0850, -1.3947,  1.8875],
         [ 3.1823,  3.5195, -0.8374,  ..., -2.6305, -3.6061,  3.7119],
         ...,
         [ 0.3973,  3.5849,  2.2308,  ..., -2.9974, -2.5713,  2.2657],
         [-1.3855,  2.7949,  3.0879,  ..., -3.0151, -2.2801,  2.4174],
         [-0.02

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[-1.9390,  2.4171,  3.3027,  ..., -2.8260, -1.9745,  2.1799],
         [-2.2856,  2.1311,  3.3489,  ..., -2.7385, -1.8630,  2.2015],
         [ 2.3402,  4.2889,  0.8334,  ..., -2.5394, -2.9723,  1.9380],
         ...,
         [-1.2334,  2.9653,  3.1058,  ..., -2.8368, -2.1657,  1.9718],
         [ 3.1327,  3.7903, -0.1514,  ..., -2.8451, -3.2317,  3.2151],
         [-2.7830,  1.5671,  3.2783,  ..., -2.4830, -1.6157,  2.1736]],

        [[-1.2674,  2.6630,  2.4196,  ..., -2.9364, -2.8384,  3.1185],
         [-1.9227,  2.4544,  3.3431,  ..., -2.6660, -1.8208,  1.8212],
         [ 3.1147,  4.1192, -0.3609,  ..., -2.8908, -3.8964,  3.6366],
         ...,
         [-2.2906,  1.9073,  3.3184,  ..., -2.6888, -1.6334,  2.1374],
         [ 3.6985,  2.5015, -1.7699,  ..., -1.8598, -2.6669,  3.1311],
         [-3.13

Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[ 3.3421e-01,  3.8265e+00,  2.3135e+00,  ..., -3.1825e+00,
          -2.9034e+00,  2.5371e+00],
         [-1.4193e+00,  2.7596e+00,  3.0705e+00,  ..., -3.0439e+00,
          -2.3043e+00,  2.5063e+00],
         [ 1.9017e-01,  3.9395e+00,  2.3928e+00,  ..., -3.1071e+00,
          -2.9934e+00,  2.4124e+00],
         ...,
         [-1.7309e+00,  2.7336e+00,  3.1963e+00,  ..., -2.8902e+00,
          -2.2809e+00,  2.2933e+00],
         [-2.7484e-02,  3.8245e+00,  2.4801e+00,  ..., -3.2808e+00,
          -3.0940e+00,  2.7670e+00],
         [ 3.7464e+00,  4.2204e+00, -6.0517e-01,  ..., -2.1378e+00,
          -3.1388e+00,  2.0914e+00]],

        [[ 6.1473e-01,  3.9833e+00,  2.1626e+00,  ..., -3.1119e+00,
          -2.9509e+00,  2.4139e+00],
         [ 1.2514e-01,  3.6469e+00,  1.8341e+00,  ..., -3.2030e+00,
          -3.5379e+00,  3.4008e+00],
         [-2.9773e-01,  3.5444e+00,  2.5873e+00,  ..., -3.3045e+00,
          -2.9272e+00

Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[-3.2124,  1.0866,  3.3804,  ..., -2.1303, -1.1181,  1.7075],
         [-2.1905,  1.8496,  3.1510,  ..., -2.7982, -1.8104,  2.5152],
         [-2.8043,  1.4847,  3.3599,  ..., -2.4879, -1.4738,  2.1051],
         ...,
         [ 3.8352,  4.2690, -0.7389,  ..., -2.2938, -3.4242,  2.5123],
         [ 3.4762,  4.3847, -0.2412,  ..., -2.3752, -3.3062,  2.2584],
         [-2.8521,  1.3349,  3.2678,  ..., -2.4643, -1.4573,  2.2148]],

        [[-2.6235,  1.7006,  3.3175,  ..., -2.6065, -1.6829,  2.2582],
         [-2.6959,  1.7770,  3.3610,  ..., -2.4614, -1.6588,  1.9954],
         [-2.2788,  2.2090,  3.3640,  ..., -2.6865, -1.9090,  2.1143],
         ...,
         [ 3.4833,  4.0240, -0.6360,  ..., -2.8000, -3.7501,  3.5478],
         [ 1.7439,  4.0439,  1.0659,  ..., -3.4372, -3.6921,  3.7026],
         [-2.6802,  1.5992,  3.2555,  ..., -2.5746, -1.6805,  2.3218]],

        [[-2.7038,  1.5864,  3.2371,  ..., -2.5489, -1.6909, 

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[ 4.3501,  2.7546, -2.2069,  ..., -1.4203, -2.5481,  2.4347],
         [-2.1571,  2.2641,  3.2811,  ..., -2.8437, -2.0761,  2.4249],
         [ 3.7424,  4.3239, -0.5154,  ..., -2.3400, -3.3193,  2.3630],
         ...,
         [-0.9889,  3.2266,  2.9782,  ..., -3.1317, -2.6388,  2.5416],
         [-0.6814,  3.4790,  2.8382,  ..., -3.1667, -2.8296,  2.5822],
         [ 3.8990,  4.1610, -0.7585,  ..., -2.1149, -3.1392,  2.1542]],

        [[ 2.9506,  1.2560, -1.7482,  ..., -1.3331, -1.5959,  2.6013],
         [-1.6781,  2.5752,  3.1422,  ..., -3.0359, -2.2963,  2.6216],
         [ 3.1902,  4.1426, -0.2820,  ..., -2.9968, -3.8442,  3.6573],
         ...,
         [-0.3134,  3.4826,  2.6963,  ..., -3.1009, -2.6313,  2.3708],
         [-1.4274,  2.8856,  3.1515,  ..., -3.0201, -2.3711,  2.4095],
         [ 3.34

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[ 2.5262,  4.4565,  0.7972,  ..., -2.7967, -3.2611,  2.3302],
         [-2.6313,  1.6100,  3.3546,  ..., -2.6155, -1.5817,  2.2459],
         [ 3.6732,  1.4026, -2.5970,  ..., -0.0976, -1.1862,  0.9607],
         ...,
         [ 4.3881,  2.5888, -2.3614,  ..., -1.0929, -2.2701,  1.9981],
         [ 3.1858,  4.4729,  0.0112,  ..., -2.4654, -3.4098,  2.3041],
         [-3.1053,  1.0893,  3.3361,  ..., -2.3082, -1.2461,  2.0372]],

        [[ 3.4037,  2.8460, -1.5725,  ..., -1.9611, -3.1194,  3.2922],
         [-3.0700,  1.1808,  3.3267,  ..., -2.3296, -1.3378,  2.0697],
         [ 3.2241,  1.3304, -2.2400,  ..., -1.0711, -1.8422,  2.6034],
         ...,
         [ 2.9154,  1.0188, -2.1727,  ..., -0.9805, -1.6398,  2.5600],
         [ 3.4143,  4.0496, -0.5936,  ..., -2.8193, -3.8094,  3.5907],
         [-2.32

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[ 4.2221,  3.7864, -1.2492,  ..., -1.7411, -2.7897,  1.8570],
         [-2.7252,  1.4985,  3.3276,  ..., -2.5810, -1.5595,  2.2990],
         [ 3.2972,  4.0179, -0.4486,  ..., -1.8460, -2.9319,  1.6781],
         ...,
         [ 3.8834,  4.2203, -0.6442,  ..., -2.2273, -3.1698,  2.2096],
         [ 3.6766,  4.3956, -0.3555,  ..., -2.4973, -3.3890,  2.4812],
         [ 3.3149,  4.0445, -0.1811,  ..., -1.9762, -2.7141,  1.5311]],

        [[ 3.5032,  3.7121, -1.0165,  ..., -2.4914, -3.6610,  3.4803],
         [-3.4093,  0.8326,  3.3691,  ..., -2.0017, -0.9663,  1.6578],
         [ 3.7948,  3.6812, -1.2089,  ..., -2.4210, -3.5545,  3.3868],
         ...,
         [ 3.8234,  3.8264, -1.1086,  ..., -2.4935, -3.6112,  3.3692],
         [ 4.0324,  3.0164, -1.8998,  ..., -1.7732, -3.0048,  2.9176],
         [ 3.45

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[-0.6661,  3.4117,  2.8077,  ..., -3.3098, -2.8995,  2.8883],
         [ 2.6449,  4.6552,  0.5702,  ..., -2.8240, -3.6143,  2.5895],
         [-0.9901,  3.2597,  2.9676,  ..., -3.2215, -2.7629,  2.7464],
         ...,
         [-1.1009,  3.0475,  3.0720,  ..., -2.6847, -2.0876,  1.6769],
         [ 2.8290,  4.4219,  0.4950,  ..., -2.6044, -3.1353,  2.1245],
         [-1.4741,  2.8920,  3.1417,  ..., -3.1133, -2.5166,  2.6596]],

        [[ 0.6481,  3.7458,  1.3609,  ..., -3.1377, -3.7289,  3.5832],
         [ 2.2246,  4.4550,  0.9493,  ..., -3.3426, -3.7303,  3.3377],
         [-1.0736,  3.0296,  2.9638,  ..., -3.2210, -2.6093,  2.8006],
         ...,
         [-1.0636,  3.1615,  3.0208,  ..., -3.1763, -2.6268,  2.6391],
         [ 2.5291,  4.3960,  0.6871,  ..., -3.2544, -3.6567,  3.2963],
         [-0.83

Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[ 3.4496,  4.4951, -0.2663,  ..., -2.4381, -3.4589,  2.3842],
         [-2.0737,  2.3312,  3.2704,  ..., -2.9413, -2.1908,  2.6094],
         [-0.0138,  3.7596,  2.5565,  ..., -3.2369, -2.9296,  2.6172],
         ...,
         [ 3.8979,  3.2913, -1.0493,  ..., -1.4301, -2.0166,  1.1741],
         [ 3.4450,  4.3257, -0.1277,  ..., -2.3760, -3.0742,  2.0680],
         [ 4.3565,  2.3522, -2.5155,  ..., -0.9537, -2.0977,  1.9166]],

        [[ 1.3066,  4.1515,  1.6198,  ..., -3.4545, -3.4935,  3.3115],
         [-0.2244,  3.6320,  2.6495,  ..., -3.2714, -2.9002,  2.7100],
         [-0.5824,  3.5198,  2.8164,  ..., -3.2528, -2.8812,  2.7183],
         ...,
         [ 2.1162,  4.1761,  0.8674,  ..., -3.4021, -3.6932,  3.6271],
         [ 2.3246,  4.2541,  0.6527,  ..., -3.3822, -3.8363,  3.7302],
         [ 3.8313,  2.1068, -2.2411,  ..., -1.3953, -2.3323,  2.7544]],

        [[ 3.7118,  3.8392, -1.0923,  ..., -2.5176, -3.6567, 

Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[ 4.1684,  2.4924, -2.3907,  ..., -0.9919, -2.2801,  1.9705],
         [-3.0949,  1.2426,  3.4382,  ..., -2.3651, -1.3126,  2.0351],
         [-2.5564,  1.8855,  3.4362,  ..., -2.6868, -1.7409,  2.2474],
         ...,
         [ 2.4204,  4.6899,  0.8303,  ..., -3.2210, -3.8137,  3.1165],
         [ 3.6583,  4.4913, -0.3543,  ..., -2.5105, -3.4210,  2.4622],
         [ 4.3500,  2.4184, -2.4912,  ..., -0.9287, -2.1203,  1.8454]],

        [[ 3.6738,  1.9594, -2.2525,  ..., -1.3293, -2.2865,  2.7660],
         [-2.8735,  1.5281,  3.4334,  ..., -2.5209, -1.5336,  2.1729],
         [-2.4727,  2.0100,  3.3101,  ..., -2.7502, -1.9985,  2.4836],
         ...,
         [ 0.3694,  3.7971,  2.2407,  ..., -3.4113, -3.1359,  3.0656],
         [ 1.3436,  4.0651,  1.5157,  ..., -3.4762, -3.4964,  3.4471],
         [ 3.8955,  1.9611, -2.3451,  ..., -1.2625, -2.1539,  2.5956]],

        [[ 3.5902,  1.6278, -2.2746,  ..., -1.2173, -1.9672, 

Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[-2.7475,  1.5363,  3.3910,  ..., -2.5900, -1.5283,  2.2604],
         [ 4.4941,  3.4517, -1.8745,  ..., -1.7014, -2.9182,  2.3833],
         [ 3.8389,  4.2896, -0.6696,  ..., -2.1757, -3.2162,  2.1446],
         ...,
         [-0.0840,  3.7326,  2.6079,  ..., -3.1929, -2.8487,  2.5082],
         [-0.2747,  3.6264,  2.7056,  ..., -2.9455, -2.5884,  2.0547],
         [ 0.0816,  3.9324,  2.5110,  ..., -3.1077, -2.9546,  2.3670]],

        [[-2.5791,  2.0147,  3.2804,  ..., -2.4183, -1.8571,  1.9934],
         [ 4.0623,  2.1283, -2.2690,  ..., -1.3727, -2.1987,  2.5852],
         [ 3.4824,  3.5427, -1.1055,  ..., -2.5072, -3.5917,  3.6118],
         ...,
         [-0.6379,  3.4090,  2.8077,  ..., -3.3128, -2.8594,  2.8699],
         [ 0.2918,  3.7946,  2.3128,  ..., -3.3995, -3.1150,  3.0136],
         [ 1.5091,  4.2985,  1.3499,  ..., -3.5285, -3.8519,  3.6740]],

        [[-2.9239,  1.6549,  3.4303,  ..., -2.3564, -1.5636, 

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[-0.1537,  3.5684,  2.6658,  ..., -3.1294, -2.6838,  2.3969],
         [ 4.2213,  3.4829, -1.4795,  ..., -2.2433, -3.1802,  3.0815],
         [-2.9894,  1.1007,  3.2975,  ..., -2.4348, -1.3333,  2.2947],
         ...,
         [-2.2467,  2.2076,  3.5067,  ..., -2.6249, -1.6914,  1.8418],
         [-1.0408,  3.1954,  3.0927,  ..., -3.0757, -2.5118,  2.3757],
         [ 0.4035,  3.8586,  2.3309,  ..., -2.8529, -2.6749,  1.9271]],

        [[ 1.1721,  3.7001,  0.9124,  ..., -3.1477, -3.8431,  3.8477],
         [ 3.9825,  2.1161, -2.2762,  ..., -1.3857, -2.2994,  2.6934],
         [-2.6487,  1.7604,  3.3701,  ..., -2.6896, -1.8165,  2.4431],
         ...,
         [ 0.0329,  3.9511,  2.3015,  ..., -3.2734, -3.4733,  3.0540],
         [ 1.8171,  4.2193,  0.7025,  ..., -2.9593, -3.9047,  3.3129],
         [ 0.70

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[-1.7378,  2.7061,  3.3272,  ..., -3.0033, -2.2842,  2.4381],
         [ 3.7653,  4.4042, -0.6390,  ..., -2.5997, -3.7242,  2.9792],
         [ 3.7979,  4.2419, -0.5983,  ..., -2.1349, -3.1010,  2.0134],
         ...,
         [-0.9311,  3.3456,  3.0539,  ..., -3.0888, -2.6233,  2.3803],
         [-1.9423,  2.5943,  3.3622,  ..., -2.9604, -2.2731,  2.4643],
         [ 3.9728,  1.6645, -2.6641,  ..., -0.7443, -1.7846,  2.0068]],

        [[-1.0865,  3.1886,  3.1160,  ..., -3.0996, -2.5352,  2.4272],
         [ 3.4132,  3.1518, -1.3526,  ..., -2.2446, -3.3728,  3.5179],
         [ 3.3702,  3.1615, -1.3376,  ..., -2.2202, -3.3813,  3.4856],
         ...,
         [ 1.4888,  3.7541,  0.6249,  ..., -3.0456, -3.8750,  3.8118],
         [-1.3056,  2.9895,  3.1247,  ..., -3.1944, -2.5959,  2.7549],
         [ 3.80

Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[-0.8605,  3.4262,  3.0048,  ..., -3.1204, -2.7123,  2.4453],
         [ 4.4489,  2.3414, -2.5194,  ..., -1.0599, -2.1428,  2.0624],
         [-0.7526,  3.4514,  2.9463,  ..., -3.2147, -2.7932,  2.6199],
         ...,
         [-0.1714,  3.6412,  2.6629,  ..., -3.2630, -2.8667,  2.6583],
         [-0.7941,  3.4028,  2.9918,  ..., -3.0058, -2.5538,  2.1945],
         [ 4.1154,  4.1048, -0.9442,  ..., -2.0451, -3.0517,  2.0665]],

        [[-0.0667,  3.7915,  2.5978,  ..., -3.3414, -3.0692,  2.8226],
         [ 3.5510,  1.7228, -2.2148,  ..., -1.3341, -2.1572,  2.8252],
         [-0.0241,  3.7529,  2.5648,  ..., -3.3547, -3.0457,  2.8579],
         ...,
         [ 2.9472,  4.2780, -0.2267,  ..., -2.7081, -3.8955,  3.2406],
         [-0.3779,  3.6339,  2.7634,  ..., -3.2905, -2.9369,  2.7323],
         [ 1.7906,  4.0250,  1.0874,  ..., -3.4814, -3.6652,  3.7465]],

        [[ 1.5857,  4.3891,  1.0756,  ..., -3.0626, -3.8739, 

pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[ 0.0315,  3.9719,  2.5410,  ..., -3.1006, -3.0120,  2.3788],
         [-1.0484,  3.2128,  2.9985,  ..., -3.2931, -2.8061,  2.9203],
         [ 1.0713,  3.9862,  1.8538,  ..., -2.6869, -2.6229,  1.7357],
         ...,
         [ 3.2219,  4.6061,  0.1235,  ..., -3.0664, -3.8175,  3.2101],
         [ 4.3327,  2.0517, -2.6494,  ..., -0.8460, -1.9276,  1.9107],
         [ 0.0637,  3.8442,  2.5460,  ..., -2.9903, -2.7599,  2.1303]],

        [[-1.1479,  2.9932,  2.9858,  ..., -3.2834, -2.6855,  2.9876],
         [-1.1338,  3.0435,  3.0155,  ..., -3.2685, -2.6730,  2.9088],
         [ 1.4003,  4.0493,  1.3789,  ..., -3.5586, -3.7183,  3.7945],
         ...,
         [ 2.1436,  1.5102, -0.7995,  ..., -2.0447, -2.0179,  3.2829],
         [ 3.4075,  1.5600, -2.0647,  ..., -1.4062, -1.9983,  2.8568],
         [-0.5402,  3.4588,  2.6842,  ..., -3.4463, -3.1215,  3.2602]],

 

Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[ 0.6885,  3.9747,  2.1538,  ..., -2.8935, -2.7298,  2.0021],
         [ 2.0973,  4.4439,  1.1328,  ..., -2.7605, -3.1058,  2.0833],
         [-3.1288,  1.1480,  3.4237,  ..., -2.4113, -1.3471,  2.2266],
         ...,
         [ 4.1140,  2.5969, -2.3173,  ..., -1.1767, -2.4920,  2.2495],
         [-3.2234,  1.1689,  3.5002,  ..., -2.3388, -1.2939,  2.0583],
         [ 3.8134,  4.3902, -0.6922,  ..., -2.2949, -3.4219,  2.3984]],

        [[-0.4380,  3.6116,  2.6486,  ..., -3.4610, -3.2283,  3.2655],
         [ 1.1364,  4.3078,  1.8258,  ..., -3.4822, -3.5646,  3.2365],
         [-2.7346,  1.1194,  3.3889,  ..., -1.6356, -0.3633,  0.4889],
         ...,
         [ 4.1364,  2.3436, -2.3508,  ..., -1.3934, -2.4525,  2.6773],
         [-3.2349,  1.0762,  3.3525,  ..., -2.3011, -1.3650,  2.2002],
         [ 2.8306,  3.8049, -0.3701,  ..., -2.9672, -3.8829,  3.9784]],

        [[ 0.1524,  3.8026,  2.3541,  ..., -3.5189, -3.3117, 

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[ 4.2988,  2.0865, -2.6178,  ..., -0.9816, -2.0601,  2.1417],
         [ 4.4840,  2.6536, -2.3411,  ..., -1.4248, -2.4884,  2.4740],
         [ 3.4782,  4.3775, -0.2589,  ..., -2.2529, -3.1798,  2.0038],
         ...,
         [-3.1861,  0.9278,  3.3935,  ..., -2.3468, -1.1966,  2.2039],
         [-1.4289,  3.0207,  3.2930,  ..., -2.9718, -2.3622,  2.2288],
         [ 3.8707,  4.2661, -0.5675,  ..., -2.2503, -3.1089,  2.1138]],

        [[ 3.3456,  1.5018, -2.2632,  ..., -1.1963, -2.0662,  2.7856],
         [ 3.6350,  1.5539, -2.4016,  ..., -1.1337, -1.9827,  2.6321],
         [ 3.0267,  3.2146, -0.9140,  ..., -2.5952, -3.5403,  3.8840],
         ...,
         [-2.6107,  1.9468,  3.3115,  ..., -2.6502, -2.0463,  2.4858],
         [-1.5816,  2.6218,  3.3107,  ..., -2.9067, -2.0303,  2.1827],
         [ 0.46

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[ 3.0911,  4.6673,  0.2782,  ..., -2.8048, -3.5937,  2.6140],
         [ 1.0289,  4.3668,  1.8901,  ..., -2.9966, -3.2550,  2.3368],
         [ 3.3411,  4.4321, -0.0248,  ..., -2.4007, -3.1770,  2.0601],
         ...,
         [-0.3871,  3.7327,  2.7841,  ..., -3.2929, -3.0185,  2.7436],
         [-0.5699,  3.6425,  2.8622,  ..., -3.2892, -2.9903,  2.7688],
         [ 3.2644,  4.2850,  0.1468,  ..., -2.7401, -3.1814,  2.5190]],

        [[ 3.8310,  3.5569, -1.3870,  ..., -2.3171, -3.5364,  3.4040],
         [-0.4479,  3.7407,  2.8156,  ..., -3.1925, -2.9530,  2.5564],
         [ 0.2266,  3.7437,  2.3387,  ..., -3.4963, -3.2266,  3.2713],
         ...,
         [-1.1573,  3.2160,  3.1203,  ..., -3.2111, -2.7012,  2.7062],
         [ 1.5750,  3.9667,  0.6687,  ..., -2.9674, -3.8950,  3.5613],
         [ 3.97

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[ 3.5267,  4.6292, -0.1219,  ..., -2.7563, -3.5687,  2.6872],
         [ 4.3381,  2.2651, -2.5864,  ..., -1.0279, -2.1962,  2.1510],
         [-2.7444,  1.6611,  3.6136,  ..., -2.4474, -1.2939,  1.7462],
         ...,
         [-0.2336,  3.8655,  2.7005,  ..., -3.3160, -3.0887,  2.7554],
         [ 0.1488,  4.0731,  2.4847,  ..., -3.3623, -3.2615,  2.8492],
         [-0.6812,  3.6002,  2.9346,  ..., -3.2217, -2.8536,  2.5945]],

        [[ 3.4263,  2.8832, -1.6034,  ..., -1.9902, -3.1880,  3.3540],
         [ 3.7339,  1.6209, -2.4550,  ..., -1.1219, -2.0284,  2.6092],
         [-2.7402,  1.9167,  3.4382,  ..., -2.6310, -1.9148,  2.3561],
         ...,
         [-0.4940,  3.5499,  2.7468,  ..., -3.4305, -3.0567,  3.1073],
         [-0.0524,  3.9407,  2.6032,  ..., -3.3509, -3.1533,  2.8165],
         [ 0.76

Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[-2.2168,  2.3913,  3.4957,  ..., -2.9211, -2.1313,  2.4572],
         [ 2.2924,  4.6037,  1.0202,  ..., -2.8989, -3.3402,  2.3571],
         [-0.6366,  3.4552,  2.9228,  ..., -3.2801, -2.7834,  2.7183],
         ...,
         [ 2.1513,  4.6985,  1.0049,  ..., -2.9106, -3.5832,  2.5200],
         [-1.3260,  3.1716,  3.2534,  ..., -3.0776, -2.5309,  2.4091],
         [-0.8085,  3.6119,  2.9560,  ..., -3.0225, -2.8177,  2.3230]],

        [[ 0.2548,  4.0674,  2.0864,  ..., -3.3775, -3.7077,  3.3943],
         [ 1.9855,  4.3702,  1.0769,  ..., -3.5178, -3.8549,  3.7039],
         [ 0.6692,  4.0875,  1.6034,  ..., -3.1728, -3.7848,  3.3523],
         ...,
         [ 2.9114,  3.7329, -0.6083,  ..., -2.7174, -3.8314,  3.7869],
         [-0.9066,  3.2911,  3.0821,  ..., -3.0780, -2.4902,  2.3172],
         [-0.9352,  3.2842,  3.0964,  ..., -3.0737, -2.4872,  2.3115]],

        [[-1.6618,  2.6833,  3.3835,  ..., -2.7906, -1.9490, 

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[-2.8408,  1.6207,  3.4715,  ..., -2.6629, -1.7178,  2.4524],
         [-0.4080,  3.6677,  2.8443,  ..., -3.1460, -2.7324,  2.3652],
         [ 0.3944,  4.1319,  2.3953,  ..., -3.1547, -3.0404,  2.3879],
         ...,
         [-0.8866,  3.4953,  3.0668,  ..., -3.1417, -2.7068,  2.4238],
         [ 0.0781,  4.0867,  2.5600,  ..., -3.2506, -3.1532,  2.5939],
         [ 3.9421,  4.4295, -0.6192,  ..., -2.3977, -3.3326,  2.3687]],

        [[-2.6435,  2.0144,  3.5279,  ..., -2.7511, -1.9245,  2.3947],
         [-0.2644,  3.8422,  2.7118,  ..., -3.4125, -3.1734,  2.9580],
         [-0.5939,  3.5041,  2.8908,  ..., -3.3092, -2.8255,  2.7581],
         ...,
         [ 0.9671,  3.9484,  1.1677,  ..., -3.0627, -3.8314,  3.5045],
         [ 2.4677,  4.5233,  0.3409,  ..., -2.9506, -3.9984,  3.2680],
         [ 3.62

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[-2.8300,  1.5056,  3.4243,  ..., -2.6798, -1.7327,  2.5898],
         [ 3.4802,  4.5776, -0.0420,  ..., -2.6387, -3.4222,  2.4199],
         [ 3.0182,  4.3005,  0.2096,  ..., -2.2170, -2.8761,  1.6140],
         ...,
         [-0.2409,  3.8348,  2.7825,  ..., -3.1509, -2.9086,  2.3925],
         [ 4.3493,  2.8059, -2.2639,  ..., -1.5045, -2.7284,  2.6185],
         [ 3.2959,  4.3023, -0.2779,  ..., -2.0951, -3.1334,  1.8496]],

        [[-3.1555,  1.4114,  3.6662,  ..., -2.4297, -1.4056,  2.0246],
         [ 3.2716,  4.4527, -0.0487,  ..., -3.1644, -3.9584,  3.6099],
         [ 3.5445,  3.9188, -0.8637,  ..., -2.7183, -3.8284,  3.6683],
         ...,
         [-0.1405,  3.4156,  1.8826,  ..., -3.1776, -3.5631,  3.6334],
         [ 3.7264,  2.2503, -1.9294,  ..., -1.8048, -2.5916,  3.2051],
         [ 3.16

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[ 4.3154,  1.6977, -2.8139,  ..., -0.4727, -1.4799,  1.3986],
         [-0.7622,  3.5854,  3.0259,  ..., -3.2801, -2.9311,  2.7026],
         [ 3.4144,  4.5278, -0.2349,  ..., -2.4178, -3.4786,  2.3184],
         ...,
         [-0.0484,  4.0578,  2.6263,  ..., -3.3126, -3.2850,  2.7720],
         [ 0.1543,  3.9797,  2.5692,  ..., -3.3655, -3.1491,  2.7831],
         [ 3.7934,  4.5168, -0.5409,  ..., -2.4982, -3.5886,  2.6000]],

        [[ 1.9969,  0.4901, -1.4834,  ..., -1.2920, -1.4224,  2.9311],
         [-0.0868,  3.6969,  2.6694,  ..., -3.3570, -2.9569,  2.8084],
         [ 2.8649,  4.0841,  0.0744,  ..., -3.2717, -3.9051,  3.9576],
         ...,
         [-1.5004,  2.8915,  3.2982,  ..., -3.1746, -2.4811,  2.6856],
         [ 0.2465,  4.1117,  2.5078,  ..., -3.3772, -3.2821,  2.8246],
         [ 3.28

pred: torch.Size([100, 4, 7])
target: torch.Size([100, 4])
pred after: torch.Size([100, 7, 4])
torch.Size([100, 5, 5])
Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[ 4.4071,  4.0851, -1.2368,  ..., -2.1048, -3.1470,  2.3016],
         [-2.9292,  1.4760,  3.4930,  ..., -2.6444, -1.6790,  2.5147],
         [ 4.5476,  2.3783, -2.6152,  ..., -0.9352, -2.0824,  1.8356],
         ...,
         [ 4.0124,  1.6496, -2.6802,  ..., -0.8926, -1.9044,  2.2624],
         [ 3.8147,  4.4765, -0.4075,  ..., -2.4723, -3.3298,  2.3261],
         [-2.9236,  1.7334,  3.7018,  ..., -2.5791, -1.5985,  2.0920]],

        [[ 4.0351,  2.6892, -2.0849,  ..., -1.7373, -2.8827,  3.0673],
         [-3.3117,  1.1547,  3.5434,  ..., -2.3646, -1.3971,  2.2044],
         [ 4.2397,  2.0851, -2.4994,  ..., -1.2682, -2.2275,  2.5585],
         ...,
         [ 3.2090,  1.4063, -2.0431,  ..., -1.4246, -2.0576,  3.0327],
         [ 3.6611,  3.3931, -1.4104,  ..., -2.2415, -3.4869,  3.4130],
         [-3.25

Before permute:
 Size: torch.Size([4, 100, 7])
Tensor: tensor([[[-3.2689,  0.9261,  3.5377,  ..., -2.3719, -1.1900,  2.2141],
         [-1.6751,  2.7573,  3.4373,  ..., -3.0660, -2.2942,  2.4696],
         [ 3.5023,  4.3074, -0.2410,  ..., -2.1535, -2.9899,  1.7496],
         ...,
         [-1.5573,  2.9998,  3.3961,  ..., -3.1419, -2.5410,  2.5895],
         [-3.4585,  0.7472,  3.5480,  ..., -2.2369, -1.0505,  2.0858],
         [-1.3383,  3.1048,  3.3402,  ..., -2.8655, -2.2722,  1.9552]],

        [[-3.4328,  0.7544,  3.3183,  ..., -2.1878, -1.2732,  2.2985],
         [-1.5921,  2.7879,  3.4127,  ..., -3.0487, -2.2661,  2.4037],
         [ 1.5205,  3.9267,  1.3819,  ..., -3.5485, -3.5987,  3.7609],
         ...,
         [-1.7731,  2.5808,  3.3547,  ..., -3.1635, -2.4004,  2.8510],
         [-3.4619,  1.1393,  3.6371,  ..., -2.1510, -1.2386,  1.8108],
         [-0.4722,  3.0309,  1.8063,  ..., -2.7439, -3.2288,  3.1973]],

        [[-3.3372,  1.0193,  3.5861,  ..., -2.3671, -1.2736, 

KeyboardInterrupt: 

In [115]:
def predict(model, input_seq, max_length=5, SOS_token=5, EOS_token=6):
    model.eval()
    
    target_input = torch.tensor([[SOS_token]], dtype=torch.long, device=device)

    # Asks model to give only one item, the next thing it thinks is most probable 
    # to continue the sentence
    for i in range(max_length):
        
        pred = model(input_seq, target_input)
        next_item = pred.topk(1)[1].view(-1)[0].item() # num with highest probability, twas view(-1)[-1] before
        next_item = torch.tensor([[next_item]], device=device)

        # Concatenate previous input with predicted best word
        target_input = torch.cat((target_input, next_item), dim=1)

        # Stop if model predicts end of sentence
        if next_item.view(-1).item() == EOS_token:
            break
        elif (len(target_input.view(-1)) >= len(input_seq[0])):
            break

    return target_input.view(-1).tolist()

In [165]:
example = torch.tensor([[5, 1, 1, 1, 6]]).type(torch.long)
result = predict(model, example)
print(f"Example")
print(f"Input: {example.view(-1).tolist()[1:-1]}")
print(f"Continuation: {result[1:-1]}")

example = torch.tensor([[5, 0, 0, 0, 6]]).type(torch.long)
result = predict(model, example)
print(f"Example")
print(f"Input: {example.view(-1).tolist()[1:-1]}")
print(f"Continuation: {result[1:-1]}")

example = torch.tensor([[5, 1, 0, 2, 6]]).type(torch.long)
result = predict(model, example)
print(f"Example")
print(f"Input: {example.view(-1).tolist()[1:-1]}")
print(f"Continuation: {result[1:-1]}")

torch.Size([1, 5, 5])
Before permute:
 Size: torch.Size([1, 1, 7])
Tensor: tensor([[[-2.6484,  1.9779,  3.5695, -1.8861, -2.8080, -1.9466,  2.5078]]],
       grad_fn=<AddBackward0>)
torch.Size([1, 5, 5])
Before permute:
 Size: torch.Size([2, 1, 7])
Tensor: tensor([[[-2.8414,  1.6805,  3.6075, -1.6949, -2.6785, -1.6684,  2.3490]],

        [[-2.6494,  1.9911,  3.4576, -1.8127, -2.7966, -2.0892,  2.6494]]],
       grad_fn=<AddBackward0>)
torch.Size([1, 5, 5])
Before permute:
 Size: torch.Size([3, 1, 7])
Tensor: tensor([[[-2.9311,  1.5419,  3.6023, -1.5876, -2.6211, -1.5697,  2.3138]],

        [[-2.8487,  1.7730,  3.5162, -1.6703, -2.6940, -1.8900,  2.5216]],

        [[-2.8487,  1.7730,  3.5162, -1.6703, -2.6940, -1.8900,  2.5216]]],
       grad_fn=<AddBackward0>)
torch.Size([1, 5, 5])
Before permute:
 Size: torch.Size([4, 1, 7])
Tensor: tensor([[[-2.9623,  1.4876,  3.5949, -1.5425, -2.6012, -1.5365,  2.3094]],

        [[-2.9396,  1.6643,  3.5434, -1.6017, -2.6434, -1.7866,  2.4550]],
