In [1]:
import ast
import gensim.downloader as api
import pandas as pd
import numpy as np


wv = api.load('word2vec-google-news-300')

In [2]:
df = pd.DataFrame({"w2v_vec": [], "label": []})

ft_size = 10
for topic in ['politics_cleaned', 'science_cleaned', 'sports_cleaned', 'weather_cleaned', 'worldnews_cleaned']:
    print(topic)
    df_read = pd.read_csv(topic + '.csv')
    df_read = df_read.dropna()
    df_read['cleaned_article'] = df_read['cleaned_article'].map(lambda x: ast.literal_eval(x))
    df_read['l'] = df_read['cleaned_article'].map(lambda x: len(x))
    df_read = df_read[df_read['l'] != 0]
    df_read['cleaned_article'] = df_read['cleaned_article'].map(lambda x: [y for y in x if y in wv])
    df_read['l'] = df_read['cleaned_article'].map(lambda x: len(x))
    df_read = df_read[df_read['l'] != 0]
    df_read['w2v_vec'] = df_read['cleaned_article'].map(lambda x: wv[x])
    df_read['label'] = topic
    df_read = df_read[['w2v_vec', 'label']]
    df_read['w2v_vec'] = df_read['w2v_vec'].map(lambda x: [y[:ft_size] for y in x])
    df = pd.concat([df, df_read])

label_dict = {'politics_cleaned': 0,
            'science_cleaned': 1,
            'sports_cleaned': 2,
            'weather_cleaned': 3,
            'worldnews_cleaned': 4}

df['label'] = df['label'].map(lambda x: label_dict[x])

politics_cleaned
science_cleaned
sports_cleaned
weather_cleaned
worldnews_cleaned


In [3]:
sample_size = 1000
df_sample = df.sample(n=sample_size, random_state=1)

df_sample['dims'] = df_sample['w2v_vec'].map(lambda x: np.array(x).shape)
mx = max(x[0] for x in df_sample['dims'])

def pad(arr, padding):
    diff = padding - len(arr)
    arr = np.pad(arr, ((0, diff), (0, 0)), 'constant')
    return arr
    
df_sample['w2v_vec_padded'] = df_sample['w2v_vec'].map(lambda x: pad(x, mx))
df_sample['w2v_vec_padded'] = df_sample['w2v_vec_padded'].map(lambda x: x.ravel())
df_sample['dims'] = df_sample['w2v_vec_padded'].map(lambda x: x.shape)

df_sample

Unnamed: 0,w2v_vec,label,dims,w2v_vec_padded
859,"[[-0.13085938, 0.006134033, 0.0138549805, 0.09...",3,"(164530,)","[-0.13085938, 0.006134033, 0.0138549805, 0.090..."
789,"[[-0.030761719, 0.04345703, 0.15527344, 0.3769...",4,"(164530,)","[-0.030761719, 0.04345703, 0.15527344, 0.37695..."
440,"[[0.14355469, 0.23632812, 0.16113281, 0.220703...",4,"(164530,)","[0.14355469, 0.23632812, 0.16113281, 0.2207031..."
781,"[[0.099609375, -0.119628906, -0.048339844, 0.1...",3,"(164530,)","[0.099609375, -0.119628906, -0.048339844, 0.17..."
317,"[[0.099609375, -0.119628906, -0.048339844, 0.1...",3,"(164530,)","[0.099609375, -0.119628906, -0.048339844, 0.17..."
...,...,...,...,...
292,"[[-0.00030708313, -0.027709961, -0.13769531, 0...",4,"(164530,)","[-0.00030708313, -0.027709961, -0.13769531, 0...."
397,"[[0.009033203, 0.12158203, 0.3984375, 0.213867...",1,"(164530,)","[0.009033203, 0.12158203, 0.3984375, 0.2138671..."
919,"[[0.099609375, -0.119628906, -0.048339844, 0.1...",3,"(164530,)","[0.099609375, -0.119628906, -0.048339844, 0.17..."
68,"[[0.013305664, 0.029907227, -0.0049438477, 0.0...",0,"(164530,)","[0.013305664, 0.029907227, -0.0049438477, 0.00..."


In [4]:
print(mx)

16453


In [5]:
df_sample['w2v_vec_padded'].iloc[0].shape

(164530,)

In [6]:
df_sample['label'].value_counts()

0    242
1    240
2    194
4    173
3    151
Name: label, dtype: int64

In [7]:
from sklearn.model_selection import train_test_split

X, y = df_sample['w2v_vec_padded'], np.array(df_sample['label'])
X = np.array([x.reshape(-1, 1) for x in X])
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1)

In [8]:
import torch
from torch.autograd import Variable
import torch.nn as nn


input_size = X_train.shape[1]
output_size = len(set(y_train))


class DNNLog(torch.nn.Module):

    def __init__(self, input_size, output_size):
        super(DNNLog, self).__init__()
        self.linear = torch.nn.Linear(input_size, output_size)

    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred

In [9]:
X_train.shape

(750, 164530, 1)

In [None]:
from sklearn import preprocessing


torch_X_train = torch.from_numpy(np.asarray(X_train).astype(np.float16)).view(X_train.shape[0], input_size)
torch_y_train = Variable(torch.as_tensor(y_train))

torch_X_test = torch.from_numpy(np.asarray(X_test).astype(np.float16)).view(X_test.shape[0], input_size)
torch_y_test = Variable(torch.as_tensor(y_test))

model = DNNLog(input_size, output_size)

learning_rate = 0.1
l = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

num_epochs = 10000
for epoch in range(num_epochs):
    optimizer.zero_grad()
    y_pred = model(torch_X_train.float())
    loss = l(y_pred, torch_y_train)
    loss.backward()
    optimizer.step()
    print('epoch {}, loss {}'.format(epoch, loss.item()))

epoch 0, loss 1.6098663806915283
epoch 1, loss 1.6057137250900269
epoch 2, loss 1.601664662361145
epoch 3, loss 1.5977369546890259
epoch 4, loss 1.5939353704452515
epoch 5, loss 1.590273380279541
epoch 6, loss 1.5867483615875244
epoch 7, loss 1.5833691358566284
epoch 8, loss 1.5801262855529785
epoch 9, loss 1.5770206451416016
epoch 10, loss 1.5740476846694946
epoch 11, loss 1.5712000131607056
epoch 12, loss 1.5684735774993896
epoch 13, loss 1.5658621788024902
epoch 14, loss 1.5633567571640015
epoch 15, loss 1.560955286026001
epoch 16, loss 1.5586490631103516
epoch 17, loss 1.5564345121383667
epoch 18, loss 1.5543038845062256
epoch 19, loss 1.5522540807724
epoch 20, loss 1.5502766370773315
epoch 21, loss 1.5483700037002563
epoch 22, loss 1.546531081199646
epoch 23, loss 1.5447516441345215
epoch 24, loss 1.543031930923462
epoch 25, loss 1.5413631200790405
epoch 26, loss 1.5397498607635498
epoch 27, loss 1.5381827354431152
epoch 28, loss 1.5366607904434204
epoch 29, loss 1.535181164741516

epoch 240, loss 1.3842957019805908
epoch 241, loss 1.3838005065917969
epoch 242, loss 1.383306622505188
epoch 243, loss 1.3828155994415283
epoch 244, loss 1.3823246955871582
epoch 245, loss 1.3818360567092896
epoch 246, loss 1.3813486099243164
epoch 247, loss 1.3808623552322388
epoch 248, loss 1.3803772926330566
epoch 249, loss 1.379894495010376
epoch 250, loss 1.3794111013412476
epoch 251, loss 1.3789308071136475
epoch 252, loss 1.3784527778625488
epoch 253, loss 1.3779743909835815
epoch 254, loss 1.377497911453247
epoch 255, loss 1.3770239353179932
epoch 256, loss 1.3765506744384766
epoch 257, loss 1.376078724861145
epoch 258, loss 1.3756080865859985
epoch 259, loss 1.375139832496643
epoch 260, loss 1.3746720552444458
epoch 261, loss 1.3742057085037231
epoch 262, loss 1.3737410306930542
epoch 263, loss 1.3732789754867554
epoch 264, loss 1.3728152513504028
epoch 265, loss 1.3723548650741577
epoch 266, loss 1.3718955516815186
epoch 267, loss 1.3714380264282227
epoch 268, loss 1.3709814

epoch 476, loss 1.2990639209747314
epoch 477, loss 1.298802375793457
epoch 478, loss 1.2985414266586304
epoch 479, loss 1.2982810735702515
epoch 480, loss 1.2980207204818726
epoch 481, loss 1.2977612018585205
epoch 482, loss 1.2975038290023804
epoch 483, loss 1.2972455024719238
epoch 484, loss 1.2969894409179688
epoch 485, loss 1.2967325448989868
epoch 486, loss 1.2964767217636108
epoch 487, loss 1.296221375465393
epoch 488, loss 1.2959659099578857
epoch 489, loss 1.2957115173339844
epoch 490, loss 1.295458197593689
epoch 491, loss 1.2952038049697876
epoch 492, loss 1.294952630996704
epoch 493, loss 1.2946994304656982
epoch 494, loss 1.2944477796554565
epoch 495, loss 1.2941981554031372
epoch 496, loss 1.2939479351043701
epoch 497, loss 1.2936983108520508
epoch 498, loss 1.2934476137161255
epoch 499, loss 1.293200135231018
epoch 500, loss 1.2929503917694092
epoch 501, loss 1.2927049398422241
epoch 502, loss 1.2924574613571167
epoch 503, loss 1.2922108173370361
epoch 504, loss 1.2919640

epoch 712, loss 1.2502267360687256
epoch 713, loss 1.250062108039856
epoch 714, loss 1.2498964071273804
epoch 715, loss 1.2497326135635376
epoch 716, loss 1.249568223953247
epoch 717, loss 1.2494056224822998
epoch 718, loss 1.2492421865463257
epoch 719, loss 1.249079704284668
epoch 720, loss 1.2489145994186401
epoch 721, loss 1.2487531900405884
epoch 722, loss 1.248590111732483
epoch 723, loss 1.2484294176101685
epoch 724, loss 1.2482666969299316
epoch 725, loss 1.2481060028076172
epoch 726, loss 1.247943639755249
epoch 727, loss 1.247782826423645
epoch 728, loss 1.2476228475570679
epoch 729, loss 1.2474615573883057
epoch 730, loss 1.247301697731018
epoch 731, loss 1.247140884399414
epoch 732, loss 1.2469826936721802
epoch 733, loss 1.2468233108520508
epoch 734, loss 1.246663212776184
epoch 735, loss 1.2465041875839233
epoch 736, loss 1.2463467121124268
epoch 737, loss 1.2461872100830078
epoch 738, loss 1.2460294961929321
epoch 739, loss 1.2458704710006714
epoch 740, loss 1.24571430683

epoch 948, loss 1.2176178693771362
epoch 949, loss 1.2175023555755615
epoch 950, loss 1.2173842191696167
epoch 951, loss 1.2172706127166748
epoch 952, loss 1.2171529531478882
epoch 953, loss 1.2170367240905762
epoch 954, loss 1.216922402381897
epoch 955, loss 1.2168055772781372
epoch 956, loss 1.2166903018951416
epoch 957, loss 1.216575026512146
epoch 958, loss 1.2164608240127563
epoch 959, loss 1.2163448333740234
epoch 960, loss 1.216230869293213
epoch 961, loss 1.2161160707473755
epoch 962, loss 1.2160006761550903
epoch 963, loss 1.2158857583999634
epoch 964, loss 1.2157725095748901
epoch 965, loss 1.2156589031219482
epoch 966, loss 1.2155444622039795
epoch 967, loss 1.2154306173324585
epoch 968, loss 1.2153165340423584
epoch 969, loss 1.2152044773101807
epoch 970, loss 1.2150896787643433
epoch 971, loss 1.2149766683578491
epoch 972, loss 1.2148641347885132
epoch 973, loss 1.2147505283355713
epoch 974, loss 1.2146378755569458
epoch 975, loss 1.2145243883132935
epoch 976, loss 1.21441

epoch 1179, loss 1.1941275596618652
epoch 1180, loss 1.1940377950668335
epoch 1181, loss 1.1939507722854614
epoch 1182, loss 1.1938625574111938
epoch 1183, loss 1.1937752962112427
epoch 1184, loss 1.1936854124069214
epoch 1185, loss 1.1935993432998657
epoch 1186, loss 1.1935120820999146
epoch 1187, loss 1.1934235095977783
epoch 1188, loss 1.1933361291885376
epoch 1189, loss 1.1932497024536133
epoch 1190, loss 1.1931629180908203
epoch 1191, loss 1.1930758953094482
epoch 1192, loss 1.1929892301559448
epoch 1193, loss 1.1929019689559937
epoch 1194, loss 1.1928157806396484
epoch 1195, loss 1.192728877067566
epoch 1196, loss 1.192641258239746
epoch 1197, loss 1.1925560235977173
epoch 1198, loss 1.192469835281372
epoch 1199, loss 1.1923831701278687
epoch 1200, loss 1.1922974586486816
epoch 1201, loss 1.192210078239441
epoch 1202, loss 1.1921249628067017
epoch 1203, loss 1.1920403242111206
epoch 1204, loss 1.1919535398483276
epoch 1205, loss 1.1918678283691406
epoch 1206, loss 1.1917834281921

epoch 1408, loss 1.1762577295303345
epoch 1409, loss 1.1761897802352905
epoch 1410, loss 1.1761208772659302
epoch 1411, loss 1.1760505437850952
epoch 1412, loss 1.1759822368621826
epoch 1413, loss 1.1759124994277954
epoch 1414, loss 1.1758427619934082
epoch 1415, loss 1.1757742166519165
epoch 1416, loss 1.1757060289382935
epoch 1417, loss 1.175635814666748
epoch 1418, loss 1.1755670309066772
epoch 1419, loss 1.1755001544952393
epoch 1420, loss 1.1754300594329834
epoch 1421, loss 1.1753607988357544
epoch 1422, loss 1.1752933263778687
epoch 1423, loss 1.1752253770828247
epoch 1424, loss 1.1751552820205688
epoch 1425, loss 1.1750872135162354
epoch 1426, loss 1.1750192642211914
epoch 1427, loss 1.1749504804611206
epoch 1428, loss 1.1748825311660767
epoch 1429, loss 1.1748141050338745
epoch 1430, loss 1.1747475862503052
epoch 1431, loss 1.1746786832809448
epoch 1432, loss 1.17461097240448
epoch 1433, loss 1.1745432615280151
epoch 1434, loss 1.1744741201400757
epoch 1435, loss 1.174406528472

epoch 1637, loss 1.1618680953979492
epoch 1638, loss 1.161810278892517
epoch 1639, loss 1.161755084991455
epoch 1640, loss 1.1616984605789185
epoch 1641, loss 1.1616414785385132
epoch 1642, loss 1.161584734916687
epoch 1643, loss 1.16152822971344
epoch 1644, loss 1.161471962928772
epoch 1645, loss 1.1614158153533936
epoch 1646, loss 1.1613584756851196
epoch 1647, loss 1.1613019704818726
epoch 1648, loss 1.1612459421157837
epoch 1649, loss 1.161190152168274
epoch 1650, loss 1.1611328125
epoch 1651, loss 1.1610780954360962
epoch 1652, loss 1.1610219478607178
epoch 1653, loss 1.160966157913208
epoch 1654, loss 1.1609094142913818
epoch 1655, loss 1.1608542203903198
epoch 1656, loss 1.1607987880706787
epoch 1657, loss 1.1607435941696167
epoch 1658, loss 1.160686731338501
epoch 1659, loss 1.1606312990188599
epoch 1660, loss 1.1605758666992188
epoch 1661, loss 1.160520076751709
epoch 1662, loss 1.1604652404785156
epoch 1663, loss 1.1604108810424805
epoch 1664, loss 1.1603542566299438
epoch 16

epoch 1867, loss 1.150065302848816
epoch 1868, loss 1.1500186920166016
epoch 1869, loss 1.1499724388122559
epoch 1870, loss 1.1499263048171997
epoch 1871, loss 1.1498810052871704
epoch 1872, loss 1.1498345136642456
epoch 1873, loss 1.1497877836227417
epoch 1874, loss 1.1497411727905273
epoch 1875, loss 1.1496955156326294
epoch 1876, loss 1.1496492624282837
epoch 1877, loss 1.1496037244796753
epoch 1878, loss 1.1495577096939087
epoch 1879, loss 1.1495119333267212
epoch 1880, loss 1.149465799331665
epoch 1881, loss 1.149419903755188
epoch 1882, loss 1.1493736505508423
epoch 1883, loss 1.1493279933929443
epoch 1884, loss 1.1492819786071777
epoch 1885, loss 1.1492372751235962
epoch 1886, loss 1.14919114112854
epoch 1887, loss 1.1491451263427734
epoch 1888, loss 1.1490994691848755
epoch 1889, loss 1.1490548849105835
epoch 1890, loss 1.1490095853805542
epoch 1891, loss 1.1489630937576294
epoch 1892, loss 1.1489182710647583
epoch 1893, loss 1.1488721370697021
epoch 1894, loss 1.14882767200469

epoch 2097, loss 1.1403117179870605
epoch 2098, loss 1.1402734518051147
epoch 2099, loss 1.140235185623169
epoch 2100, loss 1.1401963233947754
epoch 2101, loss 1.1401575803756714
epoch 2102, loss 1.1401184797286987
epoch 2103, loss 1.1400789022445679
epoch 2104, loss 1.140041708946228
epoch 2105, loss 1.1400026082992554
epoch 2106, loss 1.139963984489441
epoch 2107, loss 1.1399260759353638
epoch 2108, loss 1.139887809753418
epoch 2109, loss 1.1398488283157349
epoch 2110, loss 1.1398106813430786
epoch 2111, loss 1.1397722959518433
epoch 2112, loss 1.1397329568862915
epoch 2113, loss 1.1396952867507935
epoch 2114, loss 1.139656901359558
epoch 2115, loss 1.1396191120147705
epoch 2116, loss 1.139579176902771
epoch 2117, loss 1.139541745185852
epoch 2118, loss 1.1395028829574585
epoch 2119, loss 1.1394659280776978
epoch 2120, loss 1.1394262313842773
epoch 2121, loss 1.1393886804580688
epoch 2122, loss 1.1393505334854126
epoch 2123, loss 1.1393115520477295
epoch 2124, loss 1.1392748355865479

epoch 2327, loss 1.1320585012435913
epoch 2328, loss 1.1320247650146484
epoch 2329, loss 1.131991982460022
epoch 2330, loss 1.1319584846496582
epoch 2331, loss 1.1319257020950317
epoch 2332, loss 1.131892204284668
epoch 2333, loss 1.1318600177764893
epoch 2334, loss 1.1318268775939941
epoch 2335, loss 1.1317921876907349
epoch 2336, loss 1.131760597229004
epoch 2337, loss 1.131726622581482
epoch 2338, loss 1.1316947937011719
epoch 2339, loss 1.1316615343093872
epoch 2340, loss 1.1316282749176025
epoch 2341, loss 1.131595492362976
epoch 2342, loss 1.1315627098083496
epoch 2343, loss 1.131529688835144
epoch 2344, loss 1.1314977407455444
epoch 2345, loss 1.1314637660980225
epoch 2346, loss 1.1314311027526855
epoch 2347, loss 1.1313989162445068
epoch 2348, loss 1.1313655376434326
epoch 2349, loss 1.1313337087631226
epoch 2350, loss 1.1313005685806274
epoch 2351, loss 1.1312675476074219
epoch 2352, loss 1.1312353610992432
epoch 2353, loss 1.1312023401260376
epoch 2354, loss 1.131169319152832

epoch 2557, loss 1.1249395608901978
epoch 2558, loss 1.1249104738235474
epoch 2559, loss 1.1248818635940552
epoch 2560, loss 1.1248524188995361
epoch 2561, loss 1.124825119972229
epoch 2562, loss 1.1247953176498413
epoch 2563, loss 1.1247665882110596
epoch 2564, loss 1.124738335609436
epoch 2565, loss 1.1247094869613647
epoch 2566, loss 1.1246803998947144
epoch 2567, loss 1.1246521472930908
epoch 2568, loss 1.1246235370635986
epoch 2569, loss 1.1245944499969482
epoch 2570, loss 1.1245654821395874
epoch 2571, loss 1.1245371103286743
epoch 2572, loss 1.1245087385177612
epoch 2573, loss 1.1244800090789795
epoch 2574, loss 1.1244522333145142
epoch 2575, loss 1.1244220733642578
epoch 2576, loss 1.1243940591812134
epoch 2577, loss 1.1243666410446167
epoch 2578, loss 1.1243371963500977
epoch 2579, loss 1.1243090629577637
epoch 2580, loss 1.1242808103561401
epoch 2581, loss 1.1242520809173584
epoch 2582, loss 1.1242228746414185
epoch 2583, loss 1.124194860458374
epoch 2584, loss 1.124165773391

epoch 2786, loss 1.1187357902526855
epoch 2787, loss 1.1187108755111694
epoch 2788, loss 1.1186842918395996
epoch 2789, loss 1.1186602115631104
epoch 2790, loss 1.118633508682251
epoch 2791, loss 1.1186084747314453
epoch 2792, loss 1.1185822486877441
epoch 2793, loss 1.1185580492019653
epoch 2794, loss 1.1185319423675537
epoch 2795, loss 1.1185070276260376
epoch 2796, loss 1.1184816360473633
epoch 2797, loss 1.1184560060501099
epoch 2798, loss 1.1184308528900146
epoch 2799, loss 1.1184059381484985
epoch 2800, loss 1.1183803081512451
epoch 2801, loss 1.1183555126190186
epoch 2802, loss 1.1183305978775024
epoch 2803, loss 1.1183054447174072
epoch 2804, loss 1.1182799339294434
epoch 2805, loss 1.1182551383972168
epoch 2806, loss 1.1182290315628052
epoch 2807, loss 1.1182054281234741
epoch 2808, loss 1.1181780099868774
epoch 2809, loss 1.1181544065475464
epoch 2810, loss 1.1181282997131348
epoch 2811, loss 1.1181024312973022
epoch 2812, loss 1.1180777549743652
epoch 2813, loss 1.1180536746

epoch 3016, loss 1.1132148504257202
epoch 3017, loss 1.1131922006607056
epoch 3018, loss 1.1131701469421387
epoch 3019, loss 1.113147497177124
epoch 3020, loss 1.113126277923584
epoch 3021, loss 1.113102674484253
epoch 3022, loss 1.113079309463501
epoch 3023, loss 1.113057017326355
epoch 3024, loss 1.1130348443984985
epoch 3025, loss 1.1130123138427734
epoch 3026, loss 1.1129889488220215
epoch 3027, loss 1.1129674911499023
epoch 3028, loss 1.1129440069198608
epoch 3029, loss 1.1129223108291626
epoch 3030, loss 1.1129000186920166
epoch 3031, loss 1.1128768920898438
epoch 3032, loss 1.1128554344177246
epoch 3033, loss 1.1128325462341309
epoch 3034, loss 1.1128103733062744
epoch 3035, loss 1.1127877235412598
epoch 3036, loss 1.1127643585205078
epoch 3037, loss 1.112742304801941
epoch 3038, loss 1.112718939781189
epoch 3039, loss 1.112697958946228
epoch 3040, loss 1.112675428390503
epoch 3041, loss 1.1126528978347778
epoch 3042, loss 1.1126302480697632
epoch 3043, loss 1.1126068830490112
e

epoch 3246, loss 1.108278751373291
epoch 3247, loss 1.108258605003357
epoch 3248, loss 1.1082390546798706
epoch 3249, loss 1.108217716217041
epoch 3250, loss 1.1081981658935547
epoch 3251, loss 1.1081784963607788
epoch 3252, loss 1.108157753944397
epoch 3253, loss 1.108136773109436
epoch 3254, loss 1.1081163883209229
epoch 3255, loss 1.1080964803695679
epoch 3256, loss 1.108076572418213
epoch 3257, loss 1.1080560684204102
epoch 3258, loss 1.1080353260040283
epoch 3259, loss 1.1080162525177002
epoch 3260, loss 1.1079949140548706
epoch 3261, loss 1.1079747676849365
epoch 3262, loss 1.1079541444778442
epoch 3263, loss 1.1079341173171997
epoch 3264, loss 1.1079151630401611
epoch 3265, loss 1.1078938245773315
epoch 3266, loss 1.107873558998108
epoch 3267, loss 1.107853889465332
epoch 3268, loss 1.1078332662582397
epoch 3269, loss 1.107814073562622
epoch 3270, loss 1.1077930927276611
epoch 3271, loss 1.1077728271484375
epoch 3272, loss 1.107752799987793
epoch 3273, loss 1.107732892036438
epo

epoch 3476, loss 1.1038283109664917
epoch 3477, loss 1.1038084030151367
epoch 3478, loss 1.1037912368774414
epoch 3479, loss 1.1037724018096924
epoch 3480, loss 1.1037538051605225
epoch 3481, loss 1.1037354469299316
epoch 3482, loss 1.1037169694900513
epoch 3483, loss 1.1036992073059082
epoch 3484, loss 1.1036794185638428
epoch 3485, loss 1.1036617755889893
epoch 3486, loss 1.1036432981491089
epoch 3487, loss 1.1036252975463867
epoch 3488, loss 1.1036070585250854
epoch 3489, loss 1.1035887002944946
epoch 3490, loss 1.1035692691802979
epoch 3491, loss 1.1035529375076294
epoch 3492, loss 1.1035338640213013
epoch 3493, loss 1.1035162210464478
epoch 3494, loss 1.1034976243972778
epoch 3495, loss 1.103479266166687
epoch 3496, loss 1.103460669517517
epoch 3497, loss 1.1034419536590576
epoch 3498, loss 1.1034250259399414
epoch 3499, loss 1.1034061908721924
epoch 3500, loss 1.1033873558044434
epoch 3501, loss 1.103368878364563
epoch 3502, loss 1.1033508777618408
epoch 3503, loss 1.103332757949

epoch 3706, loss 1.0997846126556396
epoch 3707, loss 1.099768042564392
epoch 3708, loss 1.099751591682434
epoch 3709, loss 1.0997344255447388
epoch 3710, loss 1.0997167825698853
epoch 3711, loss 1.0997015237808228
epoch 3712, loss 1.0996845960617065
epoch 3713, loss 1.0996671915054321
epoch 3714, loss 1.0996507406234741
epoch 3715, loss 1.0996344089508057
epoch 3716, loss 1.0996168851852417
epoch 3717, loss 1.0995999574661255
epoch 3718, loss 1.099584937095642
epoch 3719, loss 1.0995657444000244
epoch 3720, loss 1.099550485610962
epoch 3721, loss 1.0995334386825562
epoch 3722, loss 1.0995172262191772
epoch 3723, loss 1.0995010137557983
epoch 3724, loss 1.0994843244552612
epoch 3725, loss 1.099467158317566
epoch 3726, loss 1.0994502305984497
epoch 3727, loss 1.0994333028793335
epoch 3728, loss 1.0994174480438232
epoch 3729, loss 1.0994008779525757
epoch 3730, loss 1.0993831157684326
epoch 3731, loss 1.0993679761886597
epoch 3732, loss 1.0993505716323853
epoch 3733, loss 1.09933435916900

epoch 3936, loss 1.0960904359817505
epoch 3937, loss 1.096075415611267
epoch 3938, loss 1.0960599184036255
epoch 3939, loss 1.0960443019866943
epoch 3940, loss 1.0960291624069214
epoch 3941, loss 1.0960140228271484
epoch 3942, loss 1.0959985256195068
epoch 3943, loss 1.0959839820861816
epoch 3944, loss 1.095967411994934
epoch 3945, loss 1.0959525108337402
epoch 3946, loss 1.0959371328353882
epoch 3947, loss 1.0959211587905884
epoch 3948, loss 1.095906138420105
epoch 3949, loss 1.0958917140960693
epoch 3950, loss 1.09587562084198
epoch 3951, loss 1.0958609580993652
epoch 3952, loss 1.095845341682434
epoch 3953, loss 1.0958302021026611
epoch 3954, loss 1.0958141088485718
epoch 3955, loss 1.0957989692687988
epoch 3956, loss 1.0957847833633423
epoch 3957, loss 1.0957698822021484
epoch 3958, loss 1.0957533121109009
epoch 3959, loss 1.0957378149032593
epoch 3960, loss 1.0957231521606445
epoch 3961, loss 1.095707654953003
epoch 3962, loss 1.0956940650939941
epoch 3963, loss 1.095677375793457


epoch 4166, loss 1.0926971435546875
epoch 4167, loss 1.0926828384399414
epoch 4168, loss 1.0926668643951416
epoch 4169, loss 1.092653751373291
epoch 4170, loss 1.0926400423049927
epoch 4171, loss 1.092625617980957
epoch 4172, loss 1.0926117897033691
epoch 4173, loss 1.0925973653793335
epoch 4174, loss 1.0925829410552979
epoch 4175, loss 1.092569351196289
epoch 4176, loss 1.0925554037094116
epoch 4177, loss 1.0925416946411133
epoch 4178, loss 1.092527151107788
epoch 4179, loss 1.0925142765045166
epoch 4180, loss 1.0924993753433228
epoch 4181, loss 1.0924850702285767
epoch 4182, loss 1.0924702882766724
epoch 4183, loss 1.0924568176269531
epoch 4184, loss 1.0924434661865234
epoch 4185, loss 1.0924286842346191
epoch 4186, loss 1.0924155712127686
epoch 4187, loss 1.092401146888733
epoch 4188, loss 1.0923867225646973
epoch 4189, loss 1.0923715829849243
epoch 4190, loss 1.092359185218811
epoch 4191, loss 1.0923436880111694
epoch 4192, loss 1.0923303365707397
epoch 4193, loss 1.092317104339599

In [None]:
predicted = model(torch_X_test.float()).detach().numpy()
true_test = torch_y_test.detach().numpy()

In [None]:
len(predicted), len(true_test)

In [None]:
sum(1 if x[0] == x[1] else 0 for x in list(zip([np.argmax(x) for x in predicted], true_test))) / len(true_test)

In [None]:
torch_X_test[0].shape

In [None]:
torch.save({'state_dict': model.state_dict()}, 'checkpoint.pth.tar')