In [20]:
import csv
import networkx as nx
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torchvision
from sklearn.linear_model import LogisticRegression

In [2]:
print(torch.cuda.is_available())

True


In [3]:
# prepare dataset

f = open('train.csv', 'r', encoding='utf-8')
reader = csv.reader(f)
trains = []
for line in reader:
    trains.append(line)
f.close()    
print("trains", len(trains))

f = open('node_ingredient.csv', 'r', encoding='utf-8')
reader = csv.reader(f)
node_ingredients = []
for line in reader:
    node_ingredients.append(line)
f.close()    
print("node_ingredients", len(node_ingredients))
num_ing = len(node_ingredients)

# f = open('validation_classification_question.csv', 'r', encoding='utf-8')
# reader = csv.reader(f)
# val_cls_q = []
# for line in reader:
#     val_cls_q.append(line)
# f.close()
# print("val_cls_q", len(val_cls_q))

# f = open('validation_classification_answer.csv', 'r', encoding='utf-8')
# reader = csv.reader(f)
# val_cls_a = []
# for line in reader:
#     val_cls_a.append(line)
# f.close()
# print("val_cls_a", len(val_cls_a))

f = open('validation_completion_question.csv', 'r', encoding='utf-8')
reader = csv.reader(f)
val_cpt_q = []
for line in reader:
    val_cpt_q.append(line)
f.close()
print("val_cpt_q", len(val_cpt_q))

f = open('validation_completion_answer.csv', 'r', encoding='utf-8')
reader = csv.reader(f)
val_cpt_a = []
for line in reader:
    val_cpt_a.append(line)
f.close()
print("val_cpt_a", len(val_cpt_a))

trains 23547
node_ingredients 6714
val_cpt_q 7848
val_cpt_a 7848


## Completion Task : Method 0
find neighbors of largest weight (no use of embedding)

In [62]:
# make graph of ingredients
G = nx.Graph()
for i in range(len(node_ingredients)):
    G.add_node(str(i))
    
print(G.number_of_nodes())

for data in trains:
    for i in range(len(data) - 2):
        for j in range(i+1, len(data) - 1):
            if G.has_edge(data[i], data[j]):
                G[data[i]][data[j]]['weight'] += 1
            else:
                G.add_edge(data[i], data[j], weight=1)

print(G.number_of_edges())

6714
355816


In [67]:
acc = 0
for i, data in tqdm(enumerate(val_cpt_q)):
#     print(data)
    weight_dict = {}
    for node in data:
        for adv, w in G.adj[node].items():
            if adv in weight_dict.keys():
                weight_dict[adv] += w['weight']
            else:
                weight_dict[adv] = w['weight']
    for node in data:
        if node in weight_dict.keys():
            del weight_dict[node]
    
    weight_dict = sorted(weight_dict.items(), key=(lambda x: x[1]), reverse=True)
    
#     print(weight_dict)

    if weight_dict[0][0] == val_cpt_a[i][0]:
        acc += 1
        
print("accuracy: ", acc / len(val_cpt_q) * 100, "%")

7848it [01:20, 97.01it/s] 

accuracy:  6.167176350662589 %





## Completion Task : Method 1
cosine similarity of embedding

In [4]:
def cos_sim(X,y):
    return np.dot(X, y) / (np.linalg.norm(X, axis=1) * np.linalg.norm(y))

In [14]:
def make_embedding(filename, hiddensize):
    f = open(filename, 'r', encoding='utf-8')
    reader = csv.reader(f)
    f.readline()
    embedding = np.zeros((6714, hiddensize))
    for line in reader:
        i = int(line[1])
        j = 0
        for node in line[2][2:-1].split(' '):
            if node != '':
                embedding[i][j] = float(node.strip())
                j += 1

    f.close()
    return embedding

In [15]:
embp1q10 = make_embedding('Embedding/Embp1q10.csv', 64)
svd128 = make_embedding('Embedding/SVD128.csv', 128)
svd64 = make_embedding('Embedding/SVD64.csv', 64)
svd32 = make_embedding('Embedding/SVD32.csv', 32)

In [16]:
# method 1-1: similarity of average

def similarityBased(embedding):
    acc = 0
    for i, data in tqdm(enumerate(val_cpt_q)):
        nodes = [int(node) for node in data]
        avg_node = np.average(embedding[nodes][:], axis=0)
        sims = cos_sim(embedding, avg_node)
        ranks = np.argsort(-sims)

        target = int(val_cpt_a[i][0])

    #     print("input: ", data)
    #     print("target: ", target)
    #     print("estimation: ", ranks[0:10])
    #     print("rank: ", np.where(ranks == target))
    #     print("")

        j = 0
        while ranks[j] in data:
            j += 1

        estimation = ranks[j]

        if estimation == target:
            acc += 1
            
    return acc / len(val_cpt_q)

In [17]:
print(similarityBased(embp1q10))
print(similarityBased(svd128))
print(similarityBased(svd64))
print(similarityBased(svd32))

  
7848it [00:12, 643.49it/s]
43it [00:00, 429.95it/s]

0.002420998980632008


7848it [00:16, 465.30it/s]
66it [00:00, 658.74it/s]

0.004332313965341488


7848it [00:11, 663.53it/s]
75it [00:00, 747.52it/s]

0.01210499490316004


7848it [00:10, 745.96it/s]

0.018603465851172275





In [18]:
# method 1-2: average of similarity

def similarityBased2(embedding):
    acc = 0
    for i, data in tqdm(enumerate(val_cpt_q)):
        nodes = [int(node) for node in data]
        sims = []
        for node in nodes:
            sims.append(cos_sim(embedding, embedding[node]))
        sims = np.stack(sims, axis=0)
        avg_sim = np.average(sims, axis=0)
        ranks = np.argsort(-avg_sim)

        target = int(val_cpt_a[i][0])

        j = 0
        while ranks[j] in data:
            j += 1

        estimation = ranks[j]

        if estimation == target:
            acc += 1

    return acc / len(val_cpt_q)

In [19]:
print(similarityBased2(embp1q10))
print(similarityBased2(svd128))
print(similarityBased2(svd64))
print(similarityBased2(svd32))

  
7848it [01:07, 116.99it/s]
6it [00:00, 56.00it/s]

0.00191131498470948


7848it [01:55, 68.07it/s]
11it [00:00, 105.48it/s]

0.006625891946992864


7848it [01:04, 121.86it/s]
14it [00:00, 132.34it/s]

0.011213047910295617


7848it [00:49, 157.66it/s]

0.014780835881753314





## Logistic Regression

In [108]:
def make_dataset(trainset):
    xs = []
    ys = []
    for data in tqdm(trainset):
        nodes = [int(node) for node in data[:-1]]
        x = np.zeros((len(nodes), num_ing))
        y = np.zeros(len(nodes))
        for i, node in enumerate(nodes):
            other_nodes = [nd for nd in nodes if nd != node]
            for other in other_nodes:
                x[i][other] = 1
                y[i] = node
        xs.append(x)
        ys.append(y)
    train_x = np.concatenate(xs, axis=0)
    train_y = np.concatenate(ys, axis=0)
    return train_x, train_y

In [114]:
train_x, train_y = make_dataset(trains)

100%|██████████| 23547/23547 [00:06<00:00, 3542.06it/s]


In [113]:
clf = LogisticRegression(penalty='l2', max_iter=10, verbose=True).fit(train_x, train_y)
clf.score(train_x, train_y)

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed: 45.0min finished


0.10297918006462584

In [124]:
def make_valset(question, answer):
    val_x = np.zeros((len(question), num_ing))
    val_y = np.zeros(len(question))
    for i in range(len(question)):
        nodes = [int(node) for node in question[i]]
        target = int(answer[i][0])
        for node in nodes:
            val_x[i][node] = 1
        val_y[i] = target
        
    return val_x, val_y

In [125]:
val_x, val_y = make_valset(val_cpt_q, val_cpt_a)

In [127]:
clf.score(val_x, val_y)

0.09174311926605505

## RNN

In [141]:
class CompleteRNN(nn.Module):
    def __init__(self, input_size, n_hidden, n_class):
        super(CompleteRNN, self).__init__()

        self.rnn = nn.RNN(input_size=input_size, hidden_size=n_hidden)
        self.W = nn.Parameter(torch.randn([n_hidden, n_class]).type(torch.float32))
        self.b = nn.Parameter(torch.randn([n_class]).type(torch.float32))
        self.Softmax = nn.Softmax(dim=1)

    def forward(self, hidden, X):
        X = X.transpose(0, 1)
        X = X.type(torch.float32)
        outputs, hidden = self.rnn(X, hidden)
        outputs = outputs[-1]
        model = torch.mm(outputs, self.W) + self.b
#         model = torch.softmax(model, dim=1)
        return model

In [142]:
class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, train_list, embedding):
        self.input = []
        self.target = [] # missing ingredient
        
        for data in train_list:
            nodes = [int(node) for node in data[0:-1]]
            for node in nodes:
                other_nodes = [nd for nd in nodes if nd != node]
                if len(other_nodes) == 0:
                    continue
                examples = []
                for other in other_nodes:
                    examples.append(embedding[other])
                self.input.append(np.stack(examples))
                self.target.append(node)                

    def __len__(self):
        return len(self.input)

    def __getitem__(self, idx):
        inp = self.input[idx]
        target = self.target[idx]
        return inp, target

In [143]:
# train_loader = torch.utils.data.DataLoader(
#     TrainDataset(trains, transform=torchvision.transforms.ToTensor()), 
#     batch_size=32, 
#     shuffle=True
# )

batch_size = 1

train_loader = torch.utils.data.DataLoader(
    TrainDataset(trains, svd32), 
    batch_size=batch_size, 
    shuffle=True
)

In [144]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [148]:
input_size = 32
n_hidden = 64
model = CompleteRNN(input_size=input_size, n_hidden=n_hidden, n_class=6714)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,  milestones=[10000, 20000], gamma=0.1)

for epoch in range(2):
    losses = AverageMeter()
    acces = AverageMeter()
    num_datas = 0
    for i, (input_data, label) in enumerate(train_loader):
#         print(i, data, label)
        hidden = torch.zeros(1, batch_size, n_hidden, requires_grad=True)
        output = model(hidden, input_data)
        loss = criterion(output, label)
        pred = torch.argmax(output, dim=1)
        
        losses.update(loss.item(), input_data.size(0))
        pred = pred.numpy()
        label = label.numpy()
        acces.update((pred == label).sum(), input_data.size(0))
        if i % 100 == 0:
            print(i, losses.avg, acces.sum)
            losses.reset()
            acces.reset()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        

0 15.374687194824219 0
100 19.374585380554198 0
200 19.57853721141815 0
300 19.580387902259826 0
400 17.843226635456084 1
500 18.361026492118835 0
600 16.92524642586708 1
700 18.4481254863739 0
800 17.31158528327942 0
900 16.573138194084166 1
1000 15.243335678577424 1
1100 15.274201422929764 1
1200 14.661488585472107 0
1300 14.738458690643311 0
1400 14.953477597236633 0
1500 14.698786931037903 0
1600 13.769261434078217 0
1700 14.43545117855072 0
1800 13.677596998214721 0
1900 13.282979457378387 0
2000 14.098332224488258 2
2100 14.032453281879425 0
2200 13.52780065536499 0
2300 12.183848358392716 1
2400 12.71117362856865 1
2500 13.017272582054138 0
2600 11.873878622055054 0
2700 11.486643298864365 1
2800 11.903489807844162 1
2900 10.688079844713211 3
3000 11.77183242559433 0
3100 11.3980481672287 1
3200 11.481877042651176 3
3300 10.960198636054992 4
3400 11.449365649223328 1
3500 12.12933185338974 0
3600 10.754436737298965 1
3700 10.920536754131318 1
3800 10.885261986255646 1
3900 11.20

31900 7.154901423454285 3
32000 7.520479341745377 2
32100 7.47623389005661 2
32200 7.64449277639389 3
32300 7.030284390449524 1
32400 7.146286296844482 2
32500 7.208355479240417 2
32600 6.949906814098358 2
32700 6.735969498157501 3
32800 7.851794321537017 3
32900 7.2170648503303525 5
33000 6.785430545806885 6
33100 7.832793092727661 4
33200 6.46053030014038 8
33300 7.065535321235656 4
33400 7.339726450443268 4
33500 7.498548729419708 2
33600 7.583739566802978 2
33700 7.239987273216247 6
33800 7.185142378807068 2
33900 6.688694207668305 7
34000 7.864405128955841 1
34100 7.270276598930359 4
34200 6.864362444877624 5
34300 7.386942417621612 3
34400 7.1820155930519105 5
34500 7.468501472473145 3
34600 6.785965142250061 2
34700 7.401382775306701 1
34800 7.586453430652618 1
34900 7.386287748813629 2
35000 7.1984251022338865 2
35100 6.758884518146515 1
35200 7.062820489406586 6
35300 7.70552941083908 2
35400 6.986703677177429 0
35500 7.211084382534027 4
35600 7.659829244613648 5
35700 7.27214

63400 6.84498865365982 4
63500 6.937641415596008 6
63600 7.241287338733673 5
63700 7.080218646526337 5
63800 6.91431054353714 2
63900 7.217849309444428 4
64000 7.17675950050354 5
64100 7.214565570354462 4
64200 7.319699521064758 6
64300 6.8353430676460265 3
64400 6.751939253807068 3
64500 7.459752638339996 3
64600 7.054292616844177 2
64700 7.223762998580932 1
64800 6.914902451038361 2
64900 6.668531696796418 2
65000 7.3009673738479615 0
65100 7.051951453685761 3
65200 7.228599135875702 4
65300 7.2173068857192995 3
65400 7.1688108730316165 1
65500 7.307386965751648 4
65600 7.187533636093139 5
65700 7.290075631141662 1
65800 6.955978364944458 4
65900 7.100870130062103 2
66000 6.494882409572601 2
66100 7.511677827835083 5
66200 7.174930989742279 2
66300 7.357598750591278 4
66400 7.471115720272064 3
66500 6.724519150257111 2
66600 7.58888436794281 4
66700 7.210755574703216 5
66800 7.478342537879944 2
66900 6.873490688800811 3
67000 7.391270935535431 8
67100 7.230512218475342 6
67200 7.2792

95000 7.176479723453522 3
95100 6.976846332550049 3
95200 7.576539771556854 2
95300 6.975901792049408 3
95400 7.4005149507522585 1
95500 7.497273757457733 2
95600 7.006978697776795 1
95700 7.800937674045563 0
95800 6.828184669017792 1
95900 7.475473091602326 2
96000 6.826198923587799 5
96100 7.374911706447602 4
96200 7.6348898267745975 5
96300 7.316362731456756 1
96400 6.560499639511108 5
96500 6.6873007988929745 3
96600 7.056467182636261 3
96700 6.972753264904022 1
96800 7.245385429859161 7
96900 7.548638846874237 1
97000 6.66304217338562 5
97100 7.365385890007019 2
97200 6.627035963535309 1
97300 7.537740443944931 3
97400 6.784115636348725 3
97500 6.874075889587402 4
97600 6.982552983760834 3
97700 6.850386300086975 3
97800 7.731203074455261 3
97900 7.0695641040802 0
98000 6.498878397941589 4
98100 7.128067145347595 5
98200 6.737031710147858 2
98300 6.748430461883545 4
98400 6.838670542240143 4
98500 7.150673258304596 5
98600 6.95179500579834 3
98700 6.886696164608002 9
98800 7.40716

125600 7.15909209728241 2
125700 6.669286494255066 8
125800 6.56604311466217 5
125900 6.966272072792053 0
126000 7.035099375247955 1
126100 6.930219621658325 7
126200 6.628274819850922 4
126300 7.268741466999054 0
126400 7.250357401371002 4
126500 7.561825714111328 5
126600 6.955405013561249 5
126700 6.641995096206665 5
126800 6.988300197124481 1
126900 7.180316619873047 2
127000 7.214948122501373 4
127100 6.75256276845932 6
127200 7.188424654006958 6
127300 7.079657306671143 1
127400 7.787568476200104 1
127500 7.449623901844024 2
127600 6.862340300083161 3
127700 7.087326910495758 6
127800 7.300344631671906 6
127900 6.927349505424499 3
128000 6.941821885108948 4
128100 7.274910001754761 1
128200 7.013846671581268 6
128300 6.573863279819489 4
128400 6.775543677806854 1
128500 7.117142689228058 3
128600 7.6337146162986755 3
128700 6.665072214603424 5
128800 7.033474688529968 4
128900 7.101547629833221 6
129000 7.57883661031723 4
129100 7.004908268451691 2
129200 7.218770534992218 6
1293

156000 6.6828845334053035 7
156100 7.030652022361755 2
156200 7.229731116294861 4
156300 6.684633775949478 7
156400 6.4416193008422855 7
156500 6.873400385379791 3
156600 7.253653297424316 1
156700 7.035288581848144 4
156800 7.261496107578278 5
156900 6.882407443523407 5
157000 6.76620224237442 6
157100 6.9581289911270146 4
157200 6.863882222175598 2
157300 7.229012839794159 4
157400 7.00016419172287 1
157500 6.963833169937134 5
157600 7.239280152320862 3
157700 6.584468743801117 4
157800 6.992405257225037 5
157900 7.083361852169037 3
158000 6.942444279193878 3
158100 6.806631968021393 4
158200 7.172645940780639 3
158300 6.805762138366699 4
158400 7.338484191894532 5
158500 6.736437065601349 7
158600 6.745430123806 5
158700 6.859397387504577 3
158800 6.845731616020203 3
158900 6.4711945796012875 2
159000 6.9007844686508175 0
159100 6.450087802410126 6
159200 7.0610861444473265 8
159300 7.677125966548919 2
159400 6.641911776065826 2
159500 6.694877696037293 6
159600 6.449976966381073 4


186400 7.124929316043854 4
186500 7.203703422546386 2
186600 6.638663942813873 1
186700 7.013517305850983 5
186800 7.19395370721817 3
186900 7.017847626209259 1
187000 6.749518995285034 5
187100 6.666532940864563 5
187200 7.027728250026703 2
187300 6.844066898822785 4
187400 6.919526498317719 1
187500 7.072970304489136 3
187600 7.5153495621681214 1
187700 7.105697684288025 1
187800 7.100562314987183 3
187900 6.762823603153229 2
188000 6.4675081038475035 8
188100 7.0657967400550845 6
188200 7.137330403327942 2
188300 6.968423550128937 4
188400 7.093355386257172 3
188500 6.8497523021698 1
188600 6.593829231262207 6
188700 7.397594966888428 2
188800 7.214577341079712 1
188900 6.640513589382172 8
189000 6.744074394702912 7
189100 6.896816260814667 4
189200 7.1515656042099 3
189300 7.100028502941131 1
189400 7.210868334770202 5
189500 7.004195370674133 3
189600 6.79760635137558 5
189700 6.712225515842437 2
189800 6.910598680973053 4
189900 6.4137613749504085 2
190000 6.790923452377319 4
190

216800 6.890592367649078 5
216900 7.213592801094055 2
217000 6.7220820045471195 3
217100 6.618222761154175 0
217200 7.147051985263825 4
217300 6.84053328037262 3
217400 6.937117474079132 4
217500 6.664889032840729 2
217600 6.5004031944274905 2
217700 6.611577196121216 6
217800 6.602038621902466 1
217900 7.446102578639984 3
218000 6.550645220279693 6
218100 6.752140793800354 2
218200 6.76152679681778 6
218300 6.927421059608459 5
218400 6.829293096065522 1
218500 6.577257845401764 3
218600 6.867750985622406 4
218700 7.1766777205467225 3
218800 6.922265963554382 2
218900 6.9491413617134095 7
219000 7.324111590385437 1
219100 7.299522781372071 3
219200 6.888962273597717 11
219300 7.3970811295509336 3
219400 6.641555380821228 2
219500 7.090563650131226 3
219600 7.123720602989197 2
219700 6.68016058921814 5
219800 6.888676352500916 2
219900 6.814228184223175 4
220000 6.757257196903229 1
220100 6.82343067407608 3
220200 7.224719877243042 2
220300 7.436165442466736 3
220400 6.504537997245788 2

247200 6.722720863819123 3
247300 7.124034576416015 3
247400 6.719180901050567 2
247500 7.114870467185974 3
247600 7.237853770256042 3
247700 7.318141808509827 4
247800 7.091161279678345 3
247900 6.843859782218933 5
248000 6.849147951602935 4
248100 6.714320156574249 3
248200 6.5040846753120425 3
248300 7.014209027290344 2
248400 6.993542730808258 4
248500 6.151724820137024 6
248600 6.723886730670929 6
248700 6.58050999879837 4
248800 7.045079820156097 0
248900 7.188861124515533 6
249000 7.197810592651368 2
249100 6.392468161582947 5
249200 6.63064945936203 3
249300 6.388994712829589 5
249400 7.546133198738098 0
249500 6.8527153205871585 3
249600 6.925460107326508 2
249700 7.431050243377686 2
249800 6.894740147590637 3
249900 6.500521609783172 5
250000 6.871890590190888 3
250100 6.526388771533966 5
250200 6.898290176391601 3
250300 6.95222172498703 6
250400 7.123258528709411 2
250500 6.607738890647888 1
250600 7.50290896654129 0
250700 6.754002406597137 4
250800 6.745398147106171 5
250

25400 6.36045168876648 4
25500 6.7662415599823 1
25600 6.878597643375397 5
25700 6.737729278802871 4
25800 7.186935398578644 6
25900 6.920461373329163 6
26000 7.119861382246017 4
26100 6.8962046241760255 3
26200 6.377115609645844 7
26300 6.552403070926666 4
26400 6.558730070590973 2
26500 6.830173580646515 7
26600 7.0345120525360105 1
26700 6.835423791408539 2
26800 6.8103675365448 2
26900 7.033551144599914 3
27000 6.913795547485352 4
27100 6.602837293148041 7
27200 7.060369131565094 4
27300 7.15983368396759 2
27400 7.3563861608505245 4
27500 6.939527454376221 3
27600 6.486365480422974 4
27700 6.50501496553421 3
27800 6.965613231658936 3
27900 7.00997832775116 4
28000 6.457216980457306 5
28100 6.449131472110748 3
28200 6.690820724964142 2
28300 6.954440026283264 3
28400 6.521247413158417 1
28500 6.23460964679718 6
28600 6.934534778594971 4
28700 7.051971898078919 4
28800 6.6717229437828065 4
28900 6.815408625602722 6
29000 6.994308404922485 6
29100 6.92110965013504 5
29200 6.9574405908

57000 7.247671089172363 1
57100 7.016212604045868 3
57200 6.695577783584595 2
57300 6.6053338360786436 4
57400 6.675328862667084 5
57500 6.883283214569092 8
57600 6.4283315944671635 4
57700 6.574748051166535 3
57800 6.4294608688354495 4
57900 6.77347011089325 3
58000 7.081472985744476 2
58100 6.622395210266113 1
58200 6.334725694656372 3
58300 6.628086977005005 4
58400 6.969429702758789 3
58500 6.867393698692322 4
58600 6.603544621467591 8
58700 6.8656825041770935 4
58800 7.03664217710495 4
58900 6.367851567268372 7
59000 7.171080996990204 1
59100 7.333765139579773 4
59200 7.031007368564605 4
59300 6.477743167877197 6
59400 6.976366124153137 4
59500 7.184284851551056 0
59600 6.898500185012818 7
59700 7.100098612308503 7
59800 7.10104252576828 0
59900 6.9393627643585205 6
60000 6.798528573513031 6
60100 7.0814039063453675 4
60200 6.616687965393067 3
60300 6.813712949752808 4
60400 6.629106113910675 4
60500 7.014660799503327 4
60600 6.881980149745941 5
60700 6.935091042518616 2
60800 7.0

88500 6.590228917598725 5
88600 7.121970705986023 1
88700 6.846021838188172 7
88800 6.7050481295585636 6
88900 6.203523328304291 8
89000 7.20861450433731 4
89100 6.870063352584839 4
89200 6.278512818813324 3
89300 7.018622393608093 4
89400 7.1222656726837155 3
89500 6.814326245784759 3
89600 7.0229185628890995 2
89700 6.755971875190735 4
89800 6.635331857204437 3
89900 6.513960747718811 7
90000 6.807104597091675 4
90100 6.471226599216461 4
90200 6.878586418628693 4
90300 6.570857729911804 5
90400 7.11155953168869 3
90500 6.980793023109436 2
90600 6.6574823713302616 2
90700 6.488963422775268 1
90800 7.100018987655639 4
90900 6.866182749271393 2
91000 6.781264803409576 5
91100 6.683271543979645 5
91200 6.519714198112488 4
91300 6.898021416664124 2
91400 6.470247621536255 3
91500 6.553708388805389 4
91600 6.580149872303009 1
91700 6.797233965396881 2
91800 6.8280401611328125 4
91900 6.826344830989838 5
92000 7.0061075663566585 1
92100 6.579518795013428 4
92200 7.10079835653305 5
92300 6.7

119200 6.760909316539764 2
119300 6.54102377653122 3
119400 7.060202622413636 1
119500 7.069377164840699 0
119600 6.348305277824402 5
119700 7.049747376441956 4
119800 7.221133873462677 3
119900 6.558932130336761 3
120000 7.107211935520172 1
120100 6.302219233512878 9
120200 6.738128716945648 7
120300 6.748521928787231 6
120400 6.494344704151153 1
120500 6.817625458240509 4
120600 6.687827863693237 5
120700 6.811140348911286 3
120800 7.159312572479248 6
120900 7.049356067180634 5
121000 6.7738955950737 3
121100 7.072186107635498 4
121200 6.916715664863586 3
121300 6.66447012424469 1
121400 6.707567937374115 3
121500 6.898292021751404 3
121600 6.9330134773254395 0
121700 6.444379162788391 1
121800 6.485517373085022 7
121900 7.00812273979187 7
122000 6.853996605873108 3
122100 6.554837625026703 3
122200 6.940169694423676 3
122300 6.458538126945496 6
122400 6.198776805400849 7
122500 6.763362421989441 4
122600 6.617280278205872 2
122700 6.744263663291931 4
122800 6.918353638648987 4
12290

149600 6.819679534435272 2
149700 7.231401131153107 4
149800 6.521005806922912 7
149900 6.716386880874634 5
150000 6.575366034507751 8
150100 6.467121374607086 4
150200 7.00654274225235 1
150300 6.826994824409485 5
150400 6.3257135796546935 6
150500 6.1464368891716 4
150600 6.928511836528778 5
150700 6.849932832717895 6
150800 6.553460865020752 6
150900 6.802862505912781 2
151000 6.96381938457489 4
151100 6.5680940103530885 7
151200 6.555665943622589 3
151300 6.952633690834046 0
151400 6.873202896118164 5
151500 6.8089936828613284 4
151600 6.791659226417542 2
151700 7.084989151954651 4
151800 6.958052628040313 3
151900 6.60045895576477 2
152000 7.046672308444977 3
152100 6.737929673194885 2
152200 6.882343056201935 2
152300 6.93231128692627 4
152400 6.697493770122528 4
152500 7.050726997852325 3
152600 6.482473487854004 9
152700 6.859070422649384 5
152800 6.614722707271576 2
152900 6.921848866939545 4
153000 7.195308017730713 4
153100 6.857428565025329 1
153200 6.6375672817230225 2
153

180000 7.013855576515198 5
180100 7.030774738788605 2
180200 6.812599620819092 2
180300 7.1401810908317564 6
180400 7.261033382415771 4
180500 6.671103556156158 5
180600 6.666586210727692 4
180700 6.8470693731307986 7
180800 6.930000684261322 1
180900 6.408447251319886 4
181000 6.333750278949737 2
181100 6.462708115577698 2
181200 6.786386826038361 3
181300 6.582305324077606 2
181400 6.7500363230705265 5
181500 6.832119414806366 5
181600 7.088359990119934 7
181700 6.534718654155731 5
181800 6.643716735839844 6
181900 6.95783964395523 1
182000 6.696452059745789 5
182100 7.076474997997284 4
182200 6.950661416053772 1
182300 7.159578506946564 0
182400 6.844729759693146 2
182500 7.2456727528572085 2
182600 7.1782873916625975 2
182700 6.6841849207878115 6
182800 6.608316805362701 6
182900 6.696042764186859 5
183000 6.164285533428192 8
183100 6.689278814792633 2
183200 6.622063324451447 5
183300 7.217384901046753 1
183400 6.6608408617973325 3
183500 7.099773836135864 4
183600 6.8878920507431

210400 6.66224915266037 2
210500 6.78635947227478 2
210600 6.711798782348633 4
210700 6.815303785800934 4
210800 6.927816028594971 1
210900 6.585294151306153 2
211000 6.902429227828979 3
211100 6.686187992095947 8
211200 6.369476730823517 8
211300 6.38289690732956 2
211400 6.8263252210617065 5
211500 6.27353452205658 4
211600 7.061796877384186 3
211700 7.100047430992126 3
211800 6.663939003944397 4
211900 7.067718110084534 2
212000 7.092592346668243 3
212100 6.346686692237854 3
212200 6.318030381202698 7
212300 6.710843489170075 5
212400 6.384895095825195 4
212500 6.793319804668426 3
212600 6.636327707767487 2
212700 6.574633891582489 4
212800 7.026480135917663 2
212900 6.329782958030701 4
213000 6.396968522071838 4
213100 6.749592649936676 4
213200 6.528168544769287 4
213300 6.674688997268677 1
213400 6.61786782503128 8
213500 6.501457595825196 2
213600 6.788432374000549 4
213700 6.431233637332916 3
213800 6.809881753921509 4
213900 6.909410374164581 6
214000 6.035952510833741 5
21410

240800 6.855206196308136 7
240900 6.732122001647949 6
241000 6.91329256772995 1
241100 6.7436951899528506 4
241200 7.1107927465438845 4
241300 6.687762672901154 2
241400 6.392432708740234 7
241500 6.99603374004364 8
241600 6.307219414710999 4
241700 6.719997401237488 5
241800 6.777257852554321 4
241900 6.270802500247956 8
242000 6.757286648750306 2
242100 6.992397301197052 2
242200 6.42862765789032 5
242300 6.430842258930206 5
242400 6.602453601360321 8
242500 6.879839525222779 3
242600 6.48357590675354 4
242700 6.8547330093383785 4
242800 6.8053667569160465 1
242900 6.927319989204407 3
243000 6.769316284656525 5
243100 6.681656222343445 3
243200 7.23164835691452 0
243300 6.483326108455658 2
243400 6.207678062915802 5
243500 6.675399742126465 7
243600 6.850461790561676 3
243700 6.785667147636413 5
243800 6.510945415496826 3
243900 6.722158527374267 3
244000 6.81515450000763 6
244100 6.842250428199768 1
244200 6.5968734383583065 3
244300 6.7337169218063355 1
244400 7.0791642022132875 2


In [159]:
class ValDataset(torch.utils.data.Dataset):
    def __init__(self, val_question, val_answer, embedding):
        self.input = []
        self.target = [] # missing ingredient
        
        for i in range(len(val_question)):
            nodes = [int(node) for node in val_question[i]]
            examples = []
            for node in nodes:
                examples.append(embedding[node])
            self.input.append(np.stack(examples))
            self.target.append(int(val_answer[i][0]))

    def __len__(self):
        return len(self.input)

    def __getitem__(self, idx):
        inp = self.input[idx]
        target = self.target[idx]
        return inp, target

In [160]:
val_loader = torch.utils.data.DataLoader(
    ValDataset(val_cpt_q, val_cpt_a, svd32), 
    batch_size=batch_size, 
    shuffle=False
)

In [162]:
acc = 0
with torch.no_grad():
    for i, (input_data, label) in enumerate(val_loader):
        hidden = torch.zeros(1, batch_size, n_hidden, requires_grad=True)
        output = model(hidden, input_data)
        pred = torch.argmax(output, dim=1)
        if pred.item() == label.item():
            acc += 1
print(acc / len(val_loader))

0.03835372069317024
