In [48]:
import pandas as pd
import nltk
import numpy as np
import csv
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torchtext.data import Field
from torchtext.data import TabularDataset
from torchtext.data import Iterator, BucketIterator
from sklearn.metrics import confusion_matrix

In [49]:
tokenize = lambda x: x.split()
TEXT = Field(sequential=True, tokenize=tokenize, lower=True, fix_length=100)
LABEL = Field(sequential=False, use_vocab=False)

In [50]:
train_stances = pd.read_csv("fn_data/train_stances.csv")
train_bodies = pd.read_csv("fn_data/train_bodies.csv")

In [51]:
train = pd.read_csv("lstmstuff/train_data_consolidated.csv")
val = pd.read_csv("lstmstuff/val_data_consolidated.csv")

In [52]:
train_datafields = [("Body ID", None),
                 ("Headline", TEXT),
                    ("Body", TEXT),
                 ("Stance", LABEL)]

trn, vld = TabularDataset.splits(
               path="",
               train='train_data_related_only_12_2.csv', validation="val_data_related_only_12_2.csv",
               format='csv',
               skip_header=True,
               fields=train_datafields)

In [53]:
TEXT.build_vocab(trn, vectors = 'glove.6B.100d')

In [54]:
train_iter, val_iter = BucketIterator.splits(
 (trn, vld),
 batch_sizes=(128,128),
 sort_key=lambda x: len(x.Body), # the BucketIterator needs to be told what function it should use to group the data.
 sort_within_batch=False,
 repeat=False
)

In [55]:
class BatchGenerator:
    def __init__(self, dl, x_field1, x_field2, y_field):
        self.dl, self.x_field1, self.x_field2, self.y_field = dl, x_field1, x_field2, y_field
        
    def __len__(self):
        return len(self.dl)
    
    def __iter__(self):
        for batch in self.dl:
            X1 = getattr(batch, self.x_field1)
            X2 = getattr(batch, self.x_field2)
            y = getattr(batch, self.y_field)
            yield (X1, X2, y)
            
train_batch_it = BatchGenerator(train_iter, 'Body','Headline', 'Stance')
val_batch_it = BatchGenerator(val_iter, 'Body', 'Headline', 'Stance')

In [64]:
# class RNN(nn.Module):
#     def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
#         super().__init__()
        
#         self.embedding = nn.Embedding(input_dim, embedding_dim)
#         self.lstm_headline = nn.LSTM(embedding_dim, hidden_dim)
#         self.lstm_article = nn.LSTM(embedding_dim, hidden_dim)
#         self.fc = nn.Linear(hidden_dim, output_dim)
        
#     def forward(self, x):

#         #x = [sent len, batch size]
        
#         embedded = self.embedding(x)
        
#         #embedded = [sent len, batch size, emb dim]
        
#         output, hidden = self.rnn(embedded)
        
#         #output = [sent len, batch size, hid dim]
#         #hidden = [1, batch size, hid dim]
        
#         assert torch.equal(output[-1,:,:], hidden.squeeze(0))
        
#         return self.fc(hidden.squeeze(0))

import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class LSTMClassifier(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, vocab_size, output_dim, batch_size):
        super(LSTMClassifier, self).__init__()
        self.hidden_dim = hidden_dim
        self.embed = nn.Embedding(vocab_size, embedding_dim)
        self.lstm_headline = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True, dropout = 0.25, num_layers = 2)
        self.lstm_article = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True, dropout = 0.2, num_layers = 2)
        self.fc = nn.Linear(hidden_dim*2, output_dim*8)
        self.fc2 = nn.Linear(output_dim*8, output_dim)
        self.batch_size = batch_size
        self.hidden1 = None
        self.hidden2 = None
        
    def init_hidden(self, batch_size):
        h0 = Variable(torch.zeros(4, batch_size, self.hidden_dim))
        c0 = Variable(torch.zeros(4, batch_size, self.hidden_dim))
        return (h0, c0)
    
    def forward(self, headline, article):
        headline_emb = self.embed(headline)
        ""
        article_emb = self.embed(article)
        
        lstm_headline_out, hidden = self.lstm_headline(headline_emb, self.hidden1)
        
        sum_padded_out_lstm1 = 0
        for tensor in torch.split(lstm_headline_out, self.hidden_dim, dim=2):
            sum_padded_out_lstm1 += tensor

        lstm_article_out, hidden2 = self.lstm_article(article_emb, hidden)
        res = self.fc(torch.cat((lstm_article_out[-1, :, :self.hidden_dim], lstm_article_out[0, :, self.hidden_dim:]), 1))
        res2 = self.fc2(F.relu(res))
        return res2
        

In [65]:
model = LSTMClassifier(100, 64, len(TEXT.vocab), 3, 128)

In [70]:
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

In [71]:
counter = []
loss_history = [] 
iteration_number= 0

model.train()
# Train the model
for epoch in range(0,10):
    for i, data in enumerate(train_batch_it,0):
        img0, img1 , label = data
        print("inputs are")
        print(img0.shape)
        print(img1.shape)
        optimizer.zero_grad()
        _, batch_size = img0.shape
        model.hidden1 = model.init_hidden(batch_size)
        model.hidden2 = model.init_hidden(batch_size)
        output = model(img0,img1)
        print("OUTPUT")
        print(output)
        print(label)
        loss = criterion(output.squeeze(),label.squeeze().long())
        loss.backward()
        optimizer.step()
        if True:
            print("Epoch number {}\n Current loss {}\n".format(epoch,loss.item()))
            iteration_number +=10
            counter.append(iteration_number)
            loss_history.append(loss.item())


inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.0256,  0.1915,  0.0020],
        [-0.0382,  0.1906,  0.0050],
        [-0.0226,  0.1740, -0.0197],
        [-0.0296,  0.1906, -0.0086],
        [-0.0288,  0.1852, -0.0000],
        [-0.0232,  0.1704, -0.0117],
        [-0.0315,  0.1815, -0.0138],
        [-0.0270,  0.1762, -0.0178],
        [-0.0331,  0.2013,  0.0110],
        [-0.0165,  0.1962, -0.0109],
        [-0.0233,  0.1875, -0.0089],
        [-0.0191,  0.1975, -0.0089],
        [-0.0363,  0.1940, -0.0126],
        [-0.0211,  0.1942,  0.0033],
        [-0.0264,  0.1902, -0.0039],
        [-0.0205,  0.1870,  0.0049],
        [-0.0273,  0.1964, -0.0087],
        [-0.0235,  0.1866, -0.0080],
        [-0.0256,  0.1871, -0.0045],
        [-0.0335,  0.1980,  0.0060],
        [-0.0311,  0.1768,  0.0104],
        [-0.0277,  0.1891, -0.0161],
        [-0.0420,  0.1866,  0.0035],
        [-0.0437,  0.1934,  0.0064],
        [-0.0334,  0.1734, -0.0008],
        [-0

Epoch number 0
 Current loss 1.0339120626449585

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.0707,  0.2440,  0.0424],
        [-0.0626,  0.2392,  0.0364],
        [-0.0649,  0.2439,  0.0383],
        [-0.0783,  0.2349,  0.0401],
        [-0.0846,  0.2228,  0.0536],
        [-0.0459,  0.2481,  0.0245],
        [-0.0580,  0.2280,  0.0313],
        [-0.0593,  0.2413,  0.0327],
        [-0.0720,  0.2390,  0.0258],
        [-0.0700,  0.2290,  0.0345],
        [-0.0515,  0.2421,  0.0218],
        [-0.0692,  0.2247,  0.0312],
        [-0.0662,  0.2307,  0.0182],
        [-0.0528,  0.2178,  0.0195],
        [-0.0556,  0.2382,  0.0302],
        [-0.0765,  0.2275,  0.0314],
        [-0.0670,  0.2353,  0.0346],
        [-0.0535,  0.2292,  0.0224],
        [-0.0575,  0.2278,  0.0133],
        [-0.0603,  0.2295,  0.0269],
        [-0.0631,  0.2342,  0.0340],
        [-0.0615,  0.2333,  0.0268],
        [-0.0622,  0.2321,  0.0269],
        [-0.0520,  0.2217,  0.0216],

Epoch number 0
 Current loss 0.9965620636940002

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.1190,  0.2847,  0.0765],
        [-0.1102,  0.2777,  0.0731],
        [-0.1033,  0.2910,  0.0462],
        [-0.1207,  0.2989,  0.0762],
        [-0.1113,  0.3230,  0.0828],
        [-0.1138,  0.2903,  0.0745],
        [-0.1068,  0.2929,  0.0754],
        [-0.1037,  0.2911,  0.0702],
        [-0.1178,  0.2999,  0.0704],
        [-0.1213,  0.2931,  0.0639],
        [-0.1032,  0.2941,  0.0636],
        [-0.1066,  0.2974,  0.0607],
        [-0.1124,  0.2777,  0.0572],
        [-0.0957,  0.2858,  0.0635],
        [-0.1285,  0.2895,  0.0824],
        [-0.1143,  0.2934,  0.0559],
        [-0.1067,  0.2918,  0.0617],
        [-0.1191,  0.2880,  0.0776],
        [-0.1087,  0.2842,  0.0793],
        [-0.1205,  0.2658,  0.0642],
        [-0.0997,  0.3130,  0.0685],
        [-0.1197,  0.3057,  0.0730],
        [-0.1014,  0.3021,  0.0524],
        [-0.1145,  0.2839,  0.0740],

Epoch number 0
 Current loss 0.9758898615837097

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.1709,  0.3941,  0.1166],
        [-0.1600,  0.3731,  0.0954],
        [-0.1702,  0.3706,  0.1195],
        [-0.1716,  0.3850,  0.1030],
        [-0.1698,  0.3522,  0.1005],
        [-0.1660,  0.3693,  0.1159],
        [-0.1574,  0.4091,  0.1209],
        [-0.1704,  0.3625,  0.1229],
        [-0.1834,  0.3925,  0.1343],
        [-0.1604,  0.3827,  0.1138],
        [-0.1532,  0.3779,  0.0993],
        [-0.1828,  0.3810,  0.1231],
        [-0.1757,  0.3668,  0.1075],
        [-0.1730,  0.3846,  0.1104],
        [-0.1652,  0.3669,  0.1115],
        [-0.1679,  0.3873,  0.1186],
        [-0.1652,  0.4018,  0.1103],
        [-0.1619,  0.3968,  0.0881],
        [-0.1843,  0.3640,  0.1126],
        [-0.1505,  0.3540,  0.0893],
        [-0.1666,  0.3755,  0.1015],
        [-0.1615,  0.3864,  0.1206],
        [-0.1561,  0.3585,  0.0969],
        [-0.1593,  0.3397,  0.0968],

Epoch number 0
 Current loss 0.9127746820449829

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.2488,  0.4788,  0.1610],
        [-0.2206,  0.4870,  0.1390],
        [-0.2491,  0.5280,  0.1831],
        [-0.2497,  0.4946,  0.1658],
        [-0.2477,  0.4690,  0.1568],
        [-0.2395,  0.4818,  0.1629],
        [-0.2308,  0.4488,  0.1472],
        [-0.2279,  0.4940,  0.1582],
        [-0.2384,  0.4952,  0.1548],
        [-0.2416,  0.4701,  0.1535],
        [-0.2512,  0.4747,  0.1634],
        [-0.2396,  0.4815,  0.1678],
        [-0.2421,  0.4929,  0.1600],
        [-0.2091,  0.4653,  0.1386],
        [-0.2394,  0.4860,  0.1576],
        [-0.2549,  0.4924,  0.1705],
        [-0.2415,  0.4996,  0.1677],
        [-0.2388,  0.4681,  0.1444],
        [-0.2576,  0.4722,  0.1751],
        [-0.2466,  0.5131,  0.1686],
        [-0.2570,  0.5039,  0.1815],
        [-0.2355,  0.4869,  0.1612],
        [-0.2628,  0.4829,  0.1766],
        [-0.2423,  0.4627,  0.1618],

Epoch number 0
 Current loss 0.9108644723892212

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.3313,  0.6144,  0.2205],
        [-0.3285,  0.6378,  0.2219],
        [-0.3259,  0.6265,  0.2154],
        [-0.3658,  0.7208,  0.2472],
        [-0.3279,  0.6434,  0.2056],
        [-0.3196,  0.6110,  0.2197],
        [-0.3227,  0.7271,  0.2382],
        [-0.3392,  0.6267,  0.2118],
        [-0.3300,  0.6539,  0.2203],
        [-0.3212,  0.6435,  0.2110],
        [-0.3422,  0.6594,  0.2320],
        [-0.3436,  0.6147,  0.2234],
        [-0.3669,  0.7215,  0.2514],
        [-0.3419,  0.6448,  0.2199],
        [-0.3500,  0.6221,  0.2071],
        [-0.3237,  0.6108,  0.2238],
        [-0.3593,  0.6332,  0.2423],
        [-0.3142,  0.6104,  0.1969],
        [-0.3284,  0.6587,  0.2023],
        [-0.3316,  0.6384,  0.2184],
        [-0.3336,  0.6244,  0.2248],
        [-0.3304,  0.7350,  0.2499],
        [-0.3208,  0.6325,  0.2010],
        [-0.3256,  0.6566,  0.2075],

Epoch number 0
 Current loss 0.8495725393295288

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.4266,  0.7381,  0.2416],
        [-0.4456,  0.7757,  0.2448],
        [-0.4317,  0.8044,  0.2572],
        [-0.4439,  0.8693,  0.3069],
        [-0.4237,  0.7867,  0.2678],
        [-0.4203,  0.7420,  0.2404],
        [-0.4423,  0.8374,  0.2677],
        [-0.4759,  0.8341,  0.2851],
        [-0.3973,  0.7560,  0.2644],
        [-0.4482,  0.8530,  0.2774],
        [-0.4418,  0.7472,  0.2345],
        [-0.4563,  0.7748,  0.2771],
        [-0.4057,  0.7939,  0.2471],
        [-0.4359,  0.8397,  0.2934],
        [-0.4423,  0.8343,  0.2847],
        [-0.4097,  0.7996,  0.2367],
        [-0.4782,  0.8603,  0.2931],
        [-0.4586,  0.8081,  0.2657],
        [-0.4380,  0.8966,  0.2913],
        [-0.4555,  0.7849,  0.2778],
        [-0.4696,  0.8166,  0.2956],
        [-0.4491,  0.7983,  0.2742],
        [-0.4312,  0.8175,  0.2590],
        [-0.4513,  0.8221,  0.2813],

Epoch number 0
 Current loss 0.8258869647979736

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.5443,  1.0268,  0.3020],
        [-0.5234,  0.9547,  0.2838],
        [-0.5641,  1.1510,  0.3228],
        [-0.5373,  0.9828,  0.2942],
        [-0.5646,  1.0469,  0.3111],
        [-0.5479,  1.0223,  0.2918],
        [-0.5574,  0.9893,  0.2816],
        [-0.5659,  1.0351,  0.3078],
        [-0.5940,  1.1347,  0.3441],
        [-0.5683,  1.0726,  0.3160],
        [-0.5424,  1.0081,  0.3010],
        [-0.5710,  1.0261,  0.3017],
        [-0.5444,  0.9802,  0.2968],
        [-0.5403,  0.9734,  0.3002],
        [-0.5525,  1.0301,  0.2966],
        [-0.5725,  1.0837,  0.2975],
        [-0.5649,  1.0387,  0.2894],
        [-0.5491,  1.0635,  0.3006],
        [-0.5514,  1.0530,  0.3338],
        [-0.5116,  0.9402,  0.2706],
        [-0.5671,  1.1104,  0.3244],
        [-0.5393,  0.9746,  0.2752],
        [-0.5550,  1.0324,  0.2937],
        [-0.5688,  1.0242,  0.3250],

Epoch number 0
 Current loss 0.8915838003158569

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.6702,  1.1814,  0.3086],
        [-0.7072,  1.3424,  0.3405],
        [-0.7081,  1.2397,  0.3544],
        [-0.6589,  1.1772,  0.3229],
        [-0.6430,  1.2105,  0.2987],
        [-0.6819,  1.2325,  0.3360],
        [-0.6637,  1.2248,  0.3140],
        [-0.6256,  1.1643,  0.2978],
        [-0.7072,  1.3048,  0.3464],
        [-0.7218,  1.3886,  0.3587],
        [-0.6702,  1.2453,  0.3387],
        [-0.6851,  1.3318,  0.3437],
        [-0.6713,  1.2460,  0.3248],
        [-0.6863,  1.2461,  0.3453],
        [-0.6919,  1.2546,  0.3585],
        [-0.6925,  1.2944,  0.3428],
        [-0.6942,  1.3577,  0.3455],
        [-0.7457,  1.3714,  0.3735],
        [-0.6228,  1.0787,  0.2930],
        [-0.7001,  1.2772,  0.3200],
        [-0.7067,  1.3460,  0.3529],
        [-0.7103,  1.3127,  0.3469],
        [-0.7008,  1.3550,  0.3555],
        [-0.6927,  1.2652,  0.3503],

Epoch number 0
 Current loss 0.757834792137146

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.8154,  1.4479,  0.4338],
        [-0.8927,  1.6393,  0.4673],
        [-0.8625,  1.5171,  0.4357],
        [-0.8469,  1.4999,  0.4493],
        [-0.8446,  1.4365,  0.4315],
        [-0.8228,  1.4807,  0.4195],
        [-0.7933,  1.3648,  0.4195],
        [-0.7665,  1.4007,  0.3793],
        [-0.8370,  1.4996,  0.4526],
        [-0.9168,  1.6052,  0.4520],
        [-0.8868,  1.5710,  0.4552],
        [-0.9025,  1.5912,  0.4665],
        [-0.8660,  1.4899,  0.4491],
        [-0.8189,  1.4620,  0.4155],
        [-0.8059,  1.4009,  0.4183],
        [-0.8248,  1.5077,  0.4274],
        [-0.7862,  1.4018,  0.4074],
        [-0.8791,  1.5699,  0.4559],
        [-0.7920,  1.4239,  0.4162],
        [-0.8426,  1.5101,  0.4309],
        [-0.8776,  1.5310,  0.4603],
        [-0.8770,  1.5485,  0.4542],
        [-0.7730,  1.3277,  0.4077],
        [-0.8332,  1.5045,  0.4277],


Epoch number 0
 Current loss 0.7180245518684387

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.0303,  1.6735,  0.5459],
        [-1.0192,  1.7036,  0.5412],
        [-0.9745,  1.5994,  0.5139],
        [-0.9443,  1.5424,  0.5172],
        [-0.9828,  1.5833,  0.5202],
        [-1.0601,  1.7817,  0.5768],
        [-1.0177,  1.7155,  0.5498],
        [-1.0475,  1.7311,  0.5646],
        [-0.8910,  1.4739,  0.4946],
        [-0.9972,  1.6475,  0.5273],
        [-1.0690,  1.7529,  0.5637],
        [-1.0146,  1.6930,  0.5394],
        [-1.0290,  1.6856,  0.5592],
        [-1.0310,  1.7370,  0.5526],
        [-1.0616,  1.7790,  0.5585],
        [-1.0309,  1.6917,  0.5505],
        [-1.0674,  1.8045,  0.5747],
        [-0.9436,  1.5130,  0.5046],
        [-0.9502,  1.5611,  0.5035],
        [-1.0751,  1.7418,  0.5858],
        [-1.0688,  1.7549,  0.5833],
        [-0.9952,  1.6294,  0.5513],
        [-0.8965,  1.4930,  0.5026],
        [-0.9566,  1.5013,  0.5126],

Epoch number 0
 Current loss 0.8423730134963989

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.1414,  1.7196,  0.6549],
        [-1.0967,  1.6728,  0.6322],
        [-1.1302,  1.7565,  0.6499],
        [-1.1488,  1.7678,  0.6812],
        [-1.0576,  1.5791,  0.6234],
        [-1.1338,  1.7974,  0.6557],
        [-1.1529,  1.7834,  0.6866],
        [-1.1739,  1.7644,  0.6797],
        [-1.2203,  1.9161,  0.6925],
        [-1.1549,  1.7325,  0.6714],
        [-1.0899,  1.6556,  0.6404],
        [-1.1105,  1.6842,  0.6395],
        [-1.1498,  1.7924,  0.6548],
        [-1.0446,  1.5943,  0.6070],
        [-1.0713,  1.6035,  0.5983],
        [-1.1019,  1.6929,  0.6338],
        [-1.2184,  1.8932,  0.7219],
        [-1.1639,  1.7447,  0.6655],
        [-1.0767,  1.6492,  0.6211],
        [-1.1772,  1.8037,  0.6725],
        [-1.1539,  1.8197,  0.6633],
        [-1.1542,  1.7068,  0.6640],
        [-1.0698,  1.6216,  0.6186],
        [-1.1215,  1.7695,  0.6454],

Epoch number 0
 Current loss 0.6851147413253784

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.1711,  1.6661,  0.7639],
        [-1.2041,  1.7197,  0.7498],
        [-1.0920,  1.5613,  0.6990],
        [-1.1592,  1.6575,  0.7494],
        [-1.2049,  1.7295,  0.7568],
        [-1.1588,  1.6096,  0.7273],
        [-1.2589,  1.7699,  0.8063],
        [-1.0841,  1.5437,  0.6772],
        [-1.2029,  1.7270,  0.7664],
        [-1.2412,  1.8308,  0.7811],
        [-1.1977,  1.6819,  0.7591],
        [-1.2235,  1.7229,  0.7671],
        [-1.2385,  1.7407,  0.7742],
        [-1.1185,  1.5478,  0.6993],
        [-1.1878,  1.6822,  0.7381],
        [-1.2259,  1.7358,  0.7787],
        [-1.2353,  1.8016,  0.7864],
        [-1.2063,  1.6891,  0.7574],
        [-1.1108,  1.5551,  0.6885],
        [-1.2180,  1.7230,  0.7470],
        [-1.1908,  1.6762,  0.7350],
        [-1.1682,  1.5743,  0.7344],
        [-1.1322,  1.6291,  0.7176],
        [-1.2010,  1.7046,  0.7515],

Epoch number 0
 Current loss 0.8017083406448364

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.2478,  1.6235,  0.8376],
        [-1.2565,  1.6675,  0.8441],
        [-1.2019,  1.6470,  0.8063],
        [-1.2095,  1.5937,  0.8077],
        [-1.1684,  1.5237,  0.7860],
        [-1.1070,  1.4727,  0.7471],
        [-1.2714,  1.6776,  0.8728],
        [-1.1857,  1.5968,  0.8035],
        [-1.1799,  1.5556,  0.7950],
        [-1.1912,  1.5349,  0.8030],
        [-1.2561,  1.6638,  0.8359],
        [-1.2967,  1.7392,  0.8752],
        [-1.2039,  1.5899,  0.8012],
        [-1.2570,  1.7160,  0.8843],
        [-1.2573,  1.7062,  0.8434],
        [-1.2125,  1.5735,  0.8021],
        [-1.1793,  1.5374,  0.8067],
        [-1.2349,  1.6058,  0.8491],
        [-1.2291,  1.6510,  0.8409],
        [-1.2303,  1.6123,  0.8338],
        [-1.2784,  1.6941,  0.8766],
        [-1.2321,  1.6549,  0.8313],
        [-1.1673,  1.5603,  0.7801],
        [-1.2294,  1.5915,  0.8269],

Epoch number 0
 Current loss 0.8638174533843994

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.2407,  1.5732,  0.8482],
        [-1.1256,  1.4202,  0.7897],
        [-1.2115,  1.5134,  0.8368],
        [-1.2250,  1.7092,  0.8590],
        [-1.2830,  1.6594,  0.8599],
        [-1.1617,  1.4808,  0.7951],
        [-1.1435,  1.4848,  0.8055],
        [-1.1347,  1.4636,  0.7821],
        [-1.1496,  1.5206,  0.8156],
        [-1.1858,  1.5346,  0.8211],
        [-1.1415,  1.4339,  0.7873],
        [-1.1567,  1.5234,  0.7812],
        [-1.1806,  1.5562,  0.8195],
        [-1.1149,  1.4380,  0.7661],
        [-1.1631,  1.4404,  0.8169],
        [-1.1929,  1.5576,  0.8294],
        [-1.1476,  1.5333,  0.7879],
        [-1.1184,  1.4829,  0.7566],
        [-1.1295,  1.5315,  0.7767],
        [-1.3032,  1.7676,  0.8872],
        [-1.1530,  1.5281,  0.8045],
        [-1.2237,  1.5643,  0.8371],
        [-1.2232,  1.6074,  0.8600],
        [-1.1909,  1.6669,  0.8389],

Epoch number 0
 Current loss 1.0255508422851562

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.1323,  1.4849,  0.7858],
        [-1.0425,  1.3848,  0.7218],
        [-1.0865,  1.4740,  0.7441],
        [-1.0859,  1.5046,  0.7300],
        [-1.0615,  1.4214,  0.7313],
        [-1.1204,  1.4692,  0.7796],
        [-1.0735,  1.4722,  0.7389],
        [-1.1296,  1.5223,  0.7820],
        [-1.0949,  1.4598,  0.7143],
        [-1.1803,  1.6337,  0.8005],
        [-1.0904,  1.4242,  0.7521],
        [-1.1413,  1.6107,  0.7771],
        [-1.1152,  1.4577,  0.7771],
        [-1.1117,  1.4812,  0.7637],
        [-1.1037,  1.5201,  0.7409],
        [-1.1491,  1.5674,  0.7882],
        [-1.1037,  1.4094,  0.7500],
        [-1.0807,  1.4501,  0.7375],
        [-1.0849,  1.4680,  0.7468],
        [-1.1360,  1.5822,  0.7680],
        [-1.0867,  1.5122,  0.7469],
        [-1.0340,  1.3296,  0.7220],
        [-1.0628,  1.4092,  0.7162],
        [-1.0592,  1.4294,  0.7169],

Epoch number 0
 Current loss 0.6995426416397095

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.0147,  1.4084,  0.6948],
        [-1.0023,  1.5134,  0.6798],
        [-1.0168,  1.3914,  0.6915],
        [-0.9820,  1.4663,  0.6539],
        [-1.0234,  1.4243,  0.6914],
        [-1.0592,  1.5313,  0.6956],
        [-1.0222,  1.4263,  0.6886],
        [-1.0017,  1.4437,  0.6767],
        [-0.9441,  1.3154,  0.6284],
        [-1.0294,  1.4497,  0.6885],
        [-1.0295,  1.4855,  0.6793],
        [-1.0756,  1.5512,  0.7281],
        [-0.9783,  1.4849,  0.6734],
        [-1.0290,  1.4381,  0.6806],
        [-1.0161,  1.6056,  0.6775],
        [-0.9360,  1.3088,  0.6297],
        [-1.0317,  1.5186,  0.6865],
        [-0.9962,  1.3483,  0.6667],
        [-1.0167,  1.3971,  0.6858],
        [-0.9701,  1.3423,  0.6648],
        [-1.0058,  1.3873,  0.6716],
        [-0.9264,  1.1874,  0.6104],
        [-1.0259,  1.4903,  0.6961],
        [-1.0021,  1.5148,  0.6904],

Epoch number 0
 Current loss 0.7063367962837219

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.9514,  1.3819,  0.6183],
        [-0.9501,  1.4162,  0.6038],
        [-0.9366,  1.4034,  0.6302],
        [-0.9500,  1.4596,  0.6050],
        [-0.9123,  1.4834,  0.5901],
        [-0.8607,  1.1506,  0.5615],
        [-0.9783,  1.4135,  0.6318],
        [-0.9932,  1.5182,  0.6302],
        [-0.9627,  1.3400,  0.6137],
        [-0.8817,  1.2605,  0.5672],
        [-0.9992,  1.5151,  0.6348],
        [-0.9300,  1.3317,  0.5893],
        [-0.9077,  1.3247,  0.6066],
        [-0.9450,  1.4240,  0.6116],
        [-1.0055,  1.5849,  0.6211],
        [-0.9705,  1.4977,  0.6017],
        [-0.8845,  1.2696,  0.5851],
        [-0.8621,  1.2513,  0.5618],
        [-0.8603,  1.3370,  0.5578],
        [-0.9357,  1.3742,  0.5828],
        [-0.9380,  1.4843,  0.5814],
        [-0.8870,  1.2523,  0.5794],
        [-0.9451,  1.4573,  0.6029],
        [-0.9540,  1.4535,  0.6073],

Epoch number 0
 Current loss 0.7841851711273193

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.8909,  1.3372,  0.5712],
        [-0.8694,  1.2080,  0.5483],
        [-0.8828,  1.2230,  0.5711],
        [-0.9521,  1.5460,  0.5768],
        [-0.8383,  1.2622,  0.5460],
        [-0.8185,  1.1508,  0.5192],
        [-0.9131,  1.3583,  0.5889],
        [-0.9039,  1.3712,  0.5623],
        [-0.9246,  1.4517,  0.5394],
        [-0.9074,  1.3066,  0.5794],
        [-0.9211,  1.4394,  0.5409],
        [-0.9376,  1.5065,  0.5630],
        [-0.9142,  1.3730,  0.5524],
        [-0.9394,  1.4725,  0.5856],
        [-0.8784,  1.3502,  0.5527],
        [-0.9293,  1.3565,  0.5807],
        [-0.8931,  1.4290,  0.5446],
        [-0.9222,  1.3930,  0.5867],
        [-0.8938,  1.4804,  0.5511],
        [-0.9879,  1.5695,  0.5787],
        [-0.8970,  1.3326,  0.5549],
        [-0.8926,  1.4236,  0.5482],
        [-0.8964,  1.4229,  0.5442],
        [-0.9212,  1.4586,  0.5515],

Epoch number 0
 Current loss 0.7292150259017944

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.8435,  1.3659,  0.4893],
        [-0.9271,  1.4723,  0.5353],
        [-0.8644,  1.3629,  0.5223],
        [-0.8271,  1.3003,  0.5025],
        [-0.9233,  1.4296,  0.5340],
        [-0.8774,  1.3773,  0.5360],
        [-0.8797,  1.3138,  0.5508],
        [-0.8720,  1.2937,  0.5502],
        [-0.9286,  1.3978,  0.5551],
        [-0.9084,  1.3189,  0.5581],
        [-0.7574,  1.1226,  0.5011],
        [-0.9107,  1.4665,  0.5291],
        [-0.9163,  1.5129,  0.5322],
        [-0.8970,  1.4189,  0.5329],
        [-0.9147,  1.5162,  0.5211],
        [-0.9387,  1.5981,  0.5329],
        [-0.8989,  1.4283,  0.5385],
        [-0.9034,  1.4762,  0.5305],
        [-0.8345,  1.2207,  0.5590],
        [-0.8596,  1.3970,  0.5311],
        [-0.9125,  1.3881,  0.5329],
        [-0.8913,  1.5176,  0.5274],
        [-0.9573,  1.5944,  0.5426],
        [-0.9421,  1.4574,  0.5506],

Epoch number 0
 Current loss 0.7436034083366394

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.8379,  1.3762,  0.5118],
        [-0.9082,  1.5557,  0.4948],
        [-0.8373,  1.3635,  0.5133],
        [-0.8399,  1.4430,  0.5014],
        [-0.9033,  1.5858,  0.5016],
        [-0.9054,  1.6061,  0.4863],
        [-0.9321,  1.6048,  0.5124],
        [-0.8883,  1.3550,  0.5025],
        [-0.9286,  1.6105,  0.4983],
        [-0.9262,  1.5155,  0.5183],
        [-0.8965,  1.5546,  0.4738],
        [-0.8259,  1.4433,  0.4693],
        [-0.8893,  1.4667,  0.5070],
        [-0.9137,  1.3761,  0.5410],
        [-0.8728,  1.3948,  0.5251],
        [-0.8950,  1.3663,  0.5172],
        [-0.8806,  1.3953,  0.5027],
        [-0.8929,  1.3370,  0.5468],
        [-0.9161,  1.5140,  0.5167],
        [-0.8708,  1.5186,  0.4950],
        [-0.8118,  1.5511,  0.4416],
        [-0.8978,  1.5766,  0.4931],
        [-0.9035,  1.4181,  0.5347],
        [-0.8897,  1.4103,  0.5005],

Epoch number 0
 Current loss 0.7880611419677734

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.8593,  1.3476,  0.5077],
        [-0.8863,  1.3703,  0.5171],
        [-0.8920,  1.3830,  0.4841],
        [-0.9420,  1.3912,  0.5534],
        [-0.9232,  1.7269,  0.4482],
        [-0.9204,  1.5932,  0.4925],
        [-0.8951,  1.3650,  0.5080],
        [-0.9237,  1.4944,  0.5186],
        [-0.8400,  1.2837,  0.4930],
        [-0.8607,  1.2998,  0.5159],
        [-0.9112,  1.4963,  0.5077],
        [-0.8622,  1.2346,  0.5113],
        [-0.9011,  1.5034,  0.4967],
        [-0.9527,  1.5763,  0.4943],
        [-0.8429,  1.3232,  0.5023],
        [-0.9106,  1.3862,  0.5217],
        [-0.9307,  1.7394,  0.4730],
        [-0.9213,  1.6331,  0.4914],
        [-0.8910,  1.3312,  0.5605],
        [-0.8658,  1.4225,  0.4858],
        [-0.8486,  1.4532,  0.4958],
        [-0.8542,  1.2813,  0.4922],
        [-0.8659,  1.4268,  0.5036],
        [-0.9085,  1.4737,  0.5253],

Epoch number 0
 Current loss 0.6476776003837585

inputs are
torch.Size([100, 87])
torch.Size([100, 87])
OUTPUT
tensor([[-0.9879,  1.7425,  0.4944],
        [-0.9487,  1.7666,  0.4545],
        [-0.9353,  1.5747,  0.4819],
        [-0.9476,  1.6748,  0.4804],
        [-0.9503,  1.5605,  0.5372],
        [-0.9158,  1.3568,  0.5444],
        [-0.8230,  1.1864,  0.5058],
        [-0.9487,  1.4838,  0.5282],
        [-0.8911,  1.3726,  0.5167],
        [-0.9216,  1.4424,  0.5106],
        [-0.9407,  1.5627,  0.5028],
        [-0.9485,  1.7138,  0.4739],
        [-0.9566,  1.7182,  0.4611],
        [-0.8782,  1.2497,  0.5409],
        [-0.8881,  1.3303,  0.5490],
        [-0.8978,  1.4876,  0.4801],
        [-0.8469,  1.3184,  0.5104],
        [-0.8291,  1.2741,  0.4819],
        [-0.9042,  1.4870,  0.4963],
        [-0.8592,  1.5718,  0.4419],
        [-0.8596,  1.5641,  0.4828],
        [-0.9040,  1.4287,  0.4986],
        [-0.9213,  1.5070,  0.4939],
        [-0.9097,  1.3835,  0.5221],
 

Epoch number 0
 Current loss 0.7337682247161865

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.9480,  1.6018,  0.4907],
        [-1.0032,  1.7755,  0.4611],
        [-1.0075,  1.7356,  0.4954],
        [-0.8116,  1.0841,  0.5300],
        [-0.9713,  1.7652,  0.4659],
        [-0.9996,  1.6810,  0.5073],
        [-1.0014,  1.8359,  0.4632],
        [-0.9898,  1.3921,  0.5754],
        [-0.9723,  1.6269,  0.4773],
        [-0.8907,  1.1332,  0.5791],
        [-0.9400,  1.5310,  0.5108],
        [-0.8948,  1.4174,  0.5317],
        [-0.9740,  1.6567,  0.4913],
        [-0.9589,  1.8031,  0.4686],
        [-0.9460,  1.4081,  0.5290],
        [-0.9704,  1.4917,  0.5344],
        [-0.9326,  1.6363,  0.4578],
        [-0.9120,  1.5653,  0.5208],
        [-0.9826,  1.8089,  0.4678],
        [-0.9758,  1.8465,  0.4473],
        [-0.9513,  1.5593,  0.5059],
        [-0.9557,  1.4539,  0.5377],
        [-0.9805,  1.4900,  0.5461],
        [-0.9135,  1.5281,  0.4955],

Epoch number 0
 Current loss 0.7145039439201355

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.8815,  1.3463,  0.4930],
        [-0.9804,  1.4367,  0.5587],
        [-0.9421,  1.5543,  0.4853],
        [-1.0055,  1.6408,  0.5161],
        [-0.9366,  1.4654,  0.5175],
        [-0.9196,  1.3625,  0.5414],
        [-0.9394,  1.5845,  0.5026],
        [-0.8793,  1.3262,  0.5218],
        [-0.9766,  1.4383,  0.5513],
        [-0.9757,  1.4609,  0.5622],
        [-0.9160,  1.4308,  0.5101],
        [-1.0177,  1.7832,  0.5220],
        [-1.0288,  1.8941,  0.4647],
        [-0.9588,  1.7809,  0.4496],
        [-0.9868,  1.8022,  0.4788],
        [-1.0477,  2.0153,  0.4521],
        [-1.0132,  1.7337,  0.4767],
        [-0.9854,  1.4572,  0.5583],
        [-0.9656,  1.8221,  0.4502],
        [-0.9636,  1.4849,  0.5161],
        [-0.9888,  1.8753,  0.4385],
        [-0.9014,  1.4636,  0.4935],
        [-0.7528,  0.9726,  0.4725],
        [-0.9757,  1.4127,  0.5621],

Epoch number 0
 Current loss 0.7669123411178589

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.0666,  2.1190,  0.3932],
        [-0.9093,  1.2742,  0.5360],
        [-1.0176,  1.6215,  0.5222],
        [-1.0057,  1.4150,  0.5876],
        [-0.9802,  1.6696,  0.4860],
        [-0.9761,  1.5585,  0.5010],
        [-1.0067,  1.8690,  0.4428],
        [-0.9536,  1.6538,  0.4574],
        [-1.0168,  1.5457,  0.5595],
        [-1.0302,  2.1197,  0.4106],
        [-1.0268,  2.0060,  0.4242],
        [-1.0010,  1.7270,  0.4783],
        [-0.9927,  1.8165,  0.4538],
        [-1.0109,  1.8580,  0.4546],
        [-0.9233,  1.5460,  0.4728],
        [-0.8718,  1.1966,  0.5474],
        [-0.9641,  1.5218,  0.5243],
        [-1.0233,  1.8578,  0.4605],
        [-1.0140,  1.6241,  0.5306],
        [-0.9919,  1.2502,  0.6052],
        [-1.0070,  1.9491,  0.4123],
        [-1.0310,  1.6974,  0.5556],
        [-1.0486,  2.1405,  0.3914],
        [-1.0065,  2.0258,  0.4025],

Epoch number 0
 Current loss 0.7640464305877686

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.9941,  1.9688,  0.3977],
        [-1.0159,  1.8465,  0.4592],
        [-0.9251,  1.4902,  0.4982],
        [-0.9533,  1.3320,  0.5632],
        [-0.9486,  1.3626,  0.5359],
        [-0.8980,  1.1736,  0.5764],
        [-0.8887,  1.2101,  0.5670],
        [-0.9748,  1.6415,  0.4792],
        [-0.8908,  1.2951,  0.5193],
        [-1.0084,  1.9093,  0.4318],
        [-0.9108,  1.3007,  0.5500],
        [-1.0703,  2.0225,  0.4451],
        [-1.0539,  1.8301,  0.4795],
        [-0.9997,  1.7518,  0.4631],
        [-1.0670,  2.1218,  0.4256],
        [-1.0396,  1.8324,  0.4662],
        [-1.0613,  2.1910,  0.4172],
        [-1.0048,  1.9662,  0.4147],
        [-0.9814,  1.9420,  0.3849],
        [-0.9247,  1.4285,  0.5041],
        [-0.9869,  1.9247,  0.4294],
        [-0.8994,  0.9903,  0.6005],
        [-1.0260,  1.7266,  0.5132],
        [-1.0296,  2.0813,  0.3815],

Epoch number 0
 Current loss 0.7390372157096863

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.0906,  2.0750,  0.4523],
        [-0.9501,  1.4766,  0.5361],
        [-1.0897,  1.9872,  0.4757],
        [-1.0043,  1.8984,  0.4365],
        [-0.9521,  1.2625,  0.5976],
        [-1.0170,  2.0781,  0.4060],
        [-1.0098,  1.6456,  0.5081],
        [-0.9904,  1.8042,  0.4405],
        [-1.0344,  1.6789,  0.5121],
        [-0.9265,  1.1624,  0.5979],
        [-0.9989,  1.4919,  0.5680],
        [-0.9766,  1.2607,  0.6239],
        [-0.7916,  0.8953,  0.5397],
        [-0.9524,  1.3371,  0.5527],
        [-1.0739,  2.0685,  0.4553],
        [-0.9093,  1.1905,  0.5773],
        [-0.9624,  1.5050,  0.5241],
        [-1.0372,  1.9236,  0.4478],
        [-1.0270,  1.9226,  0.4480],
        [-1.0216,  2.0348,  0.3941],
        [-1.0059,  1.5802,  0.5488],
        [-0.9569,  1.5679,  0.4793],
        [-1.0116,  1.6566,  0.5170],
        [-1.0219,  1.3034,  0.6282],

Epoch number 0
 Current loss 0.7839645147323608

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.1003,  2.0796,  0.4656],
        [-0.8788,  1.1267,  0.5857],
        [-1.0034,  1.7018,  0.5130],
        [-1.0295,  1.2464,  0.6558],
        [-0.9561,  1.3007,  0.5847],
        [-1.1392,  2.0950,  0.4934],
        [-0.9157,  0.9594,  0.6432],
        [-1.0244,  1.6045,  0.5064],
        [-0.9967,  1.3987,  0.5844],
        [-0.9942,  1.3528,  0.6192],
        [-1.1031,  1.6829,  0.5575],
        [-0.9877,  1.6324,  0.4997],
        [-0.8946,  1.2388,  0.5282],
        [-1.0113,  1.6016,  0.5327],
        [-1.1115,  1.8980,  0.5152],
        [-1.1683,  2.0528,  0.5288],
        [-1.0166,  1.6881,  0.5164],
        [-1.1449,  2.1157,  0.4930],
        [-0.9363,  1.2491,  0.5816],
        [-0.9401,  1.2395,  0.5976],
        [-1.0880,  1.8961,  0.5337],
        [-1.0381,  1.6370,  0.5611],
        [-1.0953,  1.9201,  0.5226],
        [-1.0472,  1.8338,  0.5094],

Epoch number 0
 Current loss 0.6486985087394714

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.9124,  0.8089,  0.6724],
        [-0.9839,  1.4981,  0.5313],
        [-1.1140,  2.1730,  0.4246],
        [-1.1529,  2.4099,  0.3967],
        [-1.0963,  1.9758,  0.4744],
        [-1.0002,  1.6044,  0.5297],
        [-0.9222,  1.0004,  0.6339],
        [-1.0853,  2.0006,  0.4581],
        [-1.0378,  1.1971,  0.6975],
        [-1.0727,  1.8231,  0.4858],
        [-1.0918,  1.8795,  0.5153],
        [-1.0131,  0.9536,  0.7378],
        [-1.0727,  1.8956,  0.4649],
        [-0.9440,  0.9338,  0.6746],
        [-1.0856,  2.1345,  0.4253],
        [-0.9499,  1.1272,  0.6385],
        [-1.0706,  2.0268,  0.4341],
        [-1.0380,  1.3329,  0.6492],
        [-1.0114,  1.2800,  0.6212],
        [-1.0603,  1.9145,  0.4414],
        [-0.9071,  0.8358,  0.6684],
        [-0.9178,  0.8543,  0.6612],
        [-1.0522,  2.0063,  0.4248],
        [-0.9893,  1.0317,  0.7002],

Epoch number 0
 Current loss 0.6342433094978333

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.0296,  1.2571,  0.6392],
        [-1.1565,  2.1907,  0.4620],
        [-0.9350,  1.0923,  0.6488],
        [-0.9958,  0.9414,  0.7393],
        [-1.0105,  1.0809,  0.6805],
        [-1.0239,  1.4647,  0.5952],
        [-1.1742,  2.0037,  0.5367],
        [-0.9359,  1.1053,  0.6473],
        [-1.0260,  0.8629,  0.7690],
        [-1.1419,  1.9395,  0.4881],
        [-1.0758,  1.9213,  0.4590],
        [-1.1088,  1.7693,  0.5495],
        [-0.9423,  0.7773,  0.7047],
        [-1.0721,  1.3717,  0.6377],
        [-1.1704,  1.4680,  0.7023],
        [-1.1110,  2.1382,  0.4123],
        [-1.0062,  0.9021,  0.7561],
        [-1.0953,  2.1561,  0.4049],
        [-1.0244,  0.9981,  0.7437],
        [-0.9823,  0.8677,  0.7110],
        [-1.1585,  2.3219,  0.4269],
        [-0.9025,  0.7973,  0.6419],
        [-1.0198,  1.0094,  0.7342],
        [-1.1715,  2.1616,  0.4603],

Epoch number 0
 Current loss 0.7564306855201721

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.2123,  2.5596,  0.3568],
        [-1.0554,  1.5974,  0.5213],
        [-1.0688,  2.0022,  0.4043],
        [-1.1222,  2.2604,  0.3389],
        [-1.1286,  2.2358,  0.3768],
        [-0.9339,  0.6590,  0.6950],
        [-1.0888,  1.1548,  0.7342],
        [-1.0719,  1.2403,  0.6918],
        [-1.1115,  2.3679,  0.3022],
        [-1.1597,  2.4267,  0.3095],
        [-1.1454,  2.0668,  0.4397],
        [-1.1146,  1.5037,  0.6252],
        [-1.0883,  1.8310,  0.4944],
        [-1.0404,  0.9815,  0.7301],
        [-1.1680,  1.9345,  0.4861],
        [-1.1738,  2.0274,  0.4850],
        [-0.9025,  0.7023,  0.6626],
        [-1.0588,  1.1181,  0.6975],
        [-0.9345,  1.0034,  0.6348],
        [-1.1458,  2.4160,  0.3110],
        [-0.9048,  0.7061,  0.6756],
        [-1.1002,  0.7679,  0.8544],
        [-1.0944,  1.1321,  0.7528],
        [-0.9576,  0.7039,  0.7279],

Epoch number 0
 Current loss 0.5719172954559326

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.1825,  2.3734,  0.3450],
        [-1.1502,  2.4043,  0.2828],
        [-1.0844,  1.7159,  0.5065],
        [-1.1330,  2.2811,  0.3394],
        [-1.0678,  1.5899,  0.5519],
        [-1.1329,  1.4535,  0.6535],
        [-1.0037,  0.8351,  0.7091],
        [-1.1498,  2.5220,  0.2382],
        [-0.9608,  0.5293,  0.7500],
        [-1.0949,  1.8697,  0.4377],
        [-0.8726,  0.5262,  0.6595],
        [-1.0401,  0.5960,  0.8053],
        [-1.1011,  2.2691,  0.2675],
        [-1.1089,  0.8603,  0.8377],
        [-1.0703,  1.4784,  0.5736],
        [-1.0863,  2.3343,  0.2700],
        [-1.1398,  1.0670,  0.8066],
        [-1.0925,  1.1597,  0.7220],
        [-1.1601,  2.4359,  0.2955],
        [-1.1349,  2.2969,  0.2850],
        [-1.1066,  1.0914,  0.7873],
        [-0.9921,  1.0167,  0.6756],
        [-1.1168,  2.3537,  0.2761],
        [-1.1080,  2.2881,  0.2951],

Epoch number 0
 Current loss 0.6825023889541626

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.0296,  1.9857,  0.2869],
        [-0.9471,  0.7176,  0.7085],
        [-0.9879,  0.5032,  0.7769],
        [-1.0020,  0.9808,  0.6368],
        [-1.0405,  1.9788,  0.3370],
        [-1.1526,  2.6266,  0.1629],
        [-1.0745,  2.3566,  0.2184],
        [-1.0818,  2.4073,  0.1916],
        [-1.0408,  1.6819,  0.4194],
        [-1.0392,  2.3963,  0.1477],
        [-0.9888,  1.8787,  0.3147],
        [-0.8072,  0.5504,  0.5936],
        [-1.0668,  1.2872,  0.6104],
        [-0.9790,  0.5528,  0.7503],
        [-1.0889,  2.5672,  0.1437],
        [-1.0596,  2.1308,  0.3049],
        [-1.0379,  1.7351,  0.4262],
        [-1.0474,  2.2620,  0.2179],
        [-1.0447,  2.3811,  0.1937],
        [-1.0508,  2.0443,  0.3252],
        [-1.1140,  2.5424,  0.1626],
        [-0.9824,  1.9075,  0.3103],
        [-0.9775,  0.4440,  0.7881],
        [-1.0802,  1.6157,  0.5013],

Epoch number 0
 Current loss 0.6643162965774536

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.9743,  0.4985,  0.7870],
        [-1.0822,  2.3702,  0.1922],
        [-0.9862,  0.5549,  0.7832],
        [-0.9941,  2.2752,  0.1399],
        [-1.0355,  2.5095,  0.0739],
        [-1.0603,  2.5292,  0.0901],
        [-1.0781,  1.1459,  0.6466],
        [-1.0691,  2.3717,  0.1999],
        [-1.0353,  1.0573,  0.6690],
        [-0.9856,  1.0333,  0.6002],
        [-0.9600,  0.7161,  0.6930],
        [-0.9886,  2.0304,  0.2721],
        [-1.0195,  2.5088,  0.0679],
        [-1.0485,  0.9805,  0.7069],
        [-0.8407,  0.3487,  0.6762],
        [-0.9747,  0.7685,  0.6906],
        [-0.9029,  0.5953,  0.6829],
        [-1.0556,  2.4037,  0.1387],
        [-1.0648,  2.6207,  0.0913],
        [-1.0174,  2.3751,  0.1248],
        [-0.9748,  0.7439,  0.6945],
        [-1.0495,  1.3900,  0.5367],
        [-0.9433,  1.2276,  0.4947],
        [-0.9642,  1.7211,  0.3469],

Epoch number 0
 Current loss 0.7651817202568054

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.9189,  0.3778,  0.7813],
        [-0.9732,  0.2803,  0.8310],
        [-0.8940,  0.8218,  0.6142],
        [-0.9840,  1.7918,  0.2965],
        [-1.0314,  2.0373,  0.2615],
        [-1.0271,  2.4214,  0.1359],
        [-0.9484,  0.2404,  0.8199],
        [-1.0108,  0.5662,  0.7916],
        [-0.9514,  0.7736,  0.6598],
        [-1.0449,  1.8194,  0.3588],
        [-0.9706,  0.5335,  0.7556],
        [-0.8993,  0.6385,  0.6846],
        [-0.9610,  1.9710,  0.2109],
        [-1.0107,  0.4143,  0.8567],
        [-0.9544,  2.3890,  0.0378],
        [-0.9238,  2.0461,  0.1707],
        [-0.9856,  2.2057,  0.1787],
        [-1.0334,  2.2244,  0.1696],
        [-0.9730,  2.4448,  0.0604],
        [-0.9368,  0.5291,  0.7497],
        [-0.9288,  0.5500,  0.7008],
        [-0.9334,  0.2993,  0.7971],
        [-0.9975,  2.3188,  0.1298],
        [-0.9784,  0.3051,  0.8626],

Epoch number 0
 Current loss 0.700944185256958

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.9734,  1.0530,  0.5724],
        [-1.0092,  0.7320,  0.6894],
        [-0.9748,  1.8918,  0.2357],
        [-0.9423,  0.3943,  0.7836],
        [-0.9788,  2.3041,  0.1204],
        [-0.9406,  2.2505,  0.0834],
        [-0.9255,  0.5113,  0.7241],
        [-1.0034,  1.8228,  0.2743],
        [-0.9654,  1.2395,  0.4981],
        [-0.9614,  1.5291,  0.3945],
        [-0.8836,  0.5389,  0.6885],
        [-0.9126,  2.1817,  0.0680],
        [-0.8968,  1.3499,  0.3765],
        [-0.8870,  0.4418,  0.7257],
        [-0.9934,  1.5164,  0.4243],
        [-0.9311,  2.0589,  0.1563],
        [-0.9553,  2.4978, -0.0386],
        [-1.0118,  2.4499,  0.1054],
        [-1.0834,  1.6282,  0.4640],
        [-0.8952,  2.0360,  0.1258],
        [-0.9909,  1.1012,  0.5946],
        [-0.9641,  1.8546,  0.2376],
        [-0.9991,  0.9912,  0.6032],
        [-0.9405,  1.9505,  0.2082],


Epoch number 0
 Current loss 0.6258754134178162

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.8790,  0.3238,  0.7450],
        [-0.8897,  0.7926,  0.5319],
        [-0.8646,  1.4373,  0.2874],
        [-0.8801,  2.1565,  0.0429],
        [-1.0269,  0.2278,  0.8867],
        [-0.8659,  1.1580,  0.4259],
        [-0.8557,  0.2497,  0.7085],
        [-0.9100,  0.8976,  0.5384],
        [-0.9405,  1.4958,  0.3312],
        [-0.8751,  1.8347,  0.1568],
        [-0.9276,  2.0246,  0.1116],
        [-0.9074,  0.9908,  0.4935],
        [-0.9222,  2.4476, -0.0929],
        [-0.9014,  1.4304,  0.3306],
        [-1.0689,  2.0441,  0.2684],
        [-0.8688,  1.8725,  0.1216],
        [-0.9068,  1.3009,  0.4079],
        [-0.9489,  1.9605,  0.2158],
        [-0.8450,  0.3520,  0.6981],
        [-0.9451,  1.8658,  0.2451],
        [-0.9472,  2.5312, -0.0831],
        [-1.0736,  0.5070,  0.8802],
        [-0.8586,  0.2762,  0.7238],
        [-0.9114,  0.2853,  0.7641],

Epoch number 0
 Current loss 0.7342166304588318

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.9954,  2.6617, -0.1490],
        [-0.9391,  2.4443, -0.0905],
        [-0.9338,  2.4667, -0.1202],
        [-0.9536,  2.5657, -0.1645],
        [-0.9879,  1.4411,  0.4267],
        [-0.9466,  2.5890, -0.1572],
        [-0.9637,  2.6652, -0.1817],
        [-1.0319,  2.4905, -0.0106],
        [-0.9996,  2.7138, -0.1675],
        [-0.9746,  2.6512, -0.1912],
        [-0.8712,  0.3728,  0.7158],
        [-0.9708,  2.6786, -0.1775],
        [-1.0015,  2.6814, -0.1558],
        [-1.0058,  1.3982,  0.4408],
        [-0.7740,  0.6738,  0.4818],
        [-0.8726,  2.0666,  0.0128],
        [-0.9234,  1.1631,  0.4543],
        [-0.9355,  2.4738, -0.1089],
        [-0.9309,  2.3398, -0.0696],
        [-0.9008,  2.3868, -0.1012],
        [-0.8942,  0.4094,  0.7282],
        [-0.9775,  2.6566, -0.1766],
        [-0.9206,  2.4243, -0.1376],
        [-1.0193,  2.7514, -0.1691],

Epoch number 0
 Current loss 0.6818815469741821

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.9519,  2.4681, -0.1103],
        [-0.7853,  1.7926,  0.0657],
        [-0.9189,  0.7235,  0.5753],
        [-0.8431,  0.2587,  0.6625],
        [-0.7581,  0.6566,  0.4494],
        [-0.9795,  2.5813, -0.1439],
        [-0.7939,  0.2306,  0.6643],
        [-0.8929,  2.2995, -0.1007],
        [-1.0226,  1.4648,  0.4209],
        [-0.8734,  0.2113,  0.7304],
        [-0.9589,  2.4760, -0.1193],
        [-0.9423,  2.4753, -0.1418],
        [-0.9302,  0.2207,  0.7929],
        [-0.9525,  2.5683, -0.1556],
        [-0.9574,  2.5247, -0.1598],
        [-0.9016,  2.0371,  0.0889],
        [-1.0013,  2.7249, -0.1979],
        [-0.9677,  0.5411,  0.7088],
        [-0.9381,  2.4810, -0.1217],
        [-0.8797,  2.3539, -0.1141],
        [-0.9638,  2.6204, -0.1710],
        [-0.9718,  2.5462, -0.1355],
        [-0.9249,  0.6159,  0.6777],
        [-0.7420,  0.8284,  0.3992],

Epoch number 1
 Current loss 0.6574572324752808

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.8175,  1.1420,  0.3434],
        [-0.8998,  0.2996,  0.7429],
        [-0.9898,  2.5713, -0.1203],
        [-0.9239,  1.0375,  0.4903],
        [-0.9737,  2.5646, -0.1566],
        [-0.9481,  0.5699,  0.7053],
        [-0.8903,  0.4945,  0.6386],
        [-1.0075,  2.6569, -0.1700],
        [-0.8846,  1.6736,  0.1834],
        [-0.9176,  2.1617, -0.0394],
        [-1.0469,  2.7564, -0.1661],
        [-0.9283,  0.4277,  0.7282],
        [-0.9962,  2.5347, -0.0710],
        [-0.9377,  1.9730,  0.0855],
        [-0.9864,  2.4099, -0.0074],
        [-0.9252,  2.1916, -0.0160],
        [-0.9448,  2.4560, -0.1147],
        [-1.0234,  0.1421,  0.8923],
        [-0.9925,  1.5677,  0.3748],
        [-0.9898,  2.5479, -0.1243],
        [-1.0299,  2.6771, -0.1515],
        [-0.7649,  0.2131,  0.6259],
        [-0.8507,  1.3301,  0.3314],
        [-0.9686,  2.4944, -0.1285],

Epoch number 1
 Current loss 0.6741161346435547

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.8069,  0.1355,  0.7197],
        [-0.9042,  2.0820,  0.0309],
        [-0.7163,  0.2446,  0.5744],
        [-0.9716,  0.1937,  0.8746],
        [-0.7917,  0.1516,  0.6711],
        [-0.9176,  0.8101,  0.5746],
        [-0.9570,  2.1309,  0.0513],
        [-1.0627,  2.6435, -0.0505],
        [-0.8861,  0.2077,  0.7946],
        [-0.8428,  1.3207,  0.3104],
        [-0.8757,  0.2985,  0.7427],
        [-0.9820,  2.4105, -0.0356],
        [-1.0364,  2.5614, -0.0417],
        [-0.9647,  0.4060,  0.7762],
        [-0.7794,  0.8556,  0.3954],
        [-1.0179,  0.9023,  0.6544],
        [-0.9870,  0.3194,  0.8742],
        [-0.8095,  0.7504,  0.4957],
        [-0.9078,  1.8797,  0.1593],
        [-1.0300,  2.6162, -0.1256],
        [-0.9392,  0.6437,  0.6575],
        [-0.8550,  0.1670,  0.7458],
        [-1.0387,  1.4207,  0.4404],
        [-0.9832,  2.2845,  0.0056],

Epoch number 1
 Current loss 0.599022626876831

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.0260,  0.1750,  0.9669],
        [-1.0385,  2.5958, -0.0900],
        [-0.8949,  1.5644,  0.2629],
        [-1.0176,  2.0330,  0.1700],
        [-1.0324,  2.2640,  0.0961],
        [-0.9468,  0.5520,  0.6854],
        [-0.7844,  0.5604,  0.5664],
        [-0.9449,  1.8471,  0.2154],
        [-0.9089,  0.1499,  0.8195],
        [-0.8747,  0.1067,  0.7901],
        [-0.7639,  0.1679,  0.6731],
        [-0.9743,  2.2753,  0.0097],
        [-0.9650,  2.2697,  0.0075],
        [-1.0060,  2.4305, -0.0345],
        [-0.9528,  0.5723,  0.7092],
        [-1.0346,  2.3077,  0.0686],
        [-0.9500,  0.5932,  0.6845],
        [-0.9048,  0.4843,  0.7204],
        [-0.7325,  0.1652,  0.6281],
        [-0.8660,  0.4220,  0.7092],
        [-1.0042,  1.9513,  0.2168],
        [-1.0445,  0.9511,  0.6394],
        [-0.9863,  2.2936,  0.0315],
        [-0.7985,  0.9757,  0.4127],


Epoch number 1
 Current loss 0.5595455169677734

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.8790,  0.1795,  0.7874],
        [-1.0400,  2.3437,  0.0164],
        [-1.0002,  2.0960,  0.1480],
        [-0.8331,  1.4754,  0.2273],
        [-1.0781,  2.6975, -0.0748],
        [-1.1226,  2.7168, -0.0571],
        [-0.7481,  0.2887,  0.6398],
        [-0.9944,  0.5695,  0.7114],
        [-0.8647,  1.5411,  0.2292],
        [-0.9639,  2.0712,  0.0849],
        [-1.0929,  2.6957, -0.0755],
        [-1.0796,  1.5163,  0.4576],
        [-1.0936,  2.7030, -0.1008],
        [-1.0549,  2.5021, -0.0510],
        [-0.9809,  1.8004,  0.2596],
        [-0.8470,  0.9022,  0.5056],
        [-1.0872,  2.6589, -0.0598],
        [-0.8561,  0.4460,  0.6565],
        [-0.9926,  1.9729,  0.1895],
        [-1.0662,  2.0368,  0.2281],
        [-0.9018, -0.0534,  0.8694],
        [-0.9977,  2.4073, -0.0261],
        [-0.9179,  1.3117,  0.3980],
        [-1.0292,  2.2473,  0.0509],

Epoch number 1
 Current loss 0.5244063138961792

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.1715,  2.8965, -0.1109],
        [-0.9157,  1.1190,  0.4330],
        [-1.1182,  1.9016,  0.3343],
        [-1.1144,  2.7639, -0.1239],
        [-0.9858,  0.9135,  0.5966],
        [-0.9339,  0.3772,  0.8044],
        [-1.0987,  2.6047, -0.0635],
        [-0.9299,  2.1626,  0.0105],
        [-1.1058,  2.7593, -0.1426],
        [-0.8024,  0.0091,  0.7465],
        [-0.8722,  0.1295,  0.8225],
        [-0.8853,  0.0091,  0.8217],
        [-1.1324,  2.7940, -0.0953],
        [-0.9306,  0.2091,  0.8268],
        [-0.9765,  1.3752,  0.4069],
        [-1.1026,  2.6719, -0.0750],
        [-0.9676,  0.1703,  0.8813],
        [-0.8946,  0.0086,  0.8442],
        [-1.2106,  1.8044,  0.4484],
        [-0.8841,  0.1243,  0.8207],
        [-1.0330,  0.0293,  0.9912],
        [-1.0638,  2.5698, -0.0774],
        [-1.1112,  2.7699, -0.1474],
        [-0.7280,  0.0298,  0.6427],

Epoch number 1
 Current loss 0.4877483546733856

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.0579,  2.3809, -0.0263],
        [-1.1496,  2.6837, -0.0829],
        [-1.0748,  2.5643, -0.1012],
        [-1.1376,  2.7332, -0.1011],
        [-1.1910,  2.8332, -0.0804],
        [-1.0566,  2.5643, -0.1217],
        [-1.1608,  0.5751,  0.9122],
        [-1.1945,  2.4535,  0.1216],
        [-1.1620,  2.7723, -0.1305],
        [-0.8288, -0.0664,  0.7919],
        [-1.0300,  1.1645,  0.4947],
        [-0.9997,  1.2660,  0.4532],
        [-1.0944,  2.4106, -0.0023],
        [-1.0312,  2.1029,  0.0889],
        [-0.8559,  0.0828,  0.8043],
        [-0.9747,  0.0127,  0.9365],
        [-0.7982, -0.0250,  0.7481],
        [-1.1264,  2.7164, -0.1427],
        [-1.1618,  2.8051, -0.1350],
        [-1.1279,  2.2824,  0.0726],
        [-1.1090,  2.6174, -0.0767],
        [-1.1079,  2.7149, -0.1130],
        [-1.0888,  2.7327, -0.1600],
        [-1.1417,  0.6530,  0.8267],

Epoch number 1
 Current loss 0.5563654899597168

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.1895,  2.0208,  0.3193],
        [-1.1603,  2.5610, -0.0500],
        [-0.9526,  1.5804,  0.2361],
        [-0.9630, -0.1275,  0.9591],
        [-0.9896,  0.1453,  0.9207],
        [-1.1749,  1.2750,  0.6407],
        [-1.1375,  2.3921,  0.0555],
        [-1.1289,  2.6359, -0.1046],
        [-1.1582,  2.7314, -0.1295],
        [-0.9932, -0.0294,  0.9764],
        [-0.7385,  0.2252,  0.6313],
        [-1.1293,  2.6952, -0.1490],
        [-1.1584,  2.4887,  0.0283],
        [-1.0662,  0.2120,  0.9520],
        [-0.9298,  0.6295,  0.6451],
        [-1.1726,  2.6756, -0.0727],
        [-1.1933,  2.6966, -0.0718],
        [-1.0754,  2.1831,  0.0913],
        [-1.0191,  1.1154,  0.5025],
        [-1.1828,  2.7423, -0.1106],
        [-1.1803,  1.9620,  0.3250],
        [-1.1713,  2.4414,  0.0666],
        [-0.8649,  0.6268,  0.5865],
        [-1.0584,  2.1802,  0.1058],

Epoch number 1
 Current loss 0.5221048593521118

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.1392, -0.2893,  1.1657],
        [-1.2247, -0.0877,  1.2559],
        [-1.0533,  0.0289,  1.0702],
        [-1.2929,  2.7576, -0.0471],
        [-1.0097,  0.4801,  0.7406],
        [-1.2066,  2.8130, -0.1478],
        [-0.8332,  0.3059,  0.6989],
        [-1.2014,  2.8181, -0.1714],
        [-1.0230, -0.2239,  1.0544],
        [-1.2519,  2.7999, -0.0662],
        [-0.8951,  0.5826,  0.6514],
        [-1.2811,  2.9589, -0.1290],
        [-1.0865,  0.0934,  1.0431],
        [-0.8440, -0.1227,  0.8301],
        [-1.2618,  2.9503, -0.1493],
        [-1.0316,  0.2716,  0.9299],
        [-1.1987,  2.5815, -0.0413],
        [-1.2312,  2.8888, -0.1663],
        [-1.2852,  3.0152, -0.1485],
        [-1.2566,  2.9781, -0.1812],
        [-1.1519,  2.6125, -0.0983],
        [-1.3376,  3.1562, -0.1619],
        [-1.1724, -0.1726,  1.1977],
        [-1.2365,  2.9083, -0.1480],

Epoch number 1
 Current loss 0.5336052775382996

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.2603,  2.8614, -0.1295],
        [-1.2564,  2.7720, -0.0945],
        [-1.2468,  0.7930,  0.8717],
        [-1.1837, -0.0829,  1.2262],
        [-1.2724,  2.8710, -0.0917],
        [-1.3342,  1.8577,  0.5124],
        [-0.9406, -0.1212,  0.9398],
        [-1.3036,  3.0797, -0.1997],
        [-1.3356,  2.6670,  0.0769],
        [-1.0465,  1.2735,  0.4391],
        [-1.1910,  1.3598,  0.5444],
        [-1.1630,  2.3528,  0.0577],
        [-0.9667,  0.1002,  0.9315],
        [-1.2743,  0.7615,  0.9557],
        [-1.0039, -0.2182,  1.0378],
        [-1.2124,  1.1738,  0.6708],
        [-1.2512,  2.7434, -0.0690],
        [-1.0123, -0.1921,  1.0479],
        [-0.9579,  0.1360,  0.8800],
        [-1.2847,  2.9161, -0.1194],
        [-1.2825,  2.8640, -0.1003],
        [-1.1331,  2.3064,  0.0226],
        [-0.9745, -0.1010,  0.9974],
        [-1.2674,  2.9456, -0.1707],

Epoch number 1
 Current loss 0.5279650688171387

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.2039,  0.7578,  0.8942],
        [-1.0355, -0.2502,  1.0912],
        [-1.3039,  2.7254,  0.0133],
        [-1.0912,  0.8939,  0.7129],
        [-1.3217,  2.8997, -0.0624],
        [-1.3133,  2.9597, -0.1009],
        [-1.0537,  0.9019,  0.6657],
        [-1.2874,  2.8608, -0.1141],
        [-1.0586, -0.2934,  1.1397],
        [-0.9501, -0.1806,  0.9961],
        [-1.3007,  2.9786, -0.1585],
        [-1.1315,  1.1647,  0.5657],
        [-0.9325, -0.2057,  0.9835],
        [-1.0222,  2.0226,  0.0854],
        [-1.0871,  0.9445,  0.6866],
        [-1.2061, -0.3867,  1.3177],
        [-1.1282, -0.0009,  1.1322],
        [-1.0463,  1.2214,  0.5106],
        [-1.0870,  0.7151,  0.7473],
        [-0.9097, -0.2466,  0.9589],
        [-1.3175,  2.9476, -0.0784],
        [-1.2743,  2.1718,  0.3077],
        [-1.0221, -0.0958,  1.0432],
        [-1.1225,  1.3891,  0.4740],

Epoch number 1
 Current loss 0.46771132946014404

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.7665, -0.1804,  0.7904],
        [-1.2016,  1.6949,  0.3799],
        [-1.0529, -0.2695,  1.1305],
        [-1.1066,  1.5041,  0.4312],
        [-1.0527,  0.6010,  0.7832],
        [-1.0101, -0.0195,  1.0183],
        [-1.2200,  2.5697,  0.0059],
        [-1.2722,  2.9082, -0.1549],
        [-1.2010,  2.3188,  0.1299],
        [-0.8579, -0.1246,  0.8690],
        [-1.0673, -0.2994,  1.1592],
        [-1.3012,  2.9180, -0.1567],
        [-1.3608,  2.6809,  0.0879],
        [-1.2483,  2.2202,  0.2236],
        [-1.2059,  2.4010,  0.0513],
        [-1.3342,  1.0203,  0.8592],
        [-1.3126,  2.7109, -0.0130],
        [-1.3041,  2.3565,  0.2336],
        [-0.9672, -0.3433,  1.0533],
        [-1.1235, -0.3693,  1.2464],
        [-1.0101, -0.2753,  1.0776],
        [-1.3123,  2.6851,  0.0464],
        [-1.2836,  2.6556,  0.0023],
        [-1.3638,  3.0261, -0.1060]

Epoch number 1
 Current loss 0.5708951354026794

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.3733,  2.8989, -0.0174],
        [-1.2931,  2.9192, -0.1438],
        [-0.9163, -0.1911,  0.9686],
        [-0.6858, -0.2384,  0.7056],
        [-1.3187,  2.7799, -0.0374],
        [-1.2737,  2.8181, -0.1374],
        [-0.7934, -0.0504,  0.7656],
        [-1.3509,  3.0313, -0.1518],
        [-1.1283, -0.4105,  1.2555],
        [-1.2359,  2.4783,  0.0412],
        [-1.0078, -0.4086,  1.1095],
        [-1.3689,  3.0460, -0.1268],
        [-1.4016,  3.1416, -0.1454],
        [-1.1008, -0.2231,  1.1928],
        [-1.3203,  1.9735,  0.4310],
        [-1.3818,  3.0944, -0.1599],
        [-1.2154,  2.5444,  0.0155],
        [-0.6969, -0.2471,  0.7334],
        [-1.3498,  3.0593, -0.1733],
        [-1.2559,  2.6178, -0.0060],
        [-1.2560,  1.9393,  0.3683],
        [-0.6907, -0.0451,  0.6659],
        [-1.2490,  0.9328,  0.8055],
        [-0.9731, -0.3477,  1.0897],

Epoch number 1
 Current loss 0.6013031601905823

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.4438,  2.3154,  0.3958],
        [-1.1913,  0.5260,  0.9575],
        [-0.8247, -0.2355,  0.8816],
        [-0.9153, -0.3112,  1.0256],
        [-1.3234,  2.7537, -0.0309],
        [-0.7426,  0.1298,  0.6450],
        [-0.8284, -0.0734,  0.8908],
        [-1.4140,  3.1074, -0.1304],
        [-0.8188,  0.5247,  0.5748],
        [-1.1966,  1.5501,  0.4274],
        [-1.4105,  3.1165, -0.1660],
        [-1.0349,  1.2802,  0.3947],
        [-1.2333,  1.4252,  0.5580],
        [-1.3503,  3.0497, -0.1707],
        [-0.9089, -0.2524,  1.0096],
        [-1.1250,  2.1133,  0.0808],
        [-1.0694,  0.4797,  0.8089],
        [-1.4000,  3.0053, -0.0978],
        [-0.5861, -0.1954,  0.5790],
        [-1.0687,  1.1716,  0.4634],
        [-1.4354,  3.2176, -0.1759],
        [-1.3793,  3.0825, -0.1670],
        [-1.1098,  2.0961,  0.1209],
        [-1.3380,  2.7617,  0.0081],

Epoch number 1
 Current loss 0.5443297624588013

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.6412,  0.5013,  0.3927],
        [-0.8833,  0.2960,  0.7445],
        [-1.3208,  2.4791,  0.1740],
        [-1.2906,  2.9694, -0.2144],
        [-0.8269, -0.3785,  0.9446],
        [-1.4044,  2.7293,  0.0684],
        [-1.2415,  2.0147,  0.3510],
        [-0.8452,  0.1544,  0.7857],
        [-0.7428,  0.4045,  0.5209],
        [-0.8108,  0.9861,  0.3129],
        [-1.3262,  2.9216, -0.1422],
        [-1.2966,  2.0419,  0.3178],
        [-1.5389,  1.3139,  0.8560],
        [-0.6713, -0.1943,  0.6836],
        [-0.7673, -0.1426,  0.7496],
        [-1.0232,  1.1483,  0.4437],
        [-1.2556,  1.9356,  0.3441],
        [-1.3579,  2.6577,  0.0436],
        [-1.3550,  1.9296,  0.4277],
        [-1.3157,  3.0233, -0.2350],
        [-1.2803,  1.4538,  0.5537],
        [-1.3786,  2.9907, -0.1160],
        [-0.8088,  0.6957,  0.4244],
        [-1.1222,  0.9709,  0.6357],

Epoch number 1
 Current loss 0.44933316111564636

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.7463, -0.5033,  0.8571],
        [-0.8870, -0.4651,  1.0460],
        [-0.6986,  0.7294,  0.2487],
        [-0.6827, -0.1918,  0.7261],
        [-1.3146,  2.7010, -0.0057],
        [-1.4668,  3.0619, -0.0383],
        [-1.3477,  2.8048, -0.0714],
        [-0.7203,  0.6599,  0.3742],
        [-1.3728,  2.6584,  0.0941],
        [-1.3590,  2.9537, -0.1246],
        [-1.2990,  2.6829, -0.0597],
        [-1.0323, -0.5368,  1.2362],
        [-1.0605,  2.1292,  0.0011],
        [-1.3234,  2.8167, -0.0589],
        [-1.0815,  0.3002,  0.9507],
        [-1.3836,  3.0526, -0.1740],
        [-0.9420,  0.0565,  0.8974],
        [-1.2816,  2.4116,  0.1121],
        [-1.2516,  1.2889,  0.6502],
        [-1.2257,  2.5909, -0.0737],
        [-1.1068,  2.0878,  0.0872],
        [-1.1613,  0.4224,  0.9638],
        [-1.1950,  2.3211,  0.0836],
        [-0.6781,  0.0920,  0.5255]

Epoch number 1
 Current loss 0.6135730147361755

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.1875,  1.3142,  0.5427],
        [-1.3041,  2.7417, -0.0721],
        [-1.3386,  2.8461, -0.1329],
        [-0.9327,  0.5372,  0.6402],
        [-1.4266,  3.1939, -0.2269],
        [-1.0644,  0.8958,  0.6003],
        [-1.3440,  3.1213, -0.2681],
        [-1.2054,  2.1981,  0.1253],
        [-1.3095,  2.8409, -0.1605],
        [-1.5020,  1.9841,  0.5385],
        [-1.2548,  2.0203,  0.2405],
        [-1.3459,  3.0833, -0.2332],
        [-0.7338, -0.4500,  0.8672],
        [-0.9832,  0.1253,  0.8923],
        [-1.3619,  2.9431, -0.1510],
        [-0.7635, -0.5119,  0.8875],
        [-1.3965,  3.1077, -0.2357],
        [-1.3476,  3.0854, -0.2712],
        [-1.1566,  1.7490,  0.2857],
        [-1.2377,  2.5135, -0.0254],
        [-0.5757, -0.1936,  0.5734],
        [-1.4261,  3.0422, -0.0984],
        [-1.1388,  0.3699,  0.9472],
        [-1.0553,  2.0481, -0.0132],

Epoch number 1
 Current loss 0.5041872262954712

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.9824,  0.3621,  0.7650],
        [-1.3567,  3.0288, -0.2417],
        [-1.3662,  1.6930,  0.4850],
        [-0.8758,  0.3533,  0.6119],
        [-1.3297,  2.9636, -0.2524],
        [-1.0552,  1.2019,  0.3689],
        [-0.6131, -0.3037,  0.6135],
        [-1.4648,  3.3092, -0.3146],
        [-0.6322, -0.3779,  0.6999],
        [-0.7749, -0.0710,  0.7660],
        [-0.6305, -0.1653,  0.6002],
        [-0.8356, -0.1583,  0.8811],
        [-0.8642,  0.2599,  0.7154],
        [-1.3854,  2.9046, -0.1284],
        [-1.4662,  3.2632, -0.2429],
        [-1.0689,  2.2277, -0.0982],
        [-0.8336,  1.4984, -0.0037],
        [-1.2769,  2.8956, -0.3102],
        [-1.4234,  2.7856,  0.0143],
        [-0.5645,  0.3014,  0.2806],
        [-1.3077,  2.7636, -0.1539],
        [-0.6418,  0.7491,  0.1897],
        [-0.7004,  0.0156,  0.5646],
        [-1.4277,  3.1824, -0.2465],

Epoch number 1
 Current loss 0.4343120753765106

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.5694,  0.0017,  0.4470],
        [-1.1412,  2.0815,  0.0393],
        [-1.0974,  2.0299,  0.0740],
        [-1.4143,  2.8998, -0.1119],
        [-1.3571,  2.9110, -0.1804],
        [-1.2959,  2.0419,  0.2685],
        [-0.8514,  0.4095,  0.6412],
        [-1.4585,  3.2211, -0.3060],
        [-1.3242,  1.9410,  0.3173],
        [-0.8759,  0.7401,  0.4208],
        [-0.9772,  1.0737,  0.3800],
        [-1.3590,  2.9622, -0.2262],
        [-1.0971, -0.1180,  1.0958],
        [-1.1421,  1.8481,  0.1203],
        [-0.4558, -0.4109,  0.5112],
        [-0.9105,  0.4216,  0.6343],
        [-1.4752,  3.2553, -0.2759],
        [-0.6468, -0.1653,  0.6168],
        [-1.1666,  1.8880,  0.1595],
        [-1.4065,  2.9756, -0.1680],
        [-1.4504,  3.1750, -0.2683],
        [-0.6088, -0.4000,  0.6978],
        [-1.5353,  3.2905, -0.1989],
        [-1.1592,  2.3490, -0.0244],

Epoch number 1
 Current loss 0.5649864077568054

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.4972,  3.3002, -0.3037],
        [-1.5019,  3.3688, -0.3296],
        [-1.4478,  3.2428, -0.2967],
        [-1.0674, -0.3392,  1.2579],
        [-1.4116,  3.0403, -0.2016],
        [-0.9056,  0.7775,  0.4028],
        [-1.3651,  2.2332,  0.2626],
        [-1.4670,  3.1094, -0.1545],
        [-0.7116,  0.4137,  0.3316],
        [-1.3977,  2.9356, -0.1340],
        [-1.4735,  3.2313, -0.2417],
        [-0.8284,  1.2906,  0.0661],
        [-1.2062,  0.8085,  0.7379],
        [-0.6575, -0.4173,  0.7250],
        [-0.9627,  1.6789,  0.0511],
        [-1.5389,  1.6183,  0.6327],
        [-0.7693,  0.4803,  0.4781],
        [-0.5638, -0.2147,  0.5470],
        [-1.1282,  0.8277,  0.7136],
        [-0.5314, -0.0671,  0.4284],
        [-1.3856,  2.6759,  0.0124],
        [-0.4613, -0.2969,  0.4349],
        [-1.5677,  3.2163, -0.1503],
        [-1.4564,  2.8950, -0.0042],

Epoch number 1
 Current loss 0.4926966726779938

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.5176,  3.1286, -0.1229],
        [-0.6859,  0.2923,  0.3560],
        [-1.3941,  3.2006, -0.3304],
        [-1.4128,  3.0050, -0.1690],
        [-1.1463,  2.0385,  0.0350],
        [-0.7816, -0.0067,  0.6857],
        [-0.8777,  0.0712,  0.7556],
        [-0.8732, -0.6857,  1.1120],
        [-1.5197,  3.2102, -0.2127],
        [-1.5470,  3.0429, -0.0345],
        [-0.5352, -0.5675,  0.6620],
        [-1.2932,  2.2299,  0.0999],
        [-0.4480, -0.3244,  0.3831],
        [-1.3090,  2.8149, -0.2802],
        [-1.2516,  1.7391,  0.3143],
        [-1.2465,  0.4467,  0.9097],
        [-1.4440,  2.1436,  0.3270],
        [-1.1521,  1.8840,  0.1830],
        [-0.8391, -0.7569,  1.1177],
        [-0.7976,  0.5222,  0.4427],
        [-1.2986,  2.2240,  0.1523],
        [-1.2769,  2.1774,  0.1426],
        [-0.6875,  0.5865,  0.2736],
        [-0.8348, -0.2914,  0.9519],

Epoch number 1
 Current loss 0.47729921340942383

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.5803,  3.3095, -0.2086],
        [-1.2380,  2.1056,  0.1355],
        [-0.8295,  0.0304,  0.6953],
        [-1.4199,  3.1617, -0.3346],
        [-0.4681, -0.3065,  0.4235],
        [-0.9866,  1.5058,  0.1586],
        [-1.0585, -0.7180,  1.3421],
        [-1.4472,  2.8201, -0.0621],
        [-1.4753,  3.1208, -0.2630],
        [-1.4045,  2.5686,  0.0653],
        [-1.4736,  2.9739, -0.0914],
        [-1.4112,  3.0584, -0.3035],
        [-0.6545, -0.5358,  0.8277],
        [-1.1140,  0.4646,  0.7883],
        [-0.7468,  0.9019,  0.1618],
        [-1.4501,  3.1135, -0.1952],
        [-1.1564,  1.2788,  0.4639],
        [-1.3946,  1.8709,  0.3783],
        [-0.9916, -0.7113,  1.3066],
        [-0.7630, -0.7371,  1.0302],
        [-0.6098, -0.4977,  0.7125],
        [-1.3401,  2.1259,  0.1947],
        [-0.7802,  0.5781,  0.4174],
        [-1.0085, -0.1101,  1.0401]

Epoch number 1
 Current loss 0.4710540771484375

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.4031,  3.1846, -0.3531],
        [-1.3337,  2.7529, -0.2094],
        [-0.5025, -0.6420,  0.6221],
        [-0.9479,  1.4983, -0.0170],
        [-1.0122,  0.0593,  0.9269],
        [-1.5702,  3.0356, -0.0938],
        [-1.4810,  2.9143, -0.0952],
        [-1.5617,  1.7136,  0.5906],
        [-1.1966,  1.7157,  0.1749],
        [-0.5621, -0.6860,  0.7133],
        [-1.4182,  2.9384, -0.2183],
        [-1.3302,  2.3136,  0.0898],
        [-1.4587,  2.5306,  0.0353],
        [-1.4602,  2.0196,  0.3566],
        [-1.4744,  3.3206, -0.4105],
        [-1.5916,  1.8208,  0.6111],
        [-0.6786,  0.5538,  0.2528],
        [-1.5122,  3.1182, -0.2247],
        [-0.6636, -0.6870,  0.9017],
        [-1.1000,  1.5746,  0.1160],
        [-1.5802,  3.2808, -0.2235],
        [-1.0105,  2.0887, -0.1799],
        [-1.4228,  3.2466, -0.4307],
        [-0.5664,  0.0729,  0.3593],

Epoch number 1
 Current loss 0.4580323398113251

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.5440,  3.5452, -0.5146],
        [-1.0462,  0.7584,  0.4862],
        [-1.4892,  3.1754, -0.2614],
        [-1.2325,  2.4116, -0.1425],
        [-0.7299, -0.7373,  0.9663],
        [-1.5793,  3.0910, -0.0968],
        [-1.5004,  3.4973, -0.5362],
        [-1.5241,  3.1769, -0.2619],
        [-0.7793, -0.4117,  0.8810],
        [-1.3407,  2.0450,  0.1737],
        [-0.5165, -0.3625,  0.4343],
        [-1.3650,  2.0222,  0.2177],
        [-1.4651,  1.9945,  0.3779],
        [-0.9840, -0.4683,  1.2239],
        [-1.1708,  0.8410,  0.6539],
        [-1.4818,  0.5649,  1.0526],
        [-1.3758,  1.1641,  0.6801],
        [-1.5043,  3.2405, -0.3176],
        [-1.4868,  3.3896, -0.4542],
        [-1.0826,  1.5611,  0.1068],
        [-1.4514,  2.5060,  0.0861],
        [-1.1019,  0.2284,  0.8708],
        [-1.4846,  2.9894, -0.1627],
        [-0.5405, -0.5336,  0.6239],

Epoch number 1
 Current loss 0.4761046767234802

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.4604,  3.3989, -0.5346],
        [-0.6342, -0.3477,  0.6154],
        [-0.8033,  1.4044, -0.0603],
        [-0.5754, -0.5816,  0.6455],
        [-0.8323, -0.0460,  0.6438],
        [-1.2234,  0.1979,  1.0126],
        [-1.5556,  3.3624, -0.2736],
        [-1.3297,  2.3136,  0.0224],
        [-0.6394, -0.5000,  0.7515],
        [-0.5590, -0.7598,  0.7556],
        [-0.8835,  0.3313,  0.6231],
        [-1.5060,  2.7249, -0.0241],
        [-0.9323, -0.6094,  1.1936],
        [-1.3669,  2.3305,  0.0479],
        [-0.6040,  0.1403,  0.2701],
        [-0.8970,  1.6830, -0.1080],
        [-1.5057,  3.5484, -0.5484],
        [-1.1476,  2.2940, -0.1884],
        [-0.4832, -0.8018,  0.5757],
        [-1.5854,  3.1330, -0.1200],
        [-1.3964,  2.9924, -0.2974],
        [-0.9190, -0.7948,  1.2782],
        [-1.4537,  3.1086, -0.2631],
        [-0.7398, -0.5746,  0.8533],

Epoch number 1
 Current loss 0.5131300091743469

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.8036, -0.8687,  1.1952],
        [-1.3069,  1.7678,  0.2167],
        [-1.4278,  3.2405, -0.4484],
        [-1.5763,  2.8760, -0.0137],
        [-1.2792,  1.5335,  0.2739],
        [-1.6239,  3.4017, -0.3124],
        [-0.6405, -0.8850,  0.8912],
        [-1.0772, -0.2559,  1.1494],
        [-0.9075, -0.6878,  1.1860],
        [-1.5855,  3.3166, -0.3027],
        [-0.8964,  0.4802,  0.5232],
        [-1.5907,  3.4408, -0.3841],
        [-0.9335, -0.0523,  0.8271],
        [-1.5434,  3.0121, -0.1889],
        [-0.4233, -0.4865,  0.4186],
        [-1.5335,  3.4156, -0.4362],
        [-1.5131,  3.2740, -0.3776],
        [-1.0237,  1.3434,  0.1430],
        [-1.0980,  0.2181,  0.8635],
        [-1.4664,  3.3929, -0.4981],
        [-0.8013, -0.7943,  1.0573],
        [-1.4525,  3.2989, -0.4852],
        [-0.4497, -0.5069,  0.4612],
        [-1.6142,  3.4334, -0.3162],

Epoch number 1
 Current loss 0.42648616433143616

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.4994,  3.4263, -0.4506],
        [-1.3817,  2.9004, -0.3554],
        [-1.3922,  2.7962, -0.2504],
        [-1.3442,  2.3508, -0.0890],
        [-0.4624, -0.2688,  0.3326],
        [-0.4632, -0.6480,  0.4414],
        [-1.2210,  2.5651, -0.3311],
        [-1.5734,  2.9865, -0.0975],
        [-1.4610,  2.0266,  0.2733],
        [-1.3474,  1.4298,  0.4518],
        [-1.1410,  0.6256,  0.6579],
        [-1.4180,  2.5159, -0.0176],
        [-1.5479,  3.5999, -0.6086],
        [-1.4502,  2.3877,  0.0805],
        [-1.0283,  0.0488,  0.8793],
        [-1.2594,  1.2990,  0.4276],
        [-0.6139,  0.3017,  0.1744],
        [-1.3988,  1.5941,  0.4221],
        [-1.0235, -0.0606,  0.9026],
        [-1.3845,  0.4100,  1.0566],
        [-1.3827,  2.5302, -0.0840],
        [-1.3766,  2.2787, -0.0361],
        [-1.4564,  1.9284,  0.2377],
        [-1.3478,  2.0114,  0.1499]

Epoch number 1
 Current loss 0.4696670472621918

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.3790,  3.1436, -0.4711],
        [-1.2456,  0.2591,  0.9801],
        [-1.5639,  3.4349, -0.5416],
        [-1.5638,  3.4728, -0.5963],
        [-1.5848,  3.5700, -0.5865],
        [-1.0153,  1.6668, -0.1111],
        [-1.6195,  3.5116, -0.4699],
        [-1.4462,  2.8863, -0.2556],
        [-1.4385,  3.2698, -0.4845],
        [-0.9890, -0.3852,  1.0454],
        [-1.5454,  3.4308, -0.5241],
        [-0.5971, -0.6796,  0.5477],
        [-1.2729,  2.8665, -0.5826],
        [-1.0619, -0.8039,  1.3486],
        [-0.8578,  0.4018,  0.3253],
        [-1.4498,  2.5702, -0.1207],
        [-1.5983,  3.5741, -0.5994],
        [-1.5935,  3.5462, -0.5626],
        [-1.5751,  3.5949, -0.6020],
        [-1.2047, -0.7376,  1.5574],
        [-1.0804,  0.3011,  0.8439],
        [-1.6395,  3.7263, -0.6691],
        [-1.5882,  3.4302, -0.3970],
        [-1.6877,  2.4405,  0.1601],

Epoch number 1
 Current loss 0.5403326153755188

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.3560,  2.8118, -0.4611],
        [-0.7218, -0.9274,  0.9486],
        [-1.4661,  2.6609, -0.0834],
        [-1.4640,  3.0620, -0.3773],
        [-0.6467, -0.8486,  0.7988],
        [-1.4491,  2.6753, -0.1984],
        [-1.0383, -1.0082,  1.3877],
        [-1.5314,  3.5722, -0.5733],
        [-0.5071, -0.9707,  0.6281],
        [-1.5265,  2.9793, -0.2066],
        [-1.0568, -0.4859,  1.2318],
        [-1.4865,  2.2380,  0.1291],
        [-1.4554,  2.8774, -0.3238],
        [-0.8925, -0.8171,  1.1213],
        [-0.5474, -0.9589,  0.6411],
        [-1.6698,  3.1046, -0.1703],
        [-1.4612,  3.2144, -0.4209],
        [-1.3699,  2.4593, -0.1411],
        [-0.9061, -1.0442,  1.3511],
        [-0.4089, -0.8324,  0.4546],
        [-1.0363,  1.2451,  0.0055],
        [-0.7081, -0.8604,  0.9022],
        [-1.3170,  2.6077, -0.3488],
        [-1.6583,  3.2832, -0.3630],

Epoch number 1
 Current loss 0.40633219480514526

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.5719, -0.9836,  0.7680],
        [-1.5542,  2.0724,  0.1632],
        [-1.3816,  2.5891, -0.3314],
        [-1.4815,  3.1705, -0.5453],
        [-1.0380, -1.1518,  1.5942],
        [-0.6310,  0.1242,  0.1017],
        [-1.0371,  0.8044,  0.2857],
        [-1.1197, -1.2684,  1.7286],
        [-1.2593, -0.3443,  1.3004],
        [-1.5243,  3.2073, -0.4620],
        [-1.4211,  3.1995, -0.5106],
        [-1.7656,  3.8922, -0.6327],
        [-0.9563, -0.0022,  0.7734],
        [-1.6488,  3.5032, -0.5070],
        [-1.4237,  2.0593,  0.1650],
        [-0.5290, -0.1620,  0.0348],
        [-1.7447,  3.3778, -0.3052],
        [-0.7265,  0.6522, -0.0587],
        [-1.6066,  2.8466, -0.1487],
        [-1.2215,  1.1231,  0.2985],
        [-1.6019,  1.9555,  0.2992],
        [-1.5197,  2.4336, -0.0590],
        [-0.6618, -1.0815,  0.9257],
        [-1.1925,  1.7740,  0.0462]

Epoch number 1
 Current loss 0.4493340849876404

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.4800,  0.8207,  0.9435],
        [-1.5430,  3.2632, -0.4224],
        [-1.7024,  2.8877, -0.0449],
        [-1.5637,  3.0234, -0.2643],
        [-1.5084,  3.2875, -0.4841],
        [-1.2529,  2.0743, -0.0981],
        [-1.2908,  2.8270, -0.4683],
        [-0.6790, -0.2641,  0.5242],
        [-1.3763,  2.9905, -0.6107],
        [-1.0921,  1.1358, -0.0195],
        [-0.3044, -1.1295,  0.3148],
        [-1.6529,  3.4648, -0.4792],
        [-0.9813,  0.2854,  0.5377],
        [-1.1143, -0.0170,  0.9548],
        [-0.5854, -0.6849,  0.4106],
        [-1.5897,  3.5647, -0.5917],
        [-1.4162,  1.0375,  0.7131],
        [-1.5986,  2.4762,  0.0774],
        [-0.3285, -0.9636,  0.3249],
        [-1.4813,  3.2768, -0.5701],
        [-1.3229,  1.6029,  0.1214],
        [-1.6659,  1.7588,  0.4438],
        [-0.7311, -1.0535,  0.9347],
        [-0.6080, -1.1279,  0.8700],

Epoch number 1
 Current loss 0.4923170208930969

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.5654e+00,  3.3246e+00, -6.2573e-01],
        [-1.8069e+00,  4.0782e+00, -8.3933e-01],
        [-1.3551e+00,  1.5305e+00,  2.4533e-01],
        [-1.7213e+00,  3.8076e+00, -7.4869e-01],
        [-1.3324e+00,  2.2254e+00, -2.3921e-01],
        [-1.8296e+00,  3.9151e+00, -6.9975e-01],
        [-1.5974e+00,  1.5339e+00,  5.2919e-01],
        [-1.2603e+00,  2.5988e+00, -4.4791e-01],
        [-1.2184e+00,  5.7518e-01,  6.2332e-01],
        [-1.5093e+00,  7.0639e-01,  9.7654e-01],
        [-3.9997e-01, -1.2383e+00,  5.1430e-01],
        [-1.4918e+00,  3.1612e+00, -4.5659e-01],
        [-1.2450e+00,  6.5508e-01,  5.6326e-01],
        [-3.8001e-01, -7.0170e-01,  2.3129e-01],
        [-1.5535e+00,  2.5948e+00, -1.0269e-01],
        [-1.6157e+00,  2.1612e+00,  1.6366e-01],
        [-1.8281e+00,  3.4280e+00, -2.9220e-01],
        [-1.4132e+00,  2.6037e+00, -2.5590e-01],
     

Epoch number 1
 Current loss 0.3522850573062897

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.2182,  0.8135,  0.5003],
        [-1.9111,  4.0627, -0.7144],
        [-0.4774,  0.0024, -0.3467],
        [-1.8182,  2.6446,  0.0981],
        [-1.9192,  4.0655, -0.7344],
        [-1.6121,  2.6360, -0.1420],
        [-0.9291, -0.3021,  0.7362],
        [-1.3947,  2.6256, -0.4728],
        [-0.2421, -1.2441,  0.2810],
        [-1.6068,  3.1156, -0.4773],
        [-1.8138,  3.8099, -0.6720],
        [-1.8594,  4.0068, -0.7468],
        [-0.3816, -0.8014,  0.2054],
        [-1.1393, -0.8279,  1.5157],
        [-1.1191,  1.4197, -0.2383],
        [-1.6919,  2.0999,  0.3074],
        [-1.4492,  0.4787,  1.0522],
        [-1.8943,  4.0071, -0.7237],
        [-0.6527, -1.0628,  0.6929],
        [-2.0174,  4.0669, -0.5124],
        [-1.7596,  3.7841, -0.6017],
        [-1.9072,  4.1133, -0.7251],
        [-2.0190,  3.7777, -0.3371],
        [-1.6360,  3.4880, -0.5602],

Epoch number 1
 Current loss 0.3686477243900299

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.8123,  3.7138, -0.6055],
        [-1.8526,  3.9707, -0.8028],
        [-0.4882, -0.1302, -0.2096],
        [-1.8112,  3.5511, -0.4827],
        [-0.2301, -1.1880,  0.1933],
        [-1.9077,  3.8047, -0.5333],
        [-1.7098,  3.2075, -0.4187],
        [-1.1501, -0.2162,  1.0479],
        [-1.7002,  2.8643, -0.0573],
        [-1.1168,  1.1830,  0.1093],
        [-1.8960,  3.9101, -0.6417],
        [-1.8536,  4.0438, -0.8621],
        [-0.9329, -0.0477,  0.5401],
        [-0.6488, -1.3780,  0.9629],
        [-1.0068,  0.2339,  0.4274],
        [-1.8927,  3.5770, -0.3928],
        [-1.6040,  2.0381,  0.1301],
        [-1.7541,  2.0939,  0.2799],
        [-1.7063,  1.6664,  0.5270],
        [-1.2742,  0.3696,  0.8852],
        [-1.8724,  3.9960, -0.7151],
        [-1.9340,  4.1991, -0.8446],
        [-2.0532,  3.8381, -0.4007],
        [-0.4849, -1.1457,  0.6515],

Epoch number 1
 Current loss 0.37906405329704285

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.3304,  1.9419, -0.1404],
        [-1.9998,  3.9977, -0.5246],
        [-0.9201, -1.1740,  1.3802],
        [-1.9138,  4.1451, -0.8454],
        [-1.8642,  3.4898, -0.3621],
        [-0.8732, -1.1040,  1.2081],
        [-0.8538, -0.9105,  0.8886],
        [-0.9216, -0.8659,  1.1355],
        [-1.7761,  3.7018, -0.6609],
        [-0.7991, -0.0002,  0.3529],
        [-1.8337,  3.3441, -0.3698],
        [-0.7866, -1.2551,  1.3182],
        [-1.5160,  2.1043,  0.0741],
        [-1.6616,  3.4166, -0.5124],
        [-2.1369,  3.1994,  0.1004],
        [-2.1033,  3.7306, -0.2010],
        [-1.9938,  4.2104, -0.8438],
        [-1.4500,  2.2599, -0.1665],
        [-1.7710,  2.5125,  0.1189],
        [-1.8663,  4.0001, -0.7768],
        [-0.6438, -1.5142,  1.0303],
        [-1.3166, -0.0906,  1.2691],
        [-1.8219,  3.3340, -0.3568],
        [-1.7860,  3.1070, -0.2050]

Epoch number 1
 Current loss 0.3869125545024872

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.8821,  3.8450, -0.6339],
        [-1.8271,  3.7287, -0.5543],
        [-1.0821, -1.3524,  1.6873],
        [-1.9231,  2.7611,  0.2045],
        [-1.9236,  4.1763, -0.8617],
        [-2.0076,  2.9285,  0.1298],
        [-2.0187,  3.6222, -0.2499],
        [-1.9485,  3.6008, -0.3421],
        [-1.9508,  3.2248, -0.1493],
        [-2.0384,  3.6436, -0.2719],
        [-1.8430,  3.3419, -0.3712],
        [-2.1164,  4.3795, -0.8092],
        [-1.1258,  0.8210,  0.3008],
        [-0.4145, -1.4947,  0.6899],
        [-0.9866, -1.5129,  1.6901],
        [-1.7631,  3.5069, -0.5960],
        [-1.0167, -0.4016,  1.0197],
        [-1.8757,  1.7733,  0.6509],
        [-2.0172,  3.9354, -0.4909],
        [-0.6092, -1.2051,  0.8645],
        [-1.3636,  1.6129,  0.1749],
        [-1.8065,  2.0262,  0.4784],
        [-1.9386,  2.7573,  0.1574],
        [-1.9618,  3.9959, -0.6326],

Epoch number 1
 Current loss 0.35952842235565186

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.8882,  0.4270,  0.0796],
        [-1.9727,  3.5394, -0.3099],
        [-1.8375,  2.9528, -0.0992],
        [-0.7051, -1.1721,  0.9216],
        [-1.6601,  2.6257, -0.1001],
        [-2.0054,  4.1060, -0.7773],
        [-1.8734,  3.4437, -0.3744],
        [-1.9227,  3.8567, -0.5395],
        [-1.6739,  2.1911,  0.1748],
        [-1.5386,  2.5863, -0.2751],
        [-1.7883,  2.9562, -0.1890],
        [-0.5404, -1.2111,  0.7619],
        [-1.8646,  3.2341, -0.2453],
        [-1.8859,  3.2611, -0.2252],
        [-1.8706,  3.9716, -0.7665],
        [-1.4785,  0.1812,  1.2191],
        [-1.5182,  2.6852, -0.2726],
        [-1.2981, -0.8505,  1.6670],
        [-0.6152, -1.6543,  1.2523],
        [-2.0358,  3.8661, -0.4843],
        [-0.7111, -0.9010,  0.8194],
        [-1.7991,  3.6710, -0.5842],
        [-1.7413,  3.7508, -0.8121],
        [-0.7826, -0.8053,  0.8362]

Epoch number 1
 Current loss 0.3456800580024719

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.3023,  4.3760, -0.5607],
        [-2.0640,  4.1918, -0.7307],
        [-2.1011,  4.3605, -0.8948],
        [-0.7322, -1.1835,  1.0187],
        [-1.9660,  4.0012, -0.7290],
        [-1.9392,  3.9786, -0.6587],
        [-2.1622,  3.9902, -0.4588],
        [-1.8828,  0.0178,  1.7468],
        [-2.1136,  4.1537, -0.6170],
        [-2.1440,  4.1482, -0.6294],
        [-1.1479, -1.1150,  1.6909],
        [-1.8199,  3.5633, -0.5239],
        [-1.4268,  0.8680,  0.7186],
        [-2.1227,  4.1770, -0.6388],
        [-1.4002,  1.4447,  0.2767],
        [-1.8443,  3.7538, -0.7080],
        [-1.1901,  0.0943,  1.0014],
        [-2.0511,  3.7243, -0.3704],
        [-2.2136,  4.4162, -0.7620],
        [-1.5796,  0.3965,  1.1812],
        [-1.4477,  0.4110,  1.0046],
        [-2.0527,  3.6444, -0.3791],
        [-2.0825,  4.2395, -0.7647],
        [-1.2233,  0.8717,  0.3897],

Epoch number 1
 Current loss 0.3644694685935974

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.1284,  4.1844, -0.6890],
        [-2.1255,  3.5303, -0.1572],
        [-1.8975,  1.6500,  0.7513],
        [-1.6155,  0.1710,  1.4566],
        [-2.1940,  4.2004, -0.6038],
        [-2.0112,  1.1943,  1.1284],
        [-0.9715, -1.5280,  1.6882],
        [-1.7430,  0.6869,  1.2179],
        [-0.2497, -1.5089,  0.3601],
        [-2.1323,  3.9034, -0.4427],
        [-2.0450,  3.4841, -0.2157],
        [-2.0192,  3.7327, -0.3999],
        [-1.0570, -1.1472,  1.4964],
        [-2.1297,  3.5787, -0.1948],
        [-1.5372, -0.0075,  1.4447],
        [-2.0532,  0.0839,  1.8799],
        [-2.2364,  1.3397,  1.2718],
        [-1.8861,  3.0589, -0.1479],
        [-1.8909,  0.7168,  1.3203],
        [-1.7669,  0.4852,  1.3633],
        [-0.1654, -1.4954,  0.2706],
        [-1.7341,  2.1243,  0.2355],
        [-2.2955,  3.5579, -0.0157],
        [-2.1955,  3.6554, -0.1885],

Epoch number 1
 Current loss 0.3899758756160736

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.1982,  0.9782,  1.4875],
        [-2.2312,  3.3854,  0.0547],
        [-1.1650, -0.2408,  1.1054],
        [-2.2604,  4.3467, -0.5852],
        [-1.9917,  3.3081, -0.2107],
        [-1.3098, -1.2630,  1.9616],
        [-0.7477, -1.0307,  0.9039],
        [-1.3660, -0.7556,  1.7343],
        [-2.1458,  3.4089, -0.0958],
        [-1.6331, -0.8699,  2.1297],
        [-2.2352,  4.4122, -0.7357],
        [-1.9517,  2.6459,  0.1969],
        [-1.8761,  3.2984, -0.3300],
        [-2.2444,  3.0599,  0.2735],
        [-1.1715,  0.0988,  0.9290],
        [-2.3399,  2.1503,  0.9085],
        [-1.4123,  0.2144,  1.1622],
        [-2.0756,  1.4284,  1.0666],
        [-2.2427,  4.4139, -0.7358],
        [-2.4642,  2.6798,  0.7036],
        [-2.2562,  4.0550, -0.4139],
        [-2.0562,  1.1236,  1.2409],
        [-1.8518,  0.8627,  1.1831],
        [-2.0076,  4.0538, -0.6555],

Epoch number 1
 Current loss 0.4369398057460785

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.3127, -1.6183,  0.4606],
        [-2.0369,  3.5802, -0.3088],
        [-2.1654,  3.7092, -0.2602],
        [-1.8980,  3.1526, -0.2572],
        [-2.3589,  4.4632, -0.6367],
        [-2.1501,  3.7101, -0.2855],
        [-2.2822,  4.0759, -0.4272],
        [-1.9006,  3.7024, -0.6400],
        [-1.9112,  0.0254,  1.7733],
        [-2.4187,  3.6328,  0.0903],
        [-2.0375,  3.9422, -0.6819],
        [-1.1517,  1.5663, -0.3207],
        [-2.0427,  3.4002, -0.2388],
        [-0.9887,  0.4789,  0.2493],
        [-0.4482, -1.3118,  0.5621],
        [-1.9631,  3.4955, -0.3315],
        [-2.1538,  3.6163, -0.2256],
        [-2.4269,  4.1598, -0.2966],
        [-2.0821,  3.9294, -0.6307],
        [-1.4054,  0.1858,  1.1774],
        [-2.1973,  4.1715, -0.6079],
        [-1.9239,  0.7240,  1.3576],
        [-2.0208,  3.8261, -0.5873],
        [-2.3414,  3.4066,  0.1652],

Epoch number 2
 Current loss 0.24042215943336487

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.0815,  3.8035, -0.5263],
        [-0.5966, -1.5495,  1.0127],
        [-1.8470,  3.5128, -0.6453],
        [-2.3608,  4.4184, -0.6439],
        [-1.4548,  1.4280,  0.3689],
        [-0.7003, -0.9178,  0.7653],
        [-2.3065,  4.4810, -0.7626],
        [-1.9629,  1.3618,  0.9712],
        [-2.5557,  4.2103, -0.1986],
        [-1.0938, -0.1453,  0.8395],
        [-2.3298,  4.2532, -0.5514],
        [-2.2499,  2.9963,  0.2645],
        [-1.0146, -1.0579,  1.4247],
        [-2.3254,  4.3387, -0.5951],
        [-2.3369,  4.4543, -0.6849],
        [-2.1652,  1.7960,  0.8886],
        [-2.2764,  4.2421, -0.6295],
        [-1.9051,  3.8932, -0.7566],
        [-2.2336,  4.4657, -0.8857],
        [-2.1764,  4.4248, -0.9330],
        [-2.1704,  3.9463, -0.5915],
        [-1.5348,  0.9646,  0.7912],
        [-2.2847,  4.6456, -0.9926],
        [-0.6238, -0.0875, -0.0411]

Epoch number 2
 Current loss 0.34087762236595154

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.7715, -0.7456,  0.7869],
        [-2.1726,  3.9342, -0.5667],
        [-2.3327,  4.5258, -0.8634],
        [-2.2709,  4.4168, -0.9229],
        [-0.4505, -1.4009,  0.6203],
        [-1.8449,  2.2457,  0.2465],
        [-2.2597,  4.5606, -1.0463],
        [-0.6865, -0.5422,  0.4250],
        [-0.8768, -0.3923,  0.6537],
        [-2.0865,  3.9295, -0.6375],
        [-2.2451,  4.4730, -0.9399],
        [-2.3474,  4.7164, -1.0958],
        [-2.1769,  4.4180, -1.0272],
        [-2.2591,  4.4133, -0.9490],
        [-1.7224,  1.2848,  0.7579],
        [-2.2771,  4.2534, -0.6381],
        [-2.1700,  4.3780, -1.0769],
        [-0.9704,  1.4043, -0.6903],
        [-1.5992,  1.1607,  0.7281],
        [-2.1994,  4.3203, -0.9351],
        [-1.2878, -0.2115,  1.2556],
        [-1.3052,  2.5690, -0.8814],
        [-2.3556,  4.7586, -1.1189],
        [-2.1084,  3.9813, -0.7360]

Epoch number 2
 Current loss 0.33146265149116516

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.0787,  4.2247, -0.9678],
        [-1.2957,  0.8661,  0.2267],
        [-1.9745,  3.8450, -0.8959],
        [-2.2091,  3.2870, -0.0476],
        [-2.1138,  4.3281, -1.1514],
        [-0.2147, -0.7413, -0.2790],
        [-2.0584,  3.4623, -0.3754],
        [-1.9260,  3.8168, -0.9359],
        [-1.8693,  3.2902, -0.4923],
        [-2.0088,  3.3739, -0.3885],
        [-1.7258,  1.6901,  0.4153],
        [-1.8056,  3.0282, -0.5015],
        [-2.0114,  4.0142, -1.0506],
        [-1.4294, -0.5949,  1.6944],
        [-2.2454,  4.4174, -1.0461],
        [-2.2846,  4.6472, -1.1996],
        [-1.8032,  3.1890, -0.5177],
        [-0.1950, -1.5494,  0.2578],
        [-1.6058,  1.0837,  0.6391],
        [-1.9432,  3.6398, -0.7288],
        [-1.8165,  1.3873,  0.7107],
        [-1.9625,  3.9802, -1.0703],
        [-2.2355,  4.3008, -0.8764],
        [-2.2023,  4.4458, -1.1509]

Epoch number 2
 Current loss 0.2947412133216858

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.8021,  3.6438, -0.9702],
        [-1.8050,  3.5368, -0.8900],
        [-0.1821, -1.6582,  0.4624],
        [-0.6257, -1.6012,  1.1177],
        [-1.3440,  2.2563, -0.6144],
        [-0.4732, -1.7268,  1.0264],
        [-0.8369,  0.6613, -0.4473],
        [-1.3684,  0.7770,  0.4170],
        [-0.6815, -1.0210,  0.5245],
        [-2.2843,  4.6618, -1.1745],
        [-2.0317,  3.9877, -0.9651],
        [-2.2605,  3.9699, -0.6048],
        [-1.8348,  3.7213, -1.0261],
        [-1.7387,  2.2869,  0.0182],
        [-1.8347,  2.1098,  0.1241],
        [-0.4925, -1.6077,  0.8201],
        [-2.1145,  3.7753, -0.5307],
        [-1.4216,  2.2292, -0.6322],
        [-1.6344,  1.4025,  0.4128],
        [-1.0836, -1.0933,  1.5075],
        [-2.0600,  4.1303, -0.9261],
        [-1.7518,  2.6830, -0.3340],
        [-2.0612,  3.6392, -0.5570],
        [-1.8441,  3.5390, -0.7851],

Epoch number 2
 Current loss 0.31074458360671997

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.3664, -1.8288,  0.9443],
        [-1.9749,  3.9304, -1.0480],
        [-0.1746, -1.7188,  0.4317],
        [-0.3061, -1.8828,  0.8029],
        [-1.8759,  3.3338, -0.6495],
        [-0.5388, -0.5572,  0.1404],
        [-2.0839,  3.8682, -0.8299],
        [-0.7815, -1.7735,  1.5986],
        [-1.8564,  3.6139, -0.9728],
        [-1.9420,  3.7352, -0.8141],
        [-1.9228,  1.9966,  0.3763],
        [-1.9646,  3.0492, -0.3060],
        [-2.2858,  4.5045, -1.0596],
        [-1.8653,  3.7978, -1.0677],
        [-1.8793,  3.3848, -0.6210],
        [-1.7469,  3.5698, -1.0913],
        [-2.1016,  3.9609, -0.8627],
        [-1.9894,  4.0929, -1.2174],
        [-2.0969,  4.1139, -0.9768],
        [-1.2148,  2.1196, -1.0395],
        [-0.5700, -1.9547,  1.4852],
        [-0.7716,  0.3295, -0.2976],
        [-0.6360, -1.9248,  1.4321],
        [-2.1426,  4.2154, -0.9917]

Epoch number 2
 Current loss 0.4033747613430023

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.2038,  1.8591, -0.7952],
        [-0.0783, -1.7967,  0.4744],
        [-2.0419,  4.1110, -0.9865],
        [-0.3839, -1.9896,  1.0181],
        [-1.9791,  4.0268, -1.2449],
        [-0.9666, -1.2741,  1.5591],
        [-1.9657,  3.4850, -0.6819],
        [-1.4608,  2.2890, -0.5299],
        [-0.7653, -0.5096,  0.4782],
        [-0.0827, -1.0552, -0.2449],
        [-0.6381,  0.4568, -0.5670],
        [-1.3624,  1.2348, -0.0117],
        [-1.6220,  3.1809, -1.0003],
        [-2.1067,  2.6280,  0.1601],
        [-0.4399, -2.0245,  1.1959],
        [-0.8348, -1.7086,  1.7009],
        [-1.9710,  3.1997, -0.3769],
        [-2.0754,  4.1139, -1.1342],
        [-2.1351,  4.3017, -1.2258],
        [-1.9219,  3.5229, -0.6894],
        [-0.9228, -0.3341,  0.6852],
        [-1.8511,  3.6330, -0.9307],
        [-0.7537, -1.6908,  1.5191],
        [-2.0868,  4.3301, -1.2085],

Epoch number 2
 Current loss 0.2891773283481598

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.2201, -1.2386,  1.8649],
        [-1.9565,  3.9291, -1.1108],
        [-2.3701,  4.7071, -1.2442],
        [-0.7719, -1.9041,  1.7279],
        [-1.9462,  3.7328, -0.9894],
        [-2.1733,  4.4474, -1.2761],
        [-2.0868,  4.1534, -1.0413],
        [-0.7615,  0.1179, -0.1420],
        [-0.4339, -0.8896,  0.1579],
        [-1.8317,  3.1075, -0.5836],
        [-2.2370,  1.7173,  0.9835],
        [-1.7129,  2.4851, -0.3324],
        [-2.0532,  3.8402, -0.8920],
        [-2.2074,  4.2631, -0.9850],
        [-2.0679,  4.2368, -1.2958],
        [-1.9546,  4.0443, -1.1969],
        [-1.8682,  2.5933, -0.1926],
        [-1.8574,  3.1494, -0.6613],
        [-0.4498, -0.4889, -0.0893],
        [-1.6140,  1.3446,  0.4291],
        [-1.6001,  1.3987,  0.3724],
        [-1.7248,  3.0476, -0.6934],
        [-0.9743, -1.7690,  1.8594],
        [-0.6545, -0.7769,  0.7837],

Epoch number 2
 Current loss 0.2463313490152359

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.0870, -1.1367,  1.5016],
        [-0.9417, -0.6795,  0.9120],
        [-1.2598, -0.4876,  1.4123],
        [-0.8335,  0.2372, -0.1842],
        [-1.8171,  3.4771, -0.9605],
        [-2.0463,  3.4057, -0.5636],
        [-1.4096,  2.6957, -1.1543],
        [-1.2117, -0.4999,  1.2762],
        [-0.2955, -1.3126,  0.2353],
        [-2.2436,  4.4661, -1.2423],
        [-1.4343, -0.2245,  1.2869],
        [-0.8176, -1.7141,  1.5321],
        [-2.4160,  4.2536, -0.6624],
        [-1.8354,  3.4203, -0.8678],
        [-2.1715,  4.2836, -1.1229],
        [-2.0282,  3.7006, -0.7037],
        [-0.1164, -1.7160,  0.2604],
        [-1.8024,  2.4798, -0.2017],
        [-2.0843,  0.8655,  1.3659],
        [-2.0005,  2.6217,  0.0492],
        [-0.5905, -1.5801,  1.0367],
        [-2.3571,  4.6420, -1.1831],
        [-2.0774,  4.1276, -1.2226],
        [-0.5463, -1.9889,  1.4435],

Epoch number 2
 Current loss 0.25492697954177856

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.6606, -1.7116,  1.4403],
        [-0.2591, -2.1889,  0.8668],
        [-1.6919, -0.5021,  1.8725],
        [-2.1139,  4.1617, -1.1252],
        [-2.0807,  4.0262, -1.0982],
        [-0.1764, -1.5075,  0.0049],
        [-2.2101,  3.3022, -0.2521],
        [-1.5043,  2.3200, -0.6852],
        [-2.3670,  4.2717, -0.8296],
        [-2.1041,  3.4638, -0.4995],
        [-1.8941,  1.9209,  0.3612],
        [-1.0094, -0.4810,  0.9067],
        [-2.0855,  3.8323, -0.8731],
        [-1.1838, -0.3549,  0.9192],
        [-1.9956,  3.9403, -1.1469],
        [-2.4591,  4.8948, -1.2995],
        [-2.0798,  4.1951, -1.1013],
        [-2.1684,  3.7527, -0.6088],
        [-2.3805,  3.3637, -0.0680],
        [-2.0267,  4.1146, -1.1992],
        [-2.2417,  4.4807, -1.2594],
        [-2.0689,  4.0574, -1.0981],
        [-1.7391,  3.2615, -0.8094],
        [-2.4356,  4.1912, -0.6508]

Epoch number 2
 Current loss 0.24576793611049652

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.2354,  3.6226, -0.4365],
        [-1.0942,  1.4166, -0.6819],
        [-2.2159,  3.1540, -0.0490],
        [-1.9348,  2.9037, -0.3623],
        [-2.3488,  4.4020, -1.0001],
        [-2.2642,  2.6006,  0.4377],
        [-2.3215,  4.6600, -1.3616],
        [-2.0733,  4.0315, -0.9212],
        [-2.5428,  4.8496, -1.0953],
        [-2.1894,  3.6356, -0.5425],
        [-1.6805,  0.3427,  1.2996],
        [-0.8096, -1.7375,  1.6374],
        [-1.0310, -1.6355,  1.9179],
        [-2.1386,  4.1590, -1.0660],
        [-1.9874,  3.8008, -0.9809],
        [-2.0448,  4.0255, -1.0353],
        [-1.2790, -0.7494,  1.6577],
        [-1.4638, -1.0019,  2.0227],
        [-2.3317,  4.5653, -1.1772],
        [-2.1339,  3.7396, -0.5801],
        [-2.0863,  4.1757, -1.2100],
        [-0.0705, -1.9261,  0.4048],
        [-0.9290, -0.5231,  0.5713],
        [-2.4505,  4.2548, -0.6584]

Epoch number 2
 Current loss 0.3556882441043854

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.1530, -1.9699,  0.4501],
        [-1.7724,  3.3482, -1.0173],
        [-0.6283, -1.7143,  1.2957],
        [-2.2131,  4.4021, -1.3222],
        [-2.2305,  2.6735,  0.2433],
        [ 0.0328, -1.9499,  0.1350],
        [-0.9638, -1.6808,  1.8563],
        [-1.0167, -1.9510,  2.1491],
        [-2.0723,  3.4380, -0.5403],
        [-2.4258,  4.7395, -1.2678],
        [-2.1306,  4.1982, -1.1444],
        [-2.2547,  4.4485, -1.3125],
        [-2.1955,  2.7487,  0.0459],
        [-2.3970,  3.8790, -0.4197],
        [-0.0157, -2.0309,  0.2431],
        [-2.1514,  0.9131,  1.4345],
        [-1.2415, -0.3638,  1.1333],
        [-2.2705,  4.5687, -1.4070],
        [-0.0938, -2.1577,  0.4522],
        [-2.3107,  4.3324, -0.9739],
        [-2.3834,  3.8803, -0.5121],
        [-2.3038,  4.4266, -1.1205],
        [-1.4180,  2.3497, -0.7681],
        [-2.5275,  3.8404, -0.2417],

Epoch number 2
 Current loss 0.3134291470050812

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.2427e+00,  4.3919e+00, -1.3181e+00],
        [-1.8558e+00,  2.1260e+00,  1.4564e-01],
        [-2.0499e+00,  3.9533e+00, -1.2172e+00],
        [-2.5475e+00,  3.9645e+00, -2.5638e-01],
        [-1.9832e+00,  2.4391e+00, -3.1432e-03],
        [-7.2572e-02, -2.1717e+00,  5.4593e-01],
        [-2.2580e+00,  3.1648e+00, -1.8387e-01],
        [-2.0670e+00,  1.0771e+00,  1.1543e+00],
        [-2.1446e+00,  4.0593e+00, -9.1086e-01],
        [-2.1194e+00,  4.2375e+00, -1.3076e+00],
        [-2.1012e+00,  3.1529e+00, -3.2401e-01],
        [-1.3556e+00,  3.7930e-01,  7.2652e-01],
        [-2.3814e+00,  3.0688e+00,  7.3657e-02],
        [-1.3709e+00, -5.7854e-01,  1.6678e+00],
        [-2.3978e+00,  4.7552e+00, -1.3607e+00],
        [-2.4130e+00,  4.4549e+00, -9.2336e-01],
        [-2.0359e+00,  3.6805e+00, -8.5895e-01],
        [-1.9522e+00,  2.6281e+00, -9.3448e-02],
     

Epoch number 2
 Current loss 0.23531018197536469

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.0586, -1.5991,  0.0554],
        [-0.8100,  0.6060, -0.6410],
        [-2.1892,  4.1037, -0.9122],
        [-2.1860,  4.3500, -1.2785],
        [-2.5307,  4.7647, -1.1383],
        [-1.8360,  3.6042, -1.1295],
        [-1.5670, -0.0831,  1.4294],
        [-0.3711, -1.8269,  0.8466],
        [-2.3315,  4.3920, -1.1386],
        [-2.4945,  4.7287, -1.1435],
        [-0.4220, -2.0075,  1.2020],
        [-2.4246,  4.4234, -0.9128],
        [-0.5496, -1.8850,  1.4634],
        [-2.3632,  4.7220, -1.4080],
        [-2.5223,  4.0434, -0.4664],
        [-2.4680,  4.7469, -1.2297],
        [-1.3364,  2.3739, -0.9095],
        [-2.3654,  2.8351,  0.2581],
        [-2.1080,  4.0992, -1.1152],
        [-2.0953,  3.0905, -0.2692],
        [-1.5737,  0.2022,  1.2924],
        [-2.3038,  3.1407, -0.0580],
        [-2.3732,  4.4246, -0.9944],
        [-2.4122,  4.6124, -1.2149]

Epoch number 2
 Current loss 0.2957366406917572

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.9710,  1.1193,  1.0816],
        [-2.1720,  4.1294, -1.1287],
        [-2.3461,  3.2655, -0.0649],
        [-0.3558, -2.0576,  1.1813],
        [-2.2295,  4.2900, -1.1029],
        [-0.3062, -1.0815, -0.0534],
        [ 0.0538, -1.8127,  0.2520],
        [-2.2887,  1.1764,  1.3169],
        [ 0.0390, -2.1514,  0.3352],
        [-0.6565, -1.4932,  1.0218],
        [-2.2563,  4.0136, -0.9363],
        [-2.0283,  3.7283, -1.0433],
        [-1.3136,  2.3996, -1.1607],
        [-1.8268,  0.7184,  1.1743],
        [-2.0748,  4.1257, -1.2295],
        [-2.2499,  4.4340, -1.3037],
        [-2.1565,  3.8012, -0.7155],
        [-0.3608, -1.7506,  0.8577],
        [-2.1594,  4.1059, -1.0601],
        [-1.9757,  3.6892, -1.1578],
        [-2.1562,  4.0411, -0.9626],
        [-2.2394,  3.3781, -0.3431],
        [-2.0791,  4.1937, -1.3495],
        [-2.0714,  1.9396,  0.5208],

Epoch number 2
 Current loss 0.13018698990345

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.0756,  2.2777,  0.3652],
        [-2.4949,  4.9480, -1.4356],
        [-0.5392, -1.5236,  1.0035],
        [-2.2668,  4.2156, -1.0579],
        [-0.9907,  0.8870, -0.4979],
        [-1.6669,  1.4348,  0.2864],
        [-1.8412,  0.4087,  1.4814],
        [-2.3935,  4.6854, -1.3478],
        [-1.5890,  0.7168,  0.9367],
        [-2.1452,  4.2656, -1.3015],
        [-2.3914,  4.7761, -1.5151],
        [-2.0722,  4.0266, -1.2329],
        [-2.4717,  4.7890, -1.3723],
        [-2.2317,  1.5110,  1.1426],
        [-2.4766,  3.2965,  0.1655],
        [-2.3993,  4.7056, -1.4078],
        [-1.9077,  1.7845,  0.5239],
        [-1.9343,  3.3434, -0.9525],
        [-2.2885,  3.7722, -0.5978],
        [-2.2604,  4.1365, -0.9416],
        [-2.4273,  4.7041, -1.2633],
        [-2.7057,  4.9460, -1.0776],
        [-2.3468,  4.6084, -1.3202],
        [-0.4091, -1.1811,  0.3286],
 

Epoch number 2
 Current loss 0.2186216115951538

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 0.0657, -2.3537,  0.5441],
        [-1.2621, -1.0749,  2.0292],
        [-1.2627,  0.3676,  0.7979],
        [-1.2844,  1.3694, -0.3335],
        [-1.4019,  0.8012,  0.2752],
        [-2.7192,  3.5795,  0.1621],
        [-0.6988, -1.5518,  1.4285],
        [-2.6625,  4.4966, -0.7151],
        [-1.6534,  0.1104,  1.6418],
        [-1.8486,  2.5870, -0.3608],
        [-2.3341,  4.6317, -1.4020],
        [-1.8071,  3.3827, -0.9891],
        [-2.2761,  4.1205, -0.9529],
        [-1.3735, -1.9304,  2.8969],
        [-2.5821,  4.9748, -1.3096],
        [-1.0490, -1.5500,  2.0322],
        [-2.7848,  4.6413, -0.7020],
        [-1.1257, -1.0944,  1.8182],
        [-2.5087,  4.9981, -1.5582],
        [-1.9194,  2.8145, -0.4915],
        [-1.8971,  2.3290,  0.0336],
        [-0.9967, -0.3246,  0.7112],
        [-0.2159, -2.1985,  1.0610],
        [-1.9716,  2.7271, -0.2107],

Epoch number 2
 Current loss 0.3026135563850403

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.2921, -0.8336,  1.8078],
        [-1.4064, -0.5248,  1.8191],
        [-2.3290,  4.1830, -0.9058],
        [-1.9476,  0.3354,  1.9030],
        [-2.3589,  4.6748, -1.3581],
        [-2.1651,  4.1980, -1.3049],
        [-1.5244,  0.1670,  1.4057],
        [-2.1757,  3.3030, -0.3619],
        [-2.3306,  4.6358, -1.5180],
        [-2.2454,  3.1288, -0.1942],
        [-2.1731,  3.5736, -0.5883],
        [-1.4121, -0.8083,  2.0340],
        [-2.2200,  2.4338,  0.4925],
        [-1.9182,  2.9761, -0.6099],
        [-0.1486, -1.3376,  0.1162],
        [-2.4837,  4.5197, -0.9541],
        [-1.3135,  2.1077, -0.8337],
        [-1.9455,  3.4864, -0.9568],
        [-1.1801, -0.5836,  1.4255],
        [-0.4913, -1.5358,  0.8798],
        [-2.1401,  3.9324, -0.8958],
        [-0.7826, -1.8208,  1.8406],
        [-2.5106,  5.0806, -1.6569],
        [-2.1281,  3.0996, -0.2021],

Epoch number 2
 Current loss 0.3209376335144043

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.2869,  4.6336, -1.5755],
        [-0.8305, -1.2633,  1.3735],
        [-2.7061,  5.2065, -1.4536],
        [-2.5345,  4.8128, -1.3103],
        [-2.4139,  3.7229, -0.4266],
        [-2.5034,  4.7412, -1.3936],
        [-2.4733,  3.7254, -0.3504],
        [-2.2173,  4.2598, -1.3032],
        [-1.7421,  2.3544, -0.3001],
        [-0.0530, -1.2856, -0.2480],
        [-2.4452,  4.5399, -1.0913],
        [-2.5222,  4.5734, -1.0901],
        [-2.0172,  2.8123, -0.2391],
        [-1.7505,  1.8568,  0.1081],
        [-2.5253,  4.7120, -1.2281],
        [-2.3749,  4.4348, -1.1511],
        [-1.4087, -1.8857,  2.9489],
        [-2.1650,  4.0886, -1.0855],
        [ 0.0263, -2.4261,  0.5920],
        [-2.5675,  4.6362, -1.0892],
        [ 0.1330, -2.2561,  0.2849],
        [-2.5939,  4.7892, -1.1950],
        [-2.7685,  5.0634, -1.1059],
        [-1.7888,  2.9819, -0.7046],

Epoch number 2
 Current loss 0.34117385745048523

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.2809, -2.2784,  1.2377],
        [-2.5802,  5.1304, -1.6751],
        [-2.6601,  4.8216, -1.1591],
        [-1.8842,  2.6393, -0.3786],
        [-2.2809,  4.2230, -1.1541],
        [ 0.2051, -2.2698,  0.2373],
        [-1.7696,  2.5866, -0.5887],
        [-2.3106,  4.5263, -1.4632],
        [-2.4128,  4.7976, -1.5575],
        [-1.4125,  2.2402, -0.7403],
        [-2.3695,  4.6236, -1.4776],
        [-2.4384,  4.8491, -1.5717],
        [-2.5635,  4.8154, -1.3106],
        [-1.6794, -0.3604,  1.9333],
        [-2.5899,  5.0513, -1.5158],
        [-2.4762,  4.5374, -1.1251],
        [-1.5566,  0.9992,  0.6040],
        [-3.1758,  3.3489,  0.7033],
        [-0.6972, -1.9767,  1.5351],
        [-2.7426,  4.6530, -0.8052],
        [ 0.1285, -2.1429,  0.2989],
        [-1.8934, -0.0557,  1.9219],
        [-2.3914,  4.5923, -1.3336],
        [-1.3699, -0.4060,  1.5181]

Epoch number 2
 Current loss 0.2896365523338318

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.5748,  5.0095, -1.4972],
        [-2.5308,  4.8483, -1.4544],
        [-1.5203,  0.5244,  0.8773],
        [ 0.2636, -2.0994, -0.0670],
        [-2.2570,  4.1232, -1.0557],
        [-2.5174,  4.8151, -1.3944],
        [-2.2583,  4.3757, -1.4935],
        [-2.6050,  4.6255, -1.0103],
        [-2.4963,  4.9814, -1.6347],
        [-2.6149,  5.1108, -1.5512],
        [-1.7241,  2.7305, -0.9013],
        [-1.3738, -1.0799,  2.1551],
        [-0.7514,  0.2271, -0.3495],
        [-0.5792,  0.3657, -0.7777],
        [-2.3643,  4.7596, -1.6914],
        [-2.2787,  4.5286, -1.6689],
        [-2.5569,  5.1039, -1.6669],
        [-2.6293,  5.0815, -1.5092],
        [ 0.1906, -2.2890,  0.2339],
        [-2.4277,  4.7398, -1.4983],
        [-2.6445,  5.3285, -1.7627],
        [ 0.2530, -2.4292,  0.1061],
        [-1.8582,  3.3032, -1.1566],
        [-2.2500,  4.4656, -1.5781],

Epoch number 2
 Current loss 0.34899577498435974

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.0385,  3.8090, -1.3595],
        [ 0.0738, -2.4254,  0.5733],
        [-2.0534,  1.2540,  1.1178],
        [-2.1345,  4.1106, -1.4850],
        [-2.0798,  3.9123, -1.3065],
        [-2.6206,  2.6686,  0.6140],
        [-1.8452,  3.2578, -1.1857],
        [-1.4375,  0.5224,  0.7513],
        [-0.9596, -1.6747,  1.9935],
        [-2.0881,  3.8832, -1.0909],
        [-2.6758,  5.2197, -1.5443],
        [-1.8081,  2.0960,  0.0273],
        [-2.4972,  4.9750, -1.6894],
        [-1.4199, -0.0406,  1.3089],
        [-2.2113,  4.4133, -1.7485],
        [ 0.3785, -2.3775, -0.1549],
        [-1.9186,  0.7507,  1.3991],
        [ 0.1602, -2.0981,  0.0223],
        [-1.9779,  3.4910, -1.0425],
        [ 0.1476, -2.3205,  0.3363],
        [-1.8663,  3.3205, -0.9345],
        [-0.9797, -2.3726,  2.7032],
        [-2.1091,  3.8644, -1.2669],
        [ 0.3310, -2.6541,  0.0320]

Epoch number 2
 Current loss 0.2885352373123169

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.9968, -2.3149,  2.6475],
        [-2.1354,  4.0841, -1.5136],
        [-2.3120,  4.6157, -1.6201],
        [ 0.2932, -2.1232, -0.1421],
        [-2.3778,  4.7957, -1.7523],
        [ 0.2581, -2.4952,  0.3219],
        [-0.2939, -2.3790,  1.2734],
        [-1.1717,  0.9336, -0.3880],
        [-1.6231,  2.6127, -0.9664],
        [-1.2228, -1.1914,  2.1199],
        [ 0.2733, -2.0142, -0.3154],
        [-1.2254, -1.1354,  1.8647],
        [-0.7488, -1.8405,  1.6525],
        [-1.4205,  0.9533,  0.2838],
        [ 0.1482, -1.4330, -0.6385],
        [ 0.3699, -2.5615, -0.0002],
        [-0.2248, -1.1177, -0.0944],
        [-2.2757,  4.3551, -1.4305],
        [-0.6598, -1.4949,  1.1735],
        [-2.1159,  4.0071, -1.6037],
        [-2.5379,  5.1028, -1.7499],
        [-2.4522,  4.1144, -0.7948],
        [-2.4855,  4.4981, -1.2328],
        [-0.4062, -2.0952,  1.1793],

Epoch number 2
 Current loss 0.31362348794937134

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.2360, -2.0212,  2.8423],
        [-1.2467, -0.0347,  0.7623],
        [-1.7568,  2.9532, -1.0748],
        [-2.1753,  3.9146, -1.1476],
        [-1.3352,  1.1883, -0.1522],
        [-1.7194,  2.1599, -0.3546],
        [ 0.3422, -2.1186, -0.2919],
        [-0.6529, -1.0486,  0.8636],
        [-1.7703,  2.8342, -0.9195],
        [-2.0170,  3.8922, -1.5625],
        [-1.8033,  2.7577, -0.6766],
        [-1.9353,  3.6873, -1.5440],
        [ 0.2658, -2.1841, -0.1198],
        [-0.6773, -2.3638,  2.1011],
        [-1.8974,  3.6036, -1.6524],
        [-2.2836,  4.5781, -1.6884],
        [-1.6326,  0.3307,  1.2907],
        [-1.7589,  3.1360, -1.3555],
        [-1.3979,  0.5302,  0.6774],
        [-2.2666,  3.5855, -0.5566],
        [ 0.3915, -2.6971,  0.1778],
        [-2.3209,  4.1855, -1.1468],
        [-1.5005,  2.5105, -1.0895],
        [-2.4332,  4.8826, -1.7831]

Epoch number 2
 Current loss 0.3414407968521118

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 0.2917, -2.4047, -0.0159],
        [ 0.2152, -2.8162,  0.5560],
        [-0.5578, -2.2861,  1.8775],
        [-2.1575,  2.3412,  0.4002],
        [-0.9962, -1.9852,  2.4446],
        [-0.7955,  0.4358, -0.4887],
        [-0.6930, -1.2054,  1.0179],
        [ 0.0114, -2.6351,  1.0270],
        [-1.1002, -1.7115,  2.3523],
        [-2.1791,  3.7756, -0.8778],
        [ 0.2185, -2.4425,  0.2268],
        [-2.4267,  4.8884, -1.7407],
        [-0.9785,  1.6502, -1.3659],
        [ 0.2545, -2.4946,  0.0419],
        [-0.2929, -2.2393,  1.0439],
        [-1.6925,  1.8148,  0.1047],
        [-1.8340,  2.6164, -0.5856],
        [-2.0706,  3.1738, -0.6006],
        [-0.4906, -2.6068,  2.0079],
        [-2.0630,  3.9159, -1.4199],
        [-1.6979,  2.7952, -0.9378],
        [-2.3973,  4.0615, -0.7849],
        [-2.2893,  4.4062, -1.5149],
        [-2.2991,  4.4419, -1.5461],

Epoch number 2
 Current loss 0.2313361018896103

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.3860,  4.6549, -1.5335],
        [-2.3235,  4.5288, -1.5666],
        [-2.5832,  4.9266, -1.4829],
        [-1.3856, -1.7393,  2.8848],
        [-2.4651,  4.6847, -1.4402],
        [-2.4686,  4.5241, -1.2138],
        [-2.6015,  4.0800, -0.4941],
        [-2.2793,  3.7457, -0.6699],
        [-1.6676,  1.6412,  0.0833],
        [-1.6101,  2.9784, -1.4211],
        [-2.6668,  4.8126, -1.1971],
        [-1.9859,  2.4098,  0.0306],
        [-0.3288, -2.3011,  1.2406],
        [-1.1761, -1.7019,  2.4764],
        [-0.4189, -1.8993,  1.1784],
        [-2.7715,  4.1704, -0.3675],
        [-2.3336,  4.5628, -1.5976],
        [-2.2063,  4.1776, -1.3308],
        [-1.1241, -1.9425,  2.5647],
        [-2.2970,  4.5471, -1.6302],
        [-0.8003, -2.2328,  2.2458],
        [-2.6116,  4.6038, -1.0917],
        [-2.5253,  5.0245, -1.7176],
        [-1.9853,  1.9508,  0.5189],

Epoch number 2
 Current loss 0.33199402689933777

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.6989,  5.1643, -1.5454],
        [-2.3400,  4.1679, -1.1351],
        [-1.2098, -2.0319,  2.7929],
        [ 0.1936, -2.2358,  0.0873],
        [-2.5360,  4.8554, -1.4831],
        [-1.9756,  1.0558,  1.3073],
        [-1.3071,  2.2322, -1.2072],
        [-2.4558,  3.5145, -0.1136],
        [-2.4333,  4.8188, -1.6020],
        [-2.1011,  3.9510, -1.3979],
        [-2.2148,  3.9079, -1.0606],
        [-2.4110,  2.5670,  0.5452],
        [-0.5617, -0.8615,  0.3968],
        [-1.4810, -1.8533,  3.0129],
        [-1.6942,  2.4349, -0.6483],
        [-2.1541,  3.3540, -0.5966],
        [-2.2671,  3.8779, -0.9561],
        [-1.1393,  0.8307, -0.3779],
        [-1.7626,  2.4051, -0.6244],
        [-2.4668,  4.4210, -1.1677],
        [-2.5283,  4.9285, -1.6287],
        [-0.3687, -2.4355,  1.5129],
        [-1.3123, -1.0508,  2.1145],
        [-2.1630,  3.1195, -0.4089]

Epoch number 2
 Current loss 0.23788893222808838

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.7731,  5.4088, -1.6374],
        [-0.1910, -1.6711,  0.3266],
        [-1.0865, -1.2845,  1.8047],
        [-2.6211,  5.0598, -1.6099],
        [-2.6841,  4.9917, -1.4008],
        [-2.3937,  4.5693, -1.4006],
        [-2.5147,  3.8250, -0.4833],
        [-1.6380,  2.8248, -1.0474],
        [-1.8814, -0.4314,  2.3194],
        [-2.6756,  1.7677,  1.4156],
        [-1.9873,  3.3403, -0.9747],
        [-1.3035, -1.9476,  2.8317],
        [-2.4578,  4.5772, -1.2749],
        [-1.6570, -0.9069,  2.4227],
        [-2.5745,  4.7082, -1.2884],
        [-1.0267, -0.5104,  1.0163],
        [-2.3313,  4.1012, -0.9624],
        [-1.2308, -1.4925,  2.3424],
        [-2.5090,  4.8979, -1.5079],
        [-2.6546,  4.0780, -0.3877],
        [-2.4295,  4.3995, -1.1421],
        [-1.7590, -0.4301,  2.1755],
        [-2.8061,  5.2958, -1.4025],
        [-1.9263, -0.7977,  2.6871]

Epoch number 2
 Current loss 0.16558778285980225

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 0.1949, -2.3482,  0.1335],
        [-0.2065, -2.1346,  0.9580],
        [-2.3876,  4.6284, -1.5563],
        [-2.3990,  3.7243, -0.4604],
        [-2.7128,  5.2942, -1.7153],
        [-2.7579,  5.4291, -1.7179],
        [-2.4035,  4.6618, -1.6013],
        [-2.7847,  5.2152, -1.4621],
        [-1.5956,  0.2208,  1.3001],
        [-2.6540,  4.4350, -0.7673],
        [-2.4789,  4.7768, -1.4986],
        [-2.5253,  4.7838, -1.5417],
        [-2.7281,  5.2064, -1.5000],
        [-1.9704, -0.2389,  2.1930],
        [-2.6282,  4.9567, -1.4318],
        [-2.5253,  4.7975, -1.5095],
        [-0.1866, -1.6140,  0.3223],
        [-1.1700,  0.7470, -0.0034],
        [-2.6347,  5.2249, -1.7118],
        [ 0.1400, -2.1387,  0.2016],
        [-2.7924,  5.4077, -1.5470],
        [-1.8286,  3.1827, -1.1860],
        [-2.8064,  5.4254, -1.6523],
        [-2.4603,  4.4517, -1.2838]

Epoch number 2
 Current loss 0.28111255168914795

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.8441, -1.4645,  1.5703],
        [-2.7569,  5.4021, -1.6812],
        [-2.1777,  4.2120, -1.5017],
        [-2.7161,  3.2232,  0.3901],
        [-1.6019,  1.2965,  0.1375],
        [-1.3875, -1.4103,  2.4724],
        [-0.9552, -1.5688,  1.9715],
        [-1.3727, -1.3873,  2.5278],
        [-2.3621,  4.4985, -1.3778],
        [-2.6881,  4.3741, -0.6080],
        [-2.1824,  0.3600,  2.0533],
        [-0.7934, -1.7997,  1.8103],
        [-0.8488, -0.5016,  0.6876],
        [-1.8620, -0.3592,  2.2673],
        [-1.6897, -0.6326,  2.2501],
        [-2.6313,  5.1140, -1.5832],
        [-2.6321,  4.5072, -0.8205],
        [-1.3632, -1.0447,  2.2675],
        [-2.6725,  5.2234, -1.7208],
        [-2.5458,  4.6184, -1.1759],
        [-1.6162,  0.1769,  1.5003],
        [-1.7007,  1.8906,  0.0081],
        [-1.4455, -0.6481,  1.8919],
        [-0.1329, -1.2063, -0.2368]

Epoch number 2
 Current loss 0.29059943556785583

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.6843,  5.2549, -1.6389],
        [-2.6545,  5.1844, -1.7451],
        [-0.1449, -2.3862,  1.1965],
        [-1.8430,  2.8069, -0.6556],
        [-0.3392, -1.9151,  1.0590],
        [-1.6920,  1.5004,  0.3736],
        [-2.1153,  2.8476, -0.1533],
        [-0.9125, -1.3435,  1.6707],
        [-1.4064, -0.7927,  2.0768],
        [-2.5145,  1.9652,  0.9904],
        [-2.4607,  4.6098, -1.2621],
        [-1.2638, -1.7178,  2.5633],
        [-2.5744,  4.9804, -1.5675],
        [-1.6430,  1.6716, -0.0524],
        [-2.6520,  4.9061, -1.3274],
        [-2.1348,  3.6721, -1.0092],
        [-1.2123, -2.2756,  3.0012],
        [-2.4849,  4.8185, -1.5173],
        [-2.7315,  5.2978, -1.5695],
        [-2.1842,  1.7291,  0.9467],
        [-2.3305,  1.1292,  1.5917],
        [-2.3063,  2.7198,  0.2284],
        [-1.2128,  0.8585, -0.0923],
        [-2.3162,  2.3502,  0.5585]

Epoch number 2
 Current loss 0.25164586305618286

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.6992,  4.9101, -1.1992],
        [-2.8217,  5.5673, -1.7715],
        [-2.6085,  5.0020, -1.6086],
        [-0.3586, -1.0884,  0.1962],
        [-2.3496,  4.4191, -1.3705],
        [-2.4470,  4.7767, -1.6321],
        [-1.1074, -2.0277,  2.6102],
        [-2.4249,  2.6116,  0.4109],
        [-2.4929,  4.7463, -1.5267],
        [-2.1742,  3.9028, -1.1758],
        [-2.2240,  3.7337, -1.0262],
        [-2.6757,  5.1258, -1.5587],
        [-2.3059,  4.3101, -1.2920],
        [-1.1916, -2.0519,  2.7740],
        [-0.4976, -1.9762,  1.4795],
        [-0.3239, -2.1337,  1.3320],
        [-2.6551,  4.9840, -1.3992],
        [-2.0860,  3.2882, -0.7083],
        [-0.3471, -2.1234,  1.2641],
        [-2.6824,  4.9215, -1.3016],
        [-0.8668, -1.9883,  2.2354],
        [-0.7022, -0.7483,  0.5426],
        [ 0.2296, -2.3511,  0.1741],
        [-0.0169, -2.1674,  0.6375]

Epoch number 2
 Current loss 0.2075171321630478

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.5041,  4.8479, -1.5760],
        [-2.2687,  2.8999,  0.1246],
        [-2.4346,  4.4672, -1.1915],
        [-2.6751,  5.2600, -1.7212],
        [-2.6358,  5.1735, -1.7191],
        [-2.5566,  4.5417, -1.1474],
        [-2.7531,  5.2813, -1.5352],
        [-2.8783,  5.3465, -1.3875],
        [-0.4453, -1.2971,  0.6762],
        [-0.9036, -1.8592,  2.2175],
        [-2.8388,  5.5577, -1.7215],
        [-0.0908, -1.2649, -0.2849],
        [-0.0621, -2.4553,  0.8941],
        [-2.5896,  4.8812, -1.3499],
        [-2.5948,  4.8316, -1.3773],
        [-2.3712,  4.4602, -1.3933],
        [-0.6881, -1.7830,  1.5222],
        [-2.4222,  4.7026, -1.5878],
        [-2.4426,  4.6612, -1.4845],
        [-1.1589, -0.3318,  0.9535],
        [-2.5715,  4.7260, -1.3959],
        [-0.7981, -0.7144,  0.7314],
        [-0.4899, -0.6425,  0.0714],
        [-2.8441,  5.4462, -1.5740],

Epoch number 2
 Current loss 0.19133782386779785

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.9153,  5.1590, -1.1297],
        [-1.6460,  1.3486,  0.3112],
        [-2.4619,  4.6201, -1.4347],
        [-1.6190,  1.1487,  0.5438],
        [-2.2062,  3.9909, -1.1376],
        [-0.8017, -1.9415,  2.0910],
        [-0.2405, -2.4268,  1.4741],
        [-0.7454,  0.7492, -0.7483],
        [-2.4084,  4.6424, -1.5730],
        [ 0.2497, -2.5283,  0.2374],
        [-2.5199,  4.9598, -1.6402],
        [-0.3952, -2.4678,  1.6381],
        [-2.9081,  5.1731, -1.1101],
        [-2.3835,  4.4682, -1.4491],
        [-2.5975,  5.0923, -1.6915],
        [-2.7289,  5.2509, -1.6591],
        [-1.9997,  3.8691, -1.5708],
        [-1.6469,  2.8444, -1.1095],
        [-2.4932,  4.6876, -1.4825],
        [-1.4254,  0.6596,  0.4466],
        [-2.4432,  4.8063, -1.7130],
        [-2.5899,  5.0163, -1.6469],
        [-2.5928,  5.0783, -1.7118],
        [-2.2018,  3.3845, -0.4501]

Epoch number 2
 Current loss 0.22977620363235474

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.7537,  5.4365, -1.7481],
        [-2.2791,  2.8163, -0.0211],
        [-2.5031,  4.8737, -1.6186],
        [-0.5753, -1.6156,  1.1393],
        [-0.0569, -1.9239,  0.4173],
        [-0.6936, -0.4224,  0.3880],
        [-2.3018,  2.7260,  0.2833],
        [-2.2444,  4.1785, -1.4121],
        [-2.2906,  3.9420, -0.9472],
        [-0.6327, -0.8196,  0.5054],
        [-2.5935,  5.1089, -1.7829],
        [-2.7577,  5.2899, -1.5596],
        [-2.7518,  4.9005, -1.0639],
        [-1.8781,  2.2471, -0.1991],
        [-2.3645,  4.3561, -1.3544],
        [-0.5108, -2.4850,  2.1129],
        [-0.3054, -2.7167,  1.7294],
        [-2.2922,  4.4473, -1.5080],
        [-1.0749, -1.2390,  1.9058],
        [-1.3194, -0.4185,  1.5968],
        [-2.5115,  4.9405, -1.5547],
        [-1.4257, -1.3456,  2.5477],
        [-2.7580,  4.9901, -1.3181],
        [-2.5959,  4.9574, -1.5935]

Epoch number 2
 Current loss 0.17490410804748535

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.6884, -1.9756,  1.6618],
        [-2.2342,  4.0994, -1.4022],
        [-1.6514,  1.5897,  0.1118],
        [-2.2251,  3.7129, -0.9947],
        [-1.2447, -1.2250,  2.1847],
        [-2.7876,  4.7979, -0.9388],
        [-2.5543,  4.7356, -1.3827],
        [-1.1523,  1.1084, -0.3691],
        [-2.3643,  4.5537, -1.4721],
        [-1.7411,  0.2469,  1.6262],
        [-2.4792,  4.5254, -1.2750],
        [-2.6655,  5.2060, -1.6882],
        [-1.1127, -2.3107,  2.9397],
        [-2.1440,  3.7738, -1.2302],
        [ 0.1884, -2.8568,  0.5609],
        [-2.2489,  3.9162, -1.1997],
        [-2.2318,  3.6367, -0.6986],
        [-2.7837,  5.2519, -1.6195],
        [-2.0144,  3.7652, -1.4772],
        [-1.1331, -0.6963,  1.4642],
        [-2.4871,  4.5253, -1.3036],
        [-2.5463,  4.8195, -1.5519],
        [-0.7962, -0.4915,  0.4410],
        [-2.3546,  4.4644, -1.6387]

Epoch number 2
 Current loss 0.24782609939575195

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.2829,  3.8478, -0.8791],
        [-0.5360, -0.8742,  0.3687],
        [-2.4213,  4.6580, -1.5851],
        [-2.1579,  4.2512, -1.6460],
        [-0.8267, -0.3410,  0.3689],
        [-1.9179,  0.1902,  1.9194],
        [-2.0263,  2.3333, -0.1594],
        [-2.4265,  4.5886, -1.4210],
        [-2.3052,  3.6776, -0.5724],
        [-0.9822, -0.3040,  0.5345],
        [-2.3923,  4.6043, -1.5563],
        [-0.5401, -2.3110,  1.8753],
        [ 0.1484, -2.3707,  0.1726],
        [ 0.0155, -2.0368,  0.1654],
        [-2.1038,  3.9125, -1.2743],
        [-2.1190,  3.9376, -1.4916],
        [-2.4165,  4.6546, -1.5544],
        [-2.6574,  4.9633, -1.4221],
        [-1.1695, -0.1539,  1.0277],
        [-1.0328, -2.1418,  2.6883],
        [ 0.0344, -1.2314, -0.5545],
        [-2.5681,  4.8767, -1.4998],
        [-2.7063,  4.9712, -1.3628],
        [-0.3953, -1.5490,  0.7405]

Epoch number 2
 Current loss 0.2832638621330261

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.5962, -2.2099,  1.9765],
        [-0.9654, -0.2272,  0.2966],
        [-1.6948,  3.2164, -1.5154],
        [-2.0302,  1.4889,  0.9147],
        [-2.4980,  3.3758,  0.0040],
        [-1.9868,  0.3335,  1.8965],
        [-1.8766,  2.5794, -0.4555],
        [-2.1939,  2.5495,  0.2742],
        [-2.2458,  3.8243, -1.0506],
        [-1.2628, -1.6999,  2.5674],
        [-2.7481,  5.3328, -1.7048],
        [-1.5929, -0.9663,  2.4209],
        [-0.1744, -2.3745,  1.2738],
        [-2.4025,  4.3417, -1.2973],
        [-2.0065,  2.8637, -0.5632],
        [-0.8020, -2.4933,  2.6487],
        [-2.6628,  4.1386, -0.6053],
        [-1.1785, -1.7920,  2.6038],
        [-1.4327,  0.8246,  0.2750],
        [-2.9485,  5.7553, -1.8021],
        [ 0.2097, -2.8001,  0.4539],
        [-2.1258,  3.9478, -1.5370],
        [-2.6398,  5.1355, -1.6663],
        [-1.8236,  2.5810, -0.6703],

Epoch number 2
 Current loss 0.23836159706115723

inputs are
torch.Size([100, 87])
torch.Size([100, 87])
OUTPUT
tensor([[-2.6389,  5.0775, -1.7515],
        [-1.1902,  0.2266,  0.5039],
        [-2.3015,  4.3216, -1.6073],
        [-1.0636, -1.8489,  2.4515],
        [-2.5647,  4.9429, -1.6415],
        [-2.2026,  3.6563, -0.9568],
        [-2.6125,  5.1314, -1.8556],
        [-2.5733,  4.9793, -1.7436],
        [-2.3848,  4.5533, -1.6461],
        [-2.0582,  3.5742, -1.1981],
        [-2.3357,  4.4803, -1.6342],
        [-0.4625, -2.1977,  1.5180],
        [-2.9972,  4.5608, -0.4381],
        [-2.4902,  4.6799, -1.3997],
        [-0.1906, -2.1741,  1.0258],
        [-2.5234,  2.2855,  0.8068],
        [-2.5552,  4.9018, -1.6158],
        [-1.8067,  1.7428,  0.0176],
        [ 0.1619, -2.4920,  0.4302],
        [-2.5949,  4.7329, -1.3907],
        [-1.9353,  3.1232, -1.0367],
        [-2.5996,  4.9034, -1.6162],
        [ 0.2823, -2.5752,  0.1086],
        [-2.7896,  5.2069, -1.4576],


Epoch number 2
 Current loss 0.25743794441223145

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.5607,  4.9371, -1.6253],
        [-0.9831, -2.2076,  2.5812],
        [-2.3467,  4.2519, -1.4656],
        [-2.7675,  5.3761, -1.7343],
        [-1.6706,  2.8343, -1.4132],
        [-1.9801,  2.9791, -0.7669],
        [-0.3184, -2.0817,  1.2187],
        [-0.5981, -2.2035,  1.7018],
        [-1.1892,  1.1420, -0.6715],
        [-2.5257,  4.4729, -1.1676],
        [-0.7044, -1.6802,  1.4576],
        [-2.3688,  4.2355, -1.3023],
        [-1.9514, -1.0152,  2.9032],
        [-2.4652,  4.2726, -1.1276],
        [-2.5873,  4.6389, -1.2023],
        [-0.2586, -1.8526,  0.8661],
        [-1.2194, -1.0747,  1.9043],
        [-0.5513, -2.5245,  2.0441],
        [-0.9609, -0.2830,  0.5046],
        [-2.4378,  4.5398, -1.5592],
        [-2.6012,  5.0557, -1.7592],
        [-2.7568,  5.0612, -1.4357],
        [-2.3946,  4.5589, -1.6134],
        [-2.5753,  4.9446, -1.6327]

Epoch number 3
 Current loss 0.10979176312685013

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.8528,  5.0751, -1.2934],
        [-2.8153,  5.0489, -1.3716],
        [-2.6066,  4.7241, -1.2852],
        [-0.9970,  1.2654, -0.6842],
        [-2.6919,  4.6280, -1.0563],
        [-2.8434,  5.3202, -1.5988],
        [-1.3284, -0.5167,  1.5734],
        [-2.7534,  5.2335, -1.6256],
        [-0.8033, -2.2589,  2.3524],
        [-2.9082,  5.6612, -1.8444],
        [-2.4854,  4.6227, -1.6156],
        [-1.6683,  2.5263, -0.4702],
        [-0.0580, -2.3415,  0.6953],
        [-2.8506,  5.5804, -1.8831],
        [-2.3540,  4.3844, -1.4183],
        [-2.7857,  5.4219, -1.8121],
        [-2.5255,  0.8431,  1.9825],
        [-0.6050, -1.4299,  1.1533],
        [-2.4047,  4.6649, -1.7295],
        [-1.3232, -0.8307,  1.9633],
        [-0.9092, -2.4953,  2.7216],
        [-2.6544,  4.9210, -1.4385],
        [-0.3715, -2.4961,  1.6379],
        [-2.6469,  4.4384, -0.8536]

Epoch number 3
 Current loss 0.16357256472110748

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.1207, -2.5857,  1.0665],
        [ 0.0157, -2.6640,  1.0200],
        [-1.5068, -1.6099,  2.8984],
        [-2.9861,  5.7754, -1.8393],
        [-2.7165,  5.1911, -1.7128],
        [-2.0932,  3.9240, -1.4945],
        [-1.2284, -1.3616,  2.3663],
        [-2.7415,  4.5493, -0.8587],
        [-2.3585,  4.0944, -1.1770],
        [-2.2496,  4.1530, -1.3242],
        [-2.3841,  4.3930, -1.5280],
        [-2.3147,  3.6595, -0.6670],
        [-2.6798,  5.0777, -1.7456],
        [ 0.0010, -2.0318,  0.1423],
        [-2.9624,  5.7039, -1.7979],
        [-0.0029, -2.6759,  1.0609],
        [-2.3820,  4.3560, -1.4393],
        [-1.8579,  3.4593, -1.4518],
        [-2.5338,  4.7231, -1.5618],
        [ 0.2234, -2.4371, -0.0281],
        [-2.7723,  5.3695, -1.8043],
        [-0.2283, -2.2959,  1.0542],
        [-1.4543, -1.8121,  3.0518],
        [-2.3727,  4.2047, -1.3370]

Epoch number 3
 Current loss 0.16656813025474548

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.7322,  5.1277, -1.6033],
        [-2.7086,  4.5648, -0.8917],
        [ 0.1579, -2.8227,  0.6435],
        [-1.2915, -1.8390,  2.8773],
        [-1.1913,  0.8546, -0.0513],
        [-2.6501,  4.9430, -1.5108],
        [-2.5942,  5.0888, -1.8944],
        [-2.1878,  3.4372, -0.7230],
        [-1.1538,  0.3973,  0.2544],
        [-1.2406, -1.7377,  2.6264],
        [-1.9250,  2.7225, -0.6087],
        [-1.0374, -1.1784,  1.6481],
        [-2.6633,  4.8970, -1.4784],
        [ 0.3831, -2.2280, -0.3272],
        [-2.9726,  5.1050, -1.0832],
        [-1.7279,  1.6320,  0.1249],
        [-2.6228,  4.7009, -1.2695],
        [ 0.2038, -2.4448,  0.2321],
        [-2.1885,  3.5341, -0.9902],
        [-2.9434,  4.8626, -0.8391],
        [-2.6330,  4.9946, -1.5828],
        [-2.7995,  5.1752, -1.5429],
        [-0.3891, -0.9674,  0.0677],
        [-1.8738,  3.2169, -1.3158]

Epoch number 3
 Current loss 0.2073354721069336

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.4000, -2.4724,  1.6384],
        [ 0.3950, -2.7080, -0.0524],
        [-0.3556, -1.5315,  0.6329],
        [-2.3264,  4.1989, -1.5134],
        [-1.9426,  3.5791, -1.3098],
        [-1.2361, -1.4642,  2.4942],
        [-3.0201,  5.8563, -1.9064],
        [-1.1992,  1.2144, -0.4715],
        [-2.3438,  4.3447, -1.5245],
        [ 0.2901, -1.9590, -0.4684],
        [-2.5149,  4.8764, -1.7907],
        [-1.3280, -1.6648,  2.8412],
        [-2.8539,  5.5250, -1.8457],
        [-2.7409,  5.2506, -1.7112],
        [-2.7656,  5.3151, -1.7660],
        [-2.8525,  5.4500, -1.7491],
        [-2.7859,  5.3941, -1.9282],
        [-0.0734, -2.5553,  1.2031],
        [-2.4016,  4.4901, -1.5868],
        [-1.6356, -1.0569,  2.6327],
        [-2.5249,  4.4529, -1.2307],
        [-2.7535,  5.2143, -1.6615],
        [-2.2940,  4.4334, -1.7713],
        [-2.2161,  4.1293, -1.7957],

Epoch number 3
 Current loss 0.15970538556575775

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.7986,  1.7051,  0.3551],
        [-2.2674,  4.0771, -1.5009],
        [-2.5488,  4.5531, -1.4035],
        [ 0.3063, -2.1531,  0.0346],
        [-1.7133,  2.2810, -0.5930],
        [-0.7965, -2.3407,  2.3750],
        [-2.1617,  0.9851,  1.5388],
        [-2.4020,  4.2380, -1.5205],
        [-2.0154,  2.5645, -0.3523],
        [-3.0004,  5.8399, -1.9588],
        [-1.0546, -1.6959,  2.3241],
        [ 0.4633, -2.5719, -0.1176],
        [-2.2141,  3.5656, -1.1068],
        [-1.6614,  2.0261, -0.6483],
        [-2.9409,  4.9695, -1.0572],
        [-1.9798,  3.7340, -1.9021],
        [-3.0221,  5.7242, -1.7728],
        [-1.8524,  3.0139, -1.0308],
        [-2.8599,  5.5450, -1.9194],
        [-2.9434,  5.6759, -1.8721],
        [-2.0168,  3.4125, -1.4650],
        [-1.6580, -0.4178,  2.1618],
        [-1.0064, -0.7629,  0.9360],
        [-2.8612,  5.0884, -1.2962]

Epoch number 3
 Current loss 0.1790577918291092

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.3808, -2.2639,  1.4265],
        [-3.0616,  5.8039, -1.8056],
        [-2.9070,  5.5983, -1.9876],
        [-2.9627,  5.5934, -1.7914],
        [-1.0647, -0.8758,  1.4211],
        [-2.8659,  5.1188, -1.4585],
        [-1.4860, -1.8777,  3.1305],
        [-2.9365,  5.6897, -1.9639],
        [-2.4502,  3.5854, -0.5386],
        [-2.0210,  3.0771, -0.7776],
        [ 0.5819, -2.6502, -0.4043],
        [-2.8107,  5.3037, -1.7953],
        [-2.8944,  5.2389, -1.4948],
        [-2.4919,  4.7620, -1.7798],
        [-1.1326,  1.8398, -1.6345],
        [-2.6567,  4.4236, -1.0351],
        [-2.3938,  4.0822, -0.9067],
        [-2.4130,  4.2269, -1.3480],
        [ 0.0247, -2.9948,  1.3124],
        [-0.5054, -2.5011,  2.0224],
        [-2.7818,  4.5916, -1.0435],
        [-2.8759,  5.3161, -1.6137],
        [-1.4576, -1.7591,  2.9909],
        [-2.1242,  2.5541, -0.0030],

Epoch number 3
 Current loss 0.20643194019794464

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.9315, -2.0495,  2.3522],
        [-2.2830,  4.3465, -1.8885],
        [-0.6364, -1.4385,  0.9836],
        [-2.4903,  4.4207, -1.4353],
        [-2.4294,  4.1296, -1.1066],
        [-2.9085,  5.3331, -1.6005],
        [ 0.1171, -2.4425,  0.4363],
        [-2.6827,  4.9917, -1.5869],
        [-3.2399,  5.5971, -1.2390],
        [-1.1524, -2.2610,  3.0060],
        [-2.8733,  5.4893, -1.8236],
        [-1.4579, -1.9374,  3.1760],
        [-2.6154,  4.7939, -1.5438],
        [-0.6678, -2.7157,  2.5849],
        [-0.8519, -1.7921,  2.0199],
        [-2.5447,  4.4401, -1.3433],
        [-3.0138,  5.7323, -1.9338],
        [-1.0480,  0.8160, -0.7051],
        [ 0.0777, -2.8593,  0.8866],
        [-1.4391, -2.0600,  3.2640],
        [-2.0785,  3.5723, -1.0856],
        [-2.0323,  3.7935, -1.4756],
        [-2.8767,  5.3928, -1.7218],
        [-2.8555,  5.4346, -1.8400]

Epoch number 3
 Current loss 0.2134752869606018

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.5039,  4.5391, -1.5816],
        [-0.4452, -2.7270,  2.0932],
        [-3.2009,  4.1909, -0.0438],
        [-2.7930,  5.0176, -1.4488],
        [-2.8782,  5.3828, -1.7327],
        [-2.7720,  5.2573, -1.7513],
        [-0.7530, -2.0537,  1.9815],
        [-1.6576, -0.5754,  2.2277],
        [-1.9200, -0.0728,  2.1080],
        [-1.6973, -1.4739,  3.0180],
        [-3.1377,  5.9969, -1.9668],
        [-2.1574,  3.2205, -0.7864],
        [-2.5204,  4.3222, -1.1678],
        [-2.7544,  3.5983, -0.0124],
        [-1.2542,  2.0020, -1.6239],
        [-2.8928,  5.6298, -2.0422],
        [-2.5870,  4.3336, -0.9285],
        [-3.0501,  5.8708, -2.0771],
        [ 0.3120, -2.4419,  0.0676],
        [-2.5446,  4.2416, -0.9809],
        [-2.9805,  5.4175, -1.5829],
        [-2.4122,  4.1547, -1.4580],
        [-0.5790, -2.8289,  2.2824],
        [-2.6626,  4.8695, -1.5324],

Epoch number 3
 Current loss 0.12099731713533401

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.3973,  4.3798, -1.6935],
        [-2.0369,  3.2202, -1.0538],
        [-3.1141,  5.9066, -1.8765],
        [-3.4235,  4.6526, -0.1877],
        [ 0.5354, -2.4343, -0.3419],
        [-2.6301,  3.6871, -0.2895],
        [-2.9287,  4.9565, -0.9878],
        [-0.7773, -1.8332,  1.7949],
        [-1.2013,  1.3243, -0.8401],
        [-2.0843,  0.1846,  2.0618],
        [-0.6445,  0.4935, -0.9621],
        [-3.2509,  5.7175, -1.3599],
        [-2.3233,  0.8259,  1.7634],
        [ 0.3059, -2.8310,  0.3542],
        [-0.1620, -1.6601,  0.3533],
        [-3.0417,  5.7786, -1.8698],
        [-1.5693,  0.3182,  1.1498],
        [-1.3489, -1.2720,  2.4107],
        [-2.6920,  5.0380, -1.7790],
        [-1.3825, -0.1538,  1.3932],
        [-1.5177,  0.1837,  1.2699],
        [-2.4850,  4.6204, -1.6968],
        [-2.7703,  5.1121, -1.6183],
        [-2.8057,  5.1004, -1.5092]

Epoch number 3
 Current loss 0.10044784098863602

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.1565,  6.0209, -2.0237],
        [-0.1739, -2.6000,  1.0988],
        [-0.3662, -2.0872,  1.2456],
        [-3.1780,  5.8357, -1.6544],
        [-2.8344,  5.2712, -1.7903],
        [-2.7346,  4.7919, -1.3525],
        [ 0.0005, -2.5643,  0.7747],
        [-1.2999, -2.1065,  3.0621],
        [-1.8670,  2.9839, -1.0037],
        [-0.2880, -2.7393,  1.5151],
        [-0.9511, -2.3731,  2.7727],
        [-2.8546,  5.4496, -1.8990],
        [-2.5676,  4.8741, -1.7751],
        [-2.6942,  4.7709, -1.4104],
        [ 0.4876, -2.7635, -0.0342],
        [-3.0147,  5.7286, -1.9264],
        [-1.9366, -1.3788,  3.2514],
        [-2.8509,  5.3832, -1.7784],
        [-2.6104,  4.1313, -0.8844],
        [ 0.3835, -2.5411, -0.2482],
        [ 0.1726, -3.0072,  0.7292],
        [ 0.2847, -3.1553,  0.6703],
        [-2.9068,  4.6644, -0.7716],
        [-3.1944,  3.9691,  0.1131]

Epoch number 3
 Current loss 0.13596639037132263

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.8388, -0.9538,  2.7499],
        [-1.0322, -2.5952,  3.1328],
        [-2.3114,  3.0131, -0.0360],
        [-2.7186,  5.1033, -1.6975],
        [-2.6865,  4.6066, -1.1567],
        [-2.5549,  4.6846, -1.5811],
        [-1.1428, -2.3227,  3.0448],
        [-0.6063, -2.3654,  2.0590],
        [ 0.5550, -2.9829, -0.0340],
        [-2.7825,  4.1947, -0.5506],
        [ 0.5022, -3.0561,  0.2071],
        [-3.3886,  6.1662, -1.6019],
        [-2.5219,  4.4609, -1.3928],
        [-2.6649,  4.8820, -1.4840],
        [-0.1345, -2.3964,  0.9214],
        [-0.8722, -2.3478,  2.5551],
        [-2.9430,  5.3252, -1.5766],
        [-0.2023, -1.6397,  0.0794],
        [ 0.4808, -2.6869, -0.1468],
        [-2.5643,  3.9800, -0.6777],
        [-1.6004, -1.3240,  2.8241],
        [-2.2674,  3.7845, -1.3463],
        [-2.4920,  4.4007, -1.3458],
        [-2.9464,  5.3235, -1.5146]

Epoch number 3
 Current loss 0.15865206718444824

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.6525,  4.7573, -1.3499],
        [-2.6171,  4.5372, -1.3686],
        [-3.1143,  5.7488, -1.8301],
        [-1.2768, -2.6269,  3.4666],
        [-2.5729,  4.5920, -1.5898],
        [-1.4406, -2.1878,  3.3435],
        [-3.2689,  5.2588, -0.9301],
        [ 0.1743, -2.9216,  0.7886],
        [-1.9051,  0.0316,  2.0571],
        [-3.0981,  5.9695, -2.0753],
        [-3.1794,  6.1960, -2.1534],
        [-1.8319, -1.8057,  3.5376],
        [-3.1566,  5.8998, -1.8676],
        [-1.4258, -2.7166,  3.7403],
        [-2.6046,  4.5622, -1.4963],
        [-2.1010,  1.3799,  0.9642],
        [-1.6984, -1.9762,  3.5397],
        [-2.9380,  2.2908,  1.2811],
        [-2.8109,  5.1725, -1.7655],
        [ 0.2597, -2.6081,  0.3178],
        [-3.0158,  5.4709, -1.7210],
        [-1.5617, -1.8101,  3.1615],
        [-2.8113,  5.2590, -1.8029],
        [-2.7960,  5.2460, -1.7954]

Epoch number 3
 Current loss 0.16572222113609314

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.6194, -1.7263,  3.2570],
        [ 0.0585, -2.8950,  1.2429],
        [-2.8059,  5.2686, -1.7912],
        [-0.7967, -2.8468,  2.8405],
        [-1.1330, -1.2871,  1.9646],
        [ 0.1514, -1.7680, -0.3393],
        [-3.0218,  5.3091, -1.4455],
        [-1.2773, -0.3596,  1.1400],
        [-2.6248,  4.8139, -1.5966],
        [-1.4665,  1.8153, -0.7943],
        [-2.8595,  5.0944, -1.4239],
        [-2.3056,  4.0627, -1.4150],
        [-2.2775,  3.6318, -1.1138],
        [-2.7169,  4.7920, -1.2684],
        [-0.3706, -1.5330,  0.7931],
        [-2.8444,  3.9163, -0.2314],
        [-1.5050,  0.1913,  1.2470],
        [-2.8078,  5.0013, -1.4504],
        [-3.3149,  5.4587, -1.0431],
        [-3.0972,  5.7543, -1.8307],
        [-1.0263, -1.6038,  2.0220],
        [-0.0692, -3.1114,  1.7017],
        [-2.5626,  4.5582, -1.4348],
        [-1.3657, -2.5017,  3.4759]

Epoch number 3
 Current loss 0.1135440468788147

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.3511, -0.5302, -0.7037],
        [-2.6195,  4.6812, -1.5026],
        [-2.6587,  4.3929, -1.0555],
        [-2.0057, -1.0249,  3.0286],
        [-2.3947,  4.3584, -1.6963],
        [-2.0799,  2.9972, -0.6844],
        [-2.7305,  5.0813, -1.6877],
        [-3.1859,  6.0870, -2.0735],
        [-0.5943, -1.4147,  0.9103],
        [-1.4651, -2.6111,  3.7609],
        [-2.0271,  0.9423,  1.3427],
        [-2.5515,  4.4900, -1.1913],
        [-1.2277, -2.1830,  3.1364],
        [-2.6254,  4.0040, -0.6656],
        [-1.5787,  1.5556, -0.0408],
        [-2.3903,  4.2403, -1.4759],
        [-1.9827,  0.0564,  2.1194],
        [-2.6165,  4.6545, -1.4783],
        [-3.1154,  5.6823, -1.6967],
        [-1.4299,  0.2074,  0.9267],
        [-2.8527,  3.5798,  0.0874],
        [-2.7132,  4.9057, -1.6142],
        [-2.4757,  4.5534, -1.8455],
        [-2.3472,  3.8334, -1.1403],

Epoch number 3
 Current loss 0.23549307882785797

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.9832, -1.6285,  1.9884],
        [-1.6736, -1.9211,  3.4984],
        [-2.7761,  5.0812, -1.6698],
        [-3.1571,  4.7789, -0.5944],
        [-2.6320,  3.4030,  0.0152],
        [-2.5004,  4.4130, -1.4857],
        [-2.0865,  3.3754, -1.2673],
        [-2.9054,  5.5464, -2.0746],
        [-1.6029, -1.8544,  3.3140],
        [-1.7307,  2.7072, -1.2643],
        [-3.2402,  5.8608, -1.7000],
        [-0.5173, -2.7921,  2.3875],
        [-2.4541,  3.4361, -0.5017],
        [-1.4674, -1.2449,  2.5810],
        [-2.0701,  0.3559,  1.9085],
        [-2.5390,  2.9259,  0.1633],
        [-2.0207,  2.6356, -0.6147],
        [-0.0516, -0.9099, -0.7850],
        [-2.7861,  4.4938, -0.9622],
        [-2.7813,  5.2446, -1.9478],
        [-0.5376, -1.8562,  1.1537],
        [-0.9546, -2.5171,  2.7658],
        [-2.3669,  3.5683, -0.7295],
        [-2.5580,  4.1612, -0.8788]

Epoch number 3
 Current loss 0.1556725800037384

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.3962, -2.8455,  3.8490],
        [-0.3124, -0.1265, -1.2591],
        [-1.5643, -1.8272,  3.2500],
        [-2.4575,  4.4547, -1.6896],
        [-0.2409, -2.3875,  1.4168],
        [-3.2400,  4.8945, -0.6407],
        [-2.3486,  4.1227, -1.6608],
        [-3.0251,  5.9410, -2.2088],
        [-2.4889,  3.6856, -0.5809],
        [-1.7846,  0.4684,  1.3469],
        [-1.0596, -1.1741,  1.8412],
        [-3.2052,  6.1306, -2.0429],
        [-3.1691,  6.1515, -2.1288],
        [-2.9428,  5.3434, -1.7584],
        [-1.8548,  3.1701, -1.3537],
        [-1.1256, -2.8454,  3.5362],
        [-1.1052,  2.0305, -2.0578],
        [-2.6174,  5.0420, -1.9262],
        [-2.7184,  4.8355, -1.3445],
        [-0.6078, -2.4207,  2.1492],
        [-2.5545,  3.1312,  0.1658],
        [-0.1251, -2.9075,  1.6301],
        [-2.8962,  5.4673, -1.9415],
        [-1.4006, -2.0165,  3.2253],

Epoch number 3
 Current loss 0.15563642978668213

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.4401, -2.6624,  1.9113],
        [-2.9981,  5.6968, -1.9439],
        [-1.6708, -0.7376,  2.3895],
        [-2.9378,  5.5136, -1.9171],
        [-0.1973, -2.3581,  1.0889],
        [-2.9326,  5.2740, -1.5756],
        [ 0.0833, -2.1402,  0.3734],
        [-2.1026,  3.1743, -1.0300],
        [-2.2025,  4.0197, -1.8365],
        [-1.4394, -2.1050,  3.3762],
        [-0.6872, -2.7051,  2.7219],
        [-0.0345, -2.1141,  0.5723],
        [-2.2658,  2.7300,  0.1406],
        [-3.2628,  6.1850, -2.0404],
        [-2.8552,  5.4989, -2.1190],
        [-1.3936, -0.7231,  1.8282],
        [-1.0458, -2.0867,  2.8215],
        [-2.5280,  4.5018, -1.4610],
        [-1.7946,  1.8021, -0.0098],
        [-2.5708,  4.7918, -1.9187],
        [ 0.1963, -2.5449,  0.6447],
        [-1.3442,  0.6033,  0.2924],
        [-2.6748,  5.1776, -2.0925],
        [-3.1649,  6.0932, -2.0928]

Epoch number 3
 Current loss 0.16762879490852356

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.8217,  5.5508, -2.1759],
        [-2.7480,  5.2604, -1.9953],
        [ 0.1443, -2.8306,  0.7553],
        [-2.8212,  5.3300, -1.8775],
        [-2.5352,  4.5889, -1.6757],
        [-2.6172,  4.4059, -1.1409],
        [-0.7766,  0.2995, -0.9235],
        [-3.1908,  6.2002, -2.1886],
        [-2.3026,  4.1247, -1.6991],
        [-1.1507, -2.7884,  3.5660],
        [-3.0972,  5.9843, -2.1207],
        [-2.9010,  5.5412, -2.0453],
        [ 0.5656, -2.9768,  0.0463],
        [-1.9847,  0.1826,  1.9995],
        [-3.0459,  6.0231, -2.2715],
        [-3.1470,  6.0003, -2.0688],
        [ 0.5626, -2.6464, -0.2986],
        [-2.3559,  4.4284, -1.9737],
        [-2.5873,  4.7498, -1.5888],
        [-1.4459, -2.4368,  3.5945],
        [-1.1654, -2.6296,  3.4443],
        [-2.4347,  3.4598, -0.6435],
        [-2.3867,  4.3796, -1.6719],
        [-2.4293,  4.5100, -1.8182]

Epoch number 3
 Current loss 0.20340949296951294

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.7822,  3.7608, -0.1618],
        [-2.1422,  3.7516, -1.5320],
        [-2.9918,  5.9366, -2.4126],
        [-3.1186,  6.0677, -2.2289],
        [-2.7978,  5.1146, -1.7452],
        [ 0.2365, -2.4496,  0.3126],
        [-2.4963,  4.4542, -1.3218],
        [-3.1750,  6.3025, -2.3905],
        [-2.8299,  4.7370, -0.9800],
        [-0.9679, -2.5094,  2.9698],
        [-3.1499,  6.1346, -2.2745],
        [-2.3406,  4.1702, -1.6376],
        [ 0.0908, -2.9759,  1.2062],
        [-2.9815,  5.7748, -2.1370],
        [-2.9461,  5.8135, -2.2445],
        [ 0.1914, -2.8868,  0.8274],
        [-1.5929, -0.2638,  1.7736],
        [ 0.3345, -3.1626,  0.8655],
        [-3.1457,  6.1471, -2.3209],
        [-2.5244,  0.2955,  2.5217],
        [-3.0179,  5.6878, -1.8538],
        [-2.8629,  5.4376, -1.9517],
        [-2.7614,  4.8983, -1.4532],
        [-1.0303, -2.6889,  3.2265]

Epoch number 3
 Current loss 0.19912712275981903

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.0292, -2.3823,  2.9902],
        [-1.9884,  2.3593, -0.2495],
        [-2.0353, -0.9025,  2.9469],
        [-3.2779,  6.1780, -2.0406],
        [-2.5337,  4.9219, -2.0041],
        [-1.0079, -2.2676,  2.9084],
        [-1.2648, -2.3387,  3.2779],
        [-2.7833,  5.2624, -1.9002],
        [-0.4961, -2.9574,  2.3411],
        [-1.1501, -2.6037,  3.3353],
        [-2.9242,  5.3969, -1.7638],
        [-2.7430,  5.3989, -2.1419],
        [-2.8032,  5.1170, -1.5973],
        [-1.9168, -1.9349,  3.7387],
        [-1.4660, -0.3529,  1.6571],
        [-3.0616,  5.9796, -2.1784],
        [-2.2847,  4.1618, -1.5632],
        [-3.1228,  6.1052, -2.2882],
        [-3.0038,  4.5756, -0.6285],
        [-2.2388,  3.9629, -1.6032],
        [-3.0540,  5.9452, -2.1280],
        [-0.1718, -2.3016,  0.7878],
        [-2.7803,  5.3142, -1.9500],
        [-1.6010, -1.1635,  2.7162]

Epoch number 3
 Current loss 0.1394985020160675

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 0.3237, -2.5451, -0.0699],
        [-2.2343,  1.4894,  1.1688],
        [-2.5029,  4.7421, -1.8806],
        [-1.6163, -0.0124,  1.5737],
        [-2.8017,  4.3082, -0.8054],
        [-2.7226,  4.9923, -1.6448],
        [-2.3903,  3.0123,  0.0826],
        [-2.8462,  5.0124, -1.3841],
        [-0.8676, -2.9766,  3.2756],
        [-1.3290, -1.5560,  2.6303],
        [-3.1925,  6.3125, -2.3684],
        [-2.7152,  4.1590, -0.6871],
        [ 0.2565, -2.8740,  0.4773],
        [-3.2039,  6.1657, -2.1101],
        [ 0.3984, -3.2139,  0.5556],
        [-2.0096,  3.0509, -0.9615],
        [-1.1569, -2.7033,  3.5581],
        [-2.5317,  4.2286, -1.1349],
        [-1.2665, -2.9790,  3.8704],
        [-2.7158,  4.9964, -1.6417],
        [ 0.1623, -2.6670,  0.4768],
        [-3.2956,  5.1805, -0.9017],
        [-0.8647, -2.2021,  2.4970],
        [-2.1125,  3.9896, -2.0383],

Epoch number 3
 Current loss 0.13040503859519958

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 0.4404, -2.9975,  0.1775],
        [-2.8595,  5.4750, -1.9344],
        [-1.3313, -1.1896,  2.2240],
        [-2.9908,  5.1039, -1.2150],
        [-1.1194, -2.9197,  3.6598],
        [-2.5904,  3.7411, -0.6161],
        [-2.3915,  3.7947, -1.0773],
        [-0.0283, -3.0780,  1.5904],
        [-2.2320,  4.0364, -1.5902],
        [-2.6924,  5.0618, -1.7660],
        [-2.8146,  5.0538, -1.4494],
        [-0.7890,  0.8216, -1.2723],
        [-0.6050, -1.8310,  1.5159],
        [-2.7057,  4.4358, -1.0498],
        [-2.8885,  5.1261, -1.4841],
        [-2.9104,  5.4644, -1.8604],
        [-2.9708,  5.7338, -2.1563],
        [-2.4618,  4.7374, -1.9248],
        [-2.6652,  5.0587, -1.9303],
        [ 0.0476, -3.0871,  1.4613],
        [-2.3349,  3.5691, -0.8214],
        [-3.1890,  6.2594, -2.2412],
        [-1.0454, -2.8209,  3.4198],
        [-0.1636, -2.8097,  1.6092]

Epoch number 3
 Current loss 0.13082623481750488

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.8575, -1.3203,  3.2023],
        [-2.2706,  3.9299, -1.3230],
        [-1.0729, -0.0686,  0.4510],
        [-2.9512,  4.9249, -1.1696],
        [-2.7487,  5.1691, -1.8360],
        [-2.0632,  3.5473, -1.6354],
        [-0.6753, -2.8650,  2.8174],
        [-2.8662,  5.5185, -2.1733],
        [-2.8912,  5.6647, -2.2516],
        [-2.0701, -0.2185,  2.4605],
        [-1.2421, -2.3509,  3.4178],
        [-0.4941, -1.6037,  1.1276],
        [-1.6864, -2.2350,  3.7295],
        [ 0.4919, -3.0060,  0.1015],
        [-2.9209,  4.9501, -1.3003],
        [-2.2156, -1.4562,  3.6681],
        [-2.4579,  4.7981, -2.1230],
        [-2.7419,  4.9618, -1.5676],
        [-2.6159,  4.8567, -1.6913],
        [-2.0222,  3.6194, -1.7185],
        [-1.4936, -1.5193,  2.8142],
        [-2.9030,  5.2910, -1.7110],
        [-2.8555,  5.6381, -2.3274],
        [-2.0705,  3.9356, -2.0163]

Epoch number 3
 Current loss 0.11465397477149963

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 0.3306, -3.2425,  0.7313],
        [-0.2107, -2.5177,  1.1744],
        [-2.8305,  5.7157, -2.4689],
        [-0.1204, -0.4827, -1.3111],
        [-1.9782,  3.3663, -1.4731],
        [-0.0108, -2.9723,  1.3740],
        [-2.2927,  4.5506, -2.2564],
        [-2.6686,  5.2249, -2.1622],
        [-1.1046,  0.2725,  0.2899],
        [-1.1673,  1.1091, -0.8332],
        [-0.5596, -1.5493,  0.9684],
        [-2.9955,  2.8049,  0.8376],
        [-2.3382,  4.2197, -1.6005],
        [-3.2263,  6.4359, -2.4502],
        [-1.6228, -2.3744,  3.7614],
        [-1.9512,  3.5985, -1.5242],
        [-2.8628,  5.7321, -2.2819],
        [-0.7336, -2.7692,  2.6550],
        [-2.7061,  5.1142, -1.9616],
        [-1.7266,  3.3233, -2.1993],
        [-2.7620,  4.5826, -1.2527],
        [-2.3446,  2.2860,  0.4638],
        [-0.5726, -2.7884,  2.5171],
        [-2.1996,  3.3255, -0.7188]

Epoch number 3
 Current loss 0.12242439389228821

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.9514,  5.9615, -2.4194],
        [-2.3771,  0.2389,  2.3541],
        [-2.7530,  5.1105, -1.7993],
        [-2.2508, -0.2808,  2.5976],
        [-3.0586,  6.0475, -2.3565],
        [-0.9742, -1.8005,  2.1073],
        [-2.2402, -0.4739,  2.8024],
        [ 0.0987, -2.6808,  0.9738],
        [-2.8769,  5.2188, -1.6970],
        [-0.8992, -2.5136,  2.7179],
        [ 0.5748, -3.0593,  0.0766],
        [ 0.4474, -2.9522,  0.2442],
        [-2.8596,  5.6340, -2.3217],
        [-2.9523,  5.8059, -2.1988],
        [-2.6169,  4.6598, -1.5195],
        [ 0.0794, -2.6592,  1.0578],
        [-1.8219,  0.7808,  1.0596],
        [-3.1785,  5.3004, -1.1150],
        [-2.5608,  2.8900,  0.3455],
        [-0.9487, -2.6888,  3.2153],
        [-1.2239,  1.4433, -1.2241],
        [-2.7829,  5.2933, -2.0443],
        [-3.0440,  6.2001, -2.5354],
        [-1.7376,  2.5226, -1.1732]

Epoch number 3
 Current loss 0.1299801468849182

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.7552, -2.5468,  2.7412],
        [-2.8338,  5.6475, -2.2955],
        [-3.0652,  6.1782, -2.4445],
        [-3.1680,  6.4067, -2.5464],
        [-0.2522, -2.6091,  1.5354],
        [-3.3039,  6.4335, -2.2766],
        [-2.5689,  4.5051, -1.4162],
        [-1.9851, -1.1395,  3.1685],
        [-2.9180,  4.6110, -0.8441],
        [-3.1983,  4.1947, -0.2111],
        [-3.0784,  6.2360, -2.5269],
        [-2.0677, -1.0041,  2.9663],
        [-2.7969,  4.0696, -0.5000],
        [-3.1167,  6.2321, -2.4740],
        [ 0.4616, -2.7196, -0.0953],
        [-1.6555, -2.1095,  3.5507],
        [-2.8833,  5.8289, -2.4259],
        [-1.8702,  3.7280, -2.3289],
        [-1.7420, -1.1348,  2.7863],
        [-3.1684,  6.3876, -2.4864],
        [-2.2683, -0.7462,  3.0425],
        [-2.3021,  3.9972, -1.6134],
        [-0.0584, -2.8460,  1.1612],
        [-2.6943,  5.1721, -1.9251],

Epoch number 3
 Current loss 0.14641110599040985

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.0343,  1.0623,  1.2214],
        [-2.5146,  4.6644, -1.8818],
        [-0.5347, -2.5793,  2.1302],
        [ 0.1690, -1.7133, -0.5926],
        [-2.9683,  5.5854, -1.9083],
        [-2.9935,  5.8620, -2.2784],
        [-1.3033, -0.9596,  1.5599],
        [-1.6817,  0.0632,  1.5977],
        [-2.8792,  5.5721, -2.1293],
        [-3.2188,  5.8974, -1.8435],
        [-1.2326, -2.7644,  3.6409],
        [-3.1440,  6.2822, -2.4765],
        [-2.6833,  4.9379, -1.8716],
        [-2.8667,  2.2122,  1.1525],
        [-2.4852,  3.8352, -1.0610],
        [-1.8563,  3.5900, -2.2641],
        [-3.2548,  6.5690, -2.5749],
        [-1.5476, -1.9283,  3.2889],
        [-3.1434,  6.4668, -2.6554],
        [-3.1098,  6.3623, -2.6447],
        [-2.7262,  5.0011, -1.5904],
        [-2.9261,  5.8159, -2.2934],
        [-3.1483,  6.1290, -2.2049],
        [-0.7288, -0.4280, -0.0941]

Epoch number 3
 Current loss 0.20297463238239288

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.7638,  5.3959, -2.1491],
        [-3.2151,  6.2038, -2.2242],
        [-3.2411,  6.1633, -2.0997],
        [-3.0680,  5.9494, -2.1782],
        [-3.1367,  5.3326, -1.3327],
        [ 0.3462, -3.3559,  0.8043],
        [-1.8400,  1.7671,  0.0351],
        [-3.1444,  6.1562, -2.3240],
        [-1.8429,  1.2807,  0.5186],
        [-2.0195,  2.8137, -0.8508],
        [-0.4489, -1.6599,  0.8642],
        [-3.0267,  5.9523, -2.2834],
        [-3.1003,  5.8355, -1.9713],
        [-2.9787,  5.8071, -2.1543],
        [-3.0255,  6.0861, -2.4503],
        [-2.6977,  5.1906, -2.0904],
        [-2.9103,  5.7923, -2.2747],
        [-1.9044,  1.4609,  0.2908],
        [-0.5001, -2.2574,  1.9152],
        [-3.1373,  6.2469, -2.4749],
        [-3.1712,  6.3018, -2.4330],
        [-2.6598,  4.6893, -1.4286],
        [-2.0034,  2.7895, -1.0131],
        [-2.8417,  5.6157, -2.4128]

Epoch number 3
 Current loss 0.1379421204328537

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.5056, -1.3679,  2.6110],
        [-2.6604,  5.1490, -2.0135],
        [-2.7594,  5.0698, -1.7377],
        [ 0.3311, -3.2474,  0.6910],
        [ 0.1469, -3.1723,  1.1256],
        [ 0.5561, -3.5592,  0.4128],
        [-3.0315,  6.0057, -2.3618],
        [-2.9940,  5.7665, -2.2312],
        [-2.6747,  5.0846, -2.1010],
        [-1.4644, -1.9625,  3.2041],
        [-2.9276,  5.7287, -2.3262],
        [-2.8996,  5.7447, -2.3285],
        [-2.9946,  5.9469, -2.3566],
        [-1.0662, -1.7980,  2.3903],
        [-2.5109,  4.7362, -2.0028],
        [-2.9925,  6.1108, -2.5422],
        [-1.3853, -1.1061,  2.0695],
        [-2.6523,  5.2152, -2.2788],
        [-1.7076, -1.2426,  2.8523],
        [-2.5407,  4.9045, -2.2391],
        [-2.5046,  4.7049, -2.0308],
        [-2.1060,  3.3693, -1.3088],
        [-2.1934,  3.5521, -1.3760],
        [-2.0411,  3.2304, -0.7237],

Epoch number 3
 Current loss 0.14291507005691528

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.7258,  5.2244, -2.1155],
        [-3.2778,  6.4924, -2.5182],
        [-2.7712,  5.2999, -2.0164],
        [-1.9284,  3.2660, -1.7637],
        [-3.0655,  6.0845, -2.4212],
        [-2.8381,  5.5286, -2.3844],
        [-1.7756,  2.7826, -1.4823],
        [-1.4956,  1.7750, -1.0491],
        [-0.1307, -2.2989,  0.6944],
        [-2.3633,  4.6669, -2.4844],
        [-0.8894, -2.3389,  2.7738],
        [ 0.6077, -3.5785,  0.3154],
        [-2.1929,  3.9502, -1.8551],
        [-2.5592,  4.9246, -2.2666],
        [-2.4031,  4.6988, -2.4556],
        [ 0.3933, -3.1392,  0.8069],
        [-3.1999,  6.3360, -2.4226],
        [-2.3809,  3.9325, -1.2780],
        [-2.7655,  5.2462, -2.0533],
        [-0.7805, -0.4828, -0.0951],
        [ 0.4236, -1.6674, -1.1941],
        [-2.7865,  5.3268, -2.0953],
        [-0.7667, -2.0429,  1.7276],
        [-2.3541,  4.0804, -1.6814]

Epoch number 3
 Current loss 0.12060971558094025

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.3130, -2.8661,  3.7379],
        [-2.9046,  5.6389, -2.3051],
        [-2.9479,  5.6746, -2.1519],
        [-1.6035, -1.7384,  3.1415],
        [-2.8741,  5.6621, -2.2731],
        [-3.3112,  6.2390, -2.1224],
        [-2.6335,  4.5354, -1.5163],
        [-3.2910,  6.6197, -2.6458],
        [ 0.0723, -2.9789,  1.1011],
        [-2.9271,  5.7166, -2.4225],
        [ 0.2076, -3.0861,  1.0934],
        [-3.1191,  6.3664, -2.7335],
        [-0.4513, -2.9836,  2.3273],
        [-2.3276,  4.6126, -2.3631],
        [-2.5627,  4.8702, -2.0801],
        [-3.1081,  6.1715, -2.5204],
        [-2.6881,  3.8285, -0.6313],
        [-0.0703, -3.2380,  1.7049],
        [-1.2311, -2.0580,  2.8666],
        [-1.5126, -0.6974,  1.9446],
        [-3.3089,  6.7055, -2.6913],
        [-2.7624,  5.1204, -2.0539],
        [-1.6600,  2.9573, -1.3020],
        [-2.6002,  4.9015, -2.0703]

Epoch number 3
 Current loss 0.14728286862373352

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.7357,  5.0986, -1.9973],
        [-2.9917,  3.0964,  0.5403],
        [-2.9151,  4.5185, -1.0731],
        [-2.7836,  5.4842, -2.5494],
        [-1.8123,  2.5864, -0.6819],
        [-2.7558,  4.7072, -1.4094],
        [-0.9132, -0.9732,  0.9190],
        [-2.7425,  5.1194, -2.0087],
        [-2.9153,  5.5744, -2.2950],
        [-1.7077,  0.2974,  1.4819],
        [-0.0367, -0.6806, -1.4855],
        [-2.9861,  5.9201, -2.6283],
        [-2.8055,  4.2573, -0.8602],
        [ 0.4350, -3.6147,  0.6858],
        [-0.4855, -3.4382,  2.8685],
        [-2.7859,  5.1177, -1.8307],
        [-1.0828, -2.4715,  3.0911],
        [-3.1935,  4.2080, -0.1935],
        [-3.1717,  6.3889, -2.6189],
        [-0.5297, -0.6105, -0.2897],
        [-1.7492,  1.9956, -0.5398],
        [-2.1606,  4.0715, -2.1409],
        [-1.6447, -1.4534,  2.9704],
        [-2.6594,  4.9442, -2.0828]

Epoch number 3
 Current loss 0.19167038798332214

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.6257,  4.3881, -1.3549],
        [-0.0097, -3.4076,  1.6780],
        [-3.0482,  6.0228, -2.4667],
        [ 0.7577, -3.3290, -0.0163],
        [-2.0522, -0.2176,  2.2664],
        [-3.4333,  3.9732,  0.2682],
        [-2.6319,  4.9222, -2.1023],
        [-2.6771,  5.0406, -2.0271],
        [ 0.7179, -3.5198,  0.0119],
        [-2.6394,  4.3252, -1.3951],
        [-2.8305,  5.2254, -1.7576],
        [ 0.2946, -3.1223,  0.6929],
        [-2.5025,  4.3369, -1.7523],
        [-0.9226, -2.6680,  3.0107],
        [-1.8059,  1.0932,  0.5929],
        [-2.5507,  3.5565, -0.6341],
        [-2.7043,  4.8344, -1.7155],
        [-2.4207,  3.4065, -0.8389],
        [-2.6614,  5.1205, -2.4858],
        [ 0.5132, -3.3422,  0.3173],
        [-2.8685,  5.5621, -2.5086],
        [-3.0415,  6.1133, -2.6152],
        [-3.0640,  6.0226, -2.4777],
        [-1.7293, -1.8813,  3.4276]

Epoch number 3
 Current loss 0.19052277505397797

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.9605,  5.3429, -1.6549],
        [-2.9343,  5.6709, -2.2823],
        [-2.5908,  4.9885, -2.2967],
        [-2.8456,  5.6208, -2.4082],
        [-1.2123, -2.5255,  3.3751],
        [-2.3737,  3.0956, -0.3884],
        [-2.6575,  4.9654, -1.9729],
        [-2.5601,  3.1547, -0.0264],
        [-2.7192,  5.1721, -2.0366],
        [-2.1040,  2.2056, -0.1277],
        [-1.7874, -0.1635,  1.8023],
        [-0.8749, -2.0933,  2.1684],
        [-3.0983,  6.0029, -2.2674],
        [-2.8850,  5.5325, -2.0820],
        [-2.8567,  5.2428, -1.9414],
        [-2.6235,  4.5288, -1.5472],
        [-3.1709,  4.5181, -0.6174],
        [-0.6421, -2.5174,  2.1834],
        [-1.2368, -2.3301,  3.2030],
        [-1.7864,  2.5433, -1.1911],
        [-1.5487, -2.5618,  3.8635],
        [-2.3899,  3.0345, -0.4348],
        [ 0.2821, -3.1953,  0.8514],
        [-2.7666,  4.3078, -0.9635]

Epoch number 3
 Current loss 0.14699232578277588

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.1671,  5.8747, -2.1352],
        [-3.0148,  5.8267, -2.3468],
        [-2.7680,  5.1231, -2.1571],
        [-3.2390,  6.2912, -2.4099],
        [-0.6217, -3.4952,  3.1202],
        [-2.6236,  4.8157, -2.0713],
        [ 0.7757, -3.2898, -0.2844],
        [-2.8748,  5.4689, -2.3105],
        [-2.8335,  5.3386, -2.1111],
        [-1.3070, -0.0976,  0.8123],
        [ 0.2676, -2.3996, -0.0783],
        [-2.0042, -2.0331,  3.8955],
        [-0.1311, -2.5285,  1.0164],
        [-3.4114,  6.7089, -2.6345],
        [-2.8027,  5.2891, -2.1102],
        [-2.3728,  3.9175, -1.5685],
        [ 0.3544, -3.7056,  1.1427],
        [-3.1392,  5.9945, -2.2969],
        [-2.9125,  5.5323, -2.1514],
        [ 0.5899, -3.4184,  0.1587],
        [-2.6613,  5.0255, -2.1315],
        [-1.3210, -2.1637,  3.2140],
        [-0.3203, -3.2010,  2.2367],
        [ 0.2557, -3.3197,  0.9724]

Epoch number 3
 Current loss 0.1727559119462967

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 0.4270, -3.0622,  0.2368],
        [-2.1277,  3.8688, -1.8194],
        [-0.5529, -2.1148,  1.3709],
        [-1.7330, -2.1811,  3.6791],
        [-3.0847,  6.0156, -2.4954],
        [-2.7915,  5.1539, -2.1704],
        [-0.3885, -1.7909,  0.5737],
        [-3.2700,  6.4273, -2.6094],
        [ 0.2307, -1.7497, -0.7516],
        [-0.9181, -2.2199,  2.2175],
        [-2.8468,  5.2554, -1.9180],
        [-0.9795, -1.7635,  2.0946],
        [-0.3165, -2.9986,  1.9079],
        [-3.4401,  6.7647, -2.6395],
        [-2.8161,  5.3213, -2.1925],
        [-3.2751,  6.3790, -2.5858],
        [-1.1074, -1.1720,  1.6121],
        [-1.5917, -2.2688,  3.6333],
        [-2.6718,  5.0833, -2.1597],
        [-2.0908,  3.4581, -1.2666],
        [-0.1613, -1.0255, -0.4730],
        [-3.2515,  6.2109, -2.4126],
        [-2.9856,  5.7477, -2.3525],
        [-2.9264,  5.5501, -2.1912],

Epoch number 3
 Current loss 0.17378470301628113

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.3115,  6.3102, -2.4208],
        [ 0.6828, -2.7701, -0.7323],
        [-2.8686,  5.4846, -2.3641],
        [-1.1611, -1.7287,  2.3356],
        [-3.3551,  5.6805, -1.5364],
        [-3.0848,  5.6612, -1.9671],
        [-3.0265,  5.8446, -2.4133],
        [-0.7404, -3.1304,  3.2017],
        [-1.9820, -0.4346,  2.3549],
        [-3.3986,  5.8396, -1.6707],
        [-3.3745,  6.6062, -2.6548],
        [-0.0690, -2.5611,  0.8285],
        [-1.2283, -2.6388,  3.4194],
        [-3.5639,  7.0735, -2.7759],
        [-3.1526,  6.0789, -2.4658],
        [-3.2784,  6.1016, -2.1735],
        [-3.3108,  5.8430, -1.8167],
        [-2.6660,  4.9678, -2.0636],
        [-2.9861,  5.4005, -1.8078],
        [-2.7160,  3.5284, -0.4058],
        [-3.0611,  5.9159, -2.5181],
        [-2.1005,  2.1638, -0.1453],
        [-3.0018,  5.4823, -2.1130],
        [ 0.4418, -2.7147, -0.3727]

Epoch number 3
 Current loss 0.11312359571456909

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.3309,  6.6495, -2.7800],
        [-0.1337, -2.4011,  0.6631],
        [-1.0711, -3.2675,  3.7932],
        [-0.6899, -3.0591,  2.8371],
        [-2.4349,  4.4345, -2.1313],
        [-3.2232,  5.5312, -1.5831],
        [-2.8108,  4.9813, -1.7800],
        [ 0.5957, -2.3304, -0.8311],
        [-2.8382,  5.4057, -2.2643],
        [-3.5671,  6.9874, -2.6655],
        [-3.2700,  3.8534,  0.1606],
        [ 0.5857, -2.7628, -0.5383],
        [-3.3789,  6.7143, -2.7369],
        [-3.0392,  5.2195, -1.6220],
        [-2.9381,  5.1000, -1.6531],
        [-2.5564,  3.1324, -0.0771],
        [-2.3068,  2.1571,  0.1999],
        [-0.9251, -2.6980,  3.0450],
        [-1.6197, -2.3174,  3.7187],
        [ 0.3742, -2.2791, -0.5062],
        [-3.1678,  6.0929, -2.4719],
        [ 0.2549, -1.7223, -0.8533],
        [-2.6416,  4.8939, -1.9948],
        [-3.1054,  5.9339, -2.4534]

Epoch number 3
 Current loss 0.16606464982032776

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.3107,  6.5025, -2.7070],
        [-2.2744, -1.0595,  3.3452],
        [-2.3568,  3.3284, -0.5460],
        [-3.2184,  5.6866, -1.7790],
        [-3.2774,  6.2747, -2.3876],
        [-0.5177, -2.8812,  2.2477],
        [-3.3624,  6.2899, -2.2736],
        [ 0.7292, -3.4521, -0.2044],
        [-3.3150,  5.8675, -1.8082],
        [-3.4798,  6.2734, -1.9457],
        [-0.5395, -2.4194,  1.6406],
        [-2.7092,  5.1686, -2.2648],
        [-3.0170,  5.6623, -2.2550],
        [-0.6768, -2.9278,  2.6249],
        [-3.4505,  6.6840, -2.5977],
        [-1.3490, -2.6400,  3.6264],
        [ 0.7515, -2.9880, -0.6644],
        [-3.2551,  6.3447, -2.5677],
        [-3.0793,  5.7993, -2.3498],
        [-2.5126,  4.2218, -1.6571],
        [-3.1319,  5.4883, -1.6757],
        [-1.5766,  0.0177,  1.2461],
        [-2.1383,  2.3643, -0.2070],
        [-0.9328, -2.8226,  3.1087]

Epoch number 3
 Current loss 0.15871328115463257

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.1654,  5.4534, -1.5534],
        [-2.8724,  5.3769, -2.0756],
        [-3.2109,  5.7991, -1.8969],
        [-1.6318,  1.2608,  0.0089],
        [-2.8232,  5.0969, -1.9093],
        [-3.1798,  5.7637, -1.9320],
        [-2.9171,  5.4910, -2.2628],
        [-3.4474,  5.8311, -1.5488],
        [ 0.5365, -3.5681,  0.6021],
        [-3.0019,  5.4822, -1.9639],
        [-3.5984,  6.4241, -1.9766],
        [-3.4987,  6.3501, -2.0900],
        [-2.8199,  5.3826, -2.3724],
        [-1.6309, -1.9221,  3.4617],
        [-3.1604,  5.7413, -1.9158],
        [ 0.3655, -2.7648,  0.0501],
        [ 0.6751, -3.5073,  0.2663],
        [-3.2761,  6.4418, -2.5742],
        [-2.4995,  3.3586, -0.5395],
        [-3.2091,  5.3929, -1.3861],
        [-0.7496, -2.2646,  2.1227],
        [-2.8233,  5.3352, -2.3031],
        [-2.9667,  5.6123, -2.2217],
        [-3.5112,  6.4829, -2.2183]

Epoch number 4
 Current loss 0.14388751983642578

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.5261,  6.6728, -2.4059],
        [-3.0106,  5.5124, -2.0063],
        [-3.1339,  6.1305, -2.5488],
        [-2.9182,  5.4256, -2.0174],
        [-2.4627,  4.2045, -1.4911],
        [-3.1150,  5.6939, -1.9345],
        [-2.2192, -0.7403,  2.9398],
        [-2.2343,  0.7934,  1.6216],
        [-0.4280, -3.4219,  2.7720],
        [-1.9308,  3.2064, -1.0540],
        [-3.1900,  6.2331, -2.5934],
        [-2.3949,  4.2699, -1.8232],
        [-2.5622,  4.1705, -1.4016],
        [-3.0159,  4.4679, -0.7488],
        [-3.2624,  6.4045, -2.6718],
        [-0.7908, -2.9124,  3.0759],
        [-3.2867,  6.2726, -2.3807],
        [-3.0451,  5.5491, -2.0480],
        [ 0.3227, -3.6691,  1.1502],
        [-1.0132, -2.5475,  2.9313],
        [-0.9932,  0.2235, -0.2214],
        [-1.5388, -2.3213,  3.6791],
        [-3.2420,  5.5169, -1.5757],
        [-1.8811,  2.8730, -0.8688]

Epoch number 4
 Current loss 0.10532325506210327

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.0667,  5.6975, -2.1340],
        [ 0.3727, -2.0108, -0.8524],
        [-2.2390, -0.7320,  2.9850],
        [-3.4464,  6.5963, -2.4966],
        [-3.5184,  6.7116, -2.4114],
        [-0.7445, -3.2638,  3.0152],
        [-2.0292, -0.8571,  2.9126],
        [-3.2109,  6.1592, -2.4400],
        [ 0.5234, -2.9826, -0.1424],
        [-2.9120,  4.8694, -1.3530],
        [-2.6639,  4.9940, -2.1517],
        [-1.3045,  1.3862, -0.9553],
        [-3.2500,  5.4646, -1.4607],
        [-3.3129,  6.1938, -2.2575],
        [-2.4808,  0.3889,  2.2834],
        [-1.4497, -1.5506,  2.7896],
        [-3.2700,  6.1578, -2.2342],
        [-3.2897,  6.2828, -2.4843],
        [-3.2465,  5.4370, -1.4540],
        [-1.4581, -3.1554,  4.2158],
        [-3.4158,  6.4764, -2.4138],
        [ 0.1565, -3.3834,  1.2711],
        [-3.3331,  6.3423, -2.3996],
        [-1.4983, -2.1371,  3.3437]

Epoch number 4
 Current loss 0.10161558538675308

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 0.5441, -3.3284,  0.2492],
        [-2.1499, -0.8214,  3.0817],
        [-3.6876,  7.0462, -2.5273],
        [-1.4645, -2.9809,  4.1228],
        [-2.6885,  4.9272, -2.0417],
        [-2.1933,  3.9056, -2.0112],
        [-0.0499, -2.8977,  1.2260],
        [-2.7134,  4.5763, -1.4742],
        [-2.3715, -0.7729,  3.2340],
        [-2.8018,  5.1523, -2.1148],
        [-3.6204,  6.8513, -2.4706],
        [-0.5764, -3.1573,  2.6937],
        [-2.8885,  5.2028, -1.9168],
        [-2.6004,  4.7505, -2.1305],
        [-2.5655,  2.9651,  0.1318],
        [-0.7297, -2.2083,  1.9969],
        [-2.8028,  4.0160, -0.7302],
        [ 0.3009, -1.5676, -1.2356],
        [-3.4801,  6.2937, -2.0285],
        [-3.4307,  6.4820, -2.4228],
        [-3.3993,  6.6906, -2.7738],
        [-3.2609,  6.0541, -2.1130],
        [-1.1102, -2.4332,  3.1126],
        [-3.4896,  6.7390, -2.6921]

Epoch number 4
 Current loss 0.07749554514884949

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.0746, -0.8736,  1.2691],
        [-0.9629, -2.7548,  3.0736],
        [-3.2047,  5.5830, -1.7216],
        [-3.3484,  6.3853, -2.4881],
        [-2.5802,  4.6587, -2.0951],
        [-3.5416,  6.7193, -2.5049],
        [-1.3392, -2.2368,  3.2246],
        [-2.9539,  5.1368, -1.7443],
        [-3.0400,  5.5807, -2.0366],
        [ 0.1930, -3.6291,  1.3364],
        [-3.3887,  6.3124, -2.3100],
        [-2.6513,  4.1585, -1.2010],
        [-3.4825,  6.7773, -2.7314],
        [-3.2212,  5.7445, -1.8248],
        [-2.7514,  4.7649, -1.6857],
        [-2.6600,  4.5781, -1.6520],
        [-3.5729,  6.7421, -2.4388],
        [-2.9339,  5.1033, -1.7188],
        [-3.3975,  6.2737, -2.1590],
        [-3.4400,  6.3245, -2.1802],
        [-3.5981,  6.9743, -2.6864],
        [-3.5518,  6.7626, -2.5167],
        [-3.0290,  5.6674, -2.2463],
        [-3.1275,  5.4991, -1.7606]

Epoch number 4
 Current loss 0.08708074688911438

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.5938,  4.3947, -1.4988],
        [-0.7580, -3.2283,  3.1675],
        [-3.2129,  5.7982, -1.9558],
        [-3.7236,  6.1759, -1.4975],
        [-0.7854, -2.9085,  2.9262],
        [-2.9638,  4.1597, -0.6096],
        [-2.9601,  4.5873, -0.9870],
        [-3.3362,  6.1829, -2.1832],
        [-3.4769,  6.7436, -2.6322],
        [-3.4444,  6.5455, -2.4005],
        [-3.2950,  6.1462, -2.1908],
        [-3.3654,  6.2928, -2.3394],
        [-2.8624,  5.1623, -1.8920],
        [-3.0308,  4.5781, -0.9516],
        [-3.4646,  6.4327, -2.3375],
        [-3.4724,  6.6497, -2.5762],
        [-3.2955,  6.2993, -2.5454],
        [-3.5868,  6.3298, -1.8989],
        [-3.5365,  6.3152, -1.8985],
        [-3.4153,  6.1556, -2.0592],
        [-3.1720,  5.4532, -1.6896],
        [ 0.4442, -3.5135,  0.7473],
        [-2.7354,  5.0604, -2.1913],
        [-2.8958,  5.4931, -2.2922]

Epoch number 4
 Current loss 0.13859687745571136

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.1376,  5.3760, -1.6798],
        [-0.3033, -3.2053,  2.0930],
        [-3.3757,  6.2413, -2.3285],
        [-2.9729,  5.0458, -1.5523],
        [-2.8938,  5.2719, -1.9523],
        [-2.9064,  5.0845, -1.6349],
        [-0.3749, -3.5160,  2.9518],
        [-3.3217,  3.5571,  0.4742],
        [-0.6462, -3.1875,  2.7745],
        [-0.6732, -3.2024,  3.0398],
        [-2.9706,  5.2007, -1.6923],
        [-2.3493,  0.9703,  1.7152],
        [-3.2887,  6.1109, -2.2598],
        [-1.9283, -1.5362,  3.3650],
        [-1.3668, -3.1385,  4.0167],
        [-3.0475,  5.2684, -1.7479],
        [-3.4802,  6.4605, -2.3271],
        [-3.1754,  5.6496, -1.9449],
        [-1.7554,  3.1590, -2.2022],
        [-2.5456,  4.3620, -1.8121],
        [ 0.4486, -3.6077,  0.9741],
        [-3.0998,  5.6491, -2.0707],
        [-0.4243, -3.2691,  2.4997],
        [-1.2777, -1.7415,  2.5884]

Epoch number 4
 Current loss 0.15029194951057434

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.1602,  5.7164, -2.0177],
        [-3.5737,  6.2383, -1.8299],
        [-3.4403,  6.4829, -2.4125],
        [-1.4815, -2.7945,  4.0120],
        [-3.4971,  5.3278, -1.0468],
        [-2.8277,  5.2754, -2.4947],
        [-2.3965,  2.6417, -0.1099],
        [-3.4340,  6.5823, -2.5465],
        [-3.5453,  6.4403, -2.1322],
        [-3.5600,  6.8455, -2.6101],
        [-1.5512, -2.7754,  4.0053],
        [ 0.5132, -3.7096,  0.6528],
        [ 0.4741, -2.6403, -0.5422],
        [-1.6862,  2.4985, -1.3207],
        [-3.1471,  5.8061, -2.1398],
        [-3.6126,  6.9717, -2.6722],
        [-3.1150,  5.6350, -2.0679],
        [-2.5431,  2.1514,  0.7442],
        [-1.8916,  1.5981,  0.0014],
        [-2.9125,  5.2286, -2.0340],
        [-3.3908,  4.9560, -0.7438],
        [-2.8201,  4.9527, -1.7455],
        [-1.1199, -2.3782,  3.0663],
        [ 0.8166, -3.8243,  0.1609]

Epoch number 4
 Current loss 0.10228884220123291

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 0.3335, -3.7094,  1.2419],
        [-0.7201, -3.1234,  3.1571],
        [-3.6322,  6.5885, -2.1678],
        [-2.8989,  5.0941, -1.7808],
        [-2.7204,  5.0274, -2.0830],
        [-2.9251,  5.3558, -2.2802],
        [-1.1634, -3.0257,  3.6800],
        [-3.5119,  6.6227, -2.5411],
        [-3.1362,  5.5485, -1.8306],
        [ 0.4999, -3.5841,  0.9003],
        [-0.4287, -3.4785,  2.7731],
        [-0.8422, -3.0482,  3.1694],
        [-3.1823,  5.3986, -1.6464],
        [-1.9258,  0.8873,  1.0830],
        [-2.2785,  1.5288,  0.9189],
        [-2.7834,  3.4414, -0.0693],
        [-3.3217,  5.8679, -1.8086],
        [-2.7636,  5.1221, -2.3202],
        [-3.4596,  6.6961, -2.6580],
        [-2.9709,  4.8409, -1.2332],
        [-3.1003,  5.9609, -2.5277],
        [-3.1357,  5.6489, -1.9860],
        [-3.1987,  5.0011, -1.0539],
        [-3.3402,  6.2817, -2.4550]

Epoch number 4
 Current loss 0.11036057770252228

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.6287, -2.1801,  1.7398],
        [ 0.7418, -3.1367, -0.1542],
        [-3.3915,  5.8829, -1.7606],
        [-1.3511, -2.4635,  3.5133],
        [ 0.2953, -2.8133,  0.3571],
        [-3.0330,  5.3240, -1.8325],
        [-3.6033,  6.9045, -2.7174],
        [-3.2275,  5.9963, -2.2945],
        [-0.6353, -1.6457,  1.1039],
        [-3.0826,  5.7937, -2.4778],
        [-2.9885,  5.4288, -2.0362],
        [-3.2656,  6.2213, -2.4561],
        [-1.5325, -3.2510,  4.3813],
        [-2.0990,  3.0208, -0.5981],
        [-3.2859,  6.0821, -2.4020],
        [-1.1398, -1.3475,  1.8611],
        [ 0.3752, -1.7022, -1.1411],
        [ 0.8853, -3.6972, -0.1081],
        [-0.9327, -3.2942,  3.5586],
        [-0.5071, -3.2649,  2.7285],
        [-2.8929,  3.5311, -0.0235],
        [-3.3457,  6.0663, -2.1766],
        [-2.7255,  5.0394, -2.2607],
        [-3.1140,  5.8058, -2.3280]

Epoch number 4
 Current loss 0.08840668946504593

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.6481,  4.6292, -1.8516],
        [-3.2348,  6.0372, -2.4155],
        [-3.4388,  6.4953, -2.5512],
        [-2.3923, -1.9171,  4.2487],
        [-3.1511,  5.8974, -2.3635],
        [-2.3070, -0.8751,  3.2202],
        [-1.5262, -3.1679,  4.2662],
        [-2.8591,  5.5506, -2.4075],
        [-3.1765,  4.9419, -1.0689],
        [-2.9296,  5.0343, -1.5779],
        [-0.8896, -2.1305,  2.1793],
        [-3.5825,  6.9151, -2.7594],
        [-2.0901,  2.2457, -0.2395],
        [-2.8955,  5.2130, -2.0579],
        [-3.6520,  7.0953, -2.8120],
        [-3.6443,  7.0036, -2.7407],
        [-2.0134, -2.4931,  4.3397],
        [-3.3392,  6.4152, -2.5628],
        [-2.6826,  4.9213, -2.2414],
        [-3.1522,  5.8795, -2.4752],
        [-2.9603,  5.4075, -2.0727],
        [-2.3251,  2.9478, -0.5431],
        [-3.5791,  6.9839, -2.8460],
        [-2.2146,  3.0044, -0.4849]

Epoch number 4
 Current loss 0.15809743106365204

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.5696,  6.8903, -2.7578],
        [ 0.5612, -3.5794,  0.3739],
        [-1.3830,  1.0071, -0.2965],
        [-2.7331,  5.0075, -2.0051],
        [-3.5680,  6.9798, -2.8897],
        [-3.2190,  6.0286, -2.3922],
        [-3.0081,  5.6996, -2.2379],
        [-2.1168, -0.5142,  2.7097],
        [-3.0775,  5.8066, -2.4453],
        [-2.5126,  4.7451, -2.5138],
        [-2.9336,  5.0815, -1.6062],
        [-1.9316, -0.2107,  2.1246],
        [-1.2988, -3.1222,  3.9949],
        [ 0.7725, -3.2841, -0.1875],
        [-2.7950,  2.5184,  0.7917],
        [ 0.3141, -2.2696, -0.4641],
        [-1.8847, -2.9370,  4.5732],
        [ 0.7552, -3.5095,  0.0283],
        [-3.3816,  5.5990, -1.2817],
        [-2.8096,  5.2478, -2.3941],
        [-3.5467,  6.2332, -2.0404],
        [-0.2748, -3.0103,  1.9236],
        [-3.5410,  6.3229, -2.0648],
        [-3.4381,  6.5571, -2.6413]

Epoch number 4
 Current loss 0.1274096816778183

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.6573,  6.9686, -2.6818],
        [-3.1005,  5.5603, -1.9486],
        [-1.2108, -1.8507,  2.5647],
        [-1.3293,  1.8808, -1.0901],
        [-2.7000,  4.6816, -1.7007],
        [-1.4943, -1.3936,  2.7169],
        [-3.0318,  4.9578, -1.2183],
        [-3.3623,  6.3693, -2.6180],
        [-3.4471,  6.4426, -2.3856],
        [-3.5404,  6.8302, -2.7767],
        [-3.0822,  5.8665, -2.6531],
        [-1.7905, -2.6014,  4.1920],
        [-0.7350, -2.2987,  1.9165],
        [-0.0516, -3.2635,  1.7242],
        [-2.2516,  3.8451, -1.8210],
        [-2.1316,  2.0613, -0.0851],
        [-1.4941, -2.3997,  3.6516],
        [-3.4029,  6.1915, -2.1143],
        [ 0.2797, -3.7113,  1.4281],
        [-0.4877, -2.4072,  1.5009],
        [-3.3929,  6.6386, -2.8420],
        [-1.5457, -2.9465,  4.1164],
        [-3.0143,  4.5689, -0.9364],
        [-3.5011,  6.6374, -2.5731],

Epoch number 4
 Current loss 0.1261814385652542

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.1363,  5.9268, -2.5376],
        [-2.5876,  1.1786,  1.8399],
        [-3.2747,  6.2248, -2.5566],
        [-0.7392, -3.3095,  3.3125],
        [-3.6242,  7.0005, -2.8060],
        [-3.0062,  5.6348, -2.4547],
        [-1.5664,  2.2041, -1.3502],
        [-1.7271, -3.2336,  4.6116],
        [-3.0670,  5.4756, -2.0043],
        [-2.7628,  2.6540,  0.5686],
        [-1.3180, -2.9460,  3.8371],
        [-2.0344, -1.2940,  3.3031],
        [ 0.1638, -3.6183,  1.7842],
        [-3.1656,  5.7113, -2.1448],
        [-3.2312,  6.0658, -2.5002],
        [-1.2320, -3.2274,  4.1229],
        [-3.2360,  6.2954, -2.7764],
        [-2.2265,  3.0362, -0.8755],
        [-1.2635, -3.2361,  4.0736],
        [ 0.0695, -2.2001,  0.1235],
        [-3.6446,  6.8506, -2.5730],
        [-2.6462,  4.6001, -1.9033],
        [-1.2455, -2.2879,  3.2172],
        [-2.9883,  5.0576, -1.3993],

Epoch number 4
 Current loss 0.1610337793827057

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 0.5687, -4.0065,  0.9168],
        [-1.8678, -2.7372,  4.3820],
        [-3.1119,  5.7691, -2.3492],
        [-3.0784,  5.4841, -2.2219],
        [-3.1530,  5.6510, -2.1228],
        [-3.6361,  6.8437, -2.6439],
        [-3.1973,  5.3741, -1.6673],
        [-2.9543,  5.0290, -1.5378],
        [-2.8592,  5.2201, -2.3003],
        [ 0.3831, -3.5081,  0.9807],
        [-0.8347, -3.2942,  3.3781],
        [-1.4964, -3.5529,  4.6761],
        [-2.1525,  1.3339,  0.8888],
        [-2.8355,  5.2536, -2.3390],
        [-3.0752,  4.8756, -1.2357],
        [-3.1076,  6.0976, -2.8687],
        [-3.4118,  6.5773, -2.7787],
        [-0.7209, -3.4649,  3.3712],
        [-1.2075, -3.6085,  4.3292],
        [-3.1644,  5.4278, -1.6869],
        [ 0.1317, -2.3716,  0.3011],
        [-3.0549,  5.4100, -1.7616],
        [-2.0577,  3.3521, -1.8259],
        [-2.7598,  4.6359, -1.5866],

Epoch number 4
 Current loss 0.11733192950487137

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 0.1031, -3.3018,  1.4401],
        [-3.2595,  6.2820, -2.6588],
        [ 0.4658, -3.3947,  0.9126],
        [-3.0698,  5.6092, -2.1555],
        [-2.6101,  4.8801, -2.2418],
        [-1.7811, -3.0974,  4.5480],
        [-0.3674, -2.6817,  1.5595],
        [-2.9444,  5.4207, -2.2926],
        [-3.3573,  6.5026, -2.7470],
        [-2.9906,  5.6997, -2.5348],
        [-2.6737,  4.3861, -1.4389],
        [-1.9987,  3.2610, -1.8585],
        [-2.8828,  5.4775, -2.7811],
        [-2.0362,  3.2264, -1.5386],
        [-3.0652,  5.7501, -2.6094],
        [-3.1300,  5.6440, -2.2720],
        [-3.1220,  5.5703, -2.0836],
        [-2.8635,  4.8725, -1.5720],
        [-2.9748,  5.2457, -2.0851],
        [-1.6144, -0.4760,  1.9346],
        [ 0.3815, -2.9404,  0.0835],
        [ 0.8390, -3.7172,  0.2408],
        [-2.5434,  4.6568, -2.4568],
        [-0.5866, -3.6046,  3.4220]

Epoch number 4
 Current loss 0.04904405027627945

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.5965,  7.0144, -2.9365],
        [-3.0720,  5.7992, -2.5996],
        [-3.2379,  5.0341, -1.1848],
        [-3.1127,  5.9738, -2.7816],
        [-0.6864, -3.3952,  3.0919],
        [-3.2341,  6.1890, -2.6741],
        [-2.7552,  4.8043, -1.7953],
        [-2.8447,  4.9997, -1.9418],
        [-2.9654,  5.3399, -1.9329],
        [ 0.1479, -3.1541,  1.0207],
        [-2.7563,  4.7801, -1.9775],
        [-0.7315, -3.7492,  3.6987],
        [-3.0466,  2.9678,  0.6165],
        [-1.6035, -0.5115,  1.9193],
        [-3.2285,  6.1307, -2.4771],
        [ 0.8773, -3.5634, -0.2615],
        [-1.2187, -0.1631,  0.7906],
        [-3.0649,  5.6671, -2.3584],
        [-2.9281,  5.5300, -2.5623],
        [-3.0415,  5.2354, -1.9027],
        [-2.9895,  5.0862, -1.6322],
        [-0.8536, -3.6699,  3.8182],
        [-1.6490,  0.1153,  0.9809],
        [-2.0239,  1.0082,  0.9115]

Epoch number 4
 Current loss 0.15255102515220642

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.4800,  6.6497, -2.7828],
        [-2.2646,  3.6181, -1.1984],
        [-0.4856, -0.6439, -0.3303],
        [-1.2515, -3.1914,  4.0098],
        [-1.5319, -2.4717,  3.8694],
        [-3.0973,  5.5718, -2.0798],
        [-2.8565,  5.3643, -2.3243],
        [ 0.5875, -3.1023, -0.1666],
        [-3.6307,  6.8946, -2.6601],
        [-2.9422,  5.4549, -2.4107],
        [-0.3809, -1.6576,  0.5382],
        [-3.2591,  5.9307, -2.1182],
        [-3.1783,  5.7126, -2.0543],
        [-3.2427,  5.6326, -1.9366],
        [ 0.1754, -2.9221,  0.5992],
        [ 0.6183, -2.4238, -0.9218],
        [-0.8671, -3.2146,  3.3445],
        [-3.0240,  5.7070, -2.4501],
        [-3.8073,  5.8006, -1.1877],
        [-2.7155,  4.9651, -2.2753],
        [-3.3767,  6.5276, -2.7454],
        [-3.3875,  6.3634, -2.5238],
        [-3.0896,  5.8023, -2.6027],
        [-3.3275,  4.7117, -0.6377]

Epoch number 4
 Current loss 0.13691936433315277

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.7649,  5.4040, -2.8056],
        [-3.0655,  5.8094, -2.6543],
        [-2.6901,  4.7929, -2.1783],
        [ 0.7874, -3.3518, -0.2598],
        [-3.2406,  6.2089, -2.6857],
        [-2.8694,  5.5251, -2.6640],
        [-2.7533,  5.2653, -2.5351],
        [-2.4272,  4.4255, -2.3554],
        [ 0.7291, -3.6986,  0.1771],
        [-2.1791,  4.0481, -2.2809],
        [-2.7456,  5.0547, -2.3292],
        [-1.4908, -2.7902,  3.9161],
        [-2.8756,  5.5455, -2.7162],
        [-1.1526, -3.0674,  3.7006],
        [-3.1907,  6.0595, -2.5989],
        [-1.2090, -1.0931,  1.7951],
        [ 0.4078, -3.1936,  0.4196],
        [-2.6816,  4.7574, -2.1571],
        [-2.9012,  5.5815, -2.6708],
        [-1.3055, -3.2102,  4.1524],
        [-1.8394, -0.2278,  2.0734],
        [-3.1351,  5.5434, -1.7966],
        [-1.4954, -3.4925,  4.5680],
        [-3.0684,  5.9674, -3.0431]

Epoch number 4
 Current loss 0.15200896561145782

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.2661, -3.6023,  4.3939],
        [-3.1419,  5.6771, -2.2008],
        [-2.8605,  2.8715,  0.4359],
        [-0.2340, -3.3513,  2.0222],
        [ 0.5395, -3.9118,  0.8552],
        [-1.2259, -2.1489,  2.8215],
        [-0.9247, -2.6986,  2.9867],
        [-1.0106, -3.3518,  3.8113],
        [-2.8332,  5.4063, -2.5101],
        [-3.0654,  5.9057, -2.9534],
        [-1.2278, -2.9303,  3.7644],
        [-2.6893,  4.5755, -1.9182],
        [-3.4098,  6.5933, -2.8146],
        [ 0.7386, -3.9424,  0.4886],
        [-2.9993,  5.8024, -2.8173],
        [-3.4451,  6.7842, -2.9934],
        [-1.2088, -2.9094,  3.7180],
        [-3.0210,  5.8079, -2.6547],
        [-2.7251,  4.9278, -2.1037],
        [-2.5804,  4.6386, -1.8402],
        [-1.6565, -1.6528,  3.1215],
        [ 0.7790, -3.8001,  0.3590],
        [-2.0477,  3.4185, -1.8182],
        [-1.2389, -3.5062,  4.2552]

Epoch number 4
 Current loss 0.11510303616523743

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.6786, -3.4707,  3.1868],
        [-1.3529, -1.6020,  2.4652],
        [-1.2962, -2.8896,  3.8248],
        [-1.6788, -3.4064,  4.6941],
        [-2.3213,  4.4317, -2.6292],
        [-1.4760, -3.5308,  4.6021],
        [-3.5174,  6.9164, -3.0402],
        [ 0.3801, -1.9083, -0.9447],
        [-0.0868, -3.4828,  2.1269],
        [-2.2051,  3.6690, -1.2232],
        [-1.3183, -3.2100,  4.0879],
        [-0.0567, -1.7234,  0.0409],
        [-3.1182,  5.8819, -2.7038],
        [-0.8368, -3.0720,  3.1023],
        [-3.1943,  6.0592, -2.5922],
        [-2.8858,  5.0396, -1.8177],
        [ 0.8404, -3.7352,  0.0447],
        [-2.9086,  5.5825, -2.5767],
        [-1.2238, -1.4167,  2.0652],
        [-3.4167,  6.5708, -2.7787],
        [-3.3115,  2.6848,  1.0892],
        [-1.0589, -3.1534,  3.6773],
        [-1.0144, -3.1739,  3.5777],
        [-3.4549,  6.7083, -2.9270]

Epoch number 4
 Current loss 0.06839904189109802

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.1039,  3.7774, -1.7159],
        [-2.8417,  5.2055, -2.4192],
        [-3.0605,  5.9104, -2.9242],
        [-3.5094,  6.8276, -2.9521],
        [-3.4690,  6.7250, -3.0200],
        [ 0.8656, -3.5247, -0.0074],
        [-3.6005,  6.7891, -2.6770],
        [-3.2807,  6.3627, -2.9367],
        [-3.7749,  7.2479, -2.9002],
        [-2.3966,  4.6186, -2.7850],
        [ 0.9288, -3.6856, -0.2572],
        [-3.5990,  6.8269, -2.7448],
        [-2.7512,  4.6651, -1.8431],
        [ 0.3948, -2.9825,  0.6034],
        [-3.5400,  6.7109, -2.7082],
        [-3.7215,  7.0658, -2.7775],
        [-2.7896,  5.3204, -2.6578],
        [-3.4384,  6.7535, -3.0260],
        [-3.0063,  5.6597, -2.6626],
        [-2.8553,  5.3312, -2.4970],
        [-2.8277,  5.3411, -2.5524],
        [-2.8802,  4.8066, -1.5037],
        [-3.3218,  6.5510, -3.0383],
        [-3.0529,  5.5750, -2.0530]

Epoch number 4
 Current loss 0.13049942255020142

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.6665,  5.0566, -2.6233],
        [-0.6561, -2.9331,  2.9064],
        [-3.0719,  5.8685, -2.8323],
        [-3.0599,  5.4840, -2.1056],
        [-1.3915, -1.5838,  2.7754],
        [-3.6924,  7.0970, -2.8997],
        [-1.6508,  2.1949, -1.1023],
        [-0.7766, -2.0493,  1.8500],
        [ 0.7657, -3.8162,  0.3996],
        [-1.0467, -3.5940,  4.0657],
        [-3.1463,  5.9153, -2.5612],
        [-3.6272,  6.9727, -2.8866],
        [-3.2058,  5.8786, -2.2808],
        [-3.1162,  5.7922, -2.4064],
        [-3.2169,  5.8586, -2.0275],
        [-3.5900,  7.1464, -3.2198],
        [-2.9289,  5.4870, -2.2470],
        [-3.7598,  7.1382, -2.8572],
        [-3.4710,  5.7001, -1.5499],
        [ 0.1412, -3.7430,  1.7582],
        [-3.3246,  5.9092, -2.0182],
        [-3.6572,  6.4106, -2.0914],
        [-0.4539, -2.7601,  1.7944],
        [-3.6013,  6.4845, -2.2735]

Epoch number 4
 Current loss 0.1024102121591568

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.3995,  2.9751, -0.2689],
        [-3.3370,  6.3039, -2.6959],
        [-1.5858, -2.8023,  4.1100],
        [-3.7581,  7.1363, -2.8067],
        [-1.0083, -3.3663,  3.9217],
        [-3.3317,  6.0889, -2.4807],
        [-0.7704, -3.1870,  3.0151],
        [-3.4475,  6.6155, -2.7691],
        [-3.6731,  7.1498, -2.9681],
        [-3.3086,  6.1979, -2.5800],
        [-2.3177,  3.1524, -1.0596],
        [-1.3163, -3.3488,  4.3186],
        [-3.2480,  6.0372, -2.3096],
        [-3.1222,  3.0754,  0.6987],
        [-0.1362, -1.9323,  0.3752],
        [-3.3010,  5.3971, -1.3232],
        [-1.4257, -3.3776,  4.3843],
        [ 0.0834, -2.6248,  0.7802],
        [-3.3497,  6.0480, -2.1446],
        [-0.9804, -3.2341,  3.7990],
        [-2.1817, -2.5353,  4.5799],
        [-2.9634,  5.5604, -2.5115],
        [-3.6416,  6.7521, -2.5235],
        [-3.0909,  5.0060, -1.4783],

Epoch number 4
 Current loss 0.10266601294279099

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.0904,  4.9770, -1.4303],
        [-2.7743,  4.2770, -1.1636],
        [ 0.6172, -3.5945,  0.5087],
        [-3.6484,  6.8420, -2.7090],
        [-3.2611,  5.7192, -1.7353],
        [-3.5747,  6.6522, -2.6345],
        [-3.0170,  5.2578, -1.8152],
        [-0.4795, -3.5261,  2.9809],
        [-0.3892, -3.4735,  2.7995],
        [-3.6059,  6.6945, -2.5112],
        [-1.7485, -3.2054,  4.6110],
        [-1.3545, -2.6870,  3.6729],
        [-3.6039,  6.6855, -2.4912],
        [-3.7252,  7.2796, -3.0506],
        [-3.2024,  5.9613, -2.4189],
        [-2.6762,  4.8474, -2.0524],
        [-2.7997,  2.5393,  0.6192],
        [ 0.7604, -3.6279,  0.0926],
        [-2.8366,  5.0183, -1.9433],
        [-2.6705,  3.9185, -1.1303],
        [-1.4885, -3.0404,  4.1356],
        [-3.5356,  6.8502, -2.9522],
        [-3.0656,  4.8193, -1.2887],
        [-0.4905, -3.0263,  2.4479]

Epoch number 4
 Current loss 0.07662448287010193

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.6113, -3.1753,  4.4359],
        [-1.6817, -2.0139,  3.5283],
        [-3.4815,  6.1176, -2.1942],
        [ 0.0772, -3.7319,  1.9755],
        [-3.1340,  5.8406, -2.5288],
        [-3.5108,  6.6730, -2.7091],
        [-2.0356,  2.7888, -0.6131],
        [-3.2009,  5.9553, -2.5423],
        [-1.3628, -2.6528,  3.6740],
        [-1.1374,  0.6056, -0.8348],
        [-3.1283,  5.4228, -1.7202],
        [-3.6232,  6.9084, -2.7900],
        [-3.2709,  5.7676, -1.8420],
        [-2.9998,  5.4687, -2.4027],
        [-1.3230, -2.9377,  3.9050],
        [-2.9714,  4.8352, -1.5136],
        [-1.4283, -1.8533,  2.8689],
        [ 0.2044, -3.3506,  1.2286],
        [-1.6377, -2.4653,  3.8582],
        [ 0.8152, -3.6856,  0.0399],
        [-3.1240,  5.4674, -1.8244],
        [-1.2552, -1.8133,  2.7065],
        [ 0.6140, -3.3399,  0.1733],
        [-2.9714,  5.0980, -1.8618]

Epoch number 4
 Current loss 0.10773206502199173

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.6210, -3.4905,  3.1268],
        [-3.7179,  6.8189, -2.4929],
        [-2.6946,  3.8911, -0.9229],
        [-0.8949, -1.0349,  0.8520],
        [-3.5843,  6.4580, -2.3019],
        [-2.2235,  3.9318, -2.3634],
        [-3.1546,  5.8496, -2.5543],
        [-3.3334,  6.2145, -2.5660],
        [-2.5634,  4.4742, -1.9714],
        [-3.6460,  4.0569,  0.3118],
        [-0.7523, -3.1759,  2.9936],
        [-1.2188, -2.8937,  3.7401],
        [-3.0711,  5.7603, -2.5990],
        [-1.1754, -1.9744,  2.8938],
        [-3.6491,  7.0903, -2.9543],
        [-2.1675,  1.5192,  0.5937],
        [-3.4385,  6.3894, -2.6344],
        [ 0.1813, -2.9066,  0.5294],
        [-3.4597,  5.8764, -1.8501],
        [-2.1267,  3.0245, -0.9184],
        [-2.5162,  2.1764,  0.5897],
        [-3.2666,  5.7535, -2.0688],
        [-0.5488, -3.3983,  2.6950],
        [-1.9372, -2.9442,  4.6091]

Epoch number 4
 Current loss 0.1440911740064621

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.6146,  6.9722, -2.9547],
        [-0.5695, -3.0380,  2.4625],
        [-3.5527,  6.6028, -2.5537],
        [ 0.1628, -1.6878, -0.9236],
        [-1.3737, -3.5371,  4.4328],
        [-3.6478,  6.9711, -2.9003],
        [-3.2984,  6.0741, -2.4262],
        [-2.9956,  5.3772, -2.0943],
        [ 0.0721, -1.3530, -0.7241],
        [-3.2912,  5.4437, -1.6533],
        [-3.7880,  5.7832, -1.1384],
        [-3.2738,  5.4176, -1.6851],
        [-1.0809, -2.5982,  3.0246],
        [-1.9278, -0.2106,  2.1425],
        [ 0.5774, -3.3756,  0.3633],
        [-2.9565,  5.5261, -2.5241],
        [-3.5710,  6.2988, -2.0791],
        [-0.8880, -2.8507,  2.9351],
        [ 0.2470, -3.1659,  0.7748],
        [-3.7523,  7.0721, -2.7406],
        [-3.8162,  7.3410, -2.9557],
        [-2.9792,  3.7829, -0.1563],
        [-3.6851,  6.7983, -2.5723],
        [-1.4597, -1.2724,  2.3871],

Epoch number 4
 Current loss 0.15540754795074463

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.5243,  6.6205, -2.6323],
        [-2.4460,  4.4302, -1.9796],
        [ 0.5818, -3.4456,  0.3783],
        [-3.6477,  7.0129, -2.8521],
        [-3.6083,  6.2608, -1.9299],
        [-1.8747, -2.6340,  4.2869],
        [-0.4299, -3.5276,  2.7780],
        [-3.7052,  7.1148, -2.9448],
        [-3.7065,  6.0972, -1.5887],
        [-0.8609, -3.1511,  3.1654],
        [ 0.2467, -1.9095, -0.6612],
        [-1.1417, -2.8450,  3.4909],
        [-3.5763,  6.6285, -2.5636],
        [-2.2120,  2.8027, -0.3821],
        [-1.8384, -1.8556,  3.6180],
        [-1.1941, -1.2451,  2.0457],
        [-2.0501, -1.2609,  3.2065],
        [-2.8571,  5.0483, -2.0570],
        [-3.9611,  7.4429, -2.7806],
        [-3.7652,  7.0239, -2.7509],
        [-3.9002,  7.0012, -2.3863],
        [-0.4507, -3.5758,  2.9180],
        [-1.4898, -2.4754,  3.6394],
        [ 0.5624, -3.7128,  1.0737]

Epoch number 4
 Current loss 0.056068696081638336

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.2465,  5.8962, -2.2333],
        [-1.1589,  1.7087, -1.3414],
        [-3.6844,  6.4645, -2.0668],
        [-0.4643, -1.7237,  0.5916],
        [-3.2861,  5.5916, -1.5900],
        [-1.1368, -3.4559,  4.1618],
        [-2.1871,  3.6590, -1.6563],
        [-3.5898,  6.4470, -2.2157],
        [-1.4135, -2.1880,  3.3734],
        [ 0.5665, -3.5435,  0.6388],
        [ 0.7622, -2.9594, -0.4909],
        [-3.4606,  6.5507, -2.6768],
        [-1.8112, -1.1667,  2.9242],
        [-3.7657,  6.9093, -2.5204],
        [-1.1342, -0.5923,  0.9358],
        [-1.0734, -3.5651,  4.1575],
        [-0.3675, -3.6707,  2.7953],
        [ 0.9342, -3.7591, -0.1179],
        [-1.3080, -3.2076,  4.1270],
        [-3.6283,  6.8088, -2.7019],
        [-3.8779,  7.0890, -2.4885],
        [-3.1836,  5.0164, -1.1908],
        [-3.8379,  7.2324, -2.8137],
        [ 0.0542, -2.7086,  0.5630

Epoch number 4
 Current loss 0.09394494444131851

inputs are
torch.Size([100, 87])
torch.Size([100, 87])
OUTPUT
tensor([[-3.7279,  6.9066, -2.5529],
        [-3.1257,  5.0651, -1.2558],
        [-2.5856,  2.4632,  0.5068],
        [-3.7229,  6.9404, -2.6407],
        [-0.0200, -3.9216,  2.3585],
        [-2.5780,  4.4229, -1.4915],
        [-3.8177,  6.3609, -1.6025],
        [-3.5345,  6.3760, -2.2519],
        [-3.6712,  6.7797, -2.5744],
        [-3.5571,  6.3754, -2.2386],
        [-2.8909,  4.3366, -0.8763],
        [-3.4500,  6.1881, -2.1990],
        [-3.2935,  6.0864, -2.3120],
        [-3.8056,  6.2728, -1.6449],
        [-3.0976,  5.3210, -1.6686],
        [-3.4882,  6.0276, -1.8329],
        [-3.2969,  4.8548, -0.7609],
        [-2.4778,  4.0963, -1.5437],
        [-2.2344,  1.6112,  0.6362],
        [ 0.5348, -3.6294,  0.9424],
        [-1.6758, -1.2433,  2.8005],
        [-0.6926, -2.0833,  1.6168],
        [-3.2692,  4.7929, -0.9053],
        [-3.7969,  7.0306, -2.6070],


Epoch number 4
 Current loss 0.06237168610095978

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.6230,  4.2621, -1.4527],
        [-1.1428, -2.0347,  2.6441],
        [-0.9190, -2.9379,  3.0726],
        [-3.4615,  5.8494, -1.5744],
        [-3.8552,  6.8001, -2.0624],
        [-2.6231,  3.8801, -1.0876],
        [-2.8579,  4.9679, -1.7316],
        [-2.8298,  3.9622, -0.7435],
        [-2.2669,  1.7917,  0.5323],
        [ 0.3592, -3.2145,  0.6229],
        [-1.3988, -3.3280,  4.2851],
        [-1.5801, -3.1040,  4.3696],
        [-1.9427, -1.6751,  3.5456],
        [-4.0200,  7.3645, -2.4922],
        [-2.9859,  5.1790, -1.8748],
        [-3.2219,  5.1037, -1.3454],
        [-2.6203,  4.2114, -1.4371],
        [-3.8958,  7.1859, -2.6279],
        [-3.3950,  5.8130, -1.8384],
        [-3.5130,  5.0995, -0.6441],
        [-2.1038,  3.7841, -2.3092],
        [-3.8526,  6.8300, -2.1393],
        [-3.6159,  6.6027, -2.3821],
        [-2.3348,  4.5454, -2.3010]

Epoch number 4
 Current loss 0.1412161886692047

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.4804, -1.1639,  2.3169],
        [-2.7589,  4.8636, -1.9312],
        [-2.5189,  3.5358, -0.7026],
        [-3.8133,  7.2839, -2.8996],
        [ 0.6564, -3.9998,  0.7855],
        [-3.3635,  6.2403, -2.5124],
        [-3.7858,  6.6398, -2.0418],
        [-3.1096,  5.2553, -1.5419],
        [-3.4900,  6.4199, -2.3278],
        [-0.4535, -1.8152,  0.9416],
        [-0.2545, -3.7516,  2.6285],
        [-3.2201,  4.5021, -0.7124],
        [-2.5115,  4.1769, -1.5370],
        [-3.6474,  6.4647, -2.2416],
        [-2.4781,  3.3253, -0.7525],
        [-2.2596,  0.5158,  1.8912],
        [-3.2953,  5.9689, -2.2554],
        [-3.1656,  5.4537, -1.7766],
        [-3.2754,  6.2150, -2.6104],
        [-3.3273,  5.4057, -1.3604],
        [-2.9524,  5.0563, -1.6787],
        [-2.6954,  4.4884, -1.5884],
        [-1.9088,  2.9747, -1.6468],
        [-1.6279, -3.1815,  4.4699],

Epoch number 4
 Current loss 0.12115173786878586

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.3951,  6.2798, -2.4999],
        [ 0.9342, -4.0192, -0.0515],
        [-3.6648,  6.9109, -2.6042],
        [-3.6700,  6.6621, -2.3975],
        [-3.4832,  6.2503, -2.2615],
        [-2.0743,  3.6284, -2.1946],
        [-2.8457,  4.3784, -1.1396],
        [-3.9191,  6.5494, -1.7435],
        [-1.4197, -2.7514,  3.8347],
        [-3.1912,  5.6118, -1.9440],
        [-3.2892,  6.0261, -2.3303],
        [-3.3633,  5.3972, -1.2807],
        [-2.5549,  1.5080,  1.3542],
        [-3.4139,  6.1741, -2.2306],
        [-4.0462,  7.2366, -2.3113],
        [-3.6793,  6.6443, -2.3798],
        [-0.8067, -3.0561,  3.0621],
        [-3.2100,  5.3309, -1.4816],
        [-3.9068,  7.1795, -2.5698],
        [-1.6458, -1.7445,  3.2209],
        [-0.3335, -3.5715,  2.5475],
        [-0.7512, -3.0270,  2.9119],
        [-3.2278,  5.5522, -1.8811],
        [-3.8465,  7.0118, -2.3706]

Epoch number 4
 Current loss 0.1204693466424942

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.1704, -3.7555,  4.4322],
        [-2.0385,  0.0154,  2.0492],
        [-2.8576,  5.0096, -1.9121],
        [-1.0269, -3.2407,  3.6777],
        [-3.6250,  6.5597, -2.4054],
        [-2.4894, -1.8962,  4.3304],
        [-2.2533, -1.1882,  3.5232],
        [-1.9690, -2.9525,  4.6212],
        [-1.4082, -2.9440,  3.8806],
        [-2.4315, -1.7284,  4.0627],
        [-3.3985,  6.0922, -2.2488],
        [-3.3526,  5.9264, -2.0733],
        [-3.0177,  5.4015, -2.1227],
        [-2.7327,  3.7729, -0.7086],
        [-2.8216,  2.5393,  0.7176],
        [-3.3838,  6.2599, -2.4654],
        [-1.2295, -2.7142,  3.4829],
        [-1.6058,  2.2925, -1.1243],
        [-3.1340,  5.8443, -2.5594],
        [-3.4263,  6.3778, -2.6555],
        [-3.8839,  6.9858, -2.3549],
        [-1.2254, -3.7313,  4.4550],
        [-3.3576,  5.8723, -2.0307],
        [-3.7824,  7.0363, -2.6899],

Epoch number 4
 Current loss 0.0783727616071701

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.0990,  5.6532, -2.2620],
        [-3.4473,  6.6230, -2.8690],
        [-2.8336,  5.3031, -2.5353],
        [-2.7641,  4.8022, -1.8594],
        [-3.6904,  6.7695, -2.6092],
        [-3.6909,  5.3112, -0.8311],
        [-3.1866,  5.2763, -1.4517],
        [-2.8147,  5.0579, -2.3093],
        [-3.1942,  6.1583, -2.7734],
        [-3.6705,  7.0129, -2.8157],
        [ 0.9333, -3.2503, -0.6936],
        [-3.1614,  5.5055, -1.9148],
        [ 0.5886, -2.5002, -0.7126],
        [-3.0329,  5.5858, -2.1976],
        [-3.0790,  4.4965, -0.8578],
        [-2.5796,  4.8435, -2.5642],
        [-1.4955, -3.4196,  4.4993],
        [-0.2168, -3.2993,  2.0278],
        [-2.7952,  3.3121, -0.0609],
        [-0.0388, -3.5955,  1.9207],
        [-3.2689,  5.6603, -2.0407],
        [-3.3939,  5.4527, -1.4614],
        [-0.8540, -2.6547,  2.7673],
        [-2.9727,  5.1800, -2.0788],

Epoch number 4
 Current loss 0.12751826643943787

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.9805, -3.3966,  3.7996],
        [-3.1569,  5.7563, -2.3862],
        [-0.8643, -3.4106,  3.5165],
        [-2.2501, -1.7486,  3.9216],
        [-1.4527,  0.9424, -0.0983],
        [-1.9222, -2.4530,  4.2182],
        [-1.3776, -1.0238,  1.7685],
        [-3.1313,  5.7605, -2.4272],
        [ 0.7750, -4.0210,  0.6079],
        [-3.5138,  5.9198, -1.8041],
        [-3.5279,  4.9483, -0.6893],
        [-0.6681, -2.3757,  1.9386],
        [-3.4116,  4.0774,  0.1534],
        [-1.0779, -3.4966,  4.0251],
        [-3.4170,  5.7883, -1.8125],
        [ 0.4396, -3.9247,  1.2761],
        [-1.6295, -2.4163,  3.7307],
        [-2.0714, -0.0752,  2.2367],
        [-3.7508,  7.0350, -2.7237],
        [-3.6540,  5.9815, -1.5768],
        [-3.8295,  6.5372, -1.9109],
        [-2.9784,  4.7791, -1.3609],
        [-3.2620,  5.3985, -1.4424],
        [-2.3328,  3.7880, -1.7265]

Epoch number 4
 Current loss 0.11160293221473694

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.2503,  4.2068, -2.6012],
        [-2.7969,  5.1513, -2.4984],
        [-2.6765,  4.8401, -2.2166],
        [-1.1132, -0.3148,  0.6424],
        [-3.0256,  5.0239, -1.5499],
        [-3.5367,  6.4787, -2.3577],
        [-1.4304, -3.5639,  4.5722],
        [-3.1437,  5.6618, -2.0999],
        [-3.1191,  5.4725, -1.7912],
        [-3.6932,  7.1433, -2.9996],
        [ 0.6693, -3.7786,  0.7676],
        [-1.1685, -3.8199,  4.4716],
        [-3.0050,  5.4571, -2.3731],
        [-3.2573,  6.0301, -2.3847],
        [-3.6927,  7.0986, -2.9569],
        [-3.5774,  7.0035, -3.0760],
        [-3.3548,  4.8749, -0.7743],
        [-1.3950, -2.7196,  3.7544],
        [-3.6214,  6.7026, -2.6059],
        [-3.3555,  6.5100, -2.8555],
        [ 0.8259, -4.1065,  0.4676],
        [-2.5656,  3.6944, -1.0531],
        [-2.4131,  2.2550,  0.1522],
        [-3.0762,  5.4257, -1.8176]

Epoch number 4
 Current loss 0.1191234290599823

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.4677, -3.8545,  3.2188],
        [-2.6403,  4.1685, -1.4523],
        [-3.2456,  5.9601, -2.4365],
        [-3.5501,  6.8240, -2.8409],
        [-3.5936,  6.5950, -2.4132],
        [-1.5913,  2.7354, -1.5641],
        [-3.1227,  5.4103, -2.0622],
        [-3.2722,  5.8984, -2.2789],
        [-1.1690, -3.5331,  4.1590],
        [-0.7684, -3.4315,  3.3582],
        [ 0.9305, -3.8965, -0.0332],
        [-0.3394, -1.8943,  0.5926],
        [-3.1355,  5.6011, -2.1547],
        [-3.5074,  5.0629, -0.9384],
        [-3.6411,  4.6145, -0.1580],
        [-3.4368,  6.5151, -2.5944],
        [-3.3244,  4.7547, -0.7492],
        [-2.9549,  5.5177, -2.6432],
        [-2.9589,  4.8643, -1.4947],
        [-3.9216,  7.5163, -2.9787],
        [-3.3887,  6.0511, -2.1133],
        [-3.2221,  5.8530, -2.3252],
        [-2.4428,  4.4180, -2.2309],
        [-2.8070,  0.1035,  2.9829],

Epoch number 4
 Current loss 0.0991588905453682

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.1253, -3.4430,  4.0585],
        [-3.7879,  7.1438, -2.7824],
        [-1.1068, -2.9988,  3.7248],
        [ 0.4056, -3.4949,  0.6189],
        [-1.0737, -3.7450,  4.2998],
        [-0.7238, -4.0043,  3.9040],
        [-2.5100,  3.9755, -1.5688],
        [-3.3147,  6.0883, -2.3565],
        [-1.4660, -2.7643,  3.8469],
        [-2.9766,  4.5110, -1.1890],
        [-1.2904, -3.8787,  4.6712],
        [ 0.0940, -3.6393,  1.7708],
        [-0.7915, -2.2429,  1.8499],
        [-3.3311,  6.4468, -3.0448],
        [ 0.1857, -3.5797,  1.4482],
        [-3.6646,  6.0149, -1.6933],
        [-2.6269,  4.9196, -2.6262],
        [-3.6001,  6.9483, -2.9284],
        [-3.2926,  5.6348, -1.9324],
        [-1.0661, -2.1944,  2.4279],
        [-1.8086, -2.8683,  4.4419],
        [ 0.3905, -3.4178,  0.8262],
        [-3.2156,  5.6334, -1.9177],
        [-3.3964,  3.9547,  0.0984],

Epoch number 5
 Current loss 0.08045840263366699

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.1298, -2.1465,  4.1042],
        [-3.7533,  6.4129, -2.0026],
        [-3.3642,  6.1614, -2.3174],
        [-2.8714,  4.6179, -1.5017],
        [-3.5942,  6.4775, -2.3916],
        [-1.7292, -3.1233,  4.5140],
        [-3.8384,  7.0872, -2.6863],
        [-2.0492,  2.9574, -1.4426],
        [-3.7812,  7.1415, -2.9661],
        [-2.0038,  2.2164, -0.4801],
        [-3.5979,  6.8335, -2.8245],
        [-3.9513,  6.5073, -1.7414],
        [-2.6176,  3.3779, -0.4195],
        [-3.8190,  7.3400, -2.9405],
        [ 0.1342, -3.0950,  0.7558],
        [-3.5113,  6.0062, -1.9189],
        [-3.5917,  6.7027, -2.6467],
        [-0.2633, -3.8481,  2.9489],
        [-0.7109, -3.4336,  3.1920],
        [-3.2661,  5.8671, -2.1490],
        [ 1.0095, -4.0476, -0.0262],
        [-2.6326,  3.2585, -0.1432],
        [-0.9567, -3.4791,  3.8033],
        [-2.0094, -1.5092,  3.4639]

Epoch number 5
 Current loss 0.07943509519100189

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.9911,  4.3206, -1.0130],
        [-1.0471, -3.5228,  3.9528],
        [-0.2569, -0.0479, -1.6751],
        [-3.1928,  5.7349, -2.3139],
        [-3.7693,  5.7862, -1.1797],
        [-3.6763,  6.4889, -2.3116],
        [-1.6109, -3.0768,  4.3138],
        [-1.5359, -2.2638,  3.6007],
        [-3.1160,  5.4921, -2.0173],
        [-3.7500,  7.0807, -2.8949],
        [-3.3659,  3.4590,  0.5516],
        [-1.5744, -3.3063,  4.4884],
        [-3.8225,  7.3094, -2.9360],
        [-3.1601,  6.0549, -2.6778],
        [-0.8655, -3.2203,  3.2629],
        [ 1.0395, -3.4684, -0.7567],
        [-3.7818,  7.1279, -2.7720],
        [-0.0340, -1.8018, -0.1509],
        [-1.7530, -2.2961,  3.8245],
        [ 0.2039, -3.3106,  0.9520],
        [-3.7865,  7.1928, -2.9897],
        [-1.9467, -3.1752,  4.7940],
        [-3.4546,  5.5544, -1.5609],
        [-1.5497, -2.9554,  4.1800]

Epoch number 5
 Current loss 0.059058379381895065

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.1301, -2.7935,  3.3635],
        [-3.8678,  7.2519, -2.8416],
        [-3.9661,  7.1882, -2.6321],
        [-3.1506,  5.4466, -1.9937],
        [ 0.4666, -3.7663,  1.0772],
        [-3.7820,  6.6017, -2.2900],
        [-2.7600,  4.7596, -1.8670],
        [-3.0333,  5.6349, -2.6503],
        [-3.5471,  6.5513, -2.5285],
        [-3.2429,  4.5304, -0.6288],
        [-3.8949,  7.1373, -2.6954],
        [-3.4170,  6.3777, -2.6123],
        [-3.5019,  6.6764, -2.8418],
        [-3.9204,  5.6397, -0.8614],
        [-1.3592, -3.0424,  4.0064],
        [ 0.9868, -4.3147,  0.1391],
        [-3.6234,  6.8105, -2.8017],
        [-3.1495,  4.0431, -0.3077],
        [-3.8918,  7.2107, -2.6942],
        [-3.3998,  5.9282, -1.9173],
        [-3.5460,  6.5721, -2.5809],
        [-1.8930, -3.2838,  4.8433],
        [-2.4848,  4.6454, -2.4707],
        [-3.4867,  4.6997, -0.5563

Epoch number 5
 Current loss 0.08689722418785095

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.1317,  5.2834, -1.8807],
        [ 0.8152, -3.5901,  0.1145],
        [-3.4708,  6.5999, -2.8529],
        [-3.9339,  7.5877, -3.1362],
        [-3.3263,  5.2958, -1.4037],
        [-1.8863, -1.2981,  3.1052],
        [-2.4273,  3.2926, -1.0867],
        [-3.3695,  6.4535, -2.8111],
        [-1.8870, -3.5092,  5.0267],
        [-1.3245, -3.5453,  4.3838],
        [-3.6543,  7.0938, -3.1129],
        [-0.3158, -3.7329,  2.7836],
        [-2.7506,  5.1330, -2.8650],
        [-1.6305, -0.3913,  1.5418],
        [-3.0973,  5.6908, -2.5650],
        [-0.6642, -3.6237,  3.3824],
        [ 0.1129, -2.9688,  0.8978],
        [-3.5346,  6.4288, -2.5280],
        [-2.3150,  3.4503, -0.8879],
        [-1.0240, -3.2396,  3.6510],
        [-3.7085,  6.7300, -2.5792],
        [-3.3396,  6.0958, -2.4756],
        [-3.4587,  6.0598, -2.1043],
        [-2.3932,  3.7648, -1.7082]

Epoch number 5
 Current loss 0.045227546244859695

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.8160,  3.7926, -0.6009],
        [-4.0910,  7.1207, -2.2477],
        [-4.0476,  7.1077, -2.3617],
        [-3.9349,  7.1992, -2.6884],
        [-0.1009, -3.6474,  2.0450],
        [-2.3112,  3.4707, -1.3788],
        [ 0.9570, -3.5886, -0.1184],
        [-0.0093, -3.5084,  1.7053],
        [-3.9442,  7.5644, -3.1187],
        [-1.1015, -3.5512,  4.0040],
        [-3.1849,  5.7824, -2.5082],
        [-3.4526,  5.6733, -1.6981],
        [ 1.1000, -3.9534, -0.0343],
        [-1.8756, -2.6745,  4.2840],
        [-3.3461,  6.1350, -2.5740],
        [-3.4917,  6.2454, -2.2809],
        [-0.8975, -3.4921,  3.6481],
        [ 1.0496, -4.1815, -0.0507],
        [-3.1372,  5.3430, -1.7381],
        [-3.8731,  7.4699, -3.1508],
        [-2.0662, -2.0284,  3.9783],
        [-0.9360, -3.2636,  3.6320],
        [-2.9620,  5.1537, -2.0677],
        [-3.5499,  6.0699, -1.9488

Epoch number 5
 Current loss 0.11973035335540771

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.0130,  5.5903, -2.8412],
        [-3.5848,  6.8407, -2.9123],
        [-3.3218,  6.1955, -2.6527],
        [-3.7292,  7.1655, -3.0842],
        [-3.6189,  6.9030, -2.9841],
        [-1.0118, -2.1449,  2.3746],
        [-3.3509,  6.2751, -2.8649],
        [-3.8387,  7.3335, -3.0554],
        [-3.2636,  5.5793, -1.8006],
        [-1.3262, -2.8381,  3.7981],
        [-3.3093,  6.0023, -2.3193],
        [-3.8728,  7.5961, -3.3289],
        [-4.0454,  7.6646, -3.0813],
        [-3.6210,  6.5340, -2.4644],
        [-2.9961,  5.3684, -2.5278],
        [ 0.9413, -3.8068, -0.2950],
        [-3.3530,  6.3526, -2.9725],
        [-3.4255,  5.5287, -1.6654],
        [-3.4445,  6.2541, -2.4169],
        [-3.3928,  6.1330, -2.4262],
        [-3.9068,  7.0801, -2.5753],
        [-3.3417,  6.2470, -2.7624],
        [-1.7837, -3.8559,  5.2406],
        [-3.7911,  7.2229, -3.0185]

Epoch number 5
 Current loss 0.11381065100431442

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 0.5406, -4.0001,  1.2762],
        [-3.5172,  6.7496, -2.9684],
        [-2.6571,  4.0837, -1.5063],
        [-3.8486,  6.3610, -1.8923],
        [-0.7347, -3.8027,  3.6299],
        [-3.4076,  6.0701, -2.4520],
        [-0.8913, -4.0487,  4.2182],
        [-3.5281,  6.3722, -2.5546],
        [-3.6493,  6.6738, -2.7064],
        [-1.8565,  3.2754, -2.1406],
        [-3.0368,  5.7272, -2.9657],
        [-1.0231, -3.0481,  3.2890],
        [-3.5490,  6.7237, -2.8963],
        [-3.4970,  6.5484, -2.7334],
        [-2.8533,  5.3348, -2.7892],
        [-0.8945, -2.4561,  2.6490],
        [-3.5081,  6.8092, -3.1064],
        [-3.3905,  6.2569, -2.6564],
        [-3.6461,  7.1112, -3.2880],
        [-2.9247,  5.2004, -2.4857],
        [ 0.2633, -2.9255,  0.4808],
        [-3.3606,  5.8977, -2.1805],
        [-3.8902,  7.4734, -3.1296],
        [-3.1380,  5.6652, -2.4955]

Epoch number 5
 Current loss 0.07962318509817123

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.0896,  5.9321, -3.0488],
        [-3.4184,  5.2832, -1.4190],
        [-1.0397, -3.2672,  3.6128],
        [ 1.0753, -4.0510, -0.2198],
        [-3.1456,  5.7572, -2.6611],
        [ 0.7114, -3.9997,  0.5851],
        [-2.0004, -1.1965,  3.1670],
        [-0.0721, -0.2497, -2.5673],
        [-1.0000, -3.6068,  4.0104],
        [-1.2058, -2.2430,  2.9748],
        [-2.3547,  4.1055, -2.3957],
        [-1.5498, -3.8085,  4.9217],
        [-1.2399, -1.6321,  2.2244],
        [-3.2574,  5.4818, -1.9532],
        [-2.0691, -3.2990,  5.0748],
        [-3.1052,  5.5250, -2.4950],
        [-2.8177,  4.3494, -1.4768],
        [ 0.9594, -4.0406,  0.0208],
        [-3.5179,  6.5652, -2.7959],
        [-0.5407, -3.1037,  2.3689],
        [-1.0552, -0.9058,  1.3233],
        [-3.9330,  7.0706, -2.5276],
        [-3.2234,  6.0290, -2.8825],
        [-3.8952,  7.3124, -3.0274]

Epoch number 5
 Current loss 0.10955072939395905

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.8877,  7.5779, -3.3460],
        [-3.9292,  7.4755, -3.1730],
        [-3.8467,  7.4281, -3.1945],
        [ 1.0910, -3.6694, -0.4563],
        [-3.7564,  7.2718, -3.2062],
        [ 0.7340, -4.3476,  1.0593],
        [-3.8046,  7.3323, -3.1778],
        [-4.0139,  7.3580, -2.7740],
        [-3.7719,  6.5907, -2.3720],
        [-3.3476,  6.3302, -3.1234],
        [-3.0696,  5.6834, -2.4820],
        [ 0.8934, -4.0405,  0.3085],
        [-0.3860, -3.1453,  2.1228],
        [-3.2380,  5.6492, -2.1437],
        [-1.7853, -3.3701,  4.7629],
        [-2.0334, -1.6660,  3.6306],
        [-3.9387,  6.3574, -1.7129],
        [-0.0510, -2.6674,  0.6951],
        [-3.5232,  6.7363, -3.0305],
        [-3.7499,  6.2486, -1.9911],
        [ 0.4601, -4.1362,  1.5564],
        [-2.6578,  3.8985, -1.3210],
        [-1.3453, -3.3799,  4.3268],
        [-1.6345, -3.8305,  4.9742]

Epoch number 5
 Current loss 0.1286153495311737

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.0509,  7.1105, -2.4306],
        [-2.8382,  4.7085, -1.9507],
        [-1.4908, -4.1483,  5.1128],
        [-3.7930,  6.9997, -2.8587],
        [-3.5386,  5.5254, -1.4708],
        [-3.9475,  7.5888, -3.1873],
        [-1.5340, -3.4940,  4.6054],
        [-3.7948,  6.9647, -2.7215],
        [-0.7425, -4.1280,  3.9699],
        [-3.6686,  6.5545, -2.6169],
        [-3.8380,  7.1513, -2.9592],
        [-1.8902, -3.4879,  5.0017],
        [-3.8364,  7.1249, -2.9254],
        [ 0.6824, -2.8945, -0.7764],
        [-2.7662,  3.2496, -0.0061],
        [-0.9712, -3.3992,  3.6103],
        [ 1.0753, -3.4813, -0.9572],
        [-2.0862,  3.7600, -2.7257],
        [-3.4165,  5.6578, -1.8870],
        [-0.8922, -1.1538,  0.9195],
        [-1.8788,  3.4799, -2.2712],
        [-0.8158, -3.8581,  3.9601],
        [-1.0423, -2.9811,  3.3371],
        [-3.0793,  1.0965,  2.3228],

Epoch number 5
 Current loss 0.12239497900009155

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.9106,  7.5328, -3.3117],
        [-4.0275,  7.6170, -3.1135],
        [-3.9696,  7.1811, -2.7146],
        [-1.1263, -0.4284,  0.8801],
        [-3.7969,  7.1659, -3.0198],
        [-3.3253,  5.9915, -2.5690],
        [-1.4421,  1.8546, -1.5167],
        [-0.8835, -2.5900,  2.4563],
        [-0.7535, -3.8352,  3.8562],
        [-3.8253,  7.2890, -3.1334],
        [-3.6500,  6.3338, -2.1865],
        [-3.6115,  6.7418, -2.9445],
        [ 1.0804, -4.4084,  0.0487],
        [-2.6737,  2.7198,  0.1025],
        [-3.7680,  7.0884, -2.9672],
        [-3.9954,  7.4912, -3.0995],
        [-3.2830,  5.7146, -2.1134],
        [-1.4067, -4.2392,  5.1169],
        [-1.8642,  0.1888,  1.6295],
        [-3.9083,  7.3386, -3.0506],
        [-3.9383,  7.4381, -3.1441],
        [-1.8606,  0.4870,  0.8559],
        [-3.7159,  6.9303, -2.9382],
        [-2.0632,  4.0114, -3.1912]

Epoch number 5
 Current loss 0.05022991821169853

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.7310, -2.6347,  4.0530],
        [-0.9885, -4.4279,  4.7319],
        [-3.2052,  5.7263, -2.3979],
        [-3.9090,  7.5618, -3.3896],
        [-4.0720,  7.5151, -2.9314],
        [-3.0761,  5.0538, -1.7422],
        [-3.3960,  5.4594, -1.4148],
        [-3.2352,  5.7707, -2.3057],
        [-2.7881,  4.1541, -1.0121],
        [-1.2619, -3.4020,  4.1380],
        [-3.8897,  7.5169, -3.3053],
        [-3.7757,  6.7229, -2.4472],
        [ 0.5903, -3.4427,  0.3661],
        [-3.7491,  7.0139, -2.9626],
        [ 0.5303, -3.3653,  0.3350],
        [-1.8986, -3.1744,  4.7692],
        [-3.8140,  7.3559, -3.2082],
        [ 0.8945, -4.0718,  0.5240],
        [-3.7500,  7.1749, -3.2076],
        [-3.0049,  5.4446, -2.6862],
        [-3.9086,  7.4548, -3.2012],
        [ 0.9578, -4.2449,  0.2494],
        [-1.7963,  1.4421, -0.6706],
        [-3.0510,  5.2037, -2.0885]

Epoch number 5
 Current loss 0.06254537403583527

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.6204, -2.2598,  3.5785],
        [-3.9580,  7.5992, -3.2443],
        [ 0.9302, -3.0850, -0.9879],
        [-3.8050,  7.1718, -3.0793],
        [-3.8923,  7.1354, -2.7902],
        [-3.2670,  5.6001, -2.2222],
        [-1.2113, -3.3607,  4.0612],
        [-0.9087, -3.3607,  3.5532],
        [-3.4599,  5.5745, -1.5690],
        [-3.9497,  7.4692, -3.1022],
        [-3.4261,  6.0095, -2.3436],
        [-1.3491, -3.8086,  4.6353],
        [-2.3990,  4.0144, -2.2460],
        [-3.6737,  6.8945, -3.0548],
        [ 0.9438, -4.7249,  0.7894],
        [-1.9985,  3.7257, -2.9503],
        [-4.0937,  7.0508, -2.3800],
        [-3.8637,  7.4299, -3.3203],
        [ 0.6875, -3.4568,  0.0413],
        [-0.8278, -2.7239,  2.5051],
        [-3.2148,  4.7128, -1.2189],
        [-3.9399,  7.5394, -3.2137],
        [-3.3124,  5.9393, -2.5420],
        [-3.5797,  6.3894, -2.5804]

Epoch number 5
 Current loss 0.07545655965805054

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.6660,  4.9358, -2.8691],
        [-1.0441, -3.6844,  4.0566],
        [-3.8387,  7.1028, -2.9184],
        [-2.8691,  3.7018, -0.3477],
        [-3.6237,  6.6699, -2.7217],
        [-4.0691,  7.1339, -2.4573],
        [-1.2886, -3.8511,  4.6346],
        [-3.6798,  6.8985, -2.9194],
        [-3.7723,  6.6201, -2.3824],
        [-3.0734,  5.5548, -2.7739],
        [-1.2614, -4.0189,  4.7071],
        [-3.5226,  6.4253, -2.6248],
        [-2.6038, -0.6548,  3.2358],
        [-3.4542,  5.9575, -2.1207],
        [-4.0779,  7.1879, -2.5341],
        [-0.6087, -2.8293,  2.1386],
        [-3.4382,  5.5272, -1.7492],
        [-3.9734,  7.1673, -2.6853],
        [-2.8842,  4.4432, -1.0939],
        [-2.7628,  4.5138, -1.8511],
        [-2.4621,  3.2633, -1.0851],
        [-3.5981,  6.3952, -2.4418],
        [-3.2903,  5.4639, -1.8143],
        [-1.7523, -3.9529,  5.2368]

Epoch number 5
 Current loss 0.03778766468167305

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.5041,  6.0873, -2.4269],
        [-2.9887,  5.2467, -2.2120],
        [-2.5459,  4.2544, -2.1796],
        [-0.6344, -3.8501,  3.3242],
        [-3.9053,  7.1530, -2.8432],
        [ 1.1185, -3.6838, -0.6946],
        [-2.7415,  3.2696, -0.4850],
        [-2.3519,  3.7675, -1.7376],
        [ 0.9948, -4.2480,  0.1720],
        [-3.0544,  5.3381, -2.4343],
        [-2.8296,  4.9118, -2.3245],
        [-3.6079,  6.7351, -2.8791],
        [-3.3474,  6.0020, -2.4452],
        [-3.0332,  5.3978, -2.3813],
        [-3.3639,  6.3829, -2.9762],
        [ 0.8980, -2.8920, -1.1995],
        [-3.6209,  7.0607, -3.3962],
        [ 0.9846, -4.2962,  0.1819],
        [-3.4656,  6.3629, -2.7197],
        [-1.7761, -2.9795,  4.4226],
        [-0.3737, -3.7375,  2.6953],
        [-3.3250,  5.7523, -2.1525],
        [ 0.8152, -3.8644,  0.2331],
        [-3.9005,  7.4317, -3.2905]

Epoch number 5
 Current loss 0.01826365292072296

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 1.2169, -4.3623, -0.2227],
        [ 0.8510, -3.8610,  0.1201],
        [-3.5521,  5.8385, -1.8288],
        [-3.9365,  7.5351, -3.2599],
        [-3.6305,  6.7161, -2.8639],
        [-2.5003,  0.5762,  1.9910],
        [-3.9670,  7.6288, -3.3463],
        [-3.3694,  6.0540, -2.5333],
        [-3.8548,  7.3916, -3.3287],
        [-4.1276,  6.6744, -1.8395],
        [ 0.4199, -4.0937,  1.4158],
        [-3.1657,  5.9683, -3.1278],
        [ 0.4047, -2.1166, -1.1421],
        [-2.0089,  2.6291, -1.3828],
        [-2.9697,  4.7934, -1.7709],
        [-2.9652,  1.3896,  1.8888],
        [-1.3054, -3.0652,  3.9124],
        [-3.2763,  5.5616, -2.1663],
        [-2.9244,  5.1193, -2.3235],
        [-1.7087, -1.0649,  2.5206],
        [-1.0618, -3.5605,  3.8903],
        [-3.1749,  3.6301, -0.2047],
        [-2.4369,  3.6796, -1.6388],
        [-3.8183,  6.8336, -2.5349]

Epoch number 5
 Current loss 0.08868871629238129

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.4219, -2.4353,  4.6212],
        [-2.8085,  5.1073, -2.7687],
        [-2.4557,  3.9042, -2.1867],
        [ 0.1664, -3.8782,  1.8750],
        [-3.6628,  4.5575, -0.2866],
        [-3.2084,  5.7599, -2.5290],
        [-2.7135,  4.6366, -2.2276],
        [ 1.0299, -4.6358,  0.4478],
        [-1.1212, -3.7690,  4.1504],
        [-1.2620, -2.9808,  3.6210],
        [ 1.1731, -4.3375,  0.2526],
        [-4.0118,  7.0031, -2.4581],
        [-2.2375,  3.9223, -2.6239],
        [-4.1163,  7.8811, -3.3444],
        [-0.8519, -3.8860,  3.7965],
        [-1.4741, -3.7908,  4.7389],
        [-2.8502,  4.8980, -2.2908],
        [-1.2639, -2.8379,  3.7571],
        [-3.3928,  6.2704, -2.9125],
        [-3.9431,  7.6724, -3.4766],
        [-3.4299,  5.6648, -2.0045],
        [ 1.0965, -3.5736, -0.9034],
        [-3.5678,  6.0433, -2.0452],
        [-0.0473, -3.9971,  2.4442]

Epoch number 5
 Current loss 0.08846452832221985

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.8129,  7.4363, -3.4183],
        [-4.0722,  7.5360, -3.0018],
        [ 1.1186, -4.1922,  0.1881],
        [-1.9382, -3.6853,  5.2004],
        [ 0.2950, -3.8903,  1.7126],
        [-3.2051,  5.3184, -1.9369],
        [-4.0825,  7.9129, -3.4304],
        [-1.0257, -3.9332,  4.2822],
        [-0.8581, -3.6540,  3.6372],
        [ 1.0296, -3.5602, -0.8232],
        [-0.7991, -2.1914,  1.7976],
        [-0.3931, -2.7478,  1.5464],
        [-4.0094,  7.6955, -3.4258],
        [-3.1363,  5.8071, -2.9324],
        [ 0.3828, -4.4758,  2.0035],
        [-1.5207, -3.6180,  4.6390],
        [ 0.5598, -4.0619,  1.0354],
        [ 1.0512, -3.4645, -0.7513],
        [-3.9457,  7.4682, -3.2071],
        [-4.1189,  7.9104, -3.3854],
        [-3.8916,  7.5350, -3.4216],
        [-2.8333,  4.9014, -2.3744],
        [ 0.8913, -4.1743,  0.2234],
        [-1.2079, -3.6773,  4.2937]

Epoch number 5
 Current loss 0.026364831253886223

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.5360,  5.6081, -1.5223],
        [-3.3636,  5.4090, -1.6901],
        [-2.9993,  4.9813, -1.8677],
        [-0.6929, -4.1883,  3.7461],
        [-2.2843,  4.1072, -2.8832],
        [-3.0356,  5.2223, -2.2606],
        [-3.0339,  5.0757, -2.1135],
        [-0.3948, -3.9925,  2.9963],
        [-3.8001,  7.0046, -2.9029],
        [-1.1368, -3.2435,  3.6865],
        [-0.8026, -3.9667,  3.8034],
        [-2.8974,  5.5682, -2.8918],
        [-1.3308, -3.0791,  3.7511],
        [-2.5947,  2.6315,  0.3486],
        [-1.3368, -3.7401,  4.4903],
        [-3.2812,  6.0439, -2.9968],
        [-3.1462,  5.7068, -2.6774],
        [-2.0117, -3.6059,  5.1739],
        [-1.0724, -2.7459,  3.0546],
        [-1.7554,  3.3643, -2.4594],
        [-3.4998,  5.8774, -1.9224],
        [-3.2677,  5.4890, -1.9018],
        [ 1.0518, -4.3346,  0.6154],
        [ 0.3987, -2.6876, -0.4105

Epoch number 5
 Current loss 0.08304902911186218

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.6633,  6.4064, -2.3859],
        [-1.8516, -1.2917,  3.0238],
        [-0.9742, -4.0384,  4.1838],
        [-3.6677,  7.1474, -3.3100],
        [-1.8022, -4.2705,  5.5717],
        [-3.9703,  7.7747, -3.5047],
        [-0.9204, -3.9698,  4.0219],
        [-1.6854, -3.7282,  4.9131],
        [-3.4766,  6.1152, -2.5528],
        [-2.7893,  4.9840, -2.6179],
        [ 1.1917, -4.2806, -0.2747],
        [-3.6145,  6.8945, -3.2737],
        [-3.9660,  7.5288, -3.2844],
        [-1.6958, -4.2140,  5.4133],
        [ 0.1036, -2.8116,  0.9742],
        [-3.3614,  5.9641, -2.3115],
        [-3.3154,  5.8814, -2.6318],
        [-3.6533,  7.1546, -3.5465],
        [-3.3388,  5.9616, -2.6683],
        [-3.6404,  6.3269, -2.3782],
        [-2.9423,  3.1847, -0.2914],
        [-3.8066,  7.2837, -3.3440],
        [-3.0813,  5.5477, -2.7952],
        [-2.9900,  5.3316, -2.6152]

Epoch number 5
 Current loss 0.12362844496965408

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.8162,  5.1958, -2.8533],
        [-3.9670,  7.8848, -3.6709],
        [-0.8520, -4.0492,  3.9708],
        [-3.9379,  7.3879, -3.1045],
        [-3.0002,  5.5886, -2.8222],
        [-1.9243,  2.9126, -1.2159],
        [-3.1847,  5.5816, -2.3865],
        [-0.2231, -2.9520,  1.4485],
        [-4.0827,  7.7857, -3.3625],
        [-3.6071,  6.7256, -2.8870],
        [-3.0302,  3.5009, -0.3047],
        [-1.8684, -3.6781,  5.1378],
        [ 0.6151, -4.0698,  1.2595],
        [-3.0021,  5.6846, -3.2361],
        [-3.9010,  7.5190, -3.4064],
        [-2.4845,  4.3662, -2.4742],
        [-3.4084,  4.6341, -0.9026],
        [ 0.7792, -4.9358,  1.3960],
        [-2.9902,  5.0277, -2.0807],
        [-3.2741,  5.6903, -2.4290],
        [-3.4906,  6.3404, -2.8154],
        [-1.7277,  3.2950, -2.3756],
        [-3.9113,  7.6244, -3.5366],
        [-3.5262,  5.7522, -1.8633]

Epoch number 5
 Current loss 0.04389116168022156

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.9823,  5.3055, -2.5963],
        [ 1.2870, -4.2470, -0.3414],
        [-2.7775,  4.9649, -2.4668],
        [-1.1355, -3.9071,  4.3772],
        [-2.6302,  4.8773, -2.9093],
        [-3.8995,  7.5447, -3.4134],
        [-0.7126,  0.6083, -2.1855],
        [-3.4997,  6.9538, -3.3947],
        [ 0.1485, -4.0919,  1.8642],
        [-1.4409, -3.9698,  4.9015],
        [-0.9425, -2.7612,  2.6945],
        [-3.5737,  5.6689, -1.7440],
        [-0.9193, -3.6851,  3.8863],
        [-3.3100,  6.2728, -3.1671],
        [-3.8273,  6.3986, -2.2373],
        [-2.4330,  2.5460, -0.4159],
        [-3.9778,  7.0209, -2.6133],
        [-3.6365,  6.6374, -2.7307],
        [-3.3317,  5.9359, -2.4934],
        [-2.8277,  4.1206, -1.2114],
        [-1.2088, -3.7164,  4.3622],
        [-1.3616, -3.9216,  4.7307],
        [-2.3798,  1.9522,  0.0730],
        [-3.1738,  4.7315, -1.3073]

Epoch number 5
 Current loss 0.06626251339912415

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.0564,  7.9573, -3.5846],
        [-3.8853,  7.7612, -3.6865],
        [ 1.1802, -3.9582, -0.3243],
        [-3.7444,  7.0551, -3.0935],
        [ 1.2630, -4.5784,  0.2152],
        [-3.8957,  7.7166, -3.6353],
        [ 0.9095, -3.4551, -0.2735],
        [-4.0751,  7.8653, -3.4165],
        [-3.8027,  7.2592, -3.1910],
        [-3.6117,  6.7777, -2.9767],
        [-3.1938,  5.9468, -2.9678],
        [-1.7714, -0.7959,  2.2576],
        [ 1.2390, -4.5264, -0.0064],
        [-2.3852,  4.1606, -2.5037],
        [-2.9865,  5.4299, -2.7694],
        [-2.1117, -2.5899,  4.3836],
        [-3.5544,  5.2762, -1.2397],
        [-3.2954,  5.0281, -1.4679],
        [-0.3729, -4.3600,  3.4821],
        [-1.4762, -3.3898,  4.3257],
        [-2.2656,  3.7042, -2.0697],
        [-1.4620, -3.2815,  4.3432],
        [-0.8453, -3.5430,  3.4988],
        [-3.7996,  6.5970, -2.4785]

Epoch number 5
 Current loss 0.07273358106613159

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.8568,  7.3803, -3.2580],
        [-2.4549,  4.3956, -2.5832],
        [-3.0661,  5.7104, -2.9628],
        [-2.4640,  4.7253, -2.9957],
        [ 1.2170, -3.9558, -0.6285],
        [-3.9800,  8.0038, -3.8532],
        [-1.1781, -3.7948,  4.4708],
        [ 0.7563, -3.9410,  0.7791],
        [-2.9401,  5.2355, -2.7898],
        [-3.0151,  5.5554, -2.9670],
        [-2.9315,  5.5376, -3.1054],
        [-3.1420,  5.7478, -2.9696],
        [-2.8790,  5.0322, -2.3961],
        [-0.8485, -3.8972,  3.9558],
        [-3.1405,  5.2006, -1.9623],
        [-3.4752,  5.8914, -2.3840],
        [-3.3988,  6.3833, -3.1956],
        [-2.9916,  5.1302, -2.1934],
        [-3.8370,  7.5952, -3.6873],
        [-3.7559,  7.5396, -3.8280],
        [-3.6942,  7.2975, -3.6344],
        [-3.7486,  7.0173, -3.1081],
        [-3.9968,  7.3684, -2.9805],
        [-1.9424,  3.8624, -2.9152]

Epoch number 5
 Current loss 0.04888784512877464

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.5696, -0.7689,  3.3516],
        [-1.5744, -3.8948,  4.9605],
        [-1.4766, -3.4594,  4.4251],
        [-1.6778, -4.1476,  5.2939],
        [-2.9959,  5.1336, -2.2023],
        [-1.4002, -3.9544,  4.8221],
        [ 1.0992, -3.9512, -0.6384],
        [-3.2812,  6.3126, -3.3893],
        [-3.4002,  6.2840, -2.7964],
        [-4.1419,  8.1025, -3.6404],
        [-1.3071, -3.7389,  4.3924],
        [-2.8949,  5.4844, -3.2533],
        [-4.0429,  7.7107, -3.3113],
        [-2.9483,  5.9465, -3.7739],
        [-3.7314,  7.0827, -3.2349],
        [-2.6387,  4.1901, -1.8591],
        [ 0.5814, -3.9140,  1.0653],
        [-4.0230,  7.9013, -3.6525],
        [-3.2363,  5.9466, -2.8257],
        [-3.1780,  5.4999, -2.3092],
        [-4.0794,  7.6946, -3.2214],
        [-3.1948,  6.1973, -3.6210],
        [ 1.0249, -3.3019, -1.2714],
        [ 1.1454, -4.6664,  0.1446]

Epoch number 5
 Current loss 0.06473463773727417

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.9634, -2.2557,  3.9211],
        [-3.7580,  6.5651, -2.5529],
        [-3.3081,  5.7481, -2.1791],
        [-3.9630,  7.8602, -3.6782],
        [ 1.1381, -3.5417, -0.7862],
        [-2.6555,  4.7352, -2.4476],
        [-3.2859,  6.3853, -3.5915],
        [ 0.2390, -4.5753,  2.2707],
        [-3.6369,  6.7319, -2.9339],
        [ 0.5693, -4.5459,  1.7482],
        [-3.7032,  7.3239, -3.6284],
        [-3.9819,  7.9880, -3.8239],
        [-0.6454, -4.3778,  4.1514],
        [-3.0229,  5.2298, -2.3605],
        [-3.0694,  5.6760, -3.0124],
        [-3.9546,  7.9506, -3.8435],
        [-4.0099,  7.7331, -3.4082],
        [-3.2290,  5.5550, -2.2433],
        [-0.3265, -3.4611,  2.3026],
        [-3.9187,  7.6416, -3.4638],
        [-1.5353, -3.6426,  4.6735],
        [-3.9495,  7.8725, -3.7620],
        [-4.0178,  7.8528, -3.6217],
        [-3.2397,  6.2064, -3.0526]

Epoch number 5
 Current loss 0.08809837698936462

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 1.1654, -3.8122, -0.3788],
        [-2.9162,  3.6970, -0.5661],
        [-0.8172, -4.1737,  4.1311],
        [-1.6251, -3.2415,  4.4498],
        [-3.2684,  5.2821, -1.6277],
        [-3.9712,  7.9361, -3.7709],
        [-3.6820,  7.0657, -3.2354],
        [ 1.1938, -4.6852,  0.1050],
        [-0.4860, -4.6136,  3.7506],
        [-1.9509, -3.6572,  5.1672],
        [-4.0833,  7.1331, -2.5529],
        [-3.2817,  5.2638, -1.8261],
        [-1.1320,  2.5471, -2.8505],
        [ 1.1490, -4.9223,  0.2435],
        [-3.8088,  7.4523, -3.5042],
        [-3.6456,  7.3169, -3.6702],
        [-1.5557, -2.9337,  4.1738],
        [-2.0834, -1.6694,  3.5821],
        [-3.9539,  7.7945, -3.5659],
        [-3.4702,  6.7770, -3.2405],
        [-2.0891, -2.9961,  4.7775],
        [-3.1700,  5.6721, -2.6499],
        [-3.6970,  7.0667, -3.1280],
        [-2.3671,  2.6958, -0.3086]

Epoch number 5
 Current loss 0.12928158044815063

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.0753,  5.1695, -1.9763],
        [-1.9108,  3.7512, -2.9754],
        [-0.2100, -4.2142,  2.9234],
        [-3.5282,  6.7763, -3.2274],
        [-3.3501,  5.1074, -1.4717],
        [-2.2096, -3.7131,  5.5279],
        [-2.0656,  4.0042, -3.1828],
        [-3.7815,  7.5727, -3.6936],
        [-2.7483,  5.6662, -3.5986],
        [-0.3125, -1.9918,  0.6591],
        [-1.6842,  2.5622, -1.3301],
        [-3.1060,  5.8565, -3.0110],
        [-4.0218,  7.5033, -3.0907],
        [-3.4457,  6.2076, -2.5980],
        [-2.0461, -2.7825,  4.5686],
        [-2.9353,  5.5994, -3.1184],
        [-3.4626,  5.7114, -1.9747],
        [-3.0409,  4.9562, -1.8839],
        [-3.8064,  7.4492, -3.4585],
        [-0.8443, -3.7405,  3.8240],
        [-3.7975,  7.7558, -3.8723],
        [-1.4819, -3.0950,  4.1580],
        [-3.3717,  6.0425, -2.3493],
        [-2.9687,  2.0285,  1.2355]

Epoch number 5
 Current loss 0.10241258889436722

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 0.3976, -4.4490,  1.7968],
        [-1.0237, -3.3835,  3.5980],
        [-1.4858, -3.0154,  4.1303],
        [-0.7143, -3.2951,  2.8391],
        [-2.8430,  4.9500, -2.1672],
        [ 1.2283, -4.0236, -0.9113],
        [-2.7111,  4.9362, -2.6150],
        [-3.6742,  6.9412, -3.0702],
        [-4.0433,  6.7257, -1.9119],
        [-2.0188, -1.5170,  3.4318],
        [ 0.3685, -4.1461,  1.4771],
        [ 1.2017, -4.3800, -0.1507],
        [-2.4082,  2.8617, -0.8268],
        [-3.8636,  7.9147, -3.9163],
        [-1.3598, -4.3894,  5.1628],
        [-2.0681, -0.1567,  2.3211],
        [-3.4547,  5.8875, -1.9174],
        [-3.6153,  7.3101, -3.6744],
        [-3.6455,  7.2998, -3.4921],
        [-3.0195,  5.4685, -2.6696],
        [-2.8077,  5.2135, -2.8813],
        [-3.3119,  5.0308, -1.3718],
        [-1.0990, -3.8898,  4.2970],
        [-3.7283,  6.9877, -2.9560]

Epoch number 5
 Current loss 0.16106747090816498

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.0856, -1.5262,  1.6827],
        [-1.8749, -3.9855,  5.3802],
        [-3.9456,  7.9541, -3.8591],
        [-2.8624,  6.0888, -3.8546],
        [ 1.1957, -4.6308, -0.1384],
        [-3.2258,  5.2308, -1.9295],
        [-3.7279,  7.4927, -3.7255],
        [-1.8372, -0.4130,  2.2821],
        [-3.1501,  5.2987, -1.9383],
        [-2.6053,  3.5282, -1.2054],
        [-2.8339,  5.5343, -3.3791],
        [-3.2352,  5.3155, -2.0534],
        [-1.8383,  2.8654, -2.1692],
        [-1.7970, -3.5392,  4.9229],
        [-3.5978,  7.0377, -3.2628],
        [-3.4290,  6.5687, -3.1418],
        [ 1.1679, -4.0048, -0.6254],
        [-3.9574,  7.9211, -3.6808],
        [-2.6944,  3.7371, -1.2245],
        [-3.0310,  5.9109, -3.1876],
        [-0.4387, -4.0092,  3.1082],
        [-3.0236,  4.8036, -1.7065],
        [-2.8993,  5.5590, -2.4810],
        [-3.3618,  6.1551, -2.7289]

Epoch number 5
 Current loss 0.05039665102958679

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.4059,  6.4477, -3.0410],
        [-3.5593,  7.3722, -3.9076],
        [-3.4490,  7.2065, -3.9048],
        [-3.5246,  6.8095, -3.1807],
        [-1.4899, -2.1805,  3.2200],
        [-1.6095,  3.1322, -2.7154],
        [ 1.0040, -4.1326,  0.0514],
        [-3.0739,  5.3809, -2.1815],
        [ 1.1427, -4.5618,  0.3238],
        [-3.0768,  5.8199, -2.9380],
        [ 0.5965, -4.6211,  1.4352],
        [-4.0556,  7.6550, -3.2322],
        [-2.7177,  3.7384, -1.1901],
        [-3.3764,  6.8061, -3.7053],
        [-3.6635,  7.3701, -3.7133],
        [-3.5663,  6.6078, -2.9326],
        [-1.5204, -4.2576,  5.2257],
        [ 0.8947, -3.1861, -0.7649],
        [-1.5994,  3.1009, -2.1663],
        [-1.4189, -3.2275,  4.0945],
        [-3.7883,  6.9802, -2.9298],
        [-3.0116,  5.5689, -2.6586],
        [-3.8037,  7.4456, -3.4290],
        [-0.8984, -3.5913,  3.7030]

Epoch number 5
 Current loss 0.13050222396850586

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.6451, -2.0374,  1.1479],
        [-3.0335,  5.9956, -3.5551],
        [-3.3826,  6.1867, -2.7449],
        [ 1.1171, -4.4588,  0.1053],
        [-2.2664, -0.1390,  2.4708],
        [-2.9400,  5.3726, -2.3297],
        [-2.4094,  5.0600, -3.5534],
        [ 0.9663, -4.5662,  0.6102],
        [-3.2454,  6.0395, -2.8393],
        [-3.8850,  7.4652, -3.3552],
        [-3.8466,  7.8445, -3.8715],
        [ 0.9444, -3.8162, -0.1182],
        [-4.0811,  8.1406, -3.7864],
        [-2.8373,  5.3140, -2.8608],
        [-3.9027,  7.8573, -3.8374],
        [-3.9526,  7.9234, -3.8150],
        [-1.5341, -2.3645,  3.6102],
        [-3.2105,  5.9665, -2.8283],
        [-3.6751,  7.2415, -3.4817],
        [ 1.2142, -4.3983, -0.0686],
        [-0.7883, -3.8541,  3.6067],
        [-2.6703,  4.5377, -2.2557],
        [-3.5953,  6.4476, -2.7090],
        [-3.8721,  7.8697, -3.8969]

Epoch number 5
 Current loss 0.09221825748682022

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.7772, -3.2868,  4.6655],
        [-3.4655,  6.7340, -3.1963],
        [-3.7413,  7.5582, -3.7950],
        [-2.8839,  5.0787, -2.0519],
        [-1.8662, -2.2267,  3.9150],
        [-1.5652, -2.5447,  3.7415],
        [-1.8331, -2.0475,  3.6460],
        [-0.5801, -3.7028,  3.0968],
        [-3.4810,  6.7542, -3.3843],
        [-3.7949,  7.7120, -3.8917],
        [-4.0595,  7.6888, -3.2999],
        [-2.6230,  3.7031, -1.2622],
        [-3.5424,  6.7997, -3.0625],
        [-3.3800,  6.6197, -3.2713],
        [-3.0245,  5.6824, -2.9186],
        [-3.7149,  7.5054, -3.7517],
        [-3.7908,  6.4442, -2.1403],
        [ 0.2650, -4.3454,  2.0387],
        [-1.2802, -3.3488,  4.0807],
        [-3.3504,  6.0894, -2.4943],
        [-1.5526, -3.6326,  4.6850],
        [-3.1274,  5.9976, -3.3804],
        [-2.5613,  5.3809, -3.5643],
        [-1.6745, -3.2098,  4.5070]

Epoch number 5
 Current loss 0.08419680595397949

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.5036,  6.9940, -3.5567],
        [ 0.1364, -4.4011,  2.4303],
        [-2.5978,  3.8202, -1.3932],
        [-3.7165,  7.6268, -3.9343],
        [-2.6401,  3.8986, -1.0527],
        [-1.9600, -3.8560,  5.3851],
        [-2.6978,  5.2320, -3.2571],
        [-3.0573,  5.8456, -3.1391],
        [-3.3219,  6.1774, -2.8399],
        [-2.1777,  4.0934, -2.8429],
        [-3.7806,  7.5611, -3.7394],
        [-1.0692, -3.5625,  4.0272],
        [-3.9879,  5.2626, -0.6445],
        [-3.3691,  6.4388, -3.2646],
        [-3.9020,  7.8949, -3.8962],
        [-3.2696,  5.5367, -1.9437],
        [-3.9143,  7.6775, -3.5813],
        [-3.6894,  6.6111, -2.6701],
        [-3.7302,  7.3064, -3.5764],
        [-1.8482, -2.1982,  3.8269],
        [ 0.8297, -4.4624,  0.7124],
        [ 0.8504, -4.0462,  0.7139],
        [-2.0265, -3.5377,  5.1444],
        [-2.9861,  5.5260, -2.7739]

Epoch number 5
 Current loss 0.12223301827907562

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.3092,  2.0496, -1.5787],
        [-3.2797,  6.3134, -3.3057],
        [-3.9419,  7.6868, -3.5371],
        [ 0.9214, -2.9913, -0.8654],
        [-3.4020,  6.3448, -2.9315],
        [-3.8774,  7.6943, -3.7374],
        [ 0.7462, -4.5436,  1.5929],
        [-3.6192,  6.6544, -2.9274],
        [-3.0553,  3.9154, -0.6216],
        [-1.3247,  3.0262, -3.5919],
        [-3.8381,  7.4283, -3.3864],
        [-2.5924,  5.0197, -3.1413],
        [-3.5985,  6.8799, -3.2046],
        [-3.0273,  5.4002, -2.3390],
        [-2.5828,  4.7403, -2.6448],
        [-0.8874,  1.5016, -1.9915],
        [-3.2763,  5.1617, -1.6582],
        [-3.4165,  6.4105, -2.9911],
        [-3.0331,  5.1445, -1.9219],
        [-3.2273,  6.0184, -2.7729],
        [-3.9796,  7.7309, -3.5544],
        [-3.2721,  6.1548, -2.7524],
        [-3.8191,  6.8716, -2.6837],
        [-4.0664,  7.8114, -3.4721]

Epoch number 5
 Current loss 0.10795716941356659

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.3897,  6.5181, -3.2648],
        [-1.4050,  1.2175, -1.1235],
        [-4.0422,  6.6823, -1.9629],
        [-3.4729,  6.5013, -2.8968],
        [-0.9401,  1.9728, -3.2192],
        [-0.2658, -4.2117,  3.0375],
        [-1.6305, -4.0204,  5.0750],
        [-3.5876,  6.7363, -3.0583],
        [-2.0115, -3.8397,  5.3874],
        [-4.0513,  7.9748, -3.7200],
        [-3.2835,  6.2705, -3.0722],
        [-2.6694,  4.6203, -2.3659],
        [-3.9260,  7.5863, -3.4596],
        [-3.7639,  5.7777, -1.3295],
        [-1.2524, -3.2718,  4.0131],
        [-3.8013,  7.3781, -3.4121],
        [-1.3050,  2.8455, -3.4804],
        [-0.5500, -4.1846,  3.5970],
        [-3.8823,  7.7149, -3.8045],
        [-3.3335,  5.9556, -2.2902],
        [ 0.1653, -2.8400,  0.7943],
        [-3.3494,  6.3076, -3.0734],
        [-1.3085,  2.6989, -3.3583],
        [-2.7240,  5.1465, -2.9909]

Epoch number 5
 Current loss 0.05846448615193367

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 1.1105, -4.5274,  0.0560],
        [-2.5987, -0.8498,  3.4087],
        [-2.9648,  5.6205, -2.9576],
        [ 0.4499, -2.7531, -0.3135],
        [-3.7664,  7.4270, -3.5872],
        [-2.9707,  4.3323, -1.2356],
        [-3.9571,  6.1644, -1.6946],
        [ 0.7439, -4.6459,  1.2105],
        [-0.1338, -4.0552,  2.7301],
        [-2.9402,  5.4213, -2.8924],
        [-1.4167, -3.2805,  4.2033],
        [-3.2858,  6.3043, -3.1570],
        [-2.6706,  5.0079, -2.9868],
        [-1.7292, -4.1862,  5.4296],
        [-0.4733, -4.2399,  3.5377],
        [-3.4224,  6.1302, -2.4465],
        [ 1.0985, -4.5615,  0.1497],
        [-1.6277, -3.8998,  5.0021],
        [-3.9679,  7.8064, -3.7300],
        [-0.0530, -3.2001,  1.4002],
        [-1.2428, -3.7610,  4.4674],
        [-3.4360,  5.9218, -2.3381],
        [-1.9061, -3.4286,  4.9266],
        [-3.3624,  5.4330, -1.7941]

Epoch number 5
 Current loss 0.08475644141435623

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.0276,  5.5527, -2.7299],
        [ 0.9555, -4.2608,  0.2707],
        [-3.7437,  7.0015, -3.0270],
        [-4.0956,  7.9740, -3.6661],
        [-1.2851, -3.6461,  4.3720],
        [-3.0139,  4.0488, -0.5594],
        [-1.0774, -3.8653,  4.3091],
        [ 0.1088, -3.1374,  0.9368],
        [-1.9320, -3.8930,  5.4072],
        [-1.8927,  3.4772, -2.1638],
        [-3.1941,  3.0973,  0.3878],
        [-3.8946,  7.3401, -3.2921],
        [-1.3590, -4.4495,  5.2281],
        [-2.4651,  3.4418, -1.2164],
        [-4.2406,  6.8548, -1.8787],
        [-3.9770,  7.6266, -3.3729],
        [-1.1794, -3.1587,  3.8252],
        [-0.6403, -3.9768,  3.5547],
        [-2.7716,  4.4398, -1.7724],
        [-3.8043,  7.3815, -3.5362],
        [-3.4613,  6.4270, -3.0212],
        [-1.4443, -3.8497,  4.7655],
        [-3.7509,  6.9514, -3.0429],
        [ 1.1064, -4.2176, -0.2401]

Epoch number 5
 Current loss 0.09082101285457611

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.1460,  7.7957, -3.3588],
        [-3.2894,  5.8099, -2.5115],
        [-4.1667,  7.9505, -3.5021],
        [ 0.9927, -4.2579,  0.0823],
        [ 0.5082, -3.3237,  0.3121],
        [-2.9640,  5.5702, -2.9167],
        [-4.0273,  7.7064, -3.4658],
        [-3.0117,  4.9173, -1.8548],
        [-3.9070,  6.9762, -2.7896],
        [-3.8881,  7.4376, -3.4641],
        [-0.9809, -3.9614,  4.2714],
        [-3.4703,  6.0299, -2.2252],
        [-4.3992,  7.2380, -2.1463],
        [-3.3531,  5.9249, -2.5263],
        [-4.1740,  7.8738, -3.4442],
        [ 0.2143, -4.1255,  2.2297],
        [-3.9749,  7.0615, -2.7686],
        [-3.2808,  4.5147, -0.7693],
        [-3.0433,  1.9823,  1.4275],
        [-3.1101,  5.6172, -2.6386],
        [-2.1346, -2.3732,  4.3100],
        [-3.2157,  5.7509, -2.6817],
        [-3.2773,  5.6793, -2.1470],
        [-3.4118,  5.9844, -2.2524]

Epoch number 5
 Current loss 0.1108224019408226

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.8811,  7.2189, -3.0622],
        [-3.5978,  6.6697, -3.0463],
        [-3.2164,  6.0310, -3.1033],
        [-3.2860,  5.5148, -1.8920],
        [ 1.0730, -4.2865,  0.4082],
        [-4.0115,  7.5815, -3.3933],
        [-2.1742, -3.1094,  4.9584],
        [-0.7936, -3.5693,  3.3741],
        [-2.8243,  5.2890, -2.8960],
        [-2.6332,  4.9398, -3.0012],
        [-2.5080,  3.2283, -1.1481],
        [-3.2711,  5.8951, -2.5784],
        [-2.9987,  5.5617, -2.9818],
        [-3.5685,  6.1555, -2.3604],
        [-3.9989,  7.3140, -3.0011],
        [-4.0589,  7.7202, -3.4272],
        [-3.5307,  6.4362, -2.9129],
        [ 0.7743, -3.4978, -0.2982],
        [-2.1542, -2.5369,  4.4457],
        [-2.1391,  2.9247, -1.5554],
        [-4.0577,  7.4259, -3.0977],
        [-3.6631,  6.3605, -2.4557],
        [-2.9984,  3.9754, -0.8019],
        [-2.9638,  4.3283, -1.3582],

Epoch number 6
 Current loss 0.055400192737579346

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.2826,  4.1517, -2.7068],
        [-1.2168, -4.0743,  4.7201],
        [-1.2318, -3.3696,  4.0963],
        [-3.4824,  6.4606, -3.0182],
        [ 0.7990, -3.8292,  0.8144],
        [-3.6738,  6.7295, -3.0029],
        [-4.0918,  7.9229, -3.6665],
        [-3.6594,  6.7346, -2.9390],
        [-4.1395,  7.3869, -2.9393],
        [-1.8250, -3.9415,  5.2295],
        [-3.4104,  6.3410, -2.9439],
        [-3.4802,  6.4384, -2.8469],
        [-1.5417, -2.9992,  4.1154],
        [-1.7959, -3.8688,  5.2044],
        [-0.5455,  1.0616, -2.8164],
        [-1.4792,  2.9139, -3.4362],
        [-3.9862,  7.5253, -3.3265],
        [-4.1513,  7.8989, -3.4913],
        [-2.8067,  4.7628, -1.8309],
        [-2.6532,  3.3462, -0.7248],
        [-2.6999,  4.5818, -2.1666],
        [-1.3716, -3.7672,  4.6328],
        [-2.7078, -0.1819,  2.8919],
        [-4.1166,  7.9244, -3.5719

Epoch number 6
 Current loss 0.06646300852298737

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.3285,  5.9932, -2.6722],
        [-2.1940, -0.9206,  3.0183],
        [ 0.9264, -3.8899, -0.1570],
        [ 0.9995, -3.3429, -1.0124],
        [ 0.4337, -4.1040,  1.8411],
        [-3.1664, -0.2707,  3.4580],
        [-3.1883,  5.6806, -2.6656],
        [-3.6181,  6.8401, -3.1930],
        [-3.3205,  5.9741, -2.5043],
        [-4.1954,  7.9619, -3.4327],
        [-0.5264, -2.7134,  1.6405],
        [-3.5031,  6.1982, -2.4447],
        [-3.8619,  7.3101, -3.3744],
        [-1.3780, -3.8519,  4.6744],
        [-1.7290, -0.5961,  1.5753],
        [-1.6845, -3.4200,  4.5963],
        [-4.0694,  7.9319, -3.6893],
        [-3.2400,  5.9161, -2.6769],
        [-3.0823,  5.3103, -2.1854],
        [-3.6838,  6.9147, -3.1296],
        [-2.7133,  4.8420, -2.6724],
        [-4.1478,  8.0475, -3.6990],
        [-3.3294,  6.3502, -3.4803],
        [-4.0731,  7.1451, -2.7125]

Epoch number 6
 Current loss 0.04039313644170761

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.9854,  7.3981, -3.2029],
        [-2.8738,  5.1129, -2.4499],
        [-3.9410,  6.3919, -1.9702],
        [-2.0920, -2.5050,  4.2855],
        [-3.4605,  6.3551, -2.7998],
        [-1.0791, -2.9204,  3.5730],
        [-1.4010, -3.5505,  4.3833],
        [-3.2467,  5.6074, -2.2621],
        [-3.0661,  5.3578, -2.4414],
        [ 0.9638, -3.3339, -0.6242],
        [-2.4340,  2.5417,  0.0861],
        [-3.9392,  6.8644, -2.5390],
        [-3.7227,  6.9814, -3.1758],
        [-3.6636,  6.7568, -3.0178],
        [ 0.0330, -2.9474,  0.7990],
        [-4.2820,  8.0365, -3.3610],
        [-0.9656, -2.9610,  3.2335],
        [-3.2533,  6.4367, -3.4290],
        [-3.8810,  7.6398, -3.6738],
        [-3.3452,  6.1661, -2.8442],
        [-3.2921,  5.3202, -1.8725],
        [-3.8386,  7.0794, -3.0435],
        [-4.1555,  7.6078, -3.0274],
        [-3.9265,  6.7725, -2.4245]

Epoch number 6
 Current loss 0.05680733546614647

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.3121,  6.5085, -1.4907],
        [-3.9267,  6.9786, -2.7322],
        [-1.9749, -1.3810,  3.1179],
        [-3.1444,  5.2562, -2.0805],
        [-1.6480, -3.2429,  4.4514],
        [-3.5015,  6.6178, -3.2527],
        [-4.0814,  7.4248, -2.9258],
        [-1.3244, -3.8108,  4.5070],
        [-1.6266, -2.6785,  4.0064],
        [ 1.0824, -3.5775, -0.7915],
        [-2.8927,  5.4954, -3.1939],
        [-3.5239,  5.9693, -2.1600],
        [-3.4900,  6.3973, -2.7687],
        [ 0.5413, -4.3761,  1.5112],
        [-2.9212,  5.3537, -2.8211],
        [-3.3182,  6.1086, -2.9107],
        [-2.8689,  5.1125, -2.5124],
        [-1.3820, -4.2473,  5.0716],
        [-2.4700,  3.7985, -1.7968],
        [-2.7570,  0.8168,  2.1995],
        [ 0.8202, -4.1951,  0.6427],
        [-3.1613,  5.8193, -2.8375],
        [-3.1131,  4.1096, -0.7382],
        [-1.0478, -4.1377,  4.5414]

Epoch number 6
 Current loss 0.04775793105363846

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 0.9206, -4.5179,  1.0455],
        [-3.4873,  6.4993, -3.0070],
        [-1.3432, -4.1882,  4.9413],
        [-1.8281, -4.0702,  5.4000],
        [ 0.6060, -2.7715, -0.8301],
        [-3.1685,  5.4902, -2.2597],
        [-1.8151,  0.5320,  0.6984],
        [-3.4116,  6.4442, -3.2121],
        [-3.5880,  6.6842, -3.0086],
        [-1.9800, -2.4978,  4.2854],
        [-1.9917, -0.6337,  2.3691],
        [-3.9719,  6.6784, -2.2305],
        [-0.5865, -4.3659,  4.0360],
        [-4.0425,  7.7650, -3.5357],
        [-0.8530, -4.3598,  4.3844],
        [-4.1974,  7.5198, -2.8647],
        [-4.2316,  6.3186, -1.2862],
        [-3.7058,  7.0914, -3.3617],
        [-2.6659,  4.8658, -2.8072],
        [-2.7396,  4.3990, -1.8190],
        [-3.3150,  5.5912, -2.0270],
        [-3.8337,  7.2092, -3.1653],
        [-1.3506, -2.0290,  2.8943],
        [-1.7514, -2.9777,  4.4161]

Epoch number 6
 Current loss 0.08520832657814026

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.5267,  0.1269, -1.7436],
        [-2.9377,  5.3467, -2.8829],
        [-2.1135, -1.7344,  3.6672],
        [-3.0885,  5.2580, -2.3684],
        [-3.8993,  7.0652, -2.8271],
        [-2.7662,  3.0453,  0.0007],
        [-2.8260,  1.8780,  1.1246],
        [-4.0032,  7.7894, -3.6245],
        [-2.4052,  2.0783,  0.6103],
        [-0.7306, -3.9933,  3.7994],
        [-1.3146, -4.0466,  4.7973],
        [-0.7779, -4.3642,  4.3491],
        [-3.3570,  6.1794, -2.6932],
        [-1.8993,  3.7395, -3.1684],
        [-3.3228,  5.7325, -2.1801],
        [-1.1386, -4.2683,  4.7513],
        [-3.9478,  6.3015, -1.7227],
        [-4.1409,  7.5448, -2.9107],
        [ 0.2438, -4.0860,  2.2950],
        [-2.4552,  4.0083, -2.0873],
        [-0.7342, -4.1022,  3.8322],
        [-3.4029,  4.1611, -0.3313],
        [-2.9632,  2.3675,  0.9159],
        [-1.3017,  2.3118, -2.8235]

Epoch number 6
 Current loss 0.06550265848636627

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.3573,  5.8953, -2.4176],
        [-3.0965,  5.3520, -2.1368],
        [-1.6739, -4.2595,  5.4352],
        [ 0.0640, -4.4422,  2.8700],
        [-3.7183,  6.6362, -2.5739],
        [ 1.1719, -4.3345, -0.0411],
        [-2.9506,  5.1208, -2.2933],
        [-3.7322,  6.7304, -2.7137],
        [-3.7188,  6.7397, -2.8170],
        [-1.3648, -3.1612,  4.0069],
        [-0.8812,  1.0546, -1.5387],
        [-2.9781,  5.1607, -2.2205],
        [-1.5137, -0.1190,  0.9540],
        [-3.0535,  6.1165, -3.4358],
        [-2.3489,  3.2304, -1.3302],
        [-4.0199,  7.9053, -3.6946],
        [-3.1465,  6.1824, -3.1721],
        [-3.9428,  7.5885, -3.4536],
        [-0.4741, -1.9405,  0.8366],
        [-3.0184,  5.3492, -2.2209],
        [-3.1761,  5.5699, -2.5539],
        [-3.6585,  4.7863, -0.6021],
        [-1.1264, -3.8355,  4.3483],
        [-2.9682,  4.9075, -1.9948]

Epoch number 6
 Current loss 0.04410333186388016

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.2331, -2.5492,  4.5118],
        [-1.2295, -3.7225,  4.3126],
        [-2.7729,  4.7906, -2.3786],
        [-3.5332,  6.1153, -2.1329],
        [-3.8640,  7.1791, -3.0435],
        [-1.3827, -3.5724,  4.4260],
        [ 1.3181, -4.6953,  0.1561],
        [-3.6041,  7.0035, -3.5722],
        [-2.3462,  3.8969, -2.1226],
        [-3.9288,  6.1635, -1.7835],
        [-2.8295,  4.6466, -2.0337],
        [-3.5785,  6.4343, -2.7119],
        [-3.5019,  6.7479, -3.2419],
        [-3.1478,  5.4459, -2.1877],
        [-3.3976,  6.5674, -3.3276],
        [-2.3024,  4.4645, -3.1678],
        [-3.5968,  6.2989, -2.5381],
        [-2.9384,  5.2056, -2.6002],
        [-2.8259,  5.2265, -2.9140],
        [-3.1096,  5.7567, -2.8534],
        [-2.2153,  3.7371, -2.2970],
        [-3.1219,  5.7513, -2.7811],
        [-3.6545,  6.6368, -2.6905],
        [-3.4785,  6.2487, -2.6048]

Epoch number 6
 Current loss 0.10709531605243683

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.8869, -3.9283,  5.3206],
        [-1.7911, -3.7588,  5.0065],
        [-1.9730, -2.8913,  4.5355],
        [-4.1248,  7.3195, -2.7306],
        [-1.8813,  3.3179, -2.5598],
        [-2.6437,  4.7433, -2.8926],
        [-3.5952,  7.1070, -3.5082],
        [-3.3168,  6.0266, -2.6600],
        [-2.9461,  5.5076, -2.9028],
        [-4.0149,  7.2729, -2.9225],
        [-0.5455, -1.9336,  0.8415],
        [-4.0283,  7.5242, -3.2430],
        [-2.9014,  5.4285, -3.0964],
        [-3.1307,  5.5516, -2.3112],
        [-3.8076,  6.4806, -2.3365],
        [-3.8993,  7.6166, -3.5683],
        [-3.7100,  5.9937, -1.7494],
        [-3.9086,  7.5631, -3.4702],
        [-3.1518,  5.7118, -2.7907],
        [ 1.1629, -4.2358, -0.1275],
        [-1.9840,  3.8461, -3.2378],
        [-1.4431, -3.6024,  4.5324],
        [-3.8007,  6.8823, -2.9413],
        [-3.5425,  6.6534, -3.0445]

Epoch number 6
 Current loss 0.04132625088095665

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.2602,  8.0870, -3.4019],
        [-3.8690,  5.7309, -1.3965],
        [-2.4543, -2.8056,  4.9886],
        [-4.1561,  7.5354, -2.9319],
        [-3.4685,  6.5850, -3.1311],
        [-1.7981, -3.3330,  4.6799],
        [-1.2611, -4.3096,  4.9225],
        [-3.0093,  4.9680, -1.8494],
        [-3.8546,  7.2962, -3.2109],
        [-3.6848,  7.2959, -3.6817],
        [-3.8062,  7.2497, -3.3885],
        [-3.8686,  7.6622, -3.7584],
        [-3.4291,  6.4376, -3.1369],
        [-4.0812,  7.5570, -3.1307],
        [-3.9266,  7.6433, -3.6200],
        [-2.9188,  5.2567, -2.5852],
        [ 0.0847, -0.7293, -1.8360],
        [-1.8316, -3.8157,  5.1750],
        [-1.7356, -3.0432,  4.3472],
        [-1.7021, -3.8802,  5.0499],
        [-3.3321,  5.8017, -2.3499],
        [-3.3191,  6.5012, -3.3128],
        [-3.8829,  7.7672, -3.7902],
        [-3.9424,  7.5581, -3.4291]

Epoch number 6
 Current loss 0.10788251459598541

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.7095,  6.5620, -2.6965],
        [-3.3069,  5.3723, -1.7510],
        [-3.4906,  6.0316, -2.1873],
        [-1.2730, -2.0208,  2.5607],
        [-3.5292,  6.4733, -2.8961],
        [-4.0877,  8.0178, -3.7728],
        [-1.9329, -3.7162,  5.1248],
        [ 1.2271, -4.2316, -0.2509],
        [-2.1915, -1.4221,  3.3969],
        [-3.5313,  6.1476, -2.1624],
        [-1.8933, -3.7446,  5.1506],
        [-0.5837,  1.0170, -3.0089],
        [-3.6584,  7.0018, -3.3186],
        [ 1.1374, -3.7331, -0.2429],
        [-4.3110,  8.1257, -3.3747],
        [-3.8982,  7.4040, -3.2961],
        [-4.1270,  8.0149, -3.7058],
        [-0.2249, -3.3636,  1.9183],
        [-3.2797,  6.0048, -2.8601],
        [ 1.0256, -4.0926, -0.1306],
        [-2.7839,  5.0511, -2.7133],
        [-2.0791, -3.8933,  5.4830],
        [-2.1342,  1.7920,  0.3712],
        [-2.2164, -3.1896,  5.0512]

Epoch number 6
 Current loss 0.03983399644494057

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.3821,  6.4859, -3.0993],
        [-3.5561,  6.2301, -2.4335],
        [-0.4259,  0.2875, -1.7970],
        [-0.3703, -4.3154,  3.3837],
        [-2.1129, -3.1435,  4.9234],
        [-3.1722,  4.2640, -0.8314],
        [-3.6330,  6.2368, -2.2435],
        [-3.0537,  4.7139, -1.4590],
        [-1.6235, -3.6068,  4.7242],
        [-3.0014,  5.2511, -2.3232],
        [-3.9559,  6.2872, -1.7424],
        [-3.5242,  6.1931, -2.3116],
        [-1.5802, -3.4365,  4.4691],
        [ 0.8163, -2.9517, -0.6885],
        [-3.6816,  6.6956, -2.8159],
        [-2.5261,  3.9664, -1.2734],
        [-4.0234,  6.4638, -1.8236],
        [-0.9225, -3.2134,  3.2980],
        [-2.0309, -3.3017,  4.9097],
        [-2.4410,  4.5712, -2.9682],
        [-1.8405, -3.8616,  5.1718],
        [ 0.8827, -4.1560,  0.5776],
        [-3.7058,  7.5012, -3.8557],
        [-2.8996,  5.0068, -2.0048]

Epoch number 6
 Current loss 0.09273052215576172

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.8026, -2.6224,  2.3582],
        [-3.7468,  6.3741, -2.0506],
        [ 1.2161, -4.3134, -0.3530],
        [-4.0159,  7.6204, -3.4307],
        [-0.5295, -4.6410,  3.9425],
        [-2.9589,  5.2437, -2.6571],
        [-3.6969,  6.4325, -2.3810],
        [-3.8162,  6.2028, -1.8531],
        [-3.7853,  7.2223, -3.2948],
        [-2.6912,  4.2959, -1.9772],
        [-4.2369,  8.0308, -3.4558],
        [-3.5428,  6.1467, -2.3919],
        [-3.1535,  5.4802, -2.4076],
        [-3.4841,  6.2721, -2.6450],
        [-2.8707,  5.2227, -2.7322],
        [ 1.0129, -4.6006,  0.4308],
        [-2.3067, -3.8446,  5.7139],
        [-0.2314, -4.2488,  3.0201],
        [-2.3747, -2.9571,  5.0404],
        [ 0.5151, -3.0245,  0.2422],
        [-3.5663,  6.7519, -3.2862],
        [-2.1194,  1.7800, -0.0538],
        [-2.1396, -3.5185,  5.2264],
        [-3.0026,  5.6940, -3.1739]

Epoch number 6
 Current loss 0.05544862896203995

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.3813,  3.7595, -1.3308],
        [-3.8208,  6.9094, -2.7288],
        [-2.8385,  5.2921, -2.4752],
        [-4.0418,  7.7160, -3.5460],
        [-3.6852,  5.2573, -1.1498],
        [-3.7878,  7.3468, -3.5315],
        [-3.4288,  4.8330, -0.8979],
        [-1.8789, -2.8793,  4.4033],
        [-3.9265,  7.3907, -3.1773],
        [-3.9630,  7.8700, -3.7959],
        [-1.8142, -2.3843,  3.9182],
        [-2.5131,  2.5555,  0.2280],
        [-1.9620,  3.7864, -3.0754],
        [-4.0253,  7.7194, -3.4952],
        [-3.7616,  1.2610,  2.7949],
        [-3.9698,  7.2186, -2.8979],
        [-3.6357,  6.4403, -2.4715],
        [-3.2284,  5.5533, -2.1447],
        [-3.4376,  4.5865, -0.6442],
        [ 0.2286, -4.4625,  2.1597],
        [ 0.7869, -4.4259,  0.8149],
        [-0.8187,  0.7403, -1.9722],
        [ 0.9426, -4.8455,  0.9980],
        [-3.4509,  4.0277, -0.0256]

Epoch number 6
 Current loss 0.09836927056312561

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.9557,  4.0503, -3.5457],
        [-1.5969, -1.7878,  3.0061],
        [-0.8451, -4.1617,  4.0513],
        [ 1.0173, -4.8541,  0.6820],
        [-3.4455,  6.2933, -2.8520],
        [-3.7164,  5.9955, -1.8296],
        [-3.5626,  6.4124, -2.6660],
        [-4.2053,  8.0980, -3.6652],
        [-3.4116,  5.5817, -1.7618],
        [-3.3678,  6.2614, -2.9081],
        [-4.2772,  8.2249, -3.7137],
        [-3.7456,  7.1211, -3.3899],
        [-3.7821,  6.7401, -2.6763],
        [-2.9105,  5.3932, -2.9450],
        [-3.2393,  5.8510, -2.8345],
        [ 1.2285, -3.9498, -0.6913],
        [-3.9313,  7.4126, -3.3212],
        [-1.8757, -3.8827,  5.2284],
        [-3.4902,  6.2981, -2.6396],
        [-3.7272,  6.4812, -2.3999],
        [-2.8185,  5.2634, -3.0978],
        [-3.3314,  5.8737, -2.4229],
        [-2.8111,  5.7394, -3.5311],
        [-2.4553, -1.1865,  3.4133]

Epoch number 6
 Current loss 0.06994232535362244

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.0672, -2.8933,  4.5675],
        [-0.9089, -3.3882,  3.3339],
        [-3.8076,  7.2284, -3.3907],
        [-2.3971, -3.0294,  5.0794],
        [-1.9351, -3.8384,  5.1958],
        [-3.8739,  6.7010, -2.6103],
        [-1.6533, -3.8569,  4.9473],
        [-1.6260,  3.0286, -3.0640],
        [ 0.0883, -3.5492,  1.3533],
        [-4.0284,  5.7642, -1.1729],
        [-1.2411, -3.6498,  4.1414],
        [-0.9125, -2.8897,  3.0307],
        [-3.8216,  6.5446, -2.3328],
        [-4.3987,  6.0356, -0.9175],
        [-4.2497,  7.7928, -3.1836],
        [-1.0267, -4.0015,  4.2241],
        [ 0.7248, -3.9427,  0.7005],
        [-3.6624,  7.0094, -3.4023],
        [-4.1348,  7.8491, -3.4079],
        [-3.3711,  5.9492, -2.4354],
        [-3.1729,  5.9419, -3.2583],
        [-4.2020,  7.9922, -3.4725],
        [-2.1468, -3.7164,  5.3556],
        [-3.7333,  7.0304, -3.2620]

Epoch number 6
 Current loss 0.10135588049888611

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.6020, -3.6781,  4.7126],
        [-3.9850,  7.1740, -2.8162],
        [-3.2927,  6.6149, -3.5217],
        [-4.3022,  8.0692, -3.4314],
        [-1.9759, -2.7238,  4.3354],
        [-0.7343, -2.2797,  1.6206],
        [ 1.3230, -4.7170, -0.1441],
        [-3.1699,  5.7106, -2.4578],
        [-4.1226,  7.8712, -3.5316],
        [ 0.8240, -4.3127,  0.5997],
        [-1.7082, -4.0024,  5.0895],
        [-3.4378,  5.1300, -1.4486],
        [-4.0086,  7.5030, -3.2685],
        [-3.9378,  7.5213, -3.4827],
        [-3.2807,  6.0207, -2.8355],
        [-0.4689, -3.9994,  2.9453],
        [-2.9772,  3.8072, -0.8465],
        [-3.0212,  5.8849, -3.2336],
        [-4.0305,  7.3367, -3.0055],
        [-2.0038, -3.5395,  5.0560],
        [-1.4075, -3.4586,  4.2157],
        [-3.0076,  5.2539, -2.3914],
        [ 1.3152, -4.1600, -0.8663],
        [-2.4135, -3.6769,  5.6158]

Epoch number 6
 Current loss 0.06333933770656586

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.0813,  7.6435, -3.4318],
        [-4.1678,  7.5907, -3.0081],
        [-0.4189, -2.6181,  1.7024],
        [-3.7023,  6.5358, -2.6649],
        [-2.8061,  5.0525, -2.7844],
        [-2.5497,  4.9959, -3.2002],
        [-3.3299,  5.1528, -1.6493],
        [-3.8173,  6.7263, -2.6821],
        [-4.3662,  8.3260, -3.6397],
        [-2.8964,  5.2371, -2.9305],
        [-3.1896,  5.7147, -2.7728],
        [-3.3184,  3.6319,  0.1281],
        [-4.0394,  7.8190, -3.6436],
        [ 0.7600, -2.5792, -1.2306],
        [-2.7697, -2.3131,  4.7990],
        [-3.3764,  4.8418, -1.0671],
        [-3.7913,  6.7249, -2.6928],
        [ 0.9131, -4.6682,  0.7281],
        [-3.3621,  6.2244, -3.0598],
        [-1.6811, -3.3062,  4.4735],
        [ 1.0674, -3.3566, -1.0986],
        [-3.2815,  5.8295, -2.5750],
        [-3.3366,  5.8707, -2.4848],
        [-0.8155, -4.2316,  3.9479]

Epoch number 6
 Current loss 0.07024529576301575

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.0521,  8.0602, -4.0503],
        [-3.6718,  5.5521, -1.3526],
        [-4.1897,  7.5450, -3.0474],
        [-3.6080,  6.8400, -3.2007],
        [ 1.4430, -4.9364, -0.3548],
        [-4.1526,  7.7359, -3.3339],
        [-3.8171,  7.2218, -3.4017],
        [-2.9492,  5.6044, -3.3671],
        [-3.7851,  6.5582, -2.3003],
        [-0.0907, -3.9794,  2.0942],
        [ 1.1531, -4.0402, -0.4844],
        [ 0.9190, -3.2816, -0.8238],
        [-3.8216,  6.9749, -2.8898],
        [ 1.4333, -4.5589, -0.4257],
        [-3.0693,  5.9085, -3.2207],
        [-4.2967,  8.4218, -3.9596],
        [-3.4803,  6.4189, -3.0715],
        [-3.1987,  5.4968, -2.3320],
        [-0.2210, -3.7393,  2.4143],
        [-4.4369,  8.1587, -3.2618],
        [-3.7770,  7.1676, -3.5036],
        [ 0.0407, -0.4142, -2.3070],
        [-3.7136,  7.0395, -3.3419],
        [ 0.9670, -2.9796, -1.2757]

Epoch number 6
 Current loss 0.09781324863433838

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.9316, -2.2092,  3.9141],
        [-1.4515, -3.3478,  4.1785],
        [-3.3756,  5.8190, -2.2615],
        [-3.0347,  5.4548, -2.3171],
        [-3.2057,  5.7421, -2.7381],
        [-0.2325, -4.3826,  3.1273],
        [-4.3077,  8.1704, -3.5943],
        [-3.0254,  3.6894, -0.2927],
        [-1.7016, -0.2903,  1.4322],
        [-3.8855,  7.5338, -3.6466],
        [-2.0333, -1.5829,  3.3359],
        [-4.1218,  8.1991, -3.9570],
        [-3.6715,  6.7367, -3.1606],
        [-1.2834, -3.9211,  4.5176],
        [-3.6338,  6.4697, -2.8269],
        [-2.9776,  5.7604, -3.4544],
        [-3.9067,  6.6167, -2.3375],
        [-3.2575,  6.0575, -3.2795],
        [-4.0064,  7.5038, -3.2401],
        [-1.7708, -1.8158,  3.3514],
        [-4.1028,  7.7776, -3.4169],
        [-2.2550, -1.6887,  3.6778],
        [-1.8553, -3.0725,  4.4329],
        [-4.3010,  8.0013, -3.3796]

Epoch number 6
 Current loss 0.07802814245223999

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.4770,  6.3725, -2.9799],
        [-4.3993,  8.4129, -3.7081],
        [-0.8225, -2.2701,  1.8271],
        [-3.7114,  6.9230, -3.1059],
        [-0.9416, -3.3731,  3.2195],
        [-1.6323, -1.6698,  2.6752],
        [-0.7855, -3.8551,  3.5248],
        [-4.0031,  7.7564, -3.7058],
        [-4.3362,  8.2738, -3.6931],
        [-3.6499,  5.7818, -1.6495],
        [-2.5751, -0.2152,  2.8242],
        [ 1.0898, -4.5352,  0.4438],
        [-3.6598,  5.9271, -2.0328],
        [-0.2242, -2.3250,  0.6612],
        [ 0.7997, -4.5256,  1.0143],
        [-3.7087,  6.9302, -3.2603],
        [-4.1239,  7.8274, -3.5422],
        [-1.1770, -4.1578,  4.6024],
        [-2.0082, -3.1328,  4.6892],
        [-4.0051,  7.3833, -3.1549],
        [-2.7116, -1.4871,  3.9944],
        [-3.1387,  5.8160, -3.1754],
        [-4.0530,  7.6861, -3.3674],
        [-3.9272,  7.5961, -3.5675]

Epoch number 6
 Current loss 0.058011047542095184

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.9910,  6.2943, -1.7835],
        [-4.3506,  8.1328, -3.4723],
        [-3.5313,  6.2848, -2.7357],
        [-4.2561,  8.0678, -3.6046],
        [-0.7322, -2.9318,  2.3118],
        [-3.4469,  6.1632, -2.8259],
        [-2.8737,  4.6608, -2.0863],
        [-1.6706,  1.7367, -0.6816],
        [-4.1950,  7.4332, -2.9236],
        [-2.5849, -3.3468,  5.5019],
        [ 1.3081, -4.6244, -0.3094],
        [-4.4444,  7.7386, -2.7210],
        [-4.2587,  8.2303, -3.8001],
        [-4.5575,  7.9543, -2.7641],
        [-3.0138,  4.5551, -1.5516],
        [-2.3287,  0.1025,  2.0066],
        [-4.3567,  8.3081, -3.7058],
        [-3.7348,  6.1802, -2.0038],
        [-3.4437,  5.7097, -2.0394],
        [-3.4531,  5.8519, -2.3116],
        [-3.5151,  6.1684, -2.5166],
        [-4.2036,  8.0262, -3.6241],
        [-0.4974, -4.2485,  3.4686],
        [-3.3738,  6.1358, -2.9108

Epoch number 6
 Current loss 0.05583946779370308

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.7415,  4.1811, -0.0149],
        [-3.4968,  6.4541, -3.2283],
        [-4.2415,  8.0385, -3.5707],
        [ 1.1591, -4.8868,  0.2701],
        [-4.2315,  8.2304, -3.8798],
        [-2.4403, -3.3607,  5.3969],
        [-4.3513,  8.0725, -3.3098],
        [-3.3864,  5.2540, -1.6924],
        [-1.5119, -3.5410,  4.3483],
        [-3.7350,  6.1465, -2.0884],
        [-2.9404,  1.4371,  1.4928],
        [-3.5494,  6.0797, -2.3202],
        [-4.2738,  8.1975, -3.6869],
        [-2.9169,  5.4823, -2.6709],
        [-1.7786, -3.4581,  4.6224],
        [-1.8746, -3.3555,  4.6557],
        [-4.4003,  8.0529, -3.2428],
        [-3.5238,  6.3546, -2.7455],
        [-3.8165,  5.9533, -1.6375],
        [-3.4983,  5.9916, -2.3235],
        [-1.6051, -3.3361,  4.3745],
        [-2.2249, -3.8581,  5.5033],
        [-4.3456,  8.3062, -3.6955],
        [-3.5830,  5.9277, -1.9987]

Epoch number 6
 Current loss 0.06243453547358513

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.8146,  7.1311, -3.1949],
        [-3.8849,  6.3624, -2.0744],
        [-1.3457, -2.7677,  3.3679],
        [-3.3155,  5.8603, -2.6760],
        [-4.2572,  6.0174, -1.0228],
        [-4.0074,  7.2170, -2.9395],
        [-1.0533, -4.2199,  4.3490],
        [-4.1530,  7.5999, -3.1325],
        [-1.4401, -4.1634,  4.9071],
        [-3.6633,  6.6610, -3.0210],
        [-2.3198, -2.8097,  4.8675],
        [-2.4815, -3.0880,  5.1727],
        [-4.0124,  6.2749, -1.6085],
        [-2.1964, -4.1196,  5.7036],
        [-1.2153, -3.5490,  4.0407],
        [-3.6616,  6.3890, -2.4324],
        [-3.3890,  5.9612, -2.6336],
        [-1.2745, -3.6340,  4.1426],
        [ 0.0848, -4.7751,  2.8276],
        [-4.1351,  7.2694, -2.8398],
        [-4.1355,  7.6477, -3.3159],
        [-4.2750,  8.2216, -3.8183],
        [-2.0929, -3.6348,  5.1506],
        [-1.8911, -3.9646,  5.2367]

Epoch number 6
 Current loss 0.05765650421380997

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.1998,  7.1673, -2.5064],
        [-3.2102,  5.0626, -1.8533],
        [ 1.3019, -4.4432, -0.4780],
        [ 1.0213, -4.2196,  0.1631],
        [-3.9852,  6.5077, -1.9928],
        [-4.0039,  6.9684, -2.5681],
        [-3.9014,  5.0147, -0.6311],
        [ 1.3481, -4.9762, -0.0143],
        [-0.0589, -3.9820,  2.2061],
        [-3.5626,  6.1033, -2.4774],
        [-3.6543,  6.2117, -2.1316],
        [-0.2207, -4.5204,  3.2112],
        [-3.7920,  0.0093,  3.9201],
        [-4.3072,  8.1193, -3.5470],
        [-3.5856,  6.0100, -2.1125],
        [ 0.9220, -4.7488,  0.9013],
        [-3.9074,  5.8287, -1.3474],
        [-3.7193,  6.8986, -3.2596],
        [-4.0324,  6.7208, -2.2454],
        [ 0.4154, -4.7559,  1.9767],
        [-3.7566,  6.8134, -2.9321],
        [-4.1810,  8.0894, -3.8112],
        [-4.3727,  8.1492, -3.4388],
        [-3.6341,  6.7889, -3.2717]

Epoch number 6
 Current loss 0.08520855009555817

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.6620,  7.4472, -1.9689],
        [-3.6647,  6.7517, -3.0857],
        [-1.7920, -3.0506,  4.3854],
        [-2.1585, -3.9596,  5.5666],
        [-4.0370,  5.9759, -1.3833],
        [-3.9432,  7.4387, -3.3580],
        [-1.2024, -3.7311,  4.1688],
        [-3.3736,  5.6163, -1.9878],
        [-2.2358, -3.5331,  5.2347],
        [-3.3347,  5.7603, -2.5304],
        [-2.0671, -2.5072,  4.2416],
        [-3.5347,  4.5149, -0.5253],
        [-2.4795, -2.5631,  4.7293],
        [-0.9638, -3.8027,  3.7474],
        [-3.3667,  5.6371, -2.0859],
        [-4.5222,  8.3426, -3.3759],
        [-3.3604,  5.2791, -1.7923],
        [ 0.3969, -2.8778,  0.1299],
        [-3.6496,  4.2498, -0.1454],
        [-4.0937,  7.4725, -3.2198],
        [-3.5245,  6.2724, -2.8259],
        [-1.1074, -3.9941,  4.2290],
        [ 1.4354, -4.7391, -0.2842],
        [-3.7857,  6.2552, -2.2164]

Epoch number 6
 Current loss 0.04782635346055031

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.3556,  8.1970, -3.5442],
        [-3.0893,  4.5636, -1.5607],
        [-1.8026, -3.4885,  4.6972],
        [-2.3619,  3.5979, -1.9784],
        [-1.6464, -3.9433,  4.9978],
        [-4.0635,  6.4215, -1.8863],
        [-3.9821,  5.7599, -1.3102],
        [-3.5690,  6.3875, -2.8210],
        [-4.3661,  8.1625, -3.4860],
        [-3.8316,  5.2792, -0.9855],
        [ 1.0063, -5.1641,  1.5216],
        [-4.1251,  7.1601, -2.6733],
        [-4.0384,  4.7695, -0.0538],
        [-4.2824,  8.0901, -3.5676],
        [-2.3534,  4.4361, -3.2601],
        [-4.0432,  5.7134, -0.9355],
        [-4.2309,  7.0644, -2.3169],
        [-2.0722, -4.2073,  5.6832],
        [-2.5919,  2.4022,  0.0100],
        [-3.7740,  6.7095, -2.7225],
        [-4.1343,  7.9824, -3.7588],
        [-3.0122,  5.1681, -2.0485],
        [-3.9796,  6.9008, -2.5889],
        [ 1.4677, -4.8215, -0.1733]

Epoch number 6
 Current loss 0.06839773058891296

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.4433,  8.1838, -3.3336],
        [-1.5590, -3.2914,  4.2106],
        [-3.0202,  4.1023, -0.7542],
        [-4.0131,  7.7218, -3.7110],
        [ 0.9932, -5.0819,  1.1022],
        [-2.1058, -4.0901,  5.5736],
        [ 0.1242, -3.7047,  1.5989],
        [-3.0638,  4.5329, -1.5270],
        [-4.1049,  7.3750, -2.9873],
        [-3.0186,  5.4263, -2.9564],
        [-4.0240,  7.0993, -2.8093],
        [-4.3831,  8.1553, -3.4528],
        [-4.3814,  8.0234, -3.3666],
        [-3.5591,  6.8502, -3.4220],
        [-4.4060,  8.2961, -3.6075],
        [-3.1132,  5.7515, -3.1111],
        [-3.6855,  6.0326, -1.9646],
        [-4.1439,  8.1571, -3.9990],
        [-1.8440, -2.7341,  4.1063],
        [ 0.8996, -4.4009,  1.0289],
        [-3.3922,  5.7977, -2.3951],
        [-1.9335, -3.1246,  4.5728],
        [-4.2781,  7.1442, -2.3994],
        [-2.1691, -1.8366,  3.7772]

Epoch number 6
 Current loss 0.07419431209564209

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 0.7195, -5.1078,  1.8892],
        [ 1.4396, -4.7781, -0.2135],
        [-0.0740, -3.6928,  1.7368],
        [-2.4847,  0.0424,  2.3420],
        [-4.0843,  7.7136, -3.4882],
        [-4.2682,  7.6616, -3.0247],
        [-4.1967,  8.1531, -3.8865],
        [-4.0820,  6.7357, -2.2197],
        [ 0.6949, -3.3452, -0.2072],
        [-3.8383,  5.3534, -1.0978],
        [-1.7844, -3.4842,  4.7072],
        [-4.2734,  6.8988, -2.0598],
        [-1.6457, -3.8063,  4.8291],
        [-1.2557, -3.2297,  3.9334],
        [-3.3664,  6.2343, -3.1564],
        [-1.7086, -2.9904,  4.1979],
        [-4.1425,  7.8842, -3.6363],
        [ 1.1969, -4.7847,  0.2410],
        [-0.9495, -3.0186,  3.3282],
        [-3.2045,  5.5355, -2.5041],
        [-3.8485,  6.9895, -2.9417],
        [-4.0310,  6.9909, -2.5348],
        [-2.2553, -3.7588,  5.4713],
        [-4.2487,  8.1677, -3.8009]

Epoch number 6
 Current loss 0.04960569739341736

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.3067,  7.5911, -2.8260],
        [-2.7466,  5.7074, -3.8526],
        [-4.4843,  8.5124, -3.7852],
        [ 0.5392, -1.7038, -1.9815],
        [ 0.7999, -2.9857, -1.0921],
        [-2.5193,  2.8815, -0.6855],
        [-4.5532,  7.0806, -1.7557],
        [-3.3131,  5.8507, -2.6983],
        [-3.6104,  6.2153, -2.4148],
        [-2.8706,  4.2233, -1.6189],
        [-4.5351,  5.0705,  0.0493],
        [-4.5755,  8.5809, -3.7490],
        [-3.2764,  6.0915, -3.0100],
        [-3.7544,  6.1894, -2.1080],
        [-4.7729,  8.0534, -2.6740],
        [-2.4994, -0.4040,  2.8242],
        [-4.2531,  8.2982, -3.9816],
        [-4.4198,  7.9155, -3.0939],
        [-1.6461, -4.0808,  5.0439],
        [ 1.3308, -4.7231, -0.0793],
        [-4.2246,  7.5265, -2.9498],
        [-3.6559,  5.2930, -1.0438],
        [-2.0358, -4.0967,  5.5271],
        [-3.2437,  5.4176, -2.3160]

Epoch number 6
 Current loss 0.08515126258134842

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.6678, -3.8989,  6.0514],
        [-3.2976,  5.1124, -1.5635],
        [-3.9518,  6.7599, -2.6024],
        [-3.4720,  5.7788, -2.1074],
        [ 0.0149, -4.4994,  2.6539],
        [-4.2364,  8.1309, -3.8407],
        [-4.2410,  7.1463, -2.4853],
        [ 1.1917, -4.5188,  0.4018],
        [-4.1558,  7.8511, -3.5970],
        [-4.4971,  7.2007, -2.0664],
        [-4.4229,  6.2723, -1.1243],
        [-4.3326,  8.2411, -3.7043],
        [-4.2218,  6.6452, -1.9254],
        [-3.7896,  6.5936, -2.5978],
        [-4.0784,  7.7399, -3.6678],
        [-3.8173,  6.8446, -2.8745],
        [-3.7210,  5.7135, -1.5195],
        [-3.4891,  6.4264, -3.1132],
        [-3.9355,  7.2276, -3.1319],
        [-4.3792,  8.1980, -3.5644],
        [-4.4476,  8.4061, -3.7232],
        [-3.2481,  5.6479, -2.4483],
        [-4.1763,  7.8995, -3.6427],
        [-4.0001,  7.4982, -3.4241]

Epoch number 6
 Current loss 0.03474534675478935

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.6780,  8.6299, -3.5677],
        [-1.0376, -3.0157,  3.0532],
        [-3.9199,  3.9955,  0.2979],
        [-4.4263,  4.9524,  0.2747],
        [-3.4442,  5.1512, -1.5774],
        [-4.2815,  7.5681, -2.9842],
        [-4.4514,  8.3085, -3.5288],
        [-4.2668,  6.1831, -1.2689],
        [-3.6734,  6.4916, -2.8440],
        [-3.8715,  6.0126, -1.6860],
        [-2.0632, -4.1084,  5.5523],
        [-4.5575,  5.4940, -0.1306],
        [-2.2817, -3.8764,  5.5505],
        [ 1.1732, -3.7325, -0.6262],
        [-2.5310, -0.9132,  3.3423],
        [-2.0860, -1.1275,  3.0325],
        [-4.0793,  5.6977, -1.0070],
        [-4.4158,  8.3475, -3.7671],
        [-4.0327,  7.6194, -3.5193],
        [-3.4690,  5.9323, -2.2678],
        [-3.0764, -2.5364,  5.4256],
        [ 0.9619, -4.4600,  0.8875],
        [-1.4735, -4.0951,  4.8855],
        [-4.6366,  8.6295, -3.6877]

Epoch number 6
 Current loss 0.03381292521953583

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 0.7035, -4.5727,  1.0863],
        [-4.0878,  7.1852, -2.8158],
        [-1.5590, -4.0781,  4.8846],
        [-2.5270, -3.7781,  5.7870],
        [ 1.1806, -5.1309,  0.9067],
        [-3.4161,  5.8734, -2.4812],
        [-4.2450,  7.7972, -3.3597],
        [-4.3533,  8.2959, -3.7874],
        [-3.7577,  4.6327, -0.2800],
        [ 1.0512, -4.2197, -0.0190],
        [-4.2463,  7.9525, -3.6237],
        [-3.8666,  6.9585, -3.0031],
        [-0.7925, -4.3473,  4.1360],
        [-4.2680,  7.7791, -3.3486],
        [-2.5123,  4.2764, -1.9760],
        [ 1.0259, -4.7179,  0.4345],
        [-4.4414,  8.2544, -3.6445],
        [-4.3848,  8.2470, -3.7040],
        [-3.7016,  5.5110, -1.2949],
        [-1.6226, -4.1681,  5.1000],
        [-3.9138,  6.3398, -2.0500],
        [-1.2090, -3.4817,  3.7768],
        [-3.7448,  7.1587, -3.5093],
        [-3.1911,  3.5116, -0.1294]

Epoch number 6
 Current loss 0.13496273756027222

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.6584,  8.1006, -2.8943],
        [-2.6149, -3.5005,  5.7101],
        [-4.5148,  8.6642, -3.9601],
        [-4.0958,  6.3121, -1.6648],
        [-4.5419,  8.6346, -3.9219],
        [-1.8873, -3.9788,  5.2513],
        [-3.8344,  4.2260,  0.2471],
        [-1.5431, -2.5182,  3.6622],
        [-2.3691, -4.3140,  6.1490],
        [-5.1770,  6.4300, -0.5621],
        [-3.8827,  6.9447, -3.0193],
        [-3.8415,  3.6677,  0.6873],
        [-1.7587, -3.3373,  4.5466],
        [-2.5728, -3.5300,  5.6014],
        [-3.9995,  7.0311, -2.8414],
        [-4.2545,  7.9687, -3.6251],
        [ 1.0432, -4.5300,  0.1471],
        [-4.4707,  6.4196, -1.2588],
        [-4.4017,  8.4484, -3.9389],
        [-4.0745,  7.5204, -3.4804],
        [-4.4375,  7.9555, -3.1471],
        [-3.8514,  6.5385, -2.4097],
        [-1.1566, -3.9129,  4.1363],
        [-3.9081,  6.7836, -2.5435]

Epoch number 6
 Current loss 0.04424787685275078

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.4936,  4.5260, -0.5262],
        [-3.4580,  5.0588, -1.3282],
        [-2.6637, -3.1257,  5.4353],
        [-4.4988,  6.0454, -0.9230],
        [-3.5072, -2.2461,  5.6498],
        [ 0.0906, -4.3587,  2.1881],
        [-4.1016,  6.8279, -2.2738],
        [-4.0068,  6.6082, -2.1100],
        [-3.8608,  6.9624, -3.0662],
        [-4.6068,  7.9513, -2.8181],
        [-3.4004,  5.8964, -2.6379],
        [-2.2665, -4.1355,  5.7960],
        [-3.6672,  5.5169, -1.3194],
        [-4.1995,  6.0885, -1.3606],
        [-2.4506, -4.1971,  6.0935],
        [-3.5524,  5.5536, -1.8675],
        [-3.5034,  6.2625, -2.9140],
        [-4.5001,  8.0142, -3.1677],
        [-4.2047,  7.9804, -3.7633],
        [-1.7873, -4.0924,  5.2219],
        [-0.4813, -3.5889,  2.4978],
        [-4.4905,  7.2953, -2.1929],
        [-3.9617,  7.1759, -3.1717],
        [-4.3246,  7.3710, -2.6088]

Epoch number 6
 Current loss 0.11475372314453125

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.7224,  6.3435, -2.3325],
        [-4.4854,  8.0199, -3.1741],
        [-1.9753, -3.0525,  4.4976],
        [-4.4384,  8.4011, -3.7908],
        [-4.1153,  7.8364, -3.8165],
        [-4.0812,  7.4533, -3.1961],
        [ 0.7375, -4.4385,  0.8911],
        [-3.6539,  6.5481, -2.9705],
        [-2.6087,  3.6574, -1.5905],
        [-1.2805, -2.2489,  2.6505],
        [-2.8411,  5.3201, -3.3160],
        [-4.2985,  8.3209, -4.0717],
        [-4.4790,  8.2077, -3.4202],
        [-4.3839,  8.2528, -3.6621],
        [-3.7748,  6.0335, -1.8178],
        [-3.7878,  4.7992, -0.5739],
        [-3.6699,  5.7936, -1.7572],
        [-4.3564,  6.0564, -1.1220],
        [-2.2682, -3.8998,  5.6171],
        [-2.9651,  4.7704, -2.0074],
        [-4.3773,  8.4001, -3.9288],
        [ 0.4910, -4.0734,  1.0673],
        [-4.0729,  7.4766, -3.2917],
        [-2.9367, -1.8347,  4.6542]

Epoch number 6
 Current loss 0.08805008232593536

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.7517, -3.1910,  4.3761],
        [-0.0190, -3.2899,  1.1598],
        [-3.8007,  6.9272, -3.0698],
        [-3.8574,  6.1825, -1.8557],
        [-4.2609,  7.7625, -3.1736],
        [-3.6311,  5.6555, -1.7021],
        [-3.5026,  1.4075,  2.3986],
        [-4.1254,  7.8418, -3.7610],
        [-3.8274, -0.3190,  4.1859],
        [-3.1339,  4.6548, -1.5432],
        [-4.1938,  7.7987, -3.4447],
        [-2.0300, -3.5400,  5.0044],
        [-3.1920, -0.2367,  3.4423],
        [ 1.3308, -4.5130, -0.2890],
        [-0.7915, -4.5875,  4.2279],
        [-1.3685, -1.8122,  2.1693],
        [-3.5884,  6.0399, -2.2991],
        [-1.0148, -1.4186,  1.1186],
        [-2.0694, -0.4023,  2.2122],
        [-4.3849,  8.2445, -3.7133],
        [-2.6379, -2.5701,  4.8729],
        [ 0.1630, -4.7957,  2.5957],
        [-3.8118,  4.7408, -0.6025],
        [ 1.3176, -4.1386, -0.6230]

Epoch number 6
 Current loss 0.08514008671045303

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 0.7066, -5.1293,  1.7233],
        [-3.9074,  7.5747, -3.7712],
        [-1.2166, -2.6666,  2.9496],
        [-3.5813,  6.8089, -3.4001],
        [-0.1851, -3.4038,  1.6851],
        [-4.2066,  8.2905, -4.1735],
        [-3.1917,  4.7691, -1.6885],
        [-3.4322,  6.3877, -3.4008],
        [-2.4657,  4.4740, -3.0170],
        [-3.2709,  5.8709, -2.9825],
        [-3.2643,  5.8259, -2.8484],
        [-3.4758,  5.1548, -1.4677],
        [-4.0160,  7.3200, -3.1908],
        [ 1.3931, -4.4461, -0.3971],
        [-3.1987,  5.5659, -2.6692],
        [-3.8263,  6.2180, -2.1687],
        [-3.6338,  6.5337, -3.1846],
        [-4.3752,  8.4134, -4.0087],
        [-3.1280,  4.9716, -1.9568],
        [-2.7596, -3.1644,  5.5422],
        [-3.3635,  6.8473, -4.0080],
        [-1.5434, -3.7369,  4.6307],
        [-4.2028,  7.7272, -3.5409],
        [-2.3036, -3.0597,  4.9356]

Epoch number 6
 Current loss 0.07885386049747467

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.9080,  6.5214, -2.3146],
        [-2.1312, -3.8071,  5.3729],
        [-4.3673,  8.2045, -3.6828],
        [-3.3409,  5.8792, -2.6191],
        [-4.0341,  7.1128, -2.9373],
        [-4.0914,  7.8686, -3.7704],
        [-2.7731, -3.3128,  5.7153],
        [-3.8209,  7.1259, -3.4798],
        [-4.2036,  7.8285, -3.4194],
        [ 0.3300, -4.3585,  1.7409],
        [ 0.4604, -4.6320,  1.8126],
        [-4.2334,  8.0462, -3.7145],
        [-3.7368,  6.6136, -2.7433],
        [ 0.9526, -4.3692,  0.1446],
        [-3.1374,  4.4203, -1.3923],
        [-4.1841,  7.8467, -3.6228],
        [-4.1242,  7.7512, -3.5523],
        [-4.0823,  7.6272, -3.6298],
        [-3.6645,  5.4584, -1.4700],
        [-3.0376,  5.4998, -3.1107],
        [-3.6960,  6.7886, -3.0655],
        [-3.5939,  6.2895, -2.7148],
        [-3.8872,  6.6316, -2.4759],
        [-4.0741,  7.6271, -3.4948]

Epoch number 7
 Current loss 0.0395817868411541

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.9858,  6.8125, -2.3595],
        [-4.3390,  8.4920, -4.1538],
        [-1.2742, -3.7089,  4.1551],
        [-4.0751,  7.6100, -3.4148],
        [-4.2219,  7.9720, -3.6579],
        [-4.2002,  6.5799, -1.8694],
        [-4.1288,  5.7200, -1.1102],
        [-3.7915,  6.0999, -2.0905],
        [ 1.1491, -4.4061, -0.1637],
        [-3.7499,  6.1005, -2.1572],
        [ 0.8632, -4.7221,  0.9220],
        [-3.2065,  6.3515, -3.6576],
        [-4.5101,  6.9579, -1.8203],
        [-2.2603,  4.1320, -2.8153],
        [-3.8521,  7.2047, -3.4167],
        [-2.7247, -2.0299,  4.5730],
        [-3.9890,  7.0821, -2.9049],
        [-2.3927, -2.8722,  4.9029],
        [-3.2953,  5.8285, -2.8523],
        [-3.9040,  7.2858, -3.5213],
        [-4.1564,  6.0774, -1.2790],
        [-3.3061,  6.2894, -3.3677],
        [-2.0678, -3.3921,  4.9581],
        [-3.9350,  6.2797, -2.0008],

Epoch number 7
 Current loss 0.02603679522871971

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.5329, -3.0695,  5.2456],
        [-3.7770,  5.6789, -1.6192],
        [-4.2331,  8.1236, -3.8620],
        [-2.1722, -3.7712,  5.3708],
        [-4.0713,  7.8137, -3.7322],
        [-4.2160,  8.2277, -4.0800],
        [-4.3024,  7.9601, -3.3286],
        [-2.8726,  3.7481, -0.6446],
        [-2.1862, -3.7102,  5.3115],
        [-3.8603,  6.6997, -2.6895],
        [ 1.1058, -3.5422, -1.0406],
        [-2.8389,  3.1225, -0.5549],
        [-3.5677,  5.6512, -1.8824],
        [-4.6262,  7.1517, -1.8994],
        [-3.6329,  5.8924, -2.0743],
        [-4.0850,  5.7536, -1.3280],
        [-1.8422,  3.9600, -3.5451],
        [-4.3673,  8.5586, -4.1511],
        [-3.5957,  5.4762, -1.7081],
        [-4.3127,  8.3316, -3.9912],
        [-3.5160,  5.3030, -1.5534],
        [-0.4089, -4.0666,  3.1505],
        [-3.7135,  5.8475, -1.6773],
        [-3.9388,  6.2670, -2.0056]

Epoch number 7
 Current loss 0.05373111739754677

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.3873,  6.1293, -2.8412],
        [ 1.0137, -4.5265,  0.2283],
        [-4.3621,  8.3311, -3.8705],
        [-3.1390,  4.6019, -1.5864],
        [-4.0934,  7.9031, -3.8153],
        [-4.4916,  8.6388, -3.9686],
        [-4.2528,  7.4970, -2.9610],
        [-0.9564, -4.1022,  4.0937],
        [-3.2777,  5.1343, -1.7796],
        [-2.6650, -3.4624,  5.6435],
        [-1.8652, -3.5534,  4.7707],
        [ 1.4343, -4.9360, -0.0217],
        [-4.3418,  6.2692, -1.2482],
        [-4.0702,  7.7787, -3.7356],
        [-3.3575,  5.7570, -2.4169],
        [-4.2315,  7.7040, -3.1878],
        [-3.9045,  6.7006, -2.4842],
        [-4.0258,  6.5386, -2.1148],
        [ 0.7974, -2.8051, -0.5906],
        [-1.1273, -4.0541,  4.3514],
        [-3.9143,  6.3666, -2.0707],
        [ 1.3632, -4.7448, -0.1002],
        [-3.2076,  4.8404, -1.6597],
        [-4.2642,  8.0332, -3.6255]

Epoch number 7
 Current loss 0.07093606144189835

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.3386,  5.7570, -2.5296],
        [-4.0764,  6.8979, -2.3759],
        [-2.8172, -3.0645,  5.5695],
        [-3.9706,  7.3163, -3.2218],
        [-1.9022,  2.8085, -2.2700],
        [-4.2366,  7.5817, -3.0913],
        [-3.2588,  5.5750, -2.5584],
        [-4.1986,  7.6685, -3.2408],
        [-1.0097, -4.6053,  4.6493],
        [-3.8214,  4.4269, -0.0234],
        [-3.7110,  6.6787, -2.9416],
        [-1.8035, -3.7805,  4.9681],
        [-4.2847,  5.5374, -0.6798],
        [-2.8067,  3.3175, -0.5826],
        [-3.6062,  6.0681, -2.3913],
        [-2.9079,  5.2572, -2.4614],
        [-4.5466,  8.2361, -3.3437],
        [-3.1043,  3.4493, -0.5508],
        [-2.1444, -4.0577,  5.6140],
        [-3.2934,  5.8575, -2.7471],
        [-2.4444,  4.0480, -2.6101],
        [-2.6457, -2.1204,  4.6108],
        [-2.0613, -3.6311,  5.1512],
        [-3.8702,  7.2341, -3.3854]

Epoch number 7
 Current loss 0.05255427956581116

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.0099, -3.0228,  5.7222],
        [-4.2257,  8.0441, -3.6668],
        [-3.3508,  5.9556, -2.8570],
        [ 0.8098, -5.1754,  1.4552],
        [-4.1112,  7.5404, -3.3426],
        [-3.3010,  5.3217, -2.1530],
        [-3.9291,  6.9033, -2.8101],
        [-2.2991, -3.6234,  5.3665],
        [-4.2668,  8.0054, -3.5465],
        [-2.3877, -3.7587,  5.5575],
        [-4.1501,  7.8686, -3.6326],
        [-2.0266, -3.9305,  5.3278],
        [-2.7917,  3.6988, -0.7665],
        [-4.2495,  5.0078, -0.2505],
        [-2.1487,  0.8520,  0.7205],
        [-2.2840, -3.1429,  4.9825],
        [-3.7316,  6.2669, -2.4211],
        [-2.4374,  3.0676, -1.1765],
        [-4.2656,  8.2250, -3.9397],
        [-2.1877, -4.0219,  5.5954],
        [-1.7005, -4.4696,  5.4484],
        [-1.4118,  2.8602, -3.2918],
        [-1.9366, -3.9053,  5.2069],
        [-3.8829,  7.2983, -3.4851]

Epoch number 7
 Current loss 0.047824546694755554

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.3460,  8.4375, -4.0392],
        [-4.1451,  7.0904, -2.7614],
        [-3.5222,  5.6239, -1.9533],
        [-1.9648, -3.8484,  5.2177],
        [-4.3013,  8.3576, -4.1034],
        [ 1.3490, -4.5570, -0.5617],
        [-3.7852,  7.5347, -4.0732],
        [-4.3384,  7.3449, -2.6836],
        [-1.3595, -4.0244,  4.6328],
        [-4.5050,  3.9883,  1.0082],
        [ 0.4709, -2.7716, -0.4317],
        [-0.2565, -4.5353,  3.1274],
        [-0.9145, -4.4234,  4.3358],
        [ 0.5284, -2.7232, -0.7233],
        [-1.6925, -3.7998,  4.8221],
        [-4.3014,  6.1782, -1.3036],
        [-4.0442,  6.6180, -2.2353],
        [-3.5687,  6.0899, -2.2990],
        [-1.2923, -4.2487,  4.6309],
        [-0.8004, -4.7488,  4.6849],
        [-4.4920,  7.3887, -2.4077],
        [-3.1227,  4.6301, -1.7205],
        [-3.7251,  3.3556,  0.6577],
        [-1.6695, -4.1690,  5.1368

Epoch number 7
 Current loss 0.031126797199249268

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.3546,  8.5069, -4.2119],
        [-3.3236,  5.5525, -2.2606],
        [-4.4002,  8.5675, -4.2057],
        [-3.4755,  6.0415, -2.6779],
        [-2.6329, -2.1143,  4.5665],
        [-2.1542, -3.8911,  5.4623],
        [-4.4145,  7.8641, -3.2521],
        [-3.8247,  6.3170, -2.1557],
        [-3.8367,  7.0514, -3.3807],
        [-3.4506,  6.3136, -3.2282],
        [-2.0051, -3.6186,  5.0291],
        [ 1.4061, -4.9921, -0.2353],
        [-4.4373,  6.7930, -1.8284],
        [-3.6657, -1.6896,  5.2461],
        [-2.5347,  5.1456, -3.5888],
        [-2.3354, -4.2017,  5.9593],
        [-3.5687,  6.2741, -2.7658],
        [-4.1888,  7.2360, -2.7438],
        [-2.4091, -3.9751,  5.7995],
        [-3.7435,  5.2698, -1.2456],
        [-3.9296,  7.3415, -3.6396],
        [-4.2326,  8.1477, -3.9217],
        [-4.5166,  8.1290, -3.3747],
        [-1.0698, -2.5859,  2.4573

Epoch number 7
 Current loss 0.04052930325269699

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.1580,  8.1498, -4.0662],
        [-4.2300,  5.2913, -0.4638],
        [-3.3484,  5.0782, -1.7483],
        [-3.7734,  6.7635, -2.9426],
        [-3.6221,  5.2288, -1.4369],
        [-3.9877,  7.5097, -3.5648],
        [-2.5407, -3.2417,  5.4122],
        [-2.4471, -1.5162,  3.7123],
        [-2.3991, -2.9057,  4.9784],
        [-1.2666, -2.4279,  2.6628],
        [-4.4663,  7.9456, -3.2235],
        [-2.9283,  5.4034, -3.3530],
        [-4.1288,  6.8448, -2.3422],
        [-4.6011,  7.9052, -2.9157],
        [-2.7655, -3.8109,  5.9833],
        [-2.4015, -3.9403,  5.7766],
        [-2.6591, -1.0746,  3.6454],
        [-3.4248,  5.3453, -1.8617],
        [-3.9145,  6.2041, -2.0276],
        [-4.0643,  4.1481,  0.5795],
        [-4.5134,  8.5568, -3.8980],
        [-4.0601,  6.0431, -1.5322],
        [ 1.2634, -5.3200,  0.3927],
        [-4.4522,  8.2224, -3.6647]

Epoch number 7
 Current loss 0.07681232690811157

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.1679,  6.9891, -2.3595],
        [-4.6847,  8.7772, -3.8738],
        [-3.6798,  5.8906, -2.0216],
        [-1.9793, -3.3959,  4.7971],
        [-2.4240, -4.0329,  5.8168],
        [-4.4900,  8.7028, -4.2038],
        [-0.8207, -4.4764,  4.3007],
        [-4.1231,  7.9505, -3.9165],
        [-3.9917,  7.6477, -3.7653],
        [-2.0581, -3.6559,  5.1420],
        [-4.0326,  6.5768, -2.1701],
        [-1.3715, -2.0887,  2.8822],
        [-4.4500,  7.4810, -2.6676],
        [-4.9163,  8.4299, -2.9329],
        [-4.7139,  7.7809, -2.5117],
        [-2.1843, -4.1902,  5.7498],
        [-4.5683,  7.9262, -3.0455],
        [-3.7445,  6.7603, -2.9089],
        [ 0.2116, -3.3649,  0.7332],
        [-4.6680,  7.9501, -2.8141],
        [-4.5630,  8.8013, -4.1473],
        [-4.5493,  8.1879, -3.2968],
        [-4.7201,  6.1082, -0.7370],
        [-2.6538, -4.0326,  6.0817]

Epoch number 7
 Current loss 0.05244475230574608

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.1218,  7.6850, -3.5797],
        [-4.4688,  7.6225, -2.7845],
        [-3.2956,  5.3719, -2.1206],
        [-4.3442,  8.1205, -3.7332],
        [-4.3631,  8.1309, -3.6083],
        [-4.3617,  7.8777, -3.2804],
        [-3.5932,  6.6151, -3.5579],
        [-3.8220,  6.9299, -3.3897],
        [-3.3622,  5.3243, -1.8797],
        [ 1.2215, -3.8909, -1.1777],
        [-4.4414,  7.7727, -2.9706],
        [-4.1471,  7.4021, -3.2169],
        [-4.1315,  5.9081, -1.3016],
        [-4.6368,  8.6180, -3.7756],
        [-0.7095, -4.6813,  4.1335],
        [-4.4007,  7.3574, -2.6564],
        [-2.6884,  5.2135, -3.2299],
        [-1.7432,  3.2661, -3.3024],
        [-3.0335,  5.4873, -2.7353],
        [-4.2361,  6.8400, -2.2103],
        [-4.2495,  8.3605, -4.2308],
        [-4.5333,  8.7317, -4.1282],
        [-3.8257,  6.7673, -2.9543],
        [-3.9611,  7.5430, -3.7444]

Epoch number 7
 Current loss 0.049466364085674286

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 1.4424, -4.6935, -0.6303],
        [-4.2050,  8.1906, -4.0820],
        [-4.2982,  6.4311, -1.5891],
        [-2.2601, -3.0023,  4.8697],
        [-4.6153,  8.1485, -3.2438],
        [-2.6868, -2.8349,  5.1282],
        [ 0.8998, -3.2122, -0.9948],
        [-4.0499,  7.2044, -3.0421],
        [-3.8890,  6.8354, -3.1523],
        [-2.5794, -1.1234,  3.4636],
        [-3.3383,  6.6384, -3.6961],
        [-3.2516, -2.3648,  5.3785],
        [-3.9743,  6.0927, -1.8298],
        [-2.4495, -3.0413,  5.1260],
        [ 1.2402, -5.1425,  0.3781],
        [-4.0693,  7.4702, -3.4587],
        [-3.0961,  4.3318, -1.4943],
        [-3.7622,  5.4919, -1.3806],
        [-4.3512,  7.7976, -3.3457],
        [-4.5295,  7.9434, -3.1065],
        [-2.3952, -3.8977,  5.6740],
        [-3.3358,  5.8371, -2.8500],
        [-4.1694,  7.6516, -3.3753],
        [-4.5953,  8.8251, -4.1773

Epoch number 7
 Current loss 0.06271816790103912

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.7083, -1.5357,  4.0587],
        [ 1.1584, -4.6877,  0.5246],
        [-1.9844, -2.6770,  4.2369],
        [ 1.3625, -5.2851,  0.4018],
        [-2.3648, -3.0974,  5.0597],
        [-4.2364,  6.5598, -2.0291],
        [-4.1102,  5.2876, -0.6476],
        [-2.6138, -1.1983,  3.7110],
        [-4.1160,  7.2498, -3.0250],
        [-3.2853,  6.0453, -2.8866],
        [-1.8518,  3.7241, -3.9215],
        [-2.7915, -1.9784,  4.5360],
        [-4.8721,  6.8666, -1.3380],
        [-3.6750,  5.7763, -2.0149],
        [-4.6195,  8.6437, -3.8830],
        [-4.4261,  6.8259, -1.8570],
        [-2.4181, -4.2494,  6.0234],
        [-2.4393, -4.0534,  5.9223],
        [-0.6851, -1.7787,  0.5961],
        [-4.3077,  8.3888, -4.1914],
        [-3.5167,  6.0092, -2.5305],
        [-2.4314,  4.2824, -2.7591],
        [-2.5975, -3.6721,  5.7254],
        [-4.3689,  7.4567, -2.7282]

Epoch number 7
 Current loss 0.031119024381041527

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.8485, -3.7208,  4.9629],
        [-2.1771, -2.2230,  3.9982],
        [-4.2303,  6.4392, -1.9438],
        [-3.0016,  6.0064, -3.5874],
        [-3.7265,  6.7435, -3.1208],
        [-4.0136,  7.5444, -3.6379],
        [-4.1382,  6.4965, -1.9862],
        [-3.8810,  5.9910, -1.7584],
        [-4.2524,  8.0457, -3.8059],
        [-4.4616,  8.5334, -4.0806],
        [-3.1985,  3.3232,  0.0960],
        [-4.5497,  8.7699, -4.2329],
        [-3.6888,  4.9437, -0.8198],
        [-1.1024, -2.3930,  2.3095],
        [-4.9546,  8.2858, -2.7608],
        [-4.5224,  7.9021, -3.1121],
        [-4.2865,  8.1325, -3.8616],
        [-3.8946,  6.6369, -2.4388],
        [-4.1464,  5.7622, -1.1856],
        [-2.8878, -0.0338,  2.9661],
        [-3.6667,  4.9965, -0.8522],
        [ 0.7846, -4.8875,  1.3915],
        [-2.3149, -3.8873,  5.5375],
        [-3.6261,  6.4925, -3.1036

Epoch number 7
 Current loss 0.03863462805747986

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.4622,  8.5476, -4.0139],
        [-2.6533, -1.9699,  4.3964],
        [-4.5026,  8.6537, -4.2276],
        [-4.4768,  8.5703, -4.0492],
        [ 0.9701, -3.8315, -0.5298],
        [ 1.3192, -4.1451, -0.7254],
        [-3.6374,  5.6370, -2.0004],
        [-4.5886,  8.4727, -3.6743],
        [-4.4230,  8.5302, -4.1686],
        [-4.1297,  4.0156,  0.3771],
        [ 0.4048, -4.3773,  1.9238],
        [-4.7515,  8.5819, -3.5750],
        [-0.2510, -3.3592,  1.5898],
        [-3.2817, -0.8960,  4.1333],
        [-2.9504, -1.9916,  4.6825],
        [-3.8581,  7.1859, -3.4887],
        [ 0.6560, -2.4213, -1.5143],
        [-4.6263,  7.7967, -2.7415],
        [ 1.4413, -4.9335, -0.3930],
        [-0.2210, -3.0771,  1.2522],
        [-4.6177,  8.9206, -4.2331],
        [-2.8695, -1.9342,  4.5788],
        [-3.2277,  5.8271, -3.3096],
        [-3.0120, -3.0128,  5.6221]

Epoch number 7
 Current loss 0.03257102891802788

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.7567, -2.3878,  3.4057],
        [-1.5838, -3.3966,  4.3110],
        [-3.8785,  6.6431, -2.6250],
        [-3.7589,  6.0759, -2.1807],
        [-1.4373,  2.6267, -2.3939],
        [-4.5573,  8.4902, -3.7756],
        [-4.5242,  8.3691, -3.7114],
        [-4.5307,  8.7242, -4.2510],
        [-3.6955,  6.3361, -2.5631],
        [-3.1253,  5.4963, -3.0919],
        [-4.5486,  8.6434, -4.0803],
        [-4.7149,  7.0521, -1.6871],
        [-4.5317,  8.6753, -4.1328],
        [-3.9129,  6.5859, -2.4442],
        [-3.9013,  7.5286, -3.8987],
        [ 0.0329, -0.3280, -3.1312],
        [-4.2702,  7.8107, -3.5362],
        [-2.0302, -4.3570,  5.6945],
        [-4.0344,  6.7096, -2.4207],
        [-4.4635,  7.7001, -2.9443],
        [-4.7113,  3.6996,  1.5189],
        [-1.3507, -2.7959,  3.3680],
        [-2.3287, -4.1951,  5.8514],
        [-4.5179,  8.2300, -3.5556]

Epoch number 7
 Current loss 0.1029052585363388

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.7724,  4.5155, -2.6433],
        [-4.7900,  8.3347, -3.1859],
        [-4.6162,  8.5415, -3.6863],
        [-4.4596,  7.1027, -2.2389],
        [-1.6787,  3.5254, -3.8584],
        [-2.6604, -3.2456,  5.4071],
        [-1.8646, -4.4201,  5.5505],
        [ 0.9820, -4.2241, -0.1083],
        [-3.7694,  6.4404, -2.6805],
        [ 1.2546, -4.8456, -0.0915],
        [ 1.2791, -4.2657, -0.7699],
        [-4.3347,  8.2747, -3.9231],
        [-3.6070,  4.0749,  0.0340],
        [-3.8162,  4.4893, -0.3681],
        [-4.3943,  8.1220, -3.6663],
        [-4.2350,  7.9031, -3.5813],
        [-4.0124,  6.9110, -2.8842],
        [-3.8171,  6.8520, -3.1535],
        [-4.3882,  8.3540, -4.0219],
        [-3.4082,  6.0911, -3.0720],
        [-3.4132, -1.0821,  4.3634],
        [-2.1106, -4.3105,  5.6896],
        [-4.7445,  7.7399, -2.5170],
        [-2.0674, -3.7370,  5.1760],

Epoch number 7
 Current loss 0.07488138973712921

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.5808, -3.0357,  4.0109],
        [-4.5222,  8.5066, -3.8450],
        [-4.6067,  8.1949, -3.3192],
        [-3.9912,  6.5214, -2.3994],
        [-3.3520,  2.6508,  0.7685],
        [-2.0338, -4.1900,  5.4930],
        [-4.1199,  7.8276, -3.8167],
        [-3.5036,  5.9260, -2.2932],
        [-2.2930, -3.9135,  5.5839],
        [-4.2166,  7.3794, -3.0895],
        [ 1.3141, -4.0948, -0.8320],
        [-2.6376, -2.6743,  4.9709],
        [-3.7599,  7.0354, -3.4714],
        [-1.4277, -3.0605,  3.9750],
        [-4.4877,  8.4492, -3.9675],
        [-2.4932, -2.9791,  4.9809],
        [ 0.9855, -5.1676,  1.0344],
        [-1.6782, -4.6336,  5.5670],
        [-4.4440,  8.6575, -4.2720],
        [-4.3433,  8.3361, -4.0908],
        [-1.0298, -4.0179,  4.4242],
        [-2.3516, -3.8548,  5.5602],
        [-1.6005, -4.5513,  5.3619],
        [-4.3749,  8.0303, -3.4496]

Epoch number 7
 Current loss 0.04563109576702118

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.6909, -1.0774,  1.8314],
        [-1.3588, -2.2720,  2.7481],
        [-4.0323,  7.5742, -3.6373],
        [-4.6772,  8.7590, -3.9715],
        [-4.2220,  7.7671, -3.5218],
        [-0.6266, -4.6188,  3.8805],
        [-4.5560,  8.6110, -3.9293],
        [-3.7306,  5.6145, -1.7590],
        [-3.6272,  6.5791, -3.4018],
        [-2.0027, -4.1672,  5.4333],
        [-4.7033,  8.6953, -3.8093],
        [-4.8234,  8.2198, -2.9421],
        [-4.1646,  7.8180, -3.7621],
        [-4.6054,  8.7227, -4.0394],
        [-1.7754, -3.6461,  4.7731],
        [-4.4616,  7.9054, -3.2120],
        [ 0.8712, -2.4586, -2.4644],
        [-2.1345, -4.7351,  6.1236],
        [-1.9430, -4.4040,  5.6033],
        [-3.5700,  5.5460, -1.9757],
        [ 1.6195, -5.6067,  0.0081],
        [-4.6523,  8.6272, -3.8234],
        [ 1.3909, -5.3326,  0.7435],
        [-4.2397,  7.8286, -3.5748]

Epoch number 7
 Current loss 0.04682590067386627

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.7169,  6.9374, -1.4938],
        [-3.4068,  5.2379, -2.0570],
        [-4.6086,  6.9845, -1.9125],
        [ 1.2999, -3.9295, -0.8869],
        [ 1.1979, -5.0754,  0.4210],
        [-4.5142,  7.1353, -2.1035],
        [-3.8379,  6.7898, -2.9481],
        [-4.6730,  6.4468, -1.2974],
        [-4.7705,  7.3822, -2.0886],
        [-0.2133, -4.7112,  3.2331],
        [-4.5488,  8.7703, -4.1699],
        [-3.3861,  5.7335, -2.6232],
        [-4.4578,  8.1823, -3.5566],
        [-4.7850,  6.7998, -1.3776],
        [-4.1522,  4.8325, -0.2328],
        [-4.6226,  8.8704, -4.1057],
        [-3.7786,  6.4679, -2.6795],
        [-1.8831, -3.9319,  5.0966],
        [-3.2693, -2.3236,  5.3412],
        [ 1.3289, -3.8727, -1.3724],
        [-1.8665,  3.7720, -3.5731],
        [-3.9134,  4.1721,  0.1705],
        [-2.8688,  5.1184, -3.0508],
        [-4.6152,  8.0559, -3.1109]

Epoch number 7
 Current loss 0.041386380791664124

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.2362, -4.4695,  6.0011],
        [-4.4514,  8.0965, -3.4360],
        [-3.5880,  5.1480, -1.4600],
        [-2.5353,  4.9431, -3.2845],
        [-1.4030, -4.3280,  4.8652],
        [-3.6325,  6.3910, -3.0638],
        [-3.8381,  6.3578, -2.2952],
        [-1.8805, -4.6337,  5.6901],
        [-2.9358,  5.4456, -3.5108],
        [-4.4270,  8.5122, -4.1016],
        [-1.4428, -4.5408,  5.1528],
        [-2.9760,  4.7664, -2.4129],
        [-3.9891,  5.6891, -1.3986],
        [-4.4859,  7.2922, -2.3752],
        [-1.6423, -2.5776,  3.2624],
        [-4.4472,  8.4257, -3.9186],
        [-3.1311,  4.0153, -0.5831],
        [-3.3180,  6.0904, -3.4400],
        [-3.8908,  7.1788, -3.6962],
        [-1.4586, -3.7664,  4.6668],
        [-3.8936,  6.5347, -2.4071],
        [-0.4518, -3.9106,  2.9186],
        [-2.7629, -1.7546,  4.2494],
        [-4.7189,  8.3650, -3.3625

Epoch number 7
 Current loss 0.05501943454146385

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.4999,  8.1675, -3.5795],
        [-3.7622,  5.2993, -1.2333],
        [ 1.1726, -5.3000,  0.7231],
        [-3.4890,  6.0240, -2.9103],
        [-1.1948, -4.9666,  5.1656],
        [-2.5237, -3.6113,  5.6008],
        [-2.2054, -4.0797,  5.6199],
        [ 1.5226, -4.7029, -0.7077],
        [-3.5765,  4.8603, -0.8250],
        [-2.1703, -4.0571,  5.4849],
        [-4.7070,  7.8231, -2.6390],
        [ 1.5022, -5.0476,  0.1604],
        [-1.8966,  2.9936, -1.9061],
        [-3.8261,  4.9860, -0.9251],
        [-4.5906,  8.4848, -3.7241],
        [-3.0572,  5.1716, -2.2333],
        [-2.1166, -2.1313,  3.8996],
        [-2.0056,  2.6342, -1.0167],
        [-3.9132,  5.7562, -1.5425],
        [-3.7738,  6.8297, -3.5731],
        [-1.8029, -4.4590,  5.5008],
        [-4.3854,  8.0087, -3.6033],
        [-4.5280,  8.6401, -4.1574],
        [-1.8048, -3.4005,  4.5738]

Epoch number 7
 Current loss 0.06230840086936951

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-0.6587, -4.6123,  4.0457],
        [-4.6259,  8.6212, -3.8800],
        [-4.6027,  8.6800, -4.0513],
        [-2.2429, -1.1420,  3.1706],
        [-3.7631,  6.7355, -2.9972],
        [-4.4540,  8.4069, -3.9842],
        [-4.4714,  6.0133, -0.8853],
        [-4.3295,  7.7156, -3.3195],
        [-4.6078,  8.4944, -3.6836],
        [-3.9962,  7.0919, -3.1610],
        [-2.5782,  2.5525, -0.2703],
        [-1.6725, -3.3524,  4.3674],
        [-4.6898,  8.7686, -3.9348],
        [-4.4556,  8.4100, -3.9135],
        [-0.5176, -4.2639,  3.3147],
        [-2.5703, -2.4175,  4.6555],
        [-3.6290,  5.9698, -2.2028],
        [-2.6561,  4.1874, -2.3664],
        [-3.3369,  5.6294, -2.6033],
        [-1.8405, -1.6920,  2.8073],
        [-4.6013,  8.7931, -4.0504],
        [-2.3614, -2.0074,  4.0811],
        [-3.9676,  6.5258, -2.4321],
        [-4.7853,  7.0427, -1.5936]

Epoch number 7
 Current loss 0.07552470266819

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.9065, -3.5119,  4.8081],
        [-3.6840,  5.3729, -1.2652],
        [-4.0137,  7.3282, -3.4606],
        [-0.5013, -4.4212,  3.9064],
        [ 1.5026, -4.6758, -0.7249],
        [-4.5903,  8.7611, -4.1279],
        [-1.2930, -2.5328,  2.7090],
        [-4.7840,  8.6478, -3.5934],
        [-3.8117,  6.9765, -3.6417],
        [-4.2425,  5.3573, -0.6955],
        [-4.6244,  8.4308, -3.5824],
        [-4.0666,  6.7752, -2.4879],
        [-4.2159,  6.9713, -2.5461],
        [-3.9293,  7.5662, -3.9165],
        [-1.9203, -4.1755,  5.3454],
        [-4.4175,  6.8162, -1.9250],
        [-4.6148,  8.8309, -4.1401],
        [ 1.5870, -5.0782, -0.2751],
        [-4.1950,  7.6828, -3.6757],
        [-0.1290, -4.8398,  3.1165],
        [ 1.3881, -4.2776, -0.8861],
        [-2.2212, -3.6482,  5.2540],
        [ 1.5216, -5.2107, -0.2660],
        [-3.4172,  4.9226, -1.5943],
 

Epoch number 7
 Current loss 0.05268746241927147

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.2246,  5.6601, -0.8717],
        [-3.9856,  7.0207, -3.1761],
        [-4.1133,  7.4607, -3.4558],
        [-4.4451,  8.3438, -3.9439],
        [-3.9491,  6.1562, -1.9395],
        [ 1.3479, -4.2144, -1.2969],
        [-4.5089,  8.1692, -3.4534],
        [-4.6664,  6.8998, -1.4965],
        [-3.8697,  6.5337, -2.5277],
        [-3.9374,  5.7922, -1.5380],
        [-1.8316, -4.8055,  5.8034],
        [-4.0786,  7.3251, -3.3206],
        [-4.4908,  8.1467, -3.5806],
        [-4.8469,  8.9725, -3.8486],
        [-3.5796,  6.8491, -3.7526],
        [-2.1227, -3.9567,  5.3410],
        [-1.6998, -1.7626,  2.8051],
        [-4.0460,  5.8055, -1.4216],
        [-1.8329, -3.0363,  4.3268],
        [ 0.0570, -4.3197,  2.2568],
        [-4.5260,  8.5556, -3.9807],
        [-3.2358, -1.2401,  4.3444],
        [-3.9852, -1.0616,  5.0606],
        [-2.8439,  5.4145, -3.3766]

Epoch number 7
 Current loss 0.04877536743879318

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.1396,  6.5583, -2.0017],
        [-4.6871,  8.7840, -3.9223],
        [-4.5909,  8.7858, -4.0966],
        [-4.6581,  8.6186, -3.8198],
        [ 0.1468, -1.6902, -1.0891],
        [-4.2112,  7.5996, -3.4939],
        [-4.5826,  8.6814, -4.0094],
        [-4.0561,  6.6419, -2.3543],
        [-4.8831,  9.0882, -3.9241],
        [-3.0354, -3.2359,  5.9000],
        [-4.5112,  8.3676, -3.6839],
        [-1.4320, -4.0912,  4.6978],
        [-3.7450,  5.7786, -1.8061],
        [-3.1476,  5.5858, -3.0756],
        [-4.2937,  6.3937, -1.7526],
        [-4.4989,  5.1568, -0.0338],
        [-4.5611,  8.4433, -3.7431],
        [-4.5821,  8.2314, -3.4708],
        [-3.7940,  4.8996, -0.6832],
        [-3.4143,  6.0316, -3.1757],
        [-3.8246,  6.8677, -3.3093],
        [ 0.6330, -3.6001,  0.0925],
        [-3.0327, -3.3149,  5.8794],
        [ 1.1819, -3.8642, -0.9931]

Epoch number 7
 Current loss 0.06554325670003891

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.6733, -1.0051,  3.4502],
        [-4.8301,  9.0563, -4.0409],
        [-3.8442,  6.0890, -1.9180],
        [-2.1721, -3.6848,  5.1997],
        [-4.5990,  6.1852, -1.0729],
        [-1.8951, -3.7621,  4.9189],
        [-4.5194,  8.0071, -3.2733],
        [-3.8763,  6.5415, -2.2944],
        [-3.6893,  6.6712, -3.2347],
        [-2.1580, -4.2880,  5.7394],
        [-3.9681,  5.9853, -1.6840],
        [-4.8184,  7.3281, -1.8878],
        [ 1.0984, -5.1130,  0.7683],
        [-4.1131,  6.9035, -2.4986],
        [-4.6905,  7.7605, -2.6127],
        [-4.6629,  8.7267, -3.9063],
        [-3.9440,  7.0795, -3.1865],
        [-4.1570,  4.5372,  0.1785],
        [-4.6017,  8.0824, -3.1686],
        [-1.1919, -3.7673,  3.9969],
        [-4.3076,  8.3367, -4.1026],
        [-4.2709,  7.1527, -2.6883],
        [-0.0198, -2.9036,  0.3564],
        [-4.3771,  8.0476, -3.6524]

Epoch number 7
 Current loss 0.06582800298929214

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.5844,  4.8089, -3.5423],
        [-3.8799,  6.1854, -2.1255],
        [-4.2266,  7.5768, -3.2882],
        [-4.6227,  8.9192, -4.2620],
        [-3.3916,  5.3547, -2.3235],
        [-4.6157,  8.8997, -4.2257],
        [ 1.5141, -4.8459, -0.3578],
        [-4.1352,  7.1968, -2.8956],
        [-3.8804,  6.3092, -2.2840],
        [-5.1052,  7.2681, -1.5239],
        [-4.6659,  8.4262, -3.5588],
        [-3.9136,  7.1051, -3.4063],
        [-3.9357,  6.2634, -1.9621],
        [-4.4941,  8.5839, -4.1304],
        [-3.6305,  0.7894,  2.8903],
        [-4.5456,  8.7163, -4.1900],
        [-4.5545,  8.7923, -4.2875],
        [-0.8033, -3.8013,  3.2165],
        [-3.7265,  6.7429, -3.3981],
        [-1.3500, -3.8852,  4.3060],
        [ 0.8731, -4.8310,  0.9684],
        [-3.1481, -3.3671,  6.1272],
        [-4.7061,  8.3294, -3.3771],
        [-3.9741,  5.9319, -1.6598]

Epoch number 7
 Current loss 0.050968918949365616

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.1227,  8.0008, -4.0986],
        [ 1.3884, -4.5993, -0.1465],
        [-4.4926,  8.2336, -3.6265],
        [-1.7031, -4.0739,  5.0440],
        [-4.3115,  8.4588, -4.3073],
        [-4.6463,  8.8448, -4.1227],
        [-2.7088, -2.8806,  5.2206],
        [-2.2239, -0.6282,  2.3447],
        [-4.5471,  8.2679, -3.5376],
        [-3.7872,  6.9081, -3.2936],
        [-3.7779,  6.5004, -2.6732],
        [-1.8527,  4.1085, -3.8015],
        [-4.5164,  8.0882, -3.3558],
        [-4.4837,  8.1308, -3.5134],
        [-5.0784,  6.5870, -0.8504],
        [-3.9284,  6.2546, -2.1499],
        [-4.2691,  7.0084, -2.4166],
        [-4.3092,  7.8953, -3.4982],
        [-2.2153, -3.7061,  5.2875],
        [-4.4368,  8.1249, -3.6443],
        [-4.6969,  7.0222, -1.7886],
        [-4.4976,  8.1508, -3.5412],
        [-3.9402,  7.1722, -3.4838],
        [-4.2643,  7.9204, -3.7735

Epoch number 7
 Current loss 0.06313074380159378

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.4940,  8.5031, -3.9713],
        [-1.3266, -3.5531,  3.9044],
        [-4.0664,  7.6616, -3.7099],
        [-1.0350, -3.9136,  3.7235],
        [ 0.7629, -3.6535, -0.3788],
        [-4.0918,  7.5063, -3.4515],
        [-4.6757,  8.4047, -3.4844],
        [-3.8576,  6.6498, -2.8154],
        [-4.4511,  8.1463, -3.5485],
        [-4.0730,  7.4480, -3.3606],
        [-4.1872,  6.9070, -2.4838],
        [-2.0211, -1.2924,  2.8478],
        [-3.6723,  6.4032, -2.7798],
        [-1.5179, -4.1441,  4.9289],
        [ 0.4575, -4.3538,  1.5978],
        [-4.5217,  8.2111, -3.5052],
        [-2.3817, -2.3294,  4.2970],
        [-2.9514,  4.5741, -2.1473],
        [-2.8402, -3.3253,  5.6484],
        [-2.4018, -2.9324,  4.9052],
        [-4.8613,  8.6981, -3.5111],
        [-2.6404, -4.5415,  6.4671],
        [-4.4231,  8.0490, -3.5257],
        [-3.9679,  7.1633, -3.4260]

Epoch number 7
 Current loss 0.051595646888017654

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.8447,  7.2614, -3.6314],
        [-0.9982, -4.2435,  4.0341],
        [-1.9137,  4.0356, -3.7681],
        [-3.7052,  6.3289, -2.5908],
        [-1.5005, -1.9790,  2.7326],
        [-4.4001,  8.6557, -4.2957],
        [-4.4227,  8.6766, -4.2466],
        [-3.4405,  6.0606, -2.9248],
        [-4.3450,  7.2828, -2.5890],
        [-4.7696,  7.4306, -2.0751],
        [-4.4021,  8.5343, -4.1577],
        [-4.3985,  8.3508, -3.9296],
        [-1.9269, -4.4396,  5.5836],
        [-4.6807,  7.2976, -2.1152],
        [-4.6833,  8.8900, -4.1087],
        [-0.3260, -4.1205,  2.3248],
        [-4.5895,  8.0971, -3.2452],
        [-3.6895,  6.0187, -2.0871],
        [-4.2468,  7.6006, -3.1142],
        [-2.5856, -4.7853,  6.6317],
        [-2.0887, -3.1340,  4.7143],
        [-4.6079,  8.6734, -3.9312],
        [-3.3954,  5.6980, -2.4278],
        [-3.1720,  5.9697, -3.5877

Epoch number 7
 Current loss 0.08292928338050842

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.8244,  6.4610, -2.3413],
        [-4.1551,  5.5234, -0.9615],
        [-2.8009,  4.2616, -1.3822],
        [-0.1956, -5.3946,  3.8042],
        [-4.1824,  7.4055, -3.1150],
        [-3.6220,  4.9958, -0.9803],
        [-5.0615,  7.9000, -2.1833],
        [ 1.3952, -4.9987,  0.6048],
        [ 1.0905, -5.3517,  0.9007],
        [-4.4328,  5.5280, -0.5202],
        [-4.4462,  7.2324, -2.3321],
        [-2.4902,  4.6605, -3.2243],
        [-4.1753,  7.6354, -3.5337],
        [ 1.0283, -5.1421,  0.9169],
        [-4.3877,  7.8458, -3.2648],
        [-3.9536,  6.4110, -2.3113],
        [-3.0570, -3.6441,  6.1500],
        [-2.6289, -0.6838,  3.1909],
        [-2.3785, -4.7381,  6.3322],
        [-1.8434, -4.9935,  5.9911],
        [-1.4092, -4.2029,  4.7528],
        [-1.6322, -4.6827,  5.4772],
        [-4.2588,  7.7709, -3.4879],
        [-1.6985, -1.4376,  2.5861]

Epoch number 7
 Current loss 0.06913729012012482

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.2286,  7.9791, -3.6726],
        [-4.5278,  7.1031, -2.1632],
        [-3.3741,  6.2057, -2.9698],
        [-2.8111, -4.3485,  6.4751],
        [-3.1371,  4.1829, -1.2752],
        [-3.2700,  5.9682, -3.1672],
        [-3.7149,  6.3463, -2.7193],
        [-0.4157, -3.4608,  2.4757],
        [-4.4204,  8.4148, -3.8950],
        [ 1.3946, -4.6522, -0.2596],
        [-2.5796, -2.9270,  5.1375],
        [ 0.1357, -2.5360,  0.5423],
        [-3.3754,  3.5627,  0.2565],
        [-2.4565, -3.8549,  5.6314],
        [-3.2063,  6.0419, -3.3463],
        [-2.9530, -2.7460,  5.3575],
        [-2.7044,  4.6106, -2.5787],
        [-3.2085, -3.2831,  6.0217],
        [-2.2719, -4.4958,  5.9927],
        [-4.3414,  8.4730, -4.2354],
        [-3.7365,  6.9390, -3.6503],
        [-4.4426,  7.9260, -3.2429],
        [-3.6617,  6.2922, -2.5530],
        [-4.5286,  7.9337, -3.1483]

Epoch number 7
 Current loss 0.035061608999967575

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.0134,  6.9537, -2.7596],
        [-3.4019,  5.7287, -2.2320],
        [-0.9751, -4.3416,  4.3112],
        [-4.9713,  7.7090, -2.0899],
        [-4.3268,  8.6217, -4.5086],
        [-4.0114,  7.3465, -3.4742],
        [-3.3521,  4.6604, -1.3192],
        [-3.5563,  3.1099,  0.7246],
        [-4.8201,  7.9155, -2.5907],
        [-3.4210,  4.5774, -0.5791],
        [-3.9628,  7.8476, -4.1495],
        [-2.1300, -4.6959,  6.0490],
        [-1.9691, -4.0696,  5.3221],
        [-1.5784, -4.7224,  5.4421],
        [-2.2434, -4.8455,  6.2906],
        [-3.1063,  5.7538, -3.3032],
        [-1.8178, -4.4682,  5.5123],
        [-4.2804,  7.5742, -3.1320],
        [ 1.1806, -5.0268,  0.2365],
        [ 1.2712, -4.8376,  0.0804],
        [-3.0793,  4.9382, -2.3780],
        [-4.4431,  7.7481, -2.9876],
        [-3.6951,  6.8210, -3.5666],
        [-1.9937, -4.9029,  6.1375

Epoch number 7
 Current loss 0.06371458619832993

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.7018,  6.6225, -3.1094],
        [ 0.8821, -3.0653, -0.4325],
        [-3.9961,  7.5595, -3.5947],
        [-4.3558,  8.5909, -4.3340],
        [-2.2870, -2.6527,  4.5191],
        [-3.4748,  6.2557, -3.0412],
        [-4.4939,  8.8269, -4.2967],
        [-4.5963,  8.9162, -4.3204],
        [-3.2065,  5.8282, -3.1380],
        [-3.4967,  6.3717, -3.4882],
        [ 0.2898, -4.4083,  1.5862],
        [-2.9586,  4.4653, -1.9241],
        [-4.5491,  8.7566, -4.1489],
        [-3.9349,  7.6469, -3.9774],
        [-3.0321,  5.9066, -3.7294],
        [-4.3016,  6.8808, -2.2412],
        [-4.4422,  8.3635, -3.8062],
        [-4.5517,  8.6428, -3.9842],
        [ 1.1522, -3.6716, -1.0387],
        [ 1.3365, -4.3082, -0.6036],
        [-4.4530,  8.1603, -3.5855],
        [-4.1210,  6.1414, -1.6381],
        [-4.3018,  7.6961, -3.2794],
        [-3.1690,  5.5708, -2.9603]

Epoch number 7
 Current loss 0.04697712883353233

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.7063,  7.2291, -3.7391],
        [-4.1380,  7.4817, -3.4057],
        [-4.6197,  8.9406, -4.2482],
        [-1.6718, -3.9438,  4.8425],
        [-3.8039,  4.4404, -0.0297],
        [-2.7156, -3.8218,  6.0213],
        [-4.5788,  8.8824, -4.2296],
        [-3.6788,  5.0561, -1.3520],
        [-1.8420, -3.6485,  4.8349],
        [-3.3028,  5.7907, -2.9558],
        [-4.4952,  8.7972, -4.2396],
        [-3.6634,  4.8670, -1.0255],
        [-4.1043,  5.6862, -1.2646],
        [-4.5060,  8.6824, -4.2016],
        [-1.6588, -2.1520,  3.0748],
        [-4.3822,  8.7608, -4.4868],
        [-4.5379,  8.8689, -4.3005],
        [-3.8085,  5.8449, -1.6375],
        [-3.5590,  6.4717, -3.2086],
        [ 1.4749, -4.7455, -0.2008],
        [ 1.3191, -4.3523, -0.7203],
        [-3.9228,  6.1855, -2.0627],
        [ 1.4177, -4.6354, -0.5499],
        [-4.3758,  5.8757, -1.0020]

Epoch number 7
 Current loss 0.051660191267728806

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.2573,  8.2649, -4.1672],
        [-3.8545,  5.6722, -1.4817],
        [-4.1396,  7.9285, -4.0403],
        [-4.2634,  8.1959, -3.9765],
        [-1.0302, -3.9928,  3.9243],
        [-2.5889,  0.6764,  1.6909],
        [-3.4513,  6.4874, -3.6536],
        [-4.5388,  8.8770, -4.3578],
        [-2.0122, -4.3870,  5.6704],
        [-3.6334,  6.5935, -3.1146],
        [-3.5655,  4.4454, -0.7345],
        [-1.3318, -3.8525,  4.1752],
        [-0.4628, -4.3544,  3.0461],
        [-0.6641, -3.8789,  3.3716],
        [-1.6363, -2.4528,  3.2867],
        [-3.7619,  7.2379, -3.7734],
        [-4.4459,  8.7190, -4.3274],
        [ 1.4782, -4.8877, -0.4749],
        [-2.9702,  4.4431, -1.8922],
        [-3.9760,  6.9721, -2.9172],
        [-2.5744, -3.1307,  5.2172],
        [-4.4493,  8.3213, -3.8057],
        [-3.5853,  6.2088, -2.6971],
        [-3.5692,  6.1391, -2.7225

Epoch number 7
 Current loss 0.09156060218811035

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.7450,  7.9451, -2.7073],
        [-3.7805,  6.7826, -3.1701],
        [-4.3763,  8.2660, -3.7745],
        [-1.3169, -1.2122,  1.4333],
        [ 1.3749, -4.4949, -0.5154],
        [-1.1112, -1.7247,  1.6350],
        [ 1.0377, -5.0135,  0.6745],
        [-3.5527,  6.3838, -3.1104],
        [-3.6936, -0.2566,  3.8721],
        [-4.9503,  7.7810, -2.2676],
        [-3.7812,  4.9897, -0.8902],
        [-3.4828,  6.1667, -2.9665],
        [-2.9391, -3.0029,  5.5593],
        [-4.3290,  8.5973, -4.3687],
        [-4.2849,  8.4326, -4.2783],
        [-2.3882, -3.5806,  5.4547],
        [-4.5106,  8.3585, -3.6672],
        [-1.9522, -3.4331,  4.7897],
        [-4.6682,  8.7937, -3.9263],
        [-4.3260,  8.3775, -4.0967],
        [ 1.2751, -5.3909,  0.5574],
        [-3.5209,  6.3433, -2.9065],
        [-3.3748,  5.7720, -2.6032],
        [-3.9711,  7.7868, -4.0101]

Epoch number 7
 Current loss 0.07433836907148361

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.5826,  8.5688, -3.8298],
        [-3.2349,  6.2875, -3.7280],
        [-3.9769,  6.6170, -2.4549],
        [-4.5261,  8.2688, -3.5931],
        [ 1.3388, -4.2827, -1.0053],
        [-2.6071, -4.4556,  6.3488],
        [-4.2530,  8.2328, -3.9853],
        [-4.3927,  8.5375, -4.2626],
        [-2.5601, -4.3725,  6.2270],
        [-3.4021,  6.3805, -3.3506],
        [-4.1744,  7.8142, -3.6897],
        [-1.8459, -1.2644,  2.4507],
        [-3.8805,  6.6571, -2.5261],
        [ 1.2076, -3.8340, -0.4596],
        [-3.3997,  6.3643, -3.4939],
        [-3.3887,  6.2189, -3.3579],
        [-4.3071,  8.0724, -3.7225],
        [-4.6867,  8.5047, -3.5420],
        [-4.5619,  8.8159, -4.2155],
        [-4.4737,  8.9522, -4.4849],
        [-3.2683,  6.2174, -3.4971],
        [ 0.1759, -3.3542,  0.7549],
        [-2.0445, -4.1758,  5.5463],
        [-2.2721, -4.8129,  6.2960]

Epoch number 7
 Current loss 0.03148062154650688

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.8402,  5.4626, -1.2333],
        [-1.7458, -4.4019,  5.3778],
        [-2.3466, -3.4502,  5.2216],
        [-3.2760, -2.3508,  5.2274],
        [-3.7444,  5.8932, -1.8722],
        [-4.5700,  8.9723, -4.4376],
        [-4.2066,  8.4983, -4.5318],
        [-1.9625, -4.1242,  5.3461],
        [-4.4453,  8.8063, -4.5208],
        [-3.3246,  6.2843, -3.6386],
        [-3.7287,  6.0887, -2.2311],
        [-3.1777,  4.8227, -1.8146],
        [-3.3373,  6.4248, -3.7692],
        [-1.0946, -4.2509,  4.4737],
        [-4.3478,  8.2054, -3.8265],
        [-4.2825,  8.4311, -4.3115],
        [-1.0038, -4.0209,  4.4215],
        [-3.2955,  5.4633, -2.1076],
        [-4.1396,  8.3063, -4.4151],
        [-3.7697,  5.7855, -1.8157],
        [-2.8058, -2.6289,  5.0145],
        [ 0.1213, -4.4438,  2.0990],
        [-4.4533,  8.4998, -3.9687],
        [ 0.1933, -3.8042,  1.1491]

Epoch number 7
 Current loss 0.0461619570851326

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.5613, -3.0769,  4.0266],
        [-2.3958, -3.6481,  5.4371],
        [-2.2481,  5.1023, -4.2999],
        [-4.0453,  7.5818, -3.6766],
        [-3.6456,  6.2228, -2.5264],
        [-1.7871, -3.8342,  4.9190],
        [-4.0158,  6.2831, -2.0419],
        [-3.2806,  6.2502, -3.4702],
        [-4.4502,  8.8259, -4.4199],
        [-3.3469,  5.9830, -2.9195],
        [-4.0619,  7.7723, -3.7850],
        [-4.1770,  7.6758, -3.3537],
        [-3.5538,  6.0686, -2.5157],
        [-3.7831,  6.2872, -2.5967],
        [-2.0684, -4.5506,  5.8615],
        [-3.9499,  7.1319, -3.3134],
        [-0.1052, -3.6584,  1.8225],
        [-2.3354, -4.2277,  5.8236],
        [-4.5152,  6.3086, -1.3413],
        [-3.8895,  7.1707, -3.4282],
        [-3.3678,  6.1267, -3.1665],
        [-4.6885,  8.7499, -3.8274],
        [-3.0774,  4.3110, -1.0965],
        [-1.7025, -0.9491,  2.1172],

Epoch number 8
 Current loss 0.04098046198487282

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.6389,  5.9999, -2.3227],
        [-3.7117,  6.4514, -2.7703],
        [-4.4296,  8.8479, -4.5166],
        [-3.5453,  6.1727, -2.6832],
        [ 1.4870, -4.7653, -0.5081],
        [-4.1558,  8.0241, -4.1701],
        [-2.2682, -4.6541,  6.1785],
        [-2.2782, -1.6280,  3.6549],
        [-4.1645,  8.0320, -3.9414],
        [-3.2937,  5.0698, -2.0716],
        [-3.3285,  6.0593, -3.2758],
        [-3.5950,  6.0022, -2.3705],
        [-4.6225,  8.5518, -3.7489],
        [-3.8454,  6.6123, -2.5320],
        [-4.6345,  7.8962, -2.9332],
        [-3.5320,  6.1388, -2.9517],
        [-3.8633,  6.6124, -2.6134],
        [-1.7397, -3.8357,  4.8613],
        [-0.4923, -2.4538,  1.3945],
        [-3.7731, -1.4432,  5.0372],
        [-3.7178,  7.1032, -3.9990],
        [-4.2922,  7.8944, -3.4443],
        [-4.4635,  8.8268, -4.3318],
        [-3.8312,  7.0731, -3.6221]

Epoch number 8
 Current loss 0.05747537314891815

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.4919,  8.7763, -4.2954],
        [-3.1364,  4.2822, -1.2899],
        [-4.5213,  8.9967, -4.4982],
        [-3.6988,  6.4761, -2.8805],
        [-4.5709,  8.8944, -4.2508],
        [-1.8318, -1.6971,  2.6718],
        [-4.5138,  9.0002, -4.5107],
        [-3.1898,  3.5448, -0.1003],
        [-3.7029,  6.5878, -3.0405],
        [-4.2780,  8.6054, -4.4460],
        [-4.2935,  7.8411, -3.5160],
        [-2.4492, -3.6169,  5.4362],
        [-3.4202,  5.6746, -2.3678],
        [-4.6060,  8.0889, -3.1197],
        [-2.0175, -2.2075,  3.8170],
        [-4.5081,  8.8393, -4.3376],
        [-1.9026, -4.3405,  5.4679],
        [ 1.2694, -3.7609, -1.1451],
        [-2.1481, -2.7939,  4.5279],
        [-3.5814,  6.2297, -2.7443],
        [-3.1570,  4.4678, -1.1165],
        [ 1.3616, -4.1534, -1.0530],
        [-4.2391,  7.3189, -2.9330],
        [-2.3052, -4.4835,  6.0322]

Epoch number 8
 Current loss 0.02182803303003311

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.0721,  7.7341, -3.6656],
        [-3.4415,  6.0903, -3.0536],
        [-3.7668,  5.8142, -2.0106],
        [-1.6879, -3.9185,  4.9245],
        [-4.4038,  8.2982, -3.7235],
        [-2.9686,  5.8407, -4.0715],
        [-4.1044,  8.3060, -4.5731],
        [-4.0027,  7.0115, -2.9241],
        [ 1.4011, -5.2467,  0.2529],
        [ 1.2106, -5.3092,  1.1116],
        [-1.9183, -3.7647,  5.0661],
        [-4.5594,  8.8819, -4.2983],
        [-1.8184, -4.2904,  5.3313],
        [-4.7794,  8.0422, -2.7661],
        [-3.9484,  5.5722, -1.3749],
        [-4.5285,  8.7840, -4.1933],
        [-4.2246,  8.3584, -4.2562],
        [ 1.6408, -5.3852,  0.0228],
        [-1.7463, -4.9307,  5.8127],
        [-4.3523,  8.7524, -4.4920],
        [-0.6600,  1.5017, -2.8745],
        [ 1.3522, -5.3754,  0.2814],
        [ 1.5677, -5.1385, -0.2067],
        [-2.0705, -3.9548,  5.3066]

Epoch number 8
 Current loss 0.04667847231030464

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.5103,  5.7870, -2.4372],
        [-4.3621,  8.3551, -3.9342],
        [-3.8145,  6.6504, -2.7989],
        [-3.0060,  5.4860, -3.1430],
        [-4.3963,  8.6487, -4.3543],
        [-3.2967,  5.7657, -2.7448],
        [-4.3043,  6.8795, -2.3081],
        [-2.0055, -4.6092,  5.8761],
        [-4.3512,  8.6810, -4.4963],
        [ 1.4962, -4.7437, -0.4349],
        [-3.3469,  5.0776, -1.8665],
        [-3.9259,  5.8541, -1.8054],
        [-3.1740,  5.6137, -3.0784],
        [-3.7006,  6.4503, -3.0359],
        [-1.6863, -4.8307,  5.7467],
        [-2.6779, -3.3480,  5.5114],
        [-2.6548, -3.4462,  5.5799],
        [-3.7130,  6.7188, -3.2841],
        [-0.8014, -4.6715,  4.2906],
        [-1.6740, -4.1184,  4.9900],
        [ 0.5109, -5.4861,  2.6269],
        [-3.6460,  5.1804, -1.6079],
        [-3.9729,  6.3174, -2.1963],
        [-4.8401,  6.8130, -1.4582]

Epoch number 8
 Current loss 0.05135853588581085

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.3911,  8.8224, -4.5992],
        [-3.9624,  6.8520, -2.7556],
        [-4.0524,  7.1912, -3.2454],
        [-4.4116,  5.2137, -0.3528],
        [-4.5295,  8.7550, -4.2201],
        [-3.5597,  6.2246, -2.8724],
        [-3.5032,  5.8723, -2.3485],
        [-2.6664, -3.6797,  5.7073],
        [ 1.3235, -3.9877, -0.7876],
        [ 0.0416, -3.8040,  1.6989],
        [-3.7429,  6.2231, -2.4413],
        [-1.2960, -4.3302,  4.8198],
        [-3.5158,  6.2995, -3.2957],
        [-1.9164, -4.1604,  5.3301],
        [-3.7337,  6.3572, -2.6843],
        [-4.4738,  8.1185, -3.4497],
        [-4.4500,  8.6943, -4.2675],
        [-2.6532,  5.0022, -3.5881],
        [-4.5283,  8.4031, -3.5976],
        [-3.7256,  7.1359, -3.9622],
        [-3.8425,  6.3183, -2.3134],
        [-1.3772, -4.9852,  5.4243],
        [-3.4269,  6.1470, -3.1037],
        [-3.3997,  5.7262, -2.5987]

Epoch number 8
 Current loss 0.06142790615558624

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.5040, -4.0670,  4.6923],
        [-0.4805, -4.4202,  3.3375],
        [-2.7588, -3.4656,  5.6819],
        [-3.2273,  4.1044, -0.5680],
        [-3.5059,  6.5308, -3.7422],
        [-3.6197,  6.3779, -3.0550],
        [-1.8633, -4.1885,  5.3150],
        [-3.6812,  6.2003, -2.5592],
        [-4.5393,  7.6051, -2.7564],
        [ 1.5471, -5.1431, -0.4712],
        [-2.2052, -3.9046,  5.3947],
        [-4.0113,  7.6371, -3.9811],
        [-4.3318,  7.9081, -3.5141],
        [-1.8638, -4.6268,  5.6869],
        [-4.3299,  6.3440, -1.7267],
        [-4.6913,  8.7905, -3.9104],
        [-1.3435, -4.0721,  4.4607],
        [-3.4850,  6.1332, -3.0646],
        [-4.2443,  7.9390, -3.6172],
        [-4.5407,  7.3036, -2.4510],
        [-3.6407,  6.4141, -3.0253],
        [-4.1880,  5.5593, -1.0000],
        [-2.8778,  4.3836, -1.4794],
        [-2.4967, -3.3480,  5.3447]

Epoch number 8
 Current loss 0.05255516245961189

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.2180,  7.3092, -3.0078],
        [-3.4275,  6.7562, -3.7741],
        [-3.4411,  6.4815, -3.6687],
        [-2.9706,  5.5556, -3.6808],
        [-1.9988, -4.3157,  5.5959],
        [-3.4311,  6.2781, -3.4982],
        [-5.2891,  7.8461, -1.8510],
        [-3.6380,  6.6113, -3.5216],
        [-4.2316,  7.9516, -3.7470],
        [-4.0366,  6.6059, -2.2821],
        [-3.4554,  5.9834, -2.8010],
        [-3.5911,  5.5088, -1.9804],
        [-4.5261,  8.9532, -4.4593],
        [-5.0480,  7.1698, -1.4332],
        [-0.6515, -2.7538,  1.6646],
        [-4.1261,  8.0417, -4.0080],
        [-3.5325,  6.0345, -2.7496],
        [-1.7973, -1.6702,  2.5943],
        [-3.3584,  6.3594, -3.7421],
        [-2.2782, -2.7031,  4.5207],
        [-4.1566,  7.9877, -4.0574],
        [-4.1600,  6.0740, -1.6193],
        [-4.9766,  8.3366, -2.8634],
        [-4.3398,  8.1463, -3.7968]

Epoch number 8
 Current loss 0.04187614470720291

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.7447,  6.4206, -2.8044],
        [-3.5130,  6.4944, -3.5676],
        [-4.4647,  8.5233, -4.0176],
        [-2.3035, -4.3319,  5.8970],
        [-3.6735,  6.5995, -3.2136],
        [ 1.5897, -5.2901, -0.2060],
        [-2.1348, -3.8061,  5.3593],
        [-1.9924, -4.4917,  5.6773],
        [-0.4100, -4.9281,  3.5622],
        [-4.0267,  7.2653, -3.2042],
        [-3.0691, -1.6440,  4.4273],
        [-3.2983,  5.0485, -2.1033],
        [ 1.2126, -4.8906,  0.2830],
        [-3.9305,  8.0218, -4.5333],
        [-0.3132, -0.2020, -1.9911],
        [-1.9524, -4.5556,  5.6859],
        [-1.8474, -4.7426,  5.7163],
        [-2.6229,  2.9499, -1.0657],
        [-3.6886,  6.6073, -3.2756],
        [-4.0964,  7.8873, -3.9133],
        [-0.4409, -4.6032,  3.3436],
        [-3.9647,  7.7867, -4.3306],
        [ 1.4350, -4.6544, -0.0308],
        [-4.7233,  8.5832, -3.6383]

Epoch number 8
 Current loss 0.026041202247142792

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-5.0941,  8.7350, -3.2402],
        [-4.0133,  6.3962, -2.0765],
        [-3.8041,  5.9771, -1.8492],
        [-3.9048,  7.6501, -3.9831],
        [-1.8394, -4.3420,  5.3519],
        [-3.8945,  6.5188, -2.6265],
        [-1.9048,  3.3295, -3.3572],
        [-4.1052,  7.0222, -2.6759],
        [-2.3270, -3.9892,  5.5708],
        [ 1.5886, -5.0649, -0.1657],
        [-3.1944,  5.8981, -3.4919],
        [-4.1109,  6.3424, -2.1381],
        [-2.2723,  4.2117, -3.5784],
        [-2.4847, -3.5402,  5.4735],
        [-3.5979,  6.1735, -2.5870],
        [ 1.3537, -4.2079, -0.9145],
        [-3.9000,  6.0848, -2.0492],
        [-4.5426,  8.6331, -4.0925],
        [-1.6498, -4.6562,  5.4600],
        [ 1.5763, -5.4673,  0.0593],
        [-2.4192, -3.6021,  5.4261],
        [-3.3187,  5.7839, -3.1927],
        [-1.0126, -5.2837,  5.0899],
        [-3.7949,  7.0755, -3.7812

Epoch number 8
 Current loss 0.0483209602534771

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.6030,  6.5051, -3.4523],
        [-3.6813,  6.9176, -3.7251],
        [-2.9584,  5.6375, -3.7742],
        [ 0.2959, -2.0433, -0.6883],
        [ 1.4450, -4.4293, -1.2000],
        [-3.3276, -0.2285,  3.4575],
        [-4.1404,  7.6638, -3.5900],
        [-4.4023,  8.6807, -4.4938],
        [-4.3513,  6.5419, -1.9766],
        [-2.9566,  5.2633, -3.2155],
        [-3.2486,  6.1251, -3.7560],
        [ 1.2838, -4.2615, -0.2748],
        [-3.7141,  4.9287, -0.9507],
        [-0.8305, -4.5385,  4.1993],
        [-4.2567,  8.6102, -4.6083],
        [-4.0089,  7.1513, -3.2746],
        [-3.0167,  4.5787, -2.2293],
        [-4.6368,  7.7152, -2.7260],
        [-4.4310,  8.6877, -4.3658],
        [-4.5621,  8.7286, -4.2335],
        [-4.1201,  7.4135, -3.2565],
        [-4.3827,  8.4093, -4.0598],
        [ 0.1542, -2.7816, -0.1025],
        [-3.3407,  6.1683, -3.1618],

Epoch number 8
 Current loss 0.05171814560890198

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.4124,  8.4872, -4.1073],
        [-4.3894,  8.4673, -4.1364],
        [-4.3544,  8.6929, -4.4821],
        [-4.4262,  8.8071, -4.5279],
        [-0.8296, -4.2940,  3.7618],
        [-4.1603,  6.7649, -2.2301],
        [-3.9428,  6.8170, -3.0545],
        [-3.7795,  6.9375, -3.3403],
        [-4.9045,  7.4175, -2.0235],
        [-3.2712,  5.8747, -3.3829],
        [-3.6124,  5.6944, -2.2675],
        [-4.4383,  5.1794, -0.3085],
        [-4.6092,  8.7292, -3.9584],
        [-4.5032,  8.8169, -4.3470],
        [-4.4717,  8.5853, -4.2083],
        [-4.7050,  8.5975, -3.6094],
        [-4.4929,  8.9666, -4.6258],
        [-4.5502,  8.4626, -3.8016],
        [-3.0087, -3.5983,  6.0457],
        [-1.3362,  3.1904, -3.8654],
        [-4.2361,  8.3191, -4.3118],
        [-3.7199,  5.8254, -2.2029],
        [-4.1465,  6.5562, -2.1190],
        [-3.7843,  6.4402, -2.6897]

Epoch number 8
 Current loss 0.03999391198158264

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 1.5530, -4.9827, -0.4382],
        [-3.7981,  6.7057, -3.2744],
        [-4.0212,  7.4288, -3.7381],
        [-4.2902,  4.9925, -0.3913],
        [ 1.1948, -5.2180,  0.5545],
        [-1.3779, -4.6530,  5.0440],
        [ 1.5839, -5.1539, -0.2041],
        [-2.1657, -4.8075,  6.1485],
        [-2.9472,  4.6001, -2.4658],
        [-4.3513,  8.1186, -3.8158],
        [-0.6223, -4.7108,  3.7999],
        [ 1.1779, -5.2133,  0.6282],
        [-2.3675, -3.0709,  4.9529],
        [-0.9632, -5.0318,  4.8613],
        [-4.6747,  8.7186, -3.8669],
        [-3.9479,  7.1386, -3.4540],
        [-3.8467,  7.1352, -3.6216],
        [-4.0905,  6.3914, -2.0511],
        [-4.8679,  8.1777, -2.8689],
        [-4.4478,  8.7441, -4.4580],
        [-2.8098, -2.6967,  5.2606],
        [-4.8182,  5.3460,  0.0212],
        [-2.8418, -4.3066,  6.4021],
        [-3.0278,  5.6781, -3.6762]

Epoch number 8
 Current loss 0.02190900407731533

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.1796, -4.1604,  5.5606],
        [-1.9256, -4.6754,  5.7228],
        [-3.8927,  6.9797, -3.4345],
        [-4.0482,  6.9943, -2.9821],
        [-2.6963, -0.6281,  3.1086],
        [-2.6969, -4.1274,  6.1877],
        [ 1.5530, -5.6563,  0.3624],
        [-1.2730, -4.4069,  4.6939],
        [-3.8300,  6.3210, -2.5474],
        [-3.9931,  6.1587, -2.0772],
        [-1.6481, -4.5934,  5.4149],
        [-4.5445,  8.8517, -4.4337],
        [-4.4501,  8.6445, -4.3513],
        [-4.1767,  7.9734, -4.0397],
        [-3.5646,  4.9972, -1.5061],
        [-2.5689, -2.9117,  4.9285],
        [-4.3542,  8.5534, -4.3567],
        [-4.2204,  8.4594, -4.5724],
        [-3.8795,  6.6507, -2.9978],
        [-2.5136, -4.8325,  6.5281],
        [-4.9204,  8.1557, -2.8283],
        [-3.1758,  4.8274, -1.5540],
        [ 0.9369, -3.3674, -1.0135],
        [-4.2058,  7.9719, -3.9108]

Epoch number 8
 Current loss 0.05804586783051491

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.4955,  8.8768, -4.5162],
        [-1.0856, -4.8926,  4.8263],
        [-4.5638,  9.0991, -4.6649],
        [-2.5603, -4.1101,  5.9130],
        [-4.1748,  7.8268, -3.7538],
        [-4.5138,  8.9801, -4.6185],
        [-3.5294,  5.9312, -2.7645],
        [-3.7874,  7.4170, -4.0036],
        [-1.7003, -4.1774,  5.0594],
        [-1.5168,  2.5418, -3.2921],
        [-3.1529, -0.7379,  3.6043],
        [-3.1408,  0.8881,  2.2221],
        [-4.6716,  8.5629, -3.7585],
        [-4.9332,  8.7901, -3.5832],
        [-1.0487, -3.9969,  3.7595],
        [-2.4236, -4.0033,  5.6971],
        [-3.0535, -2.6530,  5.3822],
        [-4.6914,  9.1381, -4.4614],
        [-4.7828,  9.2538, -4.3798],
        [-1.3079, -1.2240,  1.0832],
        [-3.5867,  5.6946, -2.3189],
        [-4.9362,  8.1133, -2.7326],
        [-1.5805, -3.9157,  4.5494],
        [-3.6368,  6.1856, -2.9596]

Epoch number 8
 Current loss 0.02816324308514595

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.9104,  5.3200, -1.0245],
        [-3.0777,  5.2136, -2.9217],
        [-3.4141,  6.2864, -3.6670],
        [-3.4987, -1.4391,  4.6411],
        [-3.6810,  6.3919, -2.9262],
        [-3.2320,  4.9682, -1.7217],
        [-1.6210, -1.6511,  2.1159],
        [-3.1964,  4.2137, -1.3162],
        [-4.4673,  8.2851, -3.7827],
        [-3.8692,  6.2683, -2.5444],
        [-4.5464,  7.8840, -3.1465],
        [-4.1292,  6.6863, -2.2530],
        [-2.5812, -4.3193,  6.1346],
        [-4.8350,  7.7567, -2.4773],
        [-4.4114,  7.5105, -2.9515],
        [-3.4895,  2.3082,  1.2583],
        [-4.6566,  8.8912, -4.1253],
        [-3.7147,  6.3636, -2.8512],
        [-4.5479,  8.6111, -4.0061],
        [-3.6079,  6.7007, -3.7218],
        [ 0.3955, -5.7304,  2.8727],
        [-4.6049,  8.8428, -4.2329],
        [-3.9993,  6.5441, -2.4077],
        [-4.0731,  6.8843, -2.7796]

Epoch number 8
 Current loss 0.07354708760976791

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.3247,  6.0916, -3.5731],
        [-4.3775,  8.4209, -4.1674],
        [-2.0783, -4.3299,  5.6240],
        [-2.3209, -3.7612,  5.3969],
        [-4.4609,  5.0639, -0.2263],
        [-3.3165,  3.8133, -0.4338],
        [-4.0494,  7.5277, -3.6207],
        [-3.2222,  5.7092, -3.2657],
        [-3.0021,  4.2704, -1.1535],
        [-4.7262,  7.0534, -2.0301],
        [ 1.6697, -5.4624,  0.0292],
        [-3.9869,  6.1942, -1.9809],
        [-4.7109,  8.6096, -3.6915],
        [-4.6662,  8.9524, -4.2314],
        [-0.2712, -2.8300,  0.9054],
        [-4.1459,  6.2404, -1.8263],
        [-4.5472,  9.0906, -4.7296],
        [-4.4533,  8.5402, -4.1807],
        [-4.5363,  8.9550, -4.5196],
        [-3.3379,  5.7946, -2.8167],
        [-2.1147, -1.6755,  3.1860],
        [-4.3871,  8.6157, -4.3865],
        [-4.7114,  9.1099, -4.4454],
        [-1.4992, -4.7479,  5.2735]

Epoch number 8
 Current loss 0.04622642695903778

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.4628,  8.0739, -3.5867],
        [-1.6251, -4.8726,  5.5392],
        [-3.9208,  7.8003, -4.2875],
        [ 1.4276, -4.5534, -0.5795],
        [-1.8521, -4.3521,  5.3308],
        [-4.5578,  9.1269, -4.6440],
        [-3.9789,  7.4603, -3.6341],
        [-4.6210,  9.1259, -4.6234],
        [-4.7525,  9.2300, -4.4384],
        [-2.0077, -4.9813,  6.1039],
        [-3.4094,  5.9795, -3.2812],
        [-2.1711, -4.7016,  5.9786],
        [-3.6890,  5.9354, -2.5374],
        [-2.7987, -3.2835,  5.5483],
        [ 1.6674, -5.3611, -0.1497],
        [-4.6042,  8.6451, -3.9681],
        [-4.6732,  9.2698, -4.5997],
        [-3.6462,  5.8053, -2.3113],
        [-4.2530,  8.1299, -4.2127],
        [-3.1567,  6.0198, -3.8359],
        [-1.9165, -3.6765,  4.9434],
        [-3.0328, -1.4373,  4.1589],
        [-4.3673,  6.5882, -1.8847],
        [-2.5157,  4.8723, -3.6529]

Epoch number 8
 Current loss 0.06736165285110474

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.4134, -3.9280,  4.3859],
        [-2.3879,  4.6774, -3.8573],
        [ 1.5063, -5.1833, -0.2431],
        [-2.1421, -3.7908,  5.2112],
        [-2.7294, -2.5930,  4.8402],
        [-4.1605,  7.8143, -3.9667],
        [-4.3969,  6.9018, -2.1184],
        [-3.7009,  4.9637, -1.2108],
        [-4.6559,  9.2904, -4.7238],
        [-2.2446, -4.7198,  6.0814],
        [-4.4913,  8.4270, -4.0001],
        [-4.4337,  8.4985, -4.1520],
        [-1.4269, -4.3419,  4.8028],
        [-2.2261, -4.2995,  5.7391],
        [ 1.4809, -5.5537,  0.0666],
        [-1.3551, -4.8109,  5.1399],
        [-4.2838,  5.7928, -1.2887],
        [-2.6253,  5.1324, -3.9242],
        [ 1.5299, -5.6080,  0.2147],
        [-3.2197,  5.2472, -2.4969],
        [-4.5118,  8.1384, -3.3554],
        [-4.5806,  9.0477, -4.5052],
        [-2.1732, -4.3483,  5.7104],
        [-2.3000, -4.2723,  5.7882]

Epoch number 8
 Current loss 0.051785923540592194

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.4594,  8.5380, -4.1174],
        [-4.3614,  8.1534, -3.8849],
        [-4.6376,  8.2838, -3.4195],
        [-1.1469, -5.0141,  5.0935],
        [-3.8714,  6.7832, -3.0632],
        [-1.2742, -2.4673,  2.5032],
        [-4.6022,  9.1409, -4.6528],
        [-4.9784,  8.3123, -2.9212],
        [ 1.4818, -4.5309, -0.5884],
        [-4.3286,  8.6196, -4.4531],
        [-2.2848, -4.2230,  5.6794],
        [-3.3136,  6.3565, -4.2019],
        [-2.5848,  3.8203, -1.4319],
        [-4.2995,  7.9467, -3.6726],
        [-4.3938,  8.5338, -4.3395],
        [-4.3484,  5.7588, -1.1046],
        [-4.2206,  4.5816,  0.2463],
        [-4.5208,  8.9223, -4.5138],
        [-4.0383,  6.3923, -2.2842],
        [-0.7669, -4.3731,  3.5510],
        [-4.0130,  6.7299, -2.8146],
        [ 1.4786, -4.5125, -1.3235],
        [-4.5434,  8.8281, -4.3344],
        [-4.5155,  9.0507, -4.6192

Epoch number 8
 Current loss 0.03502919152379036

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.5213,  9.0993, -4.7805],
        [-1.7771, -2.8926,  3.8886],
        [-3.9032,  6.3161, -2.4834],
        [-3.8757,  7.6325, -4.1890],
        [ 1.6954, -5.6617, -0.1278],
        [-3.9916,  6.6802, -2.6763],
        [ 0.0590, -5.4748,  3.4753],
        [-2.1433,  2.2762, -1.2571],
        [-2.3015,  4.5559, -3.9904],
        [-4.8802,  7.4729, -2.0250],
        [-2.3991, -4.5343,  6.0577],
        [-3.6224,  5.8463, -2.3543],
        [-2.2654, -4.5664,  5.9686],
        [-2.2349, -4.9742,  6.3613],
        [-4.7399,  4.9844,  0.2697],
        [-4.5160,  8.7419, -4.2590],
        [-4.4401,  8.5949, -4.2061],
        [-3.3251,  4.2887, -0.7106],
        [-4.5867,  9.1813, -4.7371],
        [-2.6881,  5.3876, -4.1888],
        [-4.6172,  8.2763, -3.5124],
        [-4.3303,  8.3399, -4.0911],
        [-3.6407,  5.5245, -1.9744],
        [-4.7540,  8.3215, -3.2438]

Epoch number 8
 Current loss 0.08117201179265976

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.9094,  5.8057, -1.9859],
        [-3.6883,  3.7619, -0.0151],
        [-4.5729,  9.1969, -4.7857],
        [ 1.5886, -4.9882, -0.6842],
        [-2.1334, -3.4656,  5.0294],
        [-4.5611,  8.7495, -4.2056],
        [-4.5142,  8.6505, -4.1565],
        [-3.6345,  5.9014, -2.4926],
        [-3.8973,  6.0814, -2.2418],
        [-3.1063,  5.9499, -3.8847],
        [-1.7910, -3.7808,  4.7571],
        [-4.4634,  8.7159, -4.3192],
        [-3.7869,  7.2940, -3.9020],
        [-4.0394,  5.9906, -1.8420],
        [-3.9699,  7.8054, -4.1291],
        [-4.7367,  9.3233, -4.5323],
        [-4.3844,  7.3873, -2.8575],
        [-3.5880,  6.2223, -3.1084],
        [-3.7246,  7.1091, -3.5412],
        [-4.5307,  9.1020, -4.7461],
        [-3.2010, -3.5742,  6.2060],
        [-2.2452, -4.6780,  6.0495],
        [-3.7848,  5.7437, -1.9798],
        [-4.3091,  6.7949, -2.1460]

Epoch number 8
 Current loss 0.06484261155128479

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.3027, -4.5873,  5.9686],
        [-4.5106,  8.2078, -3.5889],
        [-3.1815,  0.1792,  3.0066],
        [-4.3499,  5.0918, -0.4909],
        [-4.0347,  6.9599, -3.0847],
        [-3.9618,  6.5466, -2.4489],
        [-4.7397,  8.5736, -3.6043],
        [-4.4499,  8.4997, -4.1075],
        [-0.8945, -4.6099,  4.2652],
        [-4.4197,  8.5228, -4.1288],
        [-2.6407, -1.7053,  3.8733],
        [ 1.6708, -5.2499, -0.2635],
        [ 1.5150, -5.6471,  0.2033],
        [-4.3717,  8.7644, -4.6532],
        [-4.7335,  6.7541, -1.4631],
        [-4.0808,  6.3588, -2.1752],
        [-4.1577,  6.3210, -1.9440],
        [-1.9397, -2.5567,  3.8669],
        [-2.0934, -1.8365,  3.2734],
        [-1.5885, -4.7275,  5.4441],
        [-4.3974,  8.7049, -4.4663],
        [-3.6271,  6.7226, -3.6487],
        [-4.5350,  9.1639, -4.7658],
        [-3.3290, -2.4906,  5.4121]

Epoch number 8
 Current loss 0.024617329239845276

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.3873,  6.5986, -4.1028],
        [-4.6549,  9.2959, -4.6737],
        [-2.3768, -3.8842,  5.6218],
        [-4.0098,  7.6646, -4.0385],
        [ 1.5755, -4.7587, -0.6553],
        [-2.4092, -4.4116,  5.9744],
        [-1.6900, -3.6557,  4.6175],
        [-4.4218,  8.8328, -4.5788],
        [-2.6013, -4.6820,  6.4572],
        [-4.8108,  9.4613, -4.6273],
        [-1.9810, -4.8161,  5.9037],
        [-4.1836,  8.3851, -4.4134],
        [-2.0899, -4.3953,  5.6598],
        [-2.6927, -3.2802,  5.3505],
        [ 1.5859, -4.8920, -1.0514],
        [ 1.4705, -4.3434, -0.9493],
        [-4.3414,  8.4696, -4.1858],
        [-1.8595, -4.2976,  5.3365],
        [-4.5878,  8.8406, -4.2731],
        [-4.5176,  8.9145, -4.4694],
        [-4.3000,  8.8042, -4.7449],
        [-2.8764, -3.9188,  6.1165],
        [-3.3971,  6.3411, -3.6722],
        [ 1.6178, -5.0080, -0.6344

Epoch number 8
 Current loss 0.03241492062807083

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.5406,  9.0642, -4.6127],
        [-4.0521,  7.0670, -3.2508],
        [-4.5421,  9.0973, -4.7210],
        [-2.8713, -3.4602,  5.7950],
        [-4.6732,  8.3860, -3.5138],
        [-5.1502,  7.7064, -1.9818],
        [-4.4639,  9.0099, -4.7822],
        [-3.8266,  6.9958, -3.5893],
        [-1.5322, -4.0362,  4.7442],
        [ 1.2571, -5.2112,  0.5249],
        [-4.1801,  8.1700, -4.2203],
        [-4.4171,  8.6571, -4.3101],
        [-1.3698, -2.4451,  2.3889],
        [-3.6764,  6.6999, -3.5524],
        [-4.4122,  8.7143, -4.4254],
        [-3.8785,  6.6664, -2.9104],
        [-4.6475,  9.1351, -4.5573],
        [-4.5419,  9.2436, -4.8986],
        [-4.4559,  8.5820, -4.1400],
        [-2.6401, -4.6841,  6.4870],
        [-2.5996, -3.6939,  5.6461],
        [-3.2466, -3.0091,  5.8341],
        [-1.6630, -5.0975,  5.8150],
        [-4.4111,  8.1337, -3.8674]

Epoch number 8
 Current loss 0.09109018743038177

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 1.6347, -5.2055,  0.0395],
        [-3.5962,  6.7024, -3.6972],
        [-3.7047,  6.6029, -3.1646],
        [-3.0540, -3.2679,  5.7536],
        [-2.4138, -4.1395,  5.7310],
        [-4.1544,  8.1052, -4.3223],
        [-4.0266,  6.2131, -2.2365],
        [ 1.5077, -4.7693, -0.3088],
        [-0.9620, -2.3108,  2.0006],
        [-2.8210,  5.3531, -3.5508],
        [-4.6550,  9.3580, -4.7529],
        [-4.4100,  8.5668, -4.2075],
        [-3.3637, -1.8309,  4.9148],
        [-3.9739,  6.8399, -2.7696],
        [-4.4171,  6.0509, -1.3094],
        [-2.0291, -4.1876,  5.3835],
        [-4.7984,  8.2713, -3.1594],
        [-4.2378,  8.6162, -4.5774],
        [-4.9418,  8.5066, -3.1482],
        [-4.3686,  8.6471, -4.4196],
        [-3.8000,  6.4869, -2.7351],
        [-2.9986, -3.9198,  6.2384],
        [-2.5197, -3.7895,  5.6515],
        [-4.5894,  8.6144, -3.9690]

Epoch number 8
 Current loss 0.049022089689970016

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.2220,  0.6374,  2.4453],
        [-4.5092,  9.0557, -4.6226],
        [-4.1222,  8.4452, -4.7161],
        [-3.3445,  6.3223, -3.8673],
        [-2.4950, -3.9214,  5.6586],
        [-4.3438,  6.2287, -1.6717],
        [-4.8840,  8.9601, -3.8772],
        [-0.0913, -3.3550,  0.9467],
        [-4.3449,  6.9494, -2.2895],
        [-4.9952,  7.5527, -2.1289],
        [-3.9589,  5.6247, -1.4443],
        [-3.2769, -0.0675,  3.3398],
        [-1.7283, -4.0877,  4.8230],
        [-1.1040, -2.0804,  1.6682],
        [-4.6146,  8.5735, -3.8270],
        [-4.3298,  8.5583, -4.4288],
        [-4.4294,  8.1806, -3.6750],
        [-4.4549,  8.8553, -4.4664],
        [-2.9989, -2.9467,  5.4843],
        [-0.5177, -4.3826,  3.4547],
        [ 1.0590, -5.1638,  0.8756],
        [-4.7283,  8.6687, -3.7241],
        [-4.1610,  7.9862, -3.8725],
        [-0.3400, -3.3923,  1.4526

Epoch number 8
 Current loss 0.030643176287412643

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.7909,  6.3117, -2.6416],
        [-3.5064,  6.5129, -3.7242],
        [-2.5342, -4.1867,  5.9343],
        [-3.9267,  6.7826, -3.1605],
        [-3.6713,  6.6461, -3.3959],
        [-3.8582,  7.1572, -3.7736],
        [ 1.5533, -4.8523, -0.9543],
        [-3.3022,  3.6937, -0.7519],
        [-3.5330,  6.6195, -3.6702],
        [-3.4090,  6.1937, -3.3758],
        [-1.3759, -5.0543,  5.4281],
        [-4.5838,  9.2085, -4.7492],
        [-4.4715,  8.5091, -4.1044],
        [-3.6454, -1.7227,  5.0437],
        [-3.0192,  5.5989, -3.3947],
        [-4.4462,  7.2509, -2.5109],
        [ 1.5972, -5.4454, -0.2399],
        [-4.6917,  9.2326, -4.5100],
        [-1.6198, -5.3554,  5.9652],
        [-3.2063,  4.5290, -1.7784],
        [ 0.9231, -2.4504, -2.1915],
        [-4.1926,  8.5035, -4.6414],
        [-3.1575,  5.0768, -2.5239],
        [-2.4331,  3.6458, -2.5508

Epoch number 8
 Current loss 0.04878493398427963

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.8576, -1.8223,  2.6811],
        [-3.6869,  6.8764, -3.8959],
        [-2.7750, -2.6937,  4.8396],
        [-4.1681,  7.1747, -3.0214],
        [-3.3044, -3.1933,  5.9673],
        [-4.2921,  5.0421, -0.5872],
        [-3.3598,  6.3313, -3.8274],
        [-4.0196,  7.4675, -3.9181],
        [-4.5424,  9.1916, -4.7600],
        [-3.7241,  5.0849, -1.4276],
        [-4.8677,  8.6859, -3.5118],
        [-3.4530,  6.6659, -3.9156],
        [-3.8068,  6.4146, -2.7598],
        [ 1.0524, -5.2964,  0.9690],
        [-3.8252,  6.4851, -2.6926],
        [-4.0414,  7.6959, -3.7394],
        [-4.1378,  7.8281, -3.7877],
        [-4.4531,  8.6949, -4.3461],
        [-3.4502,  5.7321, -2.6872],
        [-2.8380, -3.3378,  5.6282],
        [-2.8907, -0.4298,  2.9418],
        [-4.7880,  4.2610,  0.9442],
        [-4.5975,  9.1580, -4.6298],
        [-3.9392,  7.3777, -3.9112]

Epoch number 8
 Current loss 0.022482391446828842

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.9647, -3.6126,  4.9325],
        [-4.1928,  8.3116, -4.4241],
        [-1.8397, -2.1816,  3.2671],
        [-4.5543,  9.0424, -4.5866],
        [-4.8778,  4.0003,  1.2822],
        [ 1.5025, -4.7828, -0.3489],
        [-4.2542,  8.1168, -3.9407],
        [-3.6530,  5.7895, -2.2246],
        [-3.6848,  0.1061,  3.6026],
        [-1.4231, -4.5116,  4.9137],
        [-4.0733,  7.5859, -3.6879],
        [-1.5712, -5.1803,  5.7860],
        [-4.0774,  6.6158, -2.3635],
        [-3.4698,  5.3919, -1.4684],
        [-1.2125, -4.4713,  4.5148],
        [-1.8399, -5.3352,  6.2099],
        [-3.4175, -3.3988,  6.2720],
        [-2.1139, -4.4269,  5.7362],
        [-2.7765,  3.1776, -1.0832],
        [-4.6333,  9.2341, -4.5830],
        [-3.6213,  6.8654, -3.9455],
        [-4.0841,  6.4061, -2.1458],
        [-3.5918,  6.6776, -3.7750],
        [-4.6000,  8.6411, -3.9981

Epoch number 8
 Current loss 0.023654110729694366

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.6817,  8.3809, -3.5136],
        [-3.0454, -3.5973,  6.0781],
        [-4.4265,  7.7858, -3.2151],
        [-4.2682,  8.5729, -4.6441],
        [ 1.2718, -5.6352,  0.7052],
        [-4.0636,  6.0143, -1.8424],
        [-1.4206, -4.5775,  4.9129],
        [-4.6197,  9.1724, -4.5874],
        [-4.4470,  5.2233, -0.4429],
        [-4.7754,  7.4402, -2.1777],
        [-2.6398, -4.5260,  6.3541],
        [-3.8458,  6.8404, -3.3845],
        [-0.9047, -5.1781,  4.6989],
        [-4.3643,  8.6998, -4.6327],
        [-3.5944,  6.5946, -3.5208],
        [-3.8501,  7.0618, -3.6622],
        [-3.8685,  7.2339, -4.0039],
        [ 1.0004, -5.5683,  1.8120],
        [-4.1644,  8.5478, -4.8953],
        [-2.6726, -2.3159,  4.4769],
        [-4.1712,  8.6674, -4.9126],
        [-4.4459,  8.3398, -3.8294],
        [-4.7148,  9.1698, -4.4541],
        [-4.5373,  8.8149, -4.3289

Epoch number 8
 Current loss 0.08507373929023743

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.5618,  8.7262, -4.1549],
        [-4.3227,  8.0688, -3.7021],
        [-4.6782,  7.4665, -2.4241],
        [-4.8119,  8.1142, -2.9191],
        [-3.5018,  6.8161, -3.8017],
        [-0.2715, -5.1298,  3.5807],
        [-4.3119,  7.9719, -3.8849],
        [-4.2866,  8.5416, -4.4065],
        [-4.0084,  7.2597, -3.2739],
        [-4.5114,  7.3408, -2.5216],
        [-4.0428,  6.7655, -2.6665],
        [-3.8367,  6.7197, -3.0502],
        [-4.0003,  7.7241, -4.2205],
        [ 1.6164, -5.3583, -0.0818],
        [-4.6890,  9.0493, -4.3336],
        [-2.0237, -2.7871,  4.1712],
        [ 1.4857, -4.8305, -0.6140],
        [-3.7896,  7.8562, -4.8957],
        [-3.6014,  7.1500, -4.2595],
        [-4.5893,  9.1798, -4.7030],
        [-4.1930,  7.4487, -3.2929],
        [-4.2345,  8.3377, -4.3730],
        [-4.0646,  7.7620, -3.9546],
        [ 1.5974, -5.2482, -0.1506]

Epoch number 8
 Current loss 0.024626296013593674

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.4467,  8.4095, -3.9426],
        [-3.8189,  7.2032, -4.0476],
        [-1.6006, -4.6562,  5.3016],
        [-4.8305,  8.0121, -2.8755],
        [-0.8748, -2.9753,  2.2779],
        [-3.8879,  6.5701, -2.6564],
        [-2.7171, -1.7668,  4.1296],
        [-3.9576,  7.7719, -4.5366],
        [-4.3016,  8.2937, -4.1449],
        [-3.4780,  6.5904, -3.6596],
        [-3.3622, -1.5983,  4.5529],
        [-3.6379,  6.4897, -3.3064],
        [-1.7778, -3.2155,  4.1471],
        [-4.3245,  7.4035, -2.8591],
        [-3.8746,  7.2403, -3.9866],
        [-4.6011,  9.3388, -4.8357],
        [-3.9366,  6.9242, -3.0541],
        [-4.3812,  8.8280, -4.5793],
        [ 1.5380, -5.0184, -0.6897],
        [-3.9894,  8.0332, -4.4947],
        [-3.5779,  6.4517, -3.3293],
        [-4.0501,  8.0931, -4.5053],
        [-4.5695,  9.2510, -4.8108],
        [-4.4361,  7.1102, -2.4596

Epoch number 8
 Current loss 0.05002906173467636

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 1.5347, -5.1082, -0.1956],
        [-0.8576, -4.1830,  4.1281],
        [-4.2621,  8.1649, -4.0845],
        [-4.3098,  8.7905, -4.7054],
        [-3.7196,  6.4063, -2.8654],
        [-4.6342,  9.2500, -4.6683],
        [-0.7796, -4.2741,  3.4469],
        [-3.4525,  0.2771,  3.1933],
        [-4.0952,  8.1564, -4.2701],
        [-4.6548,  7.4494, -2.4690],
        [-4.3902,  8.6283, -4.4101],
        [-1.7160, -3.3790,  4.5426],
        [-3.8843,  7.7313, -4.3011],
        [-2.1185, -4.8904,  6.1295],
        [-0.1345, -3.4823,  1.2117],
        [-4.4503,  9.0316, -4.8830],
        [-2.5117, -2.9776,  4.7478],
        [-3.6522,  5.0329, -1.0600],
        [-4.0618,  7.7536, -4.2235],
        [-3.6518,  7.1080, -4.1327],
        [-3.3012,  5.1179, -2.3727],
        [-4.2792,  8.9561, -4.9812],
        [-4.3169,  7.5138, -3.0645],
        [-2.8415, -3.3418,  5.5839]

Epoch number 8
 Current loss 0.08896563202142715

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.6020,  4.2944, -0.5602],
        [-5.3743,  6.8885, -0.9178],
        [-1.8558, -2.8080,  3.7486],
        [-3.9877,  7.0136, -3.3013],
        [-4.5580,  8.9262, -4.4428],
        [-4.5328,  9.2696, -4.8821],
        [ 1.0170, -3.1193, -1.8197],
        [-4.0550,  7.5310, -3.5494],
        [-4.6028,  5.3155, -0.1841],
        [ 1.3435, -5.4904,  0.4507],
        [-4.5873,  8.7649, -4.1950],
        [-3.0130,  4.6098, -1.5662],
        [-2.8010, -3.2145,  5.3755],
        [-4.4557,  8.3294, -3.8671],
        [-4.6887,  9.1627, -4.4531],
        [-4.3677,  7.8380, -3.4614],
        [-3.8426,  6.4868, -2.6918],
        [-3.5293,  6.6629, -3.7313],
        [-4.0961,  7.5432, -3.5072],
        [ 1.3601, -4.6160, -0.2490],
        [-4.6259,  7.6526, -2.6595],
        [-3.2326,  5.9632, -3.5142],
        [-4.7430,  8.8007, -3.9425],
        [-4.4189,  7.4701, -2.8276]

Epoch number 8
 Current loss 0.02437664568424225

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.6899, -4.2324,  6.1396],
        [-2.9349, -4.2719,  6.4289],
        [-3.7160, -1.0052,  4.5085],
        [-4.0295,  7.5649, -3.8311],
        [-4.2736,  7.0628, -2.5182],
        [-4.3366,  5.2094, -0.5346],
        [-4.5105,  9.1621, -4.9256],
        [-4.1992,  8.5244, -4.7795],
        [-4.1292,  7.2224, -3.2276],
        [ 1.4246, -4.6889, -0.5000],
        [-2.2182, -4.1504,  5.5065],
        [-2.9086,  4.0119, -1.6443],
        [-1.7263, -4.0170,  4.8767],
        [-2.6195,  4.8829, -3.2422],
        [-4.3100,  7.5199, -3.2122],
        [-1.1611,  1.9263, -2.2613],
        [-4.4120,  8.7769, -4.5316],
        [-3.4994,  6.4291, -3.5213],
        [-4.6257,  8.9271, -4.3139],
        [-4.5245,  6.8975, -2.1218],
        [-3.9785,  7.7363, -4.3421],
        [-3.3562, -3.6738,  6.4304],
        [-3.6596,  5.1051, -1.4688],
        [-3.6290,  6.5266, -3.3611]

Epoch number 8
 Current loss 0.06790433824062347

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.9248,  6.8403, -2.9648],
        [-2.3466, -2.3778,  4.2020],
        [-4.2358,  6.1787, -1.8458],
        [-4.0540,  3.5058,  0.5774],
        [-2.7527, -3.7587,  5.8002],
        [-3.7135,  6.4923, -3.1550],
        [-4.8244,  7.7346, -2.5819],
        [-2.3064,  4.6260, -3.8719],
        [-2.2186, -3.7053,  5.1892],
        [ 1.3876, -4.9544, -0.3044],
        [-3.9957,  7.9536, -4.4498],
        [-4.9616,  9.1947, -4.0194],
        [-4.4208,  8.6171, -4.3821],
        [-3.9402,  6.5220, -2.6486],
        [-4.0032,  7.4233, -3.7615],
        [-3.2170,  5.7739, -3.1382],
        [-4.2233,  8.6348, -4.9179],
        [-4.3029,  6.8998, -2.3331],
        [-3.3865, -2.5891,  5.5205],
        [-3.7759,  5.1481, -1.4194],
        [-4.1784,  8.1997, -4.5577],
        [-2.9332,  5.3396, -3.5011],
        [-2.3827, -4.3955,  5.8985],
        [-4.3914,  6.3960, -1.7100]

Epoch number 8
 Current loss 0.025987472385168076

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.1248,  6.5052, -2.1656],
        [-3.0500,  5.4892, -3.0963],
        [-4.2770,  8.4835, -4.5578],
        [-4.9028,  8.5971, -3.4309],
        [-1.8294, -2.3113,  3.2223],
        [-4.3923,  8.9384, -4.7598],
        [-2.8255,  5.6735, -4.0930],
        [-3.9948,  8.0914, -4.7225],
        [-3.6753, -2.6973,  5.9978],
        [-3.2944,  5.9431, -3.3255],
        [-3.8355,  5.9014, -1.8132],
        [-3.5713,  6.4583, -3.3726],
        [-4.3536,  8.7308, -4.7204],
        [-4.4190,  4.5351,  0.2468],
        [-4.1744,  7.8841, -3.9067],
        [-4.0601,  8.1290, -4.5272],
        [-2.9876, -3.3748,  5.7664],
        [-1.9874, -4.4205,  5.5593],
        [-4.5391,  7.0643, -2.2383],
        [-4.0976,  8.0842, -4.5019],
        [-4.2385,  8.2931, -4.4092],
        [-3.6442,  5.7477, -2.2064],
        [-1.6329, -4.4470,  5.0806],
        [-4.5119,  7.7398, -3.0597

Epoch number 8
 Current loss 0.04658462107181549

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.4834,  6.5532, -4.0508],
        [-3.9350,  6.7503, -2.8126],
        [ 0.5739, -5.0290,  2.0329],
        [ 1.5156, -5.1385, -0.4735],
        [-3.3593, -1.7771,  4.7011],
        [-4.6884,  8.3353, -3.5656],
        [ 1.1374, -5.2881,  0.5867],
        [ 1.0506, -5.4625,  1.0213],
        [-3.4768,  6.4081, -3.5903],
        [-4.3562,  6.4956, -1.8860],
        [-3.9043,  6.6578, -2.8326],
        [-4.2716,  8.7369, -4.8790],
        [-2.4578, -2.4745,  4.3501],
        [-4.1295,  7.9775, -4.0844],
        [-4.5283,  7.9945, -3.3595],
        [-4.8020,  9.1796, -4.3342],
        [-3.8143,  4.8873, -0.7003],
        [-2.9491,  1.8896,  0.5747],
        [-3.7958,  6.5086, -3.0921],
        [-4.0170,  7.1011, -3.3186],
        [-1.7737, -4.6427,  5.5629],
        [-2.5289, -3.5675,  5.3968],
        [-2.2531, -4.3372,  5.7554],
        [-4.8910,  8.6488, -3.5050]

Epoch number 8
 Current loss 0.05944881588220596

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.6564,  6.9260, -3.9036],
        [-4.6810,  6.8313, -1.7464],
        [ 1.0214, -5.1244,  0.7991],
        [-1.5750, -3.9266,  4.7424],
        [-1.4719, -4.3559,  4.9046],
        [-3.7557, -0.2640,  3.8103],
        [-2.7135, -2.9421,  5.1097],
        [-3.8108,  7.1556, -3.8918],
        [-4.8329,  7.7538, -2.5383],
        [-4.1900,  8.2612, -4.5208],
        [-2.4094, -3.4789,  5.2650],
        [-4.5324,  8.8873, -4.4761],
        [-4.2752,  8.6091, -4.6513],
        [-4.5442,  9.2247, -4.9805],
        [-3.4847, -0.8284,  4.2269],
        [-4.6889,  7.9152, -3.0294],
        [-2.4846, -4.2998,  5.9263],
        [-4.1595,  8.2009, -4.4347],
        [-4.0176,  7.3607, -3.9079],
        [-3.2033,  5.1280, -1.9927],
        [-3.3765, -1.8994,  4.8612],
        [-1.2346, -4.5173,  4.5527],
        [-4.5074,  8.8369, -4.3819],
        [ 0.8799, -5.1323,  1.1135]

Epoch number 9
 Current loss 0.024902742356061935

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.3056,  8.5862, -4.4389],
        [-4.7925,  8.3425, -3.2339],
        [-4.1701,  8.5438, -4.8511],
        [-4.4354,  8.7401, -4.4917],
        [ 1.5688, -5.3171,  0.0790],
        [-4.6630,  9.1922, -4.6196],
        [-2.5582, -4.2680,  5.9932],
        [-4.1437,  8.1897, -4.5388],
        [-4.6208,  9.2819, -4.8645],
        [-4.1268,  6.9575, -2.7520],
        [-4.5422,  9.2943, -4.9998],
        [-4.4678,  9.0080, -4.6915],
        [-4.0079,  6.8673, -3.0092],
        [ 1.6305, -5.5302, -0.3913],
        [-2.9479,  5.6434, -3.8320],
        [-4.6497,  9.1578, -4.6412],
        [-1.5882, -4.5212,  5.2370],
        [-4.6973,  7.7656, -2.7777],
        [-4.9493,  9.3703, -4.3037],
        [-5.0904,  6.8975, -1.3035],
        [-4.6138,  9.2752, -4.8117],
        [-1.6675, -1.4138,  1.6574],
        [-5.6176,  7.7198, -1.4896],
        [ 1.5469, -5.1192, -0.1880

Epoch number 9
 Current loss 0.05379084125161171

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.9018, -2.2336,  5.6947],
        [-1.9203,  3.3719, -2.3271],
        [-3.4777, -2.0279,  5.1047],
        [-3.2010, -3.0011,  5.5793],
        [-1.7405, -1.9478,  2.6211],
        [-1.7492, -4.8892,  5.6820],
        [-1.3757, -4.6812,  4.9332],
        [-3.0831, -3.4653,  5.8798],
        [-2.3187, -3.4040,  5.0785],
        [-0.5395, -4.5533,  3.5104],
        [-3.0244, -2.2921,  4.8466],
        [-4.6360,  5.0765,  0.1115],
        [ 1.3062, -5.6822,  0.5116],
        [-4.5677,  5.6890, -0.9156],
        [-3.8330,  7.6968, -4.5606],
        [-2.5470, -3.6603,  5.4469],
        [-3.4136,  6.2421, -3.3594],
        [-4.6992,  9.2731, -4.7420],
        [-4.5185,  8.8716, -4.4831],
        [-3.0084, -4.1904,  6.3767],
        [-4.3681,  8.4260, -4.3937],
        [-4.4553,  6.3585, -1.6680],
        [-5.2301,  8.9412, -3.2583],
        [-3.1235,  5.6729, -3.3072]

Epoch number 9
 Current loss 0.022037049755454063

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.1814,  7.1847, -2.9250],
        [-3.4699,  6.4895, -3.7239],
        [-4.6620,  8.8435, -4.2191],
        [-4.9576,  8.9505, -3.6625],
        [-3.8208,  6.5080, -2.9083],
        [ 1.4153, -4.5452, -0.6516],
        [-4.5082,  6.9912, -2.1176],
        [-3.6062,  6.7723, -3.9936],
        [-4.5536,  8.7242, -4.3308],
        [-2.5247, -5.0276,  6.6673],
        [-4.2728,  8.4652, -4.4062],
        [-4.1935,  8.5426, -4.6565],
        [-3.7389,  6.4979, -3.2482],
        [ 1.3731, -4.3073, -1.1290],
        [-2.6039,  4.3882, -3.0443],
        [-4.6394,  9.0042, -4.4396],
        [-4.4099,  8.4225, -4.2120],
        [-2.3135,  4.4613, -2.9006],
        [-4.3688,  9.0170, -5.0444],
        [-4.5182,  8.9610, -4.6303],
        [-3.6566,  7.5310, -4.6438],
        [-4.7898,  6.5103, -1.4102],
        [-0.2701, -3.4735,  1.3363],
        [-4.9079,  6.1636, -0.8278

Epoch number 9
 Current loss 0.060688018798828125

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.0687,  4.3973, -4.1382],
        [-4.3041,  8.6562, -4.9149],
        [-4.0664,  7.4796, -3.5733],
        [-4.4937,  8.2868, -3.8462],
        [-0.8754, -4.4176,  3.7094],
        [-4.6605,  8.6076, -3.7978],
        [-1.7035, -3.5908,  4.2849],
        [-3.0636,  3.3167, -0.1796],
        [-2.2810, -4.6276,  6.0900],
        [-4.8082,  7.3940, -2.2962],
        [-3.4344,  6.4845, -3.8963],
        [-2.7111, -3.5338,  5.5462],
        [ 1.5437, -5.5358, -0.1490],
        [-2.4128, -4.3661,  5.8933],
        [-4.2246,  7.9090, -3.7657],
        [-1.1643, -4.5322,  4.5012],
        [-3.6141,  5.5989, -2.2847],
        [-1.9326, -4.4017,  5.5077],
        [-1.3089, -5.0767,  5.2766],
        [-3.5640,  6.6600, -3.8878],
        [-2.0016, -3.9798,  5.1774],
        [-3.6961,  5.2538, -1.8314],
        [-2.0774, -3.9898,  5.1903],
        [-4.6051,  9.2234, -4.8017

Epoch number 9
 Current loss 0.037912942469120026

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.6331,  6.9786, -2.1490],
        [-3.9865,  6.7647, -3.0085],
        [-4.6824,  9.4453, -4.9143],
        [ 0.3645, -4.4227,  1.6182],
        [-3.6571,  6.9899, -4.2151],
        [-1.0334, -3.5774,  3.2436],
        [ 1.6744, -5.3411, -0.7877],
        [-1.8430, -5.0720,  5.9295],
        [-3.2472,  3.4924, -0.5504],
        [-2.6038, -4.2556,  6.0959],
        [-2.8747,  3.2418, -0.9527],
        [-2.3905, -4.2890,  5.8039],
        [-5.1982,  8.3319, -2.6241],
        [-4.3697,  6.0852, -1.5348],
        [-3.0660,  5.9310, -4.2340],
        [-4.6576,  6.7109, -1.7444],
        [-2.3760, -2.7122,  4.4239],
        [-3.8303, -2.4835,  5.8254],
        [-2.5830, -3.4432,  5.4078],
        [-3.2426,  5.9910, -3.6030],
        [-1.6122, -5.0837,  5.6967],
        [-3.8621,  7.6675, -4.2284],
        [-4.4016,  8.2271, -3.9045],
        [-5.2180,  8.5321, -2.8858

Epoch number 9
 Current loss 0.05502694845199585

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.6908, -4.7714,  6.6139],
        [-4.7897,  8.5504, -3.6532],
        [-2.5741, -3.2870,  5.2070],
        [-3.8827,  6.4685, -2.5640],
        [-4.0366,  7.8906, -4.3634],
        [-2.1315,  4.4118, -4.2244],
        [-3.7258,  6.8836, -3.6560],
        [-3.7332,  7.1956, -4.3808],
        [-3.5859,  6.0283, -2.7112],
        [-4.0022,  7.4289, -3.8481],
        [-4.2197,  6.3530, -1.9170],
        [-4.5640,  8.5648, -4.0075],
        [-2.8988, -3.8200,  5.9956],
        [-2.1016,  3.9780, -3.3737],
        [-4.0743,  4.3432, -0.1866],
        [ 1.7150, -5.4883, -0.6645],
        [-4.6051,  8.5504, -3.9187],
        [-3.9317,  6.8389, -3.0264],
        [-4.2429,  7.7474, -3.5898],
        [-1.1559, -4.8378,  4.6681],
        [-3.5886,  5.6532, -2.0791],
        [-4.7517,  6.5456, -1.3693],
        [-0.2084, -5.7632,  3.8109],
        [ 1.3253, -6.0503,  1.0244]

Epoch number 9
 Current loss 0.03759152814745903

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.6841, -4.4198,  6.2347],
        [ 1.6185, -5.2642, -0.2594],
        [-4.3264,  8.9333, -5.0421],
        [-3.6316,  6.9117, -4.2825],
        [-4.2947,  8.1107, -4.1949],
        [-2.9325,  5.7321, -4.1305],
        [ 1.6935, -5.4012, -0.1080],
        [-4.0537,  4.4910, -0.3794],
        [-3.8184,  4.8018, -0.7032],
        [-2.5507, -4.6489,  6.3544],
        [-0.9972, -4.0411,  3.8830],
        [-2.0567, -4.8755,  5.9566],
        [ 0.6649, -5.2163,  2.0404],
        [-4.6476,  6.1100, -1.0611],
        [-3.5268,  6.8203, -3.8864],
        [-3.9860,  6.8874, -3.1316],
        [-3.5625,  6.9576, -4.4233],
        [-2.1161, -5.0657,  6.2036],
        [-3.0853,  3.8357, -0.6178],
        [-3.6898,  6.9315, -3.8747],
        [-4.1622,  3.2576,  1.1820],
        [-3.9070,  6.6128, -2.7949],
        [-4.1021,  6.2802, -2.2054],
        [-2.8444,  2.5457, -0.4104]

Epoch number 9
 Current loss 0.02217882126569748

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.3818, -4.5138,  4.7042],
        [-4.8494,  8.9923, -4.1387],
        [-4.3236,  7.7260, -3.3156],
        [-4.7740,  8.4526, -3.5367],
        [-0.9996, -4.4184,  4.2416],
        [-3.8514,  6.5960, -3.1914],
        [-2.0982, -5.2163,  6.3402],
        [-2.2459, -4.8467,  6.1249],
        [-5.2168,  8.2776, -2.6577],
        [-3.7138,  6.5223, -3.3362],
        [-5.0171,  8.2603, -2.9315],
        [-4.7522,  9.3717, -4.7160],
        [-4.5936,  8.4911, -3.9246],
        [-3.6541,  6.5938, -3.5684],
        [-4.6159,  8.9730, -4.4323],
        [-3.9816,  6.5967, -2.7686],
        [-2.2573, -4.6530,  5.9757],
        [-2.3754, -4.1164,  5.6776],
        [-4.2568,  6.9734, -2.5921],
        [-4.1730,  6.9852, -2.8797],
        [-4.3112,  8.2361, -4.3613],
        [-3.9566,  7.9063, -4.7498],
        [-3.8070,  6.4448, -3.0432],
        [-4.5499,  9.0800, -4.6788]

Epoch number 9
 Current loss 0.050509013235569

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-1.8189, -5.0996,  6.0145],
        [-4.8689,  9.5326, -4.7074],
        [-3.8444, -2.1705,  5.5336],
        [-4.6838,  7.7703, -2.9153],
        [-3.4744,  5.7644, -2.5619],
        [-4.4940,  8.8529, -4.5676],
        [-4.1564,  3.5563,  0.7970],
        [-4.5165,  8.4566, -4.0011],
        [ 1.4254, -5.9006,  0.4195],
        [-4.2969,  8.8170, -5.1236],
        [-3.4484,  5.5124, -2.9201],
        [ 1.3378, -5.6614,  0.4520],
        [-4.9739,  9.3753, -4.2383],
        [-3.1825, -3.5039,  6.0597],
        [-4.7775,  9.4453, -4.8400],
        [-2.4967, -3.5778,  5.3948],
        [-4.0095,  8.0515, -4.6752],
        [-3.6769,  6.7423, -3.7922],
        [-4.0537,  7.8835, -4.4637],
        [-3.8795,  7.0209, -3.7381],
        [-2.1044, -2.7752,  4.2093],
        [-4.9397,  7.6049, -2.3566],
        [-4.7863,  9.2593, -4.4788],
        [-4.7132,  8.6766, -3.8531],


Epoch number 9
 Current loss 0.019392699003219604

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.6760,  9.1400, -4.6011],
        [ 1.2514, -5.5687,  0.8070],
        [-4.7455,  9.1729, -4.4772],
        [-4.4246,  8.6875, -4.5988],
        [-2.5899, -4.7981,  6.4309],
        [-1.6897, -3.7236,  4.6842],
        [-4.6399,  9.1783, -4.8072],
        [-4.0399,  6.9619, -3.1048],
        [-2.7187, -2.3660,  4.4277],
        [-4.9520,  9.0952, -3.9665],
        [-2.6423, -2.1986,  4.3229],
        [-2.7732,  5.2879, -3.5814],
        [-4.5409,  8.7802, -4.4181],
        [-3.9329,  6.8604, -3.2755],
        [-4.4397,  7.3455, -2.6494],
        [-2.8618, -2.4483,  4.8221],
        [ 1.3912, -4.0913, -1.4115],
        [-2.8974,  5.5970, -3.6340],
        [-4.4007,  8.2184, -4.0250],
        [-4.6945,  8.9367, -4.2926],
        [-4.0565,  6.3991, -2.5090],
        [-4.6932,  5.2427, -0.2030],
        [ 1.7475, -5.7170, -0.1379],
        [-1.6104, -4.6133,  5.2759

Epoch number 9
 Current loss 0.03656632453203201

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.8004,  9.0302, -4.2035],
        [-3.1333, -4.1017,  6.4758],
        [-4.7822,  8.0909, -3.1269],
        [ 1.7261, -5.4785, -0.1855],
        [-2.7124, -4.4366,  6.2699],
        [-4.5113,  8.6716, -4.3934],
        [-4.6698,  8.7762, -4.1408],
        [-2.2952, -3.6609,  5.1930],
        [-2.9462, -3.3103,  5.5633],
        [-4.5941,  8.4456, -3.7928],
        [-2.8154, -4.0785,  6.1090],
        [-4.3382,  7.8055, -3.6388],
        [-5.0256,  6.5896, -1.1474],
        [-3.6181, -2.2476,  5.3297],
        [-4.5758,  9.3737, -5.1188],
        [-3.0208, -1.4226,  4.1012],
        [ 1.2835, -3.6854, -1.8869],
        [-4.4495,  9.0796, -5.0294],
        [-4.9429,  8.2336, -3.0371],
        [-4.7096,  9.4427, -4.8973],
        [-3.7056,  6.5254, -3.1640],
        [ 1.3388, -3.7868, -1.9707],
        [-4.0031,  7.4313, -3.9199],
        [-4.6469,  7.7837, -3.0053]

Epoch number 9
 Current loss 0.0395321249961853

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.3351,  7.9880, -4.0170],
        [-4.1250,  7.9503, -4.4071],
        [-4.3469,  8.0958, -4.0474],
        [-3.3464, -3.1297,  5.8840],
        [-4.3064,  8.3390, -4.3733],
        [-4.2803,  7.1419, -2.7413],
        [-5.0501,  7.7984, -2.4365],
        [ 1.5189, -4.8035, -0.6150],
        [-4.6649,  9.3841, -4.9923],
        [-4.2613,  7.5520, -3.3768],
        [-2.0409, -4.4283,  5.5286],
        [-2.2573, -4.2776,  5.6584],
        [-3.6433,  5.5621, -2.2570],
        [-3.1855, -3.7097,  6.2490],
        [-4.5214,  8.8509, -4.5794],
        [-3.2987,  6.0964, -3.7882],
        [ 1.0880, -5.1577,  0.5698],
        [-4.3833,  8.5499, -4.5321],
        [-4.0738,  7.5076, -3.9548],
        [-4.5585,  8.7360, -4.3480],
        [-3.4787,  6.8741, -4.5857],
        [-3.7172,  6.1424, -2.8565],
        [-3.5406, -0.8158,  3.9465],
        [ 1.4256, -4.3341, -1.0128],

Epoch number 9
 Current loss 0.05923157557845116

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.4858,  9.1633, -4.9390],
        [-0.6320, -2.8193,  1.4108],
        [-4.0360,  7.2844, -3.4002],
        [-4.7938,  9.1364, -4.3002],
        [-4.7424,  9.4889, -4.9127],
        [-4.4525,  8.9350, -4.8206],
        [-4.5631,  7.9866, -3.3817],
        [-2.1037, -5.2201,  6.2872],
        [-4.5753,  7.1115, -2.2278],
        [-4.6463,  9.4527, -5.0832],
        [-3.9287,  6.1746, -2.3054],
        [-4.9066,  9.3587, -4.4191],
        [ 1.1123, -2.9355, -2.2068],
        [-4.6961,  9.2524, -4.7531],
        [-3.0668,  0.4302,  2.1022],
        [-2.3507, -2.8346,  4.4574],
        [-4.8803,  9.1586, -4.2647],
        [-1.9394, -5.1023,  5.9978],
        [-4.3658,  7.9447, -3.7352],
        [-1.9744, -3.2875,  4.4634],
        [-4.6078,  9.3424, -5.1365],
        [-4.3786,  7.7107, -3.5723],
        [-2.4967, -1.3538,  3.3450],
        [-0.8747, -4.1742,  3.6094]

Epoch number 9
 Current loss 0.048907309770584106

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.7535,  9.4439, -4.8431],
        [-4.5101,  8.1685, -3.6575],
        [-3.8712,  7.1559, -3.8111],
        [-3.6102,  6.8679, -3.8188],
        [-4.2803,  8.3959, -4.4452],
        [-4.0903,  7.5653, -4.0798],
        [-5.0661,  9.1437, -3.8277],
        [-3.9959,  5.2417, -1.0104],
        [-4.6079,  9.3839, -5.0227],
        [-3.3937,  6.3981, -3.9291],
        [-4.2284,  6.7332, -2.4509],
        [-4.8457,  9.2662, -4.4157],
        [-3.6914,  6.8214, -3.7639],
        [-4.1223,  7.3829, -3.6816],
        [-3.7818,  7.1683, -4.2020],
        [-3.9854,  6.7319, -2.8040],
        [-4.8555,  9.0655, -4.0819],
        [ 1.6483, -5.1768, -0.5570],
        [-4.6510,  9.1610, -4.6937],
        [-4.1165, -1.9558,  5.6603],
        [-4.1768,  7.3080, -3.3774],
        [ 1.6681, -5.3500, -0.5579],
        [-4.2677,  6.9799, -2.6599],
        [ 1.0985, -5.0052,  0.5493

Epoch number 9
 Current loss 0.07030482590198517

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.4968,  9.0665, -4.9565],
        [-4.7323,  7.2643, -2.1990],
        [-1.8600, -2.8964,  3.6385],
        [-4.7378,  9.1698, -4.5995],
        [-1.8064, -4.3648,  5.1791],
        [-4.1980,  7.2261, -3.1942],
        [-2.6781, -4.0443,  5.8735],
        [-3.7426, -3.0655,  6.2308],
        [-4.0177,  7.0724, -3.6324],
        [-2.4684, -4.4462,  5.9914],
        [-3.6945,  5.4488, -1.5273],
        [-3.1310, -4.0783,  6.4750],
        [-2.2543, -1.0053,  2.6403],
        [-4.6978,  9.4613, -4.9602],
        [-4.7485,  7.1865, -2.0945],
        [-4.4704,  8.9805, -4.8992],
        [-1.5895, -4.5912,  5.1298],
        [-4.7245,  9.3770, -4.8111],
        [ 1.4676, -5.8393,  0.4579],
        [-3.6884,  6.5941, -3.2458],
        [ 0.2141, -4.8091,  2.7166],
        [-3.1093, -3.9678,  6.3206],
        [-4.7266,  7.8601, -3.0221],
        [-0.2598, -5.4342,  3.6256]

Epoch number 9
 Current loss 0.026035809889435768

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 0.0126, -3.7413,  1.2549],
        [-4.9604,  7.3802, -2.2555],
        [-4.6529,  9.1103, -4.6610],
        [-4.6753,  8.8946, -4.2573],
        [-3.8500,  6.0625, -2.4975],
        [-4.1459,  8.1305, -4.4970],
        [-2.7435, -4.9709,  6.8161],
        [ 1.6513, -5.5263,  0.0401],
        [-4.1543,  6.7909, -2.5114],
        [-4.7576,  7.6062, -2.5124],
        [-1.8197, -2.1333,  2.9928],
        [-4.4302,  7.3592, -2.9571],
        [-4.9524,  9.0574, -4.0590],
        [-2.9862, -4.0240,  6.2531],
        [-3.8016,  6.9646, -4.0244],
        [-4.4008,  7.0180, -2.3615],
        [-2.6407, -4.4318,  6.2593],
        [-3.9601,  6.9414, -3.3970],
        [-4.9950,  9.7664, -4.8034],
        [-4.7530,  9.3842, -4.8695],
        [ 1.6061, -5.0002, -1.0959],
        [-3.3000,  5.9647, -3.5740],
        [ 0.5036, -5.7352,  2.3803],
        [ 1.5517, -4.8822, -0.9023

Epoch number 9
 Current loss 0.01857166737318039

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.9230,  7.3389, -2.0730],
        [-4.2470,  5.6013, -1.4165],
        [-2.9978, -4.2896,  6.4341],
        [ 0.6908, -5.7683,  2.1725],
        [-4.5209,  7.7163, -3.1499],
        [-2.8033, -4.3805,  6.2608],
        [-5.0867,  7.9783, -2.4369],
        [-4.4569,  7.6170, -3.1241],
        [-3.0134, -4.1785,  6.3831],
        [-2.9379, -5.1954,  7.2214],
        [-2.3896, -4.7635,  6.2272],
        [-1.6167, -5.1501,  5.6111],
        [-5.0188,  6.3124, -0.8474],
        [-4.8779,  9.4273, -4.5952],
        [-3.5241,  6.9806, -3.9750],
        [-4.4289,  4.7578, -0.1888],
        [ 0.7869, -3.3277, -0.7577],
        [-4.6376,  6.6807, -1.8245],
        [-1.8068, -5.3070,  6.0606],
        [-4.3419,  7.4887, -3.1752],
        [-4.3671,  6.4205, -1.8158],
        [-3.7896,  5.1429, -1.4245],
        [-3.3651, -3.7294,  6.3972],
        [-4.3522,  7.0544, -2.4490]

Epoch number 9
 Current loss 0.04412613436579704

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.9727,  9.5968, -4.6636],
        [-4.0757,  7.3341, -3.8288],
        [ 0.9443, -2.8355, -1.5868],
        [-4.6688,  8.9920, -4.5437],
        [-3.8417,  5.9337, -2.2350],
        [-2.7199, -4.5701,  6.3568],
        [-2.1578, -4.7879,  5.9623],
        [-4.1737,  6.3425, -2.0121],
        [-3.1755, -4.6296,  6.9573],
        [ 1.6324, -5.0265, -1.0567],
        [-4.7659,  8.5059, -3.6806],
        [-3.0383, -1.4663,  4.2028],
        [-5.2547,  8.1764, -2.4627],
        [ 1.6493, -5.6346,  0.2556],
        [-4.9548,  9.4114, -4.5027],
        [-2.9070,  3.9496, -1.1766],
        [-3.6829, -1.1569,  4.5610],
        [-3.6274,  6.1474, -2.9372],
        [-3.5338,  6.2851, -3.4791],
        [-2.9080, -4.7020,  6.7130],
        [ 0.2639, -5.9897,  3.1493],
        [-4.6573,  9.2825, -5.0332],
        [-4.7776,  8.8562, -4.0923],
        [-3.3242,  5.3551, -2.5420]

Epoch number 9
 Current loss 0.04099850356578827

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.7909, -3.2731,  6.4614],
        [-5.5545,  7.0267, -0.9091],
        [-4.9902,  5.2987,  0.2023],
        [-4.6667,  9.1140, -4.6155],
        [-4.8153,  9.0667, -4.3617],
        [-3.8692,  3.6921,  0.5132],
        [-4.8480,  7.4368, -2.3047],
        [-4.3663,  6.7025, -2.1081],
        [-4.9814,  7.2517, -1.8694],
        [ 1.1982, -6.0872,  1.3317],
        [-4.8078,  8.7902, -4.0869],
        [-4.9696,  5.9760, -0.6788],
        [-4.9048,  9.2780, -4.3895],
        [-5.0859,  9.4771, -4.3247],
        [-4.7277,  8.7102, -3.9643],
        [-3.6434,  5.2881, -2.0959],
        [-4.0076,  6.4786, -2.4968],
        [-4.5625,  8.4686, -4.2089],
        [-4.9833,  9.5302, -4.5963],
        [-4.1648,  6.3840, -2.1686],
        [-4.7662,  6.3665, -1.2753],
        [-4.7162,  8.6202, -3.9274],
        [-4.3727,  7.0322, -2.5083],
        [-0.6651, -5.6974,  4.5414]

Epoch number 9
 Current loss 0.05588392913341522

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.6128, -3.2443,  6.2640],
        [-3.6125, -3.5077,  6.4852],
        [-4.7070,  9.2370, -4.8093],
        [-4.0383,  6.8327, -2.9224],
        [-4.7645,  8.5039, -3.6827],
        [ 0.9191, -6.4524,  2.4982],
        [-4.5583,  7.7721, -3.3904],
        [-2.9340, -3.9915,  6.1054],
        [-3.9848,  7.0179, -3.4017],
        [-3.0664,  6.2464, -4.4026],
        [-4.3220,  7.7755, -3.5190],
        [-5.2383,  7.7533, -2.0427],
        [-4.6124,  9.0165, -4.5603],
        [ 1.5568, -5.1066, -0.3627],
        [-3.2368, -2.2162,  5.0414],
        [-5.1491,  7.6182, -2.0644],
        [-2.9258, -4.3989,  6.4513],
        [-4.6075,  9.2576, -5.1118],
        [-1.7342, -5.4009,  6.0304],
        [-3.0530,  5.0314, -2.3297],
        [-3.6979,  6.1469, -2.7909],
        [ 1.6297, -5.3349, -0.5804],
        [-2.4965, -3.4858,  5.3277],
        [-3.8800, -1.3424,  4.8893]

Epoch number 9
 Current loss 0.025740262120962143

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.2478,  7.8052, -4.1973],
        [-4.2409,  6.2619, -1.9175],
        [-0.4898, -5.7930,  4.6074],
        [-4.3823,  8.1612, -3.9288],
        [-4.2057,  5.0883, -0.5644],
        [-3.6754,  4.6087, -0.5796],
        [-4.8914,  8.0731, -3.0182],
        [-5.2609,  8.4635, -2.8028],
        [-3.1169, -3.9828,  6.3425],
        [-5.2833,  9.2033, -3.5207],
        [-3.4878,  5.9648, -3.0095],
        [-5.5913,  8.1829, -2.0613],
        [-4.6232,  8.9184, -4.6268],
        [-4.0721,  6.4399, -2.3395],
        [-3.0781,  5.8741, -3.8493],
        [-2.7300,  5.5440, -4.2085],
        [-4.8420,  9.3960, -4.6722],
        [ 1.9082, -6.3747, -0.0850],
        [-2.7919, -3.4464,  5.4930],
        [-4.4217,  8.3655, -4.5020],
        [-4.2720,  7.3595, -3.4205],
        [-3.3864, -1.1893,  4.3067],
        [-3.8030,  7.0490, -3.9489],
        [-4.1495,  8.1646, -4.8012

Epoch number 9
 Current loss 0.028156468644738197

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.6500,  4.6845, -1.2187],
        [-5.2649,  7.3705, -1.5454],
        [-4.6956,  4.4956,  0.6051],
        [-4.5581,  8.6700, -4.3366],
        [-4.7933,  9.4652, -4.8128],
        [-4.9328,  9.4561, -4.5374],
        [-1.8946, -4.5805,  5.4792],
        [-5.1454,  5.9914, -0.3572],
        [ 1.7925, -5.4011, -0.8425],
        [-3.2119,  6.2273, -4.1074],
        [-2.5203, -4.0635,  5.7275],
        [-4.4052,  6.6826, -2.3173],
        [ 1.1690, -6.0532,  1.3096],
        [-1.5651, -5.1668,  5.5801],
        [-4.8699,  7.4262, -2.3505],
        [-2.4803, -5.1622,  6.6255],
        [-2.7746, -2.8600,  5.0647],
        [-5.2343,  8.9425, -3.4045],
        [-4.7333,  6.8580, -1.9070],
        [-3.8324,  7.4363, -4.7041],
        [-4.9990,  9.1338, -4.0192],
        [-2.4920, -4.7214,  6.2623],
        [-5.0070,  8.9123, -3.7133],
        [-2.9667, -4.4758,  6.5710

Epoch number 9
 Current loss 0.025984138250350952

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.7661,  9.3429, -4.7748],
        [-5.1402,  8.8590, -3.5147],
        [ 1.6351, -5.1899, -0.4332],
        [-4.1274,  7.5525, -4.0326],
        [-5.0155,  8.1126, -2.8364],
        [-4.9459,  6.4460, -1.2862],
        [-5.1442,  8.3049, -2.7637],
        [-3.5914,  5.9861, -3.2937],
        [-3.6355,  5.4351, -2.1492],
        [-4.5278,  8.0336, -3.6915],
        [-4.7201,  8.9420, -4.3283],
        [ 0.1113, -5.1206,  2.6394],
        [-4.1776,  6.3390, -2.0808],
        [-4.5647,  6.9942, -2.3176],
        [ 0.8009, -1.9789, -3.2207],
        [-3.3124, -3.0954,  5.8288],
        [-3.9254,  7.5998, -4.6120],
        [-5.0751,  7.7403, -2.3143],
        [-4.8653,  8.0219, -2.9822],
        [-3.4864,  4.3210, -1.2055],
        [-4.8273,  9.5070, -4.8915],
        [-2.6436, -5.2003,  6.8658],
        [-4.3288,  8.9054, -5.2258],
        [-4.4222,  6.5911, -1.8043

Epoch number 9
 Current loss 0.047577425837516785

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.8903,  9.0099, -3.9998],
        [-4.9723,  8.6559, -3.4542],
        [-2.8933,  4.4372, -1.7922],
        [-4.6932,  8.9287, -4.3449],
        [-3.1053, -2.8799,  5.3650],
        [-5.0420,  7.4787, -1.9565],
        [-1.8710, -3.1587,  4.1831],
        [-2.7480, -2.9564,  5.0533],
        [-4.6031,  8.1586, -3.4798],
        [-4.5708,  8.3761, -3.8548],
        [-4.6812,  7.5938, -2.8822],
        [-2.5090, -4.5627,  6.0621],
        [-2.0504, -4.5876,  5.6524],
        [-3.0998, -4.2311,  6.4586],
        [-4.5283,  8.8159, -4.6732],
        [-5.1235,  7.5320, -2.0232],
        [-3.2656, -1.9164,  4.6993],
        [-2.4218, -3.9607,  5.5619],
        [-4.5914,  9.1895, -4.9048],
        [-3.2311, -3.5076,  6.0341],
        [-1.6341,  2.3574, -1.8842],
        [-2.1252,  3.6259, -3.2351],
        [-3.2794, -3.2734,  5.9327],
        [-4.2313,  7.2127, -3.0160

Epoch number 9
 Current loss 0.08004221320152283

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.6923,  8.5932, -3.9324],
        [-4.4774,  8.3802, -4.2963],
        [-2.4869, -3.0824,  4.8224],
        [-4.1917,  5.5584, -1.4404],
        [-4.1569, -1.4530,  5.3144],
        [-3.8756,  7.0228, -3.7322],
        [-2.8402, -4.8524,  6.7500],
        [ 0.7384, -2.1926, -1.9055],
        [-4.1640, -2.6489,  6.2947],
        [ 1.1020, -5.5703,  1.1345],
        [ 1.0984, -3.3277, -1.8260],
        [-5.5854,  8.6392, -2.5015],
        [-5.0332,  8.7209, -3.5294],
        [-5.1480,  6.7446, -1.2548],
        [-4.7330,  9.5227, -5.0541],
        [-5.0100,  9.2169, -4.1816],
        [-4.0401,  4.8471, -0.4140],
        [-3.8625, -1.6390,  5.0975],
        [ 1.5402, -6.2232,  0.4688],
        [-3.8802,  7.3575, -4.4251],
        [-4.7907,  9.4006, -4.7175],
        [-4.9032,  9.7442, -4.9882],
        [-4.8065,  9.3460, -4.7002],
        [ 1.7091, -5.3121, -0.7564]

Epoch number 9
 Current loss 0.012918601743876934

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.2483,  8.0128, -4.5564],
        [-4.7365,  9.5668, -5.0817],
        [ 1.5073, -4.3158, -1.2174],
        [-4.9787,  8.4888, -3.2597],
        [-3.7749,  5.4307, -1.8923],
        [-5.2555,  8.4590, -2.7764],
        [-3.6315,  6.8686, -4.0135],
        [ 1.5862, -4.7030, -1.1982],
        [-4.8510,  7.7974, -2.5973],
        [-4.0083,  7.7568, -4.5553],
        [-3.3198,  6.1969, -3.6551],
        [-4.7155,  9.1677, -4.6789],
        [-4.6380,  1.5369,  3.1707],
        [-4.4058,  4.9538, -0.1113],
        [ 1.3382, -5.5044,  0.2917],
        [-4.7180,  8.4528, -3.8225],
        [-4.4616,  7.0636, -2.7076],
        [-1.9192, -4.8325,  5.7245],
        [-1.5449, -3.0811,  3.6987],
        [-3.4665, -3.3337,  6.1151],
        [-4.8606,  8.6845, -3.8214],
        [-4.6772,  9.3032, -5.0234],
        [-1.7013, -4.6891,  5.3668],
        [-5.3687,  8.4222, -2.6477

Epoch number 9
 Current loss 0.02358645759522915

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-5.4958,  9.1611, -3.1837],
        [-4.3187,  8.4629, -4.6898],
        [-4.5483,  7.5823, -3.0141],
        [-3.3663,  6.4281, -4.3331],
        [-4.4060,  8.5732, -4.4621],
        [-2.1446, -3.3628,  4.7638],
        [-3.9894,  3.7259,  0.2358],
        [-3.7248,  6.5239, -3.2458],
        [-4.8357,  7.6707, -2.6273],
        [-3.8597,  7.4697, -4.6653],
        [-4.5654,  9.1916, -5.0335],
        [-5.2213,  8.6291, -3.0993],
        [-2.6101, -3.6728,  5.4468],
        [-2.3504, -3.3800,  5.0971],
        [-4.8508,  9.7801, -5.1107],
        [-2.5668, -4.1331,  5.8268],
        [-4.5245,  9.2223, -5.1489],
        [-1.2713, -5.0845,  5.0496],
        [-5.1503,  7.7629, -2.3583],
        [-3.9387,  6.7867, -3.1796],
        [-4.7382,  7.2488, -2.3533],
        [-4.7385,  7.3430, -2.5352],
        [-4.6450,  9.2648, -4.9782],
        [-4.7296,  7.1574, -2.3309]

Epoch number 9
 Current loss 0.04307808727025986

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.7141,  9.2058, -4.6905],
        [-4.5119,  9.1788, -5.1193],
        [-2.2093, -1.6526,  2.9521],
        [-4.2087, -3.0281,  6.6583],
        [-1.2003, -5.0179,  5.1321],
        [-3.6420,  5.5293, -2.5088],
        [-4.6818,  9.1489, -4.6995],
        [-4.3347,  8.1693, -4.4075],
        [-4.4889,  8.9213, -4.9192],
        [-2.1579, -5.3587,  6.4347],
        [-3.8287,  6.6393, -3.1617],
        [-3.9580,  7.3793, -3.8418],
        [-4.7404,  9.2046, -4.6497],
        [-2.0102, -4.1161,  5.3242],
        [-4.2576,  7.0187, -2.7050],
        [-3.2034, -2.6573,  5.1861],
        [-1.9632, -5.1173,  6.0689],
        [-3.8962,  5.4661, -1.4275],
        [-4.9937,  8.0556, -2.8555],
        [-4.4247,  8.1333, -4.0796],
        [-4.7627,  9.3972, -4.8541],
        [-3.9833,  6.5759, -2.8035],
        [-4.2566,  8.0615, -4.1879],
        [-2.7656, -4.4898,  6.3384]

Epoch number 9
 Current loss 0.0327397920191288

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.4155,  8.4398, -4.1748],
        [-4.3700,  8.2468, -4.4365],
        [-4.5108,  8.8355, -4.5497],
        [-1.4343, -3.6832,  3.5658],
        [-2.3501, -5.0652,  6.3983],
        [-2.5842, -5.0781,  6.6932],
        [-4.5730,  7.9430, -3.3655],
        [-5.5329,  5.9172,  0.0917],
        [-4.1006,  8.2105, -4.7729],
        [-3.9311,  7.3869, -4.3776],
        [-4.0055,  7.2781, -4.0230],
        [-4.3963,  8.5366, -4.3701],
        [-4.9024,  9.2601, -4.3573],
        [-4.0234,  7.0099, -3.4912],
        [-4.1076,  7.6787, -4.3094],
        [-3.0476, -4.2017,  6.4158],
        [-4.9053,  9.2315, -4.2804],
        [-2.4412, -5.5473,  6.9358],
        [-4.9550,  7.6977, -2.4165],
        [-4.9218,  9.6811, -4.8849],
        [-4.3858,  6.5234, -2.0681],
        [ 1.6224, -5.3358, -0.5418],
        [-2.5219, -3.1138,  4.8546],
        [-3.8240,  6.4666, -2.9537],

Epoch number 9
 Current loss 0.09978099167346954

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.5331,  9.0643, -4.9050],
        [-3.9227,  6.4837, -2.9530],
        [-2.5831, -4.5641,  6.1457],
        [-3.1553, -4.7852,  6.9677],
        [-4.3631,  7.1372, -2.6356],
        [-4.6961,  9.1875, -4.6915],
        [-5.1464,  8.7592, -3.3313],
        [-4.4728,  6.6767, -2.2778],
        [ 1.7351, -6.5765,  0.2659],
        [-1.3360, -5.0758,  5.2507],
        [-3.6686,  6.6527, -3.7431],
        [-4.5637,  8.3794, -3.8641],
        [-4.8867,  9.6660, -4.8970],
        [-5.2974,  4.6780,  1.1050],
        [-5.4570,  7.8179, -1.7915],
        [-1.1929, -4.6130,  4.7277],
        [-0.6062, -0.2771, -1.6790],
        [-4.6120,  7.9603, -3.3521],
        [-1.8181, -2.6222,  3.4035],
        [-2.7070, -4.8621,  6.5666],
        [-1.1830, -3.5141,  3.1602],
        [-3.7856,  6.9170, -3.7159],
        [-3.6358,  6.1510, -2.9618],
        [-4.4275,  5.5211, -0.9503]

Epoch number 9
 Current loss 0.05836720019578934

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.9527,  9.4780, -4.6102],
        [-5.1009,  8.6919, -3.3020],
        [-4.8669,  9.6330, -4.8814],
        [-3.6726,  6.8848, -3.7739],
        [-4.1323,  6.8993, -3.0284],
        [-3.2952, -4.7445,  7.0775],
        [-4.7654,  9.5456, -5.0815],
        [-2.0578, -3.3276,  4.6452],
        [-4.9397,  7.6736, -2.3544],
        [-1.2176, -4.8544,  4.9019],
        [-4.0912,  4.4361, -0.0715],
        [-4.7017,  8.7578, -4.1768],
        [-5.1207,  7.6859, -2.0769],
        [-4.3608,  7.5937, -3.6542],
        [-2.3927, -3.4376,  5.0389],
        [-4.9007,  9.4243, -4.6404],
        [-2.2771, -4.2578,  5.6392],
        [-3.9447,  7.2704, -4.0564],
        [-3.4419,  4.3847, -1.2535],
        [-1.8629, -3.9341,  4.8684],
        [-0.9156, -3.1456,  2.3444],
        [ 0.4181, -6.1066,  3.0260],
        [-4.5404,  9.1918, -5.0117],
        [-4.7546,  9.3534, -4.7577]

Epoch number 9
 Current loss 0.04634445905685425

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.5180,  6.7192, -3.8947],
        [-4.8336,  9.6265, -4.9186],
        [-2.7485,  2.9077, -1.3475],
        [-3.6054,  6.4683, -3.4466],
        [ 1.8509, -6.2787,  0.1657],
        [-2.9942,  5.3808, -3.4269],
        [-4.5110,  7.4000, -2.7760],
        [-3.8965,  6.3330, -2.6141],
        [-2.9339, -3.6585,  5.8701],
        [-4.6562,  9.4267, -5.0115],
        [-3.4220, -4.8692,  7.3450],
        [-4.7690,  8.8769, -4.0102],
        [-4.3710,  7.5722, -3.4324],
        [-2.8885, -4.0929,  6.1959],
        [-3.9248,  7.1964, -4.0972],
        [-3.5049,  4.4921, -1.3658],
        [-4.7410,  8.9067, -4.1111],
        [-4.5380,  9.1566, -4.9442],
        [-4.1447,  7.3652, -3.5486],
        [-2.3025, -5.0611,  6.3242],
        [ 1.7466, -5.4650, -0.6679],
        [-0.6762, -4.7167,  3.6298],
        [-0.4252, -6.1710,  4.5589],
        [-4.8525,  8.4218, -3.2602]

Epoch number 9
 Current loss 0.03259538486599922

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.5959, -5.5951,  7.1132],
        [-1.8183, -5.1139,  5.8888],
        [-4.7891,  6.0039, -0.7516],
        [-2.9264,  3.7467, -1.7476],
        [-4.1389,  7.1034, -3.0546],
        [-1.3865, -4.9169,  5.0660],
        [-3.8572,  6.3838, -2.7187],
        [-4.6180,  9.2569, -4.8338],
        [-4.5921,  9.3420, -4.9768],
        [ 0.3314, -6.1497,  3.1730],
        [-4.0974,  6.1578, -2.1221],
        [-4.4494,  9.1733, -5.2067],
        [-1.6137,  3.6375, -4.7539],
        [-3.7897,  5.4747, -1.7982],
        [-2.5093, -4.5385,  6.0191],
        [ 1.2945, -6.3041,  1.2819],
        [-3.0951,  5.6201, -3.6652],
        [-4.1790,  6.5718, -2.2952],
        [-4.2080,  8.0551, -4.6248],
        [ 0.0435, -6.2481,  4.0360],
        [-4.6497,  8.7453, -4.2414],
        [-4.2817,  8.2445, -4.2266],
        [-3.3843,  5.3206, -1.9090],
        [-1.1879, -5.6196,  5.6405]

Epoch number 9
 Current loss 0.04169566556811333

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-3.5267,  6.8867, -4.4072],
        [-4.8546,  9.2846, -4.3542],
        [-4.1634,  8.3928, -4.7810],
        [-4.1735,  8.4196, -4.9805],
        [-4.6683,  9.3067, -4.7688],
        [-2.0146, -5.4298,  6.3554],
        [ 1.2615, -3.9427, -1.2923],
        [-3.6375,  5.8131, -2.4928],
        [-1.7222, -4.5720,  5.2311],
        [-4.7956,  8.2393, -3.2055],
        [-4.5732,  8.8853, -4.4234],
        [-1.8752, -5.5727,  6.3160],
        [-4.8020,  8.5052, -3.5030],
        [-3.7175,  6.5256, -3.2041],
        [-2.2352,  1.0725, -0.3186],
        [-3.2714, -3.2709,  5.8480],
        [-2.8548,  4.2645, -1.6142],
        [-1.5531, -1.5065,  1.5451],
        [-1.5303, -4.4229,  4.9706],
        [-4.1311,  6.4663, -2.2997],
        [-4.8004,  9.3631, -4.6334],
        [ 1.8358, -5.8037, -0.2408],
        [-4.5452,  9.1715, -4.9387],
        [ 1.5451, -4.5978, -0.3171]

Epoch number 9
 Current loss 0.028904687613248825

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-4.9036,  8.3168, -3.1136],
        [-3.5371,  6.1593, -3.3338],
        [-4.4251,  8.2880, -4.1777],
        [-4.6453,  7.1470, -2.1109],
        [-2.0120, -5.5508,  6.5359],
        [ 1.9213, -6.0407, -0.4769],
        [-4.5149,  8.8830, -4.4936],
        [-4.4397,  7.4079, -2.8360],
        [ 1.4311, -4.1426, -0.6478],
        [-1.9994, -4.8522,  5.8224],
        [-3.9678,  6.0321, -2.0971],
        [-4.7102,  8.6833, -3.9292],
        [-3.9294,  8.2540, -5.2053],
        [-4.0069,  8.2012, -4.7327],
        [-1.7823, -5.5877,  6.2779],
        [-1.3129, -1.9868,  1.6787],
        [-3.2394,  4.2845, -1.0366],
        [-4.4781,  9.0008, -4.7482],
        [-4.5825,  9.3634, -5.0892],
        [-2.3720,  5.1888, -4.6696],
        [-3.4601,  6.0752, -2.9156],
        [ 0.0190, -3.2943,  0.5066],
        [-4.7450,  9.3880, -4.7395],
        [-4.4117,  7.6682, -3.0790

Epoch number 9
 Current loss 0.07294070720672607

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 1.6023, -5.1140, -0.3433],
        [-4.6720,  9.6953, -5.3182],
        [-2.4624, -5.5343,  6.9302],
        [-3.7064,  8.0658, -5.4340],
        [ 1.6778, -4.9829, -0.9538],
        [-4.5188,  7.7303, -3.0626],
        [-4.7053,  9.2297, -4.6115],
        [-3.4810,  6.2787, -3.7313],
        [-3.4298,  1.6231,  1.6220],
        [-1.5059, -2.8890,  3.0308],
        [-3.8267,  7.5439, -4.7147],
        [-3.4590,  6.3225, -3.7187],
        [-3.3026,  6.1321, -3.7805],
        [ 0.0255, -2.9246,  0.8005],
        [-3.6315,  5.8348, -2.8598],
        [-2.6635, -3.0646,  5.0404],
        [-3.3892,  4.3741, -1.5631],
        [-4.7125,  7.0603, -2.1588],
        [-4.8749,  8.1445, -3.0604],
        [-4.0096,  8.3324, -4.8708],
        [-4.4485,  6.5397, -1.7680],
        [ 1.3743, -5.8800,  0.4137],
        [-4.5192,  9.4231, -5.2301],
        [-2.5635,  4.7739, -3.9968]

Epoch number 9
 Current loss 0.02271922677755356

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.5050, -5.1444,  6.6104],
        [-2.4470, -4.4205,  5.9158],
        [-0.8282,  1.7753, -3.0061],
        [ 1.4286, -4.2614, -1.0778],
        [-3.1483, -4.5264,  6.7390],
        [-3.5031,  6.0323, -3.1477],
        [-1.8684,  1.3534, -1.1213],
        [ 0.2382, -5.5744,  3.2907],
        [-3.5169,  6.1015, -3.1358],
        [-4.3848,  8.8102, -4.8373],
        [-1.9545, -4.9184,  5.8389],
        [-3.7751,  6.5872, -3.1987],
        [-3.5336,  7.4258, -5.0711],
        [-4.5470,  8.6419, -4.0716],
        [-1.9793, -5.5572,  6.4611],
        [-4.3417,  8.4792, -4.4275],
        [-4.9150,  7.6303, -2.2604],
        [-4.2784,  9.0799, -5.4561],
        [-4.9746,  8.7038, -3.5009],
        [-1.6852, -3.4256,  4.2960],
        [-2.3294, -5.3534,  6.5840],
        [-3.3977, -0.1510,  3.3213],
        [-4.1461,  7.0857, -3.0333],
        [-3.2672,  5.9836, -3.6242]

Epoch number 9
 Current loss 0.03261885792016983

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[-2.4826, -3.0298,  4.7787],
        [-5.0900,  8.7666, -3.4534],
        [-1.9219, -4.9454,  5.8581],
        [-5.6777,  8.6493, -2.3478],
        [-4.7897,  9.2169, -4.3921],
        [-4.0053,  6.4868, -2.6547],
        [-1.4510, -4.5082,  4.8007],
        [-4.1244,  6.6400, -2.4914],
        [-2.6689, -2.7763,  4.8155],
        [-4.0666,  8.3923, -5.1507],
        [-4.7658,  9.7216, -5.1526],
        [-2.3652, -5.1312,  6.4191],
        [-4.3616,  9.0820, -5.4454],
        [-1.9100, -5.3127,  6.1482],
        [-2.3227, -4.9466,  6.2701],
        [-4.5823,  8.3437, -3.7442],
        [-0.7355, -4.5737,  3.5716],
        [-5.5653,  8.4186, -2.2933],
        [-4.2298,  6.0594, -1.8925],
        [-1.8012, -4.9350,  5.7447],
        [ 1.6842, -5.5634,  0.1905],
        [-1.2628, -2.4964,  2.3202],
        [-3.8457,  5.6563, -2.0891],
        [-2.8983,  5.3519, -3.7260]

Epoch number 9
 Current loss 0.04039338231086731

inputs are
torch.Size([100, 128])
torch.Size([100, 128])
OUTPUT
tensor([[ 0.1538, -3.9674,  0.9364],
        [-3.8854,  6.0092, -2.4569],
        [-3.1324, -4.6526,  6.8262],
        [-4.7796,  9.7089, -5.1449],
        [-4.3849,  5.9286, -1.2634],
        [-4.7024,  9.3269, -4.8398],
        [-2.3852, -2.6508,  4.3870],
        [-1.6027, -4.9865,  5.4800],
        [-4.7068,  9.7252, -5.3470],
        [-4.7510,  9.1453, -4.5074],
        [-2.2108, -1.3189,  2.7937],
        [-4.8878,  9.3748, -4.4009],
        [-4.4990,  9.2118, -5.0611],
        [-5.1635,  8.1896, -2.5666],
        [ 1.1347, -6.1184,  1.4470],
        [-4.5199,  7.5564, -2.9183],
        [-4.6131,  8.6691, -4.2045],
        [-4.2515,  8.3381, -4.3790],
        [-4.7379,  8.5769, -3.8357],
        [-4.6386,  7.8482, -2.9940],
        [ 1.2721, -6.4567,  1.5102],
        [-3.2464,  6.0898, -3.9479],
        [-0.6661,  2.5262, -4.6240],
        [ 1.4591, -4.5495, -1.0760]

Epoch number 9
 Current loss 0.04700294882059097



In [20]:
def binary_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """
    #round predictions to the closest integer
    rounded_preds = torch.round(torch.sigmoid(preds)).float()
    print(rounded_preds.shape)
    correct = (rounded_preds == y).sum().item() #convert into float for division
    onf_matrix = confusion_matrix(rounded_preds, y)
    print(onf_matrix)
    acc = correct/y.shape[0]
    print(acc)
    return acc

In [82]:
def classification_accuracy(preds, y):
    _, predicted = torch.max(preds.data, 1)
    num_correct = (predicted == y).sum().item()
    print("num correct")
    print(num_correct)
    print(y.shape[0])
    conf_matrix = confusion_matrix(predicted, y)
    print(conf_matrix)
    return num_correct / y.shape[0]

In [80]:
def evaluate(model, iterator):
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    batch_num = 0
    accuracies = []
    i = 0
    with torch.no_grad():
        for i, data in enumerate(iterator,0):
            img0, img1 , label = data
            _, batch_size = img0.shape
            model.hidden1 = model.init_hidden(batch_size)
            model.hidden2 = model.init_hidden(batch_size)
            output = model(img0,img1)
            #accuracy = binary_accuracy(output.squeeze(), label.float())
            accuracy = classification_accuracy(output.squeeze(), label.long())
            accuracies.append(accuracy)
            i+=1
            print("current accuracy is")
            print(accuracy)
            if i % 20 == 0:
                print("cumulative accuracy")
                print(np.sum(accuracies) / i)
        print("FINAL")
        print(np.sum(accuracies) / i)

In [81]:
#most error comes from when it's actually 1 but we predict 0 (related but we predict unrelated)
evaluate(model, val_batch_it)


num correct
124
128
current accuracy is
0.96875
num correct
117
128
current accuracy is
0.9140625
num correct
115
128
current accuracy is
0.8984375
num correct
124
128
current accuracy is
0.96875
num correct
118
128
current accuracy is
0.921875
num correct
123
128
current accuracy is
0.9609375
num correct
113
128
current accuracy is
0.8828125
num correct
114
128
current accuracy is
0.890625
num correct
115
128
current accuracy is
0.8984375
num correct
116
128
current accuracy is
0.90625
num correct
120
128
current accuracy is
0.9375
num correct
125
128
current accuracy is
0.9765625
num correct
118
128
current accuracy is
0.921875
num correct
122
128
current accuracy is
0.953125
num correct
120
128
current accuracy is
0.9375
num correct
116
128
current accuracy is
0.90625
num correct
121
128
current accuracy is
0.9453125
num correct
119
128
current accuracy is
0.9296875
num correct
120
128
current accuracy is
0.9375
num correct
119
128
current accuracy is
0.9296875
cumulative accuracy
0

In [None]:

# 0.8380412635833554 without 
# 0.8366860339915186 with extra layersa?!?!
# 0.8412300390935594 with only looking at 100 words
# 0.9252252849191624 with 100 and actual article lol, 10 epochs

# stance detection 0.9340277777777778