In [168]:
import numpy as np
import scipy
import scipy.fftpack as fft

import sys
sys.path.append('../')
from common import *

import torch
import torch_dct as dct
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim.lr_scheduler import StepLR

import torchvision
import torchvision.transforms as transforms
from torchvision import models

In [169]:
def unfold(A):
    return torch.tensor(bcirc(A)[:, :A.shape[1]])

def fold(A, shape):
    return torch.tensor(np.reshape(A, shape))

def bcirc(A):
    s = A.shape
    bcirc_A = np.zeros((s[0] * s[2], s[1] * s[2]))
    A_matriced = np.reshape(np.transpose(A, [0, 1, 2]), (s[0] * s[2], s[1]))
    for k in range(s[2]):
        bcirc_A[:, k * s[1] : (k + 1) * s[1]] = np.roll(A_matriced, k * s[0], axis=0)
    return torch.tensor(bcirc_A)

def t_product(A, B):
    dct_A = dct.dct(A)
    dct_B = dct.dct(B)
    dct_C = torch.zeros(A.shape[0], B.shape[1], A.shape[2])
    for k in range(A.shape[2]):
        dct_C[..., k] = dct_A[..., k] @ dct_B[..., k]
    return torch.tensor(dct.idct(dct_C))

def t_product_v2(A, B):
    shape = [A.shape[0], B.shape[1], A.shape[2]]
    return torch.tensor(fold(bcirc(A) @ unfold(B), shape))

def objective_function(outputs, C):
    return torch.norm(C - outputs) ** 2 / 2

In [276]:
A = torch.rand(2, 3, 4)
B = torch.rand(3, 5, 4)
C = torch.rand(2, 5)
D = torch.rand(5, 5, 4)

In [277]:
print(fft.dct(np.array(A)) - np.array(dct.dct(A)))

[[[ 0.0000000e+00  1.4901161e-08  0.0000000e+00 -1.4901161e-08]
  [ 0.0000000e+00 -5.9604645e-08  0.0000000e+00  1.1920929e-07]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00 -1.1920929e-07]]

 [[ 0.0000000e+00  1.4901161e-08  0.0000000e+00 -1.4901161e-08]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  3.7252903e-09]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]]]


In [278]:
E = t_product(A, B)
t_product(E, D)



tensor([[[109.1029, 108.3608, 108.0490, 109.1149],
         [102.9536, 102.0753, 102.8634, 103.5111],
         [118.8796, 120.4292, 120.6319, 119.1123],
         [ 91.4567,  90.2142,  91.0926,  91.4754],
         [111.0466, 110.0189, 109.2627, 110.6241]],

        [[113.0708, 111.2088, 111.2240, 112.9554],
         [106.4739, 105.4831, 105.1709, 106.9399],
         [122.5860, 123.6918, 123.2501, 122.8943],
         [ 94.0933,  92.8866,  92.6624,  94.3410],
         [114.3657, 112.2729, 112.4617, 114.0144]]])

In [259]:
class Transform_Layer(nn.Module):
    def __init__(self, size_in, size_out, n):
        super().__init__()
        self.size_in = size_in
        self.size_out = size_out
        weights = torch.rand(size_out, size_in, n)
        bias = torch.rand(size_out, 1, n)
        self.weights = nn.Parameter(weights)
        self.bias = nn.Parameter(bias)
        
    def forward(self, x):
        Wx = t_product(self.weights, x).to(device)
        return torch.add(Wx, self.bias)

In [269]:
class Transform_Net(nn.Module):
    def __init__(self):
        super(Transform_Net, self).__init__()
        self.features = nn.Sequential(
            Transform_Layer(28, 16, 128),
            nn.ReLU(inplace=True),
            #Transform_Layer(16, 12, 128),
            #nn.ReLU(inplace=True),
            # Transform_Layer(16, 10, 128),
        )

    def forward(self, x):
        x = self.features(x)
        x = dct.idct(x)
        x = torch.sum(F.log_softmax(x, dim=1), axis=1)
        
        return x

In [270]:
trainloader, testloader = load_mnist()

model = Transform_Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9, weight_decay=5e-4)

def train_step_transform(epoch, train_acc, model, trainloader, optimizer):
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    train_loss = 0
    correct = 0
    total = 0
    train_acc = []
    criterion = nn.CrossEntropyLoss()
    
    print('\nEpoch: ', epoch)
    print('|', end='')
    for batch_idx, (inputs, labels) in enumerate(trainloader):   
        inputs = inputs.squeeze()
        inputs = torch.transpose(inputs, 0, 2)
        inputs = torch.transpose(inputs, 0, 1).to(device)
        labels = labels.to(device)
        #targets = torch.zeros(10, 128)
        #for j in range(targets.shape[1]):
        #    targets[i[1][j], j] = 1
        #inputs, targets = inputs.to(device), targets.to(device)
        if not inputs.shape[2] == 128:
            continue

        
        optimizer.zero_grad()
        outputs = model(inputs)
        print(outputs)
        loss = criterion(torch.transpose(outputs, 0, 1), labels)
        loss.backward()
        
        optimizer.step()
        train_loss += loss.item()
        _, predicted = outputs.max(0)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        if batch_idx % 10 == 0:
            print('=', end='')
    print('|', 'Accuracy:', 100. * correct / total,'% ', correct, '/', total)
    train_acc.append(correct / total)
    
def train_transform(i, model, trainloader, testloader, optimizer):
    train_acc = []
    test_acc = []
    scheduler = StepLR(optimizer, step_size=5, gamma=0.9)
    for epoch in range(i):
        s = time.time()
        train_acc = train_step_transform(epoch, train_acc, model, trainloader, optimizer)
        # test_acc, _ = test(test_acc, model, testloader)
        scheduler.step()
        e = time.time()
        print('This epoch took', e - s, 'seconds to train')
        print('Current learning rate: ', scheduler.get_last_lr()[0])
    print('Best training accuracy overall: ', max(test_acc))
    return train_acc, test_acc

==> Loading data..


In [271]:
train_transform(20, model, trainloader, testloader, optimizer)


Epoch:  0
|tensor([[-27697.8926,  -3948.7805,  -5625.8989,  ...,   -103.8578,
           -103.3520,   -102.8667],
        [-28093.1504,  -4006.5271,  -5568.8608,  ...,   -111.7997,
            -99.6266,   -100.8963],
        [-28500.7812,  -4072.5767,  -5660.5713,  ...,   -115.3003,
           -105.2653,    -99.3140],
        ...,
        [-27886.3008,  -3974.2185,  -5595.7441,  ...,    -99.6970,
           -146.1663,   -112.4078],
        [-28089.1660,  -4028.7615,  -5601.0986,  ...,   -105.4042,
           -102.5223,   -110.2952],
        [-27794.9141,  -3981.5056,  -5551.2129,  ...,   -121.8633,
           -105.3973,   -100.8923]], device='cuda:0', grad_fn=<SumBackward1>)
=tensor([[-26871.4434,  -3733.9688,  -5307.6079,  ...,   -108.6834,
           -100.5483,    -98.9826],
        [-27236.3008,  -3819.2751,  -5499.8330,  ...,   -116.6707,
           -109.6091,   -107.9723],
        [-27668.0859,  -3857.9419,  -5567.2607,  ...,   -106.9280,
           -109.8920,   -178.0916],
     



tensor([[-26550.5117,  -3666.1963,  -5248.6182,  ...,    -99.7950,
           -100.6818,    -99.4391],
        [-26531.8125,  -3753.3809,  -5326.4531,  ...,    -99.3516,
            -95.7179,    -99.0060],
        [-26947.2422,  -3786.1184,  -5383.4004,  ...,    -98.2939,
           -124.2143,   -112.9248],
        ...,
        [-26510.6914,  -3690.1653,  -5292.1123,  ...,   -115.7153,
           -105.5032,    -99.5904],
        [-26764.1582,  -3775.8735,  -5368.7690,  ...,   -113.9205,
            -96.2489,   -102.9120],
        [-26456.6113,  -3701.9856,  -5296.6865,  ...,   -104.9386,
           -115.6209,    -97.1204]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-28461.7148,  -4039.4221,  -5602.6328,  ...,    -97.3947,
            -97.3721,   -102.3243],
        [-28531.8047,  -4092.1877,  -5728.6362,  ...,    -96.1919,
            -95.2210,   -109.7178],
        [-29006.1406,  -4143.6562,  -5892.4595,  ...,   -134.0677,
           -115.4386,   -100.0791],
        ...,
     

tensor([[-28172.2578,  -3896.9163,  -5415.2124,  ...,   -111.0823,
           -110.7522,    -96.3401],
        [-28213.2715,  -3987.5107,  -5731.2031,  ...,   -108.1947,
            -98.8523,   -106.5010],
        [-28666.3418,  -4041.7754,  -5809.8525,  ...,    -96.4946,
           -114.7759,    -99.3390],
        ...,
        [-28080.9023,  -3933.6475,  -5678.5840,  ...,   -104.2115,
           -120.0958,   -106.4868],
        [-28397.6309,  -4007.1040,  -5664.4268,  ...,   -119.8614,
           -110.6204,   -102.4622],
        [-28102.9453,  -3955.0569,  -5617.0469,  ...,   -102.9085,
           -112.3636,   -105.3164]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-27473.4922,  -3915.5935,  -5606.1416,  ...,   -101.7995,
           -107.2471,   -118.8536],
        [-27635.7598,  -3974.5752,  -5483.4839,  ...,    -97.3449,
            -97.0126,   -120.4476],
        [-28085.3906,  -4042.5439,  -5607.6880,  ...,   -112.5196,
           -120.2363,   -115.8220],
        ...,
     

tensor([[-24663.3418,  -3750.6106,  -5027.3955,  ...,   -113.0627,
            -99.7885,   -126.6148],
        [-24936.5117,  -3822.4595,  -5000.6333,  ...,    -95.6880,
           -105.5549,   -119.1233],
        [-25314.1562,  -3861.1274,  -5025.2168,  ...,   -109.4901,
           -108.4183,   -101.5530],
        ...,
        [-24789.3242,  -3768.2065,  -4972.2148,  ...,   -100.0158,
           -118.3875,   -112.5276],
        [-25067.1367,  -3835.9561,  -5078.8613,  ...,   -101.1760,
            -94.8870,   -117.8821],
        [-24769.0098,  -3796.7683,  -4957.7632,  ...,   -100.6659,
            -98.1625,   -102.8384]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-25070.0098,  -3536.2524,  -5112.1523,  ...,   -104.4627,
           -103.4173,    -97.6597],
        [-25196.3398,  -3599.0437,  -4972.4175,  ...,   -107.2891,
           -101.1323,   -111.4810],
        [-25481.0391,  -3629.8972,  -5082.6001,  ...,   -109.0928,
            -97.6241,   -115.1010],
        ...,
     

tensor([[-29093.0684,  -3752.3208,  -5878.2476,  ...,   -101.7096,
           -101.8214,    -99.1657],
        [-29333.2949,  -3803.5303,  -5863.4932,  ...,   -102.2569,
           -101.7099,   -129.2059],
        [-29698.9121,  -3845.6228,  -5967.2666,  ...,    -97.7793,
           -128.9712,   -210.4419],
        ...,
        [-29106.1836,  -3773.2170,  -5817.3291,  ...,    -96.9163,
           -106.3109,    -96.9741],
        [-29411.0234,  -3859.4683,  -5878.9282,  ...,    -97.8174,
           -105.1408,   -104.2459],
        [-29205.3906,  -3798.7271,  -5872.5107,  ...,    -96.1970,
            -97.6727,   -124.3292]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-25973.3086,  -3831.6401,  -5257.0518,  ...,    -99.4525,
            -99.8062,   -101.6083],
        [-25877.3594,  -3899.3176,  -5236.3008,  ...,   -102.7854,
           -110.1999,   -105.2572],
        [-26306.1797,  -3975.4417,  -5280.4951,  ...,   -105.0712,
            -98.3214,   -121.4571],
        ...,
     

tensor([[-29657.0332,  -3972.0298,  -6049.7280,  ...,    -98.7817,
           -101.2136,   -115.1268],
        [-29596.8438,  -4024.3501,  -5910.0645,  ...,   -100.6853,
           -104.9046,   -102.2128],
        [-30121.3828,  -4095.4175,  -6004.3516,  ...,   -109.0359,
           -110.1736,   -142.2424],
        ...,
        [-29534.8418,  -3993.0464,  -5916.8604,  ...,   -116.3781,
           -172.8033,   -114.0783],
        [-29800.3789,  -4041.2275,  -5919.7852,  ...,   -124.9653,
           -122.4395,   -120.4790],
        [-29547.3984,  -3982.4602,  -5961.3535,  ...,   -105.8707,
            -95.6308,   -106.0421]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-27838.4961,  -3766.1997,  -5503.3975,  ...,   -118.1520,
           -104.0972,   -106.6457],
        [-28052.2168,  -3839.4224,  -5585.3096,  ...,    -96.2683,
           -102.1132,   -102.2084],
        [-28480.3516,  -3890.0923,  -5753.8677,  ...,   -116.7310,
           -102.8925,   -108.3793],
        ...,
     

tensor([[-25943.9102,  -4077.8679,  -5125.6660,  ...,   -100.9172,
           -102.9502,    -99.1579],
        [-26002.0410,  -4118.7397,  -5204.4609,  ...,   -100.2975,
            -97.4947,   -102.8887],
        [-26344.9961,  -4165.9922,  -5238.2168,  ...,   -105.8758,
           -123.6046,   -111.5263],
        ...,
        [-25824.4199,  -4070.8140,  -5220.9463,  ...,   -116.0222,
           -109.9284,   -101.0140],
        [-26073.0508,  -4158.2471,  -5192.0107,  ...,   -104.6298,
           -100.1487,   -116.8857],
        [-25830.9824,  -4086.2891,  -5195.1504,  ...,   -134.8685,
           -104.6893,    -98.7277]], device='cuda:0', grad_fn=<SumBackward1>)
=tensor([[-25236.1582,  -4235.6743,  -5029.3491,  ...,    -97.6169,
            -98.9183,   -101.2645],
        [-25431.9180,  -4285.8984,  -5084.7427,  ...,   -104.6142,
           -113.3480,   -111.2870],
        [-25796.0254,  -4339.4917,  -5130.5757,  ...,   -113.3423,
           -133.2783,   -147.9792],
        ...,
    

           -102.3372,    -95.3831]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-26901.5703,  -3608.1758,  -5313.3032,  ...,    -98.8894,
           -112.1041,   -109.0574],
        [-27301.7188,  -3665.8911,  -5519.3818,  ...,   -103.5575,
            -96.1866,   -119.3723],
        [-27668.3789,  -3714.4717,  -5560.1655,  ...,    -96.7371,
           -102.8340,   -117.7153],
        ...,
        [-27101.5586,  -3623.8716,  -5449.1484,  ...,   -108.0486,
           -108.2526,    -96.8183],
        [-27410.3691,  -3686.5093,  -5410.8623,  ...,   -113.9005,
            -96.7785,   -104.8866],
        [-27038.5898,  -3628.6567,  -5427.8701,  ...,   -103.8646,
           -105.3064,    -96.0234]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-24554.7578,  -3521.1807,  -4908.7988,  ...,   -109.3319,
           -100.8652,   -107.5869],
        [-24557.3301,  -3562.5234,  -4882.0869,  ...,   -123.4205,
            -96.5733,   -108.4066],
        [-24918.3359,  -3632.4563,  -4998.0

=tensor([[-27823.9805,  -3436.0330,  -5502.2275,  ...,    -99.0292,
            -97.2002,   -105.3316],
        [-28212.2227,  -3485.9795,  -5614.4282,  ...,   -102.0532,
           -107.1910,   -107.3562],
        [-28585.5098,  -3522.9702,  -5718.5425,  ...,   -107.2796,
           -103.9871,   -115.4937],
        ...,
        [-27942.2070,  -3450.2742,  -5561.6748,  ...,   -105.1427,
           -101.4725,    -95.6549],
        [-28198.6348,  -3532.3877,  -5563.6309,  ...,    -97.8495,
           -102.6304,   -108.8493],
        [-27918.6719,  -3469.1956,  -5527.6865,  ...,   -100.1172,
           -113.4158,   -116.2661]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-26394.8008,  -3773.6753,  -5291.9575,  ...,   -102.3051,
           -100.6971,   -108.9271],
        [-26515.3105,  -3857.4749,  -5285.5078,  ...,   -104.2105,
            -99.6358,   -117.5562],
        [-26857.0430,  -3946.5210,  -5416.9990,  ...,    -99.3489,
           -112.9219,   -111.5316],
        ...,
    

tensor([[-27390.0430,  -3587.5542,  -5508.3936,  ...,    -97.1775,
           -106.4219,   -102.9514],
        [-27676.1406,  -3651.1265,  -5524.3452,  ...,   -103.4281,
            -96.9153,   -105.6846],
        [-28016.7070,  -3681.4465,  -5589.5190,  ...,    -99.4517,
           -111.9072,   -124.3341],
        ...,
        [-27497.0000,  -3594.0630,  -5432.6025,  ...,   -109.7407,
           -113.8953,   -119.1210],
        [-27588.5645,  -3670.0056,  -5490.3296,  ...,    -97.6185,
           -111.8717,   -105.0322],
        [-27388.1250,  -3620.8958,  -5505.9497,  ...,    -99.0580,
           -105.0950,   -108.0448]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-30443.0273,  -3577.0505,  -6027.2061,  ...,   -100.9872,
            -97.4551,    -97.7261],
        [-30450.1035,  -3661.8135,  -6058.5938,  ...,    -98.0309,
            -96.3571,    -97.0037],
        [-30932.7500,  -3702.9028,  -6194.1758,  ...,   -111.4732,
           -111.1436,   -109.1865],
        ...,
     

tensor([[-25398.6504,  -3684.0996,  -5009.0571,  ...,   -110.9197,
           -101.3641,   -100.6151],
        [-25556.0840,  -3752.6758,  -5087.7017,  ...,   -105.1189,
           -110.4691,    -98.9328],
        [-25877.3867,  -3808.8872,  -5172.0605,  ...,   -106.3663,
           -144.2460,   -112.0036],
        ...,
        [-25327.6914,  -3710.0073,  -5083.2109,  ...,   -116.3866,
           -118.8390,   -100.8246],
        [-25597.7324,  -3763.4365,  -5116.5859,  ...,   -100.4179,
           -127.0890,   -105.3405],
        [-25317.7227,  -3736.2568,  -4988.3696,  ...,   -102.8362,
           -109.5156,    -96.2555]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-26252.5430,  -3414.3657,  -5306.4561,  ...,   -106.0393,
            -97.9241,   -100.7142],
        [-26316.9121,  -3458.0063,  -5269.9551,  ...,   -102.2271,
            -96.9434,    -96.7672],
        [-26748.6543,  -3529.6023,  -5353.7568,  ...,   -106.2693,
           -110.9175,   -115.9902],
        ...,
     

tensor([[-29017.6250,  -3771.4604,  -5750.9248,  ...,   -100.6898,
           -108.1009,   -121.1219],
        [-29124.6074,  -3805.6423,  -5854.4653,  ...,   -132.5526,
           -108.5588,   -108.7402],
        [-29586.5137,  -3871.2659,  -5914.6299,  ...,   -115.1401,
           -140.3654,   -102.0919],
        ...,
        [-29016.0156,  -3754.1523,  -5792.9307,  ...,   -117.3872,
           -145.8654,   -124.3170],
        [-29373.3984,  -3850.0562,  -5834.5342,  ...,   -112.1513,
           -108.5468,    -98.9013],
        [-28965.3105,  -3786.7754,  -5739.6201,  ...,   -105.5056,
           -101.3773,   -109.8320]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-27653.4922,  -3974.7747,  -5683.8877,  ...,    -99.9726,
           -101.0714,   -101.9212],
        [-27788.3828,  -4057.3457,  -5472.7563,  ...,    -97.0443,
           -106.4513,   -110.9792],
        [-28146.3301,  -4101.7207,  -5629.0713,  ...,    -98.0289,
            -99.1717,   -120.2204],
        ...,
     

tensor([[-25500.0820,  -3686.4441,  -5074.0615,  ...,    -98.2370,
           -104.3678,   -104.7135],
        [-25538.3438,  -3730.7778,  -5121.9722,  ...,    -98.3842,
           -113.0862,   -108.5772],
        [-25852.3398,  -3772.1738,  -5147.4990,  ...,    -99.3532,
           -107.0486,   -108.2073],
        ...,
        [-25438.3340,  -3685.9902,  -5104.0176,  ...,   -108.8853,
           -122.8591,   -100.3022],
        [-25659.7461,  -3810.4954,  -5111.5342,  ...,   -129.8167,
           -107.3536,   -117.2656],
        [-25390.5117,  -3690.5005,  -5094.1992,  ...,   -103.1147,
           -101.1769,   -101.1856]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-27859.7188,  -4030.0334,  -5610.9604,  ...,    -95.5001,
           -101.0059,   -105.2922],
        [-27742.3906,  -4073.5728,  -5544.7754,  ...,    -98.5461,
            -99.4792,   -120.1447],
        [-28181.8828,  -4142.8530,  -5666.3569,  ...,   -105.9034,
            -96.8987,   -146.5091],
        ...,
     

           -107.7975,   -102.3404]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-28211.3203,  -4283.0552,  -5647.9888,  ...,   -114.9528,
            -98.9008,   -104.7865],
        [-28405.8789,  -4379.8057,  -5616.1401,  ...,   -107.5700,
           -115.6368,   -106.7947],
        [-28758.0977,  -4438.2021,  -5808.8745,  ...,   -102.6235,
           -123.3615,    -96.8352],
        ...,
        [-28262.8203,  -4321.3706,  -5660.2827,  ...,   -104.8692,
           -120.0957,   -106.9126],
        [-28534.3770,  -4396.3730,  -5748.0674,  ...,   -135.2244,
            -98.7365,   -101.5898],
        [-28192.3770,  -4325.2842,  -5579.4683,  ...,   -100.6981,
            -99.3802,    -94.7713]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-28213.0586,  -4148.1138,  -5721.3745,  ...,   -108.2043,
           -105.5410,   -120.9171],
        [-28291.5645,  -4176.1313,  -5641.5645,  ...,    -99.5779,
            -94.7915,   -100.2438],
        [-28677.3965,  -4268.9854,  -5775.5

tensor([[-26419.0859,  -3529.2202,  -5226.4668,  ...,    -99.0557,
           -114.6266,   -110.4377],
        [-26426.8730,  -3607.2905,  -5243.1567,  ...,    -98.4597,
            -97.9798,   -126.7176],
        [-26844.5742,  -3665.4392,  -5402.3545,  ...,   -100.7169,
           -115.0652,   -114.4217],
        ...,
        [-26372.1953,  -3549.1980,  -5276.5718,  ...,   -106.2042,
           -197.3991,    -99.7301],
        [-26532.4531,  -3631.8843,  -5316.2637,  ...,   -100.1569,
           -113.1099,   -121.8333],
        [-26296.1387,  -3554.8699,  -5256.2100,  ...,   -109.0402,
           -115.4172,    -99.2378]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-31600.0508,  -3700.9780,  -6321.6885,  ...,    -98.4789,
            -95.1348,   -100.7347],
        [-31681.9668,  -3765.7327,  -6338.8574,  ...,   -103.9371,
           -100.0599,   -102.0391],
        [-32094.2148,  -3826.5200,  -6412.8149,  ...,   -127.5976,
           -135.7019,   -115.3816],
        ...,
     

tensor([[-26616.0859,  -3840.3232,  -5278.9917,  ...,   -101.0584,
           -103.0928,   -100.0665],
        [-26643.7324,  -3898.4370,  -5390.9434,  ...,   -101.4929,
           -107.4530,   -118.8843],
        [-27073.6738,  -3933.6094,  -5426.6440,  ...,   -110.1085,
           -103.3325,   -110.0870],
        ...,
        [-26549.6953,  -3858.1221,  -5345.7085,  ...,   -102.3947,
            -99.1662,    -97.9292],
        [-26748.3828,  -3958.1255,  -5380.0405,  ...,   -103.0258,
           -118.9105,   -115.0086],
        [-26468.0703,  -3869.2065,  -5211.8037,  ...,   -100.4402,
           -102.4184,    -95.4206]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-32168.6738,  -4111.3887,  -6369.0674,  ...,    -97.2169,
           -102.3493,   -102.9894],
        [-32163.9785,  -4182.0024,  -6440.7627,  ...,   -126.5556,
           -117.1791,   -112.4019],
        [-32582.8145,  -4232.2617,  -6475.1519,  ...,    -97.2489,
           -106.7547,   -121.9707],
        ...,
     

tensor([[-32213.0312,  -3968.6921,  -6374.6650,  ...,    -95.5873,
           -130.3595,    -98.7274],
        [-32203.3047,  -4018.0452,  -6519.5107,  ...,    -96.5267,
           -119.8512,   -107.5323],
        [-32780.4258,  -4095.0537,  -6517.3408,  ...,   -100.1820,
           -118.1859,   -130.9664],
        ...,
        [-32111.4492,  -3972.5938,  -6422.9150,  ...,    -96.2555,
           -182.4842,   -118.7055],
        [-32428.7129,  -4067.1943,  -6414.5430,  ...,    -99.5310,
           -107.8507,   -116.3569],
        [-32133.9961,  -3992.5532,  -6438.9414,  ...,   -133.9325,
           -100.0723,   -101.8190]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-28129.2852,  -4281.6733,  -5615.3403,  ...,   -106.8378,
           -115.1835,    -99.5727],
        [-27875.2539,  -4355.3867,  -5614.9468,  ...,   -104.1826,
           -106.4231,    -98.8455],
        [-28261.9746,  -4410.8540,  -5611.8872,  ...,   -105.9644,
           -107.2134,   -159.9601],
        ...,
     

tensor([[-29490.3633,  -3772.3030,  -5857.0039,  ...,   -101.7617,
            -97.5126,    -99.3435],
        [-29480.0215,  -3844.0657,  -5929.2402,  ...,    -99.4067,
            -97.3038,   -108.4928],
        [-29989.6934,  -3906.9966,  -5977.8315,  ...,   -107.8072,
           -130.3372,   -117.5990],
        ...,
        [-29366.6289,  -3799.5430,  -5924.1289,  ...,   -104.1904,
           -122.2388,   -100.7133],
        [-29682.4141,  -3867.2202,  -5859.7490,  ...,    -96.0618,
           -110.8734,   -102.2393],
        [-29319.5391,  -3808.7812,  -5881.2295,  ...,   -112.8849,
           -121.9820,   -101.0222]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-27647.8145,  -3872.4058,  -5479.8564,  ...,    -97.5266,
            -99.0413,   -108.1187],
        [-27910.3887,  -3936.6421,  -5586.8984,  ...,   -102.7058,
            -99.6690,   -107.5672],
        [-28249.6250,  -4001.3003,  -5644.7285,  ...,   -101.0046,
           -108.8335,   -101.3265],
        ...,
     

tensor([[-30015.4375,  -4045.3394,  -6024.5762,  ...,   -100.6754,
           -102.6417,   -109.4224],
        [-30210.5234,  -4109.1943,  -5990.1152,  ...,   -103.0561,
            -98.7603,   -100.2777],
        [-30572.8555,  -4169.4136,  -6142.3223,  ...,   -115.5532,
           -101.8633,   -103.0591],
        ...,
        [-30075.1699,  -4048.9492,  -6022.4043,  ...,   -100.3919,
            -97.5989,   -115.1837],
        [-30316.4375,  -4118.7866,  -6079.8535,  ...,   -107.5030,
           -102.1056,   -104.7874],
        [-29926.4434,  -4086.4502,  -6006.5815,  ...,   -113.6880,
           -134.2599,   -103.4608]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-29186.1152,  -3869.5986,  -5766.9990,  ...,    -99.9222,
           -110.3917,    -99.9634],
        [-29083.6211,  -3926.2314,  -5844.7876,  ...,    -98.3270,
           -104.6000,   -121.6616],
        [-29645.9141,  -3981.6548,  -5908.5137,  ...,    -99.8633,
           -109.8178,   -182.6616],
        ...,
     

tensor([[-29462.0664,  -4162.0293,  -5824.2842,  ...,   -104.0531,
           -100.9723,    -99.8566],
        [-29383.0977,  -4197.2593,  -5856.1680,  ...,   -105.8340,
           -115.5551,   -104.9662],
        [-29873.0449,  -4296.5879,  -5949.6040,  ...,   -111.0216,
           -107.3899,   -129.6761],
        ...,
        [-29301.1465,  -4135.9829,  -5790.1250,  ...,   -107.0395,
            -94.9991,   -123.0347],
        [-29627.2090,  -4258.0801,  -5944.6646,  ...,    -98.3366,
            -96.2091,   -104.6123],
        [-29228.6777,  -4172.6836,  -5872.7886,  ...,   -112.0306,
           -108.2981,   -117.9199]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-27897.8828,  -3635.2737,  -5543.9194,  ...,   -100.5874,
           -120.8000,   -108.1551],
        [-27857.8730,  -3707.0725,  -5526.3135,  ...,   -107.9836,
            -97.3483,   -113.6136],
        [-28311.5605,  -3734.0615,  -5663.3799,  ...,   -102.3706,
           -102.3229,   -122.2007],
        ...,
     

tensor([[-26608.7090,  -3778.6653,  -5314.7285,  ...,   -101.5212,
            -96.9226,   -100.1325],
        [-26660.7930,  -3833.1338,  -5321.2988,  ...,   -105.7802,
           -103.4762,   -100.8511],
        [-27005.9844,  -3870.7268,  -5463.8936,  ...,   -101.6596,
            -96.6505,    -97.7435],
        ...,
        [-26588.8750,  -3790.7344,  -5326.4014,  ...,   -106.9669,
           -103.0852,    -99.6990],
        [-26740.1367,  -3877.7788,  -5362.9150,  ...,   -106.4308,
            -98.5028,   -112.8503],
        [-26562.2617,  -3807.7551,  -5307.1494,  ...,    -99.8486,
            -97.6306,   -106.9976]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-25031.0078,  -3543.2791,  -5004.8408,  ...,   -109.8043,
           -104.8065,    -97.6604],
        [-24916.2500,  -3594.4487,  -4977.6396,  ...,   -102.2185,
           -114.9814,   -116.5941],
        [-25331.0332,  -3648.1538,  -5054.0698,  ...,   -105.6981,
           -106.0975,   -131.3170],
        ...,
     

tensor([[-29356.2383,  -3848.1943,  -5843.0679,  ...,    -99.1376,
           -110.0670,   -114.9065],
        [-29599.5938,  -3902.0076,  -5921.3721,  ...,   -101.7087,
            -97.7180,   -124.5844],
        [-29975.2812,  -3956.4761,  -5926.4385,  ...,   -124.1219,
           -106.5145,   -119.8485],
        ...,
        [-29401.8320,  -3861.9180,  -5870.0723,  ...,   -108.6327,
           -103.7439,   -100.9906],
        [-29485.0312,  -3912.3464,  -5838.0806,  ...,   -109.1431,
           -107.2531,   -103.3486],
        [-29265.6152,  -3885.4031,  -5858.8486,  ...,   -121.3054,
           -101.3957,   -101.8355]], device='cuda:0', grad_fn=<SumBackward1>)
=tensor([[-26839.2461,  -3733.6230,  -5435.1729,  ...,    -99.9305,
            -96.3287,   -100.8337],
        [-27194.2500,  -3792.0493,  -5468.9575,  ...,    -97.5870,
            -98.6014,    -98.2579],
        [-27501.5781,  -3824.6895,  -5489.8994,  ...,   -102.0450,
           -107.3611,   -115.8975],
        ...,
    

           -105.8303,    -98.2107]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-27768.9961,  -3597.8716,  -5575.3008,  ...,   -109.8605,
           -111.8363,    -99.3656],
        [-27744.9492,  -3666.3691,  -5542.8462,  ...,    -97.9284,
            -96.3547,   -103.3232],
        [-28198.8203,  -3701.2532,  -5584.0762,  ...,    -99.3433,
           -122.6797,   -132.0747],
        ...,
        [-27710.0801,  -3615.9509,  -5574.4707,  ...,   -105.0218,
            -95.6494,    -99.3905],
        [-27979.9336,  -3677.3552,  -5584.9922,  ...,   -115.9819,
           -101.3225,   -112.4698],
        [-27628.6230,  -3623.9417,  -5532.0527,  ...,   -114.5046,
           -103.1094,    -99.1372]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-28367.9688,  -3959.3872,  -5598.2446,  ...,    -97.4531,
           -130.4385,   -100.2323],
        [-28353.9004,  -4023.4065,  -5642.7778,  ...,   -109.1143,
            -98.4170,   -120.7880],
        [-28840.7910,  -4074.0261,  -5742.8

           -113.5271,    -96.5710]], device='cuda:0', grad_fn=<SumBackward1>)
=tensor([[-26282.8750,  -3861.4004,  -5259.6270,  ...,    -99.8793,
            -98.1659,    -99.9234],
        [-26410.9980,  -3924.6157,  -5411.7339,  ...,   -102.0436,
           -100.1400,   -110.0899],
        [-26742.3730,  -3977.0676,  -5335.9058,  ...,   -102.7987,
           -117.9989,   -105.3172],
        ...,
        [-26229.3926,  -3883.9973,  -5307.0513,  ...,   -133.6361,
           -102.2086,    -97.4975],
        [-26442.0293,  -3960.7051,  -5351.5781,  ...,   -104.7719,
           -121.8722,   -107.9092],
        [-26280.1016,  -3890.6636,  -5245.7578,  ...,   -113.4170,
           -111.4768,   -100.4063]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-27812.4688,  -3773.2139,  -5605.0996,  ...,   -105.6260,
            -95.7308,    -97.2051],
        [-27863.1875,  -3844.1436,  -5553.5977,  ...,   -100.9858,
            -98.6728,   -103.9337],
        [-28222.2012,  -3890.5640,  -5653.

tensor([[-26678.1758,  -3712.9109,  -5339.0771,  ...,   -105.3205,
           -100.2919,   -105.9425],
        [-26753.6738,  -3754.3320,  -5399.9707,  ...,   -105.8830,
            -96.8440,    -99.1047],
        [-27101.1797,  -3814.1741,  -5469.2944,  ...,   -100.4977,
           -115.7078,   -116.9635],
        ...,
        [-26556.6797,  -3697.6738,  -5335.8516,  ...,   -107.9850,
           -106.4783,   -102.6897],
        [-26847.0938,  -3775.8354,  -5380.6440,  ...,   -100.9485,
           -100.8694,   -121.5639],
        [-26601.1055,  -3735.3262,  -5396.5171,  ...,    -98.8080,
           -107.2287,   -101.8301]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-26153.6484,  -3792.0859,  -5277.2822,  ...,   -101.0188,
           -103.6035,   -102.7643],
        [-26374.9062,  -3887.8284,  -5280.9932,  ...,   -106.9509,
            -99.2424,   -112.3525],
        [-26690.0781,  -3924.5181,  -5374.6484,  ...,   -113.5108,
            -98.4763,   -123.7832],
        ...,
     

tensor([[-28192.3789,  -3814.9873,  -5673.8203,  ...,   -101.0092,
           -107.1350,   -118.8691],
        [-28439.1719,  -3881.6484,  -5627.1357,  ...,   -106.3858,
            -96.7402,   -148.5062],
        [-28789.4961,  -3929.4482,  -5748.8955,  ...,    -97.5563,
           -151.5889,   -110.7031],
        ...,
        [-28206.7129,  -3836.7510,  -5676.6206,  ...,   -102.5650,
           -107.3316,    -98.5539],
        [-28500.0039,  -3922.2847,  -5683.2744,  ...,   -114.1114,
           -106.5010,    -98.7955],
        [-28205.7754,  -3851.6711,  -5579.8608,  ...,   -104.3493,
            -97.0789,   -102.0686]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-30266.5332,  -3457.6240,  -6125.5156,  ...,    -96.2458,
            -99.1428,   -110.9294],
        [-30242.9258,  -3515.4578,  -6034.2168,  ...,    -97.9363,
            -98.3818,   -104.3195],
        [-30734.0137,  -3565.8569,  -6146.7090,  ...,   -102.4725,
           -101.2706,   -110.3645],
        ...,
     

           -113.0341,   -127.7116]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-27093.2578,  -3752.6074,  -5366.5879,  ...,    -96.0307,
           -100.9246,   -106.5418],
        [-27063.8848,  -3791.0564,  -5451.6309,  ...,   -101.9485,
            -96.3648,   -106.7849],
        [-27434.7207,  -3855.8774,  -5527.4961,  ...,    -97.5570,
           -104.8651,   -104.1083],
        ...,
        [-27091.0820,  -3735.2437,  -5399.3091,  ...,   -102.9447,
           -101.6246,   -100.0771],
        [-27331.4062,  -3798.8843,  -5441.4951,  ...,   -100.6368,
           -100.4535,   -102.1235],
        [-26932.1328,  -3786.7031,  -5397.9971,  ...,   -136.3498,
           -108.9560,    -97.2329]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-29307.4375,  -4019.0371,  -5853.5874,  ...,    -98.5897,
           -100.3024,    -98.2523],
        [-29142.8164,  -4082.4729,  -5868.3242,  ...,   -108.1451,
           -115.7802,   -105.3222],
        [-29661.3750,  -4134.3550,  -5971.6

tensor([[-27435.9922,  -3576.3904,  -5542.0132,  ...,   -107.2256,
           -105.1972,   -100.6711],
        [-27547.1523,  -3624.5205,  -5518.1133,  ...,    -99.2175,
           -109.8758,   -123.2603],
        [-27981.7910,  -3690.2322,  -5660.1343,  ...,   -103.4595,
           -114.3569,   -107.6970],
        ...,
        [-27430.7793,  -3592.4375,  -5475.1816,  ...,    -99.1526,
           -118.5176,    -96.0953],
        [-27809.9590,  -3650.5442,  -5527.9995,  ...,   -105.0534,
           -104.4484,   -145.0869],
        [-27332.5703,  -3613.7080,  -5469.6357,  ...,   -109.3510,
           -100.1567,    -98.6735]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-27020.1328,  -3920.5483,  -5480.3447,  ...,   -101.1127,
           -102.6547,   -111.5900],
        [-27067.2246,  -3987.9805,  -5437.7954,  ...,    -99.9556,
           -105.4231,   -108.9794],
        [-27401.8555,  -4046.9407,  -5502.3462,  ...,   -103.5825,
           -103.2555,   -120.3190],
        ...,
     

KeyboardInterrupt: 