In [1]:
## Standard libraries
import os
import numpy as np
import pandas as pd
import random
import math
import json
from functools import partial

## Imports for plotting
import matplotlib.pyplot as plt
plt.set_cmap('cividis')
%matplotlib inline
#from IPython.display import set_matplotlib_formats
#matplotlib_inline.backend_inline.set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

# ## tqdm for loading bars
# from tqdm.notebook import tqdm

# for one-hot encoding
from keras.utils.np_utils import to_categorical   

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torch.utils.data.dataloader import default_collate
import torch.optim as optim
from torch.optim import Adam
from mlm_pytorch import MLM

# Transformer wrapper
import tensorflow as tf
from x_transformers import TransformerWrapper, Encoder
from torch.nn import TransformerEncoder, TransformerEncoderLayer

# Positional encoding in two dimensions
from positional_encodings.torch_encodings import PositionalEncodingPermute1D, PositionalEncoding1D

from functools import reduce
import math


# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# Path to the folder where the datasets are
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "saved_models/simple_transformer"

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)



2024-12-16 09:25:06.660551: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-12-16 09:25:06.732413: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-16 09:25:06.750223: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Seed set to 42


Device: cuda:0


In [2]:
# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning>=1.5
    import pytorch_lightning as pl

### Data preparation

In [3]:
# Global variables
NUM_FEATURES = 7
NUM_FIX = 30 
BATCH_SIZE = 64
NUM_CLASSES = 31


num_fix = NUM_FIX
seq_len = NUM_FEATURES

In [4]:
# Helper functions

def mask_with_tokens(t, token_ids):
    init_no_mask = torch.full_like(t, False, dtype=torch.bool)
    mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
    return mask
    
# def mse_loss(target, input, ignored_index):
#     mask = target == ignored_index
#     out = (input[~mask]-target[~mask])**2
#     return out.mean()

def mse_loss(target, input, mask):
    out = (input[mask]-target[mask])**2
    return out.mean()

def mask_with_tokens_3D(t, token_ids):
    init_no_mask = torch.full_like(t, False, dtype=torch.bool)
    mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
    reduced = torch.any(mask, dim=-1, keepdim=True)
    expanded = reduced.expand_as(mask)
    return expanded

def get_mask_subset_with_prob_3D(mask, prob):
    batch, num_fix, seq_len, device = *mask.shape, mask.device
    max_masked = math.ceil(prob * num_fix)

    num_tokens = mask.sum(dim=-2, keepdim=True)
    mask_excess = (mask.cumsum(dim=-2)[:,:,0] > (num_tokens[:,:,0] * prob).ceil())
    mask_excess = mask_excess[:, :max_masked]

    rand = torch.rand((batch, num_fix, seq_len), device=device).masked_fill(~mask, -1e9)
    _, sampled_indices = rand.topk(max_masked, dim=-2)
    sampled_indices = (sampled_indices[:,:,0] + 1).masked_fill_(mask_excess, 0)

    new_mask = torch.zeros((batch, num_fix + 1), device=device)
    new_mask.scatter_(-1, sampled_indices, 1)
    new_mask = new_mask[:, 1:].bool()
    
    return new_mask.unsqueeze_(2).expand(-1,-1, seq_len)
    

def prob_mask_like_3D(t, prob):
    temp = torch.zeros_like(t[:,:,0]).float().uniform_(0, 1) < prob
    return temp.unsqueeze_(2).expand(-1,-1, seq_len)


# class CustomPositionalEncoding(nn.Module):
#   """Learnable positional encoding for features """
#     def __init__(self, embed_dim, max_len=5000):
#         super(CustomPositionalEncoding, self).__init__()
        
#         # Initialize a learnable positional encoding matrix
#         self.encoding = nn.Parameter(torch.zeros(max_len, embed_dim))
#         nn.init.xavier_uniform_(self.encoding)  # Xavier initialization for better training stability

#     def forward(self, x):
#         # Add the learnable positional encoding to the input tensor
#         return x + self.encoding[:x.size(1), :]

    
class CustomPositionalEncoding(nn.Module):
    """Learnable positional encoding for both features and fixations"""
    def __init__(self, fixations, embed_dim, max_len=5000):
        super(CustomPositionalEncoding, self).__init__()
        
        # Initialize a learnable positional encoding matrix
        self.encoding = nn.Parameter(torch.zeros(fixations, embed_dim))
        nn.init.xavier_uniform_(self.encoding)  # Xavier initialization for better training stability

    def forward(self, x, mask = None):
        if mask is not None:
            # Apply the mask to ignore padded positions
            pos_encoding = self.encoding * mask
        # Assumes input `x` is of shape [batch_size, fixations, embed_dim]
        return x + pos_encoding#.unsqueeze(0)

    
    
def pad_group_with_zeros(group, target_rows):
    # Calculate the number of rows to add
    num_missing_rows = target_rows - len(group)
    if num_missing_rows > 0:
        # Create a DataFrame with the required number of padding rows
        # input padding
        zero_rows = pd.DataFrame(0.3333, index=range(num_missing_rows), columns=group.columns)
        # Label padding
        # zero_rows.iloc[:, 0] = 31
        # Concatenate the group with the zero rows
        group = pd.concat([group, zero_rows], ignore_index=True)
    return group

class ToTensor(object):
    """Convert Series in sample to Tensors."""

    def __call__(self, sample):
        trial, label = sample['trial'], sample['label']
        trial = torch.from_numpy(trial).float()
        label = torch.from_numpy(label).float()
        return trial, label

In [5]:
class RSCFixationsOrder(Dataset):
    """Dataset with the long-format sequence of fixations made during reading by dyslexic 
    and normally-developing Russian-speaking monolingual children."""

    def __init__(self, csv_file, transform=None, target_transform = None, 
                 n_fix = NUM_FIX, 
                 dropPhonologyFeatures = True, dropPhonologySubjects = True):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            transform (callable, optional): Optional transform to be applied
                on a sample.
            target_transform (callable, optional): Optional transform to be applied
                on a label.
        """
        self.fixations_frame = pd.read_csv(csv_file)
        
        # remove demography and identification
        self.fixations_frame = self.fixations_frame.drop(columns = ['fix_x', 
                                                                    'fix_y', 'full_text' 
                                                                    ]) 
        
        #self.fixations_frame = self.fixations_frame[order_cols]
        
        # Log-transforming appropriate measures
        to_transform = ['frequency', 'fix_dur'] 
        for column in to_transform:
            self.fixations_frame[column] = self.fixations_frame[column].apply(lambda x: np.log(x) if x > 0 else 0) 

        # Center 
        cols = ['fix_dur', 'landing', 'predictability',
                'frequency', 'word_length', 'number.morphemes', 
                'next_fix_dist', 'sac_ampl', 'sac_angle', 
                'sac_vel']
        for col in cols:
            self.fixations_frame[col] = np.where(self.fixations_frame[col] == 0, -4,
                (self.fixations_frame[col] - self.fixations_frame[col].mean())/self.fixations_frame[col].std(ddof=0)) 
        

        # Drop padding 
        self.fixations_frame = self.fixations_frame[self.fixations_frame['fix_dur'] != -4]
        
        
        # Convert direction to a dummy-coded variable
        self.fixations_frame['direction'] = np.where(self.fixations_frame['direction'].isnull(), 0,
                                                     self.fixations_frame['direction'])
        self.fixations_frame = pd.concat([self.fixations_frame, 
                                          pd.get_dummies(self.fixations_frame['direction'], 
                                                         prefix='dummy')], axis=1)
        
        # Center exceptions: 
        # columns where 0 is meaningful and should not be counted as NA
#         cols = ['rel.position', 'dummy_DOWN', 'dummy_LEFT','dummy_RIGHT', 'dummy_UP']
#         for col in cols:
#             self.fixations_frame[col] = (self.fixations_frame[col] - \
#                                              self.fixations_frame[col].mean())/self.fixations_frame[col].std(ddof=0)
            
        if dropPhonologySubjects == True:
            # Drop subjects
            self.fixations_frame.dropna(axis = 0, how = 'any', inplace = True)
        else:
            # Drop columns
            self.fixations_frame.dropna(axis = 1, how = 'any', inplace = True)
        
        
        self.fixations_frame['subj'] = self.fixations_frame['subj'].astype(str)
        self.fixations_frame['item'] = self.fixations_frame['sn'].astype(str)
        self.fixations_frame['Combined'] = self.fixations_frame[['subj', 'item']].agg('_'.join, axis=1)
        
        # cleaning up
        self.fixations_frame.drop(columns = ['subj', 'item', 'direction', 'dummy_0', 'sn',\
                                            'dummy_DOWN', 'dummy_LEFT', 'dummy_UP'\
                                            ], inplace = True)
        self.fixations_frame.drop(columns = ['word_length', 'predictability', 
                                             'frequency', 'number.morphemes', 'fix_index'], 
                                             inplace = True)
        

        padded = self.fixations_frame.groupby('Combined', group_keys=False).apply(lambda x: 
                                                                           pad_group_with_zeros(x, n_fix))
        padded.drop(columns = "Combined", inplace = True)
    
        
        #### Fixation index is the label
        self.fixations_frame = padded.to_numpy()
        dataReshaped = np.reshape(self.fixations_frame, (int(len(self.fixations_frame)/n_fix), 
                                                         n_fix, self.fixations_frame.shape[1]))

        self.predictors = dataReshaped[:,:,1:]
        self.labels = dataReshaped[:,:,0]
        self.labels = to_categorical(self.labels-1, num_classes=NUM_CLASSES)   # one-hot encoding

        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        trial = self.predictors[idx]
        label = self.labels[idx]
        
        sample = {'trial': trial, 'label': label}

        if self.transform:
            sample = self.transform(sample)
            
        if self.target_transform:
            sample = self.target_transform(sample)
            
        return sample

In [6]:
transformed_dataset = RSCFixationsOrder(csv_file='data/RSC_long_no_word_padded_word_pos.csv', 
                                    transform=ToTensor(),
                                    dropPhonologySubjects = True)

train_size = int(0.8 * len(transformed_dataset))
val_size = int(0.1 * len(transformed_dataset))
test_size = len(transformed_dataset) - val_size - train_size
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(transformed_dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)))
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)))
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)))

In [7]:
trials, labels = next(iter(train_loader))

train_features, train_labels = next(iter(train_loader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")

Feature batch shape: torch.Size([64, 30, 7])
Labels batch shape: torch.Size([64, 30, 31])


In [8]:
trials[1, 0:15,:]

tensor([[ 1.6186e+00,  1.3636e-03,  4.8822e-02,  2.0223e-01, -4.0000e+00,
          0.0000e+00,  1.0000e+00],
        [ 2.1263e-01,  5.3885e-01,  6.6894e-01, -2.4092e-01, -4.0000e+00,
          1.0000e-01,  1.0000e+00],
        [ 4.5904e-01,  3.6327e-01,  5.0904e-01,  4.1091e+00, -4.0000e+00,
          2.0000e-01,  0.0000e+00],
        [ 9.0372e-01,  1.0978e+00,  1.2774e+00, -1.1564e-01, -4.0000e+00,
          1.0000e-01,  1.0000e+00],
        [ 1.4327e+00,  1.0728e+00,  2.0483e-01, -6.6295e-02, -4.0000e+00,
          3.0000e-01,  1.0000e+00],
        [ 6.1280e-01, -5.2385e-02, -1.8799e-03, -7.7553e-02, -4.0000e+00,
          5.0000e-01,  1.0000e+00],
        [ 1.8812e+00,  2.1636e-01,  1.7753e-01, -5.4318e-02, -4.0000e+00,
          5.0000e-01,  1.0000e+00],
        [ 6.6885e-01,  2.8802e-01,  2.6723e-01, -1.5469e-01, -4.0000e+00,
          6.0000e-01,  1.0000e+00],
        [ 3.0441e+00,  4.0985e-01,  3.0623e-01,  1.1316e-02, -4.0000e+00,
          7.0000e-01,  1.0000e+00],
        [-

### Embedding

-- Instead of reconstructing the vocabulary, we would use MSE and try to reconstruct *x, y, fix dur* and we can look into how to reconstruct the rest of the features (i.e. all the additional that you used).

-- and an adjustable pytorch implementation here:
https://github.com/lucidrains/mlm-pytorch/tree/master

-- https://github.com/lucidrains/mlm-pytorch

-- Embedding for position - DONE

In [9]:
class AbsolutePositionalEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len, l2norm_embed = False):
        super().__init__()
        self.scale = dim ** -0.5 if not l2norm_embed else 1.
        self.max_seq_len = max_seq_len
        self.l2norm_embed = l2norm_embed
        self.emb = nn.Embedding(max_seq_len, dim)

    def forward(self, x, pos = None, seq_start_pos = None):
        seq_len, device = x.shape[1], x.device
        assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'

#         if not exists(pos):
#             pos = torch.arange(seq_len, device = device)

#         if exists(seq_start_pos):
#             pos = (pos - seq_start_pos[..., None]).clamp(min = 0)

        pos_emb = self.emb(pos)
        pos_emb = pos_emb * self.scale
        return l2norm(pos_emb) if self.l2norm_embed else pos_emb


In [11]:
# pos = torch.arange(seq_len, device = device)
# pos = pos_emb(pos)
# pos

NameError: name 'pos_emb' is not defined

In [None]:
# pos_emb = AbsolutePositionalEmbedding(dim=seq_len, max_seq_len = num_fix)
# pos_emb(X_train)

In [None]:
# nn.Embedding(30, 10)

In [12]:
# Annas changes
class TransformerWithCustomPositionalEncoding(nn.Module):
    def __init__(
        self, 
        embed_dim, 
        num_heads, 
        num_layers, 
        max_len=5000,
        mask_prob = 0.15,
        replace_prob = 1, # 0.9
        mask_token_id = 2,
        pad_token_id = 0.3333,
        mask_ignore_token_ids = []
        ):
        super(TransformerWithCustomPositionalEncoding, self).__init__()
        
        self.mask_prob = mask_prob
        self.replace_prob = replace_prob

        # token ids
        self.pad_token_id = pad_token_id
        self.mask_token_id = mask_token_id
        self.mask_ignore_token_ids = set([*mask_ignore_token_ids, pad_token_id])
        
        # positional encoding
        self.positional_encoding = CustomPositionalEncoding(num_fix, embed_dim).to(device)
        
        # transformer
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, 
                                                        nhead=num_heads, 
                                                        batch_first = True)
        self.transformer = TransformerEncoder(self.encoder_layer, num_layers=num_layers).to(device)

    def forward(self, seq):
        # do not mask [pad] tokens, or any other tokens in the tokens designated to be excluded ([cls], [sep])
        # also do not include these special tokens in the tokens chosen at random
        no_mask = mask_with_tokens_3D(seq, self.mask_ignore_token_ids) 
        mask = get_mask_subset_with_prob_3D(~no_mask, self.mask_prob)

        # mask input with mask tokens with probability of `replace_prob` (keep tokens the same with probability 1 - replace_prob)
        masked_seq = seq.clone().detach()
        masked_seq_pos = self.positional_encoding(masked_seq, mask = ~no_mask)

        # [mask] input
        masked_replace_prob = prob_mask_like_3D(seq, self.replace_prob) # Anna: select 90% of all values  (ignore all masking for now)
        masked_seq = masked_seq_pos.masked_fill(mask * masked_replace_prob, self.mask_token_id) # Anna: select 90% only of those selected for masking
        
        # derive labels to predict
        labels = seq # self.positional_encoding(seq) # add positional encoding before masking, so that encoding does not affect mask
        labels = labels.masked_fill(~mask, self.pad_token_id)
        
        # Pass through the transformer
        preds = self.transformer(masked_seq)
        
        my_loss = mse_loss(
            labels,
            preds,
            #ignored_index = self.pad_token_id
            mask = mask
        )

        return preds, my_loss

In [17]:
# X_train = trials
# seq = X_train

# mask_prob = 0.15
# replace_prob = 1
# mask_token_id = 2
# pad_token_id = 0.3333
# mask_ignore_token_ids = set([pad_token_id])
# embed_dim = 10

# positional_encoding = CustomPositionalEncoding(num_fix, seq_len).to(device)

# # transformer
# encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, 
#                                                 nhead=2, 
#                                                 batch_first = True)
# transformer = TransformerEncoder(encoder_layer, num_layers=2).to(device)

# no_mask = mask_with_tokens_3D(seq, mask_ignore_token_ids) # works in 3D
# mask = get_mask_subset_with_prob_3D(~no_mask, mask_prob)

# # mask input with mask tokens with probability of `replace_prob` (keep tokens the same with probability 1 - replace_prob)
# masked_seq = seq.clone().detach()
# masked_seq_pos = positional_encoding(masked_seq, mask = ~no_mask)

# # [mask] input
# masked_replace_prob = prob_mask_like_3D(seq, replace_prob) # Anna: select 90% of all values  (ignore all masking for now)
# masked_seq = masked_seq_pos.masked_fill(mask * masked_replace_prob, mask_token_id) # Anna: select 90% only of those selected for masking

# masked_seq[4]

In [26]:
# # derive labels to predict
# labels = seq # self.positional_encoding(seq) # add positional encoding before masking, so that encoding does not affect mask
# labels = labels.masked_fill(~mask, pad_token_id)

# labels[4]

In [None]:
# labels[0,1] == seq[0,1]

In [13]:
# Ok, let's see:
trainer = TransformerWithCustomPositionalEncoding(num_heads = 1, 
                                                  num_layers= 2, 
                                                  embed_dim = seq_len).cuda()

print(trainer)

TransformerWithCustomPositionalEncoding(
  (positional_encoding): CustomPositionalEncoding()
  (encoder_layer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=7, out_features=7, bias=True)
    )
    (linear1): Linear(in_features=7, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=7, bias=True)
    (norm1): LayerNorm((7,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((7,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=7, out_features=7, bias=True)
        )
        (linear1): Linear(in_features=7, out_features=2048, bias=True)




In [14]:
# # Setup optimizer to optimize model's parameters
optimizer = Adam(trainer.parameters(), lr=3e-4)

epochs = 81

for epoch in range(epochs):
    train_loss = 0.0
    ### Training
    trainer.train()
    
    for X_train, y_train in train_loader:
        # 1. Forward pass 
        # 2. Calculate loss/accuracy
        train_preds, loss = trainer(X_train)
        
        # 3. Optimizer zero grad
        optimizer.zero_grad()

        # 4. Loss backwards
        loss.backward()
 
        # 5. Optimizer step
        optimizer.step()
        
        train_loss += loss.item() * X_train.size(0)
        
    train_loss /= len(train_loader.dataset)

    #### Evaluation
    test_loss = 0.0
    trainer.eval()
    with torch.no_grad():
        for X_test, y_test in test_loader:
            test_preds, tloss = trainer(X_test)
            test_loss += tloss.item() * X_test.size(0)

    test_loss /= len(test_loader.dataset)


    # Print out what's happening every 10 epochs
    if epoch % 10 == 0:
        print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')


Epoch 1: Train Loss: 2.2353, Test Loss: 2.1076
Epoch 11: Train Loss: 1.3506, Test Loss: 1.2883
Epoch 21: Train Loss: 1.2263, Test Loss: 1.2021
Epoch 31: Train Loss: 1.1994, Test Loss: 1.1825
Epoch 41: Train Loss: 1.1741, Test Loss: 1.1848
Epoch 51: Train Loss: 1.1644, Test Loss: 1.2053
Epoch 61: Train Loss: 1.1734, Test Loss: 1.1306
Epoch 71: Train Loss: 1.1592, Test Loss: 1.1860
Epoch 81: Train Loss: 1.1578, Test Loss: 1.2287


In [15]:
X_test[6, 0:8]

tensor([[ 1.2460, -0.0846,  0.0449,  0.2025,  0.4812,  0.0000,  1.0000],
        [-0.1541,  0.6463,  0.8288, -0.0110,  1.4887,  0.0909,  1.0000],
        [ 0.4718,  0.2880,  0.4778, -0.0474,  1.0629,  0.2727,  1.0000],
        [ 1.6479,  0.7001,  0.7508, -0.0155,  1.0056,  0.3636,  1.0000],
        [ 0.2884,  0.5639,  0.7391,  0.0501,  0.9180,  0.5455,  1.0000],
        [ 0.9838,  0.7431,  0.9654,  0.0322,  1.3550,  0.6364,  1.0000],
        [ 3.1893,  0.2450,  0.4388,  0.0614,  0.9648,  0.7273,  1.0000],
        [ 1.0871,  3.8139,  4.3000, -0.8036,  3.7348,  0.9091,  1.0000]],
       device='cuda:0')

In [16]:
test_preds[6, 0:8]

tensor([[ 1.0681,  0.5923,  0.4687,  0.0894,  1.0752,  0.4437,  0.8271],
        [ 1.0517,  0.6961,  0.5399,  0.0664,  1.0765,  0.4452,  0.8137],
        [ 1.0538,  0.3869,  0.3900,  0.0081,  0.9554,  0.4292,  0.8181],
        [ 1.0685,  0.5520,  0.3248,  0.0790,  0.9086,  0.4410,  0.8213],
        [ 1.0538,  0.3869,  0.3900,  0.0081,  0.9554,  0.4292,  0.8181],
        [ 1.0357,  0.6291,  0.2132, -0.0023,  0.5807,  0.4299,  0.7967],
        [ 1.0453,  0.7976,  0.3094,  0.0062,  0.5926,  0.4323,  0.8061],
        [ 1.0034,  0.5951,  0.7288, -0.0642,  1.0518,  0.4237,  0.8004]],
       device='cuda:0')