## Applying classification by neural network on iris data set

In [33]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn import datasets
import numpy as np

In [34]:
# import some data to play with
iris = datasets.load_iris()

In [35]:
predictor = iris.data
target = iris.target

In [36]:
predictor

array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2],
       [5.4, 3.9, 1.7, 0.4],
       [4.6, 3.4, 1.4, 0.3],
       [5. , 3.4, 1.5, 0.2],
       [4.4, 2.9, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5.4, 3.7, 1.5, 0.2],
       [4.8, 3.4, 1.6, 0.2],
       [4.8, 3. , 1.4, 0.1],
       [4.3, 3. , 1.1, 0.1],
       [5.8, 4. , 1.2, 0.2],
       [5.7, 4.4, 1.5, 0.4],
       [5.4, 3.9, 1.3, 0.4],
       [5.1, 3.5, 1.4, 0.3],
       [5.7, 3.8, 1.7, 0.3],
       [5.1, 3.8, 1.5, 0.3],
       [5.4, 3.4, 1.7, 0.2],
       [5.1, 3.7, 1.5, 0.4],
       [4.6, 3.6, 1. , 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [4.8, 3.4, 1.9, 0.2],
       [5. , 3. , 1.6, 0.2],
       [5. , 3.4, 1.6, 0.4],
       [5.2, 3.5, 1.5, 0.2],
       [5.2, 3.4, 1.4, 0.2],
       [4.7, 3.2, 1.6, 0.2],
       [4.8, 3.1, 1.6, 0.2],
       [5.4, 3.4, 1.5, 0.4],
       [5.2, 4.1, 1.5, 0.1],
       [5.5, 4.2, 1.4, 0.2],
       [4.9, 3

In [37]:
predictor1=predictor.reshape(predictor.shape[0]*predictor.shape[1])
target

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

In [38]:
data_set=np.column_stack((predictor,target))

In [39]:
np.random.shuffle(data_set)
training, test = data_set[:120,:], data_set[120:,:]

In [40]:
train_x = training[:,:4]
train_y = training[:,4]
test_x = test[:,:4]
test_y = test[:,4]

creating a neural network for classification

In [41]:
import torch.nn as nn
import torch.nn.functional as F
#our class must extend nn.Module
class MyClassifier(nn.Module):
    def __init__(self):
        super(MyClassifier,self).__init__()
        #Our network consists of 3 layers. 1 input, 1 hidden and 1 output layer
        #This applies Linear transformation to input data. 
        self.layer1 = nn.Linear(4,3)
        self.layer2 = nn.Linear(3,3)
    
    
    #This must be implemented
    def forward(self,x):
        #of the first layer
        x = self.layer1(x)
        #Activation function is Relu. Feel free to experiment with this
        x = F.tanh(x)
        #This produces output
        x = self.layer2(x)    
        return x
        
    #This function takes an input and predicts the class, (0 or 1)        
    def predict(self,x):
        #Apply softmax to output. 
        prediction = F.softmax(self.forward(x))
        ans = []
        #Pick the class with maximum weight
        for t in prediction:
            #print("\nt0 :- ",t[0])
            #print("\nt1 :- ",t[1])
            #print("\nt2 :-",t[2])
            if t[0]>t[1]:
                if(t[0]>t[2]):
                    ans.append(0)
                else:
                    ans.append(2)
            elif(t[1]>t[2]):
                ans.append(1)
            else:
                ans.append(2)
        return torch.tensor(ans)

In [69]:
model.layer1.weight.data

tensor([[ 0.5160,  0.9782, -1.4440, -1.8860],
        [-0.1577, -1.0606,  1.2048,  1.5239],
        [ 3.2123,  3.1528, -6.4492, -5.2640]])

In [70]:
model.layer2.weight.data

tensor([[  3.7085,  -3.1482,   3.4774],
        [ -4.8522,   4.8621,  16.0103],
        [ -1.9160,   1.4510, -14.7721]])

In [42]:
import torch
#Initialize the model        
model = MyClassifier()
#Define loss criterion
criterion = nn.CrossEntropyLoss()
#Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [43]:
#convert the array into pytorch object
train_x = torch.from_numpy(train_x).type(torch.FloatTensor)
train_y = torch.from_numpy(train_y).type(torch.LongTensor)
test_x = torch.from_numpy(test_x).type(torch.FloatTensor)
test_y = torch.from_numpy(test_y).type(torch.LongTensor)

In [44]:
train_x.requires_grad = True
test_x.requires_grad = True

In [71]:
train_x

tensor([[6.4000, 2.7000, 5.3000, 1.9000],
        [5.8000, 2.7000, 4.1000, 1.0000],
        [5.4000, 3.9000, 1.3000, 0.4000],
        [6.7000, 3.0000, 5.2000, 2.3000],
        [6.4000, 2.8000, 5.6000, 2.2000],
        [6.8000, 3.0000, 5.5000, 2.1000],
        [4.4000, 2.9000, 1.4000, 0.2000],
        [4.7000, 3.2000, 1.3000, 0.2000],
        [6.7000, 3.1000, 4.7000, 1.5000],
        [6.7000, 3.3000, 5.7000, 2.5000],
        [5.7000, 4.4000, 1.5000, 0.4000],
        [6.3000, 2.7000, 4.9000, 1.8000],
        [6.9000, 3.1000, 4.9000, 1.5000],
        [7.2000, 3.0000, 5.8000, 1.6000],
        [6.7000, 3.3000, 5.7000, 2.1000],
        [5.1000, 3.5000, 1.4000, 0.2000],
        [5.0000, 3.3000, 1.4000, 0.2000],
        [5.4000, 3.0000, 4.5000, 1.5000],
        [4.6000, 3.4000, 1.4000, 0.3000],
        [5.0000, 3.5000, 1.6000, 0.6000],
        [5.0000, 2.3000, 3.3000, 1.0000],
        [7.6000, 3.0000, 6.6000, 2.1000],
        [6.0000, 2.2000, 5.0000, 1.5000],
        [6.4000, 3.1000, 5.5000, 1

In [45]:
#Number of epochs
epochs = 10000
#List to store losses
losses = []
for i in range(epochs):
    #Precit the output for Given input
    y_pred = model.forward(train_x)
    #print(y_pred)
    #print(train_y)
    #Clear the previous gradients
    optimizer.zero_grad()
    #Compute Cross entropy loss
    loss = criterion(y_pred,train_y)
    #Add loss to the list
    losses.append(loss.item())
    #Compute gradients
    loss.backward()
    #Adjust weights
    optimizer.step()
    print('epoch {}, loss {} '.format(i, loss.item()))

epoch 0, loss 1.1723262071609497 
epoch 1, loss 1.153887391090393 
epoch 2, loss 1.1381486654281616 
epoch 3, loss 1.124848484992981 
epoch 4, loss 1.1137205362319946 
epoch 5, loss 1.1044681072235107 
epoch 6, loss 1.0967869758605957 
epoch 7, loss 1.0903762578964233 
epoch 8, loss 1.0849391222000122 
epoch 9, loss 1.0801780223846436 
epoch 10, loss 1.0757888555526733 
epoch 11, loss 1.071455478668213 
epoch 12, loss 1.0668590068817139 
epoch 13, loss 1.0616931915283203 
epoch 14, loss 1.0556968450546265 
epoch 15, loss 1.0486921072006226 
epoch 16, loss 1.040618658065796 
epoch 17, loss 1.031565546989441 
epoch 18, loss 1.0217591524124146 
epoch 19, loss 1.0114400386810303 
epoch 20, loss 1.000741720199585 
epoch 21, loss 0.9898277521133423 
epoch 22, loss 0.9789243936538696 
epoch 23, loss 0.9682205319404602 
epoch 24, loss 0.9577823877334595 
epoch 25, loss 0.9475077390670776 
epoch 26, loss 0.9371415972709656 
epoch 27, loss 0.9263619780540466 
epoch 28, loss 0.9149017333984375 
e

epoch 283, loss 0.06438995897769928 
epoch 284, loss 0.06413354724645615 
epoch 285, loss 0.06387953460216522 
epoch 286, loss 0.06362791359424591 
epoch 287, loss 0.06337862461805344 
epoch 288, loss 0.06313169002532959 
epoch 289, loss 0.06288697570562363 
epoch 290, loss 0.06264455616474152 
epoch 291, loss 0.0624043233692646 
epoch 292, loss 0.06216631084680557 
epoch 293, loss 0.061930425465106964 
epoch 294, loss 0.06169668957591057 
epoch 295, loss 0.06146503612399101 
epoch 296, loss 0.06123543903231621 
epoch 297, loss 0.061007898300886154 
epoch 298, loss 0.06078238785266876 
epoch 299, loss 0.06055881828069687 
epoch 300, loss 0.060337286442518234 
epoch 301, loss 0.06011759489774704 
epoch 302, loss 0.059899892657995224 
epoch 303, loss 0.059684015810489655 
epoch 304, loss 0.05947002023458481 
epoch 305, loss 0.05925782769918442 
epoch 306, loss 0.059047505259513855 
epoch 307, loss 0.05883891507983208 
epoch 308, loss 0.05863213539123535 
epoch 309, loss 0.058427069336175

epoch 627, loss 0.030407164245843887 
epoch 628, loss 0.030369002372026443 
epoch 629, loss 0.030330996960401535 
epoch 630, loss 0.030293118208646774 
epoch 631, loss 0.030255382880568504 
epoch 632, loss 0.030217763036489487 
epoch 633, loss 0.030180301517248154 
epoch 634, loss 0.030142920091748238 
epoch 635, loss 0.030105730518698692 
epoch 636, loss 0.030068648979067802 
epoch 637, loss 0.030031686648726463 
epoch 638, loss 0.029994826763868332 
epoch 639, loss 0.029958128929138184 
epoch 640, loss 0.029921555891633034 
epoch 641, loss 0.02988511137664318 
epoch 642, loss 0.029848789796233177 
epoch 643, loss 0.029812566936016083 
epoch 644, loss 0.029776491224765778 
epoch 645, loss 0.029740547761321068 
epoch 646, loss 0.02970471791923046 
epoch 647, loss 0.029669005423784256 
epoch 648, loss 0.029633423313498497 
epoch 649, loss 0.029597951099276543 
epoch 650, loss 0.029562612995505333 
epoch 651, loss 0.029527364298701286 
epoch 652, loss 0.02949225716292858 
epoch 653, loss

epoch 979, loss 0.02199924737215042 
epoch 980, loss 0.021983863785862923 
epoch 981, loss 0.021968552842736244 
epoch 982, loss 0.02195325680077076 
epoch 983, loss 0.021937953308224678 
epoch 984, loss 0.021922709420323372 
epoch 985, loss 0.021907508373260498 
epoch 986, loss 0.021892298012971878 
epoch 987, loss 0.021877119317650795 
epoch 988, loss 0.021862009540200233 
epoch 989, loss 0.021846869960427284 
epoch 990, loss 0.021831799298524857 
epoch 991, loss 0.021816730499267578 
epoch 992, loss 0.021801704540848732 
epoch 993, loss 0.02178669348359108 
epoch 994, loss 0.021771728992462158 
epoch 995, loss 0.02175678126513958 
epoch 996, loss 0.02174186147749424 
epoch 997, loss 0.021726958453655243 
epoch 998, loss 0.021712077781558037 
epoch 999, loss 0.02169724553823471 
epoch 1000, loss 0.021682431921362877 
epoch 1001, loss 0.021667610853910446 
epoch 1002, loss 0.02165289595723152 
epoch 1003, loss 0.02163812704384327 
epoch 1004, loss 0.02162342704832554 
epoch 1005, loss

epoch 1411, loss 0.017077576369047165 
epoch 1412, loss 0.017068779096007347 
epoch 1413, loss 0.017059974372386932 
epoch 1414, loss 0.017051197588443756 
epoch 1415, loss 0.017042428255081177 
epoch 1416, loss 0.017033647745847702 
epoch 1417, loss 0.017024924978613853 
epoch 1418, loss 0.017016150057315826 
epoch 1419, loss 0.017007432878017426 
epoch 1420, loss 0.016998667269945145 
epoch 1421, loss 0.01698993891477585 
epoch 1422, loss 0.01698124408721924 
epoch 1423, loss 0.01697254367172718 
epoch 1424, loss 0.016963820904493332 
epoch 1425, loss 0.01695512980222702 
epoch 1426, loss 0.016946446150541306 
epoch 1427, loss 0.01693776622414589 
epoch 1428, loss 0.016929076984524727 
epoch 1429, loss 0.0169204194098711 
epoch 1430, loss 0.01691177673637867 
epoch 1431, loss 0.016903134062886238 
epoch 1432, loss 0.01689445786178112 
epoch 1433, loss 0.016885830089449883 
epoch 1434, loss 0.016877206042408943 
epoch 1435, loss 0.016868580132722855 
epoch 1436, loss 0.016859957948327

epoch 1659, loss 0.015145214274525642 
epoch 1660, loss 0.015093998052179813 
epoch 1661, loss 0.015094565227627754 
epoch 1662, loss 0.015125470235943794 
epoch 1663, loss 0.015143000520765781 
epoch 1664, loss 0.01511699240654707 
epoch 1665, loss 0.01507297158241272 
epoch 1666, loss 0.0150490403175354 
epoch 1667, loss 0.015057509765028954 
epoch 1668, loss 0.015072527341544628 
epoch 1669, loss 0.015064025297760963 
epoch 1670, loss 0.015036596916615963 
epoch 1671, loss 0.015014903619885445 
epoch 1672, loss 0.015013904310762882 
epoch 1673, loss 0.015020953491330147 
epoch 1674, loss 0.015015692450106144 
epoch 1675, loss 0.014997403137385845 
epoch 1676, loss 0.014980697073042393 
epoch 1677, loss 0.01497633382678032 
epoch 1678, loss 0.014977877959609032 
epoch 1679, loss 0.014972826465964317 
epoch 1680, loss 0.014959516935050488 
epoch 1681, loss 0.014946486800909042 
epoch 1682, loss 0.014940603636205196 
epoch 1683, loss 0.014938818290829659 
epoch 1684, loss 0.01493347622

epoch 2028, loss 0.012842564843595028 
epoch 2029, loss 0.012837124988436699 
epoch 2030, loss 0.012831714935600758 
epoch 2031, loss 0.012826263904571533 
epoch 2032, loss 0.012820884585380554 
epoch 2033, loss 0.012815441004931927 
epoch 2034, loss 0.01281003002077341 
epoch 2035, loss 0.012804622761905193 
epoch 2036, loss 0.012799199670553207 
epoch 2037, loss 0.012793789617717266 
epoch 2038, loss 0.012788394466042519 
epoch 2039, loss 0.012782976031303406 
epoch 2040, loss 0.012777570635080338 
epoch 2041, loss 0.012772179208695889 
epoch 2042, loss 0.01276677381247282 
epoch 2043, loss 0.012761387974023819 
epoch 2044, loss 0.012755995616316795 
epoch 2045, loss 0.012750571593642235 
epoch 2046, loss 0.012745204381644726 
epoch 2047, loss 0.012739831581711769 
epoch 2048, loss 0.01273442804813385 
epoch 2049, loss 0.012729013338685036 
epoch 2050, loss 0.012723693624138832 
epoch 2051, loss 0.01271828729659319 
epoch 2052, loss 0.012712887488305569 
epoch 2053, loss 0.0127075016

epoch 2263, loss 0.011678212322294712 
epoch 2264, loss 0.011673521250486374 
epoch 2265, loss 0.011668801307678223 
epoch 2266, loss 0.011664093472063541 
epoch 2267, loss 0.011659392155706882 
epoch 2268, loss 0.01165466383099556 
epoch 2269, loss 0.011649983935058117 
epoch 2270, loss 0.011645258404314518 
epoch 2271, loss 0.011640523560345173 
epoch 2272, loss 0.011635838076472282 
epoch 2273, loss 0.01163115631788969 
epoch 2274, loss 0.011626468040049076 
epoch 2275, loss 0.011621776036918163 
epoch 2276, loss 0.011617062613368034 
epoch 2277, loss 0.011612395755946636 
epoch 2278, loss 0.01160768885165453 
epoch 2279, loss 0.011603000573813915 
epoch 2280, loss 0.011598372831940651 
epoch 2281, loss 0.011593726463615894 
epoch 2282, loss 0.011589116416871548 
epoch 2283, loss 0.011584523133933544 
epoch 2284, loss 0.011579995974898338 
epoch 2285, loss 0.011575578711926937 
epoch 2286, loss 0.011571329087018967 
epoch 2287, loss 0.01156734861433506 
epoch 2288, loss 0.0115638347

epoch 2502, loss 0.01064330618828535 
epoch 2503, loss 0.010644283145666122 
epoch 2504, loss 0.010635343380272388 
epoch 2505, loss 0.010623099282383919 
epoch 2506, loss 0.010616646148264408 
epoch 2507, loss 0.010616925545036793 
epoch 2508, loss 0.01061722356826067 
epoch 2509, loss 0.010612105019390583 
epoch 2510, loss 0.010603328235447407 
epoch 2511, loss 0.010596397332847118 
epoch 2512, loss 0.01059379056096077 
epoch 2513, loss 0.010592790320515633 
epoch 2514, loss 0.010589411482214928 
epoch 2515, loss 0.010583131574094296 
epoch 2516, loss 0.010576712898910046 
epoch 2517, loss 0.010572511702775955 
epoch 2518, loss 0.010570118203759193 
epoch 2519, loss 0.010567246936261654 
epoch 2520, loss 0.010562625713646412 
epoch 2521, loss 0.010557081550359726 
epoch 2522, loss 0.01055220328271389 
epoch 2523, loss 0.010548721067607403 
epoch 2524, loss 0.010545636527240276 
epoch 2525, loss 0.010541853494942188 
epoch 2526, loss 0.01053721085190773 
epoch 2527, loss 0.01053237356

epoch 2747, loss 0.009677578695118427 
epoch 2748, loss 0.00967391487210989 
epoch 2749, loss 0.009670166298747063 
epoch 2750, loss 0.009666447527706623 
epoch 2751, loss 0.009662702679634094 
epoch 2752, loss 0.009658986702561378 
epoch 2753, loss 0.009655223228037357 
epoch 2754, loss 0.00965154729783535 
epoch 2755, loss 0.009647882543504238 
epoch 2756, loss 0.009644222445786 
epoch 2757, loss 0.009640475735068321 
epoch 2758, loss 0.009636769071221352 
epoch 2759, loss 0.009633001871407032 
epoch 2760, loss 0.009629334323108196 
epoch 2761, loss 0.009625640697777271 
epoch 2762, loss 0.009621924720704556 
epoch 2763, loss 0.00961822085082531 
epoch 2764, loss 0.009614560753107071 
epoch 2765, loss 0.009610859677195549 
epoch 2766, loss 0.009607161395251751 
epoch 2767, loss 0.009603452868759632 
epoch 2768, loss 0.009599718265235424 
epoch 2769, loss 0.009596000425517559 
epoch 2770, loss 0.009592309594154358 
epoch 2771, loss 0.009588642977178097 
epoch 2772, loss 0.009584922343

epoch 3008, loss 0.008747763000428677 
epoch 3009, loss 0.00874435342848301 
epoch 3010, loss 0.00874088890850544 
epoch 3011, loss 0.008737468160688877 
epoch 3012, loss 0.008733974769711494 
epoch 3013, loss 0.008730503730475903 
epoch 3014, loss 0.008727101609110832 
epoch 3015, loss 0.008723604492843151 
epoch 3016, loss 0.008720184676349163 
epoch 3017, loss 0.008716759271919727 
epoch 3018, loss 0.008713338524103165 
epoch 3019, loss 0.008709931746125221 
epoch 3020, loss 0.008706608787178993 
epoch 3021, loss 0.008703256025910378 
epoch 3022, loss 0.008700037375092506 
epoch 3023, loss 0.008697008714079857 
epoch 3024, loss 0.008694254793226719 
epoch 3025, loss 0.00869197491556406 
epoch 3026, loss 0.008690622635185719 
epoch 3027, loss 0.008690728805959225 
epoch 3028, loss 0.008693461306393147 
epoch 3029, loss 0.008700575679540634 
epoch 3030, loss 0.008714825846254826 
epoch 3031, loss 0.008740361779928207 
epoch 3032, loss 0.008780336938798428 
epoch 3033, loss 0.008834907

epoch 3230, loss 0.008021264337003231 
epoch 3231, loss 0.008018020540475845 
epoch 3232, loss 0.008014807477593422 
epoch 3233, loss 0.008011573925614357 
epoch 3234, loss 0.008008407428860664 
epoch 3235, loss 0.008005128242075443 
epoch 3236, loss 0.008001946844160557 
epoch 3237, loss 0.007998806424438953 
epoch 3238, loss 0.007995577529072762 
epoch 3239, loss 0.007992464117705822 
epoch 3240, loss 0.00798932183533907 
epoch 3241, loss 0.007986321114003658 
epoch 3242, loss 0.007983357645571232 
epoch 3243, loss 0.007980579510331154 
epoch 3244, loss 0.007978098466992378 
epoch 3245, loss 0.007975989021360874 
epoch 3246, loss 0.007974623702466488 
epoch 3247, loss 0.007974255830049515 
epoch 3248, loss 0.007975662127137184 
epoch 3249, loss 0.007979799062013626 
epoch 3250, loss 0.00798818189650774 
epoch 3251, loss 0.008002785965800285 
epoch 3252, loss 0.008025486022233963 
epoch 3253, loss 0.008057203143835068 
epoch 3254, loss 0.008093480952084064 
epoch 3255, loss 0.00812257

epoch 3577, loss 0.006975032854825258 
epoch 3578, loss 0.006972132716327906 
epoch 3579, loss 0.006969287060201168 
epoch 3580, loss 0.0069664097391068935 
epoch 3581, loss 0.006963549181818962 
epoch 3582, loss 0.006960692815482616 
epoch 3583, loss 0.006957849022001028 
epoch 3584, loss 0.006954987067729235 
epoch 3585, loss 0.006952071096748114 
epoch 3586, loss 0.0069492775946855545 
epoch 3587, loss 0.006946356035768986 
epoch 3588, loss 0.006943507120013237 
epoch 3589, loss 0.0069406465627253056 
epoch 3590, loss 0.006937773898243904 
epoch 3591, loss 0.0069349380210042 
epoch 3592, loss 0.0069320728071033955 
epoch 3593, loss 0.0069291722029447556 
epoch 3594, loss 0.006926358677446842 
epoch 3595, loss 0.0069234841503202915 
epoch 3596, loss 0.006920627783983946 
epoch 3597, loss 0.0069177658297121525 
epoch 3598, loss 0.006914861965924501 
epoch 3599, loss 0.006912045180797577 
epoch 3600, loss 0.00690919766202569 
epoch 3601, loss 0.0069063385017216206 
epoch 3602, loss 0.0

epoch 3812, loss 0.006317889783531427 
epoch 3813, loss 0.006315203383564949 
epoch 3814, loss 0.0063125016167759895 
epoch 3815, loss 0.006309791933745146 
epoch 3816, loss 0.0063071222975850105 
epoch 3817, loss 0.006304414942860603 
epoch 3818, loss 0.006301715970039368 
epoch 3819, loss 0.006299005355685949 
epoch 3820, loss 0.006296333856880665 
epoch 3821, loss 0.0062936157919466496 
epoch 3822, loss 0.006290921010077 
epoch 3823, loss 0.006288240663707256 
epoch 3824, loss 0.006285538896918297 
epoch 3825, loss 0.006282808259129524 
epoch 3826, loss 0.006280149798840284 
epoch 3827, loss 0.006277427077293396 
epoch 3828, loss 0.0062746950425207615 
epoch 3829, loss 0.006272035650908947 
epoch 3830, loss 0.00626930920407176 
epoch 3831, loss 0.006266655400395393 
epoch 3832, loss 0.0062639592215418816 
epoch 3833, loss 0.006261234171688557 
epoch 3834, loss 0.006258578971028328 
epoch 3835, loss 0.0062558711506426334 
epoch 3836, loss 0.006253193132579327 
epoch 3837, loss 0.0062

epoch 4168, loss 0.005409634672105312 
epoch 4169, loss 0.005411614663898945 
epoch 4170, loss 0.005403755698353052 
epoch 4171, loss 0.005392863415181637 
epoch 4172, loss 0.0053878771141171455 
epoch 4173, loss 0.005389708559960127 
epoch 4174, loss 0.005391939543187618 
epoch 4175, loss 0.005388904828578234 
epoch 4176, loss 0.005381749011576176 
epoch 4177, loss 0.005375843029469252 
epoch 4178, loss 0.005374290514737368 
epoch 4179, loss 0.005374937783926725 
epoch 4180, loss 0.005373804364353418 
epoch 4181, loss 0.005369633436203003 
epoch 4182, loss 0.005364554934203625 
epoch 4183, loss 0.005361237097531557 
epoch 4184, loss 0.005360099021345377 
epoch 4185, loss 0.005359260831028223 
epoch 4186, loss 0.005356834270060062 
epoch 4187, loss 0.0053529939614236355 
epoch 4188, loss 0.005349310114979744 
epoch 4189, loss 0.0053469520062208176 
epoch 4190, loss 0.005345477722585201 
epoch 4191, loss 0.0053437077440321445 
epoch 4192, loss 0.005340952891856432 
epoch 4193, loss 0.00

epoch 4398, loss 0.0048575205728411674 
epoch 4399, loss 0.004851380363106728 
epoch 4400, loss 0.004854364320635796 
epoch 4401, loss 0.004858790896832943 
epoch 4402, loss 0.0048570334911346436 
epoch 4403, loss 0.004849113058298826 
epoch 4404, loss 0.00484111811965704 
epoch 4405, loss 0.004838147200644016 
epoch 4406, loss 0.004839481320232153 
epoch 4407, loss 0.00484036048874259 
epoch 4408, loss 0.0048374030739068985 
epoch 4409, loss 0.004831711295992136 
epoch 4410, loss 0.004826765041798353 
epoch 4411, loss 0.004824773874133825 
epoch 4412, loss 0.004824659787118435 
epoch 4413, loss 0.004823796916753054 
epoch 4414, loss 0.004820824600756168 
epoch 4415, loss 0.004816601052880287 
epoch 4416, loss 0.004813049919903278 
epoch 4417, loss 0.004811177030205727 
epoch 4418, loss 0.004810108803212643 
epoch 4419, loss 0.004808507394045591 
epoch 4420, loss 0.004805779550224543 
epoch 4421, loss 0.0048024761490523815 
epoch 4422, loss 0.004799605812877417 
epoch 4423, loss 0.0047

epoch 4813, loss 0.003971452359110117 
epoch 4814, loss 0.0039694784209132195 
epoch 4815, loss 0.003967467229813337 
epoch 4816, loss 0.003965480253100395 
epoch 4817, loss 0.003963496536016464 
epoch 4818, loss 0.003961507696658373 
epoch 4819, loss 0.003959499299526215 
epoch 4820, loss 0.003957543056458235 
epoch 4821, loss 0.003955542575567961 
epoch 4822, loss 0.003953556530177593 
epoch 4823, loss 0.0039515565149486065 
epoch 4824, loss 0.0039495606906712055 
epoch 4825, loss 0.003947578836232424 
epoch 4826, loss 0.003945586271584034 
epoch 4827, loss 0.003943626303225756 
epoch 4828, loss 0.003941652365028858 
epoch 4829, loss 0.003939658869057894 
epoch 4830, loss 0.00393767561763525 
epoch 4831, loss 0.003935685846954584 
epoch 4832, loss 0.003933694679290056 
epoch 4833, loss 0.003931727726012468 
epoch 4834, loss 0.003929764032363892 
epoch 4835, loss 0.003927768673747778 
epoch 4836, loss 0.003925805911421776 
epoch 4837, loss 0.0039238715544342995 
epoch 4838, loss 0.003

epoch 5035, loss 0.0035520908422768116 
epoch 5036, loss 0.0035503054969012737 
epoch 5037, loss 0.003548477543517947 
epoch 5038, loss 0.003546673571690917 
epoch 5039, loss 0.003544823732227087 
epoch 5040, loss 0.003543056547641754 
epoch 5041, loss 0.0035412160214036703 
epoch 5042, loss 0.0035394050646573305 
epoch 5043, loss 0.003537548938766122 
epoch 5044, loss 0.0035357887391000986 
epoch 5045, loss 0.0035339572932571173 
epoch 5046, loss 0.0035321225877851248 
epoch 5047, loss 0.0035303609911352396 
epoch 5048, loss 0.0035285097546875477 
epoch 5049, loss 0.0035266787745058537 
epoch 5050, loss 0.0035249332431703806 
epoch 5051, loss 0.003523109946399927 
epoch 5052, loss 0.00352127174846828 
epoch 5053, loss 0.003519483143463731 
epoch 5054, loss 0.003517660079523921 
epoch 5055, loss 0.0035158449318259954 
epoch 5056, loss 0.0035140400286763906 
epoch 5057, loss 0.003512245137244463 
epoch 5058, loss 0.0035104602575302124 
epoch 5059, loss 0.0035086290445178747 
epoch 5060,

epoch 5381, loss 0.0029700256418436766 
epoch 5382, loss 0.002968456130474806 
epoch 5383, loss 0.0029668521601706743 
epoch 5384, loss 0.002965281717479229 
epoch 5385, loss 0.002963717794045806 
epoch 5386, loss 0.0029621373396366835 
epoch 5387, loss 0.002960583893582225 
epoch 5388, loss 0.0029590048361569643 
epoch 5389, loss 0.0029574783984571695 
epoch 5390, loss 0.0029558732639998198 
epoch 5391, loss 0.0029543312266469 
epoch 5392, loss 0.0029527531005442142 
epoch 5393, loss 0.002951208967715502 
epoch 5394, loss 0.002949605928733945 
epoch 5395, loss 0.0029480555094778538 
epoch 5396, loss 0.002946472493931651 
epoch 5397, loss 0.00294494372792542 
epoch 5398, loss 0.002943352796137333 
epoch 5399, loss 0.002941801445558667 
epoch 5400, loss 0.0029402519576251507 
epoch 5401, loss 0.0029386919923126698 
epoch 5402, loss 0.002937135985121131 
epoch 5403, loss 0.002935613738372922 
epoch 5404, loss 0.002933999290689826 
epoch 5405, loss 0.002932443516328931 
epoch 5406, loss 0

epoch 5773, loss 0.0024082933086901903 
epoch 5774, loss 0.0024069608189165592 
epoch 5775, loss 0.0024056383408606052 
epoch 5776, loss 0.0024043400771915913 
epoch 5777, loss 0.002403017831966281 
epoch 5778, loss 0.0024017025716602802 
epoch 5779, loss 0.0024003952275961637 
epoch 5780, loss 0.0023990620393306017 
epoch 5781, loss 0.0023977705277502537 
epoch 5782, loss 0.002396434312686324 
epoch 5783, loss 0.0023951493203639984 
epoch 5784, loss 0.002393814269453287 
epoch 5785, loss 0.0023925083223730326 
epoch 5786, loss 0.002391175599768758 
epoch 5787, loss 0.0023898971267044544 
epoch 5788, loss 0.0023885779082775116 
epoch 5789, loss 0.002387271961197257 
epoch 5790, loss 0.0023859681095927954 
epoch 5791, loss 0.0023846891708672047 
epoch 5792, loss 0.0023833741433918476 
epoch 5793, loss 0.0023820623755455017 
epoch 5794, loss 0.002380742458626628 
epoch 5795, loss 0.0023794430308043957 
epoch 5796, loss 0.002378139179199934 
epoch 5797, loss 0.0023768225219100714 
epoch 5

epoch 6127, loss 0.001988898729905486 
epoch 6128, loss 0.001987806288525462 
epoch 6129, loss 0.001986725488677621 
epoch 6130, loss 0.0019856400322169065 
epoch 6131, loss 0.0019845315255224705 
epoch 6132, loss 0.0019834500271826982 
epoch 6133, loss 0.0019823648035526276 
epoch 6134, loss 0.001981286099180579 
epoch 6135, loss 0.0019802036695182323 
epoch 6136, loss 0.00197910750284791 
epoch 6137, loss 0.0019780455622822046 
epoch 6138, loss 0.0019769396167248487 
epoch 6139, loss 0.001975876046344638 
epoch 6140, loss 0.0019747980404645205 
epoch 6141, loss 0.0019736881367862225 
epoch 6142, loss 0.001972620142623782 
epoch 6143, loss 0.001971562160179019 
epoch 6144, loss 0.001970474375411868 
epoch 6145, loss 0.0019693844951689243 
epoch 6146, loss 0.0019683053251355886 
epoch 6147, loss 0.0019672615453600883 
epoch 6148, loss 0.001966154668480158 
epoch 6149, loss 0.001965085044503212 
epoch 6150, loss 0.0019640030805021524 
epoch 6151, loss 0.0019629313610494137 
epoch 6152, 

epoch 6417, loss 0.0016931849531829357 
epoch 6418, loss 0.0016922299982979894 
epoch 6419, loss 0.0016912574646994472 
epoch 6420, loss 0.0016903403447940946 
epoch 6421, loss 0.0016893831780180335 
epoch 6422, loss 0.0016884164651855826 
epoch 6423, loss 0.0016874877037480474 
epoch 6424, loss 0.0016865261131897569 
epoch 6425, loss 0.0016855846624821424 
epoch 6426, loss 0.001684633200056851 
epoch 6427, loss 0.0016836781287565827 
epoch 6428, loss 0.0016827190993353724 
epoch 6429, loss 0.0016817764844745398 
epoch 6430, loss 0.0016808488871902227 
epoch 6431, loss 0.0016798845026642084 
epoch 6432, loss 0.0016789360670372844 
epoch 6433, loss 0.0016779900761321187 
epoch 6434, loss 0.001677040127106011 
epoch 6435, loss 0.0016761079896241426 
epoch 6436, loss 0.0016751649091020226 
epoch 6437, loss 0.001674216240644455 
epoch 6438, loss 0.0016732565127313137 
epoch 6439, loss 0.0016723336884751916 
epoch 6440, loss 0.0016714067896828055 
epoch 6441, loss 0.0016704824520274997 
epo

epoch 6738, loss 0.0014139882987365127 
epoch 6739, loss 0.0014132015639916062 
epoch 6740, loss 0.0014124122681096196 
epoch 6741, loss 0.001411589328199625 
epoch 6742, loss 0.0014107865281403065 
epoch 6743, loss 0.001410007826052606 
epoch 6744, loss 0.0014092139899730682 
epoch 6745, loss 0.001408413052558899 
epoch 6746, loss 0.0014076113002374768 
epoch 6747, loss 0.0014067954616621137 
epoch 6748, loss 0.0014060239773243666 
epoch 6749, loss 0.0014052080223336816 
epoch 6750, loss 0.0014043966075405478 
epoch 6751, loss 0.001403627684339881 
epoch 6752, loss 0.0014028240693733096 
epoch 6753, loss 0.0014020373346284032 
epoch 6754, loss 0.001401259913109243 
epoch 6755, loss 0.0014004678232595325 
epoch 6756, loss 0.0013996540801599622 
epoch 6757, loss 0.0013988581486046314 
epoch 6758, loss 0.00139806407969445 
epoch 6759, loss 0.0013972831657156348 
epoch 6760, loss 0.0013964889803901315 
epoch 6761, loss 0.0013956811744719744 
epoch 6762, loss 0.0013949015410616994 
epoch 6

epoch 7043, loss 0.0011920217657461762 
epoch 7044, loss 0.0011913706548511982 
epoch 7045, loss 0.001190682640299201 
epoch 7046, loss 0.0011900325771421194 
epoch 7047, loss 0.0011893659830093384 
epoch 7048, loss 0.0011887216242030263 
epoch 7049, loss 0.001188049092888832 
epoch 7050, loss 0.0011874132324010134 
epoch 7051, loss 0.0011867467546835542 
epoch 7052, loss 0.0011860898230224848 
epoch 7053, loss 0.0011854150798171759 
epoch 7054, loss 0.001184772583656013 
epoch 7055, loss 0.0011840991210192442 
epoch 7056, loss 0.0011834405595436692 
epoch 7057, loss 0.0011827977141365409 
epoch 7058, loss 0.0011821499792858958 
epoch 7059, loss 0.0011814830359071493 
epoch 7060, loss 0.0011808163253590465 
epoch 7061, loss 0.001180157414637506 
epoch 7062, loss 0.0011795054888352752 
epoch 7063, loss 0.001178865204565227 
epoch 7064, loss 0.0011782106012105942 
epoch 7065, loss 0.0011775498278439045 
epoch 7066, loss 0.0011768742697313428 
epoch 7067, loss 0.0011762288631871343 
epoch

epoch 7298, loss 0.0010335748083889484 
epoch 7299, loss 0.0010329923825338483 
epoch 7300, loss 0.0010324151953682303 
epoch 7301, loss 0.0010318311396986246 
epoch 7302, loss 0.0010312425438314676 
epoch 7303, loss 0.001030677929520607 
epoch 7304, loss 0.0010301050497218966 
epoch 7305, loss 0.0010295140091329813 
epoch 7306, loss 0.0010289415949955583 
epoch 7307, loss 0.0010283553274348378 
epoch 7308, loss 0.0010277662659063935 
epoch 7309, loss 0.001027201535180211 
epoch 7310, loss 0.0010266279568895698 
epoch 7311, loss 0.0010260391281917691 
epoch 7312, loss 0.0010254571679979563 
epoch 7313, loss 0.0010248646140098572 
epoch 7314, loss 0.0010242968564853072 
epoch 7315, loss 0.0010237176902592182 
epoch 7316, loss 0.0010231430642306805 
epoch 7317, loss 0.001022560172714293 
epoch 7318, loss 0.0010219925316050649 
epoch 7319, loss 0.0010214179055765271 
epoch 7320, loss 0.0010208452586084604 
epoch 7321, loss 0.001020276453346014 
epoch 7322, loss 0.0010196934454143047 
epoc

epoch 7630, loss 0.0008559087291359901 
epoch 7631, loss 0.0008554170490242541 
epoch 7632, loss 0.0008549338090233505 
epoch 7633, loss 0.0008544367155991495 
epoch 7634, loss 0.0008539654081687331 
epoch 7635, loss 0.0008534545777365565 
epoch 7636, loss 0.0008529787883162498 
epoch 7637, loss 0.0008524945587851107 
epoch 7638, loss 0.0008519875118508935 
epoch 7639, loss 0.0008515011868439615 
epoch 7640, loss 0.0008510214393027127 
epoch 7641, loss 0.0008505320292897522 
epoch 7642, loss 0.0008500387775711715 
epoch 7643, loss 0.0008495569927617908 
epoch 7644, loss 0.000849067815579474 
epoch 7645, loss 0.0008485897560603917 
epoch 7646, loss 0.0008481083204969764 
epoch 7647, loss 0.0008476156508550048 
epoch 7648, loss 0.0008471266482956707 
epoch 7649, loss 0.0008466523140668869 
epoch 7650, loss 0.0008461536490358412 
epoch 7651, loss 0.0008456691866740584 
epoch 7652, loss 0.0008451808243989944 
epoch 7653, loss 0.0008446939173154533 
epoch 7654, loss 0.0008442261605523527 
e

epoch 7930, loss 0.0007272749789990485 
epoch 7931, loss 0.0007268933113664389 
epoch 7932, loss 0.0007265136810019612 
epoch 7933, loss 0.0007261399878188968 
epoch 7934, loss 0.0007257727556861937 
epoch 7935, loss 0.0007254018564708531 
epoch 7936, loss 0.0007250231574289501 
epoch 7937, loss 0.0007246492896229029 
epoch 7938, loss 0.0007242606952786446 
epoch 7939, loss 0.0007238878752104938 
epoch 7940, loss 0.0007235266384668648 
epoch 7941, loss 0.0007231517811305821 
epoch 7942, loss 0.0007227788446471095 
epoch 7943, loss 0.0007224070141091943 
epoch 7944, loss 0.0007220233674161136 
epoch 7945, loss 0.0007216533413156867 
epoch 7946, loss 0.000721280463039875 
epoch 7947, loss 0.0007209084578789771 
epoch 7948, loss 0.0007205307483673096 
epoch 7949, loss 0.0007201575208455324 
epoch 7950, loss 0.0007197913946583867 
epoch 7951, loss 0.000719416537322104 
epoch 7952, loss 0.0007190416217781603 
epoch 7953, loss 0.0007186774746514857 
epoch 7954, loss 0.0007183064008131623 
ep

epoch 8149, loss 0.0006490873056463897 
epoch 8150, loss 0.0006487310747615993 
epoch 8151, loss 0.0006483980105258524 
epoch 8152, loss 0.0006480671581812203 
epoch 8153, loss 0.0006477158749476075 
epoch 8154, loss 0.000647399399895221 
epoch 8155, loss 0.0006470626103691757 
epoch 8156, loss 0.0006467151688411832 
epoch 8157, loss 0.000646380300167948 
epoch 8158, loss 0.0006460454314947128 
epoch 8159, loss 0.0006457055569626391 
epoch 8160, loss 0.000645363936200738 
epoch 8161, loss 0.0006450311047956347 
epoch 8162, loss 0.0006446932093240321 
epoch 8163, loss 0.0006443574675358832 
epoch 8164, loss 0.0006440108409151435 
epoch 8165, loss 0.0006436720723286271 
epoch 8166, loss 0.0006433401140384376 
epoch 8167, loss 0.0006430130451917648 
epoch 8168, loss 0.0006426731706596911 
epoch 8169, loss 0.0006423324230127037 
epoch 8170, loss 0.0006419955170713365 
epoch 8171, loss 0.0006416703690774739 
epoch 8172, loss 0.0006413374212570488 
epoch 8173, loss 0.0006410054629668593 
epo

epoch 8522, loss 0.0005334625020623207 
epoch 8523, loss 0.000533174374140799 
epoch 8524, loss 0.0005328909610398114 
epoch 8525, loss 0.0005326125537976623 
epoch 8526, loss 0.0005323312943801284 
epoch 8527, loss 0.0005320558557286859 
epoch 8528, loss 0.0005317607428878546 
epoch 8529, loss 0.000531485304236412 
epoch 8530, loss 0.0005312096909619868 
epoch 8531, loss 0.0005309234838932753 
epoch 8532, loss 0.0005306460661813617 
epoch 8533, loss 0.0005303627112880349 
epoch 8534, loss 0.0005300686461851001 
epoch 8535, loss 0.0005297959432937205 
epoch 8536, loss 0.0005295078735798597 
epoch 8537, loss 0.0005292342393659055 
epoch 8538, loss 0.0005289596738293767 
epoch 8539, loss 0.0005286754458211362 
epoch 8540, loss 0.000528393080458045 
epoch 8541, loss 0.0005281263147480786 
epoch 8542, loss 0.000527836091350764 
epoch 8543, loss 0.000527557625900954 
epoch 8544, loss 0.0005272772395983338 
epoch 8545, loss 0.0005269899847917259 
epoch 8546, loss 0.0005267203086987138 
epoch

epoch 8895, loss 0.0004372941912151873 
epoch 8896, loss 0.0004370560636743903 
epoch 8897, loss 0.000436820846516639 
epoch 8898, loss 0.00043657986680045724 
epoch 8899, loss 0.00043635640759021044 
epoch 8900, loss 0.0004361104220151901 
epoch 8901, loss 0.0004358792502898723 
epoch 8902, loss 0.00043565777014009655 
epoch 8903, loss 0.0004354235716164112 
epoch 8904, loss 0.00043517854646779597 
epoch 8905, loss 0.0004349473165348172 
epoch 8906, loss 0.0004347180074546486 
epoch 8907, loss 0.00043448180076666176 
epoch 8908, loss 0.000434250570833683 
epoch 8909, loss 0.00043402129085734487 
epoch 8910, loss 0.00043377920519560575 
epoch 8911, loss 0.0004335548437666148 
epoch 8912, loss 0.0004333215765655041 
epoch 8913, loss 0.00043308830936439335 
epoch 8914, loss 0.0004328560607973486 
epoch 8915, loss 0.0004326287016738206 
epoch 8916, loss 0.0004323925531934947 
epoch 8917, loss 0.0004321602755226195 
epoch 8918, loss 0.00043192997691221535 
epoch 8919, loss 0.00043169577838

epoch 9131, loss 0.00038516975473612547 
epoch 9132, loss 0.0003849616623483598 
epoch 9133, loss 0.0003847554908134043 
epoch 9134, loss 0.00038455333560705185 
epoch 9135, loss 0.0003843471931759268 
epoch 9136, loss 0.00038413910078816116 
epoch 9137, loss 0.0003839330456685275 
epoch 9138, loss 0.0003837151452898979 
epoch 9139, loss 0.00038351293187588453 
epoch 9140, loss 0.0003833107475657016 
epoch 9141, loss 0.0003831075446214527 
epoch 9142, loss 0.0003828935441561043 
epoch 9143, loss 0.0003826913598459214 
epoch 9144, loss 0.000382479396648705 
epoch 9145, loss 0.00038227226468734443 
epoch 9146, loss 0.0003820651618298143 
epoch 9147, loss 0.00038186684832908213 
epoch 9148, loss 0.00038166268495842814 
epoch 9149, loss 0.000381452584406361 
epoch 9150, loss 0.0003812425711657852 
epoch 9151, loss 0.00038103939732536674 
epoch 9152, loss 0.0003808302863035351 
epoch 9153, loss 0.0003806290333159268 
epoch 9154, loss 0.0003804219886660576 
epoch 9155, loss 0.000380219746148

epoch 9534, loss 0.0003096169966738671 
epoch 9535, loss 0.0003094526764471084 
epoch 9536, loss 0.00030928055639378726 
epoch 9537, loss 0.0003091143153142184 
epoch 9538, loss 0.0003089470264967531 
epoch 9539, loss 0.0003087818040512502 
epoch 9540, loss 0.0003086184442508966 
epoch 9541, loss 0.00030844329739920795 
epoch 9542, loss 0.0003082790062762797 
epoch 9543, loss 0.00030811273609288037 
epoch 9544, loss 0.0003079405287280679 
epoch 9545, loss 0.00030777137726545334 
epoch 9546, loss 0.0003076080174650997 
epoch 9547, loss 0.0003074387786909938 
epoch 9548, loss 0.00030727352714166045 
epoch 9549, loss 0.00030710428836755455 
epoch 9550, loss 0.0003069340600632131 
epoch 9551, loss 0.00030676875030621886 
epoch 9552, loss 0.00030660250922665 
epoch 9553, loss 0.0003064312622882426 
epoch 9554, loss 0.00030627287924289703 
epoch 9555, loss 0.0003060987510252744 
epoch 9556, loss 0.0003059323353227228 
epoch 9557, loss 0.0003057671128772199 
epoch 9558, loss 0.000305598863633

epoch 9752, loss 0.00027492738445289433 
epoch 9753, loss 0.0002747717662714422 
epoch 9754, loss 0.0002746279933489859 
epoch 9755, loss 0.000274479272775352 
epoch 9756, loss 0.00027432857314124703 
epoch 9757, loss 0.0002741847711149603 
epoch 9758, loss 0.0002740341005846858 
epoch 9759, loss 0.00027387653244659305 
epoch 9760, loss 0.00027373371995054185 
epoch 9761, loss 0.00027358104125596583 
epoch 9762, loss 0.000273434299742803 
epoch 9763, loss 0.0002732875873334706 
epoch 9764, loss 0.00027313685859553516 
epoch 9765, loss 0.00027298516943119466 
epoch 9766, loss 0.00027283935924060643 
epoch 9767, loss 0.00027267978293821216 
epoch 9768, loss 0.0002725379599723965 
epoch 9769, loss 0.0002723892102949321 
epoch 9770, loss 0.00027223851066082716 
epoch 9771, loss 0.0002720888005569577 
epoch 9772, loss 0.00027194106951355934 
epoch 9773, loss 0.0002717903407756239 
epoch 9774, loss 0.00027164554921910167 
epoch 9775, loss 0.00027149191009812057 
epoch 9776, loss 0.0002713412

epoch 9969, loss 0.00024415034567937255 
epoch 9970, loss 0.00024401926202699542 
epoch 9971, loss 0.0002438832016196102 
epoch 9972, loss 0.0002437481307424605 
epoch 9973, loss 0.0002436091162962839 
epoch 9974, loss 0.00024347702856175601 
epoch 9975, loss 0.0002433478948660195 
epoch 9976, loss 0.00024321679666172713 
epoch 9977, loss 0.00024308761931024492 
epoch 9978, loss 0.00024294962349813432 
epoch 9979, loss 0.00024282047525048256 
epoch 9980, loss 0.0002426804567221552 
epoch 9981, loss 0.00024255427706521004 
epoch 9982, loss 0.00024241922073997557 
epoch 9983, loss 0.0002422880643280223 
epoch 9984, loss 0.00024215300800278783 
epoch 9985, loss 0.00024201892665587366 
epoch 9986, loss 0.00024188679526560009 
epoch 9987, loss 0.00024175667203962803 
epoch 9988, loss 0.0002416215866105631 
epoch 9989, loss 0.0002414865157334134 
epoch 9990, loss 0.00024135931744240224 
epoch 9991, loss 0.00024122028844431043 
epoch 9992, loss 0.00024109410878736526 
epoch 9993, loss 0.00024

In [46]:
def greates_3(x):
    if x[0]> x[1]:
        if(x[0]>x[2]):
            return x[0]
        else:
            return x[2]
    elif(x[1]>x[2]):
        return x[1]
    else:
        return x[2]
greates_3([2,3,1])

3

In [47]:
pred_y=model.predict(test_x)



In [61]:
test_x.shape

torch.Size([30, 4])

In [66]:
x=torch.tensor([[8.6,7.3,2.0,1.3]])

In [67]:
model.predict(x)



tensor([0])

In [51]:
test_x

tensor([[7.7000, 2.8000, 6.7000, 2.0000],
        [5.7000, 2.5000, 5.0000, 2.0000],
        [4.6000, 3.6000, 1.0000, 0.2000],
        [6.0000, 2.7000, 5.1000, 1.6000],
        [6.1000, 2.6000, 5.6000, 1.4000],
        [4.3000, 3.0000, 1.1000, 0.1000],
        [5.6000, 2.7000, 4.2000, 1.3000],
        [7.7000, 3.8000, 6.7000, 2.2000],
        [6.1000, 3.0000, 4.9000, 1.8000],
        [5.6000, 3.0000, 4.5000, 1.5000],
        [6.2000, 2.8000, 4.8000, 1.8000],
        [5.4000, 3.4000, 1.5000, 0.4000],
        [6.0000, 2.2000, 4.0000, 1.0000],
        [5.7000, 2.9000, 4.2000, 1.3000],
        [6.8000, 2.8000, 4.8000, 1.4000],
        [5.2000, 2.7000, 3.9000, 1.4000],
        [7.1000, 3.0000, 5.9000, 2.1000],
        [5.2000, 3.4000, 1.4000, 0.2000],
        [6.9000, 3.1000, 5.4000, 2.1000],
        [5.0000, 3.2000, 1.2000, 0.2000],
        [6.3000, 2.3000, 4.4000, 1.3000],
        [4.4000, 3.2000, 1.3000, 0.2000],
        [6.2000, 2.9000, 4.3000, 1.3000],
        [6.3000, 3.4000, 5.6000, 2

In [48]:
pred_y

tensor([2, 2, 0, 2, 2, 0, 1, 2, 2, 1, 2, 0, 1, 1, 1, 1, 2, 0, 2, 0, 1, 0, 1, 2,
        2, 2, 1, 0, 2, 2])

# testing error

In [49]:
# Python script for confusion matrix creation. 
from sklearn.metrics import confusion_matrix 
from sklearn.metrics import accuracy_score 
from sklearn.metrics import classification_report 

results = confusion_matrix(test_y, pred_y) 
  
print('Confusion Matrix :')
print(results) 
print('Accuracy Score :',accuracy_score(test_y, pred_y)) 
print('Report : ')
print(classification_report(test_y, pred_y)) 

Confusion Matrix :
[[ 7  0  0]
 [ 0  9  1]
 [ 0  0 13]]
Accuracy Score : 0.9666666666666667
Report : 
              precision    recall  f1-score   support

           0       1.00      1.00      1.00         7
           1       1.00      0.90      0.95        10
           2       0.93      1.00      0.96        13

    accuracy                           0.97        30
   macro avg       0.98      0.97      0.97        30
weighted avg       0.97      0.97      0.97        30



# training error

In [50]:
results = confusion_matrix(train_y, model.predict(train_x)) 
print('Confusion Matrix :')
print(results) 
print('Accuracy Score :',accuracy_score(train_y, model.predict(train_x))) 
print('Report : ')
print(classification_report(train_y,model.predict(train_x))) 

Confusion Matrix :
[[43  0  0]
 [ 0 40  0]
 [ 0  0 37]]
Accuracy Score : 1.0
Report : 
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        43
           1       1.00      1.00      1.00        40
           2       1.00      1.00      1.00        37

    accuracy                           1.00       120
   macro avg       1.00      1.00      1.00       120
weighted avg       1.00      1.00      1.00       120



