In [147]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

import lightning as L
from torch.utils.data import DataLoader, TensorDataset

In [148]:
class LSTMbyHand(L.LightningModule):
    def __init__(self):
        super().__init__()
        
        # use normal distribution to generate random weights and biases
        mean = torch.tensor(0.0)
        std = torch.tensor(1.0)

        # short term memory weights and biases
        self.wlr1 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.wlr2 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.blr1 = nn.Parameter(torch.tensor(0.), requires_grad=True)

        # candidate long term memory/cell state weights and biases
        self.wpr1 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.wpr2 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.bpr1 = nn.Parameter(torch.tensor(0.), requires_grad=True)

        # potential memory weights and biases
        self.wp1 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.wp2 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.bp1 = nn.Parameter(torch.tensor(0.), requires_grad=True)

        # output weights and biases
        self.wo1 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.wo2 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.bo1 = nn.Parameter(torch.tensor(0.), requires_grad=True)

    def input_format(self, wsh, win, b):
        data = {
            "input_weights": int(win.item()*1000),
            "gate_biases":int(b.item()*1000000),
            "short_weights":int(wsh.item()*1000)
        }
        return data
    
    def lstm_unit(self, input_value, long_memory, short_memory):
        # determine how much of long term memory to remember - forget gate
        long_remember_percent = torch.sigmoid((short_memory * self.wlr1)
                                              + (input_value * self.wlr2) +
                                              self.blr1)
        print(f"forget = {(short_memory * self.wlr1)+ (input_value * self.wlr2) +self.blr1}")
        # calculate how much of potential memory should be remembered - Candidate gate
        potential_remember_percent = torch.sigmoid((short_memory * self.wpr1)
                                              + (input_value * self.wpr2) +
                                              self.bpr1)
        print(f"candidate = {(short_memory * self.wpr1)  + (input_value * self.wpr2) + self.bpr1}")
        # calculate candidate memory - Input gate
        potential_memory = torch.tanh((short_memory * self.wp1)
                                              + (input_value * self.wp2) +
                                              self.bp1)
        print(f"input = {(short_memory * self.wp1)+ (input_value * self.wp2) +self.bp1}")
        # update long term memory
        updated_long_memory = ((long_memory * long_remember_percent) +
                               (potential_remember_percent * potential_memory))
        print(f"new_long = {updated_long_memory}")
        # calculate new short term memory and how much of it to output - Output gate
        output_percent = torch.sigmoid((short_memory * self.wo1) + 
                                       (input_value * self.wo2) +
                                       self.bo1)
        print(f"output = {(short_memory * self.wo1) + (input_value * self.wo2) +self.bo1}")
        updated_short_memory = torch.tanh(updated_long_memory) * output_percent
        print(f"new short = {updated_short_memory}")
        print(f"\n")

        return ([updated_long_memory, updated_short_memory])
    
    def forward(self, input):
        long_memory = 0
        short_memory = 0
        # test inputs
        day1 = input[0]
        day2 = input[1]
        day3 = input[2]
        day4 = input[3]

        long_memory, short_memory = self.lstm_unit(day1, long_memory, short_memory)
        long_memory, short_memory = self.lstm_unit(day2, long_memory, short_memory)
        long_memory, short_memory = self.lstm_unit(day3, long_memory, short_memory)
        long_memory, short_memory = self.lstm_unit(day4, long_memory, short_memory)

        return short_memory

    def configure_optimizers(self):
        return Adam(self.parameters())

    def training_step(self, batch, batch_idx):
        input_i, label_i = batch
        output_i = self.forward(input_i[0])
        loss = (output_i - label_i)**2

        self.log("train_loss ", loss)
        
        if (label_i == 0):
            self.log("Out 0", output_i)
        else:
            self.log("Out 1", output_i)
        return loss


In [149]:

model = LSTMbyHand()
print("Company A: observed = 0, Predicted = ", model(torch.tensor([0., 0.25, 0.5, 1.])).detach())

forget = 0.0
candidate = 0.0
input = 0.0
new_long = 0.0
output = 0.0
new short = 0.0


forget = -0.22095584869384766
candidate = -0.605329692363739
input = 0.07093647867441177
new_long = 0.02500753290951252
output = -0.23819378018379211
new short = 0.011019309982657433


forget = -0.46344390511512756
candidate = -1.1927615404129028
input = 0.16682933270931244
new_long = 0.048132941126823425
output = -0.46684446930885315
new short = 0.018534362316131592


forget = -0.9200403094291687
candidate = -2.391214609146118
input = 0.32572227716445923
new_long = 0.04009915515780449
output = -0.9367237687110901
new short = 0.011284374631941319


Company A: observed = 0, Predicted =  tensor(0.0113)


In [150]:
model = LSTMbyHand()
print("Company B: observed = 0.5, Predicted = ", model(torch.tensor([1., 0.25, 0.5, 1.])).detach())

forget = 0.9016904234886169
candidate = -1.142091155052185
input = -0.3660004436969757
new_long = -0.0847959816455841
output = -0.9722270369529724
new short = -0.02321552485227585


forget = 0.24212688207626343
candidate = -0.32943686842918396
input = -0.08277017623186111
new_long = -0.08205623924732208
output = -0.17942604422569275
new short = -0.0372735895216465


forget = 0.47766467928886414
candidate = -0.641551673412323
input = -0.16898390650749207
new_long = -0.10837815701961517
output = -0.38395148515701294
new short = -0.04374090954661369


forget = 0.9331633448600769
candidate = -1.2248307466506958
input = -0.3495521545410156
new_long = -0.15408125519752502
output = -0.8523389101028442
new short = -0.04570034146308899


Company B: observed = 0.5, Predicted =  tensor(-0.0457)


In [151]:
inputs = torch.tensor([[0., 0.25, 0.5, 1.], [1., 0.5, 0.25, 1.]])
labels = torch.tensor([0., 0.5])

In [152]:
dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)

In [162]:
trainer = L.Trainer(max_epochs=3000)
trainer.fit(model, train_dataloaders=dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs



  | Name         | Type | Params
--------------------------------------
  | other params | n/a  | 12    
--------------------------------------
12        Trainable params
0         Non-trainable params
12        Total params
0.000     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

forget = 0.38814979791641235
candidate = 0.3732115626335144
input = 0.2586579918861389
new_long = 0.1498590111732483
output = 0.40231114625930786
new short = 0.08913566917181015


forget = 0.6959879398345947
candidate = 0.4718503952026367
input = 0.3512807786464691
new_long = 0.30784711241722107
output = 0.1209852397441864
new short = 0.1582554578781128


forget = 1.0048000812530518
candidate = 0.5218281149864197
input = 0.44331878423690796
new_long = 0.48666149377822876
output = -0.120849609375
new short = 0.21215513348579407


forget = 1.6265287399291992
candidate = 0.4167427122592926
input = 0.6249308586120605
new_long = 0.7409301996231079
output = -0.4381179213523865
new short = 0.2469644397497177


forget = 1.6348533630371094
candidate = -0.10103294253349304
input = 0.6167328357696533
new_long = 0.2605733871459961
output = -0.021539002656936646
new short = 0.12604373693466187


forget = 1.0047416687011719
candidate = 0.4418913722038269
input = 0.44075173139572144
new_long = 0.4429

`Trainer.fit` stopped: `max_epochs=3000` reached.


forget = 1.2189215421676636
candidate = 0.6538562774658203
input = -0.1766098290681839
new_long = -0.11499476432800293
output = 0.869757890701294
new short = -0.0806809514760971


forget = 1.5959937572479248
candidate = 0.39116373658180237
input = -0.11996976286172867
new_long = -0.16684138774871826
output = 0.9862515926361084
new short = -0.12040328979492188


forget = 2.0079565048217773
candidate = 0.2501838803291321
input = -0.018551647663116455
new_long = -0.1575213074684143
output = 1.05076003074646
new short = -0.11575499176979065


forget = 2.9035162925720215
candidate = 0.2181149125099182
input = 0.27621936798095703
new_long = 7.450580596923828e-08
output = 1.073044776916504
new short = 5.551990511776239e-08


forget = 3.002121925354004
candidate = 0.5620926022529602
input = 0.40276867151260376
new_long = 0.24351070821285248
output = 0.9261268973350525
new short = 0.17105622589588165


forget = 2.2562355995178223
candidate = 1.1162853240966797
input = 0.3000868558883667
new_lon

In [3]:
print("Company B: observed = 0.5, Predicted = ", model.forward(torch.tensor([1., 0.5, 0.25, 1.])).detach())

NameError: name 'model' is not defined

In [168]:
print(model.input_format(model.wlr1, model.wlr2, model.blr1))
print(model.input_format(model.wpr1, model.wpr2, model.bpr1))
print(model.input_format(model.wp1, model.wp2, model.bp1))
print(model.input_format(model.wo1, model.wo2, model.bo1))

{'input_weights': 1783, 'gate_biases': 1218921, 'short_weights': 851}
{'input_weights': -91, 'gate_biases': 653856, 'short_weights': 2971}
{'input_weights': 579, 'gate_biases': -176609, 'short_weights': 1093}
{'input_weights': 56, 'gate_biases': 869757, 'short_weights': -1269}


In [156]:
x = list([[5, 9], [5, 6]])
s = type(x)
if (s == list):
    if type(x[0]) == list:
        print("list")
    else:
        print("int")
else:
    print((s))

list


In [157]:
def values(input):
        output_string = ''
        if (type(input) == int):
            output_string += str(input)
        elif type(input) == list:
            output_string = '('
            if type(input[0]) == int:
                for k, i in enumerate(input):
                    print(i)
                    output_string += str(i)
                    if k == len(input)-1:
                        break
                    output_string += ", "
                output_string += ')'

            elif type(input[0]) == list:
                for k, i in enumerate(input):
                    output_string += '('
                    for inner, j in enumerate(i):
                        output_string += str(j)
                        if inner == len(i)-1:
                            break
                        output_string += ", "
                    output_string += ')'
                    if k == len(input)-1:
                        break
                    output_string += ','
                output_string += ')'
        return output_string

In [158]:
print(values([[2,3,4], [9,8,9]]))

((2, 3, 4),(9, 8, 9))


In [159]:
a = 0.95 * 1 + 0.86
print(a)
b = 95 * 1 + 86
print(f"b = {b/10}")

1.81
b = 18.1


In [2]:
tanh = [-761,-757,-753,-748,-744,-739,-735,-730,-725,-721,-716,-711,-706,-701,-696,-691,-685,-680,-675,-669,-664,-658,-652,-646,-641,-635,-629,-623,-616,-610,-604,-597,-591,-584,-578,-571,-564,-558,-551,-544,-537,-529,-522,-515,-507,-500,-492,-485,-477,-469,-462,-454,-446,-438,-430,-421,-413,-405,-396,-388,-379,-371,-362,-353,-345,-336,-327,-318,-309,-300,-291,-282,-272,-263,-254,-244,-235,-226,-216,-206,-197,-187,-178,-168,-158,-148,-139,-129,-119,-109,-99,-89,-79,-69,-59,-49,-39,-29,-19,-9,0,9,19,29,39,49,59,69,79,89,99,109,119,129,139,148,158,168,178,187,197,206,216,226,235,244,254,263,272,282,291,300,309,318,327,336,345,353,362,371,379,388,396,405,413,421,430,438,446,454,462,469,477,485,492,500,507,515,522,529,537,544,551,558,564,571,578,584,591,597,604,610,616,623,629,635,641,646,652,658,664,669,675,680,685,691,696,701,706,711,716,721,725,730,735,739,744,748,753,757,761,765,769,773,777,781,785,789,793,796,800,804,807,811,814,817,821,824,827,830,833,836,839,842,845,848,851,853,856,859,861,864,866,869,871,874,876,878,880,883,885,887,889,891,893,895,897,899,901,903,905,906,908,910,912,913,915,917,918,920,921,923,924,926,927,928,930,931,932,934,935,936,937,939,940,941,942,943,944,945,946,947,948,949,950,951,952,953,954,955,956,957,957,958,959,960,961,961,962,963,964,964,965,966,966,967,968,968,969,969,970,971,971,972,972,973,973,974,974,975,975,976,976,977,977,978,978,978,979,979,980,980,980,981,981,981,982,982,983,983,983,983,984,984,984,985,985,985,986,986,986,986,987,987,987,987,988,988,988,988,989,989,989,989,989,990,990,990,990,990,991,991,991,991,991,991,992,992,992,992,992,992,992,993,993,993,993,993,993,993,993,994,994,994,994,994,994,994,994,994,995,995,995,995,995,995,995,995,995,995,995,996,996,996,996,996,996,996,996,996,996,996,996,996,996,996,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999]
tanh[1100]
len(tanh)#

1101

In [217]:
sig = [268,270,272,274,276,278,280,282,284,286,289,291,293,295,297,299,301,303,305,307,310,312,314,316,318,320,323,325,327,329,331,334,336,338,340,342,345,347,349,352,354,356,358,361,363,365,368,370,372,375,377,379,382,384,386,389,391,394,396,398,401,403,406,408,410,413,415,418,420,423,425,428,430,432,435,437,440,442,445,447,450,452,455,457,460,462,465,467,470,472,475,477,480,482,485,487,490,492,495,497,500,502,504,507,509,512,514,517,519,522,524,527,529,532,534,537,539,542,544,547,549,552,554,557,559,562,564,567,569,571,574,576,579,581,584,586,589,591,593,596,598,601,603,605,608,610,613,615,617,620,622,624,627,629,631,634,636,638,641,643,645,647,650,652,654,657,659,661,663,665,668,670,672,674,676,679,681,683,685,687,689,692,694,696,698,700,702,704,706,708,710,713,715,717,719,721,723,725,727,729,731,733,734,736,738,740,742,744,746,748,750,752,753,755,757,759,761,763,764,766,768,770,772,773,775,777,779,780,782,784,785,787,789,790,792,794,795,797,798,800,802,803,805,806,808,809,811,813,814,816,817,819,820,822,823,824,826,827,829,830,832,833,834,836,837,838,840,841,842,844,845,846,848,849,850,851,853,854,855,856,858,859,860,861,862,864,865,866,867,868,869,871,872,873,874,875,876,877,878,879,880,881,882,883,884,885,886,887,888,889,890,891,892,893,894,895,896,897,898,899,900,901,902,902,903,904,905,906,907,908,908,909,910,911,912,912,913,914,915,916,916,917,918,919,919,920,921,922,922,923,924,924,925,926,926,927,928,928,929,930,930,931,932,932,933,934,934,935,935,936,937,937,938,938,939,939,940,941,941,942,942,943,943,944,944,945,945,946,946,947,947,948,948,949,949,950,950,951,951,952,952,953,953,953,954,954,955,955,956,956,956,957,957,958,958,958,959,959,960,960,960,961,961,961,962,962,963,963,963,964,964,964,965,965,965,966,966,966,967,967,967,968,968,968,968,969,969,969,970,970,970,970,971,971,971,972,972,972,972,973,973,973,973,974,974,974,974,975,975,975,975,976,976,976,976,977,977,977,977,977,978,978,978,978,978,979,979,979,979,979,980,980,980,980,980,981,981,981,981,981,982,982,982,982,982,982,983,983,983,983,983,983,984,984,984,984,984,984,984,985,985,985,985,985,985,985,986,986,986,986,986,986,986,987,987,987,987,987,987,987,987,987,988,988,988,988,988,988,988,988,989,989,989,989,989,989,989,989,989,989,990,990,990,990,990,990,990,990,990,990,990,991,991,991,991,991,991,991,991,991,991,991,991,992,992,992,992,992,992,992,992,992,992,992,992,992,993,993,993,993,993,993,993,993,993,993,993,993,993,993,993,994,994,994,994,994,994,994,994,994,994,994,994,994,994,994,994,994,994,994,995,995,995,995,995,995,995,995,995,995,995,995,995,995,995,995,995,995,995,995,995,995,996,996,996,996,996,996,996,996,996,996,996,996,996,996,996,996,996,996,996,996,996,996,996,996,996,996,996,996,996,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,997,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,998,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999,999]
sig[151]

624