In [1]:
from torch import nn
import torch
import glob
import copy
import math
import random
import time
import math

In [2]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        
        self.hidden_size = hidden_size
        
        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden
    
    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

n_input = 20
n_output = 2
n_hidden = 20
rnn = RNN(n_input, n_hidden, n_output)

In [3]:
rnn = RNN(n_input, n_hidden, n_output)
rnn.load_state_dict(torch.load('./rnn_state'))
rnn.eval()

RNN(
  (i2h): Linear(in_features=40, out_features=20, bias=True)
  (i2o): Linear(in_features=40, out_features=2, bias=True)
  (softmax): LogSoftmax()
)

# Read DataSet

In [4]:
def findFiles(path): return glob.glob(path)

filenames = findFiles('dataset/pretrain/*.csv')

In [5]:
def readCSV(filename):
    lines = open(filename, encoding='utf-8').read().strip().split('\n')
    return [line.split('\t') for line in lines]

def readCSVNumber(filename):
    lines = open(filename, encoding='utf-8').read().strip().split('\n')
    return [[float(num) for num in line.split('\t')] for line in lines]

In [6]:
datas = []
errdatas = []

for filename in filenames:
    try:
        datas.append(readCSVNumber(filename))
    except Exception as ex:
        errdatas.append(readCSV(filename))
        print(filename)
        print(ex)

In [7]:
datas[0]

[[0.08333333333333333,
  0.01694915254237288,
  0.03465982028241335,
  0.08333333333333333,
  0.01694915254237288,
  0.03465982028241335,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.08333333333333333,
  0.11864406779661017,
  0.04236200256739409,
  0.08333333333333333,
  0.11864406779661017,
  0.04236200256739409,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.08333333333333333,
  0.288135593220339,
  0.055198973042362,
  0.08333333333333333,
  0.288135593220339,
  0.055198973042362,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.19444444444444445,
  0.01694915254237288,
  0.09242618741976893,
  0.19444444444444445,
  0.01694915254237288,
  0.09242618741976893,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.19444444444444445,
  0.11864406779661017,
  0.1001283697

## 6 + 7 + 2 => DeclRefExpr
## 6 + 7 + 1 => VarDecl
## 6 + 7 + 3 => FunctionDecl
## 6 + 7 + 6 => ParmDecl

# Make Training Dataset

In [7]:
def changeData(data, declindex, refindex, changeNum):
    newdata = copy.deepcopy(data)
    changeList = []
    
    for i in range(changeNum):
        while True:
            idx = random.randint(0, len(refindex) - 1)
            
            # 선언된 변수가 1개거나 이미 바꾼 줄이면 다음 랜덤 인덱스
            if refindex[idx][1] == 1 or (idx in changeList):
                continue
            else:
                changeList.append(refindex[idx][0])

                while True:
                    decidx = random.randint(0, len(declindex) - 1)
                    # 기존의 변수랑 같은 거면 다른 변수
                    if newdata[declindex[decidx]][5] == newdata[refindex[idx][0]][2]:
                        continue
                    newdata[refindex[idx][0]][0] = newdata[declindex[decidx]][3]
                    newdata[refindex[idx][0]][1] = newdata[declindex[decidx]][4]
                    newdata[refindex[idx][0]][2] = newdata[declindex[decidx]][5]
                    break
                break
    
    return newdata, changeList

In [8]:
def makeTraining(dataset):
    trainData = []
    output = []
    
    for data in dataset:
        declindex = []
        refindex = []
        for idx, line in enumerate(data):
            if (line[14] == 1 or line[16] == 1 or line[19] == 1):
                declindex.append(idx)
            elif line[15] == 1:
                refindex.append([idx, len(declindex)])

        changeRate = 0.1
        changeNum = math.floor(changeRate * len(refindex))
        if changeNum > 0:
            for i in range(10):
                newdata, changeList = changeData(data, declindex, refindex, changeNum)
                trainData.append(newdata)
                output.append([0 if idx in changeList else 1 for idx in range(len(newdata))])
                
        print('done ', len(trainData), '/', len(dataset) * 10)
                
    return trainData, output

In [9]:
trainDataOrigin, outputOrigin = makeTraining(datas)

done  10 / 37600
done  20 / 37600
done  30 / 37600
done  40 / 37600
done  50 / 37600
done  60 / 37600
done  70 / 37600
done  80 / 37600
done  90 / 37600
done  100 / 37600
done  110 / 37600
done  120 / 37600
done  130 / 37600
done  140 / 37600
done  150 / 37600
done  160 / 37600
done  170 / 37600
done  180 / 37600
done  190 / 37600
done  200 / 37600
done  210 / 37600
done  220 / 37600
done  230 / 37600
done  240 / 37600
done  250 / 37600
done  260 / 37600
done  270 / 37600
done  280 / 37600
done  290 / 37600
done  300 / 37600
done  310 / 37600
done  320 / 37600
done  330 / 37600
done  340 / 37600
done  350 / 37600
done  360 / 37600
done  370 / 37600
done  380 / 37600
done  390 / 37600
done  400 / 37600
done  410 / 37600
done  420 / 37600
done  430 / 37600
done  440 / 37600
done  450 / 37600
done  460 / 37600
done  470 / 37600
done  480 / 37600
done  490 / 37600
done  500 / 37600
done  510 / 37600
done  520 / 37600
done  530 / 37600
done  540 / 37600
done  550 / 37600
done  560 / 37600
d

done  4540 / 37600
done  4550 / 37600
done  4560 / 37600
done  4570 / 37600
done  4580 / 37600
done  4590 / 37600
done  4600 / 37600
done  4610 / 37600
done  4620 / 37600
done  4630 / 37600
done  4640 / 37600
done  4650 / 37600
done  4660 / 37600
done  4670 / 37600
done  4680 / 37600
done  4690 / 37600
done  4700 / 37600
done  4710 / 37600
done  4720 / 37600
done  4730 / 37600
done  4740 / 37600
done  4750 / 37600
done  4760 / 37600
done  4770 / 37600
done  4780 / 37600
done  4790 / 37600
done  4800 / 37600
done  4810 / 37600
done  4820 / 37600
done  4830 / 37600
done  4840 / 37600
done  4850 / 37600
done  4860 / 37600
done  4870 / 37600
done  4880 / 37600
done  4890 / 37600
done  4900 / 37600
done  4910 / 37600
done  4920 / 37600
done  4930 / 37600
done  4940 / 37600
done  4950 / 37600
done  4960 / 37600
done  4970 / 37600
done  4980 / 37600
done  4990 / 37600
done  5000 / 37600
done  5010 / 37600
done  5020 / 37600
done  5030 / 37600
done  5040 / 37600
done  5050 / 37600
done  5060 /

done  8890 / 37600
done  8900 / 37600
done  8910 / 37600
done  8920 / 37600
done  8930 / 37600
done  8940 / 37600
done  8950 / 37600
done  8960 / 37600
done  8970 / 37600
done  8980 / 37600
done  8990 / 37600
done  9000 / 37600
done  9010 / 37600
done  9020 / 37600
done  9030 / 37600
done  9040 / 37600
done  9050 / 37600
done  9060 / 37600
done  9070 / 37600
done  9080 / 37600
done  9090 / 37600
done  9100 / 37600
done  9110 / 37600
done  9120 / 37600
done  9130 / 37600
done  9140 / 37600
done  9150 / 37600
done  9160 / 37600
done  9170 / 37600
done  9180 / 37600
done  9190 / 37600
done  9200 / 37600
done  9210 / 37600
done  9220 / 37600
done  9230 / 37600
done  9240 / 37600
done  9250 / 37600
done  9260 / 37600
done  9270 / 37600
done  9280 / 37600
done  9290 / 37600
done  9300 / 37600
done  9310 / 37600
done  9320 / 37600
done  9330 / 37600
done  9340 / 37600
done  9350 / 37600
done  9360 / 37600
done  9370 / 37600
done  9380 / 37600
done  9390 / 37600
done  9400 / 37600
done  9410 /

done  13180 / 37600
done  13190 / 37600
done  13200 / 37600
done  13210 / 37600
done  13220 / 37600
done  13230 / 37600
done  13240 / 37600
done  13250 / 37600
done  13260 / 37600
done  13270 / 37600
done  13280 / 37600
done  13290 / 37600
done  13300 / 37600
done  13310 / 37600
done  13320 / 37600
done  13330 / 37600
done  13340 / 37600
done  13350 / 37600
done  13360 / 37600
done  13370 / 37600
done  13380 / 37600
done  13390 / 37600
done  13400 / 37600
done  13410 / 37600
done  13420 / 37600
done  13430 / 37600
done  13440 / 37600
done  13450 / 37600
done  13460 / 37600
done  13470 / 37600
done  13480 / 37600
done  13490 / 37600
done  13500 / 37600
done  13510 / 37600
done  13520 / 37600
done  13530 / 37600
done  13540 / 37600
done  13550 / 37600
done  13560 / 37600
done  13570 / 37600
done  13580 / 37600
done  13590 / 37600
done  13600 / 37600
done  13610 / 37600
done  13620 / 37600
done  13630 / 37600
done  13640 / 37600
done  13650 / 37600
done  13660 / 37600
done  13670 / 37600


done  17380 / 37600
done  17390 / 37600
done  17400 / 37600
done  17410 / 37600
done  17420 / 37600
done  17430 / 37600
done  17440 / 37600
done  17450 / 37600
done  17460 / 37600
done  17470 / 37600
done  17480 / 37600
done  17490 / 37600
done  17500 / 37600
done  17510 / 37600
done  17520 / 37600
done  17530 / 37600
done  17540 / 37600
done  17550 / 37600
done  17560 / 37600
done  17570 / 37600
done  17580 / 37600
done  17590 / 37600
done  17600 / 37600
done  17610 / 37600
done  17620 / 37600
done  17630 / 37600
done  17640 / 37600
done  17650 / 37600
done  17660 / 37600
done  17670 / 37600
done  17680 / 37600
done  17690 / 37600
done  17700 / 37600
done  17710 / 37600
done  17720 / 37600
done  17730 / 37600
done  17740 / 37600
done  17750 / 37600
done  17760 / 37600
done  17770 / 37600
done  17780 / 37600
done  17790 / 37600
done  17800 / 37600
done  17810 / 37600
done  17820 / 37600
done  17830 / 37600
done  17840 / 37600
done  17850 / 37600
done  17860 / 37600
done  17870 / 37600


done  21510 / 37600
done  21520 / 37600
done  21530 / 37600
done  21540 / 37600
done  21550 / 37600
done  21560 / 37600
done  21570 / 37600
done  21580 / 37600
done  21590 / 37600
done  21600 / 37600
done  21610 / 37600
done  21620 / 37600
done  21630 / 37600
done  21640 / 37600
done  21650 / 37600
done  21660 / 37600
done  21670 / 37600
done  21680 / 37600
done  21690 / 37600
done  21700 / 37600
done  21710 / 37600
done  21720 / 37600
done  21730 / 37600
done  21740 / 37600
done  21750 / 37600
done  21760 / 37600
done  21770 / 37600
done  21780 / 37600
done  21790 / 37600
done  21800 / 37600
done  21810 / 37600
done  21820 / 37600
done  21830 / 37600
done  21840 / 37600
done  21850 / 37600
done  21860 / 37600
done  21870 / 37600
done  21880 / 37600
done  21890 / 37600
done  21900 / 37600
done  21910 / 37600
done  21920 / 37600
done  21930 / 37600
done  21940 / 37600
done  21950 / 37600
done  21960 / 37600
done  21970 / 37600
done  21980 / 37600
done  21990 / 37600
done  22000 / 37600


done  25610 / 37600
done  25620 / 37600
done  25630 / 37600
done  25640 / 37600
done  25650 / 37600
done  25660 / 37600
done  25670 / 37600
done  25680 / 37600
done  25690 / 37600
done  25700 / 37600
done  25710 / 37600
done  25720 / 37600
done  25730 / 37600
done  25740 / 37600
done  25750 / 37600
done  25760 / 37600
done  25770 / 37600
done  25780 / 37600
done  25790 / 37600
done  25800 / 37600
done  25810 / 37600
done  25820 / 37600
done  25830 / 37600
done  25840 / 37600
done  25850 / 37600
done  25860 / 37600
done  25870 / 37600
done  25880 / 37600
done  25890 / 37600
done  25900 / 37600
done  25910 / 37600
done  25920 / 37600
done  25930 / 37600
done  25940 / 37600
done  25950 / 37600
done  25960 / 37600
done  25970 / 37600
done  25980 / 37600
done  25990 / 37600
done  26000 / 37600
done  26010 / 37600
done  26020 / 37600
done  26030 / 37600
done  26040 / 37600
done  26050 / 37600
done  26060 / 37600
done  26070 / 37600
done  26080 / 37600
done  26090 / 37600
done  26100 / 37600


done  29730 / 37600
done  29740 / 37600
done  29750 / 37600
done  29760 / 37600
done  29770 / 37600
done  29780 / 37600
done  29790 / 37600
done  29800 / 37600
done  29810 / 37600
done  29820 / 37600
done  29830 / 37600
done  29840 / 37600
done  29850 / 37600
done  29860 / 37600
done  29870 / 37600
done  29880 / 37600
done  29890 / 37600
done  29900 / 37600
done  29910 / 37600
done  29920 / 37600
done  29930 / 37600
done  29940 / 37600
done  29950 / 37600
done  29960 / 37600
done  29970 / 37600
done  29980 / 37600
done  29990 / 37600
done  30000 / 37600
done  30010 / 37600
done  30020 / 37600
done  30030 / 37600
done  30040 / 37600
done  30050 / 37600
done  30060 / 37600
done  30070 / 37600
done  30080 / 37600
done  30090 / 37600
done  30100 / 37600
done  30110 / 37600
done  30120 / 37600
done  30130 / 37600
done  30140 / 37600
done  30150 / 37600
done  30160 / 37600
done  30170 / 37600
done  30180 / 37600
done  30190 / 37600
done  30200 / 37600
done  30210 / 37600
done  30220 / 37600


done  33830 / 37600
done  33840 / 37600
done  33850 / 37600
done  33860 / 37600
done  33870 / 37600
done  33880 / 37600
done  33890 / 37600
done  33900 / 37600
done  33910 / 37600
done  33920 / 37600
done  33930 / 37600
done  33940 / 37600
done  33950 / 37600
done  33960 / 37600
done  33970 / 37600
done  33980 / 37600
done  33990 / 37600
done  34000 / 37600
done  34010 / 37600
done  34020 / 37600
done  34030 / 37600
done  34040 / 37600
done  34050 / 37600
done  34060 / 37600
done  34070 / 37600
done  34080 / 37600
done  34090 / 37600
done  34100 / 37600
done  34110 / 37600
done  34120 / 37600
done  34130 / 37600
done  34140 / 37600
done  34150 / 37600
done  34160 / 37600
done  34170 / 37600
done  34180 / 37600
done  34190 / 37600
done  34200 / 37600
done  34210 / 37600
done  34220 / 37600
done  34230 / 37600
done  34240 / 37600
done  34250 / 37600
done  34260 / 37600
done  34270 / 37600
done  34280 / 37600
done  34290 / 37600
done  34300 / 37600
done  34310 / 37600
done  34320 / 37600


In [None]:
trainData = [torch.FloatTensor(data) for data in trainDataOrigin]
outputData = [torch.FloatTensor(data) for data in outputOrigin]

In [None]:
trainData = [data.reshape(len(data), 1, -1) for data in trainData]

outputData = [[data.reshape(1).long() for data in line] for line in outputData]

# Training

In [None]:
criterion = nn.NLLLoss()

In [None]:
learning_rate = 0.005

def train(tokenTensor, outputTensor):
    for idx in range(1, tokenTensor.size()[0]):
        if int(tokenTensor[idx][0][15]) != 1:
            continue
        
        hidden = rnn.initHidden()
        rnn.zero_grad()

        for i in range(idx + 1):
            output, hidden = rnn(tokenTensor[i], hidden)
        loss = criterion(output, outputTensor[idx].reshape(1).long())
        loss.backward(retain_graph=True)

        for p in rnn.parameters():
            p.data.add_(-learning_rate, p.grad.data)
    
    return output, loss.item()

In [None]:
n_iters = 100000
print_every = 50
plot_every = 10

In [None]:
current_loss = 0
all_losses = []

def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

start = time.time()

In [16]:
len(trainData)

37600

In [65]:
idx = 0
for iter in range(0, n_iters + 1):
    tokenTensor = trainData[idx]
    outputTensor = outputData[idx]
    output, loss = train(tokenTensor, outputTensor)
    current_loss += loss
    
    if iter % print_every == 0:
        print('%d %d%% (%s) %.4f' % (iter, iter / n_iters * 100, timeSince(start), loss))
    
    if iter % plot_every == 0:
        all_losses.append(current_loss / plot_every)
        current_loss = 0

    print(iter)
    idx = (idx + 1) % 230 # len(trainData)

0 0% (1155m 54s) 0.2577
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50 0% (1158m 35s) 0.1098
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100 0% (1160m 27s) 0.0097
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150 0% (1164m 51s) 0.1122
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200 0% (1165m 31s) 0.1187
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245

1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700 1% (1222m 19s) 0.0423
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750 1% (1226m 42s) 0.0519
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800 1% (1227m 18s) 0.1255
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850 1% (1227m 46s) 2.7387
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
18

3163
3164
3165
3166
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
3191
3192
3193
3194
3195
3196
3197
3198
3199
3200 3% (1279m 53s) 0.0184
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250 3% (1280m 53s) 0.0114
3250
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
3264
3265
3266
3267
3268
3269
3270
3271
3272
3273
3274
3275
3276
3277
3278
3279
3280
3281
3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
3293
3294
3295
3296
3297
3298
3299
3300 3% (1283m 23s) 0.1516
3300
3301
3302
3303
3304
3305
3306
3307
3308
3309
3310
3311
3312
3313
3314
3315
3316
3317
3318
3319
3320
3321
3322
3323
3324
3325
3326
3327
3328
3329
3330
3331
3332
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
3346

KeyboardInterrupt: 

In [None]:
correct_result = {
    0: False,
    1: True
}
def correctFromOutput(output):
    top_n, top_i = output.topk(1)
    correct_i = top_i[0].item()
    return correct_result[correct_i], correct_i

In [None]:
hidden = torch.zeros(1, n_hidden)
testidx = 500
for i in range(trainData[testidx].size()[0]):
    result, hidden = rnn(trainData[testidx][i], hidden)
    if int(trainData[testidx][i][0][15]) == 1:
        print('declrefexpr')
        #print(result)
        print('predict: ', correctFromOutput(result), 'actual: ', correct_result[int(outputData[testidx][i])], 'diff: ', float(result[0][1] - result[0][0]))

In [72]:
torch.save(rnn.state_dict(), './rnn_state')

In [12]:
def testRange(st=0, ed=1000):
    totalRef = 0
    totalFalse = 0
    falseButTrue = 0
    falseAndFalse = 0
    trueButFalse = 0

    for testidx in range(st, ed):
        if testidx % 100 == 0:
            print('testidx: ', testidx)
        hidden = torch.zeros(1, n_hidden)
        for i in range(trainData[testidx].size()[0]):
            result, hidden = rnn(trainData[testidx][i], hidden)
            if int(trainData[testidx][i][0][15]) == 1:
                totalRef += 1
                #print('declrefexpr')
                #print(result)
                #print('predict: ', correctFromOutput(result), 'actual: ', correct_result[int(outputData[testidx][i])], 'diff: ', float(result[0][1] - result[0][0]))
                if correct_result[int(outputData[testidx][i])] == False:
                    totalFalse += 1

                if correctFromOutput(result)[0] == False:
                    #print('predict: ', correctFromOutput(result), 'actual: ', correct_result[int(outputData[testidx][i])], 'diff: ', float(result[0][1] - result[0][0]))

                    if correct_result[int(outputData[testidx][i])] == False:
                        falseAndFalse += 1
                    else:
                        trueButFalse += 1
                else:
                    if correct_result[int(outputData[testidx][i])] == False:
                        falseButTrue += 1
    
    precision = falseAndFalse / (falseAndFalse + trueButFalse)
    recall = falseAndFalse / totalFalse

    print('totalRef: ', totalRef)
    print('totalFalse: ', totalFalse)
    print('falseAndFalse: ', falseAndFalse)
    print('precision: ', precision)
    print('recall: ', recall)

In [71]:
testRange(0, 1000)

testidx:  0
testidx:  100
testidx:  200
testidx:  300
testidx:  400
testidx:  500
testidx:  600
testidx:  700
testidx:  800
testidx:  900
totalRef:  202230
totalFalse:  18835
falseAndFalse:  523
precision:  0.526686807653575
recall:  0.02776745420759225


In [67]:
testRange(1000, 2000)

testidx:  1000
testidx:  1100
testidx:  1200
testidx:  1300
testidx:  1400
testidx:  1500
testidx:  1600
testidx:  1700
testidx:  1800
testidx:  1900
totalRef:  96670
totalFalse:  8806
falseAndFalse:  545
precision:  0.44308943089430897
recall:  0.06188962071315012


In [68]:
testRange(2000, 3000)

testidx:  2000
testidx:  2100
testidx:  2200
testidx:  2300
testidx:  2400
testidx:  2500
testidx:  2600
testidx:  2700
testidx:  2800
testidx:  2900
totalRef:  106180
totalFalse:  9761
falseAndFalse:  273
precision:  0.7822349570200573
recall:  0.027968445855957383


In [69]:
testRange(3000, 4000)
testRange(4000, 5000)

testidx:  3000
testidx:  3100
testidx:  3200
testidx:  3300
testidx:  3400
testidx:  3500
testidx:  3600
testidx:  3700
testidx:  3800
testidx:  3900
totalRef:  143780
totalFalse:  13295
falseAndFalse:  237
precision:  0.9011406844106464
recall:  0.017826250470101543
testidx:  4000
testidx:  4100
testidx:  4200
testidx:  4300
testidx:  4400
testidx:  4500
testidx:  4600
testidx:  4700
testidx:  4800
testidx:  4900
totalRef:  185960
totalFalse:  17242
falseAndFalse:  561
precision:  0.4010007147962831
recall:  0.03253682867416773


In [73]:
testRange(0, 1000)
testRange(1000, 2000)
testRange(2000, 3000)
testRange(3000, 4000)
testRange(4000, 5000)

testidx:  0
testidx:  100
testidx:  200
testidx:  300
testidx:  400
testidx:  500
testidx:  600
testidx:  700
testidx:  800
testidx:  900
totalRef:  202230
totalFalse:  18835
falseAndFalse:  523
precision:  0.526686807653575
recall:  0.02776745420759225
testidx:  1000
testidx:  1100
testidx:  1200
testidx:  1300
testidx:  1400
testidx:  1500
testidx:  1600
testidx:  1700
testidx:  1800
testidx:  1900
totalRef:  96670
totalFalse:  8806
falseAndFalse:  619
precision:  0.4858712715855573
recall:  0.07029298205768794
testidx:  2000
testidx:  2100
testidx:  2200
testidx:  2300
testidx:  2400
testidx:  2500
testidx:  2600
testidx:  2700
testidx:  2800
testidx:  2900
totalRef:  106180
totalFalse:  9761
falseAndFalse:  337
precision:  0.8042959427207638
recall:  0.03452515111156644
testidx:  3000
testidx:  3100
testidx:  3200
testidx:  3300
testidx:  3400
testidx:  3500
testidx:  3600
testidx:  3700
testidx:  3800
testidx:  3900
totalRef:  143780
totalFalse:  13295
falseAndFalse:  300
precisio

# Print with diff degree sequence

In [10]:
def printDiffSequence(testidx = 500):
    hidden = torch.zeros(1, n_hidden)

    predictList = []
    for i in range(trainData[testidx].size()[0]):
        result, hidden = rnn(trainData[testidx][i], hidden)
        if int(trainData[testidx][i][0][15]) == 1:
            # print('predict: ', correctFromOutput(result), 'actual: ', correct_result[int(outputData[testidx][i])], 'diff: ', float(result[0][1] - result[0][0]))
            predictList.append([i, correctFromOutput(result), correct_result[int(outputData[testidx][i])], float(result[0][1] - result[0][0])])
            
    predictList.sort(key = lambda x: x[3])
    
    for predict in predictList:
        print('index: ', predict[0], 'predict: ', predict[1], 'actual: ', predict[2], 'diff: ', predict[3])
        

In [11]:
printDiffSequence(3001)

NameError: name 'trainData' is not defined