In [1]:
'''
  code by Tae Hwan Jung(Jeff Jung) @graykode
'''
import numpy as np
import torch
#神經網路
import torch.nn as nn
#各種優化算法
import torch.optim as optim
#自動求導數
from torch.autograd import Variable
import torch.nn.functional as F

dtype = torch.FloatTensor

# Text-CNN Parameter
embedding_size = 2 # 詞向量的維度
sequence_length = 3
num_classes = 2  # 0 or 1
filter_sizes = [2, 2, 2] # n-gram window
num_filters = 3

# 3 words sentences (=sequence_length is 3)
sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]
labels = [1, 1, 1, 0, 0, 0]  # 1 is good, 0 is not good.

word_list = " ".join(sentences).split()
word_list = list(set(word_list))
word_dict = {w: i for i, w in enumerate(word_list)}
vocab_size = len(word_dict)
print(vocab_size)

inputs = []
for sen in sentences:
    inputs.append(np.asarray([word_dict[n] for n in sen.split()]))
print(len(inputs))
targets = []
for out in labels:
    targets.append(out) # To using Torch Softmax Loss function

input_batch = Variable(torch.LongTensor(inputs))
print(input_batch.shape)
target_batch = Variable(torch.LongTensor(targets))
print(target_batch)
class TextCNN(nn.Module):
    def __init__(self):
        super(TextCNN, self).__init__()
        # 總共會有num_filters * len(filter_sizes)個卷積核的輸出
        self.num_filters_total = num_filters * len(filter_sizes)
        # W為我們的嵌入詞典，第一維為詞典的大小，第二維為每個詞的詞向量的嵌入維度。其中每一行為一個詞的詞向量。
        self.W = nn.Parameter(torch.empty(vocab_size, embedding_size).uniform_(-1, 1)).type(dtype)
        print(self.W)
        # 最後輸出層的權重，維度為(num_filters_total, num_classes)。
        self.Weight = nn.Parameter(torch.empty(self.num_filters_total, num_classes).uniform_(-1, 1)).type(dtype)
        # 最後輸出層的bias
        self.Bias = nn.Parameter(0.1 * torch.ones([num_classes])).type(dtype)

    def forward(self, X):
        # 將輸入轉化為詞向量，維度為[batch_size, sequence_length, sequence_length]
        embedded_chars = self.W[X] 
        print(embedded_chars)
        # 添加額外的一個channel(=1)，為了適應Conv2d。維度為[batch, channel(=1), sequence_length, embedding_size]
        embedded_chars = embedded_chars.unsqueeze(1) 
        pooled_outputs = []
        for filter_size in filter_sizes:
            # conv : [input_channel(=1), output_channel(=3), (filter_height, filter_width), bias_option]
            conv = nn.Conv2d(1, num_filters, (filter_size, embedding_size), bias=True)(embedded_chars)
            h = F.relu(conv)
            # mp : ((filter_height, filter_width))
            mp = nn.MaxPool2d((sequence_length - filter_size + 1, 1))
            # pooled : [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3)]
            pooled = mp(h).permute(0, 3, 2, 1)
            pooled_outputs.append(pooled)

        h_pool = torch.cat(pooled_outputs, len(filter_sizes)) # [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3) * 3]
        h_pool_flat = torch.reshape(h_pool, [-1, self.num_filters_total]) # [batch_size(=6), output_height * output_width * (output_channel * 3)]

        model = torch.mm(h_pool_flat, self.Weight) + self.Bias # [batch_size, num_classes]
        return model

model = TextCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training
for epoch in range(5000):
    optimizer.zero_grad()
    output = model(input_batch)

    # output : [batch_size, num_classes], target_batch : [batch_size] (LongTensor, not one-hot)
    loss = criterion(output, target_batch)
    
    if (epoch + 1) % 1000 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))

    loss.backward()
    optimizer.step()

# Test
test_text = 'sorry hate you'
tests = [np.asarray([word_dict[n] for n in test_text.split()])]
test_batch = Variable(torch.LongTensor(tests))

# Predict
predict = model(test_batch).data.max(1, keepdim=True)[1]
if predict[0][0] == 0:
    print(test_text,"is Bad Mean...")
else:
    print(test_text,"is Good Mean!!")

16
6
torch.Size([6, 3])
tensor([1, 1, 1, 0, 0, 0])
Parameter containing:
tensor([[-0.3962, -0.9944],
        [ 0.3783, -0.0988],
        [-0.8973,  0.3972],
        [-0.0756,  0.2531],
        [-0.9290, -0.5933],
        [-0.7448, -0.9448],
        [ 0.9964,  0.6607],
        [ 0.7130,  0.7256],
        [ 0.7278,  0.8840],
        [-0.1993, -0.5791],
        [ 0.6343, -0.8486],
        [ 0.4235, -0.7311],
        [ 0.8845,  0.8907],
        [ 0.3052, -0.4557],
        [ 0.3112, -0.4683],
        [ 0.1069,  0.7204]], requires_grad=True)
tensor([[[-0.7448, -0.9448],
         [ 0.7130,  0.7256],
         [-0.3962, -0.9944]],

        [[ 0.1069,  0.7204],
         [ 0.9964,  0.6607],
         [ 0.3783, -0.0988]],

        [[ 0.4235, -0.7311],
         [-0.1993, -0.5791],
         [-0.9290, -0.5933]],

        [[-0.7448, -0.9448],
         [ 0.3112, -0.4683],
         [-0.3962, -0.9944]],

        [[ 0.6343, -0.8486],
         [ 0.8845,  0.8907],
         [-0.0756,  0.2531]],

        [[-0.

tensor([[[-0.7348, -0.9321],
         [ 0.6775,  0.7033],
         [-0.3687, -0.9715]],

        [[ 0.1350,  0.7027],
         [ 0.9553,  0.6349],
         [ 0.3902, -0.0791]],

        [[ 0.3964, -0.7108],
         [-0.1976, -0.5646],
         [-0.8861, -0.5843]],

        [[-0.7348, -0.9321],
         [ 0.3359, -0.4744],
         [-0.3687, -0.9715]],

        [[ 0.6255, -0.8711],
         [ 0.9064,  0.9105],
         [-0.1076,  0.2528]],

        [[-0.9203,  0.4069],
         [ 0.3107, -0.4792],
         [ 0.7326,  0.9016]]], grad_fn=<IndexBackward>)
tensor([[[-0.7346, -0.9314],
         [ 0.6772,  0.7033],
         [-0.3686, -0.9715]],

        [[ 0.1351,  0.7027],
         [ 0.9551,  0.6346],
         [ 0.3900, -0.0792]],

        [[ 0.3960, -0.7103],
         [-0.1974, -0.5641],
         [-0.8855, -0.5843]],

        [[-0.7346, -0.9314],
         [ 0.3358, -0.4747],
         [-0.3686, -0.9715]],

        [[ 0.6258, -0.8714],
         [ 0.9064,  0.9105],
         [-0.1079,  0.2529]

tensor([[[-0.7194, -0.9151],
         [ 0.6681,  0.6837],
         [-0.3598, -0.9576]],

        [[ 0.1433,  0.6848],
         [ 0.9268,  0.6178],
         [ 0.3940, -0.0708]],

        [[ 0.3818, -0.6904],
         [-0.1879, -0.5498],
         [-0.8572, -0.5701]],

        [[-0.7194, -0.9151],
         [ 0.3427, -0.4722],
         [-0.3598, -0.9576]],

        [[ 0.6027, -0.8812],
         [ 0.9196,  0.9096],
         [-0.1174,  0.2561]],

        [[-0.9349,  0.4002],
         [ 0.3049, -0.4829],
         [ 0.7352,  0.9152]]], grad_fn=<IndexBackward>)
tensor([[[-0.7193, -0.9149],
         [ 0.6681,  0.6833],
         [-0.3598, -0.9575]],

        [[ 0.1434,  0.6847],
         [ 0.9264,  0.6176],
         [ 0.3937, -0.0708]],

        [[ 0.3816, -0.6903],
         [-0.1878, -0.5495],
         [-0.8569, -0.5703]],

        [[-0.7193, -0.9149],
         [ 0.3429, -0.4723],
         [-0.3598, -0.9575]],

        [[ 0.6024, -0.8814],
         [ 0.9201,  0.9097],
         [-0.1174,  0.2564]

tensor([[[-0.7188, -0.9037],
         [ 0.6517,  0.6687],
         [-0.3552, -0.9353]],

        [[ 0.1388,  0.6773],
         [ 0.8929,  0.6061],
         [ 0.3929, -0.0692]],

        [[ 0.3569, -0.6666],
         [-0.1875, -0.5336],
         [-0.8446, -0.5670]],

        [[-0.7188, -0.9037],
         [ 0.3767, -0.4798],
         [-0.3552, -0.9353]],

        [[ 0.6128, -0.9027],
         [ 0.9378,  0.9181],
         [-0.1194,  0.2540]],

        [[-0.9387,  0.3927],
         [ 0.3246, -0.5009],
         [ 0.7416,  0.9186]]], grad_fn=<IndexBackward>)
tensor([[[-0.7187, -0.9037],
         [ 0.6516,  0.6685],
         [-0.3554, -0.9350]],

        [[ 0.1386,  0.6769],
         [ 0.8927,  0.6061],
         [ 0.3926, -0.0690]],

        [[ 0.3567, -0.6666],
         [-0.1874, -0.5335],
         [-0.8444, -0.5670]],

        [[-0.7187, -0.9037],
         [ 0.3767, -0.4796],
         [-0.3554, -0.9350]],

        [[ 0.6130, -0.9027],
         [ 0.9378,  0.9182],
         [-0.1193,  0.2537]

tensor([[[-0.6991, -0.9035],
         [ 0.6398,  0.6454],
         [-0.3537, -0.9296]],

        [[ 0.1390,  0.6614],
         [ 0.8794,  0.5791],
         [ 0.3807, -0.0584]],

        [[ 0.3496, -0.6617],
         [-0.1811, -0.5268],
         [-0.8240, -0.5588]],

        [[-0.6991, -0.9035],
         [ 0.3759, -0.4708],
         [-0.3537, -0.9296]],

        [[ 0.6220, -0.9105],
         [ 0.9348,  0.9274],
         [-0.1286,  0.2525]],

        [[-0.9528,  0.3908],
         [ 0.3247, -0.5032],
         [ 0.7583,  0.9256]]], grad_fn=<IndexBackward>)
tensor([[[-0.6987, -0.9030],
         [ 0.6393,  0.6451],
         [-0.3533, -0.9292]],

        [[ 0.1394,  0.6612],
         [ 0.8791,  0.5787],
         [ 0.3805, -0.0583]],

        [[ 0.3494, -0.6616],
         [-0.1812, -0.5267],
         [-0.8238, -0.5586]],

        [[-0.6987, -0.9030],
         [ 0.3761, -0.4705],
         [-0.3533, -0.9292]],

        [[ 0.6217, -0.9105],
         [ 0.9351,  0.9278],
         [-0.1288,  0.2523]

tensor([[[-0.6884, -0.8874],
         [ 0.6214,  0.6109],
         [-0.3437, -0.9147]],

        [[ 0.1530,  0.6462],
         [ 0.8527,  0.5554],
         [ 0.3917, -0.0625]],

        [[ 0.3371, -0.6544],
         [-0.1846, -0.5268],
         [-0.7931, -0.5530]],

        [[-0.6884, -0.8874],
         [ 0.3936, -0.4554],
         [-0.3437, -0.9147]],

        [[ 0.6189, -0.9228],
         [ 0.9404,  0.9503],
         [-0.1541,  0.2497]],

        [[-0.9587,  0.4045],
         [ 0.3241, -0.5121],
         [ 0.7565,  0.9398]]], grad_fn=<IndexBackward>)
tensor([[[-0.6883, -0.8872],
         [ 0.6214,  0.6107],
         [-0.3435, -0.9145]],

        [[ 0.1532,  0.6460],
         [ 0.8526,  0.5553],
         [ 0.3916, -0.0624]],

        [[ 0.3371, -0.6542],
         [-0.1845, -0.5264],
         [-0.7928, -0.5529]],

        [[-0.6883, -0.8872],
         [ 0.3935, -0.4553],
         [-0.3435, -0.9145]],

        [[ 0.6186, -0.9231],
         [ 0.9406,  0.9502],
         [-0.1542,  0.2498]

tensor([[[-0.6819, -0.8603],
         [ 0.6098,  0.6011],
         [-0.3337, -0.8920]],

        [[ 0.1609,  0.6474],
         [ 0.8273,  0.5511],
         [ 0.3886, -0.0527]],

        [[ 0.3208, -0.6223],
         [-0.1856, -0.5023],
         [-0.7698, -0.5403]],

        [[-0.6819, -0.8603],
         [ 0.4082, -0.4622],
         [-0.3337, -0.8920]],

        [[ 0.6211, -0.9634],
         [ 0.9479,  0.9500],
         [-0.1609,  0.2479]],

        [[-0.9717,  0.3966],
         [ 0.3313, -0.5436],
         [ 0.7621,  0.9482]]], grad_fn=<IndexBackward>)
tensor([[[-0.6817, -0.8602],
         [ 0.6094,  0.6007],
         [-0.3336, -0.8918]],

        [[ 0.1611,  0.6475],
         [ 0.8270,  0.5509],
         [ 0.3886, -0.0526]],

        [[ 0.3205, -0.6223],
         [-0.1856, -0.5018],
         [-0.7695, -0.5404]],

        [[-0.6817, -0.8602],
         [ 0.4085, -0.4617],
         [-0.3336, -0.8918]],

        [[ 0.6210, -0.9637],
         [ 0.9482,  0.9503],
         [-0.1612,  0.2476]

tensor([[[-0.6791, -0.8352],
         [ 0.5945,  0.5676],
         [-0.3395, -0.8654]],

        [[ 0.1761,  0.6322],
         [ 0.8004,  0.5268],
         [ 0.3946, -0.0406]],

        [[ 0.3099, -0.6227],
         [-0.1769, -0.4912],
         [-0.7381, -0.5275]],

        [[-0.6791, -0.8352],
         [ 0.4253, -0.4563],
         [-0.3395, -0.8654]],

        [[ 0.6202, -0.9705],
         [ 0.9601,  0.9795],
         [-0.1691,  0.2325]],

        [[-0.9952,  0.4196],
         [ 0.3298, -0.5665],
         [ 0.7761,  0.9606]]], grad_fn=<IndexBackward>)
tensor([[[-0.6791, -0.8346],
         [ 0.5945,  0.5673],
         [-0.3395, -0.8649]],

        [[ 0.1764,  0.6320],
         [ 0.8001,  0.5267],
         [ 0.3948, -0.0406]],

        [[ 0.3100, -0.6226],
         [-0.1767, -0.4914],
         [-0.7376, -0.5273]],

        [[-0.6791, -0.8346],
         [ 0.4253, -0.4564],
         [-0.3395, -0.8649]],

        [[ 0.6203, -0.9707],
         [ 0.9601,  0.9797],
         [-0.1695,  0.2325]

tensor([[[-0.6896, -0.8168],
         [ 0.5799,  0.5504],
         [-0.3397, -0.8487]],

        [[ 0.1859,  0.6058],
         [ 0.7762,  0.5171],
         [ 0.3886, -0.0257]],

        [[ 0.2806, -0.6032],
         [-0.1710, -0.4866],
         [-0.7172, -0.5154]],

        [[-0.6896, -0.8168],
         [ 0.4474, -0.4668],
         [-0.3397, -0.8487]],

        [[ 0.6370, -0.9959],
         [ 0.9676,  0.9993],
         [-0.1819,  0.2402]],

        [[-1.0213,  0.4241],
         [ 0.3459, -0.5862],
         [ 0.7827,  0.9752]]], grad_fn=<IndexBackward>)
tensor([[[-0.6897, -0.8165],
         [ 0.5792,  0.5501],
         [-0.3397, -0.8482]],

        [[ 0.1861,  0.6057],
         [ 0.7756,  0.5168],
         [ 0.3886, -0.0253]],

        [[ 0.2804, -0.6032],
         [-0.1709, -0.4865],
         [-0.7171, -0.5153]],

        [[-0.6897, -0.8165],
         [ 0.4480, -0.4669],
         [-0.3397, -0.8482]],

        [[ 0.6368, -0.9962],
         [ 0.9681,  0.9998],
         [-0.1822,  0.2404]

tensor([[[-0.6743, -0.7906],
         [ 0.5555,  0.5121],
         [-0.3307, -0.8386]],

        [[ 0.2101,  0.5939],
         [ 0.7407,  0.4894],
         [ 0.3960, -0.0080]],

        [[ 0.2746, -0.5968],
         [-0.1697, -0.4819],
         [-0.6828, -0.5114]],

        [[-0.6743, -0.7906],
         [ 0.4732, -0.4504],
         [-0.3307, -0.8386]],

        [[ 0.6174, -1.0177],
         [ 0.9939,  1.0332],
         [-0.2030,  0.2431]],

        [[-1.0425,  0.4415],
         [ 0.3444, -0.6106],
         [ 0.7855,  1.0044]]], grad_fn=<IndexBackward>)
tensor([[[-0.6742, -0.7903],
         [ 0.5551,  0.5116],
         [-0.3305, -0.8384]],

        [[ 0.2102,  0.5938],
         [ 0.7403,  0.4888],
         [ 0.3960, -0.0079]],

        [[ 0.2745, -0.5965],
         [-0.1699, -0.4821],
         [-0.6823, -0.5114]],

        [[-0.6742, -0.7903],
         [ 0.4737, -0.4501],
         [-0.3305, -0.8384]],

        [[ 0.6173, -1.0181],
         [ 0.9944,  1.0337],
         [-0.2033,  0.2432]

tensor([[[-0.6667, -0.7677],
         [ 0.5156,  0.4786],
         [-0.3274, -0.8189]],

        [[ 0.2297,  0.5760],
         [ 0.6984,  0.4589],
         [ 0.3979,  0.0040]],

        [[ 0.2512, -0.5705],
         [-0.1839, -0.4843],
         [-0.6595, -0.5159]],

        [[-0.6667, -0.7677],
         [ 0.5207, -0.4399],
         [-0.3274, -0.8189]],

        [[ 0.6159, -1.0533],
         [ 1.0239,  1.0653],
         [-0.2182,  0.2499]],

        [[-1.0695,  0.4531],
         [ 0.3664, -0.6314],
         [ 0.8034,  1.0364]]], grad_fn=<IndexBackward>)
tensor([[[-0.6665, -0.7674],
         [ 0.5151,  0.4782],
         [-0.3275, -0.8189]],

        [[ 0.2299,  0.5756],
         [ 0.6979,  0.4585],
         [ 0.3979,  0.0042]],

        [[ 0.2508, -0.5703],
         [-0.1840, -0.4843],
         [-0.6592, -0.5159]],

        [[-0.6665, -0.7674],
         [ 0.5213, -0.4400],
         [-0.3275, -0.8189]],

        [[ 0.6158, -1.0538],
         [ 1.0243,  1.0657],
         [-0.2185,  0.2501]

tensor([[[-0.6758, -0.7415],
         [ 0.4926,  0.4388],
         [-0.3280, -0.7945]],

        [[ 0.2464,  0.5494],
         [ 0.6580,  0.4275],
         [ 0.4072,  0.0177]],

        [[ 0.2221, -0.5532],
         [-0.1810, -0.4796],
         [-0.6210, -0.5041]],

        [[-0.6758, -0.7415],
         [ 0.5602, -0.4395],
         [-0.3280, -0.7945]],

        [[ 0.6272, -1.0952],
         [ 1.0345,  1.1135],
         [-0.2296,  0.2397]],

        [[-1.0937,  0.4774],
         [ 0.3773, -0.6702],
         [ 0.8140,  1.0711]]], grad_fn=<IndexBackward>)
tensor([[[-0.6760, -0.7412],
         [ 0.4925,  0.4384],
         [-0.3278, -0.7943]],

        [[ 0.2465,  0.5493],
         [ 0.6575,  0.4274],
         [ 0.4069,  0.0178]],

        [[ 0.2217, -0.5529],
         [-0.1808, -0.4794],
         [-0.6206, -0.5039]],

        [[-0.6760, -0.7412],
         [ 0.5604, -0.4397],
         [-0.3278, -0.7943]],

        [[ 0.6275, -1.0959],
         [ 1.0345,  1.1140],
         [-0.2297,  0.2397]

tensor([[[-0.6761, -0.7144],
         [ 0.4676,  0.3940],
         [-0.3311, -0.7644]],

        [[ 0.2610,  0.5225],
         [ 0.6183,  0.3920],
         [ 0.4121,  0.0495]],

        [[ 0.1921, -0.5404],
         [-0.1677, -0.4710],
         [-0.5730, -0.4821]],

        [[-0.6761, -0.7144],
         [ 0.5981, -0.4371],
         [-0.3311, -0.7644]],

        [[ 0.6311, -1.1317],
         [ 1.0567,  1.1658],
         [-0.2528,  0.2339]],

        [[-1.1269,  0.5013],
         [ 0.3866, -0.7041],
         [ 0.8237,  1.1091]]], grad_fn=<IndexBackward>)
tensor([[[-0.6761, -0.7143],
         [ 0.4672,  0.3935],
         [-0.3314, -0.7641]],

        [[ 0.2615,  0.5218],
         [ 0.6179,  0.3915],
         [ 0.4118,  0.0501]],

        [[ 0.1913, -0.5402],
         [-0.1674, -0.4708],
         [-0.5725, -0.4822]],

        [[-0.6761, -0.7143],
         [ 0.5986, -0.4369],
         [-0.3314, -0.7641]],

        [[ 0.6313, -1.1321],
         [ 1.0568,  1.1665],
         [-0.2531,  0.2341]

tensor([[[-0.6734, -0.7025],
         [ 0.4170,  0.3595],
         [-0.3291, -0.7384]],

        [[ 0.2868,  0.4869],
         [ 0.5647,  0.3582],
         [ 0.4060,  0.0800]],

        [[ 0.1526, -0.5202],
         [-0.1711, -0.4605],
         [-0.5324, -0.4708]],

        [[-0.6734, -0.7025],
         [ 0.6614, -0.4423],
         [-0.3291, -0.7384]],

        [[ 0.6378, -1.1639],
         [ 1.0899,  1.2211],
         [-0.2752,  0.2368]],

        [[-1.1563,  0.5296],
         [ 0.4064, -0.7365],
         [ 0.8443,  1.1467]]], grad_fn=<IndexBackward>)
tensor([[[-0.6734, -0.7018],
         [ 0.4162,  0.3590],
         [-0.3296, -0.7372]],

        [[ 0.2872,  0.4863],
         [ 0.5636,  0.3582],
         [ 0.4058,  0.0807]],

        [[ 0.1515, -0.5198],
         [-0.1711, -0.4600],
         [-0.5320, -0.4706]],

        [[-0.6734, -0.7018],
         [ 0.6622, -0.4425],
         [-0.3296, -0.7372]],

        [[ 0.6380, -1.1644],
         [ 1.0903,  1.2217],
         [-0.2754,  0.2365]

tensor([[[-0.6705, -0.6591],
         [ 0.3805,  0.3191],
         [-0.3321, -0.6975]],

        [[ 0.3045,  0.4565],
         [ 0.5112,  0.3319],
         [ 0.4116,  0.1202]],

        [[ 0.1127, -0.5072],
         [-0.1608, -0.4430],
         [-0.4867, -0.4484]],

        [[-0.6705, -0.6591],
         [ 0.7196, -0.4472],
         [-0.3321, -0.6975]],

        [[ 0.6422, -1.2019],
         [ 1.1207,  1.2808],
         [-0.2836,  0.2237]],

        [[-1.1926,  0.5629],
         [ 0.4286, -0.7781],
         [ 0.8577,  1.1828]]], grad_fn=<IndexBackward>)
tensor([[[-0.6705, -0.6584],
         [ 0.3801,  0.3187],
         [-0.3316, -0.6972]],

        [[ 0.3049,  0.4562],
         [ 0.5107,  0.3319],
         [ 0.4117,  0.1207]],

        [[ 0.1121, -0.5068],
         [-0.1604, -0.4427],
         [-0.4860, -0.4482]],

        [[-0.6705, -0.6584],
         [ 0.7205, -0.4477],
         [-0.3316, -0.6972]],

        [[ 0.6425, -1.2025],
         [ 1.1210,  1.2816],
         [-0.2839,  0.2234]

tensor([[[-0.6658, -0.6343],
         [ 0.3336,  0.2777],
         [-0.3200, -0.6568]],

        [[ 0.3261,  0.4196],
         [ 0.4594,  0.3128],
         [ 0.4152,  0.1537]],

        [[ 0.0722, -0.4800],
         [-0.1623, -0.4218],
         [-0.4401, -0.4305]],

        [[-0.6658, -0.6343],
         [ 0.7961, -0.4608],
         [-0.3200, -0.6568]],

        [[ 0.6440, -1.2436],
         [ 1.1598,  1.3342],
         [-0.3036,  0.2052]],

        [[-1.2401,  0.5885],
         [ 0.4619, -0.8285],
         [ 0.8709,  1.2220]]], grad_fn=<IndexBackward>)
tensor([[[-0.6656, -0.6339],
         [ 0.3329,  0.2772],
         [-0.3200, -0.6563]],

        [[ 0.3263,  0.4193],
         [ 0.4586,  0.3128],
         [ 0.4154,  0.1539]],

        [[ 0.0718, -0.4796],
         [-0.1626, -0.4214],
         [-0.4395, -0.4305]],

        [[-0.6656, -0.6339],
         [ 0.7971, -0.4611],
         [-0.3200, -0.6563]],

        [[ 0.6439, -1.2443],
         [ 1.1603,  1.3347],
         [-0.3038,  0.2051]

tensor([[[-0.6520, -0.5877],
         [ 0.2861,  0.2219],
         [-0.3152, -0.6253]],

        [[ 0.3424,  0.3816],
         [ 0.4109,  0.2859],
         [ 0.4144,  0.1849]],

        [[ 0.0388, -0.4452],
         [-0.1663, -0.4088],
         [-0.3801, -0.4110]],

        [[-0.6520, -0.5877],
         [ 0.8774, -0.4657],
         [-0.3152, -0.6253]],

        [[ 0.6571, -1.2935],
         [ 1.1967,  1.3878],
         [-0.3239,  0.1921]],

        [[-1.2896,  0.6307],
         [ 0.4864, -0.8812],
         [ 0.8900,  1.2642]]], grad_fn=<IndexBackward>)
tensor([[[-0.6520, -0.5870],
         [ 0.2853,  0.2208],
         [-0.3150, -0.6245]],

        [[ 0.3423,  0.3808],
         [ 0.4102,  0.2855],
         [ 0.4145,  0.1853]],

        [[ 0.0383, -0.4448],
         [-0.1664, -0.4090],
         [-0.3794, -0.4109]],

        [[-0.6520, -0.5870],
         [ 0.8785, -0.4655],
         [-0.3150, -0.6245]],

        [[ 0.6572, -1.2940],
         [ 1.1971,  1.3886],
         [-0.3243,  0.1920]

tensor([[[-0.6338, -0.5260],
         [ 0.2437,  0.1632],
         [-0.3183, -0.5636]],

        [[ 0.3518,  0.3350],
         [ 0.3806,  0.2725],
         [ 0.3918,  0.2119]],

        [[-0.0066, -0.4136],
         [-0.1541, -0.3775],
         [-0.3187, -0.4053]],

        [[-0.6338, -0.5260],
         [ 0.9647, -0.4865],
         [-0.3183, -0.5636]],

        [[ 0.6504, -1.3500],
         [ 1.2239,  1.4506],
         [-0.3509,  0.1766]],

        [[-1.3393,  0.6657],
         [ 0.5057, -0.9375],
         [ 0.9007,  1.3136]]], grad_fn=<IndexBackward>)
tensor([[[-0.6338, -0.5251],
         [ 0.2430,  0.1625],
         [-0.3185, -0.5629]],

        [[ 0.3516,  0.3343],
         [ 0.3804,  0.2724],
         [ 0.3912,  0.2122]],

        [[-0.0069, -0.4134],
         [-0.1545, -0.3773],
         [-0.3179, -0.4050]],

        [[-0.6338, -0.5251],
         [ 0.9659, -0.4868],
         [-0.3185, -0.5629]],

        [[ 0.6504, -1.3506],
         [ 1.2244,  1.4513],
         [-0.3511,  0.1764]

tensor([[[-0.6132, -0.4704],
         [ 0.1908,  0.1258],
         [-0.3121, -0.5120]],

        [[ 0.3459,  0.2850],
         [ 0.3555,  0.2605],
         [ 0.3618,  0.2387]],

        [[-0.0437, -0.3760],
         [-0.1478, -0.3512],
         [-0.2597, -0.3743]],

        [[-0.6132, -0.4704],
         [ 1.0634, -0.5164],
         [-0.3121, -0.5120]],

        [[ 0.6523, -1.4054],
         [ 1.2761,  1.5035],
         [-0.3754,  0.1742]],

        [[-1.3994,  0.6893],
         [ 0.5354, -1.0018],
         [ 0.9257,  1.3618]]], grad_fn=<IndexBackward>)
tensor([[[-0.6126, -0.4696],
         [ 0.1900,  0.1257],
         [-0.3117, -0.5113]],

        [[ 0.3460,  0.2844],
         [ 0.3550,  0.2605],
         [ 0.3612,  0.2393]],

        [[-0.0443, -0.3752],
         [-0.1480, -0.3505],
         [-0.2587, -0.3737]],

        [[-0.6126, -0.4696],
         [ 1.0645, -0.5167],
         [-0.3117, -0.5113]],

        [[ 0.6525, -1.4061],
         [ 1.2763,  1.5041],
         [-0.3757,  0.1740]

tensor([[[-0.5835, -0.4272],
         [ 0.1448,  0.0916],
         [-0.2875, -0.4576]],

        [[ 0.3437,  0.2526],
         [ 0.3425,  0.2530],
         [ 0.3421,  0.2527]],

        [[-0.0862, -0.3416],
         [-0.1335, -0.3336],
         [-0.1967, -0.3384]],

        [[-0.5835, -0.4272],
         [ 1.1450, -0.5287],
         [-0.2875, -0.4576]],

        [[ 0.6463, -1.4572],
         [ 1.3019,  1.5534],
         [-0.4072,  0.1568]],

        [[-1.4448,  0.7152],
         [ 0.5474, -1.0599],
         [ 0.9336,  1.4077]]], grad_fn=<IndexBackward>)
tensor([[[-0.5832, -0.4268],
         [ 0.1440,  0.0911],
         [-0.2874, -0.4572]],

        [[ 0.3434,  0.2520],
         [ 0.3420,  0.2527],
         [ 0.3419,  0.2525]],

        [[-0.0866, -0.3414],
         [-0.1336, -0.3335],
         [-0.1962, -0.3384]],

        [[-0.5832, -0.4268],
         [ 1.1465, -0.5286],
         [-0.2874, -0.4572]],

        [[ 0.6462, -1.4576],
         [ 1.3025,  1.5541],
         [-0.4077,  0.1567]

tensor([[[-0.5636, -0.3729],
         [ 0.0815,  0.0438],
         [-0.2846, -0.4090]],

        [[ 0.3075,  0.2260],
         [ 0.3060,  0.2255],
         [ 0.3081,  0.2252]],

        [[-0.1339, -0.3247],
         [-0.1410, -0.3213],
         [-0.1538, -0.3206]],

        [[-0.5636, -0.3729],
         [ 1.2630, -0.5478],
         [-0.2846, -0.4090]],

        [[ 0.6440, -1.4978],
         [ 1.3508,  1.6097],
         [-0.4259,  0.1468]],

        [[-1.4874,  0.7467],
         [ 0.5823, -1.1104],
         [ 0.9647,  1.4637]]], grad_fn=<IndexBackward>)
tensor([[[-0.5635, -0.3724],
         [ 0.0807,  0.0437],
         [-0.2844, -0.4086]],

        [[ 0.3072,  0.2260],
         [ 0.3057,  0.2250],
         [ 0.3079,  0.2254]],

        [[-0.1346, -0.3241],
         [-0.1410, -0.3211],
         [-0.1531, -0.3202]],

        [[-0.5635, -0.3724],
         [ 1.2643, -0.5485],
         [-0.2844, -0.4086]],

        [[ 0.6445, -1.4986],
         [ 1.3512,  1.6102],
         [-0.4262,  0.1467]

tensor([[[-0.5200, -0.3309],
         [ 0.0197,  0.0035],
         [-0.2760, -0.3555]],

        [[ 0.2837,  0.2139],
         [ 0.2839,  0.2137],
         [ 0.2839,  0.2127]],

        [[-0.1369, -0.2995],
         [-0.1356, -0.3001],
         [-0.1368, -0.2994]],

        [[-0.5200, -0.3309],
         [ 1.3656, -0.5839],
         [-0.2760, -0.3555]],

        [[ 0.6376, -1.5499],
         [ 1.3940,  1.6741],
         [-0.4512,  0.1241]],

        [[-1.5329,  0.7783],
         [ 0.6073, -1.1718],
         [ 0.9861,  1.5142]]], grad_fn=<IndexBackward>)
tensor([[[-0.5192, -0.3298],
         [ 0.0195,  0.0030],
         [-0.2760, -0.3549]],

        [[ 0.2838,  0.2136],
         [ 0.2837,  0.2135],
         [ 0.2835,  0.2128]],

        [[-0.1365, -0.2996],
         [-0.1354, -0.2994],
         [-0.1365, -0.2991]],

        [[-0.5192, -0.3298],
         [ 1.3667, -0.5847],
         [-0.2760, -0.3549]],

        [[ 0.6377, -1.5506],
         [ 1.3942,  1.6751],
         [-0.4516,  0.1237]

tensor([[[-0.4847, -0.2712],
         [-0.0197, -0.0345],
         [-0.2737, -0.3125]],

        [[ 0.2672,  0.1956],
         [ 0.2683,  0.1966],
         [ 0.2677,  0.1977]],

        [[-0.1195, -0.2803],
         [-0.1197, -0.2787],
         [-0.1186, -0.2783]],

        [[-0.4847, -0.2712],
         [ 1.4667, -0.6087],
         [-0.2737, -0.3125]],

        [[ 0.6334, -1.5933],
         [ 1.4309,  1.7391],
         [-0.4669,  0.1022]],

        [[-1.5769,  0.8178],
         [ 0.6320, -1.2202],
         [ 1.0075,  1.5646]]], grad_fn=<IndexBackward>)
tensor([[[-0.4844, -0.2706],
         [-0.0205, -0.0345],
         [-0.2736, -0.3118]],

        [[ 0.2671,  0.1957],
         [ 0.2677,  0.1967],
         [ 0.2674,  0.1977]],

        [[-0.1196, -0.2795],
         [-0.1197, -0.2784],
         [-0.1186, -0.2778]],

        [[-0.4844, -0.2706],
         [ 1.4678, -0.6091],
         [-0.2736, -0.3118]],

        [[ 0.6332, -1.5939],
         [ 1.4314,  1.7396],
         [-0.4671,  0.1022]

tensor([[[-0.4377, -0.2366],
         [-0.0744, -0.0633],
         [-0.2694, -0.2678]],

        [[ 0.2513,  0.1809],
         [ 0.2505,  0.1807],
         [ 0.2487,  0.1818]],

        [[-0.1110, -0.2616],
         [-0.1111, -0.2624],
         [-0.1104, -0.2612]],

        [[-0.4377, -0.2366],
         [ 1.5569, -0.6227],
         [-0.2694, -0.2678]],

        [[ 0.6207, -1.6302],
         [ 1.4702,  1.7898],
         [-0.4904,  0.0970]],

        [[-1.6162,  0.8569],
         [ 0.6562, -1.2692],
         [ 1.0255,  1.6109]]], grad_fn=<IndexBackward>)
tensor([[[-0.4371, -0.2360],
         [-0.0747, -0.0637],
         [-0.2694, -0.2672]],

        [[ 0.2511,  0.1807],
         [ 0.2503,  0.1804],
         [ 0.2485,  0.1818]],

        [[-0.1108, -0.2617],
         [-0.1110, -0.2618],
         [-0.1102, -0.2610]],

        [[-0.4371, -0.2360],
         [ 1.5580, -0.6233],
         [-0.2694, -0.2672]],

        [[ 0.6203, -1.6307],
         [ 1.4706,  1.7905],
         [-0.4906,  0.0969]

tensor([[[-0.3937, -0.1900],
         [-0.1207, -0.0858],
         [-0.2668, -0.2190]],

        [[ 0.2254,  0.1719],
         [ 0.2262,  0.1741],
         [ 0.2264,  0.1728]],

        [[-0.1083, -0.2397],
         [-0.1096, -0.2404],
         [-0.1102, -0.2382]],

        [[-0.3937, -0.1900],
         [ 1.6305, -0.6510],
         [-0.2668, -0.2190]],

        [[ 0.6130, -1.6659],
         [ 1.5009,  1.8337],
         [-0.5086,  0.0856]],

        [[-1.6514,  0.8780],
         [ 0.6668, -1.3102],
         [ 1.0399,  1.6351]]], grad_fn=<IndexBackward>)
tensor([[[-0.3933, -0.1895],
         [-0.1212, -0.0860],
         [-0.2666, -0.2186]],

        [[ 0.2251,  0.1717],
         [ 0.2260,  0.1741],
         [ 0.2266,  0.1725]],

        [[-0.1080, -0.2394],
         [-0.1095, -0.2401],
         [-0.1101, -0.2382]],

        [[-0.3933, -0.1895],
         [ 1.6316, -0.6514],
         [-0.2666, -0.2186]],

        [[ 0.6131, -1.6665],
         [ 1.5015,  1.8344],
         [-0.5088,  0.0851]

tensor([[[-0.3558, -0.1515],
         [-0.1680, -0.1020],
         [-0.2562, -0.1756]],

        [[ 0.2126,  0.1654],
         [ 0.2117,  0.1658],
         [ 0.2113,  0.1640]],

        [[-0.1010, -0.2204],
         [-0.1023, -0.2208],
         [-0.1035, -0.2209]],

        [[-0.3558, -0.1515],
         [ 1.7131, -0.6900],
         [-0.2562, -0.1756]],

        [[ 0.6099, -1.7010],
         [ 1.5365,  1.8819],
         [-0.5301,  0.0694]],

        [[-1.6902,  0.9007],
         [ 0.6929, -1.3585],
         [ 1.0546,  1.6740]]], grad_fn=<IndexBackward>)
tensor([[[-0.3558, -0.1509],
         [-0.1684, -0.1024],
         [-0.2564, -0.1750]],

        [[ 0.2126,  0.1653],
         [ 0.2116,  0.1655],
         [ 0.2111,  0.1639]],

        [[-0.1007, -0.2206],
         [-0.1022, -0.2204],
         [-0.1033, -0.2207]],

        [[-0.3558, -0.1509],
         [ 1.7142, -0.6905],
         [-0.2564, -0.1750]],

        [[ 0.6097, -1.7015],
         [ 1.5371,  1.8825],
         [-0.5303,  0.0692]

tensor([[[-0.3227, -0.1285],
         [-0.2067, -0.1082],
         [-0.2620, -0.1414]],

        [[ 0.1969,  0.1528],
         [ 0.1957,  0.1535],
         [ 0.1959,  0.1536]],

        [[-0.0959, -0.2034],
         [-0.0967, -0.2031],
         [-0.0970, -0.2018]],

        [[-0.3227, -0.1285],
         [ 1.7939, -0.7314],
         [-0.2620, -0.1414]],

        [[ 0.6133, -1.7462],
         [ 1.5648,  1.9300],
         [-0.5474,  0.0599]],

        [[-1.7279,  0.9321],
         [ 0.7225, -1.4084],
         [ 1.0648,  1.7109]]], grad_fn=<IndexBackward>)
tensor([[[-0.3221, -0.1282],
         [-0.2072, -0.1085],
         [-0.2619, -0.1412]],

        [[ 0.1969,  0.1525],
         [ 0.1955,  0.1531],
         [ 0.1958,  0.1536]],

        [[-0.0960, -0.2036],
         [-0.0965, -0.2031],
         [-0.0970, -0.2017]],

        [[-0.3221, -0.1282],
         [ 1.7952, -0.7318],
         [-0.2619, -0.1412]],

        [[ 0.6131, -1.7468],
         [ 1.5652,  1.9309],
         [-0.5477,  0.0598]

tensor([[[-0.2800, -0.1140],
         [-0.2520, -0.1154],
         [-0.2595, -0.1188]],

        [[ 0.1837,  0.1354],
         [ 0.1832,  0.1343],
         [ 0.1832,  0.1358]],

        [[-0.0899, -0.1997],
         [-0.0889, -0.1993],
         [-0.0904, -0.1991]],

        [[-0.2800, -0.1140],
         [ 1.8784, -0.7554],
         [-0.2595, -0.1188]],

        [[ 0.6144, -1.7780],
         [ 1.5936,  1.9700],
         [-0.5695,  0.0531]],

        [[-1.7628,  0.9668],
         [ 0.7466, -1.4461],
         [ 1.0764,  1.7514]]], grad_fn=<IndexBackward>)
tensor([[[-0.2793, -0.1140],
         [-0.2526, -0.1154],
         [-0.2597, -0.1186]],

        [[ 0.1835,  0.1352],
         [ 0.1830,  0.1340],
         [ 0.1829,  0.1355]],

        [[-0.0899, -0.1995],
         [-0.0889, -0.1996],
         [-0.0904, -0.1988]],

        [[-0.2793, -0.1140],
         [ 1.8794, -0.7558],
         [-0.2597, -0.1186]],

        [[ 0.6146, -1.7785],
         [ 1.5937,  1.9706],
         [-0.5696,  0.0532]

tensor([[[-0.2692, -0.1081],
         [-0.2688, -0.1072],
         [-0.2693, -0.1069]],

        [[ 0.1696,  0.1248],
         [ 0.1694,  0.1256],
         [ 0.1694,  0.1250]],

        [[-0.0844, -0.1835],
         [-0.0838, -0.1834],
         [-0.0821, -0.1828]],

        [[-0.2692, -0.1081],
         [ 1.9541, -0.7792],
         [-0.2693, -0.1069]],

        [[ 0.6022, -1.8141],
         [ 1.6134,  2.0124],
         [-0.5902,  0.0451]],

        [[-1.7982,  0.9924],
         [ 0.7622, -1.4880],
         [ 1.0873,  1.7897]]], grad_fn=<IndexBackward>)
tensor([[[-0.2696, -0.1078],
         [-0.2688, -0.1072],
         [-0.2693, -0.1070]],

        [[ 0.1694,  0.1247],
         [ 0.1692,  0.1254],
         [ 0.1696,  0.1249]],

        [[-0.0844, -0.1834],
         [-0.0835, -0.1832],
         [-0.0821, -0.1827]],

        [[-0.2696, -0.1078],
         [ 1.9549, -0.7796],
         [-0.2693, -0.1070]],

        [[ 0.6020, -1.8149],
         [ 1.6138,  2.0130],
         [-0.5902,  0.0452]

tensor([[[-0.2724, -0.0962],
         [-0.2724, -0.0960],
         [-0.2722, -0.0957]],

        [[ 0.1547,  0.1174],
         [ 0.1544,  0.1170],
         [ 0.1542,  0.1167]],

        [[-0.0813, -0.1702],
         [-0.0813, -0.1702],
         [-0.0807, -0.1695]],

        [[-0.2724, -0.0962],
         [ 2.0208, -0.8049],
         [-0.2722, -0.0957]],

        [[ 0.6045, -1.8449],
         [ 1.6387,  2.0415],
         [-0.5992,  0.0370]],

        [[-1.8230,  1.0210],
         [ 0.7786, -1.5210],
         [ 1.1044,  1.8216]]], grad_fn=<IndexBackward>)
tensor([[[-0.2725, -0.0961],
         [-0.2724, -0.0956],
         [-0.2724, -0.0955]],

        [[ 0.1542,  0.1169],
         [ 0.1540,  0.1173],
         [ 0.1542,  0.1166]],

        [[-0.0815, -0.1699],
         [-0.0811, -0.1702],
         [-0.0805, -0.1690]],

        [[-0.2725, -0.0961],
         [ 2.0217, -0.8049],
         [-0.2724, -0.0955]],

        [[ 0.6046, -1.8452],
         [ 1.6389,  2.0419],
         [-0.5994,  0.0369]

tensor([[[-0.2764, -0.0851],
         [-0.2743, -0.0838],
         [-0.2747, -0.0828]],

        [[ 0.1484,  0.1069],
         [ 0.1489,  0.1079],
         [ 0.1487,  0.1086]],

        [[-0.0687, -0.1599],
         [-0.0683, -0.1590],
         [-0.0683, -0.1581]],

        [[-0.2764, -0.0851],
         [ 2.0887, -0.8178],
         [-0.2747, -0.0828]],

        [[ 0.5969, -1.8766],
         [ 1.6514,  2.0859],
         [-0.6156,  0.0216]],

        [[-1.8457,  1.0497],
         [ 0.7846, -1.5560],
         [ 1.1062,  1.8635]]], grad_fn=<IndexBackward>)
tensor([[[-0.2767, -0.0850],
         [-0.2745, -0.0838],
         [-0.2748, -0.0830]],

        [[ 0.1483,  0.1068],
         [ 0.1488,  0.1074],
         [ 0.1487,  0.1084]],

        [[-0.0686, -0.1599],
         [-0.0682, -0.1590],
         [-0.0681, -0.1583]],

        [[-0.2767, -0.0850],
         [ 2.0897, -0.8178],
         [-0.2748, -0.0830]],

        [[ 0.5969, -1.8768],
         [ 1.6514,  2.0865],
         [-0.6160,  0.0215]

tensor([[[-0.2752, -0.0749],
         [-0.2768, -0.0734],
         [-0.2772, -0.0748]],

        [[ 0.1408,  0.1002],
         [ 0.1400,  0.1004],
         [ 0.1399,  0.0996]],

        [[-0.0603, -0.1453],
         [-0.0609, -0.1471],
         [-0.0615, -0.1472]],

        [[-0.2752, -0.0749],
         [ 2.1509, -0.8458],
         [-0.2772, -0.0748]],

        [[ 0.6021, -1.9076],
         [ 1.6724,  2.1251],
         [-0.6445,  0.0039]],

        [[-1.8764,  1.0768],
         [ 0.7993, -1.5963],
         [ 1.1237,  1.8922]]], grad_fn=<IndexBackward>)
tensor([[[-0.2753, -0.0747],
         [-0.2766, -0.0738],
         [-0.2774, -0.0745]],

        [[ 0.1408,  0.0997],
         [ 0.1397,  0.1003],
         [ 0.1398,  0.0993]],

        [[-0.0602, -0.1454],
         [-0.0607, -0.1472],
         [-0.0616, -0.1470]],

        [[-0.2753, -0.0747],
         [ 2.1517, -0.8460],
         [-0.2774, -0.0745]],

        [[ 0.6020, -1.9079],
         [ 1.6725,  2.1256],
         [-0.6447,  0.0037]

tensor([[[-0.2767, -0.0662],
         [-0.2736, -0.0664],
         [-0.2780, -0.0654]],

        [[ 0.1371,  0.0879],
         [ 0.1377,  0.0895],
         [ 0.1364,  0.0878]],

        [[-0.0513, -0.1410],
         [-0.0508, -0.1403],
         [-0.0503, -0.1401]],

        [[-0.2767, -0.0662],
         [ 2.2167, -0.8674],
         [-0.2780, -0.0654]],

        [[ 0.6036, -1.9406],
         [ 1.6963,  2.1679],
         [-0.6680, -0.0175]],

        [[-1.9110,  1.1018],
         [ 0.8052, -1.6223],
         [ 1.1296,  1.9342]]], grad_fn=<IndexBackward>)
tensor([[[-0.2767, -0.0664],
         [-0.2738, -0.0662],
         [-0.2780, -0.0653]],

        [[ 0.1367,  0.0876],
         [ 0.1379,  0.0894],
         [ 0.1362,  0.0878]],

        [[-0.0510, -0.1409],
         [-0.0508, -0.1406],
         [-0.0504, -0.1397]],

        [[-0.2767, -0.0664],
         [ 2.2174, -0.8678],
         [-0.2780, -0.0653]],

        [[ 0.6031, -1.9410],
         [ 1.6967,  2.1684],
         [-0.6682, -0.0176]

tensor([[[-0.2779, -0.0544],
         [-0.2779, -0.0549],
         [-0.2789, -0.0551]],

        [[ 0.1243,  0.0844],
         [ 0.1246,  0.0855],
         [ 0.1238,  0.0848]],

        [[-0.0495, -0.1281],
         [-0.0500, -0.1294],
         [-0.0486, -0.1288]],

        [[-0.2779, -0.0544],
         [ 2.2671, -0.8861],
         [-0.2789, -0.0551]],

        [[ 0.5908, -1.9698],
         [ 1.7135,  2.2093],
         [-0.6841, -0.0369]],

        [[-1.9391,  1.1109],
         [ 0.8134, -1.6502],
         [ 1.1287,  1.9578]]], grad_fn=<IndexBackward>)
tensor([[[-0.2780, -0.0541],
         [-0.2777, -0.0548],
         [-0.2790, -0.0547]],

        [[ 0.1242,  0.0844],
         [ 0.1243,  0.0854],
         [ 0.1238,  0.0849]],

        [[-0.0495, -0.1280],
         [-0.0500, -0.1292],
         [-0.0484, -0.1284]],

        [[-0.2780, -0.0541],
         [ 2.2678, -0.8862],
         [-0.2790, -0.0547]],

        [[ 0.5908, -1.9705],
         [ 1.7138,  2.2097],
         [-0.6841, -0.0372]

tensor([[[-0.2816, -0.0480],
         [-0.2828, -0.0463],
         [-0.2859, -0.0446]],

        [[ 0.1106,  0.0730],
         [ 0.1106,  0.0763],
         [ 0.1112,  0.0774]],

        [[-0.0520, -0.1240],
         [-0.0520, -0.1236],
         [-0.0523, -0.1235]],

        [[-0.2816, -0.0480],
         [ 2.3166, -0.9103],
         [-0.2859, -0.0446]],

        [[ 0.5899, -1.9988],
         [ 1.7318,  2.2395],
         [-0.6996, -0.0569]],

        [[-1.9586,  1.1358],
         [ 0.8238, -1.6835],
         [ 1.1248,  1.9893]]], grad_fn=<IndexBackward>)
tensor([[[-0.2817, -0.0479],
         [-0.2828, -0.0464],
         [-0.2860, -0.0447]],

        [[ 0.1103,  0.0729],
         [ 0.1107,  0.0762],
         [ 0.1111,  0.0772]],

        [[-0.0520, -0.1238],
         [-0.0520, -0.1237],
         [-0.0523, -0.1237]],

        [[-0.2817, -0.0479],
         [ 2.3173, -0.9103],
         [-0.2860, -0.0447]],

        [[ 0.5899, -1.9990],
         [ 1.7320,  2.2398],
         [-0.6998, -0.0569]

tensor([[[-0.2830, -0.0349],
         [-0.2812, -0.0360],
         [-0.2818, -0.0354]],

        [[ 0.1075,  0.0743],
         [ 0.1074,  0.0744],
         [ 0.1077,  0.0744]],

        [[-0.0455, -0.1119],
         [-0.0450, -0.1131],
         [-0.0465, -0.1143]],

        [[-0.2830, -0.0349],
         [ 2.3723, -0.9253],
         [-0.2818, -0.0354]],

        [[ 0.5743, -2.0290],
         [ 1.7456,  2.2755],
         [-0.7146, -0.0756]],

        [[-1.9805,  1.1596],
         [ 0.8378, -1.7148],
         [ 1.1231,  2.0162]]], grad_fn=<IndexBackward>)
tensor([[[-0.2829, -0.0349],
         [-0.2814, -0.0358],
         [-0.2816, -0.0355]],

        [[ 0.1072,  0.0742],
         [ 0.1075,  0.0744],
         [ 0.1077,  0.0743]],

        [[-0.0455, -0.1120],
         [-0.0448, -0.1128],
         [-0.0463, -0.1143]],

        [[-0.2829, -0.0349],
         [ 2.3730, -0.9254],
         [-0.2816, -0.0355]],

        [[ 0.5742, -2.0293],
         [ 1.7458,  2.2758],
         [-0.7148, -0.0756]

tensor([[[-0.2783, -0.0314],
         [-0.2788, -0.0310],
         [-0.2799, -0.0312]],

        [[ 0.1048,  0.0683],
         [ 0.1040,  0.0681],
         [ 0.1046,  0.0685]],

        [[-0.0388, -0.1076],
         [-0.0393, -0.1064],
         [-0.0388, -0.1070]],

        [[-0.2783, -0.0314],
         [ 2.4254, -0.9393],
         [-0.2799, -0.0312]],

        [[ 0.5689, -2.0569],
         [ 1.7618,  2.3042],
         [-0.7255, -0.0797]],

        [[-2.0075,  1.1898],
         [ 0.8546, -1.7410],
         [ 1.1172,  2.0456]]], grad_fn=<IndexBackward>)
tensor([[[-0.2782, -0.0314],
         [-0.2788, -0.0307],
         [-0.2803, -0.0311]],

        [[ 0.1049,  0.0681],
         [ 0.1039,  0.0680],
         [ 0.1044,  0.0685]],

        [[-0.0387, -0.1076],
         [-0.0393, -0.1065],
         [-0.0387, -0.1067]],

        [[-0.2782, -0.0314],
         [ 2.4262, -0.9395],
         [-0.2803, -0.0311]],

        [[ 0.5690, -2.0571],
         [ 1.7617,  2.3051],
         [-0.7262, -0.0796]

tensor([[[-0.2810, -0.0263],
         [-0.2797, -0.0263],
         [-0.2800, -0.0244]],

        [[ 0.1004,  0.0618],
         [ 0.1006,  0.0615],
         [ 0.1007,  0.0626]],

        [[-0.0361, -0.1046],
         [-0.0364, -0.1048],
         [-0.0357, -0.1043]],

        [[-0.2810, -0.0263],
         [ 2.4835, -0.9516],
         [-0.2800, -0.0244]],

        [[ 0.5534, -2.0837],
         [ 1.7810,  2.3430],
         [-0.7491, -0.0884]],

        [[-2.0429,  1.2210],
         [ 0.8750, -1.7619],
         [ 1.1116,  2.0842]]], grad_fn=<IndexBackward>)
tensor([[[-0.2810, -0.0261],
         [-0.2797, -0.0259],
         [-0.2798, -0.0247]],

        [[ 0.1006,  0.0618],
         [ 0.1004,  0.0615],
         [ 0.1006,  0.0623]],

        [[-0.0360, -0.1044],
         [-0.0364, -0.1046],
         [-0.0357, -0.1046]],

        [[-0.2810, -0.0261],
         [ 2.4841, -0.9518],
         [-0.2798, -0.0247]],

        [[ 0.5532, -2.0840],
         [ 1.7813,  2.3434],
         [-0.7492, -0.0886]

tensor([[[-0.2786, -0.0165],
         [-0.2791, -0.0174],
         [-0.2803, -0.0182]],

        [[ 0.0932,  0.0657],
         [ 0.0939,  0.0633],
         [ 0.0932,  0.0634]],

        [[-0.0347, -0.0874],
         [-0.0350, -0.0902],
         [-0.0338, -0.0915]],

        [[-0.2786, -0.0165],
         [ 2.5260, -0.9687],
         [-0.2803, -0.0182]],

        [[ 0.5531, -2.1088],
         [ 1.7971,  2.3634],
         [-0.7555, -0.0946]],

        [[-2.0552,  1.2377],
         [ 0.8812, -1.7912],
         [ 1.1200,  2.1124]]], grad_fn=<IndexBackward>)
tensor([[[-0.2788, -0.0166],
         [-0.2792, -0.0173],
         [-0.2800, -0.0180]],

        [[ 0.0928,  0.0655],
         [ 0.0940,  0.0633],
         [ 0.0931,  0.0636]],

        [[-0.0350, -0.0873],
         [-0.0348, -0.0901],
         [-0.0339, -0.0912]],

        [[-0.2788, -0.0166],
         [ 2.5264, -0.9688],
         [-0.2800, -0.0180]],

        [[ 0.5534, -2.1094],
         [ 1.7976,  2.3636],
         [-0.7556, -0.0949]

tensor([[[-0.2786, -0.0145],
         [-0.2784, -0.0121],
         [-0.2793, -0.0141]],

        [[ 0.0840,  0.0602],
         [ 0.0832,  0.0600],
         [ 0.0827,  0.0601]],

        [[-0.0384, -0.0854],
         [-0.0379, -0.0843],
         [-0.0388, -0.0835]],

        [[-0.2786, -0.0145],
         [ 2.5633, -0.9722],
         [-0.2793, -0.0141]],

        [[ 0.5576, -2.1320],
         [ 1.8164,  2.3758],
         [-0.7618, -0.1014]],

        [[-2.0658,  1.2534],
         [ 0.8924, -1.8242],
         [ 1.1273,  2.1359]]], grad_fn=<IndexBackward>)
tensor([[[-0.2788, -0.0142],
         [-0.2784, -0.0122],
         [-0.2792, -0.0137]],

        [[ 0.0837,  0.0602],
         [ 0.0831,  0.0601],
         [ 0.0827,  0.0601]],

        [[-0.0386, -0.0852],
         [-0.0380, -0.0841],
         [-0.0387, -0.0833]],

        [[-0.2788, -0.0142],
         [ 2.5638, -0.9724],
         [-0.2792, -0.0137]],

        [[ 0.5577, -2.1324],
         [ 1.8168,  2.3761],
         [-0.7618, -0.1015]

tensor([[[-0.2774, -0.0075],
         [-0.2773, -0.0072],
         [-0.2769, -0.0066]],

        [[ 0.0780,  0.0577],
         [ 0.0781,  0.0578],
         [ 0.0793,  0.0584]],

        [[-0.0373, -0.0767],
         [-0.0371, -0.0783],
         [-0.0350, -0.0769]],

        [[-0.2774, -0.0075],
         [ 2.6015, -0.9930],
         [-0.2769, -0.0066]],

        [[ 0.5552, -2.1550],
         [ 1.8377,  2.4034],
         [-0.7714, -0.1141]],

        [[-2.0871,  1.2590],
         [ 0.9066, -1.8503],
         [ 1.1406,  2.1493]]], grad_fn=<IndexBackward>)
tensor([[[-0.2775, -0.0072],
         [-0.2774, -0.0071],
         [-0.2767, -0.0067]],

        [[ 0.0780,  0.0578],
         [ 0.0782,  0.0582],
         [ 0.0791,  0.0581]],

        [[-0.0372, -0.0764],
         [-0.0369, -0.0781],
         [-0.0351, -0.0769]],

        [[-0.2775, -0.0072],
         [ 2.6021, -0.9933],
         [-0.2767, -0.0067]],

        [[ 0.5552, -2.1552],
         [ 1.8378,  2.4036],
         [-0.7715, -0.1142]

tensor([[[-2.8055e-01, -1.2614e-03],
         [-2.7935e-01, -2.3888e-03],
         [-2.8095e-01, -1.5240e-03]],

        [[ 7.3113e-02,  5.1802e-02],
         [ 7.2136e-02,  5.2252e-02],
         [ 7.3061e-02,  5.3847e-02]],

        [[-3.5958e-02, -7.5486e-02],
         [-3.5558e-02, -7.4234e-02],
         [-3.5266e-02, -7.2679e-02]],

        [[-2.8055e-01, -1.2614e-03],
         [ 2.6473e+00, -1.0066e+00],
         [-2.8095e-01, -1.5240e-03]],

        [[ 5.4153e-01, -2.1782e+00],
         [ 1.8588e+00,  2.4279e+00],
         [-7.7753e-01, -1.2436e-01]],

        [[-2.1046e+00,  1.2931e+00],
         [ 9.2861e-01, -1.8918e+00],
         [ 1.1569e+00,  2.1790e+00]]], grad_fn=<IndexBackward>)
tensor([[[-2.8096e-01, -9.2671e-04],
         [-2.7934e-01, -2.4076e-03],
         [-2.8080e-01, -1.4202e-03]],

        [[ 7.2862e-02,  5.1850e-02],
         [ 7.2096e-02,  5.2124e-02],
         [ 7.3150e-02,  5.3622e-02]],

        [[-3.6045e-02, -7.5259e-02],
         [-3.5455e-02, -7.4340e-02

tensor([[[-2.7463e-01,  8.8631e-04],
         [-2.7544e-01,  3.6659e-04],
         [-2.7726e-01, -1.0690e-03]],

        [[ 7.2020e-02,  4.6939e-02],
         [ 7.2589e-02,  4.6981e-02],
         [ 7.1120e-02,  4.7584e-02]],

        [[-3.1072e-02, -7.3205e-02],
         [-3.0633e-02, -7.3133e-02],
         [-3.0590e-02, -7.2948e-02]],

        [[-2.7463e-01,  8.8631e-04],
         [ 2.6834e+00, -1.0182e+00],
         [-2.7726e-01, -1.0690e-03]],

        [[ 5.3813e-01, -2.1947e+00],
         [ 1.8710e+00,  2.4486e+00],
         [-7.8772e-01, -1.3681e-01]],

        [[-2.1225e+00,  1.3112e+00],
         [ 9.4402e-01, -1.9216e+00],
         [ 1.1504e+00,  2.2031e+00]]], grad_fn=<IndexBackward>)
tensor([[[-2.7469e-01,  7.0159e-04],
         [-2.7533e-01,  5.6269e-04],
         [-2.7712e-01, -7.9649e-04]],

        [[ 7.1978e-02,  4.6915e-02],
         [ 7.2435e-02,  4.6870e-02],
         [ 7.0991e-02,  4.7579e-02]],

        [[-3.1017e-02, -7.3116e-02],
         [-3.0603e-02, -7.3288e-02

tensor([[[-2.7574e-01,  2.3511e-03],
         [-2.7554e-01,  3.4148e-03],
         [-2.7616e-01,  1.1309e-03]],

        [[ 6.3308e-02,  4.3118e-02],
         [ 6.2923e-02,  4.2844e-02],
         [ 6.3043e-02,  4.2490e-02]],

        [[-3.2000e-02, -7.2373e-02],
         [-3.3711e-02, -7.2731e-02],
         [-3.3578e-02, -7.2726e-02]],

        [[-2.7574e-01,  2.3511e-03],
         [ 2.7212e+00, -1.0306e+00],
         [-2.7616e-01,  1.1309e-03]],

        [[ 5.2968e-01, -2.2121e+00],
         [ 1.8883e+00,  2.4737e+00],
         [-7.9156e-01, -1.4356e-01]],

        [[-2.1462e+00,  1.3343e+00],
         [ 9.6246e-01, -1.9358e+00],
         [ 1.1456e+00,  2.2221e+00]]], grad_fn=<IndexBackward>)
tensor([[[-2.7567e-01,  2.6298e-03],
         [-2.7561e-01,  3.2498e-03],
         [-2.7623e-01,  1.3282e-03]],

        [[ 6.3289e-02,  4.3122e-02],
         [ 6.2806e-02,  4.2835e-02],
         [ 6.2851e-02,  4.2628e-02]],

        [[-3.2159e-02, -7.2277e-02],
         [-3.3664e-02, -7.2472e-02

tensor([[[-0.2785,  0.0077],
         [-0.2772,  0.0082],
         [-0.2793,  0.0095]],

        [[ 0.0596,  0.0456],
         [ 0.0603,  0.0451],
         [ 0.0598,  0.0444]],

        [[-0.0313, -0.0623],
         [-0.0313, -0.0636],
         [-0.0315, -0.0642]],

        [[-0.2785,  0.0077],
         [ 2.7608, -1.0604],
         [-0.2793,  0.0095]],

        [[ 0.5149, -2.2378],
         [ 1.9074,  2.5148],
         [-0.8003, -0.1583]],

        [[-2.1726,  1.3422],
         [ 0.9749, -1.9642],
         [ 1.1480,  2.2467]]], grad_fn=<IndexBackward>)
tensor([[[-0.2783,  0.0075],
         [-0.2775,  0.0080],
         [-0.2789,  0.0098]],

        [[ 0.0602,  0.0455],
         [ 0.0598,  0.0448],
         [ 0.0599,  0.0452]],

        [[-0.0312, -0.0628],
         [-0.0313, -0.0632],
         [-0.0315, -0.0637]],

        [[-0.2783,  0.0075],
         [ 2.7612, -1.0607],
         [-0.2789,  0.0098]],

        [[ 0.5146, -2.2382],
         [ 1.9075,  2.5152],
         [-0.8005, -0.1585]

tensor([[[-0.2770,  0.0067],
         [-0.2789,  0.0059],
         [-0.2776,  0.0069]],

        [[ 0.0501,  0.0426],
         [ 0.0492,  0.0420],
         [ 0.0490,  0.0423]],

        [[-0.0384, -0.0617],
         [-0.0365, -0.0618],
         [-0.0391, -0.0619]],

        [[-0.2770,  0.0067],
         [ 2.7945, -1.0715],
         [-0.2776,  0.0069]],

        [[ 0.5137, -2.2596],
         [ 1.9273,  2.5293],
         [-0.8021, -0.1732]],

        [[-2.1902,  1.3522],
         [ 0.9868, -1.9816],
         [ 1.1511,  2.2623]]], grad_fn=<IndexBackward>)
tensor([[[-0.2769,  0.0065],
         [-0.2789,  0.0057],
         [-0.2775,  0.0066]],

        [[ 0.0499,  0.0424],
         [ 0.0493,  0.0419],
         [ 0.0490,  0.0420]],

        [[-0.0382, -0.0619],
         [-0.0365, -0.0618],
         [-0.0393, -0.0620]],

        [[-0.2769,  0.0065],
         [ 2.7950, -1.0715],
         [-0.2775,  0.0066]],

        [[ 0.5137, -2.2597],
         [ 1.9275,  2.5295],
         [-0.8022, -0.1733]

tensor([[[-0.2754,  0.0096],
         [-0.2710,  0.0087],
         [-0.2739,  0.0087]],

        [[ 0.0484,  0.0424],
         [ 0.0471,  0.0430],
         [ 0.0468,  0.0433]],

        [[-0.0338, -0.0570],
         [-0.0347, -0.0573],
         [-0.0359, -0.0564]],

        [[-0.2754,  0.0096],
         [ 2.8344, -1.0857],
         [-0.2739,  0.0087]],

        [[ 0.5180, -2.2809],
         [ 1.9348,  2.5564],
         [-0.8179, -0.1790]],

        [[-2.2161,  1.3722],
         [ 0.9913, -2.0097],
         [ 1.1638,  2.2847]]], grad_fn=<IndexBackward>)
tensor([[[-0.2753,  0.0094],
         [-0.2712,  0.0086],
         [-0.2738,  0.0092]],

        [[ 0.0480,  0.0424],
         [ 0.0470,  0.0431],
         [ 0.0468,  0.0433]],

        [[-0.0341, -0.0569],
         [-0.0348, -0.0571],
         [-0.0358, -0.0564]],

        [[-0.2753,  0.0094],
         [ 2.8349, -1.0859],
         [-0.2738,  0.0092]],

        [[ 0.5180, -2.2811],
         [ 1.9349,  2.5568],
         [-0.8180, -0.1791]

tensor([[[-0.2710,  0.0188],
         [-0.2708,  0.0173],
         [-0.2713,  0.0162]],

        [[ 0.0434,  0.0463],
         [ 0.0441,  0.0460],
         [ 0.0435,  0.0443]],

        [[-0.0336, -0.0464],
         [-0.0344, -0.0471],
         [-0.0327, -0.0485]],

        [[-0.2710,  0.0188],
         [ 2.8684, -1.1014],
         [-0.2713,  0.0162]],

        [[ 0.5184, -2.3226],
         [ 1.9432,  2.5833],
         [-0.8256, -0.1961]],

        [[-2.2294,  1.3813],
         [ 1.0021, -2.0282],
         [ 1.1622,  2.3028]]], grad_fn=<IndexBackward>)
tensor([[[-0.2708,  0.0196],
         [-0.2704,  0.0174],
         [-0.2711,  0.0165]],

        [[ 0.0434,  0.0464],
         [ 0.0442,  0.0460],
         [ 0.0434,  0.0444]],

        [[-0.0334, -0.0463],
         [-0.0343, -0.0469],
         [-0.0325, -0.0483]],

        [[-0.2708,  0.0196],
         [ 2.8691, -1.1016],
         [-0.2711,  0.0165]],

        [[ 0.5178, -2.3231],
         [ 1.9437,  2.5841],
         [-0.8255, -0.1966]

tensor([[[-0.2704,  0.0213],
         [-0.2683,  0.0214],
         [-0.2681,  0.0224]],

        [[ 0.0391,  0.0407],
         [ 0.0394,  0.0412],
         [ 0.0399,  0.0409]],

        [[-0.0322, -0.0463],
         [-0.0332, -0.0460],
         [-0.0331, -0.0462]],

        [[-0.2704,  0.0213],
         [ 2.9107, -1.1057],
         [-0.2681,  0.0224]],

        [[ 0.5093, -2.3420],
         [ 1.9527,  2.6232],
         [-0.8433, -0.2311]],

        [[-2.2530,  1.3949],
         [ 1.0194, -2.0518],
         [ 1.1537,  2.3285]]], grad_fn=<IndexBackward>)
tensor([[[-0.2703,  0.0213],
         [-0.2681,  0.0211],
         [-0.2682,  0.0223]],

        [[ 0.0392,  0.0408],
         [ 0.0396,  0.0407],
         [ 0.0399,  0.0407]],

        [[-0.0324, -0.0461],
         [-0.0328, -0.0464],
         [-0.0329, -0.0464]],

        [[-0.2703,  0.0213],
         [ 2.9112, -1.1057],
         [-0.2682,  0.0223]],

        [[ 0.5091, -2.3421],
         [ 1.9527,  2.6235],
         [-0.8435, -0.2312]

tensor([[[-0.2627,  0.0258],
         [-0.2614,  0.0258],
         [-0.2587,  0.0274]],

        [[ 0.0424,  0.0369],
         [ 0.0428,  0.0375],
         [ 0.0424,  0.0378]],

        [[-0.0254, -0.0446],
         [-0.0256, -0.0444],
         [-0.0263, -0.0441]],

        [[-0.2627,  0.0258],
         [ 2.9439, -1.1136],
         [-0.2587,  0.0274]],

        [[ 0.5057, -2.3558],
         [ 1.9600,  2.6442],
         [-0.8479, -0.2394]],

        [[-2.2637,  1.4051],
         [ 1.0329, -2.0697],
         [ 1.1475,  2.3430]]], grad_fn=<IndexBackward>)
tensor([[[-0.2627,  0.0261],
         [-0.2612,  0.0259],
         [-0.2589,  0.0275]],

        [[ 0.0424,  0.0369],
         [ 0.0426,  0.0377],
         [ 0.0425,  0.0378]],

        [[-0.0255, -0.0446],
         [-0.0255, -0.0442],
         [-0.0265, -0.0441]],

        [[-0.2627,  0.0261],
         [ 2.9443, -1.1138],
         [-0.2589,  0.0275]],

        [[ 0.5055, -2.3559],
         [ 1.9601,  2.6444],
         [-0.8479, -0.2395]

tensor([[[-0.2633,  0.0287],
         [-0.2640,  0.0273],
         [-0.2626,  0.0284]],

        [[ 0.0376,  0.0363],
         [ 0.0393,  0.0340],
         [ 0.0395,  0.0343]],

        [[-0.0264, -0.0434],
         [-0.0260, -0.0436],
         [-0.0272, -0.0442]],

        [[-0.2633,  0.0287],
         [ 2.9803, -1.1217],
         [-0.2626,  0.0284]],

        [[ 0.5009, -2.3687],
         [ 1.9675,  2.6541],
         [-0.8528, -0.2464]],

        [[-2.2659,  1.4180],
         [ 1.0437, -2.0870],
         [ 1.1478,  2.3538]]], grad_fn=<IndexBackward>)
tensor([[[-0.2635,  0.0284],
         [-0.2641,  0.0274],
         [-0.2625,  0.0283]],

        [[ 0.0375,  0.0362],
         [ 0.0391,  0.0340],
         [ 0.0395,  0.0341]],

        [[-0.0261, -0.0433],
         [-0.0263, -0.0438],
         [-0.0273, -0.0443]],

        [[-0.2635,  0.0284],
         [ 2.9806, -1.1219],
         [-0.2625,  0.0283]],

        [[ 0.5008, -2.3688],
         [ 1.9675,  2.6541],
         [-0.8528, -0.2465]

tensor([[[-0.2620,  0.0311],
         [-0.2601,  0.0331],
         [-0.2607,  0.0327]],

        [[ 0.0360,  0.0403],
         [ 0.0367,  0.0391],
         [ 0.0378,  0.0391]],

        [[-0.0267, -0.0359],
         [-0.0257, -0.0349],
         [-0.0257, -0.0357]],

        [[-0.2620,  0.0311],
         [ 3.0089, -1.1318],
         [-0.2607,  0.0327]],

        [[ 0.4966, -2.3857],
         [ 1.9729,  2.6710],
         [-0.8606, -0.2537]],

        [[-2.2768,  1.4328],
         [ 1.0572, -2.1066],
         [ 1.1415,  2.3740]]], grad_fn=<IndexBackward>)
tensor([[[-0.2620,  0.0310],
         [-0.2602,  0.0330],
         [-0.2606,  0.0327]],

        [[ 0.0361,  0.0403],
         [ 0.0366,  0.0391],
         [ 0.0376,  0.0391]],

        [[-0.0266, -0.0360],
         [-0.0259, -0.0348],
         [-0.0255, -0.0358]],

        [[-0.2620,  0.0310],
         [ 3.0093, -1.1319],
         [-0.2606,  0.0327]],

        [[ 0.4966, -2.3859],
         [ 1.9730,  2.6712],
         [-0.8607, -0.2537]

tensor([[[-0.2606,  0.0356],
         [-0.2629,  0.0350],
         [-0.2673,  0.0327]],

        [[ 0.0366,  0.0367],
         [ 0.0367,  0.0361],
         [ 0.0365,  0.0377]],

        [[-0.0225, -0.0352],
         [-0.0228, -0.0360],
         [-0.0254, -0.0355]],

        [[-0.2606,  0.0356],
         [ 3.0571, -1.1348],
         [-0.2673,  0.0327]],

        [[ 0.4961, -2.4006],
         [ 2.0014,  2.7026],
         [-0.8994, -0.2695]],

        [[-2.3022,  1.4649],
         [ 1.0564, -2.1403],
         [ 1.1383,  2.4015]]], grad_fn=<IndexBackward>)
tensor([[[-0.2603,  0.0359],
         [-0.2629,  0.0348],
         [-0.2675,  0.0327]],

        [[ 0.0368,  0.0368],
         [ 0.0371,  0.0361],
         [ 0.0364,  0.0375]],

        [[-0.0220, -0.0352],
         [-0.0230, -0.0361],
         [-0.0252, -0.0354]],

        [[-0.2603,  0.0359],
         [ 3.0578, -1.1346],
         [-0.2675,  0.0327]],

        [[ 0.4962, -2.4011],
         [ 2.0013,  2.7031],
         [-0.8994, -0.2695]

tensor([[[-0.2605,  0.0340],
         [-0.2609,  0.0363],
         [-0.2645,  0.0381]],

        [[ 0.0377,  0.0370],
         [ 0.0362,  0.0363],
         [ 0.0366,  0.0363]],

        [[-0.0194, -0.0334],
         [-0.0214, -0.0337],
         [-0.0200, -0.0333]],

        [[-0.2605,  0.0340],
         [ 3.0840, -1.1352],
         [-0.2645,  0.0381]],

        [[ 0.4876, -2.4199],
         [ 2.0144,  2.7255],
         [-0.8998, -0.2762]],

        [[-2.3113,  1.4832],
         [ 1.0658, -2.1638],
         [ 1.1408,  2.4081]]], grad_fn=<IndexBackward>)
tensor([[[-0.2606,  0.0344],
         [-0.2608,  0.0363],
         [-0.2645,  0.0382]],

        [[ 0.0378,  0.0369],
         [ 0.0362,  0.0365],
         [ 0.0365,  0.0363]],

        [[-0.0196, -0.0333],
         [-0.0209, -0.0334],
         [-0.0204, -0.0333]],

        [[-0.2606,  0.0344],
         [ 3.0844, -1.1353],
         [-0.2645,  0.0382]],

        [[ 0.4874, -2.4206],
         [ 2.0146,  2.7259],
         [-0.9000, -0.2764]

tensor([[[-0.2655,  0.0379],
         [-0.2644,  0.0386],
         [-0.2636,  0.0383]],

        [[ 0.0339,  0.0341],
         [ 0.0331,  0.0322],
         [ 0.0320,  0.0332]],

        [[-0.0208, -0.0325],
         [-0.0214, -0.0338],
         [-0.0227, -0.0336]],

        [[-0.2655,  0.0379],
         [ 3.1160, -1.1550],
         [-0.2636,  0.0383]],

        [[ 0.4807, -2.4363],
         [ 2.0191,  2.7473],
         [-0.9136, -0.2778]],

        [[-2.3305,  1.4900],
         [ 1.0663, -2.1902],
         [ 1.1397,  2.4211]]], grad_fn=<IndexBackward>)
tensor([[[-0.2656,  0.0381],
         [-0.2645,  0.0386],
         [-0.2637,  0.0382]],

        [[ 0.0337,  0.0340],
         [ 0.0329,  0.0322],
         [ 0.0318,  0.0332]],

        [[-0.0210, -0.0324],
         [-0.0215, -0.0337],
         [-0.0229, -0.0336]],

        [[-0.2656,  0.0381],
         [ 3.1163, -1.1553],
         [-0.2637,  0.0382]],

        [[ 0.4807, -2.4364],
         [ 2.0191,  2.7475],
         [-0.9137, -0.2777]

tensor([[[-0.2676,  0.0402],
         [-0.2636,  0.0392],
         [-0.2671,  0.0405]],

        [[ 0.0280,  0.0294],
         [ 0.0278,  0.0278],
         [ 0.0267,  0.0275]],

        [[-0.0253, -0.0338],
         [-0.0241, -0.0353],
         [-0.0215, -0.0341]],

        [[-0.2676,  0.0402],
         [ 3.1414, -1.1703],
         [-0.2671,  0.0405]],

        [[ 0.4848, -2.4534],
         [ 2.0188,  2.7697],
         [-0.9069, -0.2957]],

        [[-2.3365,  1.4950],
         [ 1.0761, -2.2068],
         [ 1.1399,  2.4357]]], grad_fn=<IndexBackward>)
tensor([[[-0.2677,  0.0403],
         [-0.2637,  0.0394],
         [-0.2671,  0.0406]],

        [[ 0.0280,  0.0292],
         [ 0.0276,  0.0280],
         [ 0.0271,  0.0277]],

        [[-0.0254, -0.0340],
         [-0.0238, -0.0349],
         [-0.0215, -0.0341]],

        [[-0.2677,  0.0403],
         [ 3.1417, -1.1707],
         [-0.2671,  0.0406]],

        [[ 0.4847, -2.4535],
         [ 2.0187,  2.7699],
         [-0.9069, -0.2958]

tensor([[[-0.2661,  0.0387],
         [-0.2667,  0.0389],
         [-0.2655,  0.0402]],

        [[ 0.0252,  0.0229],
         [ 0.0238,  0.0240],
         [ 0.0223,  0.0238]],

        [[-0.0227, -0.0354],
         [-0.0245, -0.0351],
         [-0.0245, -0.0346]],

        [[-0.2661,  0.0387],
         [ 3.1631, -1.1823],
         [-0.2655,  0.0402]],

        [[ 0.4771, -2.4653],
         [ 2.0197,  2.7890],
         [-0.9168, -0.2997]],

        [[-2.3478,  1.5029],
         [ 1.0825, -2.2171],
         [ 1.1406,  2.4438]]], grad_fn=<IndexBackward>)
tensor([[[-0.2662,  0.0385],
         [-0.2666,  0.0388],
         [-0.2656,  0.0401]],

        [[ 0.0252,  0.0227],
         [ 0.0236,  0.0239],
         [ 0.0222,  0.0237]],

        [[-0.0228, -0.0354],
         [-0.0246, -0.0352],
         [-0.0247, -0.0347]],

        [[-0.2662,  0.0385],
         [ 3.1633, -1.1825],
         [-0.2656,  0.0401]],

        [[ 0.4771, -2.4656],
         [ 2.0199,  2.7891],
         [-0.9170, -0.2998]

tensor([[[-0.2634,  0.0415],
         [-0.2644,  0.0406],
         [-0.2644,  0.0416]],

        [[ 0.0197,  0.0246],
         [ 0.0191,  0.0241],
         [ 0.0185,  0.0246]],

        [[-0.0263, -0.0316],
         [-0.0267, -0.0326],
         [-0.0279, -0.0329]],

        [[-0.2634,  0.0415],
         [ 3.1897, -1.1933],
         [-0.2644,  0.0416]],

        [[ 0.4729, -2.4751],
         [ 2.0317,  2.8067],
         [-0.9132, -0.3028]],

        [[-2.3583,  1.5170],
         [ 1.0904, -2.2373],
         [ 1.1319,  2.4599]]], grad_fn=<IndexBackward>)
tensor([[[-0.2636,  0.0414],
         [-0.2642,  0.0403],
         [-0.2643,  0.0417]],

        [[ 0.0198,  0.0245],
         [ 0.0193,  0.0240],
         [ 0.0182,  0.0245]],

        [[-0.0263, -0.0316],
         [-0.0267, -0.0327],
         [-0.0280, -0.0329]],

        [[-0.2636,  0.0414],
         [ 3.1902, -1.1932],
         [-0.2643,  0.0417]],

        [[ 0.4728, -2.4752],
         [ 2.0318,  2.8069],
         [-0.9132, -0.3028]

tensor([[[-0.2620,  0.0402],
         [-0.2602,  0.0383],
         [-0.2650,  0.0404]],

        [[ 0.0170,  0.0200],
         [ 0.0166,  0.0205],
         [ 0.0165,  0.0199]],

        [[-0.0271, -0.0349],
         [-0.0268, -0.0336],
         [-0.0265, -0.0356]],

        [[-0.2620,  0.0402],
         [ 3.2204, -1.1979],
         [-0.2650,  0.0404]],

        [[ 0.4634, -2.4849],
         [ 2.0361,  2.8324],
         [-0.9228, -0.3107]],

        [[-2.3815,  1.5303],
         [ 1.1092, -2.2613],
         [ 1.1340,  2.4773]]], grad_fn=<IndexBackward>)
tensor([[[-0.2620,  0.0404],
         [-0.2601,  0.0385],
         [-0.2652,  0.0405]],

        [[ 0.0170,  0.0199],
         [ 0.0165,  0.0206],
         [ 0.0165,  0.0199]],

        [[-0.0272, -0.0350],
         [-0.0268, -0.0334],
         [-0.0263, -0.0357]],

        [[-0.2620,  0.0404],
         [ 3.2212, -1.1977],
         [-0.2652,  0.0405]],

        [[ 0.4630, -2.4851],
         [ 2.0361,  2.8329],
         [-0.9228, -0.3110]

tensor([[[-0.2631,  0.0456],
         [-0.2557,  0.0419],
         [-0.2649,  0.0454]],

        [[ 0.0173,  0.0184],
         [ 0.0180,  0.0181],
         [ 0.0192,  0.0185]],

        [[-0.0223, -0.0344],
         [-0.0227, -0.0340],
         [-0.0224, -0.0342]],

        [[-0.2631,  0.0456],
         [ 3.2581, -1.2036],
         [-0.2649,  0.0454]],

        [[ 0.4581, -2.4973],
         [ 2.0382,  2.8518],
         [-0.9318, -0.3120]],

        [[-2.3919,  1.5445],
         [ 1.1180, -2.2783],
         [ 1.1329,  2.5013]]], grad_fn=<IndexBackward>)
tensor([[[-0.2629,  0.0455],
         [-0.2557,  0.0421],
         [-0.2652,  0.0455]],

        [[ 0.0175,  0.0184],
         [ 0.0179,  0.0181],
         [ 0.0191,  0.0185]],

        [[-0.0223, -0.0342],
         [-0.0227, -0.0340],
         [-0.0222, -0.0341]],

        [[-0.2629,  0.0455],
         [ 3.2587, -1.2041],
         [-0.2652,  0.0455]],

        [[ 0.4583, -2.4976],
         [ 2.0382,  2.8523],
         [-0.9320, -0.3119]

tensor([[[-0.2650,  0.0455],
         [-0.2644,  0.0486],
         [-0.2642,  0.0475]],

        [[ 0.0203,  0.0156],
         [ 0.0182,  0.0185],
         [ 0.0173,  0.0169]],

        [[-0.0193, -0.0335],
         [-0.0198, -0.0323],
         [-0.0202, -0.0334]],

        [[-0.2650,  0.0455],
         [ 3.2914, -1.2217],
         [-0.2642,  0.0475]],

        [[ 0.4399, -2.5158],
         [ 2.0369,  2.8697],
         [-0.9423, -0.3243]],

        [[-2.4014,  1.5563],
         [ 1.1167, -2.2951],
         [ 1.1326,  2.5135]]], grad_fn=<IndexBackward>)
tensor([[[-0.2653,  0.0457],
         [-0.2642,  0.0482],
         [-0.2642,  0.0475]],

        [[ 0.0203,  0.0153],
         [ 0.0187,  0.0183],
         [ 0.0174,  0.0168]],

        [[-0.0193, -0.0336],
         [-0.0194, -0.0326],
         [-0.0199, -0.0333]],

        [[-0.2653,  0.0457],
         [ 3.2920, -1.2215],
         [-0.2642,  0.0475]],

        [[ 0.4398, -2.5159],
         [ 2.0369,  2.8698],
         [-0.9424, -0.3243]

tensor([[[-0.2639,  0.0432],
         [-0.2633,  0.0431],
         [-0.2639,  0.0433]],

        [[ 0.0247,  0.0102],
         [ 0.0255,  0.0097],
         [ 0.0251,  0.0091]],

        [[-0.0119, -0.0385],
         [-0.0124, -0.0383],
         [-0.0125, -0.0397]],

        [[-0.2639,  0.0432],
         [ 3.3355, -1.2225],
         [-0.2639,  0.0433]],

        [[ 0.4439, -2.5216],
         [ 2.0476,  2.8846],
         [-0.9491, -0.3265]],

        [[-2.4092,  1.5775],
         [ 1.1345, -2.3103],
         [ 1.1225,  2.5309]]], grad_fn=<IndexBackward>)
tensor([[[-0.2640,  0.0433],
         [-0.2634,  0.0434],
         [-0.2639,  0.0429]],

        [[ 0.0247,  0.0102],
         [ 0.0254,  0.0097],
         [ 0.0251,  0.0091]],

        [[-0.0119, -0.0384],
         [-0.0123, -0.0383],
         [-0.0125, -0.0397]],

        [[-0.2640,  0.0433],
         [ 3.3362, -1.2226],
         [-0.2639,  0.0429]],

        [[ 0.4441, -2.5219],
         [ 2.0479,  2.8848],
         [-0.9492, -0.3266]

tensor([[[-0.2664,  0.0449],
         [-0.2667,  0.0447],
         [-0.2670,  0.0450]],

        [[ 0.0193,  0.0055],
         [ 0.0169,  0.0068],
         [ 0.0180,  0.0071]],

        [[-0.0185, -0.0393],
         [-0.0176, -0.0402],
         [-0.0176, -0.0405]],

        [[-0.2664,  0.0449],
         [ 3.3650, -1.2312],
         [-0.2670,  0.0450]],

        [[ 0.4415, -2.5342],
         [ 2.0611,  2.9126],
         [-0.9625, -0.3361]],

        [[-2.4192,  1.5912],
         [ 1.1455, -2.3277],
         [ 1.1113,  2.5417]]], grad_fn=<IndexBackward>)
tensor([[[-0.2659,  0.0451],
         [-0.2665,  0.0447],
         [-0.2669,  0.0450]],

        [[ 0.0198,  0.0058],
         [ 0.0167,  0.0069],
         [ 0.0181,  0.0069]],

        [[-0.0181, -0.0392],
         [-0.0176, -0.0403],
         [-0.0176, -0.0404]],

        [[-0.2659,  0.0451],
         [ 3.3652, -1.2312],
         [-0.2669,  0.0450]],

        [[ 0.4415, -2.5345],
         [ 2.0611,  2.9129],
         [-0.9625, -0.3362]

tensor([[[-2.6292e-01,  4.3385e-02],
         [-2.6248e-01,  4.4129e-02],
         [-2.6367e-01,  4.5248e-02]],

        [[ 2.0324e-02,  4.0648e-03],
         [ 1.9278e-02,  3.3093e-03],
         [ 2.0026e-02,  3.1285e-03]],

        [[-1.4222e-02, -3.8883e-02],
         [-1.4792e-02, -4.0313e-02],
         [-1.4812e-02, -3.9980e-02]],

        [[-2.6292e-01,  4.3385e-02],
         [ 3.3939e+00, -1.2349e+00],
         [-2.6367e-01,  4.5248e-02]],

        [[ 4.4010e-01, -2.5536e+00],
         [ 2.0706e+00,  2.9292e+00],
         [-9.8018e-01, -3.3654e-01]],

        [[-2.4277e+00,  1.5922e+00],
         [ 1.1557e+00, -2.3411e+00],
         [ 1.1089e+00,  2.5540e+00]]], grad_fn=<IndexBackward>)
tensor([[[-2.6289e-01,  4.3467e-02],
         [-2.6221e-01,  4.4190e-02],
         [-2.6354e-01,  4.5061e-02]],

        [[ 2.0349e-02,  4.0206e-03],
         [ 1.9435e-02,  3.1914e-03],
         [ 2.0078e-02,  3.0816e-03]],

        [[-1.4131e-02, -3.8958e-02],
         [-1.4683e-02, -4.0408e-02

tensor([[[-2.6019e-01,  4.2748e-02],
         [-2.5815e-01,  4.3207e-02],
         [-2.5896e-01,  4.4679e-02]],

        [[ 1.3875e-02,  3.2787e-03],
         [ 1.5563e-02,  4.2923e-03],
         [ 1.5268e-02,  5.9892e-03]],

        [[-1.7835e-02, -3.8532e-02],
         [-1.8396e-02, -3.6813e-02],
         [-1.6711e-02, -3.5720e-02]],

        [[-2.6019e-01,  4.2748e-02],
         [ 3.4080e+00, -1.2424e+00],
         [-2.5896e-01,  4.4679e-02]],

        [[ 4.3461e-01, -2.5701e+00],
         [ 2.0882e+00,  2.9481e+00],
         [-9.8578e-01, -3.4867e-01]],

        [[-2.4312e+00,  1.5959e+00],
         [ 1.1670e+00, -2.3549e+00],
         [ 1.1076e+00,  2.5633e+00]]], grad_fn=<IndexBackward>)
tensor([[[-2.6023e-01,  4.2704e-02],
         [-2.5812e-01,  4.3324e-02],
         [-2.5916e-01,  4.4358e-02]],

        [[ 1.4026e-02,  3.1979e-03],
         [ 1.5337e-02,  4.4410e-03],
         [ 1.5233e-02,  5.9259e-03]],

        [[-1.7880e-02, -3.8404e-02],
         [-1.8376e-02, -3.6795e-02

tensor([[[-0.2590,  0.0463],
         [-0.2596,  0.0458],
         [-0.2608,  0.0452]],

        [[ 0.0128,  0.0043],
         [ 0.0118,  0.0041],
         [ 0.0128,  0.0038]],

        [[-0.0197, -0.0362],
         [-0.0191, -0.0360],
         [-0.0198, -0.0355]],

        [[-0.2590,  0.0463],
         [ 3.4396, -1.2446],
         [-0.2608,  0.0452]],

        [[ 0.4283, -2.5821],
         [ 2.0945,  2.9594],
         [-0.9952, -0.3543]],

        [[-2.4449,  1.6058],
         [ 1.1773, -2.3688],
         [ 1.1047,  2.5865]]], grad_fn=<IndexBackward>)
tensor([[[-0.2589,  0.0465],
         [-0.2595,  0.0460],
         [-0.2609,  0.0451]],

        [[ 0.0126,  0.0045],
         [ 0.0120,  0.0042],
         [ 0.0129,  0.0039]],

        [[-0.0196, -0.0360],
         [-0.0190, -0.0357],
         [-0.0196, -0.0356]],

        [[-0.2589,  0.0465],
         [ 3.4399, -1.2449],
         [-0.2609,  0.0451]],

        [[ 0.4282, -2.5823],
         [ 2.0945,  2.9596],
         [-0.9952, -0.3543]