In [1]:
import sys
sys.path.append("../")

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.autograd.profiler as profiler

import apex.fp16_utils as fp16

import os
import numpy as np
from sklearn.datasets import make_classification
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser

from utils.moduleCodeProfiler import rankByCriteria


In [2]:
!nvidia-smi

Sun Nov 22 17:06:13 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.32.00    Driver Version: 455.32.00    CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           On   | 000047DD:00:00.0 Off |                    0 |
| N/A   51C    P8    51W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [3]:
cuda0 = torch.device('cuda:0') 

In [4]:
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
args = parser.parse_args('')

args.data_dir = '~/datadrive'
args.dataset_dir = 'toy_mlp_1'
args.seed = 123
args.batch_size = 1000
# https://stackoverflow.com/questions/15753701/how-can-i-pass-a-list-as-a-command-line-argument-with-argparse
args.hidden_layer_dims = [10, 10, 10, 10]
args.lr = 0.01
args.epochs = 20

# Toy Data Generation

In [5]:
# construct and save toydataset

m_train = 9000
m_total = m_train

X, y = make_classification(n_samples=m_total, n_features=10, n_informative=10, n_redundant=0, n_repeated=0, n_classes=5, n_clusters_per_class=2, weights=None, flip_y=0.01, class_sep=1.0, hypercube=True, shift=0.0, scale=1.0, shuffle=True, random_state=args.seed)
# y = np.expand_dims(y, -1)

np.random.seed(args.seed)
permutation = np.random.permutation(m_total)
print('First 10 training indices', permutation[:10])
print('X shape', X.shape)
print('y shape', y.shape)

train_indices = permutation[0:m_train]

dataset_dir = 'toy_mlp_1'
os.makedirs(os.path.join(args.data_dir, dataset_dir, 'train'), mode = 0o777, exist_ok = True) 

np.save(os.path.join(args.data_dir, dataset_dir, 'train', 'features.npy'), X[train_indices])
np.save(os.path.join(args.data_dir, dataset_dir, 'train', 'labels.npy'), y[train_indices])

First 10 training indices [1603 8472 2213  498 1038 8399 3324 7535 1519 1959]
X shape (9000, 10)
y shape (9000,)


In [6]:
class ToyDataset(Dataset):
    """Toy dataset construction."""

    def __init__(self, data_dir):
        """
        Args:
            data_dir (string): Path to the directory with data files.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        # shape (m, nx)
        self.X = np.load(os.path.join(data_dir, 'features.npy'))
        # shape (m, ny=1)
        self.y = np.load(os.path.join(data_dir, 'labels.npy'))
        

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        else:
            X = torch.from_numpy(self.X[idx, :]).type(torch.HalfTensor)
            y = torch.tensor(self.y[idx], dtype=torch.long)
#             y = torch.from_numpy(self.y[idx, :]).type(torch.FloatTensor)
            sample = {'X': X, 'y': y}

        return sample

# Model

In [7]:
class MLPLazy(nn.Module):

    def __init__(self, nx, hidden_layer_dims, ny):
        super(MLPLazy, self).__init__()
        self.hidden_layer_dims = hidden_layer_dims
        
        linear_layers = []
        last_dim = nx
        for next_dim in hidden_layer_dims:
            linear_layer = nn.Linear(last_dim, next_dim)
            linear_layers.append(linear_layer)
            last_dim = next_dim
        # should push to ModuleList so that params stay on cuda
        self.linear_layers = nn.ModuleList(linear_layers)
        self.scorer = nn.Linear(last_dim, ny)

    def forward(self, X):
        '''
        X has shape (m, nx)
        '''
        last_X = X
        for i, linear_layer in enumerate(self.linear_layers):
            # shape (m, self.hidden_layer_dims[i])
            last_X = linear_layer(last_X)
            # shape (m, self.hidden_layer_dims[i])
            last_X = torch.relu(last_X)
        # shape (m, ny)
        z = self.scorer(last_X)
        # shape (m, ny)
        a = torch.softmax(z, dim=1)
        return z, a

# Workflow

In [8]:
def check_weights_precision(model):
    '''check weight precisions for each layer of MLP'''
    for i, layer in enumerate(model.linear_layers):
        print(f'layer {i}, weight dtype {layer.weight.dtype}')
        print(f'layer {i}, bias dtype {layer.bias.dtype}')
    print(f'scorer weight dtype {model.scorer.weight.dtype}')
    print(f'scorer bias dtype {model.scorer.bias.dtype}')

In [9]:
def get_master(opt):
    '''create a float32 master copy of float16 model weights in optimizer'''
    model_pgs = [[param for param in pg['params'] if param.requires_grad] for pg in opt.param_groups]
    master_pgs = [[param.clone().float().detach() for param in pg] for pg in model_pgs]
    for pg in master_pgs:
        for param in pg: param.requires_grad_(True)
    return model_pgs, master_pgs

In [10]:
def push_master_to_optimizer(opt, master_pgs):
    '''
        link master copy pgs to optimizer, 
        keeping other hparams such as lr, momentum dampening, weight_decay...'''
    for opt_pg, master_pg in zip(opt.param_groups, master_pgs):
        opt_pg['params'] = master_pg

In [11]:
def to_master_grads(model_pgs, master_pgs, flat_master:bool=False):
    '''copy float16 gradients from model to float32 gradients in master copy of weights'''
    for (model_params,master_params) in zip(model_pgs,master_pgs):
        fp16.model_grads_to_master_grads(model_params, master_params, flat_master=flat_master)

In [12]:
def to_model_params(model_pgs, master_pgs, flat_master:bool=False)->None:
    '''copy master copy of updated weights in float32 to model weights in float 16'''
    for (model_params,master_params) in zip(model_pgs,master_pgs):
        fp16.master_params_to_model_params(model_params, master_params, flat_master=flat_master)

In [13]:
def scale_down_master_grad(master_pgs, loss_scale):
    '''
    scale down all gradients for all master param groups
    '''
    for master_params in master_pgs:
        for param in master_params:
            if param.grad is not None:
                param.grad.div_(loss_scale)

In [14]:
def get_max_memory_alloc():
    '''read and reset max memory allocation'''
    devices_max_memory_alloc = {}
    for i in range(torch.cuda.device_count()):
        device = torch.device(f'cuda:{i}')
        devices_max_memory_alloc[device] = torch.cuda.max_memory_allocated(device) / 1e6
        torch.cuda.reset_max_memory_allocated(device)
    return devices_max_memory_alloc

In [15]:
def main_train(args, gpu=0, debug=False):

    torch.manual_seed(args.seed)
    
    ################################################################
    # load datasets
    training_set = ToyDataset(data_dir=os.path.join(args.data_dir, args.dataset_dir, 'train'))
    training_generator = torch.utils.data.DataLoader(dataset=training_set, 
                                                        batch_size=args.batch_size, 
                                                        shuffle=True, 
                                                        num_workers=0, 
                                                        pin_memory=True)

    nx = training_set.X.shape[1]
    ny = max(training_set.y) + 1
    ################################################################
    loss_scale = 512

    # 1. Create model
    model = MLPLazy(nx, args.hidden_layer_dims, ny)  # single
    loss_criterion = nn.CrossEntropyLoss(reduction='mean')
    torch.cuda.set_device(gpu)
    model.to(device=gpu)    

    # 2. initialize optimizer
    opt = torch.optim.SGD(model.parameters(), lr=args.lr)  # half
    if debug:
        print('\nmodel weights at init')
        check_weights_precision(model)

    # 3. Cast model to float16
    fp16.convert_network(model, torch.float16)
    if debug:
        print('\nmodel weights after casting')
        check_weights_precision(model)

    # 4. Create a copy of this float16 model's weight in float32 as the master copy
    model_pgs, master_pgs = get_master(opt)  # half, single

    # 5. replace optimizer float16 weights with float32 master copy
    push_master_to_optimizer(opt, master_pgs)  # opt single

    def check_grad():
        print('optimizer grad:\n', opt.param_groups[0]['params'][0].grad)
        print('master pg grad:\n', master_pgs[0][0].grad)
        print('model pg grad:\n', model_pgs[0][0].grad)

    def check_weights():
        print('optimizer weights:\n', opt.param_groups[0]['params'][0])
        print('master pg weights:\n', master_pgs[0][0])
        print('model pg weights:\n', model_pgs[0][0])

    history = {'train_losses': [], 'max_memory_allocation':[]}

    for e in range(2):
        model.train()
        sum_batch_losses = torch.tensor([0.], dtype=torch.float, device=gpu)
        batch_max_memory_alloc = []
        for batch_i, batch_data in enumerate(training_generator):
            if debug: 
                print(f'\nRunning batch_{batch_i}-----------------------------------------')

            batch_max_memory_alloc.append(get_max_memory_alloc())

            # 6. model forward with float16 data
            # NOTE: model zero grad, master last grad
            batch_X = batch_data['X'].cuda(gpu, non_blocking=True) # half
            batch_y = batch_data['y'].cuda(gpu, non_blocking=True) # long
            logits, activations = model(batch_X) # Half

            # 7. compute loss in float16
            # NOTE: model zero grad, master last grad
            loss = loss_criterion(logits, batch_y)  # half
            if debug:
                print('\nComputed loss')
                check_grad()
                check_weights()

            # 8. scale up loss here
            print('\nloss before scaling:', loss.item(), loss.dtype)
            loss *= loss_scale
            print('loss after scaling:', loss.item(), loss.dtype)

            # 9. backprop to compute gradients
            # NOTE: model new grad, master last grad
            # NOTE: when we call backward, gradient accumulate on model gradient, not master gradient
            loss.backward()  # half
            if debug:
                print('\nCalled backward to compute gradients')
                check_grad()
                check_weights()

            # 10. copy float16 gradients from model to float32 gradients in master copy of weights
            # NOTE: model new grad, master new grad
            to_master_grads(model_pgs, master_pgs)
            if debug:
                print('\nCopied model grad to master grad')
                check_grad()
                check_weights()

            # 11. scale down master copy gradients
            scale_down_master_grad(master_pgs, loss_scale)
            if debug:
                print('\nScaled down gradients in master copy')
                check_grad()
                check_weights()

            # 12.copy float16 gradients from model to float32 gradients in master copy of weights
            opt.step()
            if debug:
                print('\nOptimizer stepped')
                check_grad()
                check_weights()

            # 13.zero out gradients in model
            model.zero_grad()
            if debug:
                print('\nModel gradients zeroed out')
                check_grad()
                check_weights()

            # 14.copy float32 master weights to float16 model weights
            # NOTE: model zero grad, master new grad
            to_model_params(model_pgs, master_pgs)
            if debug:
                print('\nCopy master weights to model weight')
                check_grad()
                check_weights()
                            
            sum_batch_losses += loss


        num_batches = batch_i + 1.
        history['train_losses'].append(sum_batch_losses/num_batches)
        history['max_memory_allocation'] += batch_max_memory_alloc
    
    itemize = lambda x: [tensor_val.item()/loss_scale for tensor_val in x]
    history['train_losses'] = itemize(history['train_losses'])    

    return history, model

# Train

In [16]:
!nvidia-smi

Sun Nov 22 17:06:44 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.32.00    Driver Version: 455.32.00    CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           On   | 000047DD:00:00.0 Off |                    0 |
| N/A   46C    P8    27W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [17]:
with profiler.profile(profile_memory=True, record_shapes=True, use_cuda=False, with_stack=True) as prof:
    with profiler.record_function("forward"):
        history, model = main_train(args, debug=True)


model weights at init
layer 0, weight dtype torch.float32
layer 0, bias dtype torch.float32
layer 1, weight dtype torch.float32
layer 1, bias dtype torch.float32
layer 2, weight dtype torch.float32
layer 2, bias dtype torch.float32
layer 3, weight dtype torch.float32
layer 3, bias dtype torch.float32
scorer weight dtype torch.float32
scorer bias dtype torch.float32

model weights after casting
layer 0, weight dtype torch.float16
layer 0, bias dtype torch.float16
layer 1, weight dtype torch.float16
layer 1, bias dtype torch.float16
layer 2, weight dtype torch.float16
layer 2, bias dtype torch.float16
layer 3, weight dtype torch.float16
layer 3, bias dtype torch.float16
scorer weight dtype torch.float16
scorer bias dtype torch.float16

Running batch_0-----------------------------------------

Computed loss
optimizer grad:
 None
master pg grad:
 None
model pg grad:
 None
optimizer weights:
 tensor([[-0.1289,  0.0105, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1998,



tensor([[-0.1289,  0.0105, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1015],
        [ 0.2236,  0.0589,  0.0865,  0.3052, -0.1426,  0.1002, -0.1407,  0.2260,
          0.2524, -0.2915],
        [ 0.2700,  0.1510,  0.1378,  0.1301,  0.2629, -0.0417, -0.2673, -0.0908,
         -0.2227,  0.0209],
        [-0.0591, -0.1697, -0.0287,  0.2996, -0.0249,  0.0100, -0.0493,  0.0497,
          0.2817,  0.1934],
        [ 0.1122,  0.0687,  0.0746,  0.1222, -0.0409, -0.2939, -0.1956,  0.2700,
          0.0189, -0.2561],
        [ 0.0499,  0.2612, -0.2988, -0.2129, -0.1260,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1481],
        [ 0.2776,  0.1932, -0.2239, -0.2549,  0.1313,  0.0071,  0.1296, -0.3091,
         -0.0189,  0.2230],
        [ 0.1467,  0.0116,  0.0622, -0.0299, -0.1738, -0.1195, -0.1926,  0.2627,
          0.1741,  0.1106],
        [-0.2426,  

 tensor([[-0.2788,  0.1672,  0.0238, -0.5366,  0.0590, -0.2313,  0.2269,  0.3840,
         -0.5186,  0.4875],
        [-0.5679, -0.0563,  0.1699, -0.0338,  0.0695, -0.2069,  0.3701, -0.0789,
         -0.4741,  0.6426],
        [ 0.7925,  0.4333,  0.6313, -0.7314, -1.1006,  0.2135, -1.5195,  0.2527,
          0.7837, -0.6299],
        [-0.6255, -0.6328, -0.1174, -0.6680, -0.0582,  0.4646, -0.3479,  0.0603,
         -0.2394,  1.7266],
        [ 0.5229,  0.6533,  0.4727,  0.0843,  0.7075, -0.6187, -0.2072, -0.4336,
          0.4414,  0.1710],
        [-0.2864, -0.6733, -0.5811, -1.1514, -0.2260,  0.2661, -0.8428, -0.3711,
         -0.8711,  2.4180],
        [ 0.4114,  0.1060, -0.1840,  0.5049, -0.2231,  0.2957,  0.0325, -0.0106,
          0.2186, -0.6450],
        [-0.9302,  0.1155,  0.7334, -0.7227,  0.6450,  0.0662,  0.1432, -0.2998,
          0.6089,  0.9917],
        [-0.1735,  0.1075, -0.4229, -0.0087,  0.1104,  0.0100,  0.3557,  0.8838,
         -0.4785,  0.2861],
        [-0.1775, 

 tensor([[-0.1289,  0.0105, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1015],
        [ 0.2236,  0.0589,  0.0865,  0.3052, -0.1426,  0.1002, -0.1407,  0.2260,
          0.2524, -0.2915],
        [ 0.2700,  0.1510,  0.1378,  0.1301,  0.2629, -0.0417, -0.2673, -0.0908,
         -0.2227,  0.0209],
        [-0.0591, -0.1697, -0.0287,  0.2996, -0.0249,  0.0100, -0.0493,  0.0497,
          0.2817,  0.1934],
        [ 0.1122,  0.0687,  0.0746,  0.1222, -0.0409, -0.2939, -0.1956,  0.2700,
          0.0189, -0.2561],
        [ 0.0499,  0.2612, -0.2988, -0.2129, -0.1260,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1481],
        [ 0.2776,  0.1932, -0.2239, -0.2549,  0.1313,  0.0071,  0.1296, -0.3091,
         -0.0189,  0.2230],
        [ 0.1467,  0.0116,  0.0622, -0.0299, -0.1738, -0.1195, -0.1926,  0.2627,
          0.1741,  0.1106],
        [-0.2426, 

 tensor([[-5.4455e-04,  3.2663e-04,  4.6521e-05, -1.0481e-03,  1.1522e-04,
         -4.5180e-04,  4.4322e-04,  7.5006e-04, -1.0128e-03,  9.5224e-04],
        [-1.1091e-03, -1.1003e-04,  3.3188e-04, -6.6042e-05,  1.3566e-04,
         -4.0412e-04,  7.2289e-04, -1.5402e-04, -9.2602e-04,  1.2550e-03],
        [ 1.5478e-03,  8.4639e-04,  1.2331e-03, -1.4286e-03, -2.1496e-03,
          4.1699e-04, -2.9678e-03,  4.9353e-04,  1.5306e-03, -1.2302e-03],
        [-1.2217e-03, -1.2360e-03, -2.2936e-04, -1.3046e-03, -1.1373e-04,
          9.0742e-04, -6.7949e-04,  1.1784e-04, -4.6754e-04,  3.3722e-03],
        [ 1.0214e-03,  1.2760e-03,  9.2316e-04,  1.6463e-04,  1.3819e-03,
         -1.2083e-03, -4.0460e-04, -8.4686e-04,  8.6212e-04,  3.3402e-04],
        [-5.5933e-04, -1.3151e-03, -1.1349e-03, -2.2488e-03, -4.4131e-04,
          5.1975e-04, -1.6460e-03, -7.2479e-04, -1.7014e-03,  4.7226e-03],
        [ 8.0347e-04,  2.0695e-04, -3.5930e-04,  9.8610e-04, -4.3583e-04,
          5.7745e-04,  6.3539e-

 tensor([[-0.1289,  0.0105, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1015],
        [ 0.2236,  0.0589,  0.0865,  0.3052, -0.1426,  0.1002, -0.1407,  0.2259,
          0.2524, -0.2915],
        [ 0.2700,  0.1510,  0.1378,  0.1301,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0209],
        [-0.0591, -0.1697, -0.0288,  0.2996, -0.0250,  0.0101, -0.0493,  0.0497,
          0.2817,  0.1934],
        [ 0.1122,  0.0687,  0.0746,  0.1222, -0.0409, -0.2940, -0.1955,  0.2700,
          0.0189, -0.2562],
        [ 0.0499,  0.2612, -0.2988, -0.2129, -0.1260,  0.0127, -0.0737, -0.0348,
         -0.3084,  0.1481],
        [ 0.2776,  0.1932, -0.2239, -0.2549,  0.1313,  0.0071,  0.1296, -0.3091,
         -0.0189,  0.2230],
        [ 0.1467,  0.0115,  0.0622, -0.0299, -0.1738, -0.1195, -0.1926,  0.2627,
          0.1741,  0.1106],
        [-0.2426, 

 tensor([[-5.4455e-04,  3.2663e-04,  4.6521e-05, -1.0481e-03,  1.1522e-04,
         -4.5180e-04,  4.4322e-04,  7.5006e-04, -1.0128e-03,  9.5224e-04],
        [-1.1091e-03, -1.1003e-04,  3.3188e-04, -6.6042e-05,  1.3566e-04,
         -4.0412e-04,  7.2289e-04, -1.5402e-04, -9.2602e-04,  1.2550e-03],
        [ 1.5478e-03,  8.4639e-04,  1.2331e-03, -1.4286e-03, -2.1496e-03,
          4.1699e-04, -2.9678e-03,  4.9353e-04,  1.5306e-03, -1.2302e-03],
        [-1.2217e-03, -1.2360e-03, -2.2936e-04, -1.3046e-03, -1.1373e-04,
          9.0742e-04, -6.7949e-04,  1.1784e-04, -4.6754e-04,  3.3722e-03],
        [ 1.0214e-03,  1.2760e-03,  9.2316e-04,  1.6463e-04,  1.3819e-03,
         -1.2083e-03, -4.0460e-04, -8.4686e-04,  8.6212e-04,  3.3402e-04],
        [-5.5933e-04, -1.3151e-03, -1.1349e-03, -2.2488e-03, -4.4131e-04,
          5.1975e-04, -1.6460e-03, -7.2479e-04, -1.7014e-03,  4.7226e-03],
        [ 8.0347e-04,  2.0695e-04, -3.5930e-04,  9.8610e-04, -4.3583e-04,
          5.7745e-04,  6.3539e-

 tensor([[-0.1289,  0.0105, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1015],
        [ 0.2236,  0.0589,  0.0865,  0.3052, -0.1426,  0.1002, -0.1407,  0.2259,
          0.2524, -0.2915],
        [ 0.2700,  0.1510,  0.1378,  0.1301,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0209],
        [-0.0591, -0.1697, -0.0288,  0.2996, -0.0250,  0.0101, -0.0493,  0.0497,
          0.2817,  0.1934],
        [ 0.1122,  0.0687,  0.0746,  0.1222, -0.0409, -0.2940, -0.1955,  0.2700,
          0.0189, -0.2562],
        [ 0.0499,  0.2612, -0.2988, -0.2129, -0.1260,  0.0127, -0.0737, -0.0348,
         -0.3084,  0.1481],
        [ 0.2776,  0.1932, -0.2239, -0.2549,  0.1313,  0.0071,  0.1296, -0.3091,
         -0.0189,  0.2230],
        [ 0.1467,  0.0115,  0.0622, -0.0299, -0.1738, -0.1195, -0.1926,  0.2627,
          0.1741,  0.1106],
        [-0.2426, 

 Parameter containing:
tensor([[-0.1289,  0.0105, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1015],
        [ 0.2236,  0.0589,  0.0865,  0.3052, -0.1426,  0.1002, -0.1407,  0.2260,
          0.2524, -0.2915],
        [ 0.2700,  0.1510,  0.1378,  0.1301,  0.2629, -0.0417, -0.2673, -0.0908,
         -0.2227,  0.0209],
        [-0.0591, -0.1697, -0.0288,  0.2996, -0.0250,  0.0101, -0.0493,  0.0497,
          0.2817,  0.1934],
        [ 0.1122,  0.0687,  0.0746,  0.1222, -0.0409, -0.2939, -0.1956,  0.2700,
          0.0189, -0.2561],
        [ 0.0499,  0.2612, -0.2988, -0.2129, -0.1260,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1481],
        [ 0.2776,  0.1932, -0.2239, -0.2549,  0.1313,  0.0071,  0.1296, -0.3091,
         -0.0189,  0.2230],
        [ 0.1467,  0.0116,  0.0622, -0.0299, -0.1738, -0.1195, -0.1926,  0.2627,
          0.1741,  0.110

 tensor([[-2.1303e-04,  9.1839e-04,  4.0650e-04, -4.6682e-04,  6.8140e-04,
         -2.2972e-04,  1.0557e-03,  5.2786e-04, -1.2436e-03,  5.8699e-04],
        [-6.6710e-04,  1.4579e-04,  4.7064e-04,  4.1151e-04,  1.7703e-04,
         -4.2105e-04,  2.2340e-04, -3.7265e-04, -1.9908e-04,  1.1110e-03],
        [ 1.9474e-03, -7.5459e-05,  5.2547e-04, -2.4147e-03, -2.3327e-03,
          4.9639e-04, -1.9426e-03,  1.8253e-03,  1.6394e-03, -1.9274e-03],
        [-5.3406e-04, -4.9400e-04,  8.1968e-04, -9.0885e-04,  4.5943e-04,
          4.0442e-05, -9.9957e-05, -7.5531e-04,  3.3975e-04,  1.7214e-03],
        [ 1.2102e-03,  9.0170e-04,  8.9359e-04,  6.8069e-05,  1.0252e-03,
         -1.5965e-03, -6.6853e-04, -6.1750e-04,  1.2846e-03, -2.7680e-04],
        [ 3.5596e-04, -2.0599e-03, -2.1255e-04, -2.7466e-03, -9.0027e-04,
         -4.5395e-04, -1.0900e-03,  5.6458e-04, -3.4847e-03,  4.3716e-03],
        [ 9.3365e-04,  1.7452e-04,  2.9802e-04,  8.2541e-04, -6.4278e-04,
          9.2268e-04, -4.8256e-

 Parameter containing:
tensor([[-0.1289,  0.0105, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1015],
        [ 0.2236,  0.0589,  0.0865,  0.3052, -0.1426,  0.1002, -0.1407,  0.2260,
          0.2524, -0.2915],
        [ 0.2700,  0.1510,  0.1378,  0.1301,  0.2629, -0.0417, -0.2673, -0.0908,
         -0.2227,  0.0209],
        [-0.0591, -0.1697, -0.0288,  0.2996, -0.0250,  0.0101, -0.0493,  0.0497,
          0.2817,  0.1934],
        [ 0.1122,  0.0688,  0.0746,  0.1223, -0.0409, -0.2939, -0.1956,  0.2700,
          0.0190, -0.2561],
        [ 0.0498,  0.2612, -0.2988, -0.2129, -0.1260,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1481],
        [ 0.2776,  0.1932, -0.2239, -0.2549,  0.1313,  0.0071,  0.1296, -0.3091,
         -0.0189,  0.2230],
        [ 0.1467,  0.0115,  0.0622, -0.0299, -0.1738, -0.1195, -0.1926,  0.2627,
          0.1741,  0.110

 tensor([[-2.1303e-04,  9.1839e-04,  4.0650e-04, -4.6682e-04,  6.8140e-04,
         -2.2972e-04,  1.0557e-03,  5.2786e-04, -1.2436e-03,  5.8699e-04],
        [-6.6710e-04,  1.4579e-04,  4.7064e-04,  4.1151e-04,  1.7703e-04,
         -4.2105e-04,  2.2340e-04, -3.7265e-04, -1.9908e-04,  1.1110e-03],
        [ 1.9474e-03, -7.5459e-05,  5.2547e-04, -2.4147e-03, -2.3327e-03,
          4.9639e-04, -1.9426e-03,  1.8253e-03,  1.6394e-03, -1.9274e-03],
        [-5.3406e-04, -4.9400e-04,  8.1968e-04, -9.0885e-04,  4.5943e-04,
          4.0442e-05, -9.9957e-05, -7.5531e-04,  3.3975e-04,  1.7214e-03],
        [ 1.2102e-03,  9.0170e-04,  8.9359e-04,  6.8069e-05,  1.0252e-03,
         -1.5965e-03, -6.6853e-04, -6.1750e-04,  1.2846e-03, -2.7680e-04],
        [ 3.5596e-04, -2.0599e-03, -2.1255e-04, -2.7466e-03, -9.0027e-04,
         -4.5395e-04, -1.0900e-03,  5.6458e-04, -3.4847e-03,  4.3716e-03],
        [ 9.3365e-04,  1.7452e-04,  2.9802e-04,  8.2541e-04, -6.4278e-04,
          9.2268e-04, -4.8256e-

 tensor([[-0.1289,  0.0105, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1015],
        [ 0.2236,  0.0589,  0.0865,  0.3052, -0.1425,  0.1001, -0.1407,  0.2259,
          0.2524, -0.2915],
        [ 0.2700,  0.1510,  0.1378,  0.1301,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0209],
        [-0.0591, -0.1697, -0.0288,  0.2996, -0.0250,  0.0101, -0.0493,  0.0497,
          0.2817,  0.1934],
        [ 0.1122,  0.0688,  0.0746,  0.1222, -0.0408, -0.2939, -0.1955,  0.2700,
          0.0190, -0.2562],
        [ 0.0498,  0.2612, -0.2988, -0.2129, -0.1260,  0.0127, -0.0737, -0.0348,
         -0.3084,  0.1481],
        [ 0.2776,  0.1932, -0.2239, -0.2549,  0.1313,  0.0071,  0.1296, -0.3091,
         -0.0189,  0.2230],
        [ 0.1467,  0.0115,  0.0622, -0.0299, -0.1738, -0.1195, -0.1926,  0.2627,
          0.1741,  0.1106],
        [-0.2426, 

 Parameter containing:
tensor([[-0.1289,  0.0105, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1015],
        [ 0.2236,  0.0589,  0.0865,  0.3052, -0.1426,  0.1002, -0.1407,  0.2260,
          0.2524, -0.2915],
        [ 0.2700,  0.1510,  0.1378,  0.1301,  0.2629, -0.0417, -0.2673, -0.0908,
         -0.2227,  0.0209],
        [-0.0591, -0.1697, -0.0288,  0.2996, -0.0250,  0.0101, -0.0493,  0.0497,
          0.2817,  0.1934],
        [ 0.1122,  0.0688,  0.0746,  0.1223, -0.0409, -0.2939, -0.1956,  0.2700,
          0.0190, -0.2561],
        [ 0.0498,  0.2612, -0.2988, -0.2129, -0.1260,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1481],
        [ 0.2776,  0.1932, -0.2239, -0.2549,  0.1313,  0.0071,  0.1296, -0.3091,
         -0.0189,  0.2230],
        [ 0.1467,  0.0115,  0.0622, -0.0299, -0.1738, -0.1195, -0.1926,  0.2627,
          0.1741,  0.110

 tensor([[ 5.6922e-05,  7.9250e-04,  8.0967e-04, -5.7268e-04,  1.6987e-04,
          1.8466e-04,  4.0007e-04,  3.8433e-04, -7.1621e-04,  3.7503e-04],
        [-5.3835e-04,  5.4121e-04,  1.0920e-04,  3.5048e-05,  2.0993e-04,
         -5.5218e-04,  2.9945e-04,  2.3079e-04, -4.2439e-04,  7.2002e-04],
        [ 1.5182e-03,  2.4691e-05,  3.9697e-05, -2.2850e-03, -1.1854e-03,
          1.0915e-05, -2.3594e-03, -2.2864e-04,  1.5326e-03, -9.4318e-04],
        [-1.4439e-03, -1.3933e-03, -4.4560e-04, -2.7490e-04, -8.3399e-04,
          1.4429e-03, -7.7534e-04,  3.1412e-05, -6.6423e-04,  2.6188e-03],
        [ 1.4849e-03,  1.2217e-03,  6.4278e-04, -4.3201e-04,  1.4515e-03,
         -1.4391e-03, -3.5262e-04, -7.1716e-04,  1.1005e-03,  9.8586e-05],
        [ 8.1241e-05, -1.7662e-03, -6.0749e-04, -9.0694e-04, -1.3704e-03,
          1.3180e-03, -1.4448e-03, -1.8728e-04, -2.8820e-03,  3.4618e-03],
        [ 7.3862e-04, -2.8849e-04, -9.7609e-04,  6.4707e-04, -5.7936e-04,
         -1.0812e-04,  1.0091e-

 Parameter containing:
tensor([[-0.1289,  0.0105, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1014],
        [ 0.2236,  0.0589,  0.0865,  0.3052, -0.1426,  0.1002, -0.1406,  0.2260,
          0.2524, -0.2915],
        [ 0.2700,  0.1510,  0.1378,  0.1301,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0208],
        [-0.0591, -0.1697, -0.0288,  0.2996, -0.0250,  0.0101, -0.0493,  0.0497,
          0.2817,  0.1934],
        [ 0.1122,  0.0688,  0.0746,  0.1223, -0.0408, -0.2939, -0.1956,  0.2700,
          0.0190, -0.2563],
        [ 0.0498,  0.2612, -0.2988, -0.2129, -0.1260,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1481],
        [ 0.2776,  0.1932, -0.2239, -0.2549,  0.1313,  0.0071,  0.1296, -0.3091,
         -0.0189,  0.2230],
        [ 0.1467,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1926,  0.2627,
          0.1741,  0.110

 tensor([[ 5.6922e-05,  7.9250e-04,  8.0967e-04, -5.7268e-04,  1.6987e-04,
          1.8466e-04,  4.0007e-04,  3.8433e-04, -7.1621e-04,  3.7503e-04],
        [-5.3835e-04,  5.4121e-04,  1.0920e-04,  3.5048e-05,  2.0993e-04,
         -5.5218e-04,  2.9945e-04,  2.3079e-04, -4.2439e-04,  7.2002e-04],
        [ 1.5182e-03,  2.4691e-05,  3.9697e-05, -2.2850e-03, -1.1854e-03,
          1.0915e-05, -2.3594e-03, -2.2864e-04,  1.5326e-03, -9.4318e-04],
        [-1.4439e-03, -1.3933e-03, -4.4560e-04, -2.7490e-04, -8.3399e-04,
          1.4429e-03, -7.7534e-04,  3.1412e-05, -6.6423e-04,  2.6188e-03],
        [ 1.4849e-03,  1.2217e-03,  6.4278e-04, -4.3201e-04,  1.4515e-03,
         -1.4391e-03, -3.5262e-04, -7.1716e-04,  1.1005e-03,  9.8586e-05],
        [ 8.1241e-05, -1.7662e-03, -6.0749e-04, -9.0694e-04, -1.3704e-03,
          1.3180e-03, -1.4448e-03, -1.8728e-04, -2.8820e-03,  3.4618e-03],
        [ 7.3862e-04, -2.8849e-04, -9.7609e-04,  6.4707e-04, -5.7936e-04,
         -1.0812e-04,  1.0091e-

 tensor([[-0.1289,  0.0105, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1015],
        [ 0.2236,  0.0589,  0.0865,  0.3052, -0.1425,  0.1001, -0.1407,  0.2259,
          0.2524, -0.2915],
        [ 0.2701,  0.1510,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2226,  0.0208],
        [-0.0591, -0.1697, -0.0288,  0.2996, -0.0250,  0.0101, -0.0493,  0.0497,
          0.2817,  0.1934],
        [ 0.1122,  0.0688,  0.0746,  0.1223, -0.0408, -0.2940, -0.1955,  0.2700,
          0.0190, -0.2562],
        [ 0.0498,  0.2612, -0.2988, -0.2129, -0.1260,  0.0127, -0.0737, -0.0348,
         -0.3084,  0.1481],
        [ 0.2776,  0.1932, -0.2239, -0.2549,  0.1313,  0.0071,  0.1296, -0.3091,
         -0.0189,  0.2230],
        [ 0.1468,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1926,  0.2627,
          0.1741,  0.1106],
        [-0.2426, 

 Parameter containing:
tensor([[-0.1289,  0.0105, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1014],
        [ 0.2236,  0.0589,  0.0865,  0.3052, -0.1426,  0.1002, -0.1406,  0.2260,
          0.2524, -0.2915],
        [ 0.2700,  0.1510,  0.1378,  0.1301,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0208],
        [-0.0591, -0.1697, -0.0288,  0.2996, -0.0250,  0.0101, -0.0493,  0.0497,
          0.2817,  0.1934],
        [ 0.1122,  0.0688,  0.0746,  0.1223, -0.0408, -0.2939, -0.1956,  0.2700,
          0.0190, -0.2563],
        [ 0.0498,  0.2612, -0.2988, -0.2129, -0.1260,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1481],
        [ 0.2776,  0.1932, -0.2239, -0.2549,  0.1313,  0.0071,  0.1296, -0.3091,
         -0.0189,  0.2230],
        [ 0.1467,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1926,  0.2627,
          0.1741,  0.110

 tensor([[-2.2817e-04,  3.6478e-04,  5.7173e-04, -4.7421e-04,  4.5586e-04,
          3.0804e-04,  1.8358e-04,  5.7173e-04, -4.8208e-04,  3.9268e-04],
        [-1.0662e-03,  1.9240e-04,  1.6499e-04,  5.1546e-04,  3.1662e-04,
         -3.5930e-04,  5.8556e-04, -7.1645e-05, -3.7980e-04,  1.2350e-03],
        [ 2.0618e-03,  2.5535e-04,  8.5592e-04, -2.2163e-03, -2.1572e-03,
         -6.1893e-04, -2.4014e-03,  9.3222e-05,  1.3514e-03, -1.9455e-03],
        [-2.0719e-04, -8.9931e-04,  7.5245e-04, -6.4087e-04,  9.4414e-04,
          6.4516e-04, -5.2357e-04,  3.0947e-04,  2.3043e-04,  1.9417e-03],
        [ 1.3027e-03,  1.6766e-03,  2.5296e-04, -1.9222e-05,  1.5783e-03,
         -1.7605e-03,  3.0234e-05, -1.4544e-03,  1.1148e-03, -5.5599e-04],
        [ 2.4962e-04, -2.0218e-03,  1.1086e-04, -1.0748e-03,  3.5620e-04,
          1.0719e-03, -1.7872e-03, -1.1587e-03, -2.3270e-03,  3.9902e-03],
        [ 6.1226e-04,  3.4499e-04, -6.1798e-04,  6.9571e-04, -1.6155e-03,
         -2.8086e-04,  4.0293e-

 Parameter containing:
tensor([[-0.1289,  0.0105, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1014],
        [ 0.2235,  0.0589,  0.0865,  0.3052, -0.1425,  0.1002, -0.1406,  0.2260,
          0.2524, -0.2915],
        [ 0.2700,  0.1510,  0.1378,  0.1301,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0208],
        [-0.0591, -0.1697, -0.0288,  0.2996, -0.0250,  0.0101, -0.0493,  0.0497,
          0.2817,  0.1934],
        [ 0.1122,  0.0688,  0.0746,  0.1223, -0.0408, -0.2939, -0.1956,  0.2700,
          0.0190, -0.2563],
        [ 0.0498,  0.2612, -0.2988, -0.2129, -0.1260,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1481],
        [ 0.2776,  0.1932, -0.2239, -0.2549,  0.1313,  0.0071,  0.1296, -0.3091,
         -0.0189,  0.2230],
        [ 0.1467,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1926,  0.2627,
          0.1741,  0.110

 tensor([[-2.2817e-04,  3.6478e-04,  5.7173e-04, -4.7421e-04,  4.5586e-04,
          3.0804e-04,  1.8358e-04,  5.7173e-04, -4.8208e-04,  3.9268e-04],
        [-1.0662e-03,  1.9240e-04,  1.6499e-04,  5.1546e-04,  3.1662e-04,
         -3.5930e-04,  5.8556e-04, -7.1645e-05, -3.7980e-04,  1.2350e-03],
        [ 2.0618e-03,  2.5535e-04,  8.5592e-04, -2.2163e-03, -2.1572e-03,
         -6.1893e-04, -2.4014e-03,  9.3222e-05,  1.3514e-03, -1.9455e-03],
        [-2.0719e-04, -8.9931e-04,  7.5245e-04, -6.4087e-04,  9.4414e-04,
          6.4516e-04, -5.2357e-04,  3.0947e-04,  2.3043e-04,  1.9417e-03],
        [ 1.3027e-03,  1.6766e-03,  2.5296e-04, -1.9222e-05,  1.5783e-03,
         -1.7605e-03,  3.0234e-05, -1.4544e-03,  1.1148e-03, -5.5599e-04],
        [ 2.4962e-04, -2.0218e-03,  1.1086e-04, -1.0748e-03,  3.5620e-04,
          1.0719e-03, -1.7872e-03, -1.1587e-03, -2.3270e-03,  3.9902e-03],
        [ 6.1226e-04,  3.4499e-04, -6.1798e-04,  6.9571e-04, -1.6155e-03,
         -2.8086e-04,  4.0293e-

 tensor([[ 1.1469e-01,  1.8591e-01,  5.8105e-01, -2.3596e-01,  1.6858e-01,
          3.2153e-01,  2.4170e-01,  6.0120e-02, -2.3914e-01,  3.4912e-01],
        [-3.1299e-01, -9.5337e-02,  1.4709e-01, -1.7737e-01,  3.2031e-01,
         -5.3467e-01, -1.8492e-03,  8.5831e-03, -1.4465e-01,  7.4023e-01],
        [ 9.2578e-01,  4.0918e-01, -4.4165e-01, -1.1875e+00, -4.6094e-01,
         -1.3660e-01, -1.3545e+00, -9.7351e-02,  6.1768e-01, -9.4287e-01],
        [-5.7373e-01, -4.3823e-01,  2.2656e-01, -6.6553e-01, -4.0680e-02,
          3.1445e-01, -5.8984e-01, -2.0459e-01, -1.7371e-01,  1.8330e+00],
        [ 6.8604e-01,  9.8096e-01,  4.1968e-01, -7.3792e-02,  1.0654e+00,
         -9.1016e-01, -4.9780e-01, -6.0596e-01,  5.8398e-01, -3.9453e-01],
        [ 3.1204e-03, -9.4531e-01,  1.9379e-02, -1.0586e+00, -1.3159e-01,
          4.6362e-01, -9.9902e-01, -9.0039e-01, -9.1211e-01,  2.3691e+00],
        [ 1.2543e-02,  1.5497e-03, -4.1992e-01,  2.3950e-01, -6.1133e-01,
         -4.3243e-02,  1.2280e-

 tensor([[-0.1289,  0.0105, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1015],
        [ 0.2236,  0.0589,  0.0865,  0.3053, -0.1425,  0.1002, -0.1407,  0.2259,
          0.2524, -0.2914],
        [ 0.2701,  0.1510,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0208],
        [-0.0591, -0.1697, -0.0288,  0.2996, -0.0250,  0.0101, -0.0493,  0.0497,
          0.2817,  0.1934],
        [ 0.1122,  0.0688,  0.0746,  0.1223, -0.0408, -0.2940, -0.1955,  0.2700,
          0.0190, -0.2563],
        [ 0.0498,  0.2612, -0.2988, -0.2129, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3084,  0.1481],
        [ 0.2776,  0.1932, -0.2239, -0.2548,  0.1313,  0.0071,  0.1296, -0.3091,
         -0.0189,  0.2230],
        [ 0.1468,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1927,  0.2626,
          0.1741,  0.1106],
        [-0.2426, 

 tensor([[ 2.2399e-04,  3.6311e-04,  1.1349e-03, -4.6086e-04,  3.2926e-04,
          6.2799e-04,  4.7207e-04,  1.1742e-04, -4.6706e-04,  6.8188e-04],
        [-6.1131e-04, -1.8620e-04,  2.8729e-04, -3.4642e-04,  6.2561e-04,
         -1.0443e-03, -3.6117e-06,  1.6764e-05, -2.8253e-04,  1.4458e-03],
        [ 1.8082e-03,  7.9918e-04, -8.6260e-04, -2.3193e-03, -9.0027e-04,
         -2.6679e-04, -2.6455e-03, -1.9014e-04,  1.2064e-03, -1.8415e-03],
        [-1.1206e-03, -8.5592e-04,  4.4250e-04, -1.2999e-03, -7.9453e-05,
          6.1417e-04, -1.1520e-03, -3.9959e-04, -3.3927e-04,  3.5801e-03],
        [ 1.3399e-03,  1.9159e-03,  8.1968e-04, -1.4412e-04,  2.0809e-03,
         -1.7776e-03, -9.7227e-04, -1.1835e-03,  1.1406e-03, -7.7057e-04],
        [ 6.0946e-06, -1.8463e-03,  3.7849e-05, -2.0676e-03, -2.5702e-04,
          9.0551e-04, -1.9512e-03, -1.7586e-03, -1.7815e-03,  4.6272e-03],
        [ 2.4498e-05,  3.0268e-06, -8.2016e-04,  4.6778e-04, -1.1940e-03,
         -8.4460e-05,  2.3985e-

 tensor([[-0.1289,  0.0104, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1014],
        [ 0.2235,  0.0589,  0.0865,  0.3053, -0.1425,  0.1002, -0.1406,  0.2259,
          0.2524, -0.2914],
        [ 0.2701,  0.1510,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2226,  0.0208],
        [-0.0591, -0.1697, -0.0288,  0.2996, -0.0250,  0.0101, -0.0493,  0.0498,
          0.2817,  0.1934],
        [ 0.1122,  0.0688,  0.0746,  0.1223, -0.0408, -0.2940, -0.1955,  0.2701,
          0.0190, -0.2563],
        [ 0.0498,  0.2612, -0.2988, -0.2129, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3084,  0.1481],
        [ 0.2777,  0.1932, -0.2239, -0.2548,  0.1313,  0.0071,  0.1296, -0.3090,
         -0.0189,  0.2230],
        [ 0.1468,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1927,  0.2626,
          0.1741,  0.1106],
        [-0.2426, 

 tensor([[ 2.2399e-04,  3.6311e-04,  1.1349e-03, -4.6086e-04,  3.2926e-04,
          6.2799e-04,  4.7207e-04,  1.1742e-04, -4.6706e-04,  6.8188e-04],
        [-6.1131e-04, -1.8620e-04,  2.8729e-04, -3.4642e-04,  6.2561e-04,
         -1.0443e-03, -3.6117e-06,  1.6764e-05, -2.8253e-04,  1.4458e-03],
        [ 1.8082e-03,  7.9918e-04, -8.6260e-04, -2.3193e-03, -9.0027e-04,
         -2.6679e-04, -2.6455e-03, -1.9014e-04,  1.2064e-03, -1.8415e-03],
        [-1.1206e-03, -8.5592e-04,  4.4250e-04, -1.2999e-03, -7.9453e-05,
          6.1417e-04, -1.1520e-03, -3.9959e-04, -3.3927e-04,  3.5801e-03],
        [ 1.3399e-03,  1.9159e-03,  8.1968e-04, -1.4412e-04,  2.0809e-03,
         -1.7776e-03, -9.7227e-04, -1.1835e-03,  1.1406e-03, -7.7057e-04],
        [ 6.0946e-06, -1.8463e-03,  3.7849e-05, -2.0676e-03, -2.5702e-04,
          9.0551e-04, -1.9512e-03, -1.7586e-03, -1.7815e-03,  4.6272e-03],
        [ 2.4498e-05,  3.0268e-06, -8.2016e-04,  4.6778e-04, -1.1940e-03,
         -8.4460e-05,  2.3985e-

 tensor([[-0.1289,  0.0104, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1014],
        [ 0.2235,  0.0589,  0.0865,  0.3053, -0.1425,  0.1002, -0.1406,  0.2259,
          0.2524, -0.2914],
        [ 0.2701,  0.1510,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2226,  0.0208],
        [-0.0591, -0.1697, -0.0288,  0.2996, -0.0250,  0.0101, -0.0493,  0.0498,
          0.2817,  0.1934],
        [ 0.1122,  0.0688,  0.0746,  0.1223, -0.0408, -0.2940, -0.1955,  0.2701,
          0.0190, -0.2563],
        [ 0.0498,  0.2612, -0.2988, -0.2129, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3084,  0.1481],
        [ 0.2777,  0.1932, -0.2239, -0.2548,  0.1313,  0.0071,  0.1296, -0.3090,
         -0.0189,  0.2230],
        [ 0.1468,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1927,  0.2626,
          0.1741,  0.1106],
        [-0.2426, 

 Parameter containing:
tensor([[-0.1289,  0.0104, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1014],
        [ 0.2235,  0.0589,  0.0865,  0.3052, -0.1425,  0.1002, -0.1406,  0.2260,
          0.2524, -0.2915],
        [ 0.2700,  0.1510,  0.1378,  0.1301,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0208],
        [-0.0591, -0.1698, -0.0288,  0.2996, -0.0250,  0.0101, -0.0493,  0.0498,
          0.2817,  0.1934],
        [ 0.1122,  0.0688,  0.0746,  0.1223, -0.0408, -0.2939, -0.1954,  0.2700,
          0.0190, -0.2563],
        [ 0.0498,  0.2612, -0.2988, -0.2129, -0.1260,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1481],
        [ 0.2776,  0.1932, -0.2240, -0.2549,  0.1312,  0.0071,  0.1296, -0.3091,
         -0.0189,  0.2229],
        [ 0.1467,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1926,  0.2627,
          0.1741,  0.110

 tensor([[ 2.2873e-05,  1.9205e-04,  6.2227e-04, -1.0872e-03,  5.4026e-04,
          1.0610e-04,  3.4499e-04,  3.1543e-04, -9.7418e-04,  8.1205e-04],
        [-5.3257e-05, -5.2738e-04,  2.9945e-04, -2.3568e-04, -2.0766e-04,
         -7.2479e-04,  2.0254e-04,  2.5272e-04, -3.3069e-04,  1.0939e-03],
        [ 1.6708e-03, -4.7946e-04,  9.5844e-04, -1.8940e-03, -2.1210e-03,
          6.4564e-04, -2.1210e-03, -1.4019e-04,  1.3905e-03, -1.6375e-03],
        [-1.5106e-03, -9.7561e-04, -1.1139e-03, -6.9475e-04, -4.5395e-04,
          5.6458e-04, -9.0742e-04, -2.6989e-04,  1.1317e-05,  2.6188e-03],
        [ 2.0409e-03,  1.5469e-03, -4.2096e-06, -5.9032e-04,  1.6384e-03,
         -1.5001e-03, -4.9353e-04, -1.3828e-03,  1.5087e-03, -1.1339e-03],
        [ 3.2973e-04, -1.8263e-03, -2.0695e-03, -2.3060e-03, -3.3307e-04,
          1.3790e-03, -1.7729e-03, -9.4223e-04, -2.2945e-03,  4.7340e-03],
        [ 2.8181e-04,  1.3626e-04, -4.1890e-04,  8.9264e-04, -9.4366e-04,
         -2.6846e-04, -8.3521e-

 Parameter containing:
tensor([[-0.1289,  0.0104, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1014],
        [ 0.2235,  0.0589,  0.0865,  0.3054, -0.1425,  0.1002, -0.1406,  0.2260,
          0.2524, -0.2915],
        [ 0.2700,  0.1510,  0.1378,  0.1301,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0208],
        [-0.0591, -0.1698, -0.0288,  0.2996, -0.0250,  0.0101, -0.0493,  0.0498,
          0.2817,  0.1934],
        [ 0.1122,  0.0688,  0.0746,  0.1223, -0.0408, -0.2939, -0.1954,  0.2700,
          0.0190, -0.2563],
        [ 0.0498,  0.2612, -0.2988, -0.2129, -0.1260,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1481],
        [ 0.2776,  0.1932, -0.2240, -0.2549,  0.1312,  0.0071,  0.1296, -0.3091,
         -0.0189,  0.2229],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1926,  0.2627,
          0.1741,  0.110

 tensor([[ 2.2873e-05,  1.9205e-04,  6.2227e-04, -1.0872e-03,  5.4026e-04,
          1.0610e-04,  3.4499e-04,  3.1543e-04, -9.7418e-04,  8.1205e-04],
        [-5.3257e-05, -5.2738e-04,  2.9945e-04, -2.3568e-04, -2.0766e-04,
         -7.2479e-04,  2.0254e-04,  2.5272e-04, -3.3069e-04,  1.0939e-03],
        [ 1.6708e-03, -4.7946e-04,  9.5844e-04, -1.8940e-03, -2.1210e-03,
          6.4564e-04, -2.1210e-03, -1.4019e-04,  1.3905e-03, -1.6375e-03],
        [-1.5106e-03, -9.7561e-04, -1.1139e-03, -6.9475e-04, -4.5395e-04,
          5.6458e-04, -9.0742e-04, -2.6989e-04,  1.1317e-05,  2.6188e-03],
        [ 2.0409e-03,  1.5469e-03, -4.2096e-06, -5.9032e-04,  1.6384e-03,
         -1.5001e-03, -4.9353e-04, -1.3828e-03,  1.5087e-03, -1.1339e-03],
        [ 3.2973e-04, -1.8263e-03, -2.0695e-03, -2.3060e-03, -3.3307e-04,
          1.3790e-03, -1.7729e-03, -9.4223e-04, -2.2945e-03,  4.7340e-03],
        [ 2.8181e-04,  1.3626e-04, -4.1890e-04,  8.9264e-04, -9.4366e-04,
         -2.6846e-04, -8.3521e-

 tensor([[-0.1289,  0.0104, -0.1571,  0.1193, -0.2696,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1014],
        [ 0.2235,  0.0589,  0.0865,  0.3053, -0.1425,  0.1002, -0.1406,  0.2259,
          0.2524, -0.2914],
        [ 0.2701,  0.1511,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0907,
         -0.2226,  0.0207],
        [-0.0591, -0.1698, -0.0288,  0.2996, -0.0250,  0.0101, -0.0493,  0.0498,
          0.2817,  0.1934],
        [ 0.1122,  0.0688,  0.0746,  0.1223, -0.0408, -0.2940, -0.1955,  0.2701,
          0.0191, -0.2564],
        [ 0.0498,  0.2612, -0.2988, -0.2129, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3084,  0.1481],
        [ 0.2777,  0.1932, -0.2240, -0.2548,  0.1313,  0.0071,  0.1296, -0.3090,
         -0.0189,  0.2229],
        [ 0.1468,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1927,  0.2626,
          0.1741,  0.1105],
        [-0.2426, 

 Parameter containing:
tensor([[-0.1289,  0.0104, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1014],
        [ 0.2235,  0.0589,  0.0865,  0.3054, -0.1425,  0.1002, -0.1406,  0.2260,
          0.2524, -0.2915],
        [ 0.2700,  0.1510,  0.1378,  0.1301,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0208],
        [-0.0591, -0.1698, -0.0288,  0.2996, -0.0250,  0.0101, -0.0493,  0.0498,
          0.2817,  0.1934],
        [ 0.1122,  0.0688,  0.0746,  0.1223, -0.0408, -0.2939, -0.1954,  0.2700,
          0.0190, -0.2563],
        [ 0.0498,  0.2612, -0.2988, -0.2129, -0.1260,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1481],
        [ 0.2776,  0.1932, -0.2240, -0.2549,  0.1312,  0.0071,  0.1296, -0.3091,
         -0.0189,  0.2229],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1926,  0.2627,
          0.1741,  0.110

 tensor([[-2.9278e-04,  3.6097e-04,  8.4114e-04, -4.7350e-04,  2.4557e-04,
          5.5075e-04,  7.8154e-04,  8.4448e-04, -7.5769e-04,  5.4741e-04],
        [-7.6008e-04,  7.7772e-04,  3.2091e-04,  5.1308e-04,  1.0617e-05,
         -4.3774e-04,  4.0859e-05, -2.4819e-04, -1.4627e-04,  5.9795e-04],
        [ 1.9064e-03, -7.7629e-04,  6.9523e-04, -1.1768e-03, -1.7672e-03,
          6.3324e-04, -2.0409e-03,  2.8968e-04,  1.3437e-03, -1.6289e-03],
        [-7.3624e-04, -3.5882e-04,  2.2388e-04, -1.1549e-03, -6.4135e-04,
          7.5293e-04, -6.9439e-05, -3.1233e-04,  3.6526e-04,  2.4681e-03],
        [ 8.8072e-04,  1.7719e-03, -7.0190e-04, -9.2697e-04,  1.9646e-03,
         -2.1420e-03, -6.5684e-05, -7.8726e-04, -1.7285e-05, -3.6931e-04],
        [-3.8886e-04, -8.1825e-04, -2.6536e-04, -1.7643e-03, -1.1997e-03,
          7.9393e-04, -4.3774e-04,  1.4472e-04, -1.2560e-03,  4.2534e-03],
        [ 8.8072e-04,  8.8632e-05, -4.3249e-04,  1.1387e-03, -1.0653e-03,
         -6.0320e-04, -5.3263e-

 Parameter containing:
tensor([[-0.1289,  0.0104, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1014],
        [ 0.2235,  0.0589,  0.0864,  0.3054, -0.1425,  0.1002, -0.1406,  0.2260,
          0.2524, -0.2915],
        [ 0.2700,  0.1511,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0207],
        [-0.0591, -0.1698, -0.0288,  0.2996, -0.0251,  0.0102, -0.0493,  0.0498,
          0.2817,  0.1934],
        [ 0.1122,  0.0688,  0.0746,  0.1223, -0.0408, -0.2939, -0.1954,  0.2700,
          0.0191, -0.2563],
        [ 0.0498,  0.2612, -0.2988, -0.2129, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1481],
        [ 0.2776,  0.1932, -0.2240, -0.2549,  0.1312,  0.0071,  0.1296, -0.3091,
         -0.0189,  0.2229],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1926,  0.2627,
          0.1741,  0.110

 tensor([[-2.9278e-04,  3.6097e-04,  8.4114e-04, -4.7350e-04,  2.4557e-04,
          5.5075e-04,  7.8154e-04,  8.4448e-04, -7.5769e-04,  5.4741e-04],
        [-7.6008e-04,  7.7772e-04,  3.2091e-04,  5.1308e-04,  1.0617e-05,
         -4.3774e-04,  4.0859e-05, -2.4819e-04, -1.4627e-04,  5.9795e-04],
        [ 1.9064e-03, -7.7629e-04,  6.9523e-04, -1.1768e-03, -1.7672e-03,
          6.3324e-04, -2.0409e-03,  2.8968e-04,  1.3437e-03, -1.6289e-03],
        [-7.3624e-04, -3.5882e-04,  2.2388e-04, -1.1549e-03, -6.4135e-04,
          7.5293e-04, -6.9439e-05, -3.1233e-04,  3.6526e-04,  2.4681e-03],
        [ 8.8072e-04,  1.7719e-03, -7.0190e-04, -9.2697e-04,  1.9646e-03,
         -2.1420e-03, -6.5684e-05, -7.8726e-04, -1.7285e-05, -3.6931e-04],
        [-3.8886e-04, -8.1825e-04, -2.6536e-04, -1.7643e-03, -1.1997e-03,
          7.9393e-04, -4.3774e-04,  1.4472e-04, -1.2560e-03,  4.2534e-03],
        [ 8.8072e-04,  8.8632e-05, -4.3249e-04,  1.1387e-03, -1.0653e-03,
         -6.0320e-04, -5.3263e-

 tensor([[-0.1289,  0.0104, -0.1571,  0.1193, -0.2696,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1431],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1014],
        [ 0.2235,  0.0589,  0.0865,  0.3053, -0.1425,  0.1001, -0.1406,  0.2259,
          0.2523, -0.2914],
        [ 0.2701,  0.1511,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0907,
         -0.2227,  0.0207],
        [-0.0591, -0.1698, -0.0288,  0.2996, -0.0251,  0.0102, -0.0493,  0.0498,
          0.2817,  0.1934],
        [ 0.1122,  0.0688,  0.0746,  0.1223, -0.0408, -0.2940, -0.1955,  0.2701,
          0.0191, -0.2564],
        [ 0.0498,  0.2612, -0.2988, -0.2129, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3084,  0.1481],
        [ 0.2777,  0.1932, -0.2240, -0.2548,  0.1312,  0.0071,  0.1296, -0.3090,
         -0.0189,  0.2229],
        [ 0.1468,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1927,  0.2626,
          0.1741,  0.1105],
        [-0.2426, 

 Parameter containing:
tensor([[-0.1289,  0.0104, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1014],
        [ 0.2235,  0.0589,  0.0864,  0.3054, -0.1425,  0.1002, -0.1406,  0.2260,
          0.2524, -0.2915],
        [ 0.2700,  0.1511,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0207],
        [-0.0591, -0.1698, -0.0288,  0.2996, -0.0251,  0.0102, -0.0493,  0.0498,
          0.2817,  0.1934],
        [ 0.1122,  0.0688,  0.0746,  0.1223, -0.0408, -0.2939, -0.1954,  0.2700,
          0.0191, -0.2563],
        [ 0.0498,  0.2612, -0.2988, -0.2129, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1481],
        [ 0.2776,  0.1932, -0.2240, -0.2549,  0.1312,  0.0071,  0.1296, -0.3091,
         -0.0189,  0.2229],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1926,  0.2627,
          0.1741,  0.110

 tensor([[-8.2350e-04,  5.1165e-04,  6.6280e-04, -1.3602e-04,  5.5552e-04,
         -4.3082e-04,  4.0317e-04,  5.7077e-04, -8.1253e-04,  3.1447e-04],
        [ 1.8045e-05,  3.8457e-04,  3.9506e-04, -3.0899e-04,  1.7881e-04,
         -3.2544e-04,  5.6982e-05, -3.2216e-05, -2.6941e-04,  3.8433e-04],
        [ 1.3866e-03,  8.3268e-05,  5.6171e-04, -1.1635e-03, -2.1458e-03,
          7.2479e-04, -2.1648e-03,  9.1934e-04,  1.3876e-03, -6.6423e-04],
        [-1.2264e-03, -9.8801e-04, -5.6177e-05, -6.6519e-04,  3.4690e-04,
          3.0041e-04, -6.7997e-04, -6.2704e-04, -4.7773e-05,  2.3365e-03],
        [ 1.1501e-03,  1.5535e-03,  5.3835e-04, -7.2002e-04,  2.0580e-03,
         -1.0586e-03, -5.5981e-04, -5.4932e-04,  4.7255e-04,  3.3587e-05],
        [ 9.0957e-05, -1.7653e-03, -1.0548e-03, -2.3136e-03, -8.7214e-04,
         -8.2612e-05, -1.1759e-03, -9.3079e-04, -2.5063e-03,  2.8687e-03],
        [ 1.3971e-03,  7.1704e-05, -5.2404e-04,  5.7125e-04, -7.8583e-04,
         -8.0168e-05, -1.3006e-

 Parameter containing:
tensor([[-0.1289,  0.0104, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1997,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1014],
        [ 0.2235,  0.0589,  0.0864,  0.3054, -0.1425,  0.1002, -0.1405,  0.2260,
          0.2524, -0.2915],
        [ 0.2700,  0.1511,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0207],
        [-0.0591, -0.1698, -0.0288,  0.2996, -0.0251,  0.0102, -0.0493,  0.0498,
          0.2817,  0.1934],
        [ 0.1122,  0.0688,  0.0746,  0.1224, -0.0408, -0.2939, -0.1954,  0.2700,
          0.0191, -0.2563],
        [ 0.0498,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1482],
        [ 0.2776,  0.1932, -0.2240, -0.2549,  0.1312,  0.0071,  0.1296, -0.3091,
         -0.0190,  0.2229],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1926,  0.2627,
          0.1742,  0.110

 tensor([[-8.2350e-04,  5.1165e-04,  6.6280e-04, -1.3602e-04,  5.5552e-04,
         -4.3082e-04,  4.0317e-04,  5.7077e-04, -8.1253e-04,  3.1447e-04],
        [ 1.8045e-05,  3.8457e-04,  3.9506e-04, -3.0899e-04,  1.7881e-04,
         -3.2544e-04,  5.6982e-05, -3.2216e-05, -2.6941e-04,  3.8433e-04],
        [ 1.3866e-03,  8.3268e-05,  5.6171e-04, -1.1635e-03, -2.1458e-03,
          7.2479e-04, -2.1648e-03,  9.1934e-04,  1.3876e-03, -6.6423e-04],
        [-1.2264e-03, -9.8801e-04, -5.6177e-05, -6.6519e-04,  3.4690e-04,
          3.0041e-04, -6.7997e-04, -6.2704e-04, -4.7773e-05,  2.3365e-03],
        [ 1.1501e-03,  1.5535e-03,  5.3835e-04, -7.2002e-04,  2.0580e-03,
         -1.0586e-03, -5.5981e-04, -5.4932e-04,  4.7255e-04,  3.3587e-05],
        [ 9.0957e-05, -1.7653e-03, -1.0548e-03, -2.3136e-03, -8.7214e-04,
         -8.2612e-05, -1.1759e-03, -9.3079e-04, -2.5063e-03,  2.8687e-03],
        [ 1.3971e-03,  7.1704e-05, -5.2404e-04,  5.7125e-04, -7.8583e-04,
         -8.0168e-05, -1.3006e-

 tensor([[-0.1289,  0.0104, -0.1572,  0.1193, -0.2696,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1431],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1014],
        [ 0.2235,  0.0589,  0.0864,  0.3053, -0.1424,  0.1001, -0.1406,  0.2259,
          0.2523, -0.2914],
        [ 0.2701,  0.1511,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0907,
         -0.2227,  0.0207],
        [-0.0592, -0.1698, -0.0288,  0.2996, -0.0251,  0.0102, -0.0493,  0.0498,
          0.2817,  0.1934],
        [ 0.1122,  0.0689,  0.0746,  0.1223, -0.0408, -0.2940, -0.1954,  0.2701,
          0.0191, -0.2564],
        [ 0.0498,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3084,  0.1481],
        [ 0.2777,  0.1932, -0.2240, -0.2548,  0.1312,  0.0071,  0.1296, -0.3090,
         -0.0189,  0.2229],
        [ 0.1468,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1927,  0.2626,
          0.1741,  0.1105],
        [-0.2426, 

 Parameter containing:
tensor([[-0.1289,  0.0104, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1997,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1014],
        [ 0.2235,  0.0589,  0.0864,  0.3054, -0.1425,  0.1002, -0.1405,  0.2260,
          0.2524, -0.2915],
        [ 0.2700,  0.1511,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0207],
        [-0.0591, -0.1698, -0.0288,  0.2996, -0.0251,  0.0102, -0.0493,  0.0498,
          0.2817,  0.1934],
        [ 0.1122,  0.0688,  0.0746,  0.1224, -0.0408, -0.2939, -0.1954,  0.2700,
          0.0191, -0.2563],
        [ 0.0498,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1482],
        [ 0.2776,  0.1932, -0.2240, -0.2549,  0.1312,  0.0071,  0.1296, -0.3091,
         -0.0190,  0.2229],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1926,  0.2627,
          0.1742,  0.110

 tensor([[-3.4404e-04,  3.9554e-04,  4.8518e-04, -8.5497e-04,  4.5681e-04,
          3.7968e-05,  7.0953e-04,  9.4604e-04, -1.0691e-03,  8.0776e-04],
        [-5.8699e-04,  8.3065e-04,  3.5548e-04, -2.3097e-05, -9.6798e-05,
         -4.1437e-04,  2.9087e-04,  1.6057e-04,  2.6941e-05,  9.9373e-04],
        [ 1.5783e-03, -5.1880e-04,  6.6376e-04, -1.5650e-03, -1.5764e-03,
         -8.9407e-05, -2.7714e-03,  1.0830e-04,  8.3685e-04, -1.6308e-03],
        [-1.7748e-03, -1.7703e-04,  1.8787e-04, -6.5804e-04, -3.8171e-04,
          1.4896e-03,  2.7633e-04,  2.1541e-04, -2.9159e-04,  2.3403e-03],
        [ 1.2541e-03,  1.6804e-03,  8.5449e-04, -3.0446e-04,  1.3781e-03,
         -3.3951e-04,  5.7650e-04, -4.2200e-04,  1.9627e-03,  3.4547e-04],
        [ 2.9421e-04, -1.0395e-03, -9.9599e-05, -2.2259e-03, -1.0920e-03,
          6.4802e-04, -7.8106e-04, -2.0397e-04, -2.3003e-03,  3.9520e-03],
        [ 1.5326e-03, -3.4928e-04, -9.3746e-04,  1.1187e-03, -7.1526e-04,
         -4.0150e-04, -5.0449e-

 Parameter containing:
tensor([[-0.1289,  0.0104, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1997,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1014],
        [ 0.2235,  0.0589,  0.0864,  0.3054, -0.1425,  0.1002, -0.1405,  0.2260,
          0.2522, -0.2913],
        [ 0.2700,  0.1511,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0207],
        [-0.0592, -0.1698, -0.0288,  0.2996, -0.0251,  0.0102, -0.0493,  0.0498,
          0.2817,  0.1934],
        [ 0.1122,  0.0688,  0.0746,  0.1224, -0.0408, -0.2939, -0.1954,  0.2700,
          0.0191, -0.2566],
        [ 0.0498,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1482],
        [ 0.2778,  0.1932, -0.2240, -0.2549,  0.1312,  0.0071,  0.1296, -0.3091,
         -0.0190,  0.2229],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1926,  0.2627,
          0.1742,  0.110

 tensor([[-3.4404e-04,  3.9554e-04,  4.8518e-04, -8.5497e-04,  4.5681e-04,
          3.7968e-05,  7.0953e-04,  9.4604e-04, -1.0691e-03,  8.0776e-04],
        [-5.8699e-04,  8.3065e-04,  3.5548e-04, -2.3097e-05, -9.6798e-05,
         -4.1437e-04,  2.9087e-04,  1.6057e-04,  2.6941e-05,  9.9373e-04],
        [ 1.5783e-03, -5.1880e-04,  6.6376e-04, -1.5650e-03, -1.5764e-03,
         -8.9407e-05, -2.7714e-03,  1.0830e-04,  8.3685e-04, -1.6308e-03],
        [-1.7748e-03, -1.7703e-04,  1.8787e-04, -6.5804e-04, -3.8171e-04,
          1.4896e-03,  2.7633e-04,  2.1541e-04, -2.9159e-04,  2.3403e-03],
        [ 1.2541e-03,  1.6804e-03,  8.5449e-04, -3.0446e-04,  1.3781e-03,
         -3.3951e-04,  5.7650e-04, -4.2200e-04,  1.9627e-03,  3.4547e-04],
        [ 2.9421e-04, -1.0395e-03, -9.9599e-05, -2.2259e-03, -1.0920e-03,
          6.4802e-04, -7.8106e-04, -2.0397e-04, -2.3003e-03,  3.9520e-03],
        [ 1.5326e-03, -3.4928e-04, -9.3746e-04,  1.1187e-03, -7.1526e-04,
         -4.0150e-04, -5.0449e-

 tensor([[-0.1289,  0.0104, -0.1572,  0.1193, -0.2696,  0.2318, -0.2299, -0.2515,
         -0.1998,  0.1431],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1014],
        [ 0.2235,  0.0589,  0.0864,  0.3053, -0.1424,  0.1001, -0.1405,  0.2259,
          0.2523, -0.2914],
        [ 0.2701,  0.1511,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0907,
         -0.2226,  0.0207],
        [-0.0592, -0.1698, -0.0288,  0.2996, -0.0251,  0.0102, -0.0493,  0.0498,
          0.2816,  0.1934],
        [ 0.1122,  0.0689,  0.0746,  0.1224, -0.0408, -0.2940, -0.1954,  0.2701,
          0.0191, -0.2565],
        [ 0.0498,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3084,  0.1481],
        [ 0.2777,  0.1932, -0.2240, -0.2548,  0.1312,  0.0071,  0.1296, -0.3090,
         -0.0190,  0.2229],
        [ 0.1468,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1927,  0.2626,
          0.1742,  0.1105],
        [-0.2426, 

 Parameter containing:
tensor([[-0.1289,  0.0104, -0.1571,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1997,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1014],
        [ 0.2235,  0.0589,  0.0864,  0.3054, -0.1425,  0.1002, -0.1405,  0.2260,
          0.2522, -0.2913],
        [ 0.2700,  0.1511,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0207],
        [-0.0592, -0.1698, -0.0288,  0.2996, -0.0251,  0.0102, -0.0493,  0.0498,
          0.2817,  0.1934],
        [ 0.1122,  0.0688,  0.0746,  0.1224, -0.0408, -0.2939, -0.1954,  0.2700,
          0.0191, -0.2566],
        [ 0.0498,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1482],
        [ 0.2778,  0.1932, -0.2240, -0.2549,  0.1312,  0.0071,  0.1296, -0.3091,
         -0.0190,  0.2229],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1926,  0.2627,
          0.1742,  0.110

 tensor([[ 1.2733e-05,  2.3413e-04,  1.0386e-03, -4.8018e-04, -4.6790e-05,
          1.8418e-04,  3.2187e-04,  2.5558e-04, -9.0504e-04,  3.3164e-04],
        [-3.9649e-04,  4.9305e-04,  2.8896e-04, -2.9445e-04,  3.5262e-04,
         -2.7347e-04,  1.7226e-04, -4.7994e-04, -2.3162e-04,  7.7057e-04],
        [ 2.1057e-03,  2.9421e-04,  1.1700e-04, -1.8120e-03, -1.5440e-03,
          1.0478e-04, -1.8053e-03,  9.7942e-04,  1.0366e-03, -1.0424e-03],
        [-1.7385e-03, -5.5981e-04, -5.2738e-04, -2.1744e-03,  1.5187e-04,
          7.4673e-04, -8.7595e-04, -6.3372e-04, -1.9693e-04,  2.2907e-03],
        [ 1.3361e-03,  1.7509e-03,  3.0088e-04, -6.8712e-04,  1.2779e-03,
         -1.4935e-03, -7.8154e-04, -1.4620e-03,  9.9087e-04, -7.2670e-04],
        [-7.1096e-04, -1.7176e-03, -1.0519e-03, -3.0155e-03, -5.8842e-04,
          1.5736e-03, -1.7080e-03, -1.3599e-03, -2.0733e-03,  3.6716e-03],
        [ 1.0843e-03,  2.0131e-05, -3.0684e-04,  1.6570e-04, -9.0170e-04,
          9.3222e-05,  1.4162e-

 Parameter containing:
tensor([[-0.1289,  0.0104, -0.1572,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1997,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0745,  0.1014],
        [ 0.2235,  0.0589,  0.0864,  0.3054, -0.1425,  0.1002, -0.1405,  0.2260,
          0.2522, -0.2913],
        [ 0.2700,  0.1511,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0206],
        [-0.0592, -0.1698, -0.0288,  0.2996, -0.0251,  0.0102, -0.0493,  0.0498,
          0.2817,  0.1934],
        [ 0.1122,  0.0689,  0.0746,  0.1224, -0.0408, -0.2939, -0.1954,  0.2700,
          0.0191, -0.2566],
        [ 0.0498,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1482],
        [ 0.2778,  0.1932, -0.2240, -0.2549,  0.1312,  0.0071,  0.1296, -0.3091,
         -0.0190,  0.2229],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1926,  0.2625,
          0.1742,  0.110

 tensor([[ 1.2733e-05,  2.3413e-04,  1.0386e-03, -4.8018e-04, -4.6790e-05,
          1.8418e-04,  3.2187e-04,  2.5558e-04, -9.0504e-04,  3.3164e-04],
        [-3.9649e-04,  4.9305e-04,  2.8896e-04, -2.9445e-04,  3.5262e-04,
         -2.7347e-04,  1.7226e-04, -4.7994e-04, -2.3162e-04,  7.7057e-04],
        [ 2.1057e-03,  2.9421e-04,  1.1700e-04, -1.8120e-03, -1.5440e-03,
          1.0478e-04, -1.8053e-03,  9.7942e-04,  1.0366e-03, -1.0424e-03],
        [-1.7385e-03, -5.5981e-04, -5.2738e-04, -2.1744e-03,  1.5187e-04,
          7.4673e-04, -8.7595e-04, -6.3372e-04, -1.9693e-04,  2.2907e-03],
        [ 1.3361e-03,  1.7509e-03,  3.0088e-04, -6.8712e-04,  1.2779e-03,
         -1.4935e-03, -7.8154e-04, -1.4620e-03,  9.9087e-04, -7.2670e-04],
        [-7.1096e-04, -1.7176e-03, -1.0519e-03, -3.0155e-03, -5.8842e-04,
          1.5736e-03, -1.7080e-03, -1.3599e-03, -2.0733e-03,  3.6716e-03],
        [ 1.0843e-03,  2.0131e-05, -3.0684e-04,  1.6570e-04, -9.0170e-04,
          9.3222e-05,  1.4162e-

 tensor([[-0.1289,  0.0104, -0.1572,  0.1193, -0.2696,  0.2318, -0.2299, -0.2515,
         -0.1997,  0.1431],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0746,  0.1014],
        [ 0.2235,  0.0589,  0.0864,  0.3054, -0.1424,  0.1001, -0.1405,  0.2259,
          0.2523, -0.2914],
        [ 0.2701,  0.1511,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0907,
         -0.2226,  0.0207],
        [-0.0592, -0.1698, -0.0288,  0.2996, -0.0251,  0.0102, -0.0493,  0.0498,
          0.2816,  0.1934],
        [ 0.1122,  0.0689,  0.0746,  0.1224, -0.0408, -0.2940, -0.1954,  0.2701,
          0.0191, -0.2565],
        [ 0.0498,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3084,  0.1482],
        [ 0.2777,  0.1932, -0.2240, -0.2548,  0.1312,  0.0071,  0.1296, -0.3090,
         -0.0190,  0.2229],
        [ 0.1468,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1927,  0.2626,
          0.1742,  0.1105],
        [-0.2426, 

 Parameter containing:
tensor([[-0.1289,  0.0104, -0.1572,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1997,  0.1432],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0745,  0.1014],
        [ 0.2235,  0.0589,  0.0864,  0.3054, -0.1425,  0.1002, -0.1405,  0.2260,
          0.2522, -0.2913],
        [ 0.2700,  0.1511,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0206],
        [-0.0592, -0.1698, -0.0288,  0.2996, -0.0251,  0.0102, -0.0493,  0.0498,
          0.2817,  0.1934],
        [ 0.1122,  0.0689,  0.0746,  0.1224, -0.0408, -0.2939, -0.1954,  0.2700,
          0.0191, -0.2566],
        [ 0.0498,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1482],
        [ 0.2778,  0.1932, -0.2240, -0.2549,  0.1312,  0.0071,  0.1296, -0.3091,
         -0.0190,  0.2229],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1926,  0.2625,
          0.1742,  0.110

 tensor([[-1.3232e-04,  6.5565e-04,  3.3355e-04, -5.8508e-04,  2.0766e-04,
         -5.0688e-04,  3.5286e-04,  6.0225e-04, -1.3828e-03,  5.2071e-04],
        [-5.0497e-04,  2.1368e-05,  4.2725e-04, -8.2493e-05,  1.4380e-05,
         -8.2779e-04,  1.3530e-04,  6.0034e-04, -8.0729e-04,  1.2083e-03],
        [ 1.6584e-03, -3.8505e-04,  6.3467e-04, -2.2125e-03, -2.0065e-03,
          7.2622e-04, -2.4700e-03,  5.2118e-04,  1.8358e-03, -1.0996e-03],
        [-5.8365e-04, -5.5695e-04, -4.2391e-04, -7.8821e-04, -1.2341e-03,
          1.4105e-03, -6.5947e-04,  4.6539e-04, -1.2789e-03,  2.8400e-03],
        [ 8.8835e-04,  1.1816e-03,  5.0545e-04, -3.6061e-05,  9.7752e-04,
         -1.5793e-03, -3.7980e-04, -8.7452e-04,  1.1740e-03, -4.2081e-04],
        [ 9.7632e-05, -1.5945e-03, -4.4751e-04, -1.7920e-03, -1.9341e-03,
          7.5626e-04, -1.4162e-03,  8.4114e-04, -3.3569e-03,  5.0964e-03],
        [ 1.0319e-03, -3.1495e-04, -7.5817e-04,  1.4763e-03, -3.2091e-04,
          1.3769e-04, -6.0940e-

 Parameter containing:
tensor([[-0.1289,  0.0104, -0.1572,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1997,  0.1431],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0745,  0.1014],
        [ 0.2234,  0.0589,  0.0864,  0.3054, -0.1423,  0.1002, -0.1405,  0.2260,
          0.2522, -0.2913],
        [ 0.2700,  0.1511,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0206],
        [-0.0592, -0.1698, -0.0288,  0.2996, -0.0251,  0.0102, -0.0493,  0.0498,
          0.2817,  0.1934],
        [ 0.1122,  0.0689,  0.0746,  0.1224, -0.0408, -0.2939, -0.1954,  0.2700,
          0.0192, -0.2566],
        [ 0.0498,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1482],
        [ 0.2778,  0.1932, -0.2240, -0.2549,  0.1312,  0.0071,  0.1296, -0.3091,
         -0.0190,  0.2229],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1927,  0.2625,
          0.1742,  0.110

 tensor([[-1.3232e-04,  6.5565e-04,  3.3355e-04, -5.8508e-04,  2.0766e-04,
         -5.0688e-04,  3.5286e-04,  6.0225e-04, -1.3828e-03,  5.2071e-04],
        [-5.0497e-04,  2.1368e-05,  4.2725e-04, -8.2493e-05,  1.4380e-05,
         -8.2779e-04,  1.3530e-04,  6.0034e-04, -8.0729e-04,  1.2083e-03],
        [ 1.6584e-03, -3.8505e-04,  6.3467e-04, -2.2125e-03, -2.0065e-03,
          7.2622e-04, -2.4700e-03,  5.2118e-04,  1.8358e-03, -1.0996e-03],
        [-5.8365e-04, -5.5695e-04, -4.2391e-04, -7.8821e-04, -1.2341e-03,
          1.4105e-03, -6.5947e-04,  4.6539e-04, -1.2789e-03,  2.8400e-03],
        [ 8.8835e-04,  1.1816e-03,  5.0545e-04, -3.6061e-05,  9.7752e-04,
         -1.5793e-03, -3.7980e-04, -8.7452e-04,  1.1740e-03, -4.2081e-04],
        [ 9.7632e-05, -1.5945e-03, -4.4751e-04, -1.7920e-03, -1.9341e-03,
          7.5626e-04, -1.4162e-03,  8.4114e-04, -3.3569e-03,  5.0964e-03],
        [ 1.0319e-03, -3.1495e-04, -7.5817e-04,  1.4763e-03, -3.2091e-04,
          1.3769e-04, -6.0940e-

 tensor([[-0.1289,  0.0104, -0.1572,  0.1193, -0.2696,  0.2318, -0.2299, -0.2515,
         -0.1997,  0.1431],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0745,  0.1014],
        [ 0.2234,  0.0589,  0.0864,  0.3054, -0.1424,  0.1001, -0.1405,  0.2259,
          0.2523, -0.2913],
        [ 0.2701,  0.1511,  0.1378,  0.1302,  0.2630, -0.0418, -0.2673, -0.0907,
         -0.2226,  0.0206],
        [-0.0592, -0.1698, -0.0288,  0.2996, -0.0251,  0.0102, -0.0493,  0.0498,
          0.2816,  0.1934],
        [ 0.1122,  0.0689,  0.0747,  0.1224, -0.0408, -0.2940, -0.1954,  0.2701,
          0.0192, -0.2566],
        [ 0.0498,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3084,  0.1482],
        [ 0.2777,  0.1932, -0.2240, -0.2548,  0.1312,  0.0071,  0.1296, -0.3090,
         -0.0190,  0.2229],
        [ 0.1468,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1927,  0.2626,
          0.1742,  0.1105],
        [-0.2426, 

 Parameter containing:
tensor([[-0.1289,  0.0104, -0.1572,  0.1193, -0.2695,  0.2318, -0.2299, -0.2515,
         -0.1997,  0.1431],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0745,  0.1014],
        [ 0.2234,  0.0589,  0.0864,  0.3054, -0.1423,  0.1002, -0.1405,  0.2260,
          0.2522, -0.2913],
        [ 0.2700,  0.1511,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0206],
        [-0.0592, -0.1698, -0.0288,  0.2996, -0.0251,  0.0102, -0.0493,  0.0498,
          0.2817,  0.1934],
        [ 0.1122,  0.0689,  0.0746,  0.1224, -0.0408, -0.2939, -0.1954,  0.2700,
          0.0192, -0.2566],
        [ 0.0498,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1482],
        [ 0.2778,  0.1932, -0.2240, -0.2549,  0.1312,  0.0071,  0.1296, -0.3091,
         -0.0190,  0.2229],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1927,  0.2625,
          0.1742,  0.110

 tensor([[-3.3927e-04,  4.4012e-04,  7.6151e-04, -5.4026e-04,  4.9496e-04,
          9.8765e-05,  6.3610e-04,  7.7868e-04, -1.2131e-03,  3.9220e-04],
        [-4.5609e-04,  8.0347e-04,  4.5848e-04, -1.0073e-05,  4.5508e-05,
         -6.0797e-04,  8.8501e-04, -9.7454e-06, -5.1117e-04,  1.0700e-03],
        [ 2.1038e-03,  5.3853e-05,  1.4124e-03, -2.4376e-03, -2.3880e-03,
          1.0300e-03, -3.2310e-03,  8.3685e-04,  1.6203e-03, -1.4095e-03],
        [-1.3418e-03, -1.1024e-03, -7.6008e-04, -7.4863e-04, -6.0177e-04,
          4.0793e-04, -3.5453e-04, -2.2507e-04, -3.2973e-04,  2.3460e-03],
        [ 9.8324e-04,  1.5516e-03, -3.2330e-04,  1.5602e-05,  1.3008e-03,
         -1.4162e-03,  1.3041e-04, -4.6277e-04,  1.1120e-03, -7.1049e-04],
        [-2.1291e-04, -1.9627e-03, -1.4124e-03, -2.4948e-03, -4.1246e-04,
          9.8133e-04, -8.4829e-04, -9.4748e-04, -2.5520e-03,  4.3831e-03],
        [ 1.0204e-03,  6.6805e-04, -1.6272e-04,  3.7289e-04, -8.8739e-04,
          3.1447e-04,  8.0824e-

 Parameter containing:
tensor([[-0.1289,  0.0104, -0.1572,  0.1193, -0.2695,  0.2318, -0.2300, -0.2515,
         -0.1997,  0.1431],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0745,  0.1014],
        [ 0.2234,  0.0589,  0.0864,  0.3054, -0.1423,  0.1001, -0.1405,  0.2260,
          0.2522, -0.2913],
        [ 0.2703,  0.1511,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0206],
        [-0.0592, -0.1698, -0.0288,  0.2996, -0.0251,  0.0102, -0.0493,  0.0498,
          0.2815,  0.1934],
        [ 0.1122,  0.0689,  0.0746,  0.1224, -0.0408, -0.2939, -0.1954,  0.2700,
          0.0192, -0.2566],
        [ 0.0498,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1482],
        [ 0.2778,  0.1932, -0.2240, -0.2546,  0.1312,  0.0071,  0.1296, -0.3091,
         -0.0190,  0.2229],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1927,  0.2625,
          0.1742,  0.110

 tensor([[-3.3927e-04,  4.4012e-04,  7.6151e-04, -5.4026e-04,  4.9496e-04,
          9.8765e-05,  6.3610e-04,  7.7868e-04, -1.2131e-03,  3.9220e-04],
        [-4.5609e-04,  8.0347e-04,  4.5848e-04, -1.0073e-05,  4.5508e-05,
         -6.0797e-04,  8.8501e-04, -9.7454e-06, -5.1117e-04,  1.0700e-03],
        [ 2.1038e-03,  5.3853e-05,  1.4124e-03, -2.4376e-03, -2.3880e-03,
          1.0300e-03, -3.2310e-03,  8.3685e-04,  1.6203e-03, -1.4095e-03],
        [-1.3418e-03, -1.1024e-03, -7.6008e-04, -7.4863e-04, -6.0177e-04,
          4.0793e-04, -3.5453e-04, -2.2507e-04, -3.2973e-04,  2.3460e-03],
        [ 9.8324e-04,  1.5516e-03, -3.2330e-04,  1.5602e-05,  1.3008e-03,
         -1.4162e-03,  1.3041e-04, -4.6277e-04,  1.1120e-03, -7.1049e-04],
        [-2.1291e-04, -1.9627e-03, -1.4124e-03, -2.4948e-03, -4.1246e-04,
          9.8133e-04, -8.4829e-04, -9.4748e-04, -2.5520e-03,  4.3831e-03],
        [ 1.0204e-03,  6.6805e-04, -1.6272e-04,  3.7289e-04, -8.8739e-04,
          3.1447e-04,  8.0824e-

 tensor([[-0.1289,  0.0104, -0.1572,  0.1193, -0.2696,  0.2318, -0.2299, -0.2515,
         -0.1997,  0.1431],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0745,  0.1014],
        [ 0.2234,  0.0589,  0.0864,  0.3054, -0.1424,  0.1001, -0.1405,  0.2259,
          0.2523, -0.2913],
        [ 0.2702,  0.1511,  0.1378,  0.1302,  0.2630, -0.0418, -0.2673, -0.0907,
         -0.2226,  0.0206],
        [-0.0592, -0.1699, -0.0288,  0.2996, -0.0251,  0.0102, -0.0493,  0.0498,
          0.2816,  0.1934],
        [ 0.1122,  0.0689,  0.0747,  0.1224, -0.0408, -0.2940, -0.1954,  0.2701,
          0.0192, -0.2566],
        [ 0.0498,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3084,  0.1482],
        [ 0.2777,  0.1932, -0.2240, -0.2548,  0.1312,  0.0071,  0.1296, -0.3090,
         -0.0190,  0.2229],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1927,  0.2625,
          0.1742,  0.1105],
        [-0.2426, 

 tensor([[-3.1805e-04,  3.3641e-04,  7.0047e-04, -9.9277e-04,  8.4019e-04,
          3.9363e-04,  3.8695e-04,  4.7326e-04, -9.3746e-04,  9.9850e-04],
        [-7.1812e-04, -4.7266e-05, -3.9130e-05,  6.5136e-04,  1.2732e-04,
         -6.8426e-04, -1.8752e-04, -7.5817e-05, -5.3942e-05,  9.6607e-04],
        [ 1.2875e-03,  8.8513e-05, -7.1573e-04, -2.0790e-03, -1.4744e-03,
         -3.7193e-04, -2.8381e-03, -7.2777e-05,  1.5411e-03, -1.5793e-03],
        [ 2.8968e-04, -8.4686e-04,  6.3419e-04, -8.6546e-04,  1.4520e-04,
          3.8409e-04, -3.9721e-04, -9.8610e-04, -1.0777e-04,  1.9836e-03],
        [ 1.6756e-03,  1.7681e-03,  6.2370e-04, -4.6611e-04,  2.1782e-03,
         -1.7214e-03, -6.9618e-04, -1.5583e-03,  8.9455e-04, -1.0234e-04],
        [ 1.4076e-03, -2.3785e-03, -3.2276e-05, -2.5616e-03, -5.0962e-05,
          2.5105e-04, -1.5488e-03, -6.3610e-04, -2.0714e-03,  4.2343e-03],
        [ 3.2544e-04, -5.2303e-05, -7.6675e-04,  1.1206e-03, -1.6823e-03,
         -6.3896e-04, -3.3069e-

 tensor([[-0.1289,  0.0104, -0.1572,  0.1193, -0.2696,  0.2318, -0.2299, -0.2515,
         -0.1997,  0.1431],
        [-0.1167,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0745,  0.1014],
        [ 0.2234,  0.0589,  0.0864,  0.3054, -0.1423,  0.1001, -0.1404,  0.2259,
          0.2523, -0.2913],
        [ 0.2702,  0.1511,  0.1378,  0.1302,  0.2630, -0.0418, -0.2673, -0.0907,
         -0.2226,  0.0206],
        [-0.0592, -0.1699, -0.0288,  0.2996, -0.0252,  0.0102, -0.0493,  0.0498,
          0.2816,  0.1934],
        [ 0.1122,  0.0689,  0.0747,  0.1225, -0.0408, -0.2940, -0.1954,  0.2701,
          0.0192, -0.2566],
        [ 0.0498,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3084,  0.1482],
        [ 0.2778,  0.1932, -0.2241, -0.2548,  0.1312,  0.0071,  0.1296, -0.3090,
         -0.0190,  0.2228],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1927,  0.2625,
          0.1742,  0.1105],
        [-0.2426, 

 Parameter containing:
tensor([[-0.1289,  0.0104, -0.1572,  0.1193, -0.2695,  0.2318, -0.2300, -0.2515,
         -0.1997,  0.1431],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0745,  0.1014],
        [ 0.2234,  0.0589,  0.0864,  0.3054, -0.1423,  0.1001, -0.1404,  0.2260,
          0.2522, -0.2913],
        [ 0.2703,  0.1511,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0907,
         -0.2227,  0.0206],
        [-0.0592, -0.1699, -0.0288,  0.2996, -0.0251,  0.0102, -0.0493,  0.0498,
          0.2815,  0.1934],
        [ 0.1122,  0.0690,  0.0746,  0.1225, -0.0408, -0.2939, -0.1954,  0.2700,
          0.0192, -0.2566],
        [ 0.0498,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1482],
        [ 0.2778,  0.1932, -0.2241, -0.2546,  0.1312,  0.0071,  0.1296, -0.3091,
         -0.0190,  0.2228],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1927,  0.2625,
          0.1742,  0.110

 tensor([[-3.1805e-04,  3.3641e-04,  7.0047e-04, -9.9277e-04,  8.4019e-04,
          3.9363e-04,  3.8695e-04,  4.7326e-04, -9.3746e-04,  9.9850e-04],
        [-7.1812e-04, -4.7266e-05, -3.9130e-05,  6.5136e-04,  1.2732e-04,
         -6.8426e-04, -1.8752e-04, -7.5817e-05, -5.3942e-05,  9.6607e-04],
        [ 1.2875e-03,  8.8513e-05, -7.1573e-04, -2.0790e-03, -1.4744e-03,
         -3.7193e-04, -2.8381e-03, -7.2777e-05,  1.5411e-03, -1.5793e-03],
        [ 2.8968e-04, -8.4686e-04,  6.3419e-04, -8.6546e-04,  1.4520e-04,
          3.8409e-04, -3.9721e-04, -9.8610e-04, -1.0777e-04,  1.9836e-03],
        [ 1.6756e-03,  1.7681e-03,  6.2370e-04, -4.6611e-04,  2.1782e-03,
         -1.7214e-03, -6.9618e-04, -1.5583e-03,  8.9455e-04, -1.0234e-04],
        [ 1.4076e-03, -2.3785e-03, -3.2276e-05, -2.5616e-03, -5.0962e-05,
          2.5105e-04, -1.5488e-03, -6.3610e-04, -2.0714e-03,  4.2343e-03],
        [ 3.2544e-04, -5.2303e-05, -7.6675e-04,  1.1206e-03, -1.6823e-03,
         -6.3896e-04, -3.3069e-

 tensor([[-0.1289,  0.0104, -0.1572,  0.1193, -0.2696,  0.2318, -0.2299, -0.2515,
         -0.1997,  0.1431],
        [-0.1167,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0745,  0.1014],
        [ 0.2234,  0.0589,  0.0864,  0.3054, -0.1423,  0.1001, -0.1404,  0.2259,
          0.2523, -0.2913],
        [ 0.2702,  0.1511,  0.1378,  0.1302,  0.2630, -0.0418, -0.2673, -0.0907,
         -0.2226,  0.0206],
        [-0.0592, -0.1699, -0.0288,  0.2996, -0.0252,  0.0102, -0.0493,  0.0498,
          0.2816,  0.1934],
        [ 0.1122,  0.0689,  0.0747,  0.1225, -0.0408, -0.2940, -0.1954,  0.2701,
          0.0192, -0.2566],
        [ 0.0498,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3084,  0.1482],
        [ 0.2778,  0.1932, -0.2241, -0.2548,  0.1312,  0.0071,  0.1296, -0.3090,
         -0.0190,  0.2228],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1927,  0.2625,
          0.1742,  0.1105],
        [-0.2426, 

 Parameter containing:
tensor([[-0.1289,  0.0104, -0.1572,  0.1193, -0.2695,  0.2318, -0.2300, -0.2515,
         -0.1997,  0.1431],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0745,  0.1014],
        [ 0.2234,  0.0589,  0.0864,  0.3054, -0.1423,  0.1001, -0.1404,  0.2260,
          0.2522, -0.2913],
        [ 0.2703,  0.1511,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0907,
         -0.2227,  0.0206],
        [-0.0592, -0.1699, -0.0288,  0.2996, -0.0251,  0.0102, -0.0493,  0.0498,
          0.2815,  0.1934],
        [ 0.1122,  0.0690,  0.0746,  0.1225, -0.0408, -0.2939, -0.1954,  0.2700,
          0.0192, -0.2566],
        [ 0.0498,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1482],
        [ 0.2778,  0.1932, -0.2241, -0.2546,  0.1312,  0.0071,  0.1296, -0.3091,
         -0.0190,  0.2228],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1927,  0.2625,
          0.1742,  0.110

 tensor([[-7.7307e-05,  5.0211e-04,  7.6675e-04, -3.9768e-04,  4.3607e-04,
          3.6573e-04,  3.9148e-04,  1.4150e-04, -2.7108e-04,  3.8815e-04],
        [-5.6839e-04,  1.7130e-04,  4.6182e-04,  2.0730e-04, -1.9896e-04,
         -1.8942e-04, -1.5989e-05, -2.3234e-04, -9.8944e-05,  9.9850e-04],
        [ 2.3727e-03, -1.8030e-05, -5.6773e-05, -1.2398e-03, -1.5497e-03,
         -4.2152e-04, -2.0275e-03, -1.7405e-04,  6.3086e-04, -1.8501e-03],
        [-1.0395e-03, -1.0023e-03,  6.5994e-04, -5.8508e-04,  3.1638e-04,
          6.8140e-04, -2.0778e-04,  1.0437e-04,  5.4455e-04,  3.1490e-03],
        [ 1.7223e-03,  1.4305e-03,  9.3126e-04, -5.8556e-04,  1.6527e-03,
         -9.1553e-04, -3.0160e-04, -3.7837e-04,  7.0429e-04,  5.0545e-04],
        [ 5.5265e-04, -1.7061e-03,  2.1434e-04, -1.4439e-03, -4.4084e-04,
          1.1015e-03, -7.5579e-04, -2.4271e-04, -1.5564e-03,  4.6501e-03],
        [ 4.8923e-04, -3.3522e-04, -9.3269e-04,  8.4829e-04, -9.6703e-04,
         -1.1784e-04, -2.7585e-

 Parameter containing:
tensor([[-0.1289,  0.0104, -0.1572,  0.1193, -0.2695,  0.2318, -0.2300, -0.2515,
         -0.1997,  0.1431],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0745,  0.1014],
        [ 0.2234,  0.0589,  0.0864,  0.3054, -0.1423,  0.1002, -0.1404,  0.2260,
          0.2522, -0.2913],
        [ 0.2703,  0.1511,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0206],
        [-0.0592, -0.1699, -0.0288,  0.2996, -0.0252,  0.0102, -0.0493,  0.0498,
          0.2815,  0.1934],
        [ 0.1122,  0.0690,  0.0746,  0.1225, -0.0408, -0.2939, -0.1953,  0.2700,
          0.0192, -0.2566],
        [ 0.0497,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1482],
        [ 0.2778,  0.1932, -0.2241, -0.2546,  0.1311,  0.0071,  0.1296, -0.3091,
         -0.0190,  0.2228],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1927,  0.2625,
          0.1742,  0.110

 tensor([[-7.7307e-05,  5.0211e-04,  7.6675e-04, -3.9768e-04,  4.3607e-04,
          3.6573e-04,  3.9148e-04,  1.4150e-04, -2.7108e-04,  3.8815e-04],
        [-5.6839e-04,  1.7130e-04,  4.6182e-04,  2.0730e-04, -1.9896e-04,
         -1.8942e-04, -1.5989e-05, -2.3234e-04, -9.8944e-05,  9.9850e-04],
        [ 2.3727e-03, -1.8030e-05, -5.6773e-05, -1.2398e-03, -1.5497e-03,
         -4.2152e-04, -2.0275e-03, -1.7405e-04,  6.3086e-04, -1.8501e-03],
        [-1.0395e-03, -1.0023e-03,  6.5994e-04, -5.8508e-04,  3.1638e-04,
          6.8140e-04, -2.0778e-04,  1.0437e-04,  5.4455e-04,  3.1490e-03],
        [ 1.7223e-03,  1.4305e-03,  9.3126e-04, -5.8556e-04,  1.6527e-03,
         -9.1553e-04, -3.0160e-04, -3.7837e-04,  7.0429e-04,  5.0545e-04],
        [ 5.5265e-04, -1.7061e-03,  2.1434e-04, -1.4439e-03, -4.4084e-04,
          1.1015e-03, -7.5579e-04, -2.4271e-04, -1.5564e-03,  4.6501e-03],
        [ 4.8923e-04, -3.3522e-04, -9.3269e-04,  8.4829e-04, -9.6703e-04,
         -1.1784e-04, -2.7585e-

 tensor([[-0.1289,  0.0104, -0.1572,  0.1193, -0.2696,  0.2318, -0.2299, -0.2515,
         -0.1997,  0.1431],
        [-0.1167,  0.1183, -0.2684, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0745,  0.1014],
        [ 0.2234,  0.0589,  0.0864,  0.3054, -0.1423,  0.1001, -0.1404,  0.2259,
          0.2523, -0.2913],
        [ 0.2702,  0.1511,  0.1378,  0.1303,  0.2630, -0.0419, -0.2673, -0.0907,
         -0.2226,  0.0205],
        [-0.0592, -0.1699, -0.0288,  0.2996, -0.0252,  0.0102, -0.0493,  0.0498,
          0.2816,  0.1934],
        [ 0.1122,  0.0690,  0.0747,  0.1225, -0.0408, -0.2941, -0.1954,  0.2701,
          0.0192, -0.2567],
        [ 0.0498,  0.2612, -0.2988, -0.2130, -0.1258,  0.0127, -0.0737, -0.0348,
         -0.3084,  0.1482],
        [ 0.2778,  0.1932, -0.2241, -0.2547,  0.1311,  0.0071,  0.1296, -0.3090,
         -0.0190,  0.2228],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1927,  0.2625,
          0.1742,  0.1105],
        [-0.2426, 

 Parameter containing:
tensor([[-0.1289,  0.0104, -0.1572,  0.1193, -0.2695,  0.2318, -0.2300, -0.2515,
         -0.1997,  0.1431],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0745,  0.1014],
        [ 0.2234,  0.0589,  0.0864,  0.3054, -0.1423,  0.1002, -0.1404,  0.2260,
          0.2522, -0.2913],
        [ 0.2703,  0.1511,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0908,
         -0.2227,  0.0206],
        [-0.0592, -0.1699, -0.0288,  0.2996, -0.0252,  0.0102, -0.0493,  0.0498,
          0.2815,  0.1934],
        [ 0.1122,  0.0690,  0.0746,  0.1225, -0.0408, -0.2939, -0.1953,  0.2700,
          0.0192, -0.2566],
        [ 0.0497,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1482],
        [ 0.2778,  0.1932, -0.2241, -0.2546,  0.1311,  0.0071,  0.1296, -0.3091,
         -0.0190,  0.2228],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1927,  0.2625,
          0.1742,  0.110

 tensor([[ 1.2034e-04,  8.8787e-04,  9.4032e-04, -6.8808e-04,  9.1982e-04,
         -9.2924e-05,  3.6073e-04,  6.0129e-04, -4.9829e-04,  6.4373e-04],
        [-6.2656e-04, -1.1843e-04,  1.9801e-04,  1.1092e-04,  3.6478e-04,
         -4.0174e-04,  4.3344e-04,  9.3058e-06, -6.3133e-04,  1.1978e-03],
        [ 1.4753e-03,  2.3413e-04,  3.2282e-04, -1.5507e-03, -1.6632e-03,
          5.3787e-04, -2.4071e-03,  3.2854e-04,  1.6727e-03, -1.5297e-03],
        [-1.0357e-03, -1.2121e-03,  4.7445e-04, -5.7507e-04, -3.1441e-05,
          2.2352e-04, -8.4829e-04, -6.0034e-04, -4.0460e-04,  2.6073e-03],
        [ 8.9359e-04,  1.1358e-03,  1.7476e-04, -3.4595e-04,  1.5564e-03,
         -1.4372e-03, -6.6459e-05, -1.1597e-03,  1.4877e-03, -2.2888e-04],
        [ 5.3930e-04, -1.1578e-03, -4.7898e-04, -1.4153e-03, -4.8685e-04,
          3.2163e-04, -1.3952e-03, -1.9798e-03, -1.8225e-03,  3.1509e-03],
        [ 3.0875e-04,  9.5010e-05, -4.4727e-04,  8.1778e-04, -8.9693e-04,
          1.3876e-04,  7.2658e-

 tensor([[-0.1289,  0.0104, -0.1572,  0.1194, -0.2696,  0.2318, -0.2299, -0.2515,
         -0.1997,  0.1431],
        [-0.1167,  0.1183, -0.2684, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0745,  0.1014],
        [ 0.2234,  0.0589,  0.0864,  0.3055, -0.1423,  0.1001, -0.1404,  0.2259,
          0.2522, -0.2913],
        [ 0.2702,  0.1511,  0.1378,  0.1303,  0.2630, -0.0419, -0.2673, -0.0907,
         -0.2226,  0.0205],
        [-0.0592, -0.1699, -0.0288,  0.2996, -0.0252,  0.0103, -0.0493,  0.0499,
          0.2816,  0.1934],
        [ 0.1122,  0.0690,  0.0747,  0.1225, -0.0408, -0.2941, -0.1954,  0.2701,
          0.0192, -0.2567],
        [ 0.0498,  0.2612, -0.2987, -0.2130, -0.1258,  0.0127, -0.0737, -0.0348,
         -0.3084,  0.1482],
        [ 0.2778,  0.1932, -0.2241, -0.2547,  0.1311,  0.0071,  0.1296, -0.3089,
         -0.0190,  0.2228],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1739, -0.1195, -0.1927,  0.2625,
          0.1742,  0.1105],
        [-0.2426, 

 tensor([[ 1.2034e-04,  8.8787e-04,  9.4032e-04, -6.8808e-04,  9.1982e-04,
         -9.2924e-05,  3.6073e-04,  6.0129e-04, -4.9829e-04,  6.4373e-04],
        [-6.2656e-04, -1.1843e-04,  1.9801e-04,  1.1092e-04,  3.6478e-04,
         -4.0174e-04,  4.3344e-04,  9.3058e-06, -6.3133e-04,  1.1978e-03],
        [ 1.4753e-03,  2.3413e-04,  3.2282e-04, -1.5507e-03, -1.6632e-03,
          5.3787e-04, -2.4071e-03,  3.2854e-04,  1.6727e-03, -1.5297e-03],
        [-1.0357e-03, -1.2121e-03,  4.7445e-04, -5.7507e-04, -3.1441e-05,
          2.2352e-04, -8.4829e-04, -6.0034e-04, -4.0460e-04,  2.6073e-03],
        [ 8.9359e-04,  1.1358e-03,  1.7476e-04, -3.4595e-04,  1.5564e-03,
         -1.4372e-03, -6.6459e-05, -1.1597e-03,  1.4877e-03, -2.2888e-04],
        [ 5.3930e-04, -1.1578e-03, -4.7898e-04, -1.4153e-03, -4.8685e-04,
          3.2163e-04, -1.3952e-03, -1.9798e-03, -1.8225e-03,  3.1509e-03],
        [ 3.0875e-04,  9.5010e-05, -4.4727e-04,  8.1778e-04, -8.9693e-04,
          1.3876e-04,  7.2658e-

 tensor([[-0.1289,  0.0104, -0.1572,  0.1194, -0.2696,  0.2318, -0.2299, -0.2515,
         -0.1997,  0.1431],
        [-0.1167,  0.1183, -0.2684, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0745,  0.1014],
        [ 0.2234,  0.0589,  0.0864,  0.3055, -0.1423,  0.1001, -0.1404,  0.2259,
          0.2522, -0.2913],
        [ 0.2702,  0.1511,  0.1378,  0.1303,  0.2630, -0.0419, -0.2673, -0.0907,
         -0.2226,  0.0205],
        [-0.0592, -0.1699, -0.0288,  0.2996, -0.0252,  0.0103, -0.0493,  0.0499,
          0.2816,  0.1934],
        [ 0.1122,  0.0690,  0.0747,  0.1225, -0.0408, -0.2941, -0.1954,  0.2701,
          0.0192, -0.2567],
        [ 0.0498,  0.2612, -0.2987, -0.2130, -0.1258,  0.0127, -0.0737, -0.0348,
         -0.3084,  0.1482],
        [ 0.2778,  0.1932, -0.2241, -0.2547,  0.1311,  0.0071,  0.1296, -0.3089,
         -0.0190,  0.2228],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1739, -0.1195, -0.1927,  0.2625,
          0.1742,  0.1105],
        [-0.2426, 

 Parameter containing:
tensor([[-0.1289,  0.0104, -0.1572,  0.1194, -0.2695,  0.2318, -0.2300, -0.2515,
         -0.1997,  0.1431],
        [-0.1168,  0.1183, -0.2683, -0.1919, -0.1161, -0.0620, -0.2412,  0.2070,
         -0.0745,  0.1014],
        [ 0.2234,  0.0589,  0.0864,  0.3054, -0.1423,  0.1001, -0.1404,  0.2260,
          0.2522, -0.2913],
        [ 0.2703,  0.1511,  0.1378,  0.1302,  0.2629, -0.0418, -0.2673, -0.0907,
         -0.2227,  0.0205],
        [-0.0592, -0.1699, -0.0288,  0.2996, -0.0252,  0.0103, -0.0493,  0.0499,
          0.2815,  0.1934],
        [ 0.1122,  0.0690,  0.0746,  0.1225, -0.0408, -0.2939, -0.1953,  0.2700,
          0.0192, -0.2568],
        [ 0.0497,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1482],
        [ 0.2778,  0.1932, -0.2241, -0.2546,  0.1311,  0.0071,  0.1296, -0.3088,
         -0.0190,  0.2228],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1738, -0.1195, -0.1927,  0.2625,
          0.1742,  0.110

 tensor([[-2.4557e-04,  5.7745e-04,  6.0987e-04, -8.3590e-04,  6.9284e-04,
         -1.4210e-04,  8.7500e-04,  1.0719e-03, -1.3723e-03,  7.3099e-04],
        [-6.4468e-04, -3.0947e-04,  2.2054e-04,  2.9162e-05,  1.3387e-04,
         -5.8508e-04,  1.9237e-05,  3.2902e-04, -2.6965e-04,  8.4019e-04],
        [ 1.4696e-03, -4.5085e-04,  1.5745e-03, -1.9321e-03, -2.3270e-03,
          6.7711e-04, -2.5139e-03,  4.1771e-04,  1.7357e-03, -1.1797e-03],
        [-2.1439e-03, -4.7112e-04, -6.4087e-04, -5.8365e-04, -2.4414e-04,
          9.8419e-04, -3.5405e-04, -3.2496e-04,  2.7013e-04,  1.9245e-03],
        [ 1.1406e-03,  1.4896e-03,  1.2910e-04, -3.2473e-04,  1.4830e-03,
         -1.4439e-03, -2.6271e-05, -4.2844e-04,  8.0538e-04, -1.5712e-04],
        [-4.0388e-04, -1.3018e-03, -1.1425e-03, -1.9350e-03, -1.1854e-03,
          4.8971e-04, -1.3170e-03, -7.9870e-06, -3.2864e-03,  3.9253e-03],
        [ 6.1560e-04,  2.7037e-04, -6.7282e-04,  5.5075e-04, -5.1737e-04,
          2.1982e-04, -2.7776e-

 tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], device='cuda:0',
       dtype=torch.float16)
optimizer weights:
 tensor([[-0.1289,  0.0104, -0.1572,  0.1194, -0.2696,  0.2318, -0.2299, -0.2516,
         -0.1997,  0.1431],
        [-0.1167,  0.1183, -0.2684, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0745,  0.1013],
        [ 0.2234,  0.0589,  0.0864,  0.3055, -0.1423,  0.1001, -0.1404,  0.2259,
          0.2522, -0.2913],
        [ 0.2702,  0.1511,  0.1378,  0.1303,  0.2630, -0.0419, -0.2673, -0.0907,
         -0.2226,  0.0205

 tensor([[-2.4557e-04,  5.7745e-04,  6.0987e-04, -8.3590e-04,  6.9284e-04,
         -1.4210e-04,  8.7500e-04,  1.0719e-03, -1.3723e-03,  7.3099e-04],
        [-6.4468e-04, -3.0947e-04,  2.2054e-04,  2.9162e-05,  1.3387e-04,
         -5.8508e-04,  1.9237e-05,  3.2902e-04, -2.6965e-04,  8.4019e-04],
        [ 1.4696e-03, -4.5085e-04,  1.5745e-03, -1.9321e-03, -2.3270e-03,
          6.7711e-04, -2.5139e-03,  4.1771e-04,  1.7357e-03, -1.1797e-03],
        [-2.1439e-03, -4.7112e-04, -6.4087e-04, -5.8365e-04, -2.4414e-04,
          9.8419e-04, -3.5405e-04, -3.2496e-04,  2.7013e-04,  1.9245e-03],
        [ 1.1406e-03,  1.4896e-03,  1.2910e-04, -3.2473e-04,  1.4830e-03,
         -1.4439e-03, -2.6271e-05, -4.2844e-04,  8.0538e-04, -1.5712e-04],
        [-4.0388e-04, -1.3018e-03, -1.1425e-03, -1.9350e-03, -1.1854e-03,
          4.8971e-04, -1.3170e-03, -7.9870e-06, -3.2864e-03,  3.9253e-03],
        [ 6.1560e-04,  2.7037e-04, -6.7282e-04,  5.5075e-04, -5.1737e-04,
          2.1982e-04, -2.7776e-

 tensor([[-0.1254,  0.2778,  0.0470, -0.2379,  0.1655,  0.2629,  0.3984,  0.4397,
         -0.4978,  0.5156],
        [-0.3635,  0.5586,  0.1519,  0.1465,  0.1143, -0.2595,  0.3267, -0.2163,
          0.0527,  0.4858],
        [ 0.4846, -0.1309,  0.5200, -1.0469, -0.8076, -0.1471, -1.2012,  0.4534,
          0.6675, -1.0488],
        [-0.3987, -0.4922,  0.5439, -0.1608,  0.0697,  0.0759, -0.5146,  0.1414,
          0.4214,  1.4785],
        [ 0.8101,  0.7642,  0.5063, -0.3557,  0.8501, -0.8330, -0.3691, -0.7109,
          0.5571, -0.2147],
        [ 0.3428, -0.4871,  0.0047, -0.8901, -0.0477, -0.0279, -0.9131, -0.3650,
         -1.1680,  2.3242],
        [ 0.5889,  0.1029, -0.2484,  0.5308, -0.4375, -0.1172,  0.0073, -0.1892,
          0.3738, -0.6880],
        [-0.7563,  0.1937,  0.7822, -0.5171,  0.6587, -0.2952, -0.0143, -0.5474,
          0.5356,  0.5796],
        [-0.6416,  0.2986, -0.1970,  0.6704,  0.0980,  0.1558,  0.4487,  0.8794,
         -0.5479,  0.4041],
        [-0.0889, 

 tensor([[-0.1289,  0.0104, -0.1572,  0.1194, -0.2696,  0.2318, -0.2299, -0.2516,
         -0.1997,  0.1431],
        [-0.1167,  0.1183, -0.2684, -0.1919, -0.1161, -0.0621, -0.2412,  0.2070,
         -0.0745,  0.1013],
        [ 0.2234,  0.0589,  0.0864,  0.3055, -0.1423,  0.1001, -0.1404,  0.2259,
          0.2522, -0.2913],
        [ 0.2702,  0.1511,  0.1378,  0.1303,  0.2630, -0.0419, -0.2673, -0.0907,
         -0.2226,  0.0205],
        [-0.0593, -0.1699, -0.0288,  0.2996, -0.0252,  0.0103, -0.0493,  0.0499,
          0.2816,  0.1934],
        [ 0.1122,  0.0690,  0.0747,  0.1225, -0.0408, -0.2941, -0.1953,  0.2701,
          0.0193, -0.2568],
        [ 0.0497,  0.2612, -0.2987, -0.2130, -0.1258,  0.0127, -0.0737, -0.0347,
         -0.3084,  0.1482],
        [ 0.2778,  0.1932, -0.2241, -0.2547,  0.1311,  0.0071,  0.1296, -0.3089,
         -0.0190,  0.2228],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1739, -0.1195, -0.1927,  0.2625,
          0.1742,  0.1105],
        [-0.2426, 

 tensor([[-2.4486e-04,  5.4264e-04,  9.1851e-05, -4.6468e-04,  3.2330e-04,
          5.1355e-04,  7.7820e-04,  8.5878e-04, -9.7227e-04,  1.0071e-03],
        [-7.1001e-04,  1.0910e-03,  2.9659e-04,  2.8610e-04,  2.2316e-04,
         -5.0688e-04,  6.3801e-04, -4.2248e-04,  1.0300e-04,  9.4891e-04],
        [ 9.4652e-04, -2.5558e-04,  1.0157e-03, -2.0447e-03, -1.5774e-03,
         -2.8729e-04, -2.3460e-03,  8.8549e-04,  1.3037e-03, -2.0485e-03],
        [-7.7868e-04, -9.6130e-04,  1.0624e-03, -3.1400e-04,  1.3614e-04,
          1.4818e-04, -1.0052e-03,  2.7609e-04,  8.2302e-04,  2.8877e-03],
        [ 1.5821e-03,  1.4925e-03,  9.8896e-04, -6.9475e-04,  1.6603e-03,
         -1.6270e-03, -7.2098e-04, -1.3885e-03,  1.0881e-03, -4.1938e-04],
        [ 6.6948e-04, -9.5129e-04,  9.2760e-06, -1.7385e-03, -9.3102e-05,
         -5.4419e-05, -1.7834e-03, -7.1287e-04, -2.2812e-03,  4.5395e-03],
        [ 1.1501e-03,  2.0099e-04, -4.8518e-04,  1.0366e-03, -8.5449e-04,
         -2.2888e-04,  1.4208e-

 tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], device='cuda:0',
       dtype=torch.float16)
optimizer weights:
 tensor([[-0.1289,  0.0104, -0.1572,  0.1194, -0.2696,  0.2318, -0.2299, -0.2516,
         -0.1997,  0.1431],
        [-0.1167,  0.1183, -0.2684, -0.1919, -0.1161, -0.0620, -0.2413,  0.2070,
         -0.0745,  0.1013],
        [ 0.2233,  0.0589,  0.0864,  0.3055, -0.1423,  0.1001, -0.1403,  0.2259,
          0.2522, -0.2913],
        [ 0.2702,  0.1511,  0.1378,  0.1303,  0.2630, -0.0419, -0.2672, -0.0907,
         -0.2226,  0.0205

 tensor([[-2.4486e-04,  5.4264e-04,  9.1851e-05, -4.6468e-04,  3.2330e-04,
          5.1355e-04,  7.7820e-04,  8.5878e-04, -9.7227e-04,  1.0071e-03],
        [-7.1001e-04,  1.0910e-03,  2.9659e-04,  2.8610e-04,  2.2316e-04,
         -5.0688e-04,  6.3801e-04, -4.2248e-04,  1.0300e-04,  9.4891e-04],
        [ 9.4652e-04, -2.5558e-04,  1.0157e-03, -2.0447e-03, -1.5774e-03,
         -2.8729e-04, -2.3460e-03,  8.8549e-04,  1.3037e-03, -2.0485e-03],
        [-7.7868e-04, -9.6130e-04,  1.0624e-03, -3.1400e-04,  1.3614e-04,
          1.4818e-04, -1.0052e-03,  2.7609e-04,  8.2302e-04,  2.8877e-03],
        [ 1.5821e-03,  1.4925e-03,  9.8896e-04, -6.9475e-04,  1.6603e-03,
         -1.6270e-03, -7.2098e-04, -1.3885e-03,  1.0881e-03, -4.1938e-04],
        [ 6.6948e-04, -9.5129e-04,  9.2760e-06, -1.7385e-03, -9.3102e-05,
         -5.4419e-05, -1.7834e-03, -7.1287e-04, -2.2812e-03,  4.5395e-03],
        [ 1.1501e-03,  2.0099e-04, -4.8518e-04,  1.0366e-03, -8.5449e-04,
         -2.2888e-04,  1.4208e-

 tensor([[-0.1289,  0.0104, -0.1572,  0.1194, -0.2696,  0.2318, -0.2299, -0.2516,
         -0.1997,  0.1431],
        [-0.1167,  0.1183, -0.2684, -0.1919, -0.1161, -0.0620, -0.2413,  0.2070,
         -0.0745,  0.1013],
        [ 0.2233,  0.0589,  0.0864,  0.3055, -0.1423,  0.1001, -0.1403,  0.2259,
          0.2522, -0.2913],
        [ 0.2702,  0.1511,  0.1378,  0.1303,  0.2630, -0.0419, -0.2672, -0.0907,
         -0.2226,  0.0205],
        [-0.0593, -0.1699, -0.0288,  0.2996, -0.0252,  0.0103, -0.0493,  0.0499,
          0.2816,  0.1934],
        [ 0.1122,  0.0690,  0.0747,  0.1225, -0.0408, -0.2941, -0.1953,  0.2701,
          0.0193, -0.2568],
        [ 0.0497,  0.2612, -0.2987, -0.2130, -0.1258,  0.0127, -0.0737, -0.0347,
         -0.3084,  0.1482],
        [ 0.2778,  0.1932, -0.2241, -0.2547,  0.1311,  0.0071,  0.1296, -0.3089,
         -0.0191,  0.2228],
        [ 0.1469,  0.0115,  0.0622, -0.0300, -0.1739, -0.1195, -0.1927,  0.2625,
          0.1742,  0.1105],
        [-0.2426, 

 Parameter containing:
tensor([[-0.1289,  0.0104, -0.1572,  0.1194, -0.2695,  0.2318, -0.2300, -0.2515,
         -0.1997,  0.1431],
        [-0.1167,  0.1183, -0.2683, -0.1919, -0.1161, -0.0620, -0.2412,  0.2070,
         -0.0745,  0.1013],
        [ 0.2234,  0.0589,  0.0864,  0.3054, -0.1422,  0.1001, -0.1404,  0.2258,
          0.2522, -0.2913],
        [ 0.2703,  0.1511,  0.1378,  0.1302,  0.2629, -0.0419, -0.2673, -0.0907,
         -0.2227,  0.0205],
        [-0.0593, -0.1699, -0.0288,  0.2996, -0.0252,  0.0103, -0.0493,  0.0499,
          0.2815,  0.1934],
        [ 0.1122,  0.0690,  0.0747,  0.1226, -0.0407, -0.2939, -0.1953,  0.2700,
          0.0193, -0.2568],
        [ 0.0497,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0348,
         -0.3083,  0.1482],
        [ 0.2778,  0.1932, -0.2241, -0.2546,  0.1311,  0.0071,  0.1296, -0.3088,
         -0.0191,  0.2228],
        [ 0.1470,  0.0115,  0.0623, -0.0300, -0.1738, -0.1195, -0.1927,  0.2625,
          0.1742,  0.110

 tensor([[-5.6696e-04,  6.6614e-04,  6.1417e-04, -1.0939e-03,  3.4809e-04,
         -1.5521e-04,  6.6757e-04,  2.6131e-04, -3.1114e-04,  3.6168e-04],
        [-6.5327e-04,  7.6115e-05,  2.0432e-04, -5.3167e-04,  2.7895e-05,
         -5.9319e-04,  4.7302e-04,  4.3082e-04, -3.5548e-04,  8.5688e-04],
        [ 1.9684e-03,  3.5095e-04,  1.8215e-04, -1.0862e-03, -1.8911e-03,
          2.4796e-04, -2.0065e-03, -6.5136e-04,  8.0299e-04, -1.4639e-03],
        [-1.3189e-03, -6.3086e-04,  6.6710e-04, -9.3222e-04,  7.1955e-04,
          1.7796e-03,  1.9863e-05,  5.3358e-04, -6.8307e-05,  3.0823e-03],
        [ 1.3399e-03,  1.4734e-03,  8.7118e-04,  1.4985e-04,  1.9779e-03,
         -1.1654e-03,  1.2338e-04, -1.7250e-04,  1.3857e-03, -1.6248e-04],
        [-7.9823e-04, -1.7042e-03, -1.4305e-04, -1.5774e-03, -5.8985e-04,
          6.5374e-04, -1.6928e-03, -2.9397e-04, -1.7080e-03,  3.5954e-03],
        [ 1.1339e-03, -2.8610e-04, -5.6934e-04,  9.7752e-04, -1.2293e-03,
         -2.6345e-04,  1.4758e-

 Parameter containing:
tensor([[-0.1289,  0.0104, -0.1572,  0.1194, -0.2695,  0.2318, -0.2300, -0.2515,
         -0.1997,  0.1431],
        [-0.1167,  0.1183, -0.2683, -0.1919, -0.1161, -0.0620, -0.2412,  0.2070,
         -0.0745,  0.1013],
        [ 0.2233,  0.0589,  0.0864,  0.3054, -0.1422,  0.1001, -0.1403,  0.2258,
          0.2522, -0.2913],
        [ 0.2703,  0.1511,  0.1378,  0.1302,  0.2629, -0.0419, -0.2673, -0.0907,
         -0.2227,  0.0204],
        [-0.0593, -0.1699, -0.0288,  0.2996, -0.0252,  0.0103, -0.0493,  0.0499,
          0.2815,  0.1934],
        [ 0.1122,  0.0690,  0.0747,  0.1226, -0.0407, -0.2939, -0.1953,  0.2700,
          0.0193, -0.2568],
        [ 0.0497,  0.2612, -0.2988, -0.2130, -0.1259,  0.0127, -0.0737, -0.0347,
         -0.3083,  0.1482],
        [ 0.2778,  0.1932, -0.2241, -0.2546,  0.1311,  0.0071,  0.1296, -0.3088,
         -0.0191,  0.2228],
        [ 0.1470,  0.0115,  0.0623, -0.0300, -0.1738, -0.1195, -0.1927,  0.2625,
          0.1742,  0.110

# Training Results

In [27]:
# Ran without profiler
history

{'train_losses': [1.61328125,
  1.6127387285232544,
  1.6127387285232544,
  1.6124131679534912,
  1.6119791269302368,
  1.6117621660232544,
  1.611328125,
  1.6112196445465088,
  1.6108940839767456,
  1.6108940839767456,
  1.610568642616272,
  1.6101346015930176,
  1.6101346015930176,
  1.6098090410232544,
  1.6097005605697632,
  1.6094834804534912,
  1.6092665195465088,
  1.6090494394302368,
  1.6089409589767456,
  1.608615517616272],
 'max_memory_allocation': [{device(type='cuda', index=0): 0.010752},
  {device(type='cuda', index=0): 0.256512},
  {device(type='cuda', index=0): 0.266752},
  {device(type='cuda', index=0): 0.266752},
  {device(type='cuda', index=0): 0.266752},
  {device(type='cuda', index=0): 0.266752},
  {device(type='cuda', index=0): 0.266752},
  {device(type='cuda', index=0): 0.266752},
  {device(type='cuda', index=0): 0.266752},
  {device(type='cuda', index=0): 0.266752},
  {device(type='cuda', index=0): 0.267264},
  {device(type='cuda', index=0): 0.267264},
  {devi

# Profiler Results

In [19]:
rankByCriteria(prof, model, criteria='cuda_memory_usage', per_thread=False, per_inp_shapes=False, include_external=False)

Ranked by cuda_memory_usage

41.18 Mb
##############################################
model, aten::empty, forward, (26) last_X = torch.relu(last_X)
2.81 Mb
##############################################
model, aten::addmm, forward, (24) last_X = linear_layer(last_X)
1.41 Mb
##############################################
model.scorer, aten::addmm, forward, (93) return F.linear(input, self.weight, self.bias)
model, aten::addmm, forward, (24) last_X = linear_layer(last_X)
1.41 Mb
##############################################
model, aten::resize_, forward, (24) last_X = linear_layer(last_X)
1.41 Mb
##############################################
model.scorer, aten::resize_, forward, (93) return F.linear(input, self.weight, self.bias)
model, aten::resize_, forward, (24) last_X = linear_layer(last_X)
1.41 Mb
##############################################
model, aten::relu, forward, (26) last_X = torch.relu(last_X)
1.41 Mb
##############################################
model, aten::threshold, 

# Nvidia Results

In [28]:
# Ran without profiler
!nvidia-smi

Sun Nov 22 16:36:20 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.32.00    Driver Version: 455.32.00    CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           On   | 000047DD:00:00.0 Off |                    0 |
| N/A   43C    P0    55W / 149W |    338MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [29]:
# Ran without profiler
print(torch.cuda.memory_summary(cuda0))

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |       0 B  |  276480 B  |  160123 KB |  160123 KB |
|       from large pool |       0 B  |       0 B  |       0 KB |       0 KB |
|       from small pool |       0 B  |  276480 B  |  160123 KB |  160123 KB |
|---------------------------------------------------------------------------|
| Active memory         |       0 B  |  276480 B  |  160123 KB |  160123 KB |
|       from large pool |       0 B  |       0 B  |       0 KB |       0 KB |
|       from small pool |       0 B  |  276480 B  |  160123 KB |  160123 KB |
|---------------------------------------------------------------