In [1]:
import sys, os
sys.path.append('/home/A00512318/TCN')
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
from TCN.poly_music.model import TCN
import numpy as np

# set hyperparameters and settings
cuda_ = True
dropout_ = 0.25
clip_ = 0.4
epochs_ = 100
kernel_size_ = 6
levels_ = 4
log_interval_ = 100
lr_ = 1e-3
optim_ = 'Adam'
nhid_ = 150
input_size_ = 88 # number of keys on a piano
seed_ = 1111
n_channels_ = [nhid_] * levels_
data_ = "Muse"

In [2]:
# Set the random seed manually for reproducibility.
torch.manual_seed(seed_)

<torch._C.Generator at 0x7f7884a37df0>

In [3]:
from scipy.io import loadmat

def data_generator(dataset):
    if dataset == "JSB":
        print('loading JSB data...')
        data = loadmat('./mdata/JSB_Chorales.mat')
    elif dataset == "Muse":
        print('loading Muse data...')
        data = loadmat('./mdata/MuseData.mat')
    elif dataset == "Nott":
        print('loading Nott data...')
        data = loadmat('./mdata/Nottingham.mat')
    elif dataset == "Piano":
        print('loading Piano data...')
        data = loadmat('./mdata/Piano_midi.mat')

    X_train = data['traindata'][0]
    X_valid = data['validdata'][0]
    X_test = data['testdata'][0]

    for data in [X_train, X_valid, X_test]:
        for i in range(len(data)):
            data[i] = torch.Tensor(data[i].astype(np.float64))

    return X_train, X_valid, X_test

X_train_, X_valid_, X_test_ = data_generator(data_) # tests will be done with the Muse data set

loading Muse data...


In [4]:
# set up device and model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = TCN(input_size_, input_size_, n_channels_, kernel_size_, dropout_)

model = nn.DataParallel(model)
model.to(device)

DataParallel(
  (module): TCN(
    (tcn): TemporalConvNet(
      (network): Sequential(
        (0): TemporalBlock(
          (conv1): Conv1d(88, 150, kernel_size=(6,), stride=(1,), padding=(5,))
          (chomp1): Chomp1d()
          (relu1): ReLU()
          (dropout1): Dropout(p=0.25)
          (conv2): Conv1d(150, 150, kernel_size=(6,), stride=(1,), padding=(5,))
          (chomp2): Chomp1d()
          (relu2): ReLU()
          (dropout2): Dropout(p=0.25)
          (net): Sequential(
            (0): Conv1d(88, 150, kernel_size=(6,), stride=(1,), padding=(5,))
            (1): Chomp1d()
            (2): ReLU()
            (3): Dropout(p=0.25)
            (4): Conv1d(150, 150, kernel_size=(6,), stride=(1,), padding=(5,))
            (5): Chomp1d()
            (6): ReLU()
            (7): Dropout(p=0.25)
          )
          (downsample): Conv1d(88, 150, kernel_size=(1,), stride=(1,))
          (relu): ReLU()
        )
        (1): TemporalBlock(
          (conv1): Conv1d(150, 150,

In [5]:
criterion_ = nn.CrossEntropyLoss()
optimizer_ = getattr(optim, optim_)(model.parameters(), lr=lr_)

In [6]:
def train(ep):
    model.train()
    total_loss = 0
    count = 0
    train_idx_list = np.arange(len(X_train_), dtype="int32")
    np.random.shuffle(train_idx_list)
    for idx in train_idx_list:
        data_line = X_train_[idx]
        x, y = Variable(data_line[:-1]), Variable(data_line[1:])
        if cuda_:
            x, y = x.cuda(), y.cuda()
        optimizer_.zero_grad()
        output = model(x.unsqueeze(0)).squeeze(0)
        loss = -torch.trace(torch.matmul(y, torch.log(output).float().t()) +
                            torch.matmul((1 - y), torch.log(1 - output).float().t()))
        total_loss += loss.data[0]
        count += output.size(0)

        if clip_ > 0:
            torch.nn.utils.clip_grad_norm(model.parameters(), clip_)
        loss.backward()
        optimizer_.step()
        if idx > 0 and idx % log_interval_ == 0:
            cur_loss = total_loss / count
            print("Epoch {:2d} | lr {:.5f} | loss {:.5f}".format(ep, lr_, cur_loss))
            total_loss = 0.0
            count = 0

In [28]:
def evaluate(X_data):
    eval_idx_list = np.arange(len(X_data), dtype="int32")
#     print(eval_idx_list)
    total_loss = 0.0
    count = 0
    with torch.no_grad():
        for idx in eval_idx_list:
            data_line = X_data[idx]
            x, y = Variable(data_line[:-1]), Variable(data_line[1:])
            if cuda_:
                x, y = x.to(device), y.to(device)
            output = model(x.unsqueeze(0))
            print(output)
            output = output.squeeze(0)
            loss = -torch.trace(torch.matmul(y, torch.log(output).float().t()) +
                                torch.matmul((1-y), torch.log(1-output).float().t()))
            total_loss += loss.data[0]
            count += output.size(0)
    eval_loss = total_loss / count
    print("Validation/Test loss: {:.5f}".format(eval_loss))
    return eval_loss

In [8]:
best_vloss = 1e8
vloss_list = []
model_name = "poly_music_{0}.pt".format(data_)
for ep in range(1, epochs_+1):
    train(ep)
    vloss = evaluate(X_valid_)
    tloss = evaluate(X_test_)
    if vloss < best_vloss:
        with open(model_name, "wb") as f:
            torch.save(model, f)
            print("Saved model!\n")
        best_vloss = vloss
    if ep > 10 and vloss > max(vloss_list[-3:]):
        lr_ /= 10
        for param_group in optimizer_.param_groups:
            param_group['lr'] = lr_

    vloss_list.append(vloss)

# print('-' * 89)
model = torch.load(open(model_name, "rb"))
tloss = evaluate(X_test_)




Epoch  1 | lr 0.00100 | loss 20.11985
Epoch  1 | lr 0.00100 | loss 12.03653
Epoch  1 | lr 0.00100 | loss 13.35413
Epoch  1 | lr 0.00100 | loss 10.98336
Epoch  1 | lr 0.00100 | loss 11.60595




Validation/Test loss: 11.44709
Validation/Test loss: 11.03024
Saved model!

Epoch  2 | lr 0.00100 | loss 9.72717
Epoch  2 | lr 0.00100 | loss 10.86648
Epoch  2 | lr 0.00100 | loss 11.19038
Epoch  2 | lr 0.00100 | loss 10.78325
Epoch  2 | lr 0.00100 | loss 10.20871
Validation/Test loss: 10.03831
Validation/Test loss: 9.74944
Saved model!

Epoch  3 | lr 0.00100 | loss 7.67366
Epoch  3 | lr 0.00100 | loss 9.77530
Epoch  3 | lr 0.00100 | loss 9.57626
Epoch  3 | lr 0.00100 | loss 8.61017
Epoch  3 | lr 0.00100 | loss 8.60887
Validation/Test loss: 9.00673
Validation/Test loss: 8.84206
Saved model!

Epoch  4 | lr 0.00100 | loss 8.96128
Epoch  4 | lr 0.00100 | loss 8.78838
Epoch  4 | lr 0.00100 | loss 8.62562
Epoch  4 | lr 0.00100 | loss 8.39875
Epoch  4 | lr 0.00100 | loss 7.13647
Validation/Test loss: 8.51429
Validation/Test loss: 8.39670
Saved model!

Epoch  5 | lr 0.00100 | loss 8.42384
Epoch  5 | lr 0.00100 | loss 9.06066
Epoch  5 | lr 0.00100 | loss 7.79880
Epoch  5 | lr 0.00100 | loss 8.

Validation/Test loss: 7.57710
Validation/Test loss: 7.52446
Saved model!

Epoch 34 | lr 0.00010 | loss 6.21665
Epoch 34 | lr 0.00010 | loss 7.06398
Epoch 34 | lr 0.00010 | loss 7.45604
Epoch 34 | lr 0.00010 | loss 7.19895
Epoch 34 | lr 0.00010 | loss 7.45482
Validation/Test loss: 7.57365
Validation/Test loss: 7.52934
Saved model!

Epoch 35 | lr 0.00010 | loss 7.62508
Epoch 35 | lr 0.00010 | loss 7.21105
Epoch 35 | lr 0.00010 | loss 7.40499
Epoch 35 | lr 0.00010 | loss 7.13224
Epoch 35 | lr 0.00010 | loss 8.01859
Validation/Test loss: 7.56072
Validation/Test loss: 7.51885
Saved model!

Epoch 36 | lr 0.00010 | loss 7.20219
Epoch 36 | lr 0.00010 | loss 7.52927
Epoch 36 | lr 0.00010 | loss 6.24914
Epoch 36 | lr 0.00010 | loss 6.59910
Epoch 36 | lr 0.00010 | loss 6.73766
Validation/Test loss: 7.55102
Validation/Test loss: 7.50343
Saved model!

Epoch 37 | lr 0.00010 | loss 7.24002
Epoch 37 | lr 0.00010 | loss 7.18043
Epoch 37 | lr 0.00010 | loss 7.15107
Epoch 37 | lr 0.00010 | loss 7.13163
E

Epoch 66 | lr 0.00001 | loss 6.90704
Epoch 66 | lr 0.00001 | loss 6.65765
Epoch 66 | lr 0.00001 | loss 5.60899
Epoch 66 | lr 0.00001 | loss 6.35242
Epoch 66 | lr 0.00001 | loss 6.84721
Validation/Test loss: 7.38604
Validation/Test loss: 7.36734
Saved model!

Epoch 67 | lr 0.00001 | loss 7.43584
Epoch 67 | lr 0.00001 | loss 6.81624
Epoch 67 | lr 0.00001 | loss 6.83270
Epoch 67 | lr 0.00001 | loss 6.12827
Epoch 67 | lr 0.00001 | loss 6.50078
Validation/Test loss: 7.39354
Validation/Test loss: 7.37368
Epoch 68 | lr 0.00000 | loss 6.79593
Epoch 68 | lr 0.00000 | loss 6.97215
Epoch 68 | lr 0.00000 | loss 6.50284
Epoch 68 | lr 0.00000 | loss 6.70975
Epoch 68 | lr 0.00000 | loss 7.30790
Validation/Test loss: 7.38862
Validation/Test loss: 7.36449
Epoch 69 | lr 0.00000 | loss 6.79025
Epoch 69 | lr 0.00000 | loss 6.53696
Epoch 69 | lr 0.00000 | loss 6.43586
Epoch 69 | lr 0.00000 | loss 7.22675
Epoch 69 | lr 0.00000 | loss 7.07770
Validation/Test loss: 7.38710
Validation/Test loss: 7.36800
Epoch 

Epoch 99 | lr 0.00000 | loss 6.90879
Epoch 99 | lr 0.00000 | loss 6.83108
Validation/Test loss: 7.38381
Validation/Test loss: 7.36738
Epoch 100 | lr 0.00000 | loss 6.27534
Epoch 100 | lr 0.00000 | loss 6.87706
Epoch 100 | lr 0.00000 | loss 5.59217
Epoch 100 | lr 0.00000 | loss 6.66237
Epoch 100 | lr 0.00000 | loss 6.80876
Validation/Test loss: 7.39405
Validation/Test loss: 7.36440
Validation/Test loss: 7.36607


In [30]:
model = torch.load(open(model_name, "rb"))
model.module

TCN(
  (tcn): TemporalConvNet(
    (network): Sequential(
      (0): TemporalBlock(
        (conv1): Conv1d(88, 150, kernel_size=(6,), stride=(1,), padding=(5,))
        (chomp1): Chomp1d()
        (relu1): ReLU()
        (dropout1): Dropout(p=0.25)
        (conv2): Conv1d(150, 150, kernel_size=(6,), stride=(1,), padding=(5,))
        (chomp2): Chomp1d()
        (relu2): ReLU()
        (dropout2): Dropout(p=0.25)
        (net): Sequential(
          (0): Conv1d(88, 150, kernel_size=(6,), stride=(1,), padding=(5,))
          (1): Chomp1d()
          (2): ReLU()
          (3): Dropout(p=0.25)
          (4): Conv1d(150, 150, kernel_size=(6,), stride=(1,), padding=(5,))
          (5): Chomp1d()
          (6): ReLU()
          (7): Dropout(p=0.25)
        )
        (downsample): Conv1d(88, 150, kernel_size=(1,), stride=(1,))
        (relu): ReLU()
      )
      (1): TemporalBlock(
        (conv1): Conv1d(150, 150, kernel_size=(6,), stride=(1,), padding=(10,), dilation=(2,))
        (chomp1)

In [29]:
tloss = evaluate(X_test_)

tensor([[[2.4248e-07, 6.6521e-06, 1.6307e-06,  ..., 1.4699e-12,
          1.2990e-12, 1.4892e-12],
         [1.1001e-06, 5.9317e-06, 1.8735e-06,  ..., 2.5942e-12,
          1.9067e-12, 3.1333e-12],
         [1.5711e-06, 5.4267e-05, 1.8997e-06,  ..., 2.8511e-12,
          2.5467e-12, 3.0656e-12],
         ...,
         [4.9202e-07, 1.0769e-07, 3.4478e-08,  ..., 5.3692e-15,
          1.0886e-14, 8.5966e-15],
         [3.7291e-08, 1.8440e-07, 1.8596e-08,  ..., 4.6915e-14,
          6.6713e-14, 6.5617e-14],
         [1.7144e-06, 4.6529e-07, 2.4021e-07,  ..., 3.7354e-15,
          5.8928e-15, 5.1455e-15]]], device='cuda:0', dtype=torch.float64)
tensor([[[1.8027e-07, 1.1935e-06, 2.9993e-07,  ..., 1.9007e-12,
          3.3156e-12, 2.3901e-12],
         [2.2026e-06, 4.3573e-06, 8.6490e-06,  ..., 5.5595e-12,
          9.6718e-12, 7.2919e-12],
         [9.1781e-08, 2.0032e-07, 1.7952e-07,  ..., 6.4057e-13,
          1.3559e-12, 8.5712e-13],
         ...,
         [1.5706e-12, 4.1756e-09, 1.5639e



tensor([[[4.8085e-08, 1.7467e-06, 1.3539e-06,  ..., 1.4559e-13,
          1.7607e-13, 2.8611e-13],
         [7.7369e-08, 7.2075e-06, 1.0215e-07,  ..., 1.2179e-13,
          1.0421e-13, 2.0993e-13],
         [1.8920e-08, 1.8891e-06, 9.2665e-08,  ..., 2.9795e-15,
          3.9205e-15, 4.7688e-15],
         ...,
         [4.6292e-08, 8.7380e-08, 2.0614e-09,  ..., 6.8397e-16,
          1.4134e-15, 8.3155e-16],
         [2.9411e-07, 8.1486e-07, 1.0862e-07,  ..., 2.5796e-13,
          1.9963e-13, 2.6618e-13],
         [2.0414e-07, 6.5542e-08, 5.5680e-08,  ..., 4.1223e-15,
          8.0421e-15, 8.9531e-15]]], device='cuda:0', dtype=torch.float64)
tensor([[[1.2206e-05, 1.9617e-07, 6.5945e-06,  ..., 7.3913e-15,
          5.2710e-15, 7.1575e-15],
         [3.2302e-06, 4.7433e-06, 8.8365e-06,  ..., 3.1034e-13,
          3.3685e-13, 5.0584e-13],
         [5.1054e-06, 1.4364e-06, 2.4945e-06,  ..., 9.3266e-15,
          9.5586e-15, 1.4468e-14],
         ...,
         [3.6120e-05, 1.5216e-07, 5.2619e

tensor([[[7.4126e-08, 8.5747e-07, 3.4675e-07,  ..., 9.0461e-13,
          2.5662e-12, 1.4897e-12],
         [1.1980e-07, 3.0326e-07, 2.3470e-07,  ..., 1.1711e-13,
          2.1430e-13, 1.3459e-13],
         [1.1555e-05, 1.7077e-05, 1.2229e-05,  ..., 1.3068e-10,
          1.2310e-10, 1.3141e-10],
         ...,
         [8.9174e-08, 1.0251e-06, 7.3345e-08,  ..., 3.3769e-12,
          2.5305e-12, 3.3107e-12],
         [1.9999e-06, 4.9080e-06, 2.4318e-07,  ..., 7.8642e-11,
          8.8705e-11, 8.6019e-11],
         [8.7571e-07, 1.3193e-06, 6.1928e-07,  ..., 1.5912e-11,
          1.6208e-11, 1.9234e-11]]], device='cuda:0', dtype=torch.float64)
tensor([[[3.2573e-06, 2.5247e-06, 6.1333e-05,  ..., 2.5923e-13,
          2.8896e-13, 2.3633e-13],
         [4.2013e-06, 3.2026e-06, 2.9635e-05,  ..., 1.7895e-13,
          2.1281e-13, 1.6376e-13],
         [8.4600e-06, 9.3774e-06, 2.4981e-05,  ..., 1.0664e-14,
          1.4490e-14, 8.9019e-15],
         ...,
         [9.6028e-06, 7.5057e-06, 9.6277e

tensor([[[1.6145e-06, 6.3292e-06, 3.7528e-06,  ..., 3.3386e-11,
          4.2830e-11, 4.6172e-11],
         [4.2188e-04, 1.4067e-04, 7.0493e-03,  ..., 2.3890e-10,
          2.6094e-10, 2.4494e-10],
         [5.8662e-05, 6.3177e-05, 6.8595e-04,  ..., 6.3189e-11,
          8.7242e-11, 7.1364e-11],
         ...,
         [3.9351e-05, 3.7702e-05, 3.9046e-05,  ..., 5.0233e-09,
          7.2396e-09, 5.5582e-09],
         [1.6333e-07, 1.8475e-07, 3.1505e-06,  ..., 1.8450e-15,
          2.5438e-15, 8.8598e-16],
         [2.3409e-07, 1.3213e-07, 2.2534e-06,  ..., 3.3406e-15,
          6.5475e-15, 4.7468e-15]]], device='cuda:0', dtype=torch.float64)
tensor([[[6.8877e-07, 4.1154e-06, 8.1786e-07,  ..., 3.8787e-12,
          4.4160e-12, 4.3468e-12],
         [8.2103e-07, 2.3326e-06, 6.2791e-07,  ..., 3.4939e-13,
          5.1004e-13, 4.7030e-13],
         [1.5510e-06, 3.9677e-06, 6.1700e-07,  ..., 7.9938e-12,
          7.4857e-12, 5.4984e-12],
         ...,
         [6.6020e-06, 1.5758e-05, 2.6444e

tensor([[[3.7984e-06, 3.4920e-05, 1.2132e-05,  ..., 1.7865e-10,
          1.8011e-10, 1.3994e-10],
         [3.4673e-06, 2.0305e-06, 1.5460e-05,  ..., 9.0005e-12,
          1.2717e-11, 1.1729e-11],
         [2.6463e-06, 3.5671e-06, 1.6309e-05,  ..., 1.4351e-11,
          1.8105e-11, 1.8552e-11],
         ...,
         [9.0138e-08, 5.0890e-07, 7.1970e-08,  ..., 1.8346e-11,
          2.8200e-11, 1.5507e-11],
         [8.8602e-08, 1.3981e-07, 5.6371e-08,  ..., 3.0516e-11,
          3.8188e-11, 2.2512e-11],
         [8.0482e-07, 1.3673e-06, 4.7895e-07,  ..., 4.3978e-11,
          7.2045e-11, 4.2485e-11]]], device='cuda:0', dtype=torch.float64)
tensor([[[1.2018e-07, 3.8915e-07, 2.7794e-07,  ..., 1.8531e-13,
          3.0734e-13, 1.9544e-13],
         [1.0167e-06, 2.1184e-06, 1.6118e-06,  ..., 2.2553e-11,
          4.6564e-11, 2.8168e-11],
         [5.0333e-07, 1.1706e-05, 1.8022e-06,  ..., 3.8183e-11,
          4.0836e-11, 3.8552e-11],
         ...,
         [5.9649e-08, 6.4393e-08, 1.1408e

tensor([[[8.0371e-08, 1.4638e-06, 3.7194e-07,  ..., 3.7462e-14,
          2.3868e-14, 5.1091e-14],
         [3.4074e-06, 1.0382e-05, 3.6834e-06,  ..., 2.8657e-14,
          2.1329e-14, 4.1118e-14],
         [3.5591e-05, 9.8823e-05, 6.9074e-05,  ..., 1.7591e-11,
          1.2697e-11, 1.5382e-11],
         ...,
         [4.6820e-08, 2.4651e-07, 5.5289e-08,  ..., 4.9694e-14,
          6.0677e-14, 6.9261e-14],
         [1.5442e-07, 8.6189e-07, 7.4755e-08,  ..., 2.1219e-13,
          1.7062e-13, 1.9631e-13],
         [4.3497e-08, 2.0433e-07, 4.7180e-08,  ..., 1.6172e-14,
          3.2344e-14, 2.8519e-14]]], device='cuda:0', dtype=torch.float64)
tensor([[[3.7696e-05, 5.6363e-05, 1.0917e-04,  ..., 1.4207e-10,
          3.2239e-10, 1.5711e-10],
         [1.7028e-06, 9.4524e-06, 4.9190e-06,  ..., 5.8293e-12,
          9.8284e-12, 5.4707e-12],
         [9.3391e-06, 1.4563e-05, 7.9755e-06,  ..., 2.9921e-11,
          3.0973e-11, 3.0824e-11],
         ...,
         [2.7881e-05, 9.7950e-06, 5.5851e

tensor([[[1.2921e-07, 1.2858e-07, 3.3456e-07,  ..., 3.0369e-15,
          3.4967e-15, 3.6365e-15],
         [1.6464e-07, 3.0623e-08, 1.7678e-07,  ..., 1.7688e-15,
          2.5372e-15, 2.0915e-15],
         [7.6556e-06, 8.7062e-06, 1.1731e-05,  ..., 2.2321e-10,
          3.9716e-10, 3.7970e-10],
         ...,
         [1.2509e-09, 1.8151e-09, 5.6871e-09,  ..., 1.0115e-15,
          2.2665e-15, 6.6443e-16],
         [3.7769e-05, 6.2554e-06, 2.7101e-05,  ..., 6.6065e-10,
          7.8831e-10, 4.9373e-10],
         [3.1826e-08, 2.5024e-08, 4.3987e-08,  ..., 3.7194e-17,
          1.2327e-16, 4.5684e-17]]], device='cuda:0', dtype=torch.float64)
tensor([[[2.1195e-05, 1.9624e-08, 1.2035e-06,  ..., 2.5718e-17,
          1.8700e-17, 3.4579e-17],
         [4.9821e-05, 2.2180e-08, 8.1294e-07,  ..., 2.9259e-17,
          2.0741e-17, 4.9050e-17],
         [8.8306e-05, 8.0548e-08, 1.2761e-06,  ..., 9.2398e-17,
          1.0448e-16, 2.0249e-16],
         ...,
         [4.5156e-06, 2.8859e-06, 3.7329e

tensor([[[1.0702e-06, 8.6254e-07, 8.8779e-07,  ..., 3.7974e-13,
          3.4143e-13, 3.9908e-13],
         [5.6963e-06, 9.9417e-07, 2.2196e-06,  ..., 1.0583e-13,
          1.1077e-13, 1.0213e-13],
         [8.5755e-06, 3.4554e-07, 4.8902e-06,  ..., 4.1234e-14,
          2.6456e-14, 3.1578e-14],
         ...,
         [2.2798e-04, 4.9767e-07, 9.7888e-05,  ..., 8.9267e-15,
          8.7583e-15, 9.1113e-15],
         [1.4983e-04, 2.3392e-06, 2.0597e-04,  ..., 5.3876e-14,
          3.9955e-14, 4.0610e-14],
         [1.4541e-04, 1.5741e-06, 1.5157e-04,  ..., 3.7942e-13,
          4.7748e-13, 5.3966e-13]]], device='cuda:0', dtype=torch.float64)
tensor([[[3.5977e-07, 1.6518e-06, 2.0520e-06,  ..., 6.6633e-12,
          6.1851e-12, 1.0311e-11],
         [4.7479e-07, 9.4410e-07, 3.2411e-06,  ..., 3.9623e-12,
          4.4750e-12, 4.1881e-12],
         [4.5250e-08, 1.7508e-06, 3.1283e-07,  ..., 7.5928e-13,
          1.1509e-12, 9.4369e-13],
         ...,
         [3.5885e-06, 3.7443e-07, 1.0836e

tensor([[[4.3397e-05, 1.8965e-06, 1.9970e-05,  ..., 2.1527e-13,
          2.0374e-13, 3.3702e-13],
         [1.1847e-05, 1.1854e-06, 8.6004e-06,  ..., 2.9470e-14,
          2.7045e-14, 4.3145e-14],
         [6.6244e-05, 6.1320e-06, 2.8070e-05,  ..., 1.1317e-11,
          7.5834e-12, 1.5600e-11],
         ...,
         [1.1995e-06, 3.8083e-08, 2.0729e-06,  ..., 7.0035e-16,
          2.4534e-15, 2.0472e-15],
         [4.9060e-06, 1.1090e-06, 1.1200e-05,  ..., 9.9893e-11,
          9.6646e-11, 1.0140e-10],
         [6.9894e-05, 9.7074e-07, 1.8806e-04,  ..., 2.2463e-14,
          6.9870e-14, 5.9797e-14]]], device='cuda:0', dtype=torch.float64)
tensor([[[9.0215e-06, 2.9370e-06, 9.1288e-06,  ..., 4.2406e-13,
          3.3702e-13, 3.7637e-13],
         [4.2059e-06, 8.3350e-07, 1.7710e-06,  ..., 1.5815e-13,
          1.2119e-13, 1.6183e-13],
         [1.9991e-06, 5.7156e-07, 7.0399e-07,  ..., 3.1867e-14,
          3.5876e-14, 4.5024e-14],
         ...,
         [1.3061e-06, 5.9013e-08, 8.7219e