In [1]:
import utilities as u

import torch
from torch_geometric.data import Data, InMemoryDataset, DataLoader
import torch.nn as nn

import pickle
import random

torch.manual_seed(42)

<torch._C.Generator at 0x7f678c095390>

In [2]:
# EXPNAME = 'test_adding_edge_features'
EXPNAME = 'bug_fix'
binary = True
only_top = True

# Tensorboard Plotting 

In [3]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(f'runs/{EXPNAME}')

# Define Dataset

In [4]:
from tqdm import tqdm

class TopLevelProofDataset(InMemoryDataset):
    def __init__(self, root='', transform=None, pre_transform=None):
        super(TopLevelProofDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
        
    @property
    def raw_file_names(self):
        return []
    
    @property
    def processed_file_names(self):
        return [f'../bug_fix.dataset']
    
    def download(self):
        pass
    
    def process(self):
        global data
        data_list = []
        all_features = set()
        trees = []
        
        for thm, y in tqdm(data):
            thm = u.process_theorem(thm)
            tree, distinct_features = u.thm_to_tree(thm)
            all_features = all_features | distinct_features
            trees.append((tree, y))
        
        normalized_features = {k: [random.random() for i in range(128)] for k in list(all_features)}
            
        for tree, y in tqdm(trees):
            merged_tree = u.merge_subexpressions(tree)
            x, (edge_index_up, edge_index_down), (edge_features_up, edge_features_down) = u.graph_to_data(tree, 
                                                                                                           normalized_features)
            data = Data(x=x, 
                        y=y, 
                        edge_index=torch.cat((edge_index_up, edge_index_down), dim=1),
                        edge_attr=torch.cat((edge_features_up, edge_features_down)),
                       )
            data_list.append(data)
            
        
#         for thm, y in tqdm(data):
#             thm = u.process_theorem(thm)
#             tree, distinct_features = u.thm_to_tree(thm)
#             normalized_features = {k: [random.random() for i in range(16)] for k in list(distinct_features)}
#             tree = u.merge_subexpressions(tree)
            
# #             x, edge_index = u.graph_to_data(tree, list(distinct_features))
#             x, edge_index = u.graph_to_data(tree, normalized_features)
            
#             data = Data(x=x, edge_index=edge_index, y=y)
#             data_list.append(data)
        
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

# SAGEConv Layer

In [5]:
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops

class SAGEConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(SAGEConv, self).__init__(aggr='mean') #  "Max" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)
        self.act = torch.nn.ReLU()
        self.update_lin = torch.nn.Linear(in_channels + out_channels, in_channels, bias=False)
        self.update_act = torch.nn.ReLU()
        
    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        
        
        edge_index, _ = remove_self_loops(edge_index)
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_j):
        # x_j has shape [E, in_channels]

        x_j = self.lin(x_j)
        x_j = self.act(x_j)
        
        return x_j

    def update(self, aggr_out, x):
        # aggr_out has shape [N, out_channels]


        new_embedding = torch.cat([aggr_out, x], dim=1)
        
        new_embedding = self.update_lin(new_embedding)
        new_embedding = self.update_act(new_embedding)
        
        return new_embedding

# GNN definition

In [46]:
embed_dim = 128
from torch_geometric.nn import TopKPooling, GCNConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
import torch.nn.functional as F
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.conv1 = SAGEConv(dataset.num_features, embed_dim)
#         self.conv1 = GCNConv(embed_dim, 128)
        self.embedding = torch.nn.Embedding(num_embeddings=len(distinct_features)+1, embedding_dim=embed_dim)
        self.pool1 = TopKPooling(128, ratio=0.8)
        self.conv2 = SAGEConv(128, 128)
        self.pool2 = TopKPooling(128, ratio=0.8)
        self.conv3 = SAGEConv(128, 128)
        self.pool3 = TopKPooling(128, ratio=0.8)
        self.lin1 = torch.nn.Linear(256, 128)
        self.lin2 = torch.nn.Linear(128, 64)
        self.lin3 = torch.nn.Linear(64, 11)
        self.lin4 = torch.nn.Linear(64, 1)
        self.bn1 = torch.nn.BatchNorm1d(128)
        self.bn2 = torch.nn.BatchNorm1d(64)
        self.act1 = torch.nn.ReLU()
        self.act2 = torch.nn.ReLU()  
  
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch    

        x = self.conv1(x, edge_index)
        x = F.relu(x)

        x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
#         x = x1

        x = F.relu(self.conv2(x, edge_index))
     
        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = F.relu(self.conv3(x, edge_index))

        x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)
        x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = x1 + x2 + x3

        x = self.lin1(x)
        x = self.act1(x)
        x = self.lin2(x)
        x = self.act2(x)      
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.lin4(x))

#         x = F.log_softmax(self.lin3(x), dim=1).squeeze(1)


        return x

# Model 2 (Subgraph Pooling Paper)

In [30]:
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops

class PaliwalMP(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(PaliwalMP, self).__init__(aggr='mean', flow='target_to_source') #  "Mean" aggregation.
        
        # MLP for Parents and Children, step 2 of Paliwal MP
        self.MLP_edge = BuildingBlock(3*128, 128)
        self.MLP_edge_hat = BuildingBlock(3*128, 128)
        
        # MLP to pass aggregated message through, step 3 of Paliwal MP
        self.MLP_aggr = BuildingBlock(3*128, 128)

        
    def forward(self, x, edge_index_parents, edge_index_children, edge_attr_parents, edge_attr_children):
        # x has shape [N, in_channels]
        # edge_index_x has shape [2, E/2]
        out_parents = self.propagate(edge_index_parents, 
                                     x=x, 
                                     edge_attr=edge_attr_parents, 
                                     direction='up', 
                                     size=None)
        out_children = self.propagate(edge_index_children, 
                                      x=x, 
                                      edge_attr=edge_attr_children, 
                                      direction='down', 
                                      size=None)
        
        out = torch.cat([x, out_parents, out_children], dim=1)
        out = self.MLP_aggr(out) + x
        
        
        return out

    
    def message(self, x_i, x_j, edge_attr, direction):

        s_ij = torch.cat([x_i, x_j, edge_attr], dim=1)
        if direction == 'up':
            s_ij = self.MLP_edge(s_ij)
        elif direction == 'down':
            s_ij = self.MLP_edge_hat(s_ij)
        
        return s_ij

    
    def update(self, aggr_out, x):
        # aggr_out has shape [N, out_channels]
        return aggr_out


In [76]:
from torch_geometric.nn import global_max_pool as gmp

embed_dim = 128

# TODO: Apply dropout

class BuildingBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, dim=0):
        super(BuildingBlock, self).__init__()
        self.lin1 = Linear(in_channels, 256, dim)
        self.hidden = Linear(256, 128)
        self.lin2 = Linear(128, 128)
        
    def forward(self, x):
        x = F.relu(self.lin1(x))
        x = F.relu(self.hidden(x))
        x = F.dropout(x, 0.5)
        x = F.relu(self.lin2(x))
        
        return x
    

class PaliwalNet(torch.nn.Module):
    def __init__(self, t):
        super(PaliwalNet, self).__init__()
        self.MLP_V = BuildingBlock(embed_dim, 128)
        self.MLP_E = BuildingBlock(1, 128)
        
        self.message_passing_steps = nn.ModuleList()
        for i in range(t):
            self.message_passing_steps.append(PaliwalMP(embed_dim, embed_dim))
            
        self.conv1 = nn.Conv1d(128, 512, (1,1))
        self.conv2 = nn.Conv1d(512, 1024, (1,1))
        
        # FCNN for final prediction
        self.lin1 = Linear(1024, 512)
        self.lin2 = Linear(512, 512)
        self.lin3 = Linear(512, 256)
        self.lin4 = Linear(256, 256)
        self.lin5 = Linear(256, 128)
        self.lin6 = Linear(128, 1)

  
    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        edge_index_u, edge_index_d = torch.split(edge_index, int(edge_index.shape[1]/2), dim=1)
        edge_attr_u, edge_attr_d = torch.split(edge_attr, int(edge_attr.shape[0]/2))
        
        # Embed node and edge features into high dimensional space
        x = self.MLP_V(x)
        edge_attr_u = self.MLP_E(edge_attr_u.float())
        edge_attr_d = self.MLP_E(edge_attr_d.float())
        
        for message_passing_step in self.message_passing_steps:
            x = message_passing_step(x, edge_index_u, edge_index_d, edge_attr_u, edge_attr_d)
        
        x = x.unsqueeze(-1).unsqueeze(-1)
        x = self.conv1(x)
        x = self.conv2(x)
        
        # Final prediction network
        g = gmp(x, batch)
        g = g.squeeze(-1).squeeze(-1)
        g = F.relu(self.lin1(g))
        g = F.relu(self.lin2(g))
        g = F.relu(self.lin3(g))
        g = F.relu(self.lin4(g))
        g = F.relu(self.lin5(g))
        g = F.relu(self.lin6(g))
        
        return x

In [77]:
pnet = PaliwalNet(t=1)
net = Net()
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=False, num_workers=1)

for i, data in enumerate(train_loader):
    pnet(data)
    break
    
for i, data in enumerate(train_loader):
    net(data)
    break

torch.Size([439, 128, 1, 1])
torch.Size([439, 1024, 1, 1])
torch.Size([16, 1]) tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.]], grad_fn=<ReluBackward0>)
hi torch.Size([439, 128])
torch.Size([16, 256])


# Data inspections

In [9]:
new_data = True

if new_data == True:
    data = u.make_data(binary=binary, only_top=only_top)
    with open(EXPNAME, 'wb') as outfile:
        pickle.dump(data, outfile)
else:
    with open(EXPNAME, 'rb') as infile:
        data = pickle.load(infile)
        
        
# data = data[0:2]

100%|██████████| 10/10 [00:07<00:00,  1.32it/s]


In [10]:
distinct_features = set()

for idx, (thm, _) in enumerate(data):
    if idx % 1000 == 0:
        print(f'{idx} / {len(data)}')
    thm = u.process_theorem(thm)
    thm_tree, features = u.thm_to_tree(thm)
    distinct_features = distinct_features.union(features)
len(distinct_features)

0 / 1427
1000 / 1427


350

In [11]:
distinct_features = set(i for i in range(len(distinct_features)))

In [12]:
test_thm = '(fun (a A B) (a A (a A B)))'
print(test_thm)
thm = u.process_theorem(test_thm)
print(thm)
thm_tree, _ = u.thm_to_tree(thm)
print(len(thm_tree))
print(thm_tree.subtrees[0].parents[0])
thm_tree = u.merge_subexpressions(thm_tree)
x = u.graph_to_data(thm_tree)
print(x)

#print([t.root for t in thm_tree.subtrees[0].subtrees])


print(thm_tree.root)
print([t.root for t in thm_tree.subtrees])
t_0, t_1 = thm_tree.subtrees
print([t.root for t in t_0.subtrees])
print([t.root for t in t_1.subtrees])
print(t_1.subtrees[0].subtree_str)
print(len(thm_tree))

(fun (a A B) (a A (a A B)))
['(', 'fun', '(', 'a', 'A', 'B', ')', '(', 'a', 'A', '(', 'a', 'A', 'B', ')', ')', ')']
9
Tree(root=(fun, index=None), parents=None, size=9)
5
(['fun', 'a', 'A', 'B', 'a'], (tensor([[1, 1, 3, 3, 4, 2],
        [0, 2, 1, 2, 1, 0]]), tensor([[0, 0, 1, 1, 2, 2],
        [2, 1, 4, 3, 1, 3]])), (tensor([[0],
        [0],
        [0],
        [0],
        [1],
        [1]]), tensor([[1],
        [0],
        [1],
        [0],
        [1],
        [0]])))
(fun, index=0)
[(a, index=1), (a, index=2)]
[(A, index=3), (B, index=4)]
[(A, index=3), (a, index=1)]
A
5


In [13]:
counter = dict()
for _, y in data:
    if y in counter:
        counter[y] += 1
    else:
        counter[y] = 1
counter = list(counter.items())
counter.sort(key=lambda x: x[0], reverse=False)
percentages = [(x, y/len(data)*100) for x,y in counter]
percentages

[(0, 45.05956552207428), (1, 54.94043447792571)]

# Create Dataset

In [14]:
from math import floor

dataset = TopLevelProofDataset()
dataset.shuffle()

train_dataset = dataset[:floor(len(dataset)/2)]
valid_dataset = dataset[floor(len(dataset)/2) : 3*floor(len(dataset)/2)]
test_dataset = dataset[3*floor(len(dataset)/2):]

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False, num_workers=8)
print(train_dataset)

  3%|▎         | 39/1427 [00:00<00:03, 384.38it/s]

Processing...


100%|██████████| 1427/1427 [00:38<00:00, 36.87it/s]
  5%|▌         | 73/1427 [00:00<00:04, 282.32it/s]

25
15
27
26
31
31
24
31
45
25
34
22
23
17
42
21
44
46
21
49
21
38
36
36
39
38
37
39
38
37
37
38
37
36
36
36
39
38
29
28
30
32
40
46
37
16
40
20
20
41
15
17
40
34
28
21
18
19
31
31
21
19
37
19
36
36
19
39
19
37
37
41
54
54
40
41
13
62
47


 11%|█         | 151/1427 [00:00<00:04, 313.69it/s]

55
31
44
60
41
17
21
28
19
19
21
13
24
24
23
28
27
63
49
25
25
20
35
47
31
22
24
18
16
21
22
22
27
27
29
24
31
9
21
39
29
35
29
30
21
33
24
23
25
21
30
21
37
58
34
34
46
20
46
58
191
49
27
50
70
52
61
21
64
79
21
83
84
49
17
66
25


 13%|█▎        | 180/1427 [00:00<00:04, 251.00it/s]

39
59
14
63
30
65
34
43
37
37
24
27
31
24
32
33
27
37
34
43
48
36
36
46
37
38
29
36
39
33
32
32
24
19
34
32
36
36


 16%|█▌        | 229/1427 [00:00<00:05, 209.37it/s]

38
33
29
36
32
32
29
25
33
40
25
25
38
39
38
49
51
49
49
50
51
54
21
32
57
48
24
46
25
25
13
33
35
46
27
36
30
34
33
62
49


 18%|█▊        | 251/1427 [00:00<00:06, 194.66it/s]

36
33
26
29
39
34
36
37
39
42
44
27
28
48
36
69
49
39
55
54
57
49
68
68
76


 19%|█▉        | 272/1427 [00:01<00:07, 158.12it/s]

57
25
60
51
29
25
44
52
43
53
56
56
66
30
55
61
56
37
41
55
57
39
57
70
43
30
75
54
33
52
33


 22%|██▏       | 309/1427 [00:01<00:08, 127.07it/s]

62
51
61
83
51
71
69
30
74
36
103
54
47
19
45
89
93
79
79


 23%|██▎       | 324/1427 [00:01<00:08, 126.66it/s]

106
41
108
56
73
22
42
16
39
58
21
64
64
41
26
57
126
134
57
128
124
135


 26%|██▌       | 367/1427 [00:01<00:07, 147.81it/s]

30
40
44
38
14
196
28
35
31
14
212
25
41
34
14
68
105
67
14
142
47
33
39
49
14
14
47
35
34
42
50
24
14
39
39
42
40
21
17
33
54
21
36
434
148
24
63
86
43
50
37
38


 29%|██▉       | 417/1427 [00:02<00:05, 185.24it/s]

47
101
38
79
20
150
38
68
78
13
12
83
46
67
66
37
30
40
52
73
67
35
40
35
21
54
36
40
50
39
60
61
42
58
22
60
73
54
49
56
45
58
48
45
43
47
46
61
58
46
36
58
77
43
38
25
64
75
64
18
42
47
56
46
72
47


 34%|███▍      | 483/1427 [00:02<00:04, 232.87it/s]

68
41
37
23
62
23
74
68
33
35
49
37
53
36
32
33
44
36
45
47
59
36
87
76
43
57
60
64
48
63
59
50
59
15
27
34
36
47
15
35
74
13
28
37
49
34
32
25
39
39
34
37
38
39
29
38
39
39


 38%|███▊      | 538/1427 [00:02<00:03, 240.17it/s]

60
54
58
52
71
69
84
114
104
100
87
51
23
40
38
38
32
32
32
49
16
28
30
38
40
28
28
35
36
35
36
16
37
28
49
52
38
40
40
50
50
33
74
17
17
17
22
58
40


 42%|████▏     | 598/1427 [00:02<00:03, 221.68it/s]

37
32
32
20
30
32
32
47
47
37
25
31
21
35
31
35
73
45
62
36
52
41
55
35
28
28
38
39
48
44
30
44
43
39
51
48
49
32
40
159
28
16


 44%|████▎     | 623/1427 [00:02<00:04, 179.66it/s]

101
42
53
43
37
37
39
41
56
40
38
43
43
46
30
20
38
30
30
32
38
38
38
43
24
38
49
49
47
42
49
29
39
38
37
31


 47%|████▋     | 676/1427 [00:03<00:03, 205.98it/s]

39
44
44
44
41
41
41
32
45
35
51
33
32
36
40
26
29
39
39
31
32
37
37
23
32
46
30
42
19
39
44
40
48
50
32
33
40
40
34
33
27
27
41
41
43
43
34
34
50


 49%|████▉     | 699/1427 [00:03<00:04, 156.27it/s]

45
45
55
55
56
55
50
20
37
37
24
37
26
40
48
51
24
32
45
36
45


 50%|█████     | 718/1427 [00:03<00:04, 147.22it/s]

47
50
39
56
57
58
56
55
59
70
60
58
59
59
58
44
56
67
44
31
49
49
49
51
51
41
68
55
58
57


 53%|█████▎    | 754/1427 [00:03<00:04, 149.72it/s]

70
61
63
50
59
58
48
83
40
40
53
52
50
63
69
60
72
65
50
50
50
50
72
44
28
53
23
27
41
24
32
32


 56%|█████▌    | 794/1427 [00:04<00:03, 160.14it/s]

69
132
42
20
53
70
49
73
70
44
31
63
60
58
67
68
87
25
34
25
92
59
38
15
48
13
19
35
30
39
39
30
39
42
49
36
52
19
53
19
41


 58%|█████▊    | 831/1427 [00:04<00:03, 157.51it/s]

38
57
44
29
76
13
36
55
83
31
21
70
35
31
70
35
73
45
62
36
56
41
55
31
159


 59%|█████▉    | 848/1427 [00:04<00:04, 142.30it/s]

92
16
85
74
29
25
72
95
21
24
75
41
57
65
19
34
43
192
76
21
43
43
46
46
37
63
56
60


 64%|██████▍   | 910/1427 [00:04<00:03, 169.70it/s]

61
43
17
17
37
36
35
27
27
40
37
40
20
22
51
35
28
35
36
41
64
38
35
27
55
38
35
28
40
36
42
38
45
26
27
24
26
28
26
27
51
37
35
23
33
39
50
48
50
34
51
46
23
36
50
45
54
50
44
41
23
33
59
44
112
36
39
52
55


 66%|██████▌   | 937/1427 [00:04<00:02, 185.06it/s]

16
29
44
23
30
46
49
50
26
60
77
90
79
45
64
89
59
68
49
33
33
77
285
273
78
259
267
61
58
60
76


 69%|██████▊   | 980/1427 [00:05<00:04, 107.91it/s]

94
19
46
120
111
36
136
99
64
50
82
62
79
47
53
63
75
75
53
87
87
34
36
36
32
32
32
73
98
97
77


 70%|███████   | 1002/1427 [00:05<00:03, 127.35it/s]

76
63
27
64
99
69
105
27
21
27
112
27
72
70
34
13
76
43
25
74
43
25
74
88


 71%|███████▏  | 1020/1427 [00:05<00:03, 109.00it/s]

26
113
62
116
114
112
122
29
68
71
46
72
39
69
65
40
126
36
76


 76%|███████▌  | 1082/1427 [00:06<00:02, 143.81it/s]

81
45
51
50
24
54
70
107
24
24
63
70
25
87
93
66
88
59
61
23
16
78
23
102
29
113
95
102
68
69
66
61
23
21
87
86
62
57
57
43
44
45
47
47
48
62
203


 77%|███████▋  | 1100/1427 [00:06<00:02, 117.59it/s]

274
113
53
51
26
27
17
76
46
39
43
53
57
45
45
10
58
9
25
18
66
52


 78%|███████▊  | 1115/1427 [00:06<00:02, 125.24it/s]

34
45
38
40
35
41
28
75
54
73
82
41
60
42


 79%|███████▉  | 1130/1427 [00:06<00:03, 81.44it/s] 

49
49
54
51
123
99
106
83
99
98
60
128


 80%|████████  | 1142/1427 [00:06<00:03, 76.15it/s]

84
117
105
30
110
114
43
72
59
61
65
64
64
54
71
63
67
59
46
71


 81%|████████▏ | 1163/1427 [00:07<00:03, 77.60it/s]

55
56
70
61
63
63
66
53
69
63
67
27
46


 83%|████████▎ | 1184/1427 [00:07<00:02, 95.55it/s]

82
77
54
56
68
72
56
45
41
36
44
46
33
73
98
53
63
50
51
31
25
43
38
146
19
99
71
43
42
317


 85%|████████▍ | 1210/1427 [00:07<00:02, 99.01it/s]

85
53
109
112
85
41
82
23
21
29
27
86
109
73
82
84
118
101
50
75
147


 86%|████████▋ | 1233/1427 [00:07<00:01, 97.09it/s]

114
72
71
135
55
55
87
103
95
24
10
83
73
70
46
110
59
50
66
78
59
66
60


 89%|████████▊ | 1264/1427 [00:07<00:01, 104.54it/s]

77
40
60
100
59
71
24
84
64
76
88
88
78
44
132
50
138
116
117
27
56
99
119
142
114
98
123
108


 89%|████████▉ | 1276/1427 [00:08<00:01, 100.66it/s]

56
69
57
69
35
79
75
70
68
74
76
72
93
30
10
88
73
70
46
116
64
50
42
60
78
82
82
51
58
58
80
105


 93%|█████████▎| 1327/1427 [00:08<00:00, 145.23it/s]

70
33
72
66
74
85
24
89
70
93
85
93
78
100
44
91
56
152
102
117
58
27
90
90
90
114
129
98
108
61
135
95
102
60
72
51


 94%|█████████▍| 1346/1427 [00:08<00:00, 120.14it/s]

75
74
84
88
83
63
76
82
76
110
81
81
99
86
93
84
128
50
117
23
107
13
43
102
10
40
37
28
45
40
39
45


 98%|█████████▊| 1401/1427 [00:08<00:00, 166.82it/s]

81
36
64
31
25
38
26
30
34
73
64
78
84
31
92
49
55
34
38
35
46
62
59
50
49
55
34
78
21
45
22
81
77
59
69
66
27
73
67
61
46
47
110
48
50
110
50
51
110
71


100%|██████████| 1427/1427 [00:09<00:00, 157.69it/s]

286
111
139
51
81
88
98
117
144
141
34
159





Done!
TopLevelProofDataset(713)


# Train

In [54]:
batch_size = 1
num_epochs = 100

def train():
    global epoch
    model.train()
    
    loss_all = 0
    correct = 0
    for i, data in enumerate(train_loader):
        x = data.x.squeeze(1)

        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        label = torch.unsqueeze(data.y.to(device), 1).float()
#         torch.unsqueeze(label, 1)
#         print(data.shape, output.shape, label.shape)
#         print(label, output)
#         loss = F.nll_loss(output, label)
        loss = F.mse_loss(output, label)
    
        if torch.isnan(loss):
            print(output, label)
        loss.backward()
        loss_all += data.num_graphs * loss.item()
        pred = output.data.max(1, keepdim=True)[1]
        correct_this_time = pred.eq(label.data.view_as(pred)).sum()
        correct += correct_this_time
#         if correct_this_time == 0:
#             print(label, output)
        

        optimizer.step()
#         print(correct.item())
    
    writer.add_scalar('training loss',
                     loss_all / len(train_dataset),
                     epoch)
    
    writer.add_scalar('training accuracy',
                     correct.item() / len(train_dataset),
                     epoch)
    #print(correct.item())
    
    return loss_all / len(train_dataset), correct.item() / len(train_dataset)


def test():
    global epoch
    model.eval()
    
    loss_all = 0
    correct = 0
    for i, data in enumerate(valid_loader):
        x = data.x.squeeze(1)

        data = data.to(device)
        output = model(data)
        label = torch.unsqueeze(data.y.to(device), 1)
        
        loss = F.mse_loss(output, label)
        loss_all += data.num_graphs * loss.item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(label.data.view_as(pred)).sum()
    
    writer.add_scalar('validation loss',
                     loss_all / len(valid_dataset),
                     epoch)
    
    writer.add_scalar('validation accuracy',
                     correct.item() / len(valid_dataset),
                     epoch)
    
    return loss_all / len(valid_dataset), correct.item() / len(valid_dataset)
    
    
# device = torch.device('cuda:2')
device = torch.device('cpu')

model = Net().to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=0.00005, momentum=0.8)
crit = torch.nn.CrossEntropyLoss()
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=8)

assert 1 == 2

epoch = 0
valid_loss, valid_acc = test()
for epoch in tqdm(range(num_epochs)):
    epoch_loss, epoch_acc = train()
    #print(f'train - loss: {epoch_loss}, acc: {epoch_acc}')
    
    if epoch % 5 == 4:
        valid_loss, valid_acc = test()
#         print(f'valid - loss: {valid_loss}, acc: {valid_acc}')
        
    #print()
        
    

AssertionError: 

In [60]:
def save_model(model, PATH):
    torch.save(model.state_dict(), PATH)
    
def load_model(model_type, PATH):
    model = model_type()
    model.load_state_dict(torch.load(PATH))
    return model

In [61]:
m = load_model(Net, 'bug_fix.pth')
m

Net(
  (conv1): SAGEConv(
    (lin): Linear(in_features=128, out_features=128, bias=True)
    (act): ReLU()
    (update_lin): Linear(in_features=256, out_features=128, bias=False)
    (update_act): ReLU()
  )
  (embedding): Embedding(351, 128)
  (pool1): TopKPooling(128, ratio=0.8, multiplier=1)
  (conv2): SAGEConv(
    (lin): Linear(in_features=128, out_features=128, bias=True)
    (act): ReLU()
    (update_lin): Linear(in_features=256, out_features=128, bias=False)
    (update_act): ReLU()
  )
  (pool2): TopKPooling(128, ratio=0.8, multiplier=1)
  (conv3): SAGEConv(
    (lin): Linear(in_features=128, out_features=128, bias=True)
    (act): ReLU()
    (update_lin): Linear(in_features=256, out_features=128, bias=False)
    (update_act): ReLU()
  )
  (pool3): TopKPooling(128, ratio=0.8, multiplier=1)
  (lin1): Linear(in_features=256, out_features=128, bias=True)
  (lin2): Linear(in_features=128, out_features=64, bias=True)
  (lin3): Linear(in_features=64, out_features=11, bias=True)
  (