In [19]:
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from sklearn import preprocessing

from cuisine.cookbook import get_ingredient_list, get_cookbook_train, get_cookbook_valid_question, get_cookbook_valid_answer
from cuisine.embedding import import_embedding, create_random_embedding, create_one_hot_embedding
from cuisine.chef import make_recipe_embedding_data
from cuisine.utils import move_to_top_directory

In [20]:
move_to_top_directory()
%pwd

'/home/felix/cuisine'

In [21]:
cookbook_train = get_cookbook_train()
cookbook_valid_question = get_cookbook_valid_question()
cookbook_valid_answer = get_cookbook_valid_answer()
ingredient_list = get_ingredient_list()

cookbook_train[15]

{'recipe_id': 15,
 'ingredients': [2866, 4243, 4362, 5377, 5408, 6187, 6352, 6568],
 'kitchen_name': 'italian',
 'kitchen_id': 5}

In [22]:
embedding_random_walk_2 = import_embedding("Embp1q2", ingredient_list)
embedding_random_walk_10 = import_embedding("Embp1q10", ingredient_list)
embedding_uniform_rand = create_random_embedding(ingredient_list, 128)
embedding_svd_32 = import_embedding("SVD32", ingredient_list)
embedding_svd_64 = import_embedding("SVD64", ingredient_list)
embedding_svd_128 = import_embedding("SVD128", ingredient_list)
embedding_one_hot = create_one_hot_embedding(ingredient_list)

embedding_svd_32[2813]

array([34.63750259, 20.35864365, -1.20763975, -0.97736046,  0.23479278,
        5.40115298,  8.51684704, -0.53078058, -0.53245975, -0.8002612 ,
       -3.08002422,  0.61509874,  1.03727886, -1.68163813,  4.22448238,
        4.41397363, -2.43951389,  1.38136843, -0.5487753 , -0.09986042,
       -1.99060814, -0.10028059,  0.17231477, -0.5892525 , -0.08022406,
       -2.57672924,  4.58388978, -2.31123317, -0.80378753,  1.565003  ,
        0.76383124, -3.8348879 ])

In [23]:
embedding = embedding_one_hot

X_train, y_train = make_recipe_embedding_data(cookbook_train, embedding, avg=False)
X_valid, _       = make_recipe_embedding_data(cookbook_valid_question, embedding, avg=False)
_      , y_valid = make_recipe_embedding_data(cookbook_valid_answer, embedding, avg=False)



In [24]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# scaler = preprocessing.StandardScaler().fit(X_train)
# X_train = scaler.transform(X_train)
# X_valid = scaler.transform(X_valid)

X_train = torch.from_numpy(X_train).float().to(device)
y_train = torch.from_numpy(y_train).long().to(device)
X_valid = torch.from_numpy(X_valid).float().to(device)
y_valid = torch.from_numpy(y_valid).long().to(device)

In [25]:
class LogisticNN(torch.nn.Module):

    def __init__(self, embedding_dim):
        super().__init__()

        self.linear1 = torch.nn.Linear(embedding_dim, 64)
        self.linear2 = torch.nn.Linear(64, 64)

    def forward(self, x):

        # x = self.linear1(x)

        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)

        return F.softmax(x, dim=1)

In [26]:
writer = SummaryWriter()

embedding_dim = X_train.shape[1]
model = LogisticNN(embedding_dim).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()

for epoch in range(200):
    out = model(X_train)
    loss = F.cross_entropy(out, y_train)
    
    writer.add_scalar("Loss/train", loss, epoch)

    acc_train = (model(X_train).argmax(dim=1) == y_train).sum() / len(y_train)
    acc_valid = (model(X_valid).argmax(dim=1) == y_valid).sum() / len(y_valid)
    writer.add_scalar("Acc_train", acc_train, epoch)
    writer.add_scalar("Acc_valid", acc_valid, epoch)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

writer.flush()

In [27]:
model.eval()
correct_train = (model(X_train).argmax(dim=1) == y_train).sum()
correct_valid = (model(X_valid).argmax(dim=1) == y_valid).sum()
acc_train = int(correct_train) / len(y_train)
acc_valid = int(correct_valid) / len(y_valid)
print(f'Train Accuracy: {acc_train:.2%}: {correct_train} out of {len(y_train)}')
print(f'Valid Accuracy: {acc_valid:.2%}: {correct_valid} out of {len(y_valid)}')

Train Accuracy: 77.45%: 18236 out of 23547
Valid Accuracy: 66.03%: 5182 out of 7848
