In [2]:
import os
import random
import sys
#sys.path.append('../../../') 

import numpy as np
import torch as th
from torch import nn
from torch import optim
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.autograd import grad
import learn2learn as l2l
from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels, ConsecutiveLabels

In [3]:
#def transformer_loss(L0,L1):
def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)

def transformbatch(batch, transformer):
    data, labels = batch
    data, labels = data.to(device), labels.to(device)
    data_transformed = transformer(data)
    return [data_transformed, labels]

def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device):
    data, labels = batch
    data, labels = data.to(device), labels.to(device)

    # Separate data into adaptation/evalutation sets
    adaptation_indices = th.zeros(data.size(0)).byte()
    adaptation_indices[th.arange(shots*ways) * 2] = 1
    print(adaptation_indices.size())
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    print(adaptation_labels.size())
    evaluation_data, evaluation_labels = data[1 - adaptation_indices], labels[1 - adaptation_indices]
    
    #before adaptation, get the needed parameters for adaptation of transformer
    L0 = loss(learner(adaptation_data), adaptation_labels)
    L1 = loss(learner(evaluation_data), evaluation_labels)
     # Adapt the model
    for step in range(adaptation_steps):
        train_error = loss(learner(adaptation_data), adaptation_labels)
        train_error /= len(adaptation_data)
        learner.adapt(train_error)

    # Evaluate the adapted model
    predictions = learner(evaluation_data)
    valid_error = loss(predictions, evaluation_labels)
   # print(valid_error)
    valid_error /= len(evaluation_data)
    valid_accuracy = accuracy(predictions, evaluation_labels)
    return valid_error, valid_accuracy, L0, L1

In [4]:
#variables to change
train_dataset = l2l.vision.datasets.MiniImagenet(root='../data', mode='train')
valid_dataset = l2l.vision.datasets.MiniImagenet(root='../data', mode='validation')
test_dataset = l2l.vision.datasets.MiniImagenet(root='../data', mode='test')
train_dataset = l2l.data.MetaDataset(train_dataset)
valid_dataset = l2l.data.MetaDataset(valid_dataset)
test_dataset = l2l.data.MetaDataset(test_dataset)

In [5]:
ways=5
shots=5
meta_lr=0.003
fast_lr=0.5
meta_batch_size=32
adaptation_steps=1
num_iterations=5
cuda=True
seed=42

In [6]:
random.seed(seed)
np.random.seed(seed)
th.manual_seed(seed)
device = th.device('cpu')
if cuda and th.cuda.device_count():
    th.cuda.manual_seed(seed)
    device = th.device('cuda')

In [7]:
train_transforms = [NWays(train_dataset, ways),KShots(train_dataset, 2*shots),LoadData(train_dataset),
        RemapLabels(train_dataset),
        ConsecutiveLabels(train_dataset),
    ]
train_tasks = l2l.data.TaskDataset(train_dataset, task_transforms=train_transforms,num_tasks=20000)

In [8]:
valid_transforms = [NWays(valid_dataset, ways), KShots(valid_dataset, 2*shots), LoadData(valid_dataset),
        ConsecutiveLabels(train_dataset),
        RemapLabels(valid_dataset),
    ]
valid_tasks = l2l.data.TaskDataset(valid_dataset, task_transforms=valid_transforms, num_tasks=600)

In [9]:
test_transforms = [NWays(test_dataset, ways), KShots(test_dataset, 2*shots), LoadData(test_dataset), 
        RemapLabels(test_dataset),
        ConsecutiveLabels(train_dataset),
    ]
test_tasks = l2l.data.TaskDataset(test_dataset, task_transforms=test_transforms, num_tasks=600)

In [10]:
# Create models, 2 networks, one is the transformer (conditioned on gradients of minibatches)
# the other is the meta-learner doing MAML
embedding = 64
#same as miniImagenetCNN to learn the latent space/conditioned on the gradients,
#used size 64 like LEO encoder
transformer = l2l.vision.models.MiniImagenetCNN(embedding)
transformer.to(device)
metal = l2l.vision.models.LinearBlock(embedding, ways)
metal.to(device)
maml = l2l.algorithms.MAML(metal, lr=fast_lr, first_order=False)
opt = optim.Adam(maml.parameters(), meta_lr)
opt_transform = optim.Adam(transformer.parameters(), meta_lr)
loss = nn.CrossEntropyLoss(reduction='mean')
print(transformer)
print(metal)

MiniImagenetCNN(
  (base): ConvBase(
    (0): ConvBlock(
      (max_pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
      (normalize): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1): ConvBlock(
      (max_pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
      (normalize): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (2): ConvBlock(
      (max_pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
      (normalize): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1),

In [None]:
 for iteration in range(num_iterations):
        opt.zero_grad()
        meta_train_error = 0.0
        meta_train_accuracy = 0.0
        meta_valid_error = 0.0
        meta_valid_accuracy = 0.0
        meta_test_error = 0.0
        meta_test_accuracy = 0.0
        for task in range(meta_batch_size):
            # Compute meta-training loss
            learner = maml.clone()
            batch = train_tasks.sample()
            print(batch[0].size())
            batch_transformed = transformbatch(batch, transformer)
            evaluation_error, evaluation_accuracy, L0, L1 = fast_adapt(batch_transformed, learner, loss, adaptation_steps, shots, ways,
                           device)
            
            #check the gradients, torch.autogra.grad does not accumulate gradients ;)
          #  print("before")
           # for p in maml.parameters():
            #    print(p.grad)
            gradients0 = grad(L0, maml.parameters(), retain_graph=True, create_graph=True)
           # print("after")
            #for p in maml.parameters():
             #   print(p.grad)
            #print("gradients")
            #print(gradients0)
            gradients1 = grad(L1, maml.parameters(), retain_graph=True, create_graph=True)
            evaluation_error.backward()
            meta_train_error += evaluation_error.item()
            meta_train_accuracy += evaluation_accuracy.item()

            # Compute meta-validation loss
            learner = maml.clone()
            batch = valid_tasks.sample()
            batch_transformed = transformbatch(batch, transformer)
            evaluation_error, evaluation_accuracy, _, _ = fast_adapt(batch_transformed, learner,loss, adaptation_steps,
                                                               shots,
                                                               ways,
                                                               device)
            meta_valid_error += evaluation_error.item()
            meta_valid_accuracy += evaluation_accuracy.item()

            # Compute meta-testing loss
            learner = maml.clone()
            batch = test_tasks.sample()
            batch_transformed = transformbatch(batch, transformer)
            evaluation_error, evaluation_accuracy, _,_ = fast_adapt(batch_transformed, learner, loss, adaptation_steps,
                                                               shots,
                                                               ways,
                                                               device)
            meta_test_error += evaluation_error.item()
            meta_test_accuracy += evaluation_accuracy.item()

        # Print some metrics
        print('\n')
        print('Iteration', iteration)
        print('Meta Train Error', meta_train_error / meta_batch_size)
        print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size)
        print('Meta Valid Error', meta_valid_error / meta_batch_size)
        print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size)
        print('Meta Test Error', meta_test_error / meta_batch_size)
        print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size)

        # Average the accumulated gradients and optimize., last step of meta-gradient
        for p in maml.parameters():
            p.grad.data.mul_(1.0 / meta_batch_size)
        opt.step()

torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
None
None
None
None
after
None
None
None
None
gradients
(tensor([ 0.0022, -0.0689,  0.0215,  0.0557,  0.0846], device='cuda:0',
       grad_fn=<CudnnBatchNormBackwardBackward>), tensor([ 0.0045, -0.0644,  0.0225,  0.0511,  0.0427], device='cuda:0',
       grad_fn=<CudnnBatchNormBackwardBackward>), tensor([[-0.0594, -0.4706, -0.0256, -0.7164,  0.2068, -0.2853,  0.0563, -0.2484,
         -0.2776, -0.1635, -0.1548, -0.4708,  0.6861, -0.8108, -0.0441, -0.5403,
         -0.4345,  0.3895, -1.3378,  0.1365,  0.8986,  0.7473, -0.2026, -0.2166,
          0.3864,  0.0314, -0.3231,  0.4612,  1.2349,  0.2354, -0.0862, -0.1428,
         -0.2070, -1.0633, -1.1912,  0.0122, -0.3351, -0.4603,  0.1450,  0.0349,
          0.6610, -0.6435, -0.8402,  1.2083,  0.1899, -0.3261,  0.0086, -0.2803,
         -0.0327, -0.2220, -0.5531,  0.2499,  0.5999, -0.4719, -0.5790,  0.9992,
         -0.1266,  0.0381,  0.4507, -0.0801, -0.3023,  0.2652,  0



torch.Size([50])
torch.Size([25])
torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([-0.0013, -0.0038,  0.0027,  0.0002,  0.0043], device='cuda:0')
tensor([-0.0007, -0.0016,  0.0016, -0.0009, -0.0002], device='cuda:0')
tensor([[ 1.2863e-02, -8.5229e-03, -3.4589e-03, -1.8406e-02,  4.2207e-02,
         -1.0658e-02,  1.1379e-02,  2.6383e-02, -1.8469e-02,  1.0257e-02,
         -2.0640e-02, -2.5471e-02,  3.8182e-02, -3.4612e-02, -2.3995e-02,
         -7.5102e-03, -1.3458e-02,  1.9898e-02, -7.9386e-02, -3.3519e-02,
          1.9469e-02,  3.7168e-02, -9.8294e-03,  1.3595e-02,  1.2333e-02,
          1.3435e-02, -1.0970e-02,  1.9749e-02,  7.3184e-02, -6.3881e-04,
         -1.3651e-02,  5.6752e-03, -9.2215e-03, -5.8860e-02, -2.5233e-02,
         -8.1557e-03, -1.5794e-02, -4.4173e-02, -1.3665e-02, -2.1507e-02,
          5.1636e-02, -1.8799e-02, -1.8702e-02,  3.1885e-02, -8.5048e-04,
         -5.6002e-02, -3.7429e-02, -2.4789e-02, -3.7624e-03, -3.3553e-02,
         -5.1792e-02,  3.3851e-02,  5.8055e-02, -



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([-0.0032, -0.0016,  0.0034,  0.0022,  0.0053], device='cuda:0')
tensor([-1.9351e-03, -2.8045e-03,  1.3531e-03,  9.1958e-05, -7.9565e-04],
       device='cuda:0')
tensor([[-0.0037, -0.0093, -0.0043, -0.0213,  0.0371, -0.0091, -0.0133,  0.0272,
         -0.0272,  0.0319, -0.0356, -0.0281,  0.0454, -0.0422, -0.0269, -0.0021,
         -0.0337,  0.0136, -0.0657, -0.0445,  0.0161,  0.0528, -0.0072,  0.0207,
          0.0133, -0.0076, -0.0105,  0.0599,  0.0931,  0.0104, -0.0436, -0.0021,
         -0.0161, -0.0663, -0.0388, -0.0149, -0.0168, -0.0344,  0.0033, -0.0034,
          0.0526, -0.0422, -0.0207,  0.0226,  0.0052, -0.0603, -0.0367, -0.0236,
          0.0105, -0.0276, -0.0435,  0.0197,  0.0650, -0.0222, -0.0214,  0.0836,
         -0.0561, -0.0677,  0.0686, -0.0087, -0.0283, -0.0273, -0.0110,  0.0731],
        [-0.0211, -0.0134,  0.0073,  0.0047, -0.0071,  0.0097,  0.0213,  0.0269,
         -0.0191,  0.0192, -0.01



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([-0.0020, -0.0042,  0.0017,  0.0090,  0.0121], device='cuda:0')
tensor([ 0.0007, -0.0058, -0.0010,  0.0036,  0.0047], device='cuda:0')
tensor([[-2.1788e-02,  8.2883e-03, -4.0329e-03, -3.1276e-02,  4.6859e-02,
         -2.7944e-02, -2.4630e-02,  2.1390e-02, -4.1173e-02,  3.6472e-02,
         -3.9077e-02, -1.9734e-02,  4.0855e-02, -4.6095e-02, -3.0617e-02,
         -1.3933e-02, -2.8515e-02,  4.1341e-02, -7.8511e-02, -4.0019e-02,
          1.7908e-02,  5.8335e-02, -1.7026e-02,  1.9054e-02,  2.6019e-02,
         -1.3837e-02, -3.7560e-03,  7.2127e-02,  1.0903e-01,  5.9862e-03,
         -3.8196e-02,  1.7205e-04, -2.5770e-02, -7.0822e-02, -2.9745e-02,
         -9.2662e-03, -1.9250e-03, -2.6153e-02, -1.1109e-02,  1.3873e-02,
          3.3052e-02, -3.1969e-02, -3.6404e-02,  2.5604e-02,  4.2005e-03,
         -5.7079e-02, -3.5518e-02, -1.9798e-02,  3.6606e-04, -2.1327e-02,
         -3.9341e-02,  1.2460e-03,  5.7579e-02, -



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([-0.0004, -0.0083,  0.0044,  0.0098,  0.0184], device='cuda:0')
tensor([ 0.0038, -0.0079,  0.0007,  0.0065,  0.0013], device='cuda:0')
tensor([[-3.4816e-02,  5.2309e-02, -3.2821e-02, -3.7538e-02,  2.0780e-02,
         -6.7195e-02, -2.4937e-02, -1.2593e-04,  2.2731e-04,  7.8442e-03,
          1.9688e-02,  1.3514e-02, -1.1662e-02, -1.2123e-02, -1.1117e-02,
         -3.2515e-02,  1.4715e-02,  6.6881e-03, -7.3809e-02, -5.8165e-02,
          3.0859e-03,  7.4775e-03,  1.5172e-02,  4.8273e-02,  6.5289e-02,
          1.2232e-02,  4.1942e-02,  2.0255e-02,  5.8533e-02, -2.8461e-02,
         -2.6647e-03,  5.3080e-04, -4.1345e-02, -5.6186e-02,  6.2907e-03,
         -2.3879e-02,  2.9110e-02,  1.0603e-02, -3.6131e-03,  2.9894e-02,
         -7.3245e-02, -6.3652e-02, -4.0439e-02,  1.0704e-02, -1.1183e-02,
         -1.1609e-02, -1.0450e-01,  2.9932e-02, -3.9403e-02,  3.3066e-02,
         -6.8390e-02, -3.5808e-02,  5.6975e-02,  



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([-0.0025, -0.0070,  0.0043,  0.0109,  0.0123], device='cuda:0')
tensor([ 0.0034, -0.0054, -0.0015,  0.0075, -0.0009], device='cuda:0')
tensor([[-2.7798e-03,  4.5380e-02,  5.0002e-02, -7.7736e-02,  1.5175e-02,
         -1.4637e-01, -4.0126e-02,  5.1021e-02,  2.6302e-02, -6.0594e-02,
          9.6242e-03, -4.6608e-02,  4.3923e-02, -3.4083e-02, -1.3146e-02,
         -8.7724e-02,  1.3574e-03,  3.2430e-02, -3.6946e-02, -2.1955e-02,
         -4.4696e-02,  7.8654e-02,  5.8812e-02,  5.5657e-03,  5.0266e-02,
          5.0381e-02,  7.4137e-02,  4.3527e-02,  5.7198e-02,  3.1681e-03,
          3.8678e-03,  1.0669e-02, -1.2267e-01, -5.6011e-02, -3.6949e-02,
         -7.0662e-03, -2.1514e-02,  4.1010e-02, -1.0208e-02,  1.0465e-02,
         -8.1642e-02, -1.8587e-02, -1.0396e-01,  1.3433e-02, -4.9917e-02,
         -3.2595e-02, -9.6989e-02,  4.5988e-02, -3.2622e-02, -5.0964e-02,
         -6.2525e-02, -7.9396e-02,  1.1348e-01, -



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0015, -0.0051,  0.0009,  0.0127,  0.0128], device='cuda:0')
tensor([ 0.0070, -0.0045, -0.0015,  0.0089,  0.0013], device='cuda:0')
tensor([[ 1.0268e-02,  7.5471e-02,  5.6455e-02, -6.1525e-02,  1.5565e-02,
         -1.3266e-01, -4.4643e-02,  5.9283e-03,  1.4208e-02, -1.2352e-01,
         -7.3073e-03, -4.2033e-02, -2.5912e-03, -2.4042e-02, -1.6194e-02,
         -1.0833e-01,  3.0528e-02,  1.8017e-03, -1.2466e-02, -3.2291e-02,
         -6.6876e-02,  7.9907e-02,  6.5185e-02,  1.5088e-02,  6.7731e-02,
          6.2742e-02,  8.0136e-02,  3.2884e-02,  7.4445e-02, -2.1673e-02,
         -1.9366e-03,  4.2364e-03, -1.2724e-01, -4.8375e-02, -2.1147e-02,
         -3.1754e-02, -6.5266e-03,  5.0195e-02,  2.7631e-02, -1.2222e-04,
         -1.0460e-01, -2.1618e-02, -9.0810e-02,  1.4777e-02, -3.9034e-02,
          4.8863e-03, -9.4082e-02,  5.6985e-02, -4.5713e-02, -3.1940e-02,
         -5.9618e-02, -8.5588e-02,  1.2599e-01,  



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([0.0015, 0.0011, 0.0031, 0.0163, 0.0116], device='cuda:0')
tensor([ 0.0069, -0.0039, -0.0013,  0.0125, -0.0004], device='cuda:0')
tensor([[ 6.3013e-03,  6.8304e-02,  6.6665e-02, -6.0166e-02,  5.1691e-03,
         -1.4040e-01, -5.0369e-02,  4.9134e-03,  1.1150e-02, -1.3210e-01,
         -2.3096e-02, -4.6190e-02,  1.7656e-02, -1.7821e-02, -9.9868e-03,
         -1.0199e-01,  2.2647e-02, -1.7236e-02, -5.9137e-03, -3.2682e-02,
         -7.5464e-02,  8.8786e-02,  6.7952e-02,  1.8015e-02,  6.7724e-02,
          6.1856e-02,  8.2932e-02,  3.9836e-02,  6.5181e-02, -1.3359e-02,
         -1.1775e-02,  3.2947e-03, -1.2234e-01, -4.2226e-02, -2.6719e-02,
         -1.7439e-02,  8.8582e-05,  5.4605e-02,  2.7362e-02, -1.3545e-02,
         -1.0142e-01, -1.8222e-02, -8.6790e-02,  6.1067e-03, -3.3204e-02,
         -3.3727e-03, -9.2035e-02,  6.5860e-02, -5.1172e-02, -2.5875e-02,
         -5.4414e-02, -9.0296e-02,  1.1439e-01,  2.696



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0046, -0.0036, -0.0009,  0.0119,  0.0128], device='cuda:0')
tensor([ 0.0076, -0.0058, -0.0049,  0.0093,  0.0006], device='cuda:0')
tensor([[ 1.7139e-03,  8.0644e-02,  6.4108e-02, -5.4161e-02,  3.5105e-03,
         -1.3791e-01, -6.0853e-02, -1.0022e-03,  9.1918e-03, -1.3699e-01,
         -2.0438e-02, -4.4749e-02,  1.2811e-02, -2.4022e-02, -1.8056e-02,
         -1.1505e-01,  2.4489e-02, -2.6674e-02, -1.4735e-02, -4.0606e-02,
         -7.3058e-02,  8.5368e-02,  5.7448e-02,  2.2951e-02,  6.1261e-02,
          6.8875e-02,  8.5710e-02,  3.5456e-02,  5.5166e-02, -1.4930e-02,
         -6.6674e-03, -3.4811e-03, -1.2178e-01, -4.1674e-02, -1.8213e-02,
         -1.8702e-02,  3.2415e-04,  4.7279e-02,  3.5021e-02, -2.6268e-02,
         -1.0483e-01, -2.6856e-02, -9.1301e-02, -3.8708e-04, -2.3047e-02,
          3.1490e-03, -9.7417e-02,  5.9454e-02, -4.2002e-02, -2.2032e-02,
         -5.2919e-02, -8.7724e-02,  9.6390e-02,  



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0040, -0.0035, -0.0005,  0.0038,  0.0164], device='cuda:0')
tensor([ 0.0080, -0.0064, -0.0045,  0.0045,  0.0029], device='cuda:0')
tensor([[-2.0569e-02,  1.1564e-01,  6.5490e-02, -4.1395e-02, -9.4845e-03,
         -1.3685e-01, -7.6148e-02, -1.6275e-02, -2.2718e-03, -1.4305e-01,
         -7.2911e-03, -4.3641e-02, -1.0652e-02, -4.8238e-02, -1.1762e-02,
         -1.2614e-01,  1.2654e-02, -3.3466e-02, -1.6009e-02, -4.7967e-02,
         -6.7680e-02,  7.3247e-02,  3.8934e-02,  1.3663e-02,  7.0695e-02,
          8.0178e-02,  1.1496e-01,  5.3021e-02,  5.8250e-02, -3.0463e-02,
         -1.1093e-02, -8.6111e-03, -1.1880e-01, -4.3419e-02, -8.5187e-03,
         -1.2076e-02, -6.5428e-04,  4.3950e-02,  6.5043e-02, -1.3881e-02,
         -1.3079e-01, -4.3309e-02, -9.9820e-02,  7.0364e-03, -9.0615e-03,
          1.1025e-02, -9.3530e-02,  5.4484e-02, -6.0576e-02, -1.3665e-02,
         -5.8230e-02, -9.8491e-02,  5.2480e-02,  



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0097, -0.0019,  0.0032,  0.0008,  0.0196], device='cuda:0')
tensor([ 0.0107, -0.0045, -0.0044,  0.0040,  0.0029], device='cuda:0')
tensor([[-3.6333e-02,  1.2093e-01,  6.2144e-02, -3.9040e-02, -1.1253e-02,
         -1.5256e-01, -9.1778e-02, -1.4270e-02,  1.0327e-03, -1.3420e-01,
          3.4096e-04, -4.4764e-02, -7.2727e-03, -6.1270e-02, -1.0393e-02,
         -1.2160e-01,  1.3864e-02, -4.6286e-02, -1.3160e-02, -6.4411e-02,
         -5.9064e-02,  7.9873e-02,  2.9226e-02,  1.3154e-02,  7.0320e-02,
          8.0093e-02,  1.1484e-01,  4.2263e-02,  4.8762e-02, -2.6793e-02,
          5.9134e-03, -1.3483e-03, -1.1349e-01, -5.2091e-02, -4.8130e-03,
         -1.9870e-02,  4.0401e-03,  5.0420e-02,  7.6251e-02, -2.7767e-02,
         -1.3906e-01, -5.3169e-02, -9.6630e-02,  1.2720e-02, -2.1221e-02,
          3.6003e-03, -8.9081e-02,  6.6958e-02, -4.9551e-02, -1.4850e-02,
         -6.0422e-02, -9.4393e-02,  7.2764e-02,  



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0119, -0.0046,  0.0026, -0.0023,  0.0193], device='cuda:0')
tensor([ 1.3064e-02, -7.4636e-03, -4.9817e-03,  3.3994e-03,  2.6408e-05],
       device='cuda:0')
tensor([[-4.1063e-02,  1.2686e-01,  5.7880e-02, -4.3712e-02, -2.5176e-02,
         -1.5390e-01, -9.8659e-02, -2.0914e-02,  1.4137e-03, -1.3677e-01,
          8.3611e-03, -3.9273e-02, -1.1867e-02, -6.2228e-02, -5.5497e-03,
         -1.2740e-01,  1.3729e-02, -3.9373e-02, -7.8618e-03, -5.9316e-02,
         -6.8316e-02,  7.1444e-02,  2.7458e-02,  4.3041e-03,  7.7388e-02,
          8.1806e-02,  1.1849e-01,  4.3322e-02,  4.6000e-02, -3.7888e-02,
          1.0171e-02, -9.8863e-04, -1.1423e-01, -4.9581e-02, -2.6485e-03,
         -1.8465e-02,  6.0794e-03,  5.5058e-02,  7.9943e-02, -3.3566e-02,
         -1.5015e-01, -5.4618e-02, -9.3411e-02,  1.6055e-02, -1.9562e-02,
          8.9889e-03, -8.4231e-02,  7.1679e-02, -6.0046e-02, -1.3709e-02,
         -6.0898e-02, 



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0128, -0.0052,  0.0006, -0.0070,  0.0154], device='cuda:0')
tensor([ 0.0141, -0.0071, -0.0067,  0.0030, -0.0021], device='cuda:0')
tensor([[ 2.7270e-02,  1.3864e-03,  9.3057e-02, -7.0515e-02,  6.5777e-02,
         -1.2891e-01, -7.9681e-02,  5.1559e-02, -1.1314e-02, -7.7828e-02,
         -9.6921e-02, -7.0242e-02,  8.1229e-02, -4.6204e-02, -1.0059e-02,
         -7.6582e-02, -1.4681e-02, -4.0951e-02,  2.0037e-04, -7.8093e-02,
         -1.0437e-01,  1.1986e-01,  4.8670e-02,  2.8426e-03,  3.7622e-02,
          4.8781e-02,  2.7366e-02,  1.6937e-02,  3.8113e-02, -5.5047e-03,
          2.0486e-02,  1.7397e-03, -1.2115e-01, -2.3803e-02, -5.0948e-02,
         -7.0481e-02, -1.5553e-02,  6.3210e-02,  5.9196e-02, -6.4270e-02,
         -3.3053e-02,  2.5455e-02, -4.0229e-02,  1.3882e-02, -4.1585e-02,
          4.8594e-03, -1.2474e-01,  2.5735e-02,  1.1122e-02, -6.3566e-02,
         -5.3660e-02, -9.4430e-02,  1.4889e-01,  



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0096, -0.0058,  0.0011,  0.0015,  0.0197], device='cuda:0')
tensor([ 0.0119, -0.0086, -0.0066,  0.0059,  0.0035], device='cuda:0')
tensor([[-4.5080e-02,  3.9317e-02,  6.8927e-02, -4.6296e-02, -3.1558e-03,
         -1.1149e-01, -1.3452e-01,  2.2743e-02, -1.0713e-02, -9.5136e-02,
         -9.1632e-02, -7.7077e-02,  6.4948e-02, -1.2460e-01,  8.0735e-03,
         -9.5836e-02, -4.4520e-02, -2.3979e-03, -2.0598e-02, -1.0932e-01,
         -5.4192e-02,  1.3290e-01,  1.6223e-02, -6.9386e-03,  6.1675e-02,
          7.3533e-02,  5.1090e-02,  3.2283e-02,  9.3593e-02, -1.2316e-02,
          2.4300e-02,  3.5127e-04, -1.1697e-01, -6.5408e-02, -4.4101e-02,
         -2.6954e-02, -2.7597e-02,  5.3110e-02,  9.8801e-02, -7.1773e-02,
         -8.0327e-02, -3.3623e-02, -8.2849e-02,  3.4866e-02,  2.8954e-03,
         -1.3263e-02, -9.8777e-02,  1.9280e-02, -1.6962e-02, -2.7615e-02,
         -3.3128e-02, -1.5167e-01,  1.3746e-01,  



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0119, -0.0032,  0.0020, -0.0003,  0.0249], device='cuda:0')
tensor([ 0.0133, -0.0081, -0.0073,  0.0054,  0.0065], device='cuda:0')
tensor([[-3.7336e-02,  1.8777e-02,  9.5923e-02, -5.6305e-02,  1.5419e-02,
         -1.1701e-01, -1.3250e-01,  4.1565e-02, -3.8442e-03, -6.4619e-02,
         -1.1218e-01, -7.4153e-02,  8.9607e-02, -1.1014e-01,  2.0141e-02,
         -1.0279e-01, -5.3470e-02, -5.6144e-03, -4.0474e-02, -9.4425e-02,
         -5.0708e-02,  1.5969e-01,  1.0181e-02, -2.1193e-02,  5.6757e-02,
          8.6250e-02,  5.1697e-02,  2.3974e-02,  1.0303e-01, -1.9338e-02,
          2.0769e-02,  3.8252e-03, -1.1557e-01, -8.8559e-02, -5.2184e-02,
         -3.6957e-02, -4.8941e-02,  5.7713e-02,  9.9889e-02, -9.0207e-02,
         -6.4631e-02, -2.9706e-02, -7.7692e-02,  4.4322e-02,  1.6595e-02,
         -2.9095e-03, -1.0556e-01,  2.0850e-02, -2.2009e-02, -6.0384e-02,
         -4.7369e-02, -1.4403e-01,  1.2779e-01,  



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0112, -0.0035, -0.0014,  0.0037,  0.0258], device='cuda:0')
tensor([ 0.0124, -0.0077, -0.0101,  0.0069,  0.0040], device='cuda:0')
tensor([[-3.9323e-02,  2.3585e-02,  8.2110e-02, -5.2564e-02,  5.8850e-03,
         -1.5692e-01, -1.2754e-01,  6.7588e-02,  1.5189e-03, -7.8382e-02,
         -1.0067e-01, -6.4319e-02,  6.7158e-02, -1.3097e-01,  2.1524e-02,
         -1.0357e-01, -4.9491e-02,  8.4980e-03, -3.3290e-02, -1.2935e-01,
         -7.2795e-02,  1.6457e-01,  3.3164e-02, -2.2277e-02,  6.4556e-02,
          8.9606e-02,  4.9580e-02,  5.3179e-02,  9.1550e-02, -2.6241e-02,
          1.1941e-02,  1.3070e-02, -1.1913e-01, -1.0218e-01, -6.3135e-02,
         -2.1558e-02, -4.1769e-02,  5.6091e-02,  1.1797e-01, -1.0978e-01,
         -6.1285e-02, -4.3127e-02, -8.0189e-02,  5.3220e-02,  9.1202e-03,
         -4.8949e-02, -1.4321e-01,  4.5769e-02, -1.1112e-02, -7.4937e-02,
         -6.1846e-02, -1.4594e-01,  1.4513e-01,  



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0170, -0.0035, -0.0045,  0.0016,  0.0277], device='cuda:0')
tensor([ 0.0157, -0.0078, -0.0132,  0.0066,  0.0048], device='cuda:0')
tensor([[-0.0365,  0.0144,  0.0930, -0.0451,  0.0142, -0.1561, -0.1340,  0.0773,
          0.0119, -0.0656, -0.1033, -0.0660,  0.0720, -0.1306,  0.0173, -0.0906,
         -0.0479,  0.0104, -0.0414, -0.1168, -0.0869,  0.1673,  0.0428, -0.0207,
          0.0561,  0.0725,  0.0220,  0.0555,  0.0974, -0.0140,  0.0103,  0.0153,
         -0.1216, -0.0983, -0.0683, -0.0177, -0.0484,  0.0759,  0.1150, -0.1210,
         -0.0551, -0.0482, -0.0748,  0.0638,  0.0084, -0.0657, -0.1450,  0.0471,
         -0.0023, -0.0852, -0.0769, -0.1540,  0.1507,  0.0274, -0.0793,  0.0768,
         -0.0365,  0.0733,  0.0153, -0.0545, -0.0122, -0.0066,  0.0501,  0.1193],
        [-0.1275,  0.0890,  0.0337,  0.0471, -0.1156, -0.0146, -0.0704, -0.0152,
         -0.1157,  0.0732, -0.0900,  0.0481,  0.0135, -0.02



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0121, -0.0012, -0.0075,  0.0024,  0.0245], device='cuda:0')
tensor([ 0.0131, -0.0054, -0.0156,  0.0054,  0.0041], device='cuda:0')
tensor([[-3.6976e-02,  1.3891e-02,  9.3380e-02, -3.8034e-02,  1.3753e-02,
         -1.4965e-01, -1.3705e-01,  8.4186e-02,  1.1416e-02, -6.6019e-02,
         -1.0478e-01, -6.8693e-02,  7.1894e-02, -1.2087e-01,  2.1682e-02,
         -8.5971e-02, -4.4909e-02,  2.6740e-03, -3.9664e-02, -1.0938e-01,
         -8.3875e-02,  1.6041e-01,  4.2569e-02, -2.0667e-02,  5.4764e-02,
          8.1152e-02,  2.0385e-02,  4.4110e-02,  9.3944e-02, -1.5798e-02,
          1.2253e-02,  1.9426e-02, -1.1866e-01, -8.9102e-02, -6.7351e-02,
         -1.6703e-02, -4.9143e-02,  7.1735e-02,  1.2028e-01, -1.2279e-01,
         -4.7357e-02, -5.1470e-02, -8.2203e-02,  7.6297e-02,  5.6618e-03,
         -6.5985e-02, -1.4711e-01,  3.6385e-02,  3.9110e-03, -7.6438e-02,
         -6.7082e-02, -1.5424e-01,  1.5195e-01,  



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0113, -0.0029, -0.0032,  0.0062,  0.0210], device='cuda:0')
tensor([ 0.0132, -0.0061, -0.0147,  0.0062,  0.0004], device='cuda:0')
tensor([[-0.0528,  0.0376,  0.0925, -0.0362, -0.0010, -0.1491, -0.1487,  0.0683,
          0.0139, -0.0699, -0.1112, -0.0634,  0.0764, -0.1236,  0.0170, -0.0917,
         -0.0497,  0.0011, -0.0314, -0.1128, -0.0785,  0.1645,  0.0324, -0.0230,
          0.0734,  0.0890,  0.0216,  0.0372,  0.0875, -0.0209,  0.0217,  0.0174,
         -0.1126, -0.0851, -0.0471, -0.0095, -0.0526,  0.0663,  0.1418, -0.1216,
         -0.0650, -0.0565, -0.0705,  0.0779,  0.0111, -0.0575, -0.1408,  0.0407,
          0.0109, -0.0529, -0.0678, -0.1594,  0.1436,  0.0230, -0.0865,  0.0720,
         -0.0277,  0.0965,  0.0012, -0.0809, -0.0028, -0.0074,  0.0439,  0.0995],
        [-0.1256,  0.0943,  0.0260,  0.0966, -0.1150,  0.0376, -0.0635,  0.0610,
         -0.1805,  0.1166, -0.1297,  0.0757, -0.0197, -0.04



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0082, -0.0084, -0.0045,  0.0125,  0.0258], device='cuda:0')
tensor([ 0.0123, -0.0085, -0.0167,  0.0109,  0.0033], device='cuda:0')
tensor([[-3.6137e-02,  2.9782e-02,  9.8512e-02, -4.9818e-02, -9.4955e-03,
         -1.5229e-01, -1.4468e-01,  6.3311e-02,  9.2589e-03, -7.1217e-02,
         -1.0923e-01, -6.5250e-02,  9.2652e-02, -1.3122e-01,  6.6209e-03,
         -9.0939e-02, -3.3958e-02, -1.3364e-02, -4.3324e-02, -1.0989e-01,
         -7.0130e-02,  1.7860e-01,  2.8001e-02, -3.9425e-02,  6.8983e-02,
          9.0711e-02,  3.7697e-04,  4.1914e-02,  8.7563e-02, -1.1370e-02,
          1.4910e-02,  2.4733e-02, -1.0665e-01, -8.1905e-02, -5.5583e-02,
         -4.1617e-03, -5.5443e-02,  7.9470e-02,  1.4181e-01, -1.2018e-01,
         -5.8272e-02, -4.6490e-02, -7.5638e-02,  7.3656e-02,  1.6550e-02,
         -6.6991e-02, -1.3780e-01,  4.9468e-02,  1.2312e-02, -4.0600e-02,
         -7.3766e-02, -1.5981e-01,  1.5454e-01,  



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0075, -0.0103,  0.0007,  0.0154,  0.0259], device='cuda:0')
tensor([ 0.0114, -0.0097, -0.0142,  0.0136,  0.0045], device='cuda:0')
tensor([[-0.0125, -0.0328,  0.1194, -0.0543,  0.0185, -0.1456, -0.1311,  0.0852,
          0.0129, -0.0546, -0.1326, -0.0906,  0.1388, -0.1167,  0.0072, -0.0761,
         -0.0362, -0.0108, -0.0393, -0.0806, -0.0972,  0.1926,  0.0408, -0.0312,
          0.0506,  0.0689, -0.0208,  0.0213,  0.0956,  0.0033,  0.0085,  0.0277,
         -0.1154, -0.0633, -0.0734,  0.0044, -0.0721,  0.0914,  0.1040, -0.1236,
         -0.0145, -0.0156, -0.0755,  0.0553,  0.0291, -0.0559, -0.1462,  0.0538,
          0.0387, -0.0720, -0.0799, -0.1544,  0.1742, -0.0256, -0.0887,  0.1175,
         -0.0563,  0.0562,  0.0044, -0.0254, -0.0218, -0.0128,  0.0532,  0.1131],
        [-0.1071,  0.1376, -0.0221,  0.1197, -0.1233,  0.0132, -0.0614,  0.0572,
         -0.2137,  0.1467, -0.1338,  0.0684, -0.0192, -0.08



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0034, -0.0097,  0.0073,  0.0135,  0.0275], device='cuda:0')
tensor([ 0.0107, -0.0090, -0.0117,  0.0113,  0.0050], device='cuda:0')
tensor([[-5.4551e-02,  3.7710e-03,  1.3926e-01, -7.6057e-02, -1.4002e-02,
         -1.4601e-01, -1.3183e-01,  7.9398e-02,  1.8192e-02, -3.1359e-02,
         -1.5341e-01, -9.1218e-02,  1.3454e-01, -1.0908e-01,  1.2060e-02,
         -1.0393e-01,  1.1304e-02,  3.3826e-02, -6.1070e-02, -1.2664e-01,
         -5.5081e-02,  1.9807e-01, -1.3842e-02, -7.0898e-02,  5.0004e-02,
          1.1769e-01, -4.8860e-04,  6.9650e-02,  1.1552e-01, -5.0353e-02,
          1.7179e-02,  6.1951e-02, -1.0886e-01, -7.6059e-02, -1.4898e-02,
          3.1792e-02, -7.8020e-02,  6.3160e-02,  1.3796e-01, -9.1011e-02,
         -5.2189e-02, -2.6179e-02, -1.0182e-01,  4.4146e-02,  5.4163e-02,
         -6.9696e-02, -1.3288e-01,  6.1388e-02,  1.9741e-02, -6.6401e-03,
         -1.0337e-01, -1.5581e-01,  1.4226e-01, -



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0089, -0.0101,  0.0069,  0.0134,  0.0229], device='cuda:0')
tensor([ 0.0147, -0.0097, -0.0103,  0.0110,  0.0041], device='cuda:0')
tensor([[-5.2995e-02,  4.1848e-03,  1.4440e-01, -7.3470e-02, -1.9212e-02,
         -1.4470e-01, -1.3021e-01,  8.0515e-02,  2.0243e-02, -2.6674e-02,
         -1.4160e-01, -8.7611e-02,  1.1638e-01, -1.0535e-01,  2.5131e-02,
         -9.9035e-02,  2.1014e-02,  2.7175e-02, -5.2566e-02, -1.2899e-01,
         -6.1254e-02,  1.8670e-01, -9.9654e-03, -5.0338e-02,  5.6899e-02,
          1.1695e-01,  1.3997e-02,  5.7951e-02,  9.6909e-02, -4.5952e-02,
          2.1111e-02,  5.2557e-02, -1.1427e-01, -7.6157e-02, -1.8566e-02,
          4.1443e-02, -7.2414e-02,  8.2758e-02,  1.4153e-01, -8.9537e-02,
         -5.0730e-02, -2.0010e-02, -9.3981e-02,  4.4247e-02,  4.7164e-02,
         -5.8210e-02, -1.1393e-01,  6.7785e-02,  2.3517e-02, -2.4780e-03,
         -1.0309e-01, -1.5835e-01,  1.2718e-01, -



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0149, -0.0094,  0.0090,  0.0154,  0.0312], device='cuda:0')
tensor([ 0.0161, -0.0089, -0.0093,  0.0145,  0.0099], device='cuda:0')
tensor([[-4.4594e-02,  4.0889e-03,  1.4616e-01, -7.0269e-02, -1.8088e-02,
         -1.4151e-01, -1.2784e-01,  7.9763e-02,  2.4923e-02, -7.0173e-03,
         -1.3880e-01, -9.7463e-02,  1.1707e-01, -1.1191e-01,  2.3475e-02,
         -8.0039e-02,  1.1553e-02,  2.5413e-02, -5.0238e-02, -1.1528e-01,
         -5.5785e-02,  1.9070e-01, -7.5055e-03, -4.8282e-02,  5.1905e-02,
          1.3239e-01,  4.6157e-03,  4.5839e-02,  1.1809e-01, -4.2863e-02,
          2.6399e-02,  5.3281e-02, -1.0517e-01, -8.3694e-02, -3.6055e-02,
          3.2030e-02, -8.9925e-02,  8.2522e-02,  1.3936e-01, -8.8086e-02,
         -4.2920e-02,  2.8797e-04, -9.4392e-02,  4.7356e-02,  4.5759e-02,
         -6.9816e-02, -1.1285e-01,  6.8247e-02,  1.3259e-02, -1.0774e-02,
         -1.0835e-01, -1.4907e-01,  1.3464e-01, -



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0198, -0.0100,  0.0074,  0.0162,  0.0301], device='cuda:0')
tensor([ 0.0185, -0.0107, -0.0111,  0.0143,  0.0076], device='cuda:0')
tensor([[-5.4720e-02,  3.1614e-04,  1.4557e-01, -7.4781e-02, -3.1124e-02,
         -1.4985e-01, -1.3520e-01,  6.9295e-02,  2.5422e-02, -1.1660e-02,
         -1.2920e-01, -1.0019e-01,  1.1756e-01, -1.0791e-01,  2.5948e-02,
         -9.7120e-02, -3.6305e-05,  2.9562e-02, -4.3019e-02, -1.1197e-01,
         -4.6031e-02,  1.8745e-01, -8.9725e-03, -5.3570e-02,  4.6092e-02,
          1.4302e-01,  4.6009e-03,  5.5761e-02,  1.2225e-01, -4.5053e-02,
          3.5634e-02,  5.4696e-02, -1.0133e-01, -9.1869e-02, -3.3307e-02,
          3.6555e-02, -9.8690e-02,  7.6557e-02,  1.3802e-01, -9.8857e-02,
         -5.0433e-02, -3.2591e-03, -8.8075e-02,  4.7193e-02,  5.3613e-02,
         -6.6892e-02, -1.0531e-01,  5.4743e-02,  1.9565e-02, -9.8059e-03,
         -1.1869e-01, -1.4773e-01,  1.1836e-01, -



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0178, -0.0114,  0.0114,  0.0229,  0.0317], device='cuda:0')
tensor([ 0.0166, -0.0115, -0.0067,  0.0185,  0.0085], device='cuda:0')
tensor([[-1.0434e-01,  2.0768e-02,  1.2563e-01, -8.3198e-04, -7.2209e-02,
         -1.4740e-01, -1.0445e-01,  6.5963e-02,  1.3124e-02,  2.3857e-02,
         -1.4553e-01, -6.6562e-02,  1.1000e-01, -5.3043e-02,  4.5272e-02,
         -9.0169e-02, -6.9100e-02,  6.5832e-02, -4.5250e-02, -1.6789e-01,
         -1.9890e-02,  2.5181e-01, -5.0265e-02, -4.5452e-02,  5.7657e-02,
          1.6809e-01, -4.2321e-03,  7.0776e-02,  1.4377e-01, -2.8323e-02,
          3.6489e-02,  4.2219e-02, -9.7374e-02, -1.2432e-01, -9.2101e-02,
         -1.9486e-04, -1.3051e-01, -1.4502e-02,  1.5398e-01, -1.1105e-01,
         -4.1813e-02,  6.1751e-03, -8.2524e-02,  5.9273e-02,  3.7371e-02,
         -1.2255e-01, -7.0693e-02,  4.0545e-02,  6.0473e-02, -2.7115e-02,
         -1.1734e-01, -1.8742e-01,  1.4790e-01, -



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0235, -0.0076,  0.0140,  0.0246,  0.0312], device='cuda:0')
tensor([ 0.0171, -0.0084, -0.0043,  0.0176,  0.0075], device='cuda:0')
tensor([[ 1.7856e-02, -1.3626e-01,  1.5671e-01,  7.0946e-02,  8.0214e-03,
         -1.0875e-01, -5.9252e-02,  1.5182e-01,  9.0851e-03,  1.1898e-01,
         -2.7046e-01, -8.5163e-02,  1.6583e-01, -3.2661e-02,  5.7629e-02,
          1.2730e-02, -4.5199e-02,  8.8211e-02, -5.9216e-02, -1.4514e-01,
         -4.9182e-02,  3.0652e-01, -9.3267e-03,  1.2438e-03, -1.7732e-02,
          8.0964e-02, -1.3239e-01,  8.8513e-02,  1.4897e-01,  3.4171e-02,
         -2.2833e-02,  1.6364e-02, -1.2457e-01, -8.4707e-02, -1.7190e-01,
         -1.8826e-02, -1.6477e-01,  6.3249e-02,  9.9629e-02, -1.4083e-01,
          6.5775e-02,  1.0578e-01, -8.0366e-02,  5.5928e-02,  1.9557e-02,
         -1.5761e-01, -1.0122e-01, -1.3790e-02,  1.7178e-01, -7.8295e-02,
         -1.0064e-01, -1.6591e-01,  2.5039e-01, -



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0260, -0.0048,  0.0165,  0.0250,  0.0366], device='cuda:0')
tensor([ 0.0199, -0.0072, -0.0018,  0.0168,  0.0133], device='cuda:0')
tensor([[ 4.1767e-02, -1.8858e-01,  1.3284e-01,  9.9445e-02,  3.9154e-02,
         -9.2541e-02, -4.8519e-02,  1.6946e-01, -9.6210e-03,  1.3618e-01,
         -2.7040e-01, -8.3474e-02,  1.6142e-01, -4.0211e-02,  7.6314e-02,
          5.7526e-02, -4.2097e-02,  8.0314e-02, -5.4510e-02, -1.5312e-01,
         -3.4843e-02,  3.3210e-01, -1.5434e-02, -3.8149e-03, -5.4956e-02,
          4.0654e-02, -1.2701e-01,  1.2957e-01,  1.5838e-01,  5.4970e-02,
         -6.2219e-02,  3.0089e-02, -1.1845e-01, -9.3449e-02, -2.1635e-01,
         -4.2691e-02, -1.7460e-01,  6.6908e-02,  9.0101e-02, -1.5866e-01,
          8.6371e-02,  1.1612e-01, -9.5400e-02,  4.6366e-02, -2.1323e-03,
         -1.9132e-01, -1.1709e-01, -2.3748e-02,  1.9716e-01, -1.0430e-01,
         -1.1434e-01, -1.8055e-01,  2.7453e-01, -



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0325, -0.0095,  0.0144,  0.0256,  0.0391], device='cuda:0')
tensor([ 0.0226, -0.0078, -0.0018,  0.0170,  0.0134], device='cuda:0')
tensor([[-2.5677e-02, -1.4425e-01,  1.3935e-01,  1.4784e-01, -1.9164e-03,
         -9.2910e-02, -4.3872e-02,  1.1515e-01, -2.6155e-02,  1.0252e-01,
         -2.4298e-01, -8.2099e-02,  1.0618e-01, -4.5600e-02,  9.3958e-02,
          2.8357e-02, -2.5917e-02,  6.6467e-02, -6.8326e-02, -2.1218e-01,
         -1.1846e-02,  3.3393e-01, -6.4669e-02,  1.1207e-03, -8.4996e-02,
          4.0533e-02, -9.2677e-02,  1.6418e-01,  1.5069e-01,  4.5581e-02,
         -6.0648e-02,  4.5322e-02, -1.2123e-01, -9.0289e-02, -2.0953e-01,
         -3.4606e-02, -1.8070e-01,  6.5612e-02,  1.1013e-01, -1.4484e-01,
          3.1391e-02,  1.0279e-01, -1.0756e-01,  3.7104e-02, -1.6292e-02,
         -1.9823e-01, -9.6636e-02, -2.5654e-02,  2.1636e-01, -5.2945e-02,
         -1.0739e-01, -2.3738e-01,  2.4379e-01, -



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0352, -0.0076,  0.0141,  0.0172,  0.0439], device='cuda:0')
tensor([ 0.0245, -0.0075, -0.0025,  0.0115,  0.0173], device='cuda:0')
tensor([[-1.7190e-02, -1.5270e-01,  1.3443e-01,  1.3326e-01,  4.0774e-03,
         -9.4599e-02, -3.7095e-02,  1.2536e-01, -1.4987e-02,  1.0286e-01,
         -2.4280e-01, -8.5654e-02,  1.1720e-01, -4.7925e-02,  9.5195e-02,
          3.0859e-02, -3.0676e-02,  8.1297e-02, -6.5553e-02, -2.1276e-01,
         -2.2907e-02,  3.4142e-01, -5.3168e-02,  2.1781e-04, -8.7051e-02,
          4.2045e-02, -9.8832e-02,  1.5552e-01,  1.5800e-01,  5.5922e-02,
         -6.1281e-02,  4.7145e-02, -1.2078e-01, -1.0036e-01, -2.1911e-01,
         -3.6384e-02, -1.9066e-01,  5.8352e-02,  1.0212e-01, -1.4918e-01,
          4.5069e-02,  1.1187e-01, -1.0309e-01,  3.3878e-02, -1.9226e-02,
         -1.8601e-01, -1.0053e-01, -2.5353e-02,  2.1702e-01, -6.8967e-02,
         -1.1403e-01, -2.3231e-01,  2.5352e-01, -



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0325, -0.0068,  0.0169,  0.0175,  0.0457], device='cuda:0')
tensor([ 0.0228, -0.0060, -0.0014,  0.0132,  0.0184], device='cuda:0')
tensor([[ 5.3760e-03, -1.9604e-01,  1.4291e-01,  1.2428e-01,  3.3293e-03,
         -1.0791e-01, -2.7702e-02,  1.3109e-01, -3.2269e-02,  1.3564e-01,
         -2.3097e-01, -1.1794e-01,  1.4115e-01, -3.9401e-02,  9.4953e-02,
          6.1948e-02, -5.9889e-02,  8.0045e-02, -7.1925e-02, -1.9161e-01,
         -3.7245e-02,  3.5323e-01, -4.9130e-02, -9.7835e-03, -1.0148e-01,
          4.3657e-02, -1.0274e-01,  1.5277e-01,  1.8650e-01,  5.9956e-02,
         -8.0155e-02,  4.4006e-02, -1.1901e-01, -1.0272e-01, -2.3591e-01,
         -2.9095e-02, -2.0738e-01,  7.8769e-02,  1.0197e-01, -1.5232e-01,
          4.5126e-02,  1.3232e-01, -1.1712e-01,  1.6734e-02, -2.1714e-02,
         -1.7312e-01, -9.1947e-02, -3.0687e-02,  2.2527e-01, -9.7604e-02,
         -1.1567e-01, -2.4060e-01,  3.0344e-01, -



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0361, -0.0081,  0.0141,  0.0164,  0.0467], device='cuda:0')
tensor([ 0.0208, -0.0079, -0.0035,  0.0119,  0.0201], device='cuda:0')
tensor([[ 1.2018e-02, -1.8954e-01,  1.4231e-01,  1.3188e-01,  6.2326e-03,
         -1.0687e-01, -2.6362e-02,  1.2934e-01, -3.7346e-02,  1.3358e-01,
         -2.3364e-01, -1.1569e-01,  1.4319e-01, -4.2857e-02,  9.9971e-02,
          6.5340e-02, -5.9024e-02,  8.9250e-02, -7.1013e-02, -1.9081e-01,
         -3.3031e-02,  3.6064e-01, -4.7400e-02, -9.1201e-03, -1.0950e-01,
          5.0891e-02, -1.0795e-01,  1.4928e-01,  1.9199e-01,  5.7083e-02,
         -7.0063e-02,  4.0220e-02, -1.1616e-01, -1.0625e-01, -2.4388e-01,
         -3.4589e-02, -2.0569e-01,  8.1462e-02,  1.0336e-01, -1.5003e-01,
          5.3323e-02,  1.3150e-01, -1.1669e-01,  1.5277e-02, -1.5113e-02,
         -1.7371e-01, -9.1510e-02, -3.8738e-02,  2.2474e-01, -9.4812e-02,
         -1.1540e-01, -2.3972e-01,  3.0834e-01, -



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([0., 0., 0., 0., 0.], device='cuda:0')
tensor([0., 0., 0., 0., 0.], device='cuda:0')
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 



torch.Size([50])
torch.Size([25])
torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0061,  0.0031, -0.0036,  0.0077,  0.0034], device='cuda:0')
tensor([ 0.0027,  0.0011, -0.0017,  0.0045,  0.0011], device='cuda:0')
tensor([[ 5.5468e-02, -9.4363e-02,  2.0761e-02,  1.2455e-02, -2.3252e-03,
         -3.1441e-02,  4.2351e-02,  1.0263e-01, -1.2985e-02,  1.0121e-01,
         -1.9956e-02, -3.1236e-02,  5.7674e-02, -4.7554e-02,  1.0384e-02,
          3.2684e-02,  7.9394e-03, -3.4905e-02, -3.9464e-03,  3.2418e-03,
          2.8410e-02,  3.9818e-02, -7.5590e-02,  2.5578e-02, -4.7145e-02,
         -3.1922e-02,  2.6990e-02, -1.1196e-02,  1.6015e-02,  3.6062e-02,
         -4.0029e-02,  4.3809e-02, -2.1391e-02,  6.9067e-03, -1.9860e-02,
          2.5916e-02, -8.0064e-03,  2.1094e-02,  2.3420e-02,  2.9930e-02,
          6.2295e-02,  3.1938e-02,  7.1807e-02, -4.3946e-02, -2.5078e-02,
          1.9863e-02, -1.7991e-02,  1.4665e-02,  5.7939e-02, -2.3066e-02,
         -3.286



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0068,  0.0015, -0.0015,  0.0085,  0.0032], device='cuda:0')
tensor([2.5257e-03, 2.8327e-06, 1.5347e-03, 3.9884e-03, 1.6196e-04],
       device='cuda:0')
tensor([[ 8.3761e-02, -4.6051e-02,  1.4827e-02,  9.8625e-03,  1.3083e-03,
         -2.3180e-02,  2.7057e-02,  7.9257e-02, -1.9881e-02,  1.3177e-01,
         -2.6702e-02, -2.9248e-02,  2.4058e-02, -8.5809e-02,  1.4913e-02,
          5.3740e-02,  1.6113e-02, -4.0587e-02, -4.1449e-03, -3.4229e-02,
          1.8330e-02,  2.0388e-02, -3.4912e-02,  1.2379e-03, -4.4434e-02,
         -1.3920e-02,  1.5656e-02,  3.5619e-03, -5.7278e-03,  6.2465e-02,
         -3.1376e-02,  3.4696e-02, -1.4275e-02,  1.6841e-02,  2.8647e-02,
         -3.1397e-03,  2.1962e-03, -3.1803e-03,  7.6017e-02,  4.3435e-02,
          2.9596e-02,  1.1791e-02,  6.1522e-02, -3.9335e-02, -2.2743e-02,
          3.0614e-02, -2.3079e-02, -9.5861e-03,  4.7261e-02, -7.9059e-03,
         -1.6307e-02, -1.51



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0054, -0.0029,  0.0014,  0.0096,  0.0063], device='cuda:0')
tensor([ 0.0013, -0.0031,  0.0023,  0.0053,  0.0036], device='cuda:0')
tensor([[ 1.2166e-01, -8.7681e-02,  1.7783e-02, -1.0295e-02,  2.0202e-02,
         -2.6164e-02,  4.5175e-02,  9.9396e-02, -2.4560e-02,  1.4114e-01,
         -3.9262e-02, -5.5893e-02,  6.7088e-02, -9.4173e-02,  1.8734e-02,
          6.8408e-02,  2.2876e-02, -3.8058e-02,  5.6126e-03, -3.3862e-03,
          8.9885e-03,  4.3278e-02, -2.4348e-02, -1.4749e-02, -5.6336e-02,
         -4.3064e-02,  1.1677e-02,  2.1232e-03,  1.8366e-03,  8.8006e-02,
         -5.1952e-02,  3.9842e-02, -1.7281e-02,  3.2936e-02,  2.4676e-02,
         -2.2828e-03,  1.6894e-03, -8.1961e-03,  4.8894e-02,  5.3180e-02,
          4.5909e-02,  3.6673e-02,  6.7505e-02, -6.7628e-02, -1.3492e-02,
          3.9804e-02, -2.2859e-02, -2.4629e-03,  6.9781e-02, -3.1223e-02,
         -2.0117e-02,  1.5794e-03, -1.0299e-02,  



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0109, -0.0029, -0.0079,  0.0203,  0.0042], device='cuda:0')
tensor([ 0.0034, -0.0031, -0.0047,  0.0071,  0.0017], device='cuda:0')
tensor([[ 1.6040e-01, -1.6756e-01,  2.6600e-02,  6.0567e-04,  4.7743e-02,
          8.1545e-03,  8.9442e-02,  1.4220e-01, -3.4713e-02,  1.8469e-01,
         -6.2389e-02, -4.5378e-02,  1.0749e-01, -6.1652e-02,  5.7457e-03,
          1.0871e-01,  9.4242e-02, -8.4793e-03,  2.5134e-02,  2.4933e-02,
          3.9090e-02,  8.2022e-02, -3.5258e-02, -2.3185e-02, -1.1939e-01,
         -8.7724e-02, -5.9302e-02,  2.1345e-03, -8.8837e-03,  9.2647e-02,
         -9.9203e-02,  4.9414e-02, -5.2565e-02,  4.1056e-02, -1.6944e-02,
         -4.8905e-03,  1.1496e-03, -1.5962e-02, -5.2762e-03,  6.3193e-02,
          1.3515e-01,  7.0560e-02,  5.4512e-02, -7.3194e-02, -3.4899e-02,
          2.5008e-02, -1.6535e-02,  2.1460e-02,  1.0654e-01, -1.7646e-02,
         -1.0614e-02,  1.8568e-02,  2.4933e-02, -



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0081, -0.0051, -0.0059,  0.0203,  0.0060], device='cuda:0')
tensor([-0.0007, -0.0036, -0.0043,  0.0074,  0.0029], device='cuda:0')
tensor([[ 0.1047, -0.1541, -0.0028, -0.0128,  0.0543, -0.0023,  0.1080,  0.1213,
         -0.0114,  0.2055, -0.0546, -0.0531,  0.1045, -0.0701, -0.0069,  0.0649,
          0.1038, -0.0021,  0.0622,  0.0157,  0.0111,  0.0730, -0.0543, -0.0053,
         -0.0700, -0.0958, -0.0443,  0.0257, -0.0500,  0.1138, -0.0566,  0.0444,
         -0.0743,  0.0564, -0.0337, -0.0398, -0.0026,  0.0126, -0.0055,  0.0929,
          0.1024,  0.0248,  0.0591, -0.0965, -0.0347,  0.0493,  0.0082,  0.0131,
          0.1378, -0.0099, -0.0352,  0.0282,  0.0021, -0.0191,  0.0411,  0.2229,
         -0.0879, -0.0680,  0.0292,  0.1161, -0.0775, -0.0012, -0.0579,  0.0314],
        [ 0.0434, -0.0142, -0.0098,  0.0024,  0.0083,  0.0048,  0.0016,  0.0328,
          0.0116,  0.0609, -0.0077,  0.0339,  0.0569,  0.04



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 8.9618e-03, -4.3328e-03,  6.8463e-05,  1.9915e-02,  6.1963e-03],
       device='cuda:0')
tensor([-0.0035, -0.0024, -0.0027,  0.0074,  0.0041], device='cuda:0')
tensor([[ 0.0979, -0.1582, -0.0182, -0.0245,  0.0688, -0.0043,  0.0986,  0.1513,
         -0.0231,  0.2065, -0.0247, -0.0413,  0.1130, -0.0721,  0.0287,  0.0441,
          0.1255, -0.0252,  0.0804, -0.0241,  0.0116,  0.0634, -0.0928, -0.0407,
         -0.0806, -0.1105, -0.0221,  0.0888, -0.0643,  0.1328, -0.0819,  0.0326,
         -0.0922,  0.0265, -0.0786, -0.0348,  0.0004,  0.0090,  0.0063,  0.1044,
          0.1193, -0.0096,  0.0964, -0.0944, -0.0265,  0.0375, -0.0026,  0.0176,
          0.1618,  0.0058, -0.0360,  0.0336,  0.0052, -0.0259,  0.0419,  0.2624,
         -0.1028, -0.0731,  0.0375,  0.1553, -0.0879,  0.0029, -0.0616,  0.0548],
        [ 0.0514, -0.0098, -0.0078, -0.0074,  0.0346,  0.0311, -0.0449,  0.0079,
          0.0262,  0.0743, -0.01



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0150, -0.0002, -0.0015,  0.0162,  0.0071], device='cuda:0')
tensor([-0.0012,  0.0013, -0.0037,  0.0051,  0.0062], device='cuda:0')
tensor([[ 9.4497e-02, -2.1390e-01, -5.0911e-03, -3.7141e-02,  7.4698e-02,
         -2.9455e-03,  1.1606e-01,  1.5696e-01, -2.7466e-02,  2.2347e-01,
         -1.3238e-02, -5.5516e-02,  1.4432e-01, -5.9982e-02,  7.2412e-03,
          2.6184e-02,  1.0651e-01, -3.1438e-03,  8.4949e-02, -1.5119e-02,
         -1.4772e-02,  8.9982e-02, -1.0793e-01, -5.5243e-02, -9.0546e-02,
         -1.2097e-01, -6.2484e-02,  9.2955e-02, -5.1238e-02,  1.5384e-01,
         -8.2858e-02,  3.4762e-02, -9.3039e-02,  2.2644e-02, -1.0671e-01,
         -5.6090e-02, -1.1813e-03,  9.8048e-03, -3.4410e-02,  1.0879e-01,
          1.4972e-01,  2.7604e-02,  8.6870e-02, -1.1473e-01, -1.5919e-02,
          5.7106e-02, -2.6249e-02,  2.8878e-02,  1.7635e-01, -6.5056e-03,
         -4.5496e-02,  5.5420e-02,  1.0613e-02, -



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0091, -0.0032, -0.0089,  0.0177,  0.0074], device='cuda:0')
tensor([-0.0034, -0.0006, -0.0057,  0.0041,  0.0067], device='cuda:0')
tensor([[ 0.0830, -0.2152,  0.0051, -0.0409,  0.0683,  0.0010,  0.1187,  0.1418,
         -0.0337,  0.2027, -0.0193, -0.0455,  0.1458, -0.0317,  0.0030, -0.0074,
          0.0946, -0.0152,  0.0328, -0.0057, -0.0257,  0.0830, -0.0931, -0.0425,
         -0.0520, -0.0994, -0.0755,  0.0346, -0.0399,  0.1185, -0.0718,  0.0387,
         -0.0831,  0.0482, -0.0597, -0.0629,  0.0143, -0.0241, -0.0592,  0.0962,
          0.1524,  0.0123,  0.0833, -0.0971, -0.0120,  0.0891, -0.0437,  0.0498,
          0.1298, -0.0153, -0.0339,  0.0561,  0.0086, -0.0306,  0.0548,  0.2401,
         -0.1131, -0.0482,  0.0767,  0.1153, -0.1106, -0.0312, -0.0183,  0.0320],
        [ 0.0618, -0.0395,  0.0210, -0.0021,  0.0479,  0.0303, -0.0349,  0.0804,
          0.0404,  0.1051, -0.0519,  0.0645,  0.0885,  0.00



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0073, -0.0046, -0.0010,  0.0162,  0.0081], device='cuda:0')
tensor([-0.0067, -0.0031,  0.0019,  0.0040,  0.0039], device='cuda:0')
tensor([[ 8.1540e-02, -2.2763e-01,  7.8080e-03, -5.6212e-02,  7.4319e-02,
         -2.5233e-02,  1.2088e-01,  1.6822e-01, -4.4046e-02,  2.1061e-01,
         -2.0570e-02, -6.3105e-02,  1.7208e-01, -4.3819e-02, -7.7426e-03,
          2.1624e-03,  8.2839e-02, -1.6002e-02,  2.6530e-02,  1.4869e-02,
         -3.2026e-02,  9.6132e-02, -9.2777e-02, -4.8593e-02, -7.0037e-02,
         -1.1328e-01, -7.4563e-02,  4.4983e-02, -5.2193e-02,  1.3270e-01,
         -9.8935e-02,  3.0355e-02, -1.0370e-01,  6.4111e-02, -8.9554e-02,
         -6.3580e-02, -5.1599e-03, -8.9901e-03, -8.6114e-02,  1.1069e-01,
          1.7332e-01,  2.1735e-02,  7.6481e-02, -1.0277e-01, -2.1508e-02,
          7.8514e-02, -1.9653e-02,  5.7999e-02,  1.5004e-01, -3.9414e-02,
         -4.8976e-02,  4.6767e-02,  3.2683e-02, -



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0120, -0.0086, -0.0014,  0.0192,  0.0087], device='cuda:0')
tensor([-0.0052, -0.0063,  0.0008,  0.0036,  0.0032], device='cuda:0')
tensor([[ 1.1072e-01, -2.5647e-01,  1.9198e-03, -8.8600e-02,  1.1866e-01,
         -6.3306e-02,  1.2319e-01,  1.8693e-01, -8.7680e-02,  3.0297e-01,
         -3.0308e-02, -7.3993e-02,  1.9763e-01, -7.3306e-02, -1.4100e-02,
          5.9482e-02,  7.3917e-02, -7.5044e-05,  1.0949e-02,  4.0689e-02,
         -2.7650e-02,  1.4787e-01, -9.9767e-02, -3.4647e-02, -8.8366e-02,
         -1.2396e-01, -7.7904e-02,  1.1230e-01, -6.8108e-02,  1.6591e-01,
         -1.3558e-01,  5.5055e-02, -1.1635e-01,  1.1287e-01, -1.0510e-01,
         -8.7530e-02, -9.8359e-04,  4.9856e-02, -1.0473e-01,  1.2660e-01,
          1.6983e-01,  4.0976e-02,  2.2999e-02, -1.2642e-01, -3.9079e-02,
          8.1054e-02, -3.2239e-02,  4.4433e-02,  2.1089e-01, -2.8333e-02,
         -7.9679e-02,  6.4154e-02,  3.2345e-02, -



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0116, -0.0024,  0.0044,  0.0199,  0.0119], device='cuda:0')
tensor([-0.0068, -0.0016,  0.0050,  0.0060,  0.0044], device='cuda:0')
tensor([[ 7.7656e-02, -2.1458e-01, -2.1980e-02, -8.2345e-02,  7.5665e-02,
         -6.4393e-02,  9.8757e-02,  1.4608e-01, -1.0211e-01,  3.1785e-01,
         -3.7797e-05, -6.7961e-02,  1.4301e-01, -5.1372e-02, -3.4001e-02,
          4.2610e-02,  7.1619e-02, -1.9743e-02, -2.3439e-02,  2.6576e-02,
         -3.7477e-02,  2.6880e-02, -1.0410e-01, -2.2247e-02, -6.3200e-02,
         -1.4678e-01, -4.0498e-02,  4.9699e-02, -1.1507e-01,  1.4211e-01,
         -8.4886e-02,  5.6113e-02, -1.2840e-01,  1.5797e-01, -1.7624e-02,
         -1.0064e-01,  3.9909e-02,  8.4183e-02, -1.7298e-01,  1.4043e-01,
          1.0298e-01, -6.0688e-03,  3.2546e-02, -1.3505e-01, -6.3254e-02,
          9.4677e-02,  1.6489e-02,  8.0903e-02,  1.6916e-01, -8.0675e-03,
         -7.2620e-02,  5.9897e-02, -9.8634e-03, -



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0113, -0.0049,  0.0083,  0.0199,  0.0151], device='cuda:0')
tensor([-0.0053, -0.0012,  0.0039,  0.0048,  0.0068], device='cuda:0')
tensor([[ 6.5065e-02, -2.0473e-01, -2.4434e-02, -6.3693e-02,  5.3906e-02,
         -4.8720e-02,  1.0877e-01,  1.4182e-01, -9.1175e-02,  3.0068e-01,
         -1.7524e-03, -8.1147e-02,  1.4566e-01, -4.1404e-02, -1.6137e-02,
          1.1854e-02,  5.6369e-02,  1.3607e-03, -4.8663e-02,  2.3058e-02,
         -4.2063e-02,  3.8458e-02, -9.6697e-02, -3.5107e-02, -5.5817e-02,
         -1.3295e-01, -4.8062e-02,  7.6546e-02, -1.0590e-01,  1.4929e-01,
         -5.3035e-02,  8.3809e-02, -1.3923e-01,  1.6415e-01, -1.6041e-02,
         -8.7017e-02,  1.5449e-02,  6.7232e-02, -1.6396e-01,  1.1487e-01,
          9.0036e-02,  5.2487e-04,  3.7353e-02, -1.4200e-01, -6.7570e-02,
          1.0849e-01,  1.2635e-02,  8.3463e-02,  1.5458e-01, -2.0317e-02,
         -4.7999e-02,  4.4588e-02, -3.9763e-02, -



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0154, -0.0034,  0.0111,  0.0179,  0.0151], device='cuda:0')
tensor([-0.0042, -0.0015,  0.0064,  0.0043,  0.0067], device='cuda:0')
tensor([[ 4.6709e-02, -1.9921e-01, -6.3167e-03, -8.4407e-02,  6.1494e-02,
         -7.3931e-02,  7.0889e-02,  1.1653e-01, -8.6367e-02,  3.2194e-01,
         -5.0242e-03, -8.4456e-02,  1.7757e-01, -7.4461e-02, -2.5867e-02,
          3.7437e-03,  4.6272e-02, -3.4903e-03, -6.1834e-02, -1.0025e-02,
         -4.3327e-02,  2.9952e-02, -9.8677e-02, -2.6577e-02, -4.1470e-02,
         -1.5080e-01,  5.0083e-03,  1.1373e-01, -1.0873e-01,  1.4129e-01,
         -5.0120e-02,  9.7542e-02, -1.2426e-01,  1.2624e-01, -2.3426e-02,
         -6.2618e-02,  8.8176e-03,  4.2536e-02, -1.4671e-01,  9.3056e-02,
          8.8100e-02, -1.3807e-02,  8.2539e-02, -1.3382e-01, -5.3779e-02,
          1.0763e-01,  2.2410e-02,  1.1404e-01,  1.4873e-01, -1.1359e-02,
         -7.2787e-02,  6.1115e-02, -7.8369e-02, -



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([0.0184, 0.0038, 0.0127, 0.0202, 0.0153], device='cuda:0')
tensor([-0.0022,  0.0039,  0.0071,  0.0067,  0.0059], device='cuda:0')
tensor([[ 0.0787, -0.2067,  0.0088, -0.0574,  0.0653, -0.0777,  0.0525,  0.1148,
         -0.1008,  0.3163, -0.0722, -0.1055,  0.1598, -0.0370, -0.0506,  0.0262,
          0.0130, -0.0197, -0.0769, -0.0104, -0.0655,  0.0436, -0.0862,  0.0023,
         -0.0210, -0.1298, -0.0266,  0.0993, -0.1155,  0.1778, -0.0401,  0.0588,
         -0.0901,  0.1618, -0.0404, -0.0721,  0.0316,  0.0631, -0.2104,  0.1265,
          0.1361,  0.0279,  0.1161, -0.1263, -0.0845,  0.0959,  0.0439,  0.1357,
          0.2169, -0.0008, -0.0826,  0.0289, -0.0582, -0.0542,  0.0814,  0.2674,
         -0.1005, -0.0601,  0.0088,  0.2002, -0.0718,  0.0228, -0.0698,  0.1029],
        [ 0.1400, -0.1173,  0.0297,  0.0244,  0.0722,  0.0084,  0.0205,  0.0983,
          0.0508,  0.1532, -0.0688,  0.0300,  0.1792,  0.0219, -



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0152, -0.0005,  0.0089,  0.0263,  0.0198], device='cuda:0')
tensor([-0.0048,  0.0012,  0.0044,  0.0116,  0.0092], device='cuda:0')
tensor([[ 4.7530e-02, -2.0960e-01, -2.5311e-03, -9.3858e-02,  1.3138e-02,
         -5.4192e-02,  4.1777e-02,  9.2930e-02, -9.8512e-02,  3.1003e-01,
         -1.5458e-02, -7.3037e-02,  1.3636e-01, -3.6092e-02, -3.4663e-03,
         -4.3241e-03,  9.9356e-03,  2.5077e-04, -9.0901e-02,  3.5655e-02,
         -4.5479e-02,  7.3016e-03, -8.7195e-02, -5.3430e-03,  2.2435e-03,
         -9.8547e-02,  8.7503e-03,  5.1286e-02, -1.4496e-01,  1.5710e-01,
         -3.7657e-02,  7.0846e-02, -8.8461e-02,  1.4105e-01,  4.8817e-02,
         -8.1647e-02,  2.8848e-02,  4.5802e-02, -2.1385e-01,  1.3948e-01,
          7.0664e-02,  7.1284e-03,  1.0541e-01, -6.2086e-02, -8.0894e-02,
          1.1861e-01,  4.3044e-02,  1.3591e-01,  1.8365e-01, -4.4926e-02,
         -7.4243e-02,  3.8212e-02, -8.1717e-02,  



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0110, -0.0033,  0.0072,  0.0235,  0.0196], device='cuda:0')
tensor([-0.0054, -0.0006,  0.0035,  0.0092,  0.0104], device='cuda:0')
tensor([[ 2.5982e-02, -2.0010e-01,  1.2851e-02, -1.0369e-01,  2.8945e-02,
         -3.6527e-02,  2.2964e-02,  7.5163e-02, -8.3283e-02,  3.1479e-01,
         -1.1501e-02, -8.1272e-02,  1.1158e-01, -3.9829e-02, -1.9430e-03,
         -4.9747e-02,  3.4112e-02, -3.1632e-03, -8.6602e-02,  3.6839e-02,
         -5.2742e-02, -1.0041e-02, -7.0723e-02, -2.5560e-02,  1.6873e-02,
         -1.0584e-01,  3.3919e-03,  6.3460e-02, -1.4645e-01,  1.5726e-01,
         -2.4013e-02,  7.0527e-02, -1.0532e-01,  1.5494e-01,  5.8318e-02,
         -8.4333e-02,  3.9713e-02,  1.7491e-02, -2.0605e-01,  1.3266e-01,
          6.5211e-02, -1.1735e-02,  8.8139e-02, -4.5192e-02, -5.1110e-02,
          1.3095e-01,  4.1602e-02,  1.4891e-01,  1.8024e-01, -2.1979e-02,
         -5.8059e-02,  5.7677e-02, -8.8355e-02,  



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0162, -0.0008,  0.0100,  0.0214,  0.0180], device='cuda:0')
tensor([ 0.0004, -0.0007,  0.0069,  0.0068,  0.0030], device='cuda:0')
tensor([[ 0.0036, -0.1832,  0.0188, -0.0702,  0.0308, -0.0290, -0.0125,  0.0616,
         -0.0810,  0.3358, -0.0233, -0.0733,  0.0902, -0.0182,  0.0279, -0.0117,
          0.0295, -0.0062, -0.0876,  0.0598, -0.0424, -0.0155, -0.0552,  0.0185,
          0.0267, -0.0885,  0.0118,  0.0527, -0.1492,  0.1521, -0.0094,  0.0698,
         -0.0756,  0.1367,  0.0935, -0.1267,  0.0436,  0.0171, -0.1726,  0.1340,
          0.0728, -0.0107,  0.1099, -0.0598, -0.0469,  0.1575,  0.0033,  0.1280,
          0.2036, -0.0336, -0.0427,  0.0442, -0.0874,  0.0751,  0.0983,  0.1622,
         -0.0361, -0.0661, -0.0080,  0.1327, -0.0071,  0.0095, -0.0533,  0.0594],
        [ 0.2071, -0.2476,  0.0258,  0.0624,  0.0595, -0.0227,  0.0664,  0.1687,
          0.0531,  0.2303, -0.0680, -0.0255,  0.2298,  0.04



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0179, -0.0047,  0.0106,  0.0183,  0.0162], device='cuda:0')
tensor([ 0.0008, -0.0024,  0.0068,  0.0048,  0.0011], device='cuda:0')
tensor([[ 4.2929e-02, -2.0941e-01,  3.0397e-02, -8.5067e-02,  2.7923e-02,
         -4.9535e-02,  9.6801e-03,  6.2159e-02, -7.6965e-02,  3.2893e-01,
         -2.5842e-02, -8.6773e-02,  1.1372e-01,  2.9591e-03,  2.9578e-02,
          3.6404e-03,  1.1761e-02, -6.3677e-03, -1.0210e-01,  8.4855e-02,
         -2.8757e-02,  4.9297e-03, -4.9342e-02, -6.5573e-03,  2.7945e-02,
         -9.7659e-02, -7.7333e-03,  5.1273e-02, -1.1746e-01,  1.4965e-01,
         -2.7312e-02,  7.0636e-02, -8.2003e-02,  1.1791e-01,  6.7434e-02,
         -1.2215e-01,  1.1877e-02,  1.9846e-02, -1.7377e-01,  1.4052e-01,
          7.3124e-02, -6.6394e-03,  9.0738e-02, -5.8147e-02, -4.0444e-02,
          1.5734e-01,  7.8563e-04,  1.0641e-01,  2.1219e-01, -4.2531e-02,
         -5.8176e-02,  3.2176e-02, -8.6516e-02,  



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 1.4456e-02, -6.2546e-05,  1.0024e-02,  1.4703e-02,  2.0100e-02],
       device='cuda:0')
tensor([-5.2120e-04, -7.7308e-05,  7.2751e-03,  2.4837e-03,  2.3907e-03],
       device='cuda:0')
tensor([[ 3.9300e-02, -1.8992e-01,  2.2570e-02, -8.6859e-02,  1.0765e-02,
         -4.4088e-02, -3.1707e-03,  5.1638e-02, -8.3409e-02,  3.2018e-01,
         -3.5675e-02, -8.8628e-02,  1.0728e-01,  4.3659e-03,  3.2512e-02,
          6.9635e-03,  1.8641e-02, -7.3135e-03, -9.7695e-02,  8.0382e-02,
         -2.3337e-02,  3.4252e-04, -5.4932e-02,  2.8996e-03,  1.9741e-02,
         -1.0816e-01,  2.3091e-03,  3.7937e-02, -1.2178e-01,  1.4639e-01,
         -1.8348e-02,  6.3753e-02, -7.8227e-02,  1.2124e-01,  6.8898e-02,
         -1.2541e-01,  1.6117e-02,  2.2787e-02, -1.6428e-01,  1.5178e-01,
          6.7195e-02, -8.9327e-03,  1.0422e-01, -5.3328e-02, -5.3375e-02,
          1.5048e-01,  7.4951e-03,  1.0755e-01,  2.1319e-01, -3.3312e



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0134, -0.0001,  0.0193,  0.0138,  0.0194], device='cuda:0')
tensor([-0.0013, -0.0020,  0.0146, -0.0010,  0.0013], device='cuda:0')
tensor([[ 3.7822e-02, -1.8748e-01,  2.0795e-02, -9.3326e-02,  3.6308e-03,
         -5.5119e-02, -1.5390e-03,  4.1615e-02, -8.5175e-02,  3.1810e-01,
         -2.8391e-02, -8.7234e-02,  1.0126e-01, -5.1386e-05,  2.9018e-02,
          5.4657e-03,  1.2138e-02, -2.0150e-02, -1.0062e-01,  7.8067e-02,
         -2.1403e-02, -5.2753e-03, -4.9625e-02, -1.1552e-03,  2.1632e-02,
         -9.4724e-02,  6.1359e-03,  2.7205e-02, -1.2786e-01,  1.4222e-01,
         -2.0909e-02,  5.7628e-02, -6.9430e-02,  1.2297e-01,  6.9712e-02,
         -1.1808e-01,  1.6025e-02,  2.5707e-02, -1.6201e-01,  1.4726e-01,
          5.8419e-02, -1.4104e-02,  9.3792e-02, -4.8016e-02, -5.4985e-02,
          1.4633e-01,  7.5539e-03,  1.0380e-01,  2.1602e-01, -3.7990e-02,
         -5.0973e-02,  3.0461e-02, -8.9826e-02,  



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0125, -0.0007,  0.0170,  0.0155,  0.0189], device='cuda:0')
tensor([-0.0009, -0.0030,  0.0113,  0.0018,  0.0019], device='cuda:0')
tensor([[ 7.7065e-03, -1.2488e-01,  4.1363e-03, -5.8529e-02, -4.1863e-03,
         -4.8254e-02, -1.8321e-02,  2.8275e-02, -1.0302e-01,  3.0798e-01,
         -3.1157e-02, -8.9483e-02,  6.2174e-02,  2.7232e-02,  4.4599e-02,
          8.2485e-03, -1.7378e-02, -2.6305e-02, -9.6357e-02,  5.4157e-02,
         -9.5750e-03, -1.2483e-02, -5.3364e-02,  7.4611e-03,  3.2152e-02,
         -8.2468e-02,  3.9385e-02, -7.1870e-04, -1.4334e-01,  1.3712e-01,
         -4.2884e-03,  4.4619e-02, -5.6105e-02,  1.3262e-01,  9.6049e-02,
         -8.4438e-02,  4.5082e-02,  1.2463e-02, -1.6092e-01,  1.5240e-01,
          4.6965e-02, -5.3158e-02,  1.0704e-01, -4.8455e-02, -6.3725e-02,
          1.6294e-01,  7.2359e-03,  8.8159e-02,  2.0533e-01, -2.0959e-02,
         -3.6573e-02,  2.8072e-02, -1.4141e-01,  



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0132, -0.0063,  0.0144,  0.0151,  0.0191], device='cuda:0')
tensor([-2.3709e-04, -5.3211e-03,  1.0380e-02,  4.4567e-05,  1.8205e-03],
       device='cuda:0')
tensor([[ 0.0098, -0.1134, -0.0168, -0.0641,  0.0110, -0.0395, -0.0386,  0.0595,
         -0.1078,  0.2986, -0.0254, -0.0918,  0.0608,  0.0426,  0.0467,  0.0023,
         -0.0201, -0.0124, -0.0864,  0.0540, -0.0124, -0.0311, -0.0549,  0.0144,
          0.0601, -0.1002,  0.0381, -0.0335, -0.1639,  0.1372, -0.0051,  0.0386,
         -0.0376,  0.1472,  0.1286, -0.0799,  0.0436,  0.0167, -0.1564,  0.1497,
          0.0421, -0.0537,  0.1263, -0.0463, -0.0523,  0.1808, -0.0187,  0.1091,
          0.2128, -0.0160, -0.0134,  0.0295, -0.1418,  0.1300,  0.1012,  0.1283,
         -0.0114, -0.0426, -0.0086,  0.1395,  0.0114,  0.0394,  0.0123,  0.0695],
        [ 0.2795, -0.3260,  0.0866,  0.0368,  0.1032, -0.0882,  0.1515,  0.2629,
          0.0781,  0.2783, -0.02



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0145, -0.0030,  0.0179,  0.0167,  0.0236], device='cuda:0')
tensor([ 5.0048e-05, -2.8084e-03,  1.2214e-02, -8.9882e-04,  3.6116e-03],
       device='cuda:0')
tensor([[ 4.1135e-03, -1.1739e-01, -4.4163e-03, -6.8145e-02,  2.8585e-02,
         -4.7370e-02, -2.9098e-02,  5.2040e-02, -8.9249e-02,  2.6404e-01,
         -3.9649e-02, -8.9950e-02,  4.4380e-02,  6.7669e-02,  3.5842e-02,
         -6.5484e-03, -2.4639e-02, -3.3711e-02, -8.8552e-02,  2.4109e-02,
          1.8489e-02, -5.0882e-02, -4.4350e-02,  2.7421e-02,  6.7964e-02,
         -1.0421e-01,  2.5630e-02, -1.1628e-02, -1.8488e-01,  1.4058e-01,
          2.4467e-02,  5.4561e-02, -2.7152e-02,  1.3192e-01,  1.3904e-01,
         -7.0627e-02,  6.1287e-02,  1.5111e-02, -1.6930e-01,  1.5260e-01,
          3.8893e-02, -9.4815e-02,  1.4889e-01, -5.2731e-02, -3.1558e-02,
          1.6868e-01, -2.6469e-02,  9.4312e-02,  2.1897e-01, -1.3268e-02,
         -8.2488e-03, 



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0162, -0.0027,  0.0199,  0.0166,  0.0227], device='cuda:0')
tensor([ 0.0016, -0.0053,  0.0129, -0.0020,  0.0026], device='cuda:0')
tensor([[-7.3574e-02, -6.9709e-02, -1.5550e-02, -5.2535e-02,  2.2172e-02,
         -5.4343e-02, -3.9444e-02,  1.2237e-02, -5.1723e-02,  2.4722e-01,
         -1.0971e-02, -5.0870e-02,  2.5387e-02,  6.2144e-02,  6.0152e-02,
         -4.7985e-02, -2.1290e-02, -1.7752e-02, -8.2061e-02, -3.4938e-02,
          1.7248e-02, -8.7884e-02, -3.5336e-02,  4.9027e-02,  7.7629e-02,
         -1.0421e-01,  7.4740e-02, -1.4794e-02, -2.0945e-01,  1.3049e-01,
          4.5965e-02,  8.0156e-02, -2.4811e-02,  1.2265e-01,  1.8728e-01,
         -7.4312e-02,  1.0310e-01, -2.4490e-02, -1.1916e-01,  1.5011e-01,
         -8.7097e-03, -1.1739e-01,  1.6075e-01, -3.3740e-02, -3.1203e-02,
          1.7551e-01, -1.4700e-02,  9.9085e-02,  1.9563e-01,  1.1762e-02,
         -1.4381e-02, -8.3303e-03, -1.6342e-01,  



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0218, -0.0008,  0.0166,  0.0159,  0.0219], device='cuda:0')
tensor([ 0.0040, -0.0045,  0.0119, -0.0015,  0.0022], device='cuda:0')
tensor([[-5.8332e-02, -8.6527e-02, -1.4906e-02, -5.2975e-02,  1.4865e-02,
         -4.0618e-02, -2.9429e-02,  2.4728e-02, -6.4644e-02,  2.6061e-01,
         -1.8236e-02, -5.5816e-02,  3.5856e-02,  7.5174e-02,  7.5103e-02,
         -2.4138e-02, -5.0689e-02, -3.6517e-02, -7.9367e-02, -1.5586e-02,
          2.4122e-02, -6.3424e-02, -3.6789e-02,  5.1267e-02,  8.7999e-02,
         -1.0746e-01,  5.8267e-02, -4.0123e-02, -1.9079e-01,  1.2695e-01,
          3.9808e-02,  7.2399e-02, -3.7834e-02,  1.2214e-01,  1.8021e-01,
         -7.1049e-02,  8.6674e-02, -1.5164e-02, -1.2547e-01,  1.6390e-01,
         -1.1410e-02, -1.1205e-01,  1.5868e-01, -2.8440e-02, -1.3250e-02,
          1.7099e-01, -4.3664e-03,  9.1876e-02,  2.2791e-01,  1.1330e-02,
         -7.0301e-03, -8.4258e-04, -1.7338e-01,  



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0179, -0.0020,  0.0125,  0.0167,  0.0213], device='cuda:0')
tensor([ 0.0016, -0.0058,  0.0103, -0.0028,  0.0031], device='cuda:0')
tensor([[-5.4345e-02, -7.5529e-02, -1.3151e-02, -3.2580e-02,  3.8550e-02,
         -5.5505e-02, -1.3369e-02,  6.0408e-02, -7.6228e-02,  2.6092e-01,
         -1.1975e-02, -4.2760e-02,  3.7846e-02,  9.9664e-02,  7.8054e-02,
         -2.4335e-02, -3.5270e-02, -4.2437e-02, -7.5721e-02, -3.0310e-03,
          4.2546e-02, -6.6696e-02, -2.7567e-02,  5.0842e-02,  8.2420e-02,
         -9.3836e-02,  4.1104e-02, -3.7339e-02, -1.7533e-01,  1.1560e-01,
          6.3538e-02,  6.3817e-02, -3.2333e-02,  1.2841e-01,  1.6374e-01,
         -7.2139e-02,  8.5687e-02, -2.8105e-02, -1.3194e-01,  1.8644e-01,
          9.2085e-03, -1.0445e-01,  1.5487e-01, -1.9999e-02, -1.3482e-02,
          1.6799e-01, -1.1177e-02,  6.9485e-02,  2.3935e-01,  2.5921e-02,
         -8.9690e-03,  2.5064e-03, -1.7867e-01,  



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0170, -0.0080,  0.0144,  0.0197,  0.0260], device='cuda:0')
tensor([-0.0002, -0.0061,  0.0112, -0.0012,  0.0012], device='cuda:0')
tensor([[-3.2703e-02, -1.1589e-01,  7.0727e-03, -6.0384e-02,  5.7741e-02,
         -1.0562e-01,  3.0321e-02,  1.1371e-01, -9.8279e-02,  2.2543e-01,
         -3.8295e-02, -4.6777e-02,  7.3406e-02,  1.4121e-01,  6.0817e-02,
         -1.6164e-02, -4.3917e-02, -5.0513e-02, -4.9496e-02,  3.0887e-02,
          5.8670e-02, -6.4328e-02, -3.9768e-02,  1.0424e-01,  7.3583e-02,
         -9.2422e-02,  3.4456e-02, -8.6259e-02, -1.7285e-01,  1.1946e-01,
          4.9901e-02,  6.0756e-02, -4.1488e-02,  1.5228e-01,  1.7918e-01,
         -1.1250e-01,  7.6660e-02, -1.7920e-02, -2.1756e-01,  2.0476e-01,
          5.2390e-02, -8.2633e-02,  1.6499e-01, -3.5564e-02, -1.1420e-02,
          1.6547e-01,  3.2298e-03,  1.0422e-01,  2.4338e-01, -6.1647e-03,
         -3.2112e-02,  8.8149e-03, -1.3699e-01,  



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0150, -0.0078,  0.0136,  0.0254,  0.0280], device='cuda:0')
tensor([-0.0017, -0.0040,  0.0078,  0.0021,  0.0018], device='cuda:0')
tensor([[-1.6297e-02, -1.0716e-01,  7.7296e-03, -4.9185e-02,  3.3758e-02,
         -1.0522e-01,  2.5004e-02,  7.9171e-02, -1.0601e-01,  2.0967e-01,
         -4.8900e-02, -4.4385e-02,  3.7154e-02,  1.2860e-01,  7.6792e-02,
         -3.4292e-02, -2.1544e-02, -6.6885e-02, -6.2590e-02,  5.9457e-02,
          6.2477e-02, -8.6168e-02, -2.2423e-02,  1.1749e-01,  1.0646e-01,
         -1.0611e-01,  1.3702e-02, -1.2477e-01, -1.5700e-01,  1.2577e-01,
          5.0961e-02,  6.2857e-02, -3.5628e-02,  1.4988e-01,  1.8696e-01,
         -8.1193e-02,  9.4886e-02, -9.3963e-04, -2.3526e-01,  2.2122e-01,
          3.4476e-02, -7.1352e-02,  1.8131e-01, -1.3235e-02,  4.3517e-03,
          1.6980e-01,  2.3975e-02,  1.0464e-01,  2.4944e-01, -1.4823e-02,
         -2.1395e-02,  3.1013e-03, -1.2357e-01,  



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0156, -0.0055,  0.0195,  0.0150,  0.0324], device='cuda:0')
tensor([-0.0019, -0.0018,  0.0108, -0.0037,  0.0028], device='cuda:0')
tensor([[-0.0329, -0.0947,  0.0187, -0.0457,  0.0385, -0.1179,  0.0495,  0.0735,
         -0.1133,  0.2195, -0.0606, -0.0326,  0.0128,  0.1549,  0.0991, -0.0500,
         -0.0123, -0.0439, -0.0463,  0.0917,  0.0441, -0.0902, -0.0341,  0.1068,
          0.1142, -0.0997,  0.0174, -0.0825, -0.1441,  0.1248,  0.0703,  0.0237,
         -0.0326,  0.1419,  0.1843, -0.0867,  0.0790, -0.0110, -0.2305,  0.2237,
          0.0137, -0.0753,  0.1868, -0.0218,  0.0147,  0.1758,  0.0102,  0.1254,
          0.2581, -0.0201, -0.0259,  0.0071, -0.1481,  0.1727,  0.1196, -0.0192,
         -0.0106, -0.0789, -0.0425,  0.1465,  0.0685,  0.0494, -0.0307, -0.0144],
        [ 0.3393, -0.4426,  0.1362, -0.0654,  0.2011, -0.0576,  0.2141,  0.3214,
          0.0581,  0.3541, -0.1388, -0.2065,  0.4467, -0.00



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([0.0190, 0.0005, 0.0205, 0.0174, 0.0351], device='cuda:0')
tensor([-0.0014,  0.0046,  0.0118, -0.0032,  0.0048], device='cuda:0')
tensor([[-2.6396e-02, -9.3389e-02,  2.6287e-02, -3.6637e-02,  1.9168e-02,
         -1.1645e-01,  6.5717e-02,  8.4997e-02, -9.3261e-02,  2.3286e-01,
         -6.6010e-02, -3.0396e-02,  2.2943e-02,  1.6082e-01,  8.4622e-02,
         -5.7420e-02,  9.4684e-03, -3.9081e-02, -4.2483e-02,  1.0075e-01,
          5.0339e-02, -8.7732e-02, -4.1400e-02,  1.2866e-01,  1.0577e-01,
         -9.6181e-02,  3.5976e-02, -8.6708e-02, -1.4005e-01,  1.0939e-01,
          8.9822e-02,  4.4256e-02, -5.7661e-02,  1.2876e-01,  1.8221e-01,
         -8.3255e-02,  8.1314e-02,  1.7070e-02, -2.2711e-01,  2.0250e-01,
          1.7302e-02, -4.7986e-02,  1.7503e-01, -2.6905e-02,  2.6551e-03,
          1.7912e-01,  1.3098e-02,  1.2715e-01,  2.6571e-01, -4.5033e-02,
         -2.2742e-02,  7.6756e-05, -1.2930e-01,  1.549



torch.Size([50])
torch.Size([25])




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([0.0203, 0.0015, 0.0249, 0.0234, 0.0335], device='cuda:0')
tensor([-0.0009,  0.0028,  0.0126,  0.0009,  0.0022], device='cuda:0')
tensor([[-1.6880e-02, -4.4897e-02, -3.7235e-03, -2.1186e-02, -9.6999e-03,
         -1.2442e-01,  4.6885e-02,  3.9150e-02, -1.1494e-01,  1.8956e-01,
         -3.8074e-02,  8.5363e-03, -2.4038e-02,  1.5510e-01,  1.1044e-01,
         -4.0051e-02,  3.7910e-02, -6.2818e-02, -6.9648e-02,  7.3275e-02,
          7.8479e-02, -9.3321e-02, -5.5751e-02,  1.0838e-01,  8.6443e-02,
         -9.5920e-02,  7.6508e-02, -1.0703e-01, -1.1870e-01,  1.0635e-01,
          1.0612e-01,  4.9455e-02, -6.5937e-02,  1.3174e-01,  2.0380e-01,
         -8.5340e-02,  8.5409e-02,  2.9228e-02, -1.6676e-01,  2.4600e-01,
         -4.3574e-02, -5.7039e-02,  1.4667e-01, -8.3993e-03, -7.6943e-03,
          1.7191e-01,  1.8631e-03,  1.1075e-01,  2.8436e-01,  3.4580e-03,
         -4.9118e-03, -1.6426e-02, -1.8880e-01,  1.688



torch.Size([50])
torch.Size([25])


Iteration 1
Meta Train Error 0.062271493254229426
Meta Train Accuracy 0.31499999132938683
Meta Valid Error 0.06434055068530142
Meta Valid Accuracy 0.2787499923724681
Meta Test Error 0.06290630542207509
Meta Test Accuracy 0.3049999901559204




torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([0., 0., 0., 0., 0.], device='cuda:0')
tensor([0., 0., 0., 0., 0.], device='cuda:0')
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0046,  0.0030, -0.0007,  0.0049, -0.0021], device='cuda:0')
tensor([ 0.0014,  0.0025, -0.0032,  0.0040,  0.0001], device='cuda:0')
tensor([[ 1.9282e-03, -1.9190e-02, -5.7396e-03,  1.6157e-03, -8.8029e-03,
          3.7480e-03, -1.8085e-03, -8.5671e-03, -1.4274e-02,  2.2595e-03,
          1.5340e-02,  2.1985e-03,  3.2121e-02,  7.2659e-03,  4.1815e-03,
          1.5694e-02,  6.2456e-03, -8.5056e-03, -5.6233e-03,  1.8687e-02,
         -3.1923e-03, -4.1959e-04, -1.0889e-02, -2.6260e-02, -1.1877e-02,
         -5.1600e-03,  1.3273e-02, -4.0324e-03,  8.7394e-03,  6.8014e-03,
          6.3673e-03, -2.5091e-02,  5.9708e-03,  1.8045e-02, -1.3913e-02,
         -2.6573e-03, -8.0057e-03,  5.1869e-03,  2.1598e-03,  1.0555e-02,
         -9.3080e-03,  9.6361e-03, -1.1027e-02, -2.6547e-02, -1.8047e-02,
          1.2588e-02,  1.5282e-02,  1.3613e-02,  1.5626e-02,  2.5592e-02,
          2.3923e-03, -8.8979e-03, -5.1752e-04, -



torch.Size([50, 3, 84, 84])
torch.Size([50])
torch.Size([25])
before
tensor([ 0.0027,  0.0005, -0.0007,  0.0080, -0.0002], device='cuda:0')
tensor([ 0.0023,  0.0020, -0.0028,  0.0056,  0.0041], device='cuda:0')
tensor([[ 0.0014, -0.0202,  0.0103,  0.0124,  0.0104,  0.0033, -0.0161, -0.0138,
         -0.0054, -0.0069,  0.0266,  0.0012,  0.0153, -0.0037, -0.0029, -0.0306,
         -0.0025,  0.0085, -0.0162,  0.0377, -0.0005, -0.0133, -0.0204, -0.0299,
         -0.0036, -0.0101,  0.0058, -0.0170,  0.0110, -0.0008,  0.0093, -0.0310,
          0.0081,  0.0184, -0.0072,  0.0125,  0.0040, -0.0101, -0.0165,  0.0020,
         -0.0020,  0.0162, -0.0110, -0.0279, -0.0031,  0.0316, -0.0010,  0.0464,
          0.0142,  0.0160, -0.0055, -0.0228, -0.0330,  0.0066,  0.0292, -0.0379,
          0.0091,  0.0117, -0.0065,  0.0215,  0.0016,  0.0004, -0.0101,  0.0048],
        [-0.0066,  0.0447, -0.0256, -0.0019, -0.0613,  0.0214,  0.0097, -0.0514,
          0.0128, -0.0173,  0.0266,  0.0391, -0.0542, -0.00



torch.Size([50])
torch.Size([25])


