In [20]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
torch.manual_seed(1)

<torch._C.Generator at 0x7f907072d130>

In [182]:
training_data = np.zeros((2,4))

for x in range(0, 20):
    training_data = np.append(training_data, np.array([[x, x, x+1, np.cos(x+1)], [x+1, x+1, x+2, np.cos(x+1)]]),axis=1)


In [184]:
training_data = training_data.reshape(-1,2,4)
print (training_data)
print (training_data.shape)

[[[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  1.00000000e+00  5.40302306e-01]]

 [[ 1.00000000e+00  1.00000000e+00  2.00000000e+00 -4.16146837e-01]
  [ 2.00000000e+00  2.00000000e+00  3.00000000e+00 -9.89992497e-01]]

 [[ 3.00000000e+00  3.00000000e+00  4.00000000e+00 -6.53643621e-01]
  [ 4.00000000e+00  4.00000000e+00  5.00000000e+00  2.83662185e-01]]

 [[ 5.00000000e+00  5.00000000e+00  6.00000000e+00  9.60170287e-01]
  [ 6.00000000e+00  6.00000000e+00  7.00000000e+00  7.53902254e-01]]

 [[ 7.00000000e+00  7.00000000e+00  8.00000000e+00 -1.45500034e-01]
  [ 8.00000000e+00  8.00000000e+00  9.00000000e+00 -9.11130262e-01]]

 [[ 9.00000000e+00  9.00000000e+00  1.00000000e+01 -8.39071529e-01]
  [ 1.00000000e+01  1.00000000e+01  1.10000000e+01  4.42569799e-03]]

 [[ 1.10000000e+01  1.10000000e+01  1.20000000e+01  8.43853959e-01]
  [ 1.20000000e+01  1.20000000e+01  1.30000000e+01  9.07446781e-01]]

 [[ 1.30000000e+01  1.30000000e+01

In [4]:
lstm = nn.LSTM(3, 3)  # Input dim is 3, output dim is 3
inputs = [torch.randn(1, 3) for _ in range(5)]  # make a sequence of length 5

# initialize the hidden state.
hidden = (torch.randn(1, 1, 3),
          torch.randn(1, 1, 3))
for i in inputs:
    # Step through the sequence one element at a time.
    # after each step, hidden contains the hidden state.
    out, hidden = lstm(i.view(1, 1, -1), hidden)

# alternatively, we can do the entire sequence all at once.
# the first value returned by LSTM is all of the hidden states throughout
# the sequence. the second is just the most recent hidden state
# (compare the last slice of "out" with "hidden" below, they are the same)
# The reason for this is that:
# "out" will give you access to all hidden states in the sequence
# "hidden" will allow you to continue the sequence and backpropagate,
# by passing it as an argument  to the lstm at a later time
# Add the extra 2nd dimension
inputs = torch.cat(inputs).view(len(inputs), 1, -1)
hidden = (torch.randn(1, 1, 3), torch.randn(1, 1, 3))  # clean out hidden state
out, hidden = lstm(inputs, hidden)
print(out)
print(hidden)

tensor([[[ 0.2490, -0.0525,  0.3253]],

        [[ 0.1655, -0.0304,  0.3348]],

        [[-0.1104, -0.1085,  0.7568]],

        [[-0.0148, -0.0855,  0.4162]],

        [[ 0.0703, -0.1089,  0.2071]]], grad_fn=<StackBackward>)
(tensor([[[ 0.0703, -0.1089,  0.2071]]], grad_fn=<StackBackward>), tensor([[[ 0.2099, -0.3541,  0.9947]]], grad_fn=<StackBackward>))


In [192]:
class LSTMTagger(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
        super(LSTMTagger, self).__init__()
        self.hidden_dim = hidden_dim

        self.word_embeddings = nn.Linear(vocab_size, embedding_dim)

        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)

        # The linear layer that maps from hidden state space to tag space
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)

    def forward(self, sentence):
        embeds = self.word_embeddings(sentence)
        lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))
        tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
#         tag_scores = F.log_softmax(tag_space, dim=1)
        return tag_space

In [193]:
model = LSTMTagger(10, 30, 4, 4)

In [194]:
loss_function = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)

In [195]:
for epoch in range(300):  # again, normally you would NOT do 300 epochs, it is toy data
    for sentence, tags in training_data:
        # Step 1. Remember that Pytorch accumulates gradients.
        # We need to clear them out before each instance
        model.zero_grad()

        # Step 2. Get our inputs ready for the network, that is, turn them into
        # Tensors of word indices.
        sentence = torch.Tensor(sentence)
        tags = torch.Tensor(tags)
#         print ("\tinput ", sentence)
#         print ("\tgt ", tags)

        # Step 3. Run our forward pass.
        
        tag_scores = model(sentence[None])

        # Step 4. Compute the loss, gradients, and update the parameters by
        #  calling optimizer.step()
        loss = loss_function(tags[None],tag_scores)
        loss.backward()
        print (loss.item())
        optimizer.step()

# See what the scores are after training
# with torch.no_grad():
#     inputs = prepare_sequence(training_data[0][0], word_to_ix)
#     tag_scores = model(inputs)

#     # The sentence is "the dog ate the apple".  i,j corresponds to score for tag j
#     # for word i. The predicted tag is the maximum scoring tag.
#     # Here, we can see the predicted sequence below is 0 1 2 0 1
#     # since 0 is index of the maximum value of row 1,
#     # 1 is the index of maximum value of row 2, etc.
#     # Which is DET NOUN VERB DET NOUN, the correct sequence!
#     print(tag_scores)

0.4388449192047119
4.53697395324707
14.50612735748291
30.632246017456055
52.2845344543457
79.46220397949219
112.60518646240234
150.55796813964844
193.41403198242188
242.257080078125
0.17097727954387665
4.063918590545654
13.458547592163086
28.450071334838867
48.36225128173828
74.52374267578125
106.1427993774414
143.10214233398438
186.25827026367188
234.4187469482422
287.8338928222656
0.3919041156768799
3.898266315460205
12.322640419006348
25.917055130004883
44.94999694824219
69.48167419433594
99.79977416992188
135.68357849121094
177.15838623046875
224.16871643066406
1.0203379392623901
3.385488748550415
11.348526954650879
24.02127456665039
41.535888671875
65.4464111328125
94.59959411621094
129.22671508789062
170.12579345703125
215.93299865722656
267.1986083984375
0.3605233430862427
3.2560253143310547
10.004737854003906
21.11387062072754
37.80496597290039
59.90294647216797
87.76136779785156
121.4417724609375
160.65087890625
205.485595703125
2.6904850006103516
2.6775293350219727
8.96600055

0.3149908185005188
0.7970337867736816
2.6944284439086914
0.6971984505653381
0.18245863914489746
2.767364025115967
9.785754203796387
22.325916290283203
40.91630172729492
91.41619110107422
0.18769203126430511
0.49072104692459106
2.1724612712860107
0.6056065559387207
0.3941621780395508
2.8361172676086426
9.470135688781738
22.326438903808594
40.503292083740234
64.55146789550781
0.5891233682632446
0.3177763819694519
0.7513906955718994
2.759830951690674
0.7450255751609802
0.14345209300518036
2.457249402999878
9.019888877868652
21.035476684570312
39.09041213989258
94.22077178955078
0.18027238547801971
0.4508402943611145
2.1984925270080566
0.6614905595779419
0.35745489597320557
2.5446577072143555
8.722249031066895
21.078536987304688
38.736732482910156
62.29167556762695
0.5849721431732178
0.3212854564189911
0.7127202153205872
2.8283448219299316
0.79433274269104
0.11229103803634644
2.190638542175293
8.332637786865234
19.853456497192383
37.40160369873047
96.91593933105469
0.17411495745182037
0.41

1.4974033832550049
1.0665079355239868
0.1912039965391159
0.48081105947494507
2.0856096744537354
8.005053520202637
18.485315322875977
34.984161376953125
0.41820237040519714
0.6150044202804565
0.13875049352645874
2.560913562774658
1.233542561531067
0.007686689030379057
0.40600132942199707
2.2281877994537354
7.41314172744751
17.927217483520508
138.66427612304688
0.398188054561615
0.1954638957977295
1.461035132408142
1.0622289180755615
0.1900085210800171
0.4630632996559143
2.002089262008667
7.785253524780273
18.102312088012695
34.43170166015625
0.41352415084838867
0.6270105838775635
0.13496145606040955
2.5274658203125
1.2398114204406738
0.008631991222500801
0.3946952819824219
2.1477601528167725
7.201992034912109
17.55719757080078
139.75071716308594
0.4160032868385315
0.2038198709487915
1.4249799251556396
1.0573370456695557
0.1888723224401474
0.4469451308250427
1.9243741035461426
7.577167987823486
17.737693786621094
33.9039421081543
0.4090512692928314
0.6385906338691711
0.13176369667053223


0.3487503230571747
1.1100645065307617
4.367530822753906
12.411163330078125
157.01947021484375
0.7712250351905823
0.3563379943370819
0.764594316482544
0.8869372010231018
0.1982358694076538
0.3176960051059723
1.0223712921142578
4.781576156616211
12.633759498596191
26.299306869506836
0.3465079963207245
0.7743287086486816
0.12404751032590866
1.6749399900436401
1.4182963371276855
0.05413433536887169
0.3515847325325012
1.0839531421661377
4.2985663414001465
12.281899452209473
157.51876831054688
0.7822185158729553
0.3598180413246155
0.744236171245575
0.879845142364502
0.2009119689464569
0.3171863555908203
1.0036845207214355
4.713726997375488
12.504545211791992
26.10028648376465
0.34495335817337036
0.7753487229347229
0.12437482178211212
1.6424273252487183
1.4277902841567993
0.057245973497629166
0.3545967936515808
1.0587828159332275
4.232424736022949
12.157739639282227
158.0023193359375
0.7927684783935547
0.36305901408195496
0.7245493531227112
0.8728817105293274
0.20380383729934692
0.31682032346

0.7409186959266663
0.1267831027507782
1.0612496137619019
1.6372296810150146
0.15513616800308228
0.42397308349609375
0.7035620212554932
3.2971882820129395
10.369344711303711
165.47817993164062
0.9174360036849976
0.3911493420600891
0.4378068447113037
0.7463400363922119
0.3078569173812866
0.31127557158470154
0.7585278749465942
3.742140769958496
10.592887878417969
23.07971954345703
0.32253023982048035
0.7376149296760559
0.1266893744468689
1.0412445068359375
1.645431637763977
0.16061042249202728
0.42575979232788086
0.6948534846305847
3.271810293197632
10.318984985351562
165.70596313476562
0.9187872409820557
0.3911468982696533
0.43020355701446533
0.7413243651390076
0.31348299980163574
0.31031203269958496
0.7535445094108582
3.717942953109741
10.542698860168457
22.9974365234375
0.32199496030807495
0.7342196106910706
0.12658050656318665
1.0217818021774292
1.6534340381622314
0.16612356901168823
0.42735904455184937
0.6866337656974792
3.2474746704101562
10.270506858825684
165.92637634277344
0.9198

21.748798370361328
0.31616756319999695
0.6470267176628113
0.12124443054199219
0.7108351588249207
1.776257038116455
0.2825544476509094
0.42262864112854004
0.5911204218864441
2.903062105178833
9.551230430603027
169.347900390625
0.885684609413147
0.3788204789161682
0.3330138921737671
0.633784830570221
0.4261682331562042
0.27720531821250916
0.6929634809494019
3.363316535949707
9.776863098144531
21.710662841796875
0.316114604473114
0.6428831219673157
0.12091484665870667
0.701020359992981
1.7796125411987305
0.28712910413742065
0.4212174713611603
0.5892089009284973
2.893369674682617
9.529728889465332
169.45465087890625
0.8823559284210205
0.37796738743782043
0.33117109537124634
0.6295418739318848
0.42995545268058777
0.2757119834423065
0.6915057897567749
3.3537631034851074
9.755340576171875
21.673694610595703
0.3160756528377533
0.6387497186660767
0.12058216333389282
0.6914921998977661
1.7828012704849243
0.29161733388900757
0.419756680727005
0.5874140858650208
2.8840107917785645
9.50889587402343

9.17624568939209
171.22430419921875
0.7898356914520264
0.35669857263565063
0.31402507424354553
0.5463434457778931
0.49095186591148376
0.2480512410402298
0.6608215570449829
3.198090076446533
9.400900840759277
21.066001892089844
0.3179590106010437
0.5499347448348999
0.1129886731505394
0.5320363640785217
1.8204429149627686
0.36938241124153137
0.38136452436447144
0.5652420520782471
2.734083414077759
9.165295600891113
171.27850341796875
0.7854925394058228
0.35577139258384705
0.31397396326065063
0.5433752536773682
0.4926089644432068
0.24715444445610046
0.659419059753418
3.1932458877563477
9.389925003051758
21.047496795654297
0.3181244134902954
0.5464311242103577
0.11269113421440125
0.5270085334777832
1.8208746910095215
0.3716861307621002
0.37973546981811523
0.5647217035293579
2.729494333267212
9.154607772827148
171.33128356933594
0.7811444997787476
0.3548465371131897
0.3139568269252777
0.5404603481292725
0.4941953122615814
0.2462812066078186
0.6580038070678711
3.188514232635498
9.37921333312

0.624263346195221
3.107473611831665
9.20162582397461
20.742351531982422
0.3225348889827728
0.4773486256599426
0.10732804983854294
0.43917006254196167
1.814204216003418
0.402513325214386
0.34822943806648254
0.5553267598152161
2.6486005783081055
8.969541549682617
172.2043914794922
0.6869171857833862
0.33538582921028137
0.31902119517326355
0.4888392984867096
0.5136963129043579
0.2322794795036316
0.6224472522735596
3.104238510131836
9.1948881149292
20.73210334777832
0.3227561414241791
0.47454798221588135
0.1071396917104721
0.43595612049102783
1.8133063316345215
0.4030866026878357
0.3470018208026886
0.5549179315567017
2.6455445289611816
8.962821960449219
172.23373413085938
0.6827796697616577
0.33453693985939026
0.31939709186553955
0.4870138168334961
0.5139778256416321
0.23184412717819214
0.6206125020980835
3.1010403633117676
9.188264846801758
20.72211265563965
0.32297828793525696
0.47177422046661377
0.1069558635354042
0.4327927827835083
1.81236732006073
0.4035930633544922
0.3457903265953064

20.558725357055664
0.3277342915534973
0.4190727174282074
0.10406429320573807
0.37549930810928345
1.7842159271240234
0.4000920057296753
0.32345250248908997
0.5450869798660278
2.5843636989593506
8.836624145507812
172.730224609375
0.5964116454124451
0.3160589039325714
0.32856330275535583
0.45636460185050964
0.5097634792327881
0.22540289163589478
0.5751064419746399
3.0373549461364746
9.06545352935791
20.552749633789062
0.3279612958431244
0.41679805517196655
0.10396652668714523
0.3731115460395813
1.7825251817703247
0.3993045389652252
0.3224978446960449
0.5445965528488159
2.581739902496338
8.831622123718262
172.74732971191406
0.592707633972168
0.31521251797676086
0.32899340987205505
0.45535552501678467
0.5091490745544434
0.22522568702697754
0.5728060007095337
3.0346505641937256
9.060657501220703
20.546903610229492
0.3281881809234619
0.4145422577857971
0.10387156903743744
0.370747447013855
1.78080415725708
0.3984658122062683
0.3215499520301819
0.5441023707389832
2.579120397567749
8.8266658782

2.5244455337524414
8.73155403137207
173.0402374267578
0.5160529613494873
0.29592281579971313
0.3382716476917267
0.4397578537464142
0.487789124250412
0.22274388372898102
0.5176535248756409
2.9764626026153564
8.965659141540527
20.44651985168457
0.3331257104873657
0.36915379762649536
0.10234583169221878
0.32365208864212036
1.7363160848617554
0.36841341853141785
0.3010975122451782
0.532725989818573
2.52182674407959
8.727380752563477
173.05030822753906
0.5127800703048706
0.295005738735199
0.33868274092674255
0.43930041790008545
0.4864947199821472
0.22266942262649536
0.5149630308151245
2.9738385677337646
8.961721420288086
20.44304847717285
0.333347886800766
0.36726170778274536
0.10229241847991943
0.3216937184333801
1.7340129613876343
0.3665929138660431
0.30014121532440186
0.5322045087814331
2.519197702407837
8.723235130310059
173.06015014648438
0.5095247030258179
0.2940843403339386
0.3390929102897644
0.43886011838912964
0.48517611622810364
0.2225954681634903
0.5122590661048889
2.971216678619

0.2207007259130478
0.4532332718372345
2.916057586669922
8.881511688232422
20.385637283325195
0.3382101058959961
0.32876530289649963
0.10107006132602692
0.2816906273365021
1.6777153015136719
0.31995639204978943
0.27720656991004944
0.5214444398880005
2.460427761077881
8.639106750488281
173.22080993652344
0.44208428263664246
0.2727237343788147
0.3480117917060852
0.43238747119903564
0.4504338800907135
0.22056889533996582
0.450359582901001
2.9134297370910645
8.878143310546875
20.383811950683594
0.33843064308166504
0.32714736461639404
0.10100333392620087
0.28000539541244507
1.6749086380004883
0.31762999296188354
0.2760710120201111
0.521003782749176
2.4577109813690186
8.635608673095703
173.2257537841797
0.4391956329345703
0.2717018723487854
0.34841716289520264
0.43218469619750977
0.4486280679702759
0.22043095529079437
0.4474833309650421
2.9108030796051025
8.874794960021973
20.38205909729004
0.33865100145339966
0.3255401849746704
0.10093499720096588
0.27833205461502075
1.6720831394195557
0.315

0.3432484269142151
0.2942403554916382
0.09899849444627762
0.24600492417812347
1.6080251932144165
0.26550379395484924
0.24977919459342957
0.5129586458206177
2.3971991539001465
8.566244125366211
173.28684997558594
0.3790168762207031
0.24821630120277405
0.35744306445121765
0.4275360107421875
0.40562722086906433
0.21545597910881042
0.3849167823791504
2.8533976078033447
8.807300567626953
20.358856201171875
0.34346410632133484
0.2928657829761505
0.09887849539518356
0.24461007118225098
1.6047558784484863
0.2631640136241913
0.24856270849704742
0.5126768946647644
2.3944172859191895
8.563448905944824
173.2875518798828
0.37642329931259155
0.24710848927497864
0.35785654187202454
0.4272719621658325
0.40356820821762085
0.21512742340564728
0.38215357065200806
2.850818395614624
8.804520606994629
20.35847282409668
0.34367918968200684
0.29150182008743286
0.09875573962926865
0.24322891235351562
1.601464867591858
0.2608341872692108
0.247349351644516
0.5124025940895081
2.3916399478912354
8.560685157775879


0.2177196443080902
1.527978539466858
0.2144203931093216
0.22308886051177979
0.5088485479354858
2.332571029663086
8.509644508361816
173.26077270507812
0.3221375048160553
0.22219769656658173
0.36677801609039307
0.4188583195209503
0.357865571975708
0.205505833029747
0.32481294870376587
2.795320987701416
8.74985122680664
20.363414764404297
0.3481594920158386
0.2642436623573303
0.09535570442676544
0.2166803926229477
1.5242758989334106
0.21235248446464539
0.22202646732330322
0.5087976455688477
2.329716444015503
8.507532119750977
173.25775146484375
0.31979089975357056
0.22105099260807037
0.3671657145023346
0.41834011673927307
0.35581743717193604
0.2049703299999237
0.32239270210266113
2.792860269546509
8.74765682220459
20.364200592041016
0.34834611415863037
0.26313287019729614
0.09517060965299606
0.21565790474414825
1.5205503702163696
0.2103007435798645
0.22097475826740265
0.5087598562240601
2.3268558979034424
8.505448341369629
173.2545928955078
0.3174546957015991
0.21990475058555603
0.3675514

In [191]:
imput = torch.Tensor([0.,0.,1., 0],)
model(imput[None])

tensor([[ 1.2011,  1.2104,  1.8353, -0.0036]], grad_fn=<AddmmBackward>)

In [240]:
import torch
import torch.nn as nn

input_dim = 2
hidden_dim = 4
n_layers = 2

lstm_layer = nn.LSTM(input_dim, hidden_dim, n_layers, batch_first=True)

batch_size = 1
seq_len = 1

inp = torch.randn(batch_size, seq_len, input_dim) # batch size, sequence length, input dimension).

hidden_state = torch.randn(n_layers, batch_size, hidden_dim)
cell_state = torch.randn(n_layers, batch_size, hidden_dim)
hidden = (hidden_state, cell_state)

out, hidden = lstm_layer(inp, hidden)
print("Output shape: ", out.shape)
print("Hidden: ", hidden)



Output shape:  torch.Size([1, 1, 4])
Hidden:  (tensor([[[ 0.0171,  0.1125,  0.1405, -0.3655]],

        [[-0.2261, -0.1112,  0.3362,  0.0356]]], grad_fn=<StackBackward>), tensor([[[ 0.0340,  0.1782,  0.5202, -0.7167]],

        [[-0.5972, -0.2874,  0.9275,  0.1313]]], grad_fn=<StackBackward>))


In [241]:
seq_len = 3
inp = torch.randn(batch_size, seq_len, input_dim)
out, hidden = lstm_layer(inp, hidden)
print(out.shape)

torch.Size([1, 3, 4])


In [242]:
inp = torch.tensor([[0.1,0.2],[0.1,0.2],[0.1,0.2]] ).reshape(1,3,2)
hidden_state = torch.randn(n_layers, batch_size, hidden_dim)
cell_state = torch.randn(n_layers, batch_size, hidden_dim)
hidden = (hidden_state, cell_state)
# out, hidden = lstm_layer(inp, hidden)
out, hidden = lstm_layer(inp, hidden)
print ("out = ", out)
print (hidden)

out =  tensor([[[ 0.1854,  0.1087,  0.3529,  0.2931],
         [ 0.0575,  0.1505,  0.2047,  0.2179],
         [-0.0173,  0.1635,  0.1669,  0.1857]]], grad_fn=<TransposeBackward0>)
(tensor([[[-0.1179,  0.1196, -0.0678, -0.0368]],

        [[-0.0173,  0.1635,  0.1669,  0.1857]]], grad_fn=<StackBackward>), tensor([[[-0.3541,  0.2019, -0.1455, -0.0615]],

        [[-0.0420,  0.4253,  0.4563,  0.7389]]], grad_fn=<StackBackward>))
