In [2]:
# FizzBuzz简单小游戏：整除3:fizz;整除5:Buzz;整除15:FizzBuzz;
# 4分类问题：0，1，2，3 

def fizz_buzz_encoder(i):
    if i % 15 == 0 : return 3
    elif i % 5 == 0 : return 2
    elif i % 3 == 0 : return 1
    else : return 0

def fizz_buzz_decoder(i,prediction):
    return [str(i),"fizz","buzz","fizzbuzz"][prediction] # 输出这个长度为4的字符串的第几个值

def helper(i):
    print(fizz_buzz_decoder(i,fizz_buzz_encoder(i)))
    
for i in range(1,16):
    helper(i)

1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz


In [3]:
# 现在我们训练一个神经网络来玩这个游戏
import numpy as np
import torch

NUM_DIGITS = 10

# 把数字转换为二进制的数组:相当于是做词嵌入表征
def binary_encode(i,num_digits):
    return np.array([i>>d & 1 for d in range(num_digits)][::-1])

# 15移0，1，2，3位置时的值分别为15，7，3，1 与1 求与运算得到的是1，后面等到的是0，
# 相当于转换成了一个10位的二进制数
print(binary_encode(15,NUM_DIGITS))# array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1])
 
#我们101-1024(2的十次方)之间的数当做训练数据
trainX = torch.Tensor([binary_encode(i,NUM_DIGITS) for i in range(101,2**NUM_DIGITS)])
trainY = torch.LongTensor([fizz_buzz_encoder(i) for i in range(101,2**NUM_DIGITS)])

[0 0 0 0 0 0 1 1 1 1]


In [4]:
print(trainX.shape)
print(trainY.shape)

torch.Size([923, 10])
torch.Size([923])


In [5]:
NUM_HIDDEN = 100
# 自定义一个模型
model = torch.nn.Sequential(
    torch.nn.Linear(NUM_DIGITS,NUM_HIDDEN),
    torch.nn.ReLU(),
    torch.nn.Linear(NUM_HIDDEN,4) # 4中类型的分类问题
)

# 定义LossFunction
loss_fn = torch.nn.CrossEntropyLoss() # 多分类问题的损失函数

# 定义优化器
# optimizer = torch.optim.SGD(model.parameters(),lr=0.05)
optimizer = torch.optim.Adam(model.parameters(),lr=0.01)



In [6]:
BATCH_SIZE = 128
# 进行模型训练
for epoch in range(1000):
    for start in range(0,len(trainX),BATCH_SIZE):
#         print(start)#0 128 256 384 512 640 768 896
        end = start + BATCH_SIZE
        batchX = trainX[start:end]
        batchY = trainY[start:end]
        
        if torch.cuda.is_available():
            batchX = batchX.cuda()
            batchY = batchY.cuda()
            
        # 前向传播计算出预测值
        y_pred = model(batchX)
        
        # 计算Loss的值
        loss = loss_fn(y_pred,batchY)
#         print(y_pred.shape)#torch.Size([128, 4])
#         print(batchY.shape)#torch.Size([128])
        
        print("Epoch",epoch,":",loss.item())
        
        # 优化器清空grad
        optimizer.zero_grad()
        
        # 进行反向传播:backpass
        loss.backward() 
        
        #进行梯度下降:gradient descent
        optimizer.step() 

torch.Size([128, 4])
torch.Size([128])
Epoch 0 : 1.456317663192749
torch.Size([128, 4])
torch.Size([128])
Epoch 0 : 1.3361343145370483
torch.Size([128, 4])
torch.Size([128])
Epoch 0 : 1.216068148612976
torch.Size([128, 4])
torch.Size([128])
Epoch 0 : 1.1882133483886719
torch.Size([128, 4])
torch.Size([128])
Epoch 0 : 1.170426368713379
torch.Size([128, 4])
torch.Size([128])
Epoch 0 : 1.1891570091247559
torch.Size([128, 4])
torch.Size([128])
Epoch 0 : 1.220849633216858
torch.Size([27, 4])
torch.Size([27])
Epoch 0 : 1.1772310733795166
torch.Size([128, 4])
torch.Size([128])
Epoch 1 : 1.137791633605957
torch.Size([128, 4])
torch.Size([128])
Epoch 1 : 1.1521611213684082
torch.Size([128, 4])
torch.Size([128])
Epoch 1 : 1.1414523124694824
torch.Size([128, 4])
torch.Size([128])
Epoch 1 : 1.1672433614730835
torch.Size([128, 4])
torch.Size([128])
Epoch 1 : 1.1808069944381714
torch.Size([128, 4])
torch.Size([128])
Epoch 1 : 1.1634910106658936
torch.Size([128, 4])
torch.Size([128])
Epoch 1 : 1.1741

torch.Size([128, 4])
torch.Size([128])
Epoch 21 : 1.0105615854263306
torch.Size([128, 4])
torch.Size([128])
Epoch 21 : 0.8528810143470764
torch.Size([128, 4])
torch.Size([128])
Epoch 21 : 0.9396006464958191
torch.Size([27, 4])
torch.Size([27])
Epoch 21 : 0.7874893546104431
torch.Size([128, 4])
torch.Size([128])
Epoch 22 : 1.0033429861068726
torch.Size([128, 4])
torch.Size([128])
Epoch 22 : 1.0306566953659058
torch.Size([128, 4])
torch.Size([128])
Epoch 22 : 0.8780520558357239
torch.Size([128, 4])
torch.Size([128])
Epoch 22 : 0.8857570290565491
torch.Size([128, 4])
torch.Size([128])
Epoch 22 : 0.9517205953598022
torch.Size([128, 4])
torch.Size([128])
Epoch 22 : 0.9130266904830933
torch.Size([128, 4])
torch.Size([128])
Epoch 22 : 1.1013177633285522
torch.Size([27, 4])
torch.Size([27])
Epoch 22 : 0.7737247347831726
torch.Size([128, 4])
torch.Size([128])
Epoch 23 : 0.8046326041221619
torch.Size([128, 4])
torch.Size([128])
Epoch 23 : 0.8357383012771606
torch.Size([128, 4])
torch.Size([128])

torch.Size([128, 4])
torch.Size([128])
Epoch 45 : 0.5180726647377014
torch.Size([128, 4])
torch.Size([128])
Epoch 45 : 0.4662136733531952
torch.Size([128, 4])
torch.Size([128])
Epoch 45 : 0.5870369076728821
torch.Size([27, 4])
torch.Size([27])
Epoch 45 : 0.3789123594760895
torch.Size([128, 4])
torch.Size([128])
Epoch 46 : 0.46079719066619873
torch.Size([128, 4])
torch.Size([128])
Epoch 46 : 0.5034244656562805
torch.Size([128, 4])
torch.Size([128])
Epoch 46 : 0.4433671534061432
torch.Size([128, 4])
torch.Size([128])
Epoch 46 : 0.5050382018089294
torch.Size([128, 4])
torch.Size([128])
Epoch 46 : 0.5371629595756531
torch.Size([128, 4])
torch.Size([128])
Epoch 46 : 0.4970918595790863
torch.Size([128, 4])
torch.Size([128])
Epoch 46 : 0.541995108127594
torch.Size([27, 4])
torch.Size([27])
Epoch 46 : 0.36541298031806946
torch.Size([128, 4])
torch.Size([128])
Epoch 47 : 0.4673336446285248
torch.Size([128, 4])
torch.Size([128])
Epoch 47 : 0.4843168258666992
torch.Size([128, 4])
torch.Size([128]

torch.Size([128, 4])
torch.Size([128])
Epoch 71 : 0.30288833379745483
torch.Size([27, 4])
torch.Size([27])
Epoch 71 : 0.18284361064434052
torch.Size([128, 4])
torch.Size([128])
Epoch 72 : 0.2793649137020111
torch.Size([128, 4])
torch.Size([128])
Epoch 72 : 0.3126964867115021
torch.Size([128, 4])
torch.Size([128])
Epoch 72 : 0.22883205115795135
torch.Size([128, 4])
torch.Size([128])
Epoch 72 : 0.25571855902671814
torch.Size([128, 4])
torch.Size([128])
Epoch 72 : 0.2585347294807434
torch.Size([128, 4])
torch.Size([128])
Epoch 72 : 0.24915169179439545
torch.Size([128, 4])
torch.Size([128])
Epoch 72 : 0.33076485991477966
torch.Size([27, 4])
torch.Size([27])
Epoch 72 : 0.223793625831604
torch.Size([128, 4])
torch.Size([128])
Epoch 73 : 0.2572113871574402
torch.Size([128, 4])
torch.Size([128])
Epoch 73 : 0.23438617587089539
torch.Size([128, 4])
torch.Size([128])
Epoch 73 : 0.2121599316596985
torch.Size([128, 4])
torch.Size([128])
Epoch 73 : 0.30095621943473816
torch.Size([128, 4])
torch.Size

torch.Size([128, 4])
torch.Size([128])
Epoch 98 : 0.14713962376117706
torch.Size([128, 4])
torch.Size([128])
Epoch 98 : 0.12497281283140182
torch.Size([128, 4])
torch.Size([128])
Epoch 98 : 0.17898447811603546
torch.Size([27, 4])
torch.Size([27])
Epoch 98 : 0.08144471049308777
torch.Size([128, 4])
torch.Size([128])
Epoch 99 : 0.1112695038318634
torch.Size([128, 4])
torch.Size([128])
Epoch 99 : 0.09766382724046707
torch.Size([128, 4])
torch.Size([128])
Epoch 99 : 0.10723351687192917
torch.Size([128, 4])
torch.Size([128])
Epoch 99 : 0.1549077033996582
torch.Size([128, 4])
torch.Size([128])
Epoch 99 : 0.15721547603607178
torch.Size([128, 4])
torch.Size([128])
Epoch 99 : 0.13447365164756775
torch.Size([128, 4])
torch.Size([128])
Epoch 99 : 0.16113439202308655
torch.Size([27, 4])
torch.Size([27])
Epoch 99 : 0.08077407628297806
torch.Size([128, 4])
torch.Size([128])
Epoch 100 : 0.11235202103853226
torch.Size([128, 4])
torch.Size([128])
Epoch 100 : 0.10213059931993484
torch.Size([128, 4])
tor

Epoch 123 : 0.11638812720775604
torch.Size([128, 4])
torch.Size([128])
Epoch 123 : 0.07600893825292587
torch.Size([128, 4])
torch.Size([128])
Epoch 123 : 0.11021140217781067
torch.Size([27, 4])
torch.Size([27])
Epoch 123 : 0.04577842727303505
torch.Size([128, 4])
torch.Size([128])
Epoch 124 : 0.06793590635061264
torch.Size([128, 4])
torch.Size([128])
Epoch 124 : 0.05685149505734444
torch.Size([128, 4])
torch.Size([128])
Epoch 124 : 0.0636427029967308
torch.Size([128, 4])
torch.Size([128])
Epoch 124 : 0.10137689113616943
torch.Size([128, 4])
torch.Size([128])
Epoch 124 : 0.12756973505020142
torch.Size([128, 4])
torch.Size([128])
Epoch 124 : 0.08060071617364883
torch.Size([128, 4])
torch.Size([128])
Epoch 124 : 0.10663250833749771
torch.Size([27, 4])
torch.Size([27])
Epoch 124 : 0.058357734233140945
torch.Size([128, 4])
torch.Size([128])
Epoch 125 : 0.07978034019470215
torch.Size([128, 4])
torch.Size([128])
Epoch 125 : 0.06812399625778198
torch.Size([128, 4])
torch.Size([128])
Epoch 125 

torch.Size([128, 4])
torch.Size([128])
Epoch 149 : 0.08113852143287659
torch.Size([128, 4])
torch.Size([128])
Epoch 149 : 0.07129088789224625
torch.Size([128, 4])
torch.Size([128])
Epoch 149 : 0.05613846331834793
torch.Size([128, 4])
torch.Size([128])
Epoch 149 : 0.11445190757513046
torch.Size([128, 4])
torch.Size([128])
Epoch 149 : 0.11168619245290756
torch.Size([128, 4])
torch.Size([128])
Epoch 149 : 0.09515704214572906
torch.Size([128, 4])
torch.Size([128])
Epoch 149 : 0.10104775428771973
torch.Size([27, 4])
torch.Size([27])
Epoch 149 : 0.03858821094036102
torch.Size([128, 4])
torch.Size([128])
Epoch 150 : 0.05584358423948288
torch.Size([128, 4])
torch.Size([128])
Epoch 150 : 0.08575879782438278
torch.Size([128, 4])
torch.Size([128])
Epoch 150 : 0.053071435540914536
torch.Size([128, 4])
torch.Size([128])
Epoch 150 : 0.14491549134254456
torch.Size([128, 4])
torch.Size([128])
Epoch 150 : 0.10742249339818954
torch.Size([128, 4])
torch.Size([128])
Epoch 150 : 0.07100262492895126
torch.S

torch.Size([128, 4])
torch.Size([128])
Epoch 174 : 0.03258374705910683
torch.Size([128, 4])
torch.Size([128])
Epoch 174 : 0.040293704718351364
torch.Size([128, 4])
torch.Size([128])
Epoch 174 : 0.03238930180668831
torch.Size([128, 4])
torch.Size([128])
Epoch 174 : 0.04367567598819733
torch.Size([27, 4])
torch.Size([27])
Epoch 174 : 0.017783381044864655
torch.Size([128, 4])
torch.Size([128])
Epoch 175 : 0.026614613831043243
torch.Size([128, 4])
torch.Size([128])
Epoch 175 : 0.02138206921517849
torch.Size([128, 4])
torch.Size([128])
Epoch 175 : 0.021536827087402344
torch.Size([128, 4])
torch.Size([128])
Epoch 175 : 0.033284951001405716
torch.Size([128, 4])
torch.Size([128])
Epoch 175 : 0.038989171385765076
torch.Size([128, 4])
torch.Size([128])
Epoch 175 : 0.030837824568152428
torch.Size([128, 4])
torch.Size([128])
Epoch 175 : 0.043671660125255585
torch.Size([27, 4])
torch.Size([27])
Epoch 175 : 0.017438678070902824
torch.Size([128, 4])
torch.Size([128])
Epoch 176 : 0.026627054437994957


torch.Size([128, 4])
torch.Size([128])
Epoch 199 : 0.01639380119740963
torch.Size([128, 4])
torch.Size([128])
Epoch 199 : 0.024394920095801353
torch.Size([128, 4])
torch.Size([128])
Epoch 199 : 0.03021998703479767
torch.Size([128, 4])
torch.Size([128])
Epoch 199 : 0.02408766560256481
torch.Size([128, 4])
torch.Size([128])
Epoch 199 : 0.032929543405771255
torch.Size([27, 4])
torch.Size([27])
Epoch 199 : 0.01291579008102417
torch.Size([128, 4])
torch.Size([128])
Epoch 200 : 0.020150061696767807
torch.Size([128, 4])
torch.Size([128])
Epoch 200 : 0.01603006012737751
torch.Size([128, 4])
torch.Size([128])
Epoch 200 : 0.016109708696603775
torch.Size([128, 4])
torch.Size([128])
Epoch 200 : 0.023941170424222946
torch.Size([128, 4])
torch.Size([128])
Epoch 200 : 0.029921414330601692
torch.Size([128, 4])
torch.Size([128])
Epoch 200 : 0.02363966405391693
torch.Size([128, 4])
torch.Size([128])
Epoch 200 : 0.03313090652227402
torch.Size([27, 4])
torch.Size([27])
Epoch 200 : 0.01290909107774496
torc

torch.Size([128, 4])
torch.Size([128])
Epoch 226 : 0.015432382002472878
torch.Size([128, 4])
torch.Size([128])
Epoch 226 : 0.012373784556984901
torch.Size([128, 4])
torch.Size([128])
Epoch 226 : 0.0122937997803092
torch.Size([128, 4])
torch.Size([128])
Epoch 226 : 0.018415328115224838
torch.Size([128, 4])
torch.Size([128])
Epoch 226 : 0.023458560928702354
torch.Size([128, 4])
torch.Size([128])
Epoch 226 : 0.01847185753285885
torch.Size([128, 4])
torch.Size([128])
Epoch 226 : 0.02568710781633854
torch.Size([27, 4])
torch.Size([27])
Epoch 226 : 0.009690598584711552
torch.Size([128, 4])
torch.Size([128])
Epoch 227 : 0.015289459377527237
torch.Size([128, 4])
torch.Size([128])
Epoch 227 : 0.0122224697843194
torch.Size([128, 4])
torch.Size([128])
Epoch 227 : 0.012227443978190422
torch.Size([128, 4])
torch.Size([128])
Epoch 227 : 0.018075555562973022
torch.Size([128, 4])
torch.Size([128])
Epoch 227 : 0.023423977196216583
torch.Size([128, 4])
torch.Size([128])
Epoch 227 : 0.01820148155093193
t

torch.Size([128, 4])
torch.Size([128])
Epoch 251 : 0.00982347596436739
torch.Size([128, 4])
torch.Size([128])
Epoch 251 : 0.009653371758759022
torch.Size([128, 4])
torch.Size([128])
Epoch 251 : 0.014344749040901661
torch.Size([128, 4])
torch.Size([128])
Epoch 251 : 0.018856283277273178
torch.Size([128, 4])
torch.Size([128])
Epoch 251 : 0.014533767476677895
torch.Size([128, 4])
torch.Size([128])
Epoch 251 : 0.020310213789343834
torch.Size([27, 4])
torch.Size([27])
Epoch 251 : 0.007174591068178415
torch.Size([128, 4])
torch.Size([128])
Epoch 252 : 0.011871150694787502
torch.Size([128, 4])
torch.Size([128])
Epoch 252 : 0.00984397903084755
torch.Size([128, 4])
torch.Size([128])
Epoch 252 : 0.009655985981225967
torch.Size([128, 4])
torch.Size([128])
Epoch 252 : 0.014327890239655972
torch.Size([128, 4])
torch.Size([128])
Epoch 252 : 0.0185927115380764
torch.Size([128, 4])
torch.Size([128])
Epoch 252 : 0.01416432298719883
torch.Size([128, 4])
torch.Size([128])
Epoch 252 : 0.020228620618581772

Epoch 276 : 0.009398060850799084
torch.Size([128, 4])
torch.Size([128])
Epoch 276 : 0.007891172543168068
torch.Size([128, 4])
torch.Size([128])
Epoch 276 : 0.007691994309425354
torch.Size([128, 4])
torch.Size([128])
Epoch 276 : 0.011407752521336079
torch.Size([128, 4])
torch.Size([128])
Epoch 276 : 0.01522168517112732
torch.Size([128, 4])
torch.Size([128])
Epoch 276 : 0.011572700925171375
torch.Size([128, 4])
torch.Size([128])
Epoch 276 : 0.016153041273355484
torch.Size([27, 4])
torch.Size([27])
Epoch 276 : 0.005661958362907171
torch.Size([128, 4])
torch.Size([128])
Epoch 277 : 0.009295986965298653
torch.Size([128, 4])
torch.Size([128])
Epoch 277 : 0.007866193540394306
torch.Size([128, 4])
torch.Size([128])
Epoch 277 : 0.007648048456758261
torch.Size([128, 4])
torch.Size([128])
Epoch 277 : 0.011247863993048668
torch.Size([128, 4])
torch.Size([128])
Epoch 277 : 0.015226149931550026
torch.Size([128, 4])
torch.Size([128])
Epoch 277 : 0.011574163101613522
torch.Size([128, 4])
torch.Size([1

torch.Size([128, 4])
torch.Size([128])
Epoch 302 : 0.012421690858900547
torch.Size([128, 4])
torch.Size([128])
Epoch 302 : 0.00949771422892809
torch.Size([128, 4])
torch.Size([128])
Epoch 302 : 0.012943770736455917
torch.Size([27, 4])
torch.Size([27])
Epoch 302 : 0.004584113601595163
torch.Size([128, 4])
torch.Size([128])
Epoch 303 : 0.007509865798056126
torch.Size([128, 4])
torch.Size([128])
Epoch 303 : 0.006336728576570749
torch.Size([128, 4])
torch.Size([128])
Epoch 303 : 0.006152798887342215
torch.Size([128, 4])
torch.Size([128])
Epoch 303 : 0.0089888209477067
torch.Size([128, 4])
torch.Size([128])
Epoch 303 : 0.012208588421344757
torch.Size([128, 4])
torch.Size([128])
Epoch 303 : 0.009213319979608059
torch.Size([128, 4])
torch.Size([128])
Epoch 303 : 0.012943301349878311
torch.Size([27, 4])
torch.Size([27])
Epoch 303 : 0.004414386581629515
torch.Size([128, 4])
torch.Size([128])
Epoch 304 : 0.007396046072244644
torch.Size([128, 4])
torch.Size([128])
Epoch 304 : 0.006211587693542242

torch.Size([27, 4])
torch.Size([27])
Epoch 328 : 0.0036923051811754704
torch.Size([128, 4])
torch.Size([128])
Epoch 329 : 0.005963511765003204
torch.Size([128, 4])
torch.Size([128])
Epoch 329 : 0.005088323727250099
torch.Size([128, 4])
torch.Size([128])
Epoch 329 : 0.004971825052052736
torch.Size([128, 4])
torch.Size([128])
Epoch 329 : 0.007128894794732332
torch.Size([128, 4])
torch.Size([128])
Epoch 329 : 0.010065010748803616
torch.Size([128, 4])
torch.Size([128])
Epoch 329 : 0.007331345230340958
torch.Size([128, 4])
torch.Size([128])
Epoch 329 : 0.010435022413730621
torch.Size([27, 4])
torch.Size([27])
Epoch 329 : 0.0035397051833570004
torch.Size([128, 4])
torch.Size([128])
Epoch 330 : 0.005903124343603849
torch.Size([128, 4])
torch.Size([128])
Epoch 330 : 0.0050730351358652115
torch.Size([128, 4])
torch.Size([128])
Epoch 330 : 0.004972781985998154
torch.Size([128, 4])
torch.Size([128])
Epoch 330 : 0.007046358194202185
torch.Size([128, 4])
torch.Size([128])
Epoch 330 : 0.009983559139

torch.Size([128, 4])
torch.Size([128])
Epoch 355 : 0.008239947259426117
torch.Size([128, 4])
torch.Size([128])
Epoch 355 : 0.0059218439273536205
torch.Size([128, 4])
torch.Size([128])
Epoch 355 : 0.008534586057066917
torch.Size([27, 4])
torch.Size([27])
Epoch 355 : 0.0028543586377054453
torch.Size([128, 4])
torch.Size([128])
Epoch 356 : 0.00473268236964941
torch.Size([128, 4])
torch.Size([128])
Epoch 356 : 0.004184400197118521
torch.Size([128, 4])
torch.Size([128])
Epoch 356 : 0.00408244039863348
torch.Size([128, 4])
torch.Size([128])
Epoch 356 : 0.005737940315157175
torch.Size([128, 4])
torch.Size([128])
Epoch 356 : 0.008236756548285484
torch.Size([128, 4])
torch.Size([128])
Epoch 356 : 0.005858319811522961
torch.Size([128, 4])
torch.Size([128])
Epoch 356 : 0.008539360016584396
torch.Size([27, 4])
torch.Size([27])
Epoch 356 : 0.002807829761877656
torch.Size([128, 4])
torch.Size([128])
Epoch 357 : 0.0046953777782619
torch.Size([128, 4])
torch.Size([128])
Epoch 357 : 0.00413306010887026

torch.Size([27, 4])
torch.Size([27])
Epoch 381 : 0.002309066243469715
torch.Size([128, 4])
torch.Size([128])
Epoch 382 : 0.0038835820741951466
torch.Size([128, 4])
torch.Size([128])
Epoch 382 : 0.003462644526734948
torch.Size([128, 4])
torch.Size([128])
Epoch 382 : 0.003387709381058812
torch.Size([128, 4])
torch.Size([128])
Epoch 382 : 0.004812545143067837
torch.Size([128, 4])
torch.Size([128])
Epoch 382 : 0.0068412297405302525
torch.Size([128, 4])
torch.Size([128])
Epoch 382 : 0.004831473343074322
torch.Size([128, 4])
torch.Size([128])
Epoch 382 : 0.007097106892615557
torch.Size([27, 4])
torch.Size([27])
Epoch 382 : 0.0023179948329925537
torch.Size([128, 4])
torch.Size([128])
Epoch 383 : 0.00384985888376832
torch.Size([128, 4])
torch.Size([128])
Epoch 383 : 0.0034205184783786535
torch.Size([128, 4])
torch.Size([128])
Epoch 383 : 0.0033693236764520407
torch.Size([128, 4])
torch.Size([128])
Epoch 383 : 0.004754809197038412
torch.Size([128, 4])
torch.Size([128])
Epoch 383 : 0.00689155980

torch.Size([128, 4])
torch.Size([128])
Epoch 408 : 0.004009442403912544
torch.Size([128, 4])
torch.Size([128])
Epoch 408 : 0.005704295821487904
torch.Size([128, 4])
torch.Size([128])
Epoch 408 : 0.0040125311352312565
torch.Size([128, 4])
torch.Size([128])
Epoch 408 : 0.005942832212895155
torch.Size([27, 4])
torch.Size([27])
Epoch 408 : 0.0018767458386719227
torch.Size([128, 4])
torch.Size([128])
Epoch 409 : 0.0031746807508170605
torch.Size([128, 4])
torch.Size([128])
Epoch 409 : 0.002873812336474657
torch.Size([128, 4])
torch.Size([128])
Epoch 409 : 0.002796979621052742
torch.Size([128, 4])
torch.Size([128])
Epoch 409 : 0.003982518799602985
torch.Size([128, 4])
torch.Size([128])
Epoch 409 : 0.0056937928311526775
torch.Size([128, 4])
torch.Size([128])
Epoch 409 : 0.003999597392976284
torch.Size([128, 4])
torch.Size([128])
Epoch 409 : 0.005865642335265875
torch.Size([27, 4])
torch.Size([27])
Epoch 409 : 0.0018803172279149294
torch.Size([128, 4])
torch.Size([128])
Epoch 410 : 0.0031582787

Epoch 434 : 0.0026651020161807537
torch.Size([128, 4])
torch.Size([128])
Epoch 434 : 0.002444923622533679
torch.Size([128, 4])
torch.Size([128])
Epoch 434 : 0.002362548839300871
torch.Size([128, 4])
torch.Size([128])
Epoch 434 : 0.0033483344595879316
torch.Size([128, 4])
torch.Size([128])
Epoch 434 : 0.004800016526132822
torch.Size([128, 4])
torch.Size([128])
Epoch 434 : 0.0033716424368321896
torch.Size([128, 4])
torch.Size([128])
Epoch 434 : 0.005008834879845381
torch.Size([27, 4])
torch.Size([27])
Epoch 434 : 0.0015385757433250546
torch.Size([128, 4])
torch.Size([128])
Epoch 435 : 0.0026284786872565746
torch.Size([128, 4])
torch.Size([128])
Epoch 435 : 0.0024237208999693394
torch.Size([128, 4])
torch.Size([128])
Epoch 435 : 0.0023551927879452705
torch.Size([128, 4])
torch.Size([128])
Epoch 435 : 0.00332235568203032
torch.Size([128, 4])
torch.Size([128])
Epoch 435 : 0.004725622478872538
torch.Size([128, 4])
torch.Size([128])
Epoch 435 : 0.0033925010357052088
torch.Size([128, 4])
torch

torch.Size([128, 4])
torch.Size([128])
Epoch 460 : 0.0019950163550674915
torch.Size([128, 4])
torch.Size([128])
Epoch 460 : 0.002796960761770606
torch.Size([128, 4])
torch.Size([128])
Epoch 460 : 0.004029013216495514
torch.Size([128, 4])
torch.Size([128])
Epoch 460 : 0.002865745685994625
torch.Size([128, 4])
torch.Size([128])
Epoch 460 : 0.004230877384543419
torch.Size([27, 4])
torch.Size([27])
Epoch 460 : 0.0012683294480666518
torch.Size([128, 4])
torch.Size([128])
Epoch 461 : 0.002211441518738866
torch.Size([128, 4])
torch.Size([128])
Epoch 461 : 0.0020676993299275637
torch.Size([128, 4])
torch.Size([128])
Epoch 461 : 0.0019986422266811132
torch.Size([128, 4])
torch.Size([128])
Epoch 461 : 0.002750921528786421
torch.Size([128, 4])
torch.Size([128])
Epoch 461 : 0.00404219888150692
torch.Size([128, 4])
torch.Size([128])
Epoch 461 : 0.0028825898189097643
torch.Size([128, 4])
torch.Size([128])
Epoch 461 : 0.0041698068380355835
torch.Size([27, 4])
torch.Size([27])
Epoch 461 : 0.0012627196

Epoch 486 : 0.0034445347264409065
torch.Size([128, 4])
torch.Size([128])
Epoch 486 : 0.0024697664193809032
torch.Size([128, 4])
torch.Size([128])
Epoch 486 : 0.0035702157765626907
torch.Size([27, 4])
torch.Size([27])
Epoch 486 : 0.0010569012956693769
torch.Size([128, 4])
torch.Size([128])
Epoch 487 : 0.0018702060915529728
torch.Size([128, 4])
torch.Size([128])
Epoch 487 : 0.001772098010405898
torch.Size([128, 4])
torch.Size([128])
Epoch 487 : 0.0017097970703616738
torch.Size([128, 4])
torch.Size([128])
Epoch 487 : 0.0023575425148010254
torch.Size([128, 4])
torch.Size([128])
Epoch 487 : 0.0034451051615178585
torch.Size([128, 4])
torch.Size([128])
Epoch 487 : 0.0024319752119481564
torch.Size([128, 4])
torch.Size([128])
Epoch 487 : 0.003572860499843955
torch.Size([27, 4])
torch.Size([27])
Epoch 487 : 0.0010342627065256238
torch.Size([128, 4])
torch.Size([128])
Epoch 488 : 0.001853502239100635
torch.Size([128, 4])
torch.Size([128])
Epoch 488 : 0.001762391533702612
torch.Size([128, 4])
torc

torch.Size([128, 4])
torch.Size([128])
Epoch 513 : 0.0015165851218625903
torch.Size([128, 4])
torch.Size([128])
Epoch 513 : 0.0014792685396969318
torch.Size([128, 4])
torch.Size([128])
Epoch 513 : 0.0019870165269821882
torch.Size([128, 4])
torch.Size([128])
Epoch 513 : 0.002991398563608527
torch.Size([128, 4])
torch.Size([128])
Epoch 513 : 0.002100193640217185
torch.Size([128, 4])
torch.Size([128])
Epoch 513 : 0.003051300533115864
torch.Size([27, 4])
torch.Size([27])
Epoch 513 : 0.0008743410580791533
torch.Size([128, 4])
torch.Size([128])
Epoch 514 : 0.0015701825032010674
torch.Size([128, 4])
torch.Size([128])
Epoch 514 : 0.0015140945324674249
torch.Size([128, 4])
torch.Size([128])
Epoch 514 : 0.0014586823526769876
torch.Size([128, 4])
torch.Size([128])
Epoch 514 : 0.001993841491639614
torch.Size([128, 4])
torch.Size([128])
Epoch 514 : 0.0029126820154488087
torch.Size([128, 4])
torch.Size([128])
Epoch 514 : 0.0020642150193452835
torch.Size([128, 4])
torch.Size([128])
Epoch 514 : 0.0030

torch.Size([27, 4])
torch.Size([27])
Epoch 539 : 0.000736223766580224
torch.Size([128, 4])
torch.Size([128])
Epoch 540 : 0.0013458789326250553
torch.Size([128, 4])
torch.Size([128])
Epoch 540 : 0.0013147388817742467
torch.Size([128, 4])
torch.Size([128])
Epoch 540 : 0.001259359996765852
torch.Size([128, 4])
torch.Size([128])
Epoch 540 : 0.0017059051897376776
torch.Size([128, 4])
torch.Size([128])
Epoch 540 : 0.002530014608055353
torch.Size([128, 4])
torch.Size([128])
Epoch 540 : 0.0017187884077429771
torch.Size([128, 4])
torch.Size([128])
Epoch 540 : 0.002597064943984151
torch.Size([27, 4])
torch.Size([27])
Epoch 540 : 0.0007293560192920268
torch.Size([128, 4])
torch.Size([128])
Epoch 541 : 0.0013352017849683762
torch.Size([128, 4])
torch.Size([128])
Epoch 541 : 0.001305832527577877
torch.Size([128, 4])
torch.Size([128])
Epoch 541 : 0.001262157573364675
torch.Size([128, 4])
torch.Size([128])
Epoch 541 : 0.0017021059757098556
torch.Size([128, 4])
torch.Size([128])
Epoch 541 : 0.00250815

Epoch 565 : 0.0006213102606125176
torch.Size([128, 4])
torch.Size([128])
Epoch 566 : 0.0011518142418935895
torch.Size([128, 4])
torch.Size([128])
Epoch 566 : 0.0011515733785927296
torch.Size([128, 4])
torch.Size([128])
Epoch 566 : 0.001076775835826993
torch.Size([128, 4])
torch.Size([128])
Epoch 566 : 0.0014790691202506423
torch.Size([128, 4])
torch.Size([128])
Epoch 566 : 0.002180651528760791
torch.Size([128, 4])
torch.Size([128])
Epoch 566 : 0.0014791232533752918
torch.Size([128, 4])
torch.Size([128])
Epoch 566 : 0.002242917660623789
torch.Size([27, 4])
torch.Size([27])
Epoch 566 : 0.0006158375763334334
torch.Size([128, 4])
torch.Size([128])
Epoch 567 : 0.001144189853221178
torch.Size([128, 4])
torch.Size([128])
Epoch 567 : 0.0011363792000338435
torch.Size([128, 4])
torch.Size([128])
Epoch 567 : 0.0010841527255252004
torch.Size([128, 4])
torch.Size([128])
Epoch 567 : 0.0014618764398619533
torch.Size([128, 4])
torch.Size([128])
Epoch 567 : 0.0021664099767804146
torch.Size([128, 4])
to

torch.Size([128, 4])
torch.Size([128])
Epoch 592 : 0.0009339911048300564
torch.Size([128, 4])
torch.Size([128])
Epoch 592 : 0.001288411091081798
torch.Size([128, 4])
torch.Size([128])
Epoch 592 : 0.0018899372080340981
torch.Size([128, 4])
torch.Size([128])
Epoch 592 : 0.001276288996450603
torch.Size([128, 4])
torch.Size([128])
Epoch 592 : 0.001935568987391889
torch.Size([27, 4])
torch.Size([27])
Epoch 592 : 0.0005298250471241772
torch.Size([128, 4])
torch.Size([128])
Epoch 593 : 0.000982463825494051
torch.Size([128, 4])
torch.Size([128])
Epoch 593 : 0.0010002064518630505
torch.Size([128, 4])
torch.Size([128])
Epoch 593 : 0.000943807594012469
torch.Size([128, 4])
torch.Size([128])
Epoch 593 : 0.0012576229637488723
torch.Size([128, 4])
torch.Size([128])
Epoch 593 : 0.0019102506339550018
torch.Size([128, 4])
torch.Size([128])
Epoch 593 : 0.0012803319841623306
torch.Size([128, 4])
torch.Size([128])
Epoch 593 : 0.0019137038616463542
torch.Size([27, 4])
torch.Size([27])
Epoch 593 : 0.0005246

Epoch 617 : 0.0004544180992525071
torch.Size([128, 4])
torch.Size([128])
Epoch 618 : 0.0008568065823055804
torch.Size([128, 4])
torch.Size([128])
Epoch 618 : 0.0008870541350916028
torch.Size([128, 4])
torch.Size([128])
Epoch 618 : 0.000812011887319386
torch.Size([128, 4])
torch.Size([128])
Epoch 618 : 0.0011171638034284115
torch.Size([128, 4])
torch.Size([128])
Epoch 618 : 0.0016436240402981639
torch.Size([128, 4])
torch.Size([128])
Epoch 618 : 0.0011029024608433247
torch.Size([128, 4])
torch.Size([128])
Epoch 618 : 0.00166081462521106
torch.Size([27, 4])
torch.Size([27])
Epoch 618 : 0.00045126205077394843
torch.Size([128, 4])
torch.Size([128])
Epoch 619 : 0.0008491865010000765
torch.Size([128, 4])
torch.Size([128])
Epoch 619 : 0.0008759570773690939
torch.Size([128, 4])
torch.Size([128])
Epoch 619 : 0.0008131748181767762
torch.Size([128, 4])
torch.Size([128])
Epoch 619 : 0.0010983776301145554
torch.Size([128, 4])
torch.Size([128])
Epoch 619 : 0.0016429921379312873
torch.Size([128, 4])


torch.Size([128, 4])
torch.Size([128])
Epoch 644 : 0.0009671240695752203
torch.Size([128, 4])
torch.Size([128])
Epoch 644 : 0.001436246675439179
torch.Size([27, 4])
torch.Size([27])
Epoch 644 : 0.0003819902194663882
torch.Size([128, 4])
torch.Size([128])
Epoch 645 : 0.0007386216311715543
torch.Size([128, 4])
torch.Size([128])
Epoch 645 : 0.000776932982262224
torch.Size([128, 4])
torch.Size([128])
Epoch 645 : 0.0007054626476019621
torch.Size([128, 4])
torch.Size([128])
Epoch 645 : 0.0009604039369150996
torch.Size([128, 4])
torch.Size([128])
Epoch 645 : 0.001427637878805399
torch.Size([128, 4])
torch.Size([128])
Epoch 645 : 0.0009613758302293718
torch.Size([128, 4])
torch.Size([128])
Epoch 645 : 0.0014213136164471507
torch.Size([27, 4])
torch.Size([27])
Epoch 645 : 0.0003810174821410328
torch.Size([128, 4])
torch.Size([128])
Epoch 646 : 0.0007341243326663971
torch.Size([128, 4])
torch.Size([128])
Epoch 646 : 0.0007717980188317597
torch.Size([128, 4])
torch.Size([128])
Epoch 646 : 0.00070

torch.Size([128, 4])
torch.Size([128])
Epoch 670 : 0.0008380005601793528
torch.Size([128, 4])
torch.Size([128])
Epoch 670 : 0.0012665499234572053
torch.Size([128, 4])
torch.Size([128])
Epoch 670 : 0.0008397568017244339
torch.Size([128, 4])
torch.Size([128])
Epoch 670 : 0.0012439407873898745
torch.Size([27, 4])
torch.Size([27])
Epoch 670 : 0.0003283279074821621
torch.Size([128, 4])
torch.Size([128])
Epoch 671 : 0.0006430171779356897
torch.Size([128, 4])
torch.Size([128])
Epoch 671 : 0.0006836111424490809
torch.Size([128, 4])
torch.Size([128])
Epoch 671 : 0.0006174220470711589
torch.Size([128, 4])
torch.Size([128])
Epoch 671 : 0.0008379171486012638
torch.Size([128, 4])
torch.Size([128])
Epoch 671 : 0.0012482001911848783
torch.Size([128, 4])
torch.Size([128])
Epoch 671 : 0.0008366277324967086
torch.Size([128, 4])
torch.Size([128])
Epoch 671 : 0.0012442374136298895
torch.Size([27, 4])
torch.Size([27])
Epoch 671 : 0.0003278430085629225
torch.Size([128, 4])
torch.Size([128])
Epoch 672 : 0.00

torch.Size([128, 4])
torch.Size([128])
Epoch 696 : 0.0006039994186721742
torch.Size([128, 4])
torch.Size([128])
Epoch 696 : 0.0005476378719322383
torch.Size([128, 4])
torch.Size([128])
Epoch 696 : 0.0007354446570388973
torch.Size([128, 4])
torch.Size([128])
Epoch 696 : 0.0011020044330507517
torch.Size([128, 4])
torch.Size([128])
Epoch 696 : 0.0007327980711124837
torch.Size([128, 4])
torch.Size([128])
Epoch 696 : 0.0010793578112497926
torch.Size([27, 4])
torch.Size([27])
Epoch 696 : 0.00028298739925958216
torch.Size([128, 4])
torch.Size([128])
Epoch 697 : 0.0005579678690992296
torch.Size([128, 4])
torch.Size([128])
Epoch 697 : 0.0006040083826519549
torch.Size([128, 4])
torch.Size([128])
Epoch 697 : 0.0005452851764857769
torch.Size([128, 4])
torch.Size([128])
Epoch 697 : 0.000724976125638932
torch.Size([128, 4])
torch.Size([128])
Epoch 697 : 0.0011130101047456264
torch.Size([128, 4])
torch.Size([128])
Epoch 697 : 0.0007282417500391603
torch.Size([128, 4])
torch.Size([128])
Epoch 697 : 0.

torch.Size([128, 4])
torch.Size([128])
Epoch 722 : 0.0005338290357030928
torch.Size([128, 4])
torch.Size([128])
Epoch 722 : 0.00048049038741737604
torch.Size([128, 4])
torch.Size([128])
Epoch 722 : 0.0006415495299734175
torch.Size([128, 4])
torch.Size([128])
Epoch 722 : 0.0009648756822571158
torch.Size([128, 4])
torch.Size([128])
Epoch 722 : 0.0006372607895173132
torch.Size([128, 4])
torch.Size([128])
Epoch 722 : 0.0009404463926330209
torch.Size([27, 4])
torch.Size([27])
Epoch 722 : 0.0002457688970025629
torch.Size([128, 4])
torch.Size([128])
Epoch 723 : 0.0004836150910705328
torch.Size([128, 4])
torch.Size([128])
Epoch 723 : 0.0005295822629705071
torch.Size([128, 4])
torch.Size([128])
Epoch 723 : 0.00048199211596511304
torch.Size([128, 4])
torch.Size([128])
Epoch 723 : 0.0006353732314892113
torch.Size([128, 4])
torch.Size([128])
Epoch 723 : 0.0009630775894038379
torch.Size([128, 4])
torch.Size([128])
Epoch 723 : 0.0006364801665768027
torch.Size([128, 4])
torch.Size([128])
Epoch 723 : 

torch.Size([128, 4])
torch.Size([128])
Epoch 748 : 0.0008518424001522362
torch.Size([128, 4])
torch.Size([128])
Epoch 748 : 0.0005602528108283877
torch.Size([128, 4])
torch.Size([128])
Epoch 748 : 0.0008224641787819564
torch.Size([27, 4])
torch.Size([27])
Epoch 748 : 0.0002144313621101901
torch.Size([128, 4])
torch.Size([128])
Epoch 749 : 0.00042248022509738803
torch.Size([128, 4])
torch.Size([128])
Epoch 749 : 0.00047214870573952794
torch.Size([128, 4])
torch.Size([128])
Epoch 749 : 0.0004240695561747998
torch.Size([128, 4])
torch.Size([128])
Epoch 749 : 0.0005614091060124338
torch.Size([128, 4])
torch.Size([128])
Epoch 749 : 0.0008616759441792965
torch.Size([128, 4])
torch.Size([128])
Epoch 749 : 0.0005538551486097276
torch.Size([128, 4])
torch.Size([128])
Epoch 749 : 0.0008204224868677557
torch.Size([27, 4])
torch.Size([27])
Epoch 749 : 0.00021366917644627392
torch.Size([128, 4])
torch.Size([128])
Epoch 750 : 0.0004190137260593474
torch.Size([128, 4])
torch.Size([128])
Epoch 750 : 0

torch.Size([128, 4])
torch.Size([128])
Epoch 775 : 0.00036807660944759846
torch.Size([128, 4])
torch.Size([128])
Epoch 775 : 0.0004103494866285473
torch.Size([128, 4])
torch.Size([128])
Epoch 775 : 0.0003737105580512434
torch.Size([128, 4])
torch.Size([128])
Epoch 775 : 0.0004921344225294888
torch.Size([128, 4])
torch.Size([128])
Epoch 775 : 0.0007502353400923312
torch.Size([128, 4])
torch.Size([128])
Epoch 775 : 0.0004894481971859932
torch.Size([128, 4])
torch.Size([128])
Epoch 775 : 0.0007172704790718853
torch.Size([27, 4])
torch.Size([27])
Epoch 775 : 0.00018423364963382483
torch.Size([128, 4])
torch.Size([128])
Epoch 776 : 0.0003671057347673923
torch.Size([128, 4])
torch.Size([128])
Epoch 776 : 0.00041096878703683615
torch.Size([128, 4])
torch.Size([128])
Epoch 776 : 0.0003713242767844349
torch.Size([128, 4])
torch.Size([128])
Epoch 776 : 0.00048684634384699166
torch.Size([128, 4])
torch.Size([128])
Epoch 776 : 0.0007494613528251648
torch.Size([128, 4])
torch.Size([128])
Epoch 776 

torch.Size([128, 4])
torch.Size([128])
Epoch 801 : 0.0003289424057584256
torch.Size([128, 4])
torch.Size([128])
Epoch 801 : 0.00042891554767265916
torch.Size([128, 4])
torch.Size([128])
Epoch 801 : 0.0006720182718709111
torch.Size([128, 4])
torch.Size([128])
Epoch 801 : 0.00043309529428370297
torch.Size([128, 4])
torch.Size([128])
Epoch 801 : 0.0006304983398877084
torch.Size([27, 4])
torch.Size([27])
Epoch 801 : 0.00016138523642439395
torch.Size([128, 4])
torch.Size([128])
Epoch 802 : 0.000321669940603897
torch.Size([128, 4])
torch.Size([128])
Epoch 802 : 0.0003586288949009031
torch.Size([128, 4])
torch.Size([128])
Epoch 802 : 0.0003261610690969974
torch.Size([128, 4])
torch.Size([128])
Epoch 802 : 0.0004319229337852448
torch.Size([128, 4])
torch.Size([128])
Epoch 802 : 0.0006581469788216054
torch.Size([128, 4])
torch.Size([128])
Epoch 802 : 0.00042935245437547565
torch.Size([128, 4])
torch.Size([128])
Epoch 802 : 0.0006353424396365881
torch.Size([27, 4])
torch.Size([27])
Epoch 802 : 0

torch.Size([128, 4])
torch.Size([128])
Epoch 827 : 0.0005544381565414369
torch.Size([27, 4])
torch.Size([27])
Epoch 827 : 0.00014047138392925262
torch.Size([128, 4])
torch.Size([128])
Epoch 828 : 0.00028196710627526045
torch.Size([128, 4])
torch.Size([128])
Epoch 828 : 0.00031623183167539537
torch.Size([128, 4])
torch.Size([128])
Epoch 828 : 0.000288106850348413
torch.Size([128, 4])
torch.Size([128])
Epoch 828 : 0.0003787284076679498
torch.Size([128, 4])
torch.Size([128])
Epoch 828 : 0.0005847903667017817
torch.Size([128, 4])
torch.Size([128])
Epoch 828 : 0.0003794281801674515
torch.Size([128, 4])
torch.Size([128])
Epoch 828 : 0.0005557829281315207
torch.Size([27, 4])
torch.Size([27])
Epoch 828 : 0.00013938185293227434
torch.Size([128, 4])
torch.Size([128])
Epoch 829 : 0.00028042125632055104
torch.Size([128, 4])
torch.Size([128])
Epoch 829 : 0.00031557222246192396
torch.Size([128, 4])
torch.Size([128])
Epoch 829 : 0.00028657991788350046
torch.Size([128, 4])
torch.Size([128])
Epoch 829 

torch.Size([128, 4])
torch.Size([128])
Epoch 854 : 0.0002798674686346203
torch.Size([128, 4])
torch.Size([128])
Epoch 854 : 0.00025561300572007895
torch.Size([128, 4])
torch.Size([128])
Epoch 854 : 0.0003325811994727701
torch.Size([128, 4])
torch.Size([128])
Epoch 854 : 0.0005289223045110703
torch.Size([128, 4])
torch.Size([128])
Epoch 854 : 0.0003359847760293633
torch.Size([128, 4])
torch.Size([128])
Epoch 854 : 0.0004900358035229146
torch.Size([27, 4])
torch.Size([27])
Epoch 854 : 0.00012277044879738241
torch.Size([128, 4])
torch.Size([128])
Epoch 855 : 0.0002485253498889506
torch.Size([128, 4])
torch.Size([128])
Epoch 855 : 0.0002781289513222873
torch.Size([128, 4])
torch.Size([128])
Epoch 855 : 0.0002531821664888412
torch.Size([128, 4])
torch.Size([128])
Epoch 855 : 0.0003369449113961309
torch.Size([128, 4])
torch.Size([128])
Epoch 855 : 0.0005158047424629331
torch.Size([128, 4])
torch.Size([128])
Epoch 855 : 0.0003318731614854187
torch.Size([128, 4])
torch.Size([128])
Epoch 855 : 

torch.Size([128, 4])
torch.Size([128])
Epoch 880 : 0.00043539932812564075
torch.Size([27, 4])
torch.Size([27])
Epoch 880 : 0.00010661577107384801
torch.Size([128, 4])
torch.Size([128])
Epoch 881 : 0.00021713480236940086
torch.Size([128, 4])
torch.Size([128])
Epoch 881 : 0.0002445575955789536
torch.Size([128, 4])
torch.Size([128])
Epoch 881 : 0.00022334013192448765
torch.Size([128, 4])
torch.Size([128])
Epoch 881 : 0.00029348841053433716
torch.Size([128, 4])
torch.Size([128])
Epoch 881 : 0.00045570102520287037
torch.Size([128, 4])
torch.Size([128])
Epoch 881 : 0.00029106211150065064
torch.Size([128, 4])
torch.Size([128])
Epoch 881 : 0.00042810203740373254
torch.Size([27, 4])
torch.Size([27])
Epoch 881 : 0.00010583880066405982
torch.Size([128, 4])
torch.Size([128])
Epoch 882 : 0.00021693328744731843
torch.Size([128, 4])
torch.Size([128])
Epoch 882 : 0.00024200919142458588
torch.Size([128, 4])
torch.Size([128])
Epoch 882 : 0.0002225840580649674
torch.Size([128, 4])
torch.Size([128])
Epoch

torch.Size([128, 4])
torch.Size([128])
Epoch 907 : 0.0002141021832358092
torch.Size([128, 4])
torch.Size([128])
Epoch 907 : 0.00019835842249449342
torch.Size([128, 4])
torch.Size([128])
Epoch 907 : 0.00025845092022791505
torch.Size([128, 4])
torch.Size([128])
Epoch 907 : 0.0004034504818264395
torch.Size([128, 4])
torch.Size([128])
Epoch 907 : 0.0002575954422354698
torch.Size([128, 4])
torch.Size([128])
Epoch 907 : 0.0003793641517404467
torch.Size([27, 4])
torch.Size([27])
Epoch 907 : 9.286760905524716e-05
torch.Size([128, 4])
torch.Size([128])
Epoch 908 : 0.00019106837862636894
torch.Size([128, 4])
torch.Size([128])
Epoch 908 : 0.00021390008623711765
torch.Size([128, 4])
torch.Size([128])
Epoch 908 : 0.00019669887842610478
torch.Size([128, 4])
torch.Size([128])
Epoch 908 : 0.00025944196386262774
torch.Size([128, 4])
torch.Size([128])
Epoch 908 : 0.00040111347334459424
torch.Size([128, 4])
torch.Size([128])
Epoch 908 : 0.00025641321553848684
torch.Size([128, 4])
torch.Size([128])
Epoch 

torch.Size([128, 4])
torch.Size([128])
Epoch 933 : 0.0002283749490743503
torch.Size([128, 4])
torch.Size([128])
Epoch 933 : 0.0003580485063139349
torch.Size([128, 4])
torch.Size([128])
Epoch 933 : 0.0002266226219944656
torch.Size([128, 4])
torch.Size([128])
Epoch 933 : 0.00033567420905455947
torch.Size([27, 4])
torch.Size([27])
Epoch 933 : 8.129981142701581e-05
torch.Size([128, 4])
torch.Size([128])
Epoch 934 : 0.000168421640410088
torch.Size([128, 4])
torch.Size([128])
Epoch 934 : 0.0001897000620374456
torch.Size([128, 4])
torch.Size([128])
Epoch 934 : 0.00017430313164368272
torch.Size([128, 4])
torch.Size([128])
Epoch 934 : 0.00022792200616095215
torch.Size([128, 4])
torch.Size([128])
Epoch 934 : 0.0003588536928873509
torch.Size([128, 4])
torch.Size([128])
Epoch 934 : 0.00022493579308502376
torch.Size([128, 4])
torch.Size([128])
Epoch 934 : 0.0003353562206029892
torch.Size([27, 4])
torch.Size([27])
Epoch 934 : 8.055827493080869e-05
torch.Size([128, 4])
torch.Size([128])
Epoch 935 : 0

torch.Size([128, 4])
torch.Size([128])
Epoch 959 : 0.0002023371635004878
torch.Size([128, 4])
torch.Size([128])
Epoch 959 : 0.00031833723187446594
torch.Size([128, 4])
torch.Size([128])
Epoch 959 : 0.00020044227130711079
torch.Size([128, 4])
torch.Size([128])
Epoch 959 : 0.0002982816658914089
torch.Size([27, 4])
torch.Size([27])
Epoch 959 : 7.109527359716594e-05
torch.Size([128, 4])
torch.Size([128])
Epoch 960 : 0.00014908892626408488
torch.Size([128, 4])
torch.Size([128])
Epoch 960 : 0.00016892218263819814
torch.Size([128, 4])
torch.Size([128])
Epoch 960 : 0.00015415535017382354
torch.Size([128, 4])
torch.Size([128])
Epoch 960 : 0.0002033899654634297
torch.Size([128, 4])
torch.Size([128])
Epoch 960 : 0.00031569035490974784
torch.Size([128, 4])
torch.Size([128])
Epoch 960 : 0.00019846783834509552
torch.Size([128, 4])
torch.Size([128])
Epoch 960 : 0.00029797962633892894
torch.Size([27, 4])
torch.Size([27])
Epoch 960 : 7.080846262397245e-05
torch.Size([128, 4])
torch.Size([128])
Epoch 96

torch.Size([128, 4])
torch.Size([128])
Epoch 984 : 0.00013290089555084705
torch.Size([128, 4])
torch.Size([128])
Epoch 984 : 0.0001491714210715145
torch.Size([128, 4])
torch.Size([128])
Epoch 984 : 0.00013774560648016632
torch.Size([128, 4])
torch.Size([128])
Epoch 984 : 0.0001817760057747364
torch.Size([128, 4])
torch.Size([128])
Epoch 984 : 0.00028413531254045665
torch.Size([128, 4])
torch.Size([128])
Epoch 984 : 0.00017779454356059432
torch.Size([128, 4])
torch.Size([128])
Epoch 984 : 0.0002680054458323866
torch.Size([27, 4])
torch.Size([27])
Epoch 984 : 6.249255238799378e-05
torch.Size([128, 4])
torch.Size([128])
Epoch 985 : 0.0001318814029218629
torch.Size([128, 4])
torch.Size([128])
Epoch 985 : 0.00014845834812149405
torch.Size([128, 4])
torch.Size([128])
Epoch 985 : 0.00013858877355232835
torch.Size([128, 4])
torch.Size([128])
Epoch 985 : 0.00017892583855427802
torch.Size([128, 4])
torch.Size([128])
Epoch 985 : 0.00028550621937029064
torch.Size([128, 4])
torch.Size([128])
Epoch 

In [33]:
# 
testX = torch.Tensor([binary_encode(i,NUM_DIGITS) for i in range(1,101)])
# testY = torch.Tensor([])
with torch.no_grad(): # 在进行预测的时候不需要进行梯度计算，所以要加上这个
    testY = model(testX)

In [44]:
# testY.max(1)# 会返回两组数，第一组是4个数中最大的那个数的值是多少，第二组数值这个最大值对应的4个位置的值索引
# testY.max(1)[1] # 我们把位置索引取出来就行
testY.max(1)[1].cpu().data.numpy() # 转成cpu的numpy
# testY.max(1)[1].cpu().data.tolist() # 转成1维的List

array([0, 0, 1, 0, 2, 1, 0, 0, 1, 2, 0, 1, 0, 0, 3, 0, 2, 1, 2, 2, 1, 0,
       1, 1, 2, 0, 1, 0, 0, 3, 0, 0, 1, 0, 2, 1, 0, 1, 1, 2, 0, 1, 0, 0,
       3, 0, 0, 1, 0, 2, 1, 0, 0, 1, 2, 0, 1, 0, 0, 3, 0, 0, 1, 0, 0, 1,
       0, 0, 1, 2, 0, 1, 0, 0, 3, 2, 0, 1, 0, 2, 1, 0, 2, 1, 2, 0, 1, 0,
       0, 3, 0, 0, 1, 0, 2, 1, 0, 1, 1, 2])

In [50]:
predictions = zip(range(1,101),testY.max(1)[1].cpu().data.tolist())
# predictions
print([fizz_buzz_decoder(i,x) for i,x in predictions])

['1', '2', 'fizz', '4', 'buzz', 'fizz', '7', '8', 'fizz', 'buzz', '11', 'fizz', '13', '14', 'fizzbuzz', '16', 'buzz', 'fizz', 'buzz', 'buzz', 'fizz', '22', 'fizz', 'fizz', 'buzz', '26', 'fizz', '28', '29', 'fizzbuzz', '31', '32', 'fizz', '34', 'buzz', 'fizz', '37', 'fizz', 'fizz', 'buzz', '41', 'fizz', '43', '44', 'fizzbuzz', '46', '47', 'fizz', '49', 'buzz', 'fizz', '52', '53', 'fizz', 'buzz', '56', 'fizz', '58', '59', 'fizzbuzz', '61', '62', 'fizz', '64', '65', 'fizz', '67', '68', 'fizz', 'buzz', '71', 'fizz', '73', '74', 'fizzbuzz', 'buzz', '77', 'fizz', '79', 'buzz', 'fizz', '82', 'buzz', 'fizz', 'buzz', '86', 'fizz', '88', '89', 'fizzbuzz', '91', '92', 'fizz', '94', 'buzz', 'fizz', '97', 'fizz', 'fizz', 'buzz']


In [60]:
realTestY =[]
for i in range(1,101):
    realTestY.append([fizz_buzz_decoder(i,fizz_buzz_encoder(i))])
realTestY = np.array(realTestY).reshape((100,))
realTestY.shape
realTestY

array(['1', '2', 'fizz', '4', 'buzz', 'fizz', '7', '8', 'fizz', 'buzz',
       '11', 'fizz', '13', '14', 'fizzbuzz', '16', '17', 'fizz', '19',
       'buzz', 'fizz', '22', '23', 'fizz', 'buzz', '26', 'fizz', '28',
       '29', 'fizzbuzz', '31', '32', 'fizz', '34', 'buzz', 'fizz', '37',
       '38', 'fizz', 'buzz', '41', 'fizz', '43', '44', 'fizzbuzz', '46',
       '47', 'fizz', '49', 'buzz', 'fizz', '52', '53', 'fizz', 'buzz',
       '56', 'fizz', '58', '59', 'fizzbuzz', '61', '62', 'fizz', '64',
       'buzz', 'fizz', '67', '68', 'fizz', 'buzz', '71', 'fizz', '73',
       '74', 'fizzbuzz', '76', '77', 'fizz', '79', 'buzz', 'fizz', '82',
       '83', 'fizz', 'buzz', '86', 'fizz', '88', '89', 'fizzbuzz', '91',
       '92', 'fizz', '94', 'buzz', 'fizz', '97', '98', 'fizz', 'buzz'],
      dtype='<U8')