In [2]:
# LSTM 
# explanation:  http://colah.github.io/posts/2015-08-Understanding-LSTMs/

In [1]:
import os
import time
import numpy as np
from tqdm import tqdm
from string import punctuation
from collections import Counter
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(123)

<torch._C.Generator at 0x7f9be4ba91f0>

In [2]:
import random
from torchtext import (data, datasets)

In [3]:
TEXT_FIELD = data.Field(tokenize = data.get_tokenizer("basic_english"), include_lengths = True)
LABEL_FIELD = data.LabelField(dtype = torch.float)

train_dataset, test_dataset = datasets.IMDB.splits(TEXT_FIELD, LABEL_FIELD)
train_dataset, valid_dataset = train_dataset.split(random_state = random.seed(123))

In [4]:
MAX_VOCABULARY_SIZE = 25000

TEXT_FIELD.build_vocab(train_dataset, 
                 max_size = MAX_VOCABULARY_SIZE)

LABEL_FIELD.build_vocab(train_dataset)

In [5]:
B_SIZE = 64 # batch size

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_data_iterator, valid_data_iterator, test_data_iterator = data.BucketIterator.splits(
    (train_dataset, valid_dataset, test_dataset), 
    batch_size = B_SIZE,
    sort_within_batch = True,
    device = device)

In [27]:
## If you are training using GPUs, we need to use the following function for the pack_padded_sequence method to work 
## (reference : https://discuss.pytorch.org/t/error-with-lengths-in-pack-padded-sequence/35517/3)
if torch.cuda.is_available():
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
from torch.nn.utils.rnn import pack_padded_sequence, PackedSequence

def cuda_pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=False):
    # length becoms a tensor. ref: https://pytorch.org/docs/stable/generated/torch.as_tensor.html
    lengths = torch.as_tensor(lengths, dtype=torch.int64)
    lengths = lengths.cpu()
    
    if enforce_sorted:
        sorted_indices = None
    else:
        # sorted_indices means original position at after sorted tensor.
        # ref: https://pytorch.org/docs/stable/generated/torch.sort.html
        lengths, sorted_indices = torch.sort(lengths, descending=True)
        sorted_indices = sorted_indices.to(input.device)
        
    batch_dim = 0 if batch_first else 1
    input = input.index_select(batch_dim, sorted_indices)

    data, batch_sizes = \
    torch._C._VariableFunctions._pack_padded_sequence(input, lengths, batch_first)
    return PackedSequence(data, batch_sizes, sorted_indices)

In [28]:
class LSTM(nn.Module):
    def __init__(self, vocabulary_size, embedding_dimension, hidden_dimension, output_dimension, dropout, pad_index):
        super().__init__()
        self.embedding_layer = nn.Embedding(vocabulary_size, embedding_dimension, padding_idx = pad_index)
        # num_layers should be > 1. otherwise, it cause below warning
        # UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.5 and num_layers=1
        self.lstm_layer = nn.LSTM(embedding_dimension, 
                           hidden_dimension, 
                           num_layers=2, 
                           bidirectional=True, 
                           dropout=dropout)
        self.fc_layer = nn.Linear(hidden_dimension * 2, output_dimension)
        self.dropout_layer = nn.Dropout(dropout)
        
    def forward(self, sequence, sequence_lengths=None):
        if sequence_lengths is None:
            sequence_lengths = torch.LongTensor([len(sequence)])
        
        # sequence := (sequence_length, batch_size)
        embedded_output = self.dropout_layer(self.embedding_layer(sequence))
        
        
        # embedded_output := (sequence_length, batch_size, embedding_dimension)
        if torch.cuda.is_available():
            packed_embedded_output = cuda_pack_padded_sequence(embedded_output, sequence_lengths)
        else:
            packed_embedded_output = nn.utils.rnn.pack_padded_sequence(embedded_output, sequence_lengths)
        
        packed_output, (hidden_state, cell_state) = self.lstm_layer(packed_embedded_output)
        # hidden_state := (num_layers * num_directions, batch_size, hidden_dimension)
        # num_directions = 2 if bidirectional LSTM.
        # cell_state := (num_layers * num_directions, batch_size, hidden_dimension)
        
        op, op_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)
        # op := (sequence_length, batch_size, hidden_dimension * num_directions)
        
        hidden_output = torch.cat((hidden_state[-2,:,:], hidden_state[-1,:,:]), dim = 1)        
        # hidden_output := (batch_size, hidden_dimension * num_directions)
        
        return self.fc_layer(hidden_output)

    
INPUT_DIMENSION = len(TEXT_FIELD.vocab)
EMBEDDING_DIMENSION = 100
HIDDEN_DIMENSION = 32
OUTPUT_DIMENSION = 1
DROPOUT = 0.5
PAD_INDEX = TEXT_FIELD.vocab.stoi[TEXT_FIELD.pad_token]

lstm_model = LSTM(INPUT_DIMENSION, 
            EMBEDDING_DIMENSION, 
            HIDDEN_DIMENSION, 
            OUTPUT_DIMENSION, 
            DROPOUT, 
            PAD_INDEX)

In [29]:
UNK_INDEX = TEXT_FIELD.vocab.stoi[TEXT_FIELD.unk_token] # unk means unknown

lstm_model.embedding_layer.weight.data[UNK_INDEX] = torch.zeros(EMBEDDING_DIMENSION)
lstm_model.embedding_layer.weight.data[PAD_INDEX] = torch.zeros(EMBEDDING_DIMENSION)

In [30]:
optim = torch.optim.Adam(lstm_model.parameters())
loss_func = nn.BCEWithLogitsLoss() # binary cross entropy.

lstm_model = lstm_model.to(device)
loss_func = loss_func.to(device)

In [31]:
def accuracy_metric(predictions, ground_truth):
    """
    Returns 0-1 accuracy for the given set of predictions and ground truth
    """
    # round predictions to either 0 or 1
    rounded_predictions = torch.round(torch.sigmoid(predictions))
    success = (rounded_predictions == ground_truth).float() #convert into float for division 
    accuracy = success.sum() / len(success)
    return accuracy

In [63]:
def train(model, data_iterator, optim, loss_func):
    loss = 0
    accuracy = 0
    model.train()
    
    for curr_batch in data_iterator:
        optim.zero_grad()
        sequence, sequence_lengths = curr_batch.text
        preds = lstm_model(sequence, sequence_lengths).squeeze(1)
        
        loss_curr = loss_func(preds, curr_batch.label)
        accuracy_curr = accuracy_metric(preds, curr_batch.label)
        
        loss_curr.backward()
        optim.step()
        
        loss += loss_curr.item()
        accuracy += accuracy_curr.item()
        
    return loss/len(data_iterator), accuracy/len(data_iterator)

In [64]:
def validate(model, data_iterator, loss_func):
    loss = 0
    accuracy = 0
    model.eval()
    
    with torch.no_grad():
        for curr_batch in data_iterator:
            sequence, sequence_lengths = curr_batch.text
            preds = model(sequence, sequence_lengths).squeeze(1)
            
            loss_curr = loss_func(preds, curr_batch.label)
            accuracy_curr = accuracy_metric(preds, curr_batch.label)

            loss += loss_curr.item()
            accuracy += accuracy_curr.item()
        
    return loss/len(data_iterator), accuracy/len(data_iterator)

In [34]:
num_epochs = 10
best_validation_loss = float('inf')

for ep in range(num_epochs):

    time_start = time.time()
    
    training_loss, train_accuracy = train(lstm_model, train_data_iterator, optim, loss_func)
    validation_loss, validation_accuracy = validate(lstm_model, valid_data_iterator, loss_func)
    
    time_end = time.time()
    time_delta = time_end - time_start 
    
    if validation_loss < best_validation_loss:
        best_validation_loss = validation_loss
        torch.save(lstm_model.state_dict(), 'lstm_model.pt')
    
    print(f'epoch number: {ep+1} | time elapsed: {time_delta}s')
    print(f'training loss: {training_loss:.3f} | training accuracy: {train_accuracy*100:.2f}%')
    print(f'validation loss: {validation_loss:.3f} |  validation accuracy: {validation_accuracy*100:.2f}%')
    print()

epoch number: 1 | time elapsed: 7.216773271560669s
training loss: 0.683 | training accuracy: 55.51%
validation loss: 0.632 |  validation accuracy: 65.00%

epoch number: 2 | time elapsed: 7.319506645202637s
training loss: 0.607 | training accuracy: 67.18%
validation loss: 0.513 |  validation accuracy: 74.96%

epoch number: 3 | time elapsed: 7.346320629119873s
training loss: 0.538 | training accuracy: 73.13%
validation loss: 0.475 |  validation accuracy: 77.69%

epoch number: 4 | time elapsed: 7.020611524581909s
training loss: 0.504 | training accuracy: 75.76%
validation loss: 0.426 |  validation accuracy: 80.49%

epoch number: 5 | time elapsed: 7.159215688705444s
training loss: 0.450 | training accuracy: 79.33%
validation loss: 0.382 |  validation accuracy: 84.11%

epoch number: 6 | time elapsed: 7.411106586456299s
training loss: 0.477 | training accuracy: 77.29%
validation loss: 0.558 |  validation accuracy: 71.92%

epoch number: 7 | time elapsed: 7.131444215774536s
training loss: 0.48

In [37]:
lstm_model.load_state_dict(torch.load('../../Mastering-PyTorch/Chapter04/lstm_model.pt'))

test_loss, test_accuracy = validate(lstm_model, test_data_iterator, loss_func)

print(f'test loss: {test_loss:.3f} | test accuracy: {test_accuracy*100:.2f}%')

test loss: 0.401 | test accuracy: 82.86%


In [38]:
def sentiment_inference(model, sentence):
    model.eval()
    
    # text transformations
    tokenized = data.get_tokenizer("basic_english")(sentence)
    tokenized = [TEXT_FIELD.vocab.stoi[t] for t in tokenized]
    
    # model inference
    model_input = torch.LongTensor(tokenized).to(device)
    model_input = model_input.unsqueeze(1)
    
    pred = torch.sigmoid(model(model_input))
    
    return pred.item()

In [39]:
print(sentiment_inference(lstm_model, "This film is horrible"))
print(sentiment_inference(lstm_model, "Director tried too hard but this film is bad"))
print(sentiment_inference(lstm_model, "Decent movie, although could be shorter"))
print(sentiment_inference(lstm_model, "This film will be houseful for weeks"))
print(sentiment_inference(lstm_model, "I loved the movie, every part of it"))

0.2899036407470703
0.04821621999144554
0.4011892080307007
0.5961778163909912
0.8934914469718933


# Visualization ??
### but doesn't work...

In [44]:
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization

In [52]:
! python -m spacy download en_core_web_sm
import spacy
nlp = spacy.load('en_core_web_sm')

Collecting en-core-web-sm==3.3.0
  Using cached https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.3.0/en_core_web_sm-3.3.0-py3-none-any.whl (12.8 MB)
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')


In [53]:
lig = LayerIntegratedGradients(lstm_model, lstm_model.embedding_layer)

In [54]:
def forward_with_sigmoid(input, l):
    return torch.sigmoid(lstm_model(input, l))

In [55]:
token_reference = TokenReferenceBase(reference_token_idx=PAD_INDEX)

In [56]:
# accumalate couple samples in this array for visualization purposes
vis_data_records_ig = []

def interpret_sentence(model, sentence, min_len = 7, label = 0):
    # text transformations
    tokenized = data.get_tokenizer("basic_english")(sentence)
    tokenized = [TEXT_FIELD.vocab.stoi[t] for t in tokenized]
    
    # model inference
    model_input = torch.LongTensor(tokenized).to(device)
    model_input = model_input.unsqueeze(1)
    length_input = torch.LongTensor([len(tokenized)])
    pred = torch.sigmoid(model(model_input, length_input))

    # generate reference indices for each sample
    reference_indices = token_reference.generate_reference(len(tokenized), device=device).unsqueeze(0)
    
    
    print(model_input.shape)
    print(reference_indices)
    # compute attributions and approximation delta using layer integrated gradients
    attributions_ig, delta = lig.attribute(model_input,
                                           reference_indices.reshape(model_input.shape[1], model_input.shape[0]), 
                                           n_steps=500, return_convergence_delta=True)

    print('pred: ', Label.vocab.itos[pred_ind], '(', '%.2f'%pred, ')', ', delta: ', abs(delta))

    add_attributions_to_visualizer(attributions_ig, text, pred, pred_ind, label, delta, vis_data_records_ig)
    
def add_attributions_to_visualizer(attributions, text, pred, pred_ind, label, delta, vis_data_records):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()

    # storing couple samples in an array for visualization purposes
    vis_data_records.append(visualization.VisualizationDataRecord(
                            attributions,
                            pred,
                            Label.vocab.itos[pred_ind],
                            Label.vocab.itos[label],
                            Label.vocab.itos[1],
                            attributions.sum(),       
                            text,
                            delta))

In [57]:
interpret_sentence(lstm_model, 'It was a fantastic performance !', label=1)
interpret_sentence(lstm_model, 'Best film ever', label=1)
interpret_sentence(lstm_model, 'Such a great show!', label=1)
interpret_sentence(lstm_model, 'It was a horrible movie', label=0)
interpret_sentence(lstm_model, 'I\'ve never watched something as bad', label=0)
interpret_sentence(lstm_model, 'It is a disgusting movie!', label=0)

torch.Size([6, 1])
tensor([[1, 1, 1, 1, 1, 1]], device='cuda:0')


AssertionError: Baseline can be provided as a tensor for just one input and broadcasted to the batch or input and baseline must have the same shape or the baseline corresponding to each input tensor must be a scalar. Found baseline: tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0') and input: tensor([[[ 1.5899e-01,  3.9581e-01,  8.3462e-01,  9.0878e-01, -1.8682e-02,
           1.2946e+00,  4.9903e-01, -3.9060e-01,  1.2779e+00, -1.1718e+00,
           2.2578e-01, -2.1784e+00, -8.9297e-01,  2.3828e-01,  1.0838e+00,
           3.1132e-02,  5.1392e-01,  7.9335e-01,  1.2815e-01, -8.2359e-01,
           1.0550e+00,  2.1603e-01, -7.1140e-01,  2.1493e-03, -8.0563e-01,
           7.3644e-01,  8.6861e-01,  3.3154e-01, -2.3092e+00, -1.0909e+00,
           5.0913e-01,  8.7766e-01, -6.7373e-01,  1.3036e+00, -3.6899e-01,
          -8.1205e-01, -5.6677e-01, -7.9048e-01,  6.9638e-01, -2.2907e-01,
          -2.7089e+00, -3.4715e-01,  1.0964e+00, -4.3381e-01, -6.3328e-01,
           1.1284e+00,  8.4442e-01,  7.3844e-01,  2.5823e-02,  2.0916e+00,
           5.5208e-02, -5.9757e-01, -8.5522e-01,  1.1001e+00,  6.4866e-01,
           3.9835e-01, -1.6432e+00, -1.0061e+00,  5.0756e-01, -9.1122e-01,
           1.6324e-01,  1.4564e+00, -3.8453e-01, -1.7927e-01,  1.9327e-01,
          -1.0681e+00, -5.3395e-01,  1.0046e+00,  1.0359e+00,  6.3755e-01,
          -1.2106e+00,  1.0613e+00, -1.7825e+00,  9.5887e-01, -1.0199e+00,
           9.5073e-01,  9.9059e-01, -4.8548e-01,  5.1139e-01, -9.8376e-01,
           7.3883e-01,  7.8875e-01,  7.1675e-01,  4.7122e-01,  1.4556e+00,
           6.9142e-01,  2.3638e-01,  8.7639e-02,  1.5929e-01, -3.0165e-01,
          -5.9262e-01,  7.9961e-01,  6.3470e-01, -5.0957e-01, -1.3268e-01,
           4.5197e-01,  5.7480e-01, -1.2900e+00, -1.4176e+00,  2.2525e+00]],

        [[-6.0990e-01, -7.1162e-03, -3.5230e-01, -7.1586e-01, -6.4691e-01,
          -5.3946e-01,  4.3544e-01, -1.1125e+00, -1.7052e+00,  1.4048e-02,
           8.4729e-01, -1.4919e+00, -2.3127e-01, -1.4540e+00,  4.7199e-01,
           4.3575e-01, -1.1604e+00,  1.7407e+00,  1.4828e-01,  2.5753e+00,
          -3.9196e-01,  1.5379e+00, -9.2076e-01, -2.8512e-01,  8.2832e-01,
           7.7631e-01, -9.2027e-01,  1.2276e+00, -1.0036e+00,  3.4194e-01,
           9.5555e-01, -1.2601e+00,  5.1857e-01, -2.9635e-01, -1.0057e+00,
           1.4105e+00, -5.4511e-01, -7.8574e-01,  2.2185e-01, -1.5449e+00,
           1.4163e+00, -1.9123e+00, -5.9268e-01, -9.6424e-03,  1.1232e+00,
          -7.1503e-01,  1.2369e+00, -3.5510e-01, -2.4510e+00,  2.3202e+00,
           4.4226e-01, -1.4638e+00,  4.7718e-02,  3.7105e-01,  8.7604e-01,
          -1.2044e+00,  7.1035e-01, -1.3140e+00, -2.0518e-01,  1.3727e-01,
          -1.6167e+00,  2.5406e-01, -7.8117e-01, -3.7308e-01, -7.0169e-01,
           8.6800e-01, -9.4508e-01, -1.1207e-01, -6.8230e-01,  1.8243e+00,
           2.0093e-01, -6.9027e-01, -1.0147e+00,  1.1779e+00, -2.2387e+00,
          -2.6378e+00, -2.8828e-01,  3.3515e-01,  8.3937e-01,  6.2541e-01,
          -8.6416e-01,  4.7975e-01,  9.8031e-01,  1.1215e+00, -9.3458e-01,
          -1.5531e-01,  1.4716e+00, -4.0777e-01,  4.4154e-02, -6.4398e-01,
           9.0853e-02, -3.6660e-01,  4.7068e-01, -9.0916e-01,  5.7365e-02,
          -1.4034e+00,  3.3440e-01,  1.4931e-01, -1.3336e+00, -1.0846e-02]],

        [[-9.5110e-01, -7.0620e-03,  1.4664e+00, -7.2822e-02,  1.6167e+00,
           2.0529e-01,  2.0434e+00,  6.3391e-01, -6.7473e-01, -1.4003e+00,
           1.4847e+00, -1.1046e-01,  8.9979e-01, -1.9165e+00, -4.4945e-01,
          -2.9848e-01, -1.5685e+00,  4.3572e-01,  1.4080e-01, -1.3210e+00,
          -7.1064e-01,  5.6128e-01,  8.8877e-01,  9.6616e-01, -7.2182e-01,
           1.5497e+00,  1.2478e+00, -2.7926e-01,  2.3774e+00,  4.6843e-01,
           8.2842e-01, -5.0900e-01,  4.3904e-01,  9.3011e-03,  5.6372e-01,
          -6.7623e-01, -1.3174e+00, -3.3618e-01, -8.5175e-01,  1.7084e+00,
           3.9251e-01,  6.1506e-01,  1.6341e+00, -2.7244e-02,  1.0603e+00,
           6.2143e-01,  1.8539e+00,  8.0723e-01,  7.0309e-01,  6.7047e-01,
           2.0411e-01, -5.3258e-01, -1.4166e+00,  6.0213e-01, -1.0231e+00,
          -1.7525e+00,  2.0353e+00, -5.4666e-01,  1.4422e+00, -3.2823e-01,
          -1.1709e+00,  8.1312e-01,  3.5337e-01, -8.8316e-01, -1.2471e+00,
          -1.9352e-01,  5.1235e-01, -8.0290e-01, -6.0411e-01,  4.7489e-02,
          -3.0681e-01, -8.0887e-01, -1.2147e+00,  1.2481e+00, -1.6954e+00,
           1.7670e-01, -6.9225e-01,  7.1540e-02,  2.0898e+00,  1.1484e-01,
          -8.2643e-01,  8.3625e-02,  1.1815e+00, -4.9900e-01,  9.8698e-02,
          -7.1430e-01,  1.7285e+00,  1.4402e+00, -3.2650e-03, -1.4089e-01,
          -1.3567e+00, -9.2281e-02,  3.2848e-01, -1.7771e-01, -2.0604e-01,
           1.4197e+00, -4.3751e-01, -3.2615e-01,  6.2895e-01,  5.3981e-01]],

        [[ 1.6892e+00, -1.9431e-01, -3.2224e-01,  3.4490e-01, -1.1653e+00,
          -3.9481e-01, -3.5266e-01,  5.9497e-01,  2.1040e+00, -1.6459e+00,
          -9.0557e-01, -1.1838e+00, -6.3269e-01, -9.6686e-01, -1.1722e+00,
           5.4055e-01, -2.9333e-01,  1.9952e+00, -1.3771e-01,  2.9921e+00,
          -6.5558e-01, -9.3074e-01, -2.2860e-01, -2.2551e-02,  9.0229e-01,
          -2.0580e-03, -7.8316e-01, -8.9419e-01, -2.2447e-01, -1.8828e+00,
           1.0121e+00,  1.3446e+00, -4.9502e-01, -3.0010e-01, -5.9109e-01,
          -1.4185e-01,  5.4657e-02, -6.8540e-01,  6.5147e-02,  2.9379e+00,
          -4.9744e-01, -1.7516e+00, -2.7392e-01,  1.7361e+00,  2.5768e-01,
           8.0858e-01, -7.7159e-01, -1.6262e+00, -1.6733e+00, -2.0280e-01,
          -4.3399e-01,  2.9403e-01,  4.1337e-01,  1.4134e+00,  1.7226e+00,
          -8.6325e-01,  1.0139e+00,  1.3789e-01, -1.9410e-01,  5.0410e-01,
          -9.0640e-01, -1.4849e+00,  4.1672e-02, -1.7339e-01,  7.6474e-01,
          -6.8545e-01, -8.4875e-01, -6.3901e-01, -1.9897e+00,  9.4109e-02,
          -1.0168e+00, -8.8822e-01,  3.0812e-02,  5.4795e-01, -5.3808e-01,
           3.4739e-01,  1.3666e+00, -7.9314e-02, -5.8982e-03,  8.4411e-01,
           1.1872e-01, -1.1269e+00, -2.4777e+00,  1.0928e+00,  2.2439e+00,
          -5.6637e-01,  4.0842e-01, -1.0176e+00,  1.6326e+00, -1.7560e-01,
           1.1970e-01,  7.9557e-02,  4.5318e-01,  1.2080e+00,  1.6906e+00,
          -3.1913e+00,  6.5363e-01,  1.0044e+00,  6.5124e-01, -4.6142e-01]],

        [[ 1.1474e+00, -1.2219e+00,  1.5311e+00, -1.9181e+00,  1.0266e+00,
          -1.0445e+00, -1.5458e-01, -8.7398e-01, -3.9716e-01,  1.2200e+00,
           5.9102e-01, -2.0237e-01,  2.6036e-01,  2.8829e-01, -6.0330e-01,
          -2.4380e+00, -1.6508e+00,  6.3733e-01, -8.7281e-01,  9.6396e-01,
           3.6009e-02,  4.7853e-01,  5.8551e-01, -1.0222e+00,  4.6850e-01,
           8.5582e-01,  2.3504e-01,  4.0865e-02,  5.8716e-01, -1.5681e+00,
           1.2935e+00, -2.1845e-01, -2.3343e-02,  5.1217e-01, -1.4539e+00,
          -1.0669e-01, -1.1420e+00,  7.4076e-01, -6.0516e-01,  9.7153e-01,
          -1.7203e+00,  1.6090e-01,  7.2770e-01,  1.3474e-01, -2.3276e-01,
          -9.8349e-01,  8.7394e-01,  1.6044e+00,  2.2305e+00,  1.3159e+00,
           1.2003e+00,  6.3528e-01,  4.0749e-01,  1.7186e+00,  9.2511e-01,
          -6.4882e-01, -2.2362e+00, -1.8320e+00,  2.0051e+00,  7.4933e-01,
          -4.9061e-02, -4.2367e-01, -2.9759e-01, -5.2820e-01, -1.0170e+00,
           1.9610e+00, -7.1650e-01,  1.2064e+00, -1.3965e+00,  8.5272e-01,
          -1.1305e+00,  4.8074e-01,  6.8236e-01, -2.9311e-02,  1.4540e+00,
          -2.0741e+00, -1.4512e+00,  5.4786e-01, -3.1652e-01,  6.1505e-01,
           8.1605e-01, -3.6835e-01,  1.6427e-01, -1.1652e+00,  5.2819e-01,
          -9.3690e-01, -1.0042e+00,  2.5815e-01, -2.2093e-01,  1.0568e+00,
           1.8710e-01, -1.6438e+00, -1.0410e+00, -1.6262e-01,  2.7776e-01,
          -1.2441e+00,  7.0846e-01, -1.2571e+00, -1.9005e+00, -5.5108e-01]],

        [[-6.2225e-01,  1.1263e-01,  1.5201e+00,  1.6294e+00,  6.9333e-01,
           1.4749e+00,  4.0515e-01,  1.3435e+00, -1.6834e+00, -1.9123e-01,
          -2.9999e-01, -1.7239e+00,  3.9591e-01, -1.3796e+00,  1.8547e+00,
          -1.6333e+00,  1.7782e+00,  2.2036e-01, -1.0546e+00, -3.7253e-01,
           2.2099e+00, -1.9922e-01, -1.1540e+00,  1.2550e+00, -1.5669e+00,
          -5.9739e-01, -8.7843e-01,  1.5402e+00,  7.8677e-01,  2.5974e+00,
           1.7088e+00,  3.4667e-01, -1.3186e+00, -8.7469e-01,  2.6160e+00,
           1.1151e+00, -2.2571e-01, -5.5196e-01, -3.2358e-01,  6.2215e-02,
          -5.6477e-01, -9.4670e-01,  9.1513e-03, -1.6173e+00,  4.9922e-01,
          -1.4955e+00, -9.5343e-01, -1.1535e+00,  3.5272e-01,  6.0270e-01,
          -3.2133e-01, -5.4476e-01,  1.2251e+00, -1.2561e+00,  1.1752e+00,
          -8.2474e-01,  6.7103e-01, -1.1552e-01, -8.4066e-01,  1.2818e+00,
           1.4465e+00,  5.5877e-01,  6.8670e-01, -2.1592e-01, -1.2302e+00,
          -1.3102e+00,  5.6326e-01,  3.2067e-01, -1.4668e+00,  6.4914e-01,
           8.9192e-01,  1.1502e+00, -6.2337e-02, -8.9526e-01,  7.5140e-01,
           9.3156e-01,  8.2075e-01, -2.6678e-01, -1.6224e+00,  1.0034e+00,
           5.8191e-01, -6.9103e-01, -3.7767e-01,  7.8173e-01, -1.1547e+00,
           5.6047e-02, -5.8822e-01,  5.8918e-01, -7.3669e-01, -1.3374e+00,
          -5.9425e-01,  6.1755e-01,  6.9365e-01,  1.0297e+00,  9.5625e-01,
          -2.8796e-01,  1.7674e-01, -3.5438e-01, -1.3552e+00, -4.5781e-01]]],
       device='cuda:0')