In [1]:
from torch import nn
import torch
import glob
import copy
import math
import random
import time
import math

In [2]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        
        self.hidden_size = hidden_size
        
        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden
    
    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

n_input = 20
n_output = 2
n_hidden = 20
rnn = RNN(n_input, n_hidden, n_output)

# Read DataSet

In [3]:
def findFiles(path): return glob.glob(path)

filenames = findFiles('dataset/pretrain/*.csv')

In [4]:
def readCSV(filename):
    lines = open(filename, encoding='utf-8').read().strip().split('\n')
    return [line.split('\t') for line in lines]

def readCSVNumber(filename):
    lines = open(filename, encoding='utf-8').read().strip().split('\n')
    return [[float(num) for num in line.split('\t')] for line in lines]

In [5]:
datas = []
errdatas = []

for filename in filenames:
    try:
        datas.append(readCSVNumber(filename))
    except Exception as ex:
        errdatas.append(readCSV(filename))
        print(filename)
        print(ex)

In [6]:
datas[0]

[[0.08333333333333333,
  0.01694915254237288,
  0.03465982028241335,
  0.08333333333333333,
  0.01694915254237288,
  0.03465982028241335,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.08333333333333333,
  0.11864406779661017,
  0.04236200256739409,
  0.08333333333333333,
  0.11864406779661017,
  0.04236200256739409,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.08333333333333333,
  0.288135593220339,
  0.055198973042362,
  0.08333333333333333,
  0.288135593220339,
  0.055198973042362,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.19444444444444445,
  0.01694915254237288,
  0.09242618741976893,
  0.19444444444444445,
  0.01694915254237288,
  0.09242618741976893,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.19444444444444445,
  0.11864406779661017,
  0.1001283697

## 6 + 7 + 2 => DeclRefExpr
## 6 + 7 + 1 => VarDecl
## 6 + 7 + 3 => FunctionDecl
## 6 + 7 + 6 => ParmDecl

# Make Training Dataset

In [7]:
def changeData(data, declindex, refindex, changeNum):
    newdata = copy.deepcopy(data)
    changeList = []
    
    for i in range(changeNum):
        while True:
            idx = random.randint(0, len(refindex) - 1)
            
            # 선언된 변수가 1개거나 이미 바꾼 줄이면 다음 랜덤 인덱스
            if refindex[idx][1] == 1 or (idx in changeList):
                continue
            else:
                changeList.append(idx)

                while True:
                    decidx = random.randint(0, len(declindex) - 1)
                    # 기존의 변수랑 같은 거면 다른 변수
                    if newdata[declindex[decidx]][5] == newdata[refindex[idx][0]][2]:
                        continue
                    newdata[refindex[idx][0]][0] = newdata[declindex[decidx]][3]
                    newdata[refindex[idx][0]][1] = newdata[declindex[decidx]][4]
                    newdata[refindex[idx][0]][2] = newdata[declindex[decidx]][5]
                    break
                break
    
    return newdata, changeList

In [8]:
def makeTraining(dataset):
    trainData = []
    output = []
    
    for data in dataset:
        declindex = []
        refindex = []
        for idx, line in enumerate(data):
            if (line[14] == 1 or line[16] == 1 or line[19] == 1):
                declindex.append(idx)
            elif line[15] == 1:
                refindex.append([idx, len(declindex)])

        changeRate = 0.1
        changeNum = math.floor(changeRate * len(refindex))
        if changeNum > 0:
            for i in range(10):
                newdata, changeList = changeData(data, declindex, refindex, changeNum)
                trainData.append(newdata)
                output.append([0 if idx in changeList else 1 for idx in range(len(newdata))])
                
        print('done ', len(trainData), '/', len(dataset) * 10)
                
    return trainData, output

In [9]:
trainDataOrigin, outputOrigin = makeTraining(datas)

done  10 / 37600
done  20 / 37600
done  30 / 37600
done  40 / 37600
done  50 / 37600
done  60 / 37600
done  70 / 37600
done  80 / 37600
done  90 / 37600
done  100 / 37600
done  110 / 37600
done  120 / 37600
done  130 / 37600
done  140 / 37600
done  150 / 37600
done  160 / 37600
done  170 / 37600
done  180 / 37600
done  190 / 37600
done  200 / 37600
done  210 / 37600
done  220 / 37600
done  230 / 37600
done  240 / 37600
done  250 / 37600
done  260 / 37600
done  270 / 37600
done  280 / 37600
done  290 / 37600
done  300 / 37600
done  310 / 37600
done  320 / 37600
done  330 / 37600
done  340 / 37600
done  350 / 37600
done  360 / 37600
done  370 / 37600
done  380 / 37600
done  390 / 37600
done  400 / 37600
done  410 / 37600
done  420 / 37600
done  430 / 37600
done  440 / 37600
done  450 / 37600
done  460 / 37600
done  470 / 37600
done  480 / 37600
done  490 / 37600
done  500 / 37600
done  510 / 37600
done  520 / 37600
done  530 / 37600
done  540 / 37600
done  550 / 37600
done  560 / 37600
d

done  4530 / 37600
done  4540 / 37600
done  4550 / 37600
done  4560 / 37600
done  4570 / 37600
done  4580 / 37600
done  4590 / 37600
done  4600 / 37600
done  4610 / 37600
done  4620 / 37600
done  4630 / 37600
done  4640 / 37600
done  4650 / 37600
done  4660 / 37600
done  4670 / 37600
done  4680 / 37600
done  4690 / 37600
done  4700 / 37600
done  4710 / 37600
done  4720 / 37600
done  4730 / 37600
done  4740 / 37600
done  4750 / 37600
done  4760 / 37600
done  4770 / 37600
done  4780 / 37600
done  4790 / 37600
done  4800 / 37600
done  4810 / 37600
done  4820 / 37600
done  4830 / 37600
done  4840 / 37600
done  4850 / 37600
done  4860 / 37600
done  4870 / 37600
done  4880 / 37600
done  4890 / 37600
done  4900 / 37600
done  4910 / 37600
done  4920 / 37600
done  4930 / 37600
done  4940 / 37600
done  4950 / 37600
done  4960 / 37600
done  4970 / 37600
done  4980 / 37600
done  4990 / 37600
done  5000 / 37600
done  5010 / 37600
done  5020 / 37600
done  5030 / 37600
done  5040 / 37600
done  5050 /

done  8900 / 37600
done  8910 / 37600
done  8920 / 37600
done  8930 / 37600
done  8940 / 37600
done  8950 / 37600
done  8960 / 37600
done  8970 / 37600
done  8980 / 37600
done  8990 / 37600
done  9000 / 37600
done  9010 / 37600
done  9020 / 37600
done  9030 / 37600
done  9040 / 37600
done  9050 / 37600
done  9060 / 37600
done  9070 / 37600
done  9080 / 37600
done  9090 / 37600
done  9100 / 37600
done  9110 / 37600
done  9120 / 37600
done  9130 / 37600
done  9140 / 37600
done  9150 / 37600
done  9160 / 37600
done  9170 / 37600
done  9180 / 37600
done  9190 / 37600
done  9200 / 37600
done  9210 / 37600
done  9220 / 37600
done  9230 / 37600
done  9240 / 37600
done  9250 / 37600
done  9260 / 37600
done  9270 / 37600
done  9280 / 37600
done  9290 / 37600
done  9300 / 37600
done  9310 / 37600
done  9320 / 37600
done  9330 / 37600
done  9340 / 37600
done  9350 / 37600
done  9360 / 37600
done  9370 / 37600
done  9380 / 37600
done  9390 / 37600
done  9400 / 37600
done  9410 / 37600
done  9420 /

done  13220 / 37600
done  13230 / 37600
done  13240 / 37600
done  13250 / 37600
done  13260 / 37600
done  13270 / 37600
done  13280 / 37600
done  13290 / 37600
done  13300 / 37600
done  13310 / 37600
done  13320 / 37600
done  13330 / 37600
done  13340 / 37600
done  13350 / 37600
done  13360 / 37600
done  13370 / 37600
done  13380 / 37600
done  13390 / 37600
done  13400 / 37600
done  13410 / 37600
done  13420 / 37600
done  13430 / 37600
done  13440 / 37600
done  13450 / 37600
done  13460 / 37600
done  13470 / 37600
done  13480 / 37600
done  13490 / 37600
done  13500 / 37600
done  13510 / 37600
done  13520 / 37600
done  13530 / 37600
done  13540 / 37600
done  13550 / 37600
done  13560 / 37600
done  13570 / 37600
done  13580 / 37600
done  13590 / 37600
done  13600 / 37600
done  13610 / 37600
done  13620 / 37600
done  13630 / 37600
done  13640 / 37600
done  13650 / 37600
done  13660 / 37600
done  13670 / 37600
done  13680 / 37600
done  13690 / 37600
done  13700 / 37600
done  13710 / 37600


done  17360 / 37600
done  17370 / 37600
done  17380 / 37600
done  17390 / 37600
done  17400 / 37600
done  17410 / 37600
done  17420 / 37600
done  17430 / 37600
done  17440 / 37600
done  17450 / 37600
done  17460 / 37600
done  17470 / 37600
done  17480 / 37600
done  17490 / 37600
done  17500 / 37600
done  17510 / 37600
done  17520 / 37600
done  17530 / 37600
done  17540 / 37600
done  17550 / 37600
done  17560 / 37600
done  17570 / 37600
done  17580 / 37600
done  17590 / 37600
done  17600 / 37600
done  17610 / 37600
done  17620 / 37600
done  17630 / 37600
done  17640 / 37600
done  17650 / 37600
done  17660 / 37600
done  17670 / 37600
done  17680 / 37600
done  17690 / 37600
done  17700 / 37600
done  17710 / 37600
done  17720 / 37600
done  17730 / 37600
done  17740 / 37600
done  17750 / 37600
done  17760 / 37600
done  17770 / 37600
done  17780 / 37600
done  17790 / 37600
done  17800 / 37600
done  17810 / 37600
done  17820 / 37600
done  17830 / 37600
done  17840 / 37600
done  17850 / 37600


done  21560 / 37600
done  21570 / 37600
done  21580 / 37600
done  21590 / 37600
done  21600 / 37600
done  21610 / 37600
done  21620 / 37600
done  21630 / 37600
done  21640 / 37600
done  21650 / 37600
done  21660 / 37600
done  21670 / 37600
done  21680 / 37600
done  21690 / 37600
done  21700 / 37600
done  21710 / 37600
done  21720 / 37600
done  21730 / 37600
done  21740 / 37600
done  21750 / 37600
done  21760 / 37600
done  21770 / 37600
done  21780 / 37600
done  21790 / 37600
done  21800 / 37600
done  21810 / 37600
done  21820 / 37600
done  21830 / 37600
done  21840 / 37600
done  21850 / 37600
done  21860 / 37600
done  21870 / 37600
done  21880 / 37600
done  21890 / 37600
done  21900 / 37600
done  21910 / 37600
done  21920 / 37600
done  21930 / 37600
done  21940 / 37600
done  21950 / 37600
done  21960 / 37600
done  21970 / 37600
done  21980 / 37600
done  21990 / 37600
done  22000 / 37600
done  22010 / 37600
done  22020 / 37600
done  22030 / 37600
done  22040 / 37600
done  22050 / 37600


done  25690 / 37600
done  25700 / 37600
done  25710 / 37600
done  25720 / 37600
done  25730 / 37600
done  25740 / 37600
done  25750 / 37600
done  25760 / 37600
done  25770 / 37600
done  25780 / 37600
done  25790 / 37600
done  25800 / 37600
done  25810 / 37600
done  25820 / 37600
done  25830 / 37600
done  25840 / 37600
done  25850 / 37600
done  25860 / 37600
done  25870 / 37600
done  25880 / 37600
done  25890 / 37600
done  25900 / 37600
done  25910 / 37600
done  25920 / 37600
done  25930 / 37600
done  25940 / 37600
done  25950 / 37600
done  25960 / 37600
done  25970 / 37600
done  25980 / 37600
done  25990 / 37600
done  26000 / 37600
done  26010 / 37600
done  26020 / 37600
done  26030 / 37600
done  26040 / 37600
done  26050 / 37600
done  26060 / 37600
done  26070 / 37600
done  26080 / 37600
done  26090 / 37600
done  26100 / 37600
done  26110 / 37600
done  26120 / 37600
done  26130 / 37600
done  26140 / 37600
done  26150 / 37600
done  26160 / 37600
done  26170 / 37600
done  26180 / 37600


done  29930 / 37600
done  29940 / 37600
done  29950 / 37600
done  29960 / 37600
done  29970 / 37600
done  29980 / 37600
done  29990 / 37600
done  30000 / 37600
done  30010 / 37600
done  30020 / 37600
done  30030 / 37600
done  30040 / 37600
done  30050 / 37600
done  30060 / 37600
done  30070 / 37600
done  30080 / 37600
done  30090 / 37600
done  30100 / 37600
done  30110 / 37600
done  30120 / 37600
done  30130 / 37600
done  30140 / 37600
done  30150 / 37600
done  30160 / 37600
done  30170 / 37600
done  30180 / 37600
done  30190 / 37600
done  30200 / 37600
done  30210 / 37600
done  30220 / 37600
done  30230 / 37600
done  30240 / 37600
done  30250 / 37600
done  30260 / 37600
done  30270 / 37600
done  30280 / 37600
done  30290 / 37600
done  30300 / 37600
done  30310 / 37600
done  30320 / 37600
done  30330 / 37600
done  30340 / 37600
done  30350 / 37600
done  30360 / 37600
done  30370 / 37600
done  30380 / 37600
done  30390 / 37600
done  30400 / 37600
done  30410 / 37600
done  30420 / 37600


done  34110 / 37600
done  34120 / 37600
done  34130 / 37600
done  34140 / 37600
done  34150 / 37600
done  34160 / 37600
done  34170 / 37600
done  34180 / 37600
done  34190 / 37600
done  34200 / 37600
done  34210 / 37600
done  34220 / 37600
done  34230 / 37600
done  34240 / 37600
done  34250 / 37600
done  34260 / 37600
done  34270 / 37600
done  34280 / 37600
done  34290 / 37600
done  34300 / 37600
done  34310 / 37600
done  34320 / 37600
done  34330 / 37600
done  34340 / 37600
done  34350 / 37600
done  34360 / 37600
done  34370 / 37600
done  34380 / 37600
done  34390 / 37600
done  34400 / 37600
done  34410 / 37600
done  34420 / 37600
done  34430 / 37600
done  34440 / 37600
done  34450 / 37600
done  34460 / 37600
done  34470 / 37600
done  34480 / 37600
done  34490 / 37600
done  34500 / 37600
done  34510 / 37600
done  34520 / 37600
done  34530 / 37600
done  34540 / 37600
done  34550 / 37600
done  34560 / 37600
done  34570 / 37600
done  34580 / 37600
done  34590 / 37600
done  34600 / 37600


In [10]:
trainData = [torch.FloatTensor(data) for data in trainDataOrigin]
outputData = [torch.FloatTensor(data) for data in outputOrigin]

In [11]:
trainData = [data.reshape(len(data), 1, -1) for data in trainData]

outputData = [[data.reshape(1).long() for data in line] for line in outputData]

# Training

In [12]:
criterion = nn.NLLLoss()

In [17]:
learning_rate = 0.005

def train(tokenTensor, outputTensor):
    hidden = rnn.initHidden()
    rnn.zero_grad()

    for i in range(tokenTensor.size()[0]):
        output, hidden = rnn(tokenTensor[i], hidden)
        #print(output)
        #print(outputTensor[i].reshape(1, n_output).long())
        loss = criterion(output, outputTensor[i].reshape(1).long())
        loss.backward(retain_graph=True)

    for p in rnn.parameters():
        p.data.add_(-learning_rate, p.grad.data)
    
    return output, loss.item()

In [14]:
n_iters = 100000
print_every = 50
plot_every = 10

In [15]:
current_loss = 0
all_losses = []

def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

start = time.time()

In [20]:
idx = 0
for iter in range(1, n_iters + 1):
    tokenTensor = trainData[idx]
    outputTensor = outputData[idx]
    output, loss = train(tokenTensor, outputTensor)
    current_loss += loss
    
    if iter % print_every == 0:
        print('%d %d%% (%s) %.4f' % (iter, iter / n_iters * 100, timeSince(start), loss))
    
    if iter % plot_every == 0:
        all_losses.append(current_loss / plot_every)
        current_loss = 0
        
    print(iter)
    idx = (idx + 1) % len(trainData)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50 0% (5m 0s) 0.0005
50
51
52
53
54
55
56
57
58
59
60
61


KeyboardInterrupt: 

In [24]:
correct_result = {
    0: False,
    1: True
}
def correctFromOutput(output):
    top_n, top_i = output.topk(1)
    correct_i = top_i[0].item()
    return correct_result[correct_i], correct_i

In [32]:
hidden = torch.zeros(1, n_hidden)
for i in range(trainData[0].size()[0]):
    result, hidden = rnn(trainData[0][i], hidden)
    #print(result)
    print('predict: ', correctFromOutput(result), 'actual: ', correct_result[int(outputData[0][i])], 'diff: ', float(result[0][1] - result[0][0]))

predict:  (True, 1) actual:  True diff:  1.570624589920044
predict:  (True, 1) actual:  True diff:  1.8894290924072266
predict:  (True, 1) actual:  True diff:  1.36505126953125
predict:  (True, 1) actual:  True diff:  1.387604832649231
predict:  (True, 1) actual:  True diff:  1.1451308727264404
predict:  (True, 1) actual:  True diff:  1.1784156560897827
predict:  (True, 1) actual:  True diff:  1.1185340881347656
predict:  (True, 1) actual:  True diff:  1.1479504108428955
predict:  (True, 1) actual:  True diff:  1.166602373123169
predict:  (True, 1) actual:  True diff:  1.2617146968841553
predict:  (True, 1) actual:  False diff:  1.285836935043335
predict:  (True, 1) actual:  True diff:  1.6597962379455566
predict:  (True, 1) actual:  True diff:  1.661882758140564
predict:  (True, 1) actual:  True diff:  1.3852134943008423
predict:  (True, 1) actual:  True diff:  1.2588224411010742
predict:  (True, 1) actual:  True diff:  1.1229243278503418
predict:  (True, 1) actual:  True diff:  1.156

predict:  (True, 1) actual:  True diff:  4.976753234863281
predict:  (True, 1) actual:  True diff:  5.278386116027832
predict:  (True, 1) actual:  True diff:  5.148708343505859
predict:  (True, 1) actual:  True diff:  5.350506782531738
predict:  (True, 1) actual:  True diff:  4.442850112915039
predict:  (True, 1) actual:  True diff:  4.71136474609375
predict:  (True, 1) actual:  True diff:  4.910245895385742
predict:  (True, 1) actual:  True diff:  4.954710960388184
predict:  (True, 1) actual:  True diff:  5.377714157104492
predict:  (True, 1) actual:  True diff:  5.255571365356445
predict:  (True, 1) actual:  True diff:  5.469278335571289
predict:  (True, 1) actual:  True diff:  4.523590087890625
predict:  (True, 1) actual:  True diff:  4.834252834320068
predict:  (True, 1) actual:  True diff:  5.460104465484619
predict:  (True, 1) actual:  True diff:  5.329298496246338
predict:  (True, 1) actual:  True diff:  5.277484893798828
predict:  (True, 1) actual:  True diff:  4.90444564819335