In [1]:
import numpy as np
import torch
from torch.utils.data import Dataset
import pandas as pd
from numpy import argmax
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OneHotEncoder
import json
import time
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import csv

In [2]:
np.random.seed(12345)

# run on the GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class DataReader:
    NEGATIVE_TABLE_SIZE = 1e8

    def __init__(self, inputFileName, min_count):

        self.negatives = []
        self.discards = []
        self.negpos = 0

        self.word2id = dict()
        self.id2word = dict()
        self.sentences_count = 0
        self.token_count = 0
        self.word_frequency = dict()

        self.inputFileName = inputFileName
        self.read_words(min_count)
        self.initTableNegatives()
        self.initTableDiscards()
        
    def plot_frequency(self):
        pass
      
    def read_words(self, min_count):
        word_frequency = dict()
        for line in open(self.inputFileName, encoding="utf8"):
            line = line.split()
            if len(line) > 1:
                self.sentences_count += 1
                for word in line:
                    if len(word) > 0:
                        self.token_count += 1
                        word_frequency[word] = word_frequency.get(word, 0) + 1

                        if self.token_count % 1000000 == 0:
                            print("Read " + str(int(self.token_count / 1000000)) + "M words.")
        # show each word's frequency before the discard action
#         plot_frequency()
        wid = 0
        print()
        # w represents the word; c is the frequency of the word
        for w, c in word_frequency.items():
            if c < min_count:
                continue
            # if the counts of one word is less than min_count, then don't put this word in the vocabulary
            self.word2id[w] = wid
            self.id2word[wid] = w
            self.word_frequency[wid] = c
            wid += 1
        print("Total embeddings: " + str(len(self.word2id)))

    def initTableDiscards(self):
        t = 0.00001
        f = np.array(list(self.word_frequency.values())) / self.token_count
        # every ingredient's Probability to be discarded
        self.discards = np.sqrt(t / f) + (t / f)

    def initTableNegatives(self):
        pow_frequency = np.array(list(self.word_frequency.values())) ** 0.5
        words_pow = sum(pow_frequency)
        ratio = pow_frequency / words_pow
        count = np.round(ratio * DataReader.NEGATIVE_TABLE_SIZE)
        for wid, c in enumerate(count):
            self.negatives += [wid] * int(c)
        self.negatives = np.array(self.negatives)
        np.random.shuffle(self.negatives)

    def getNegatives(self, target, size):  # TODO check equality with target
        response = self.negatives[self.negpos:self.negpos + size]
        self.negpos = (self.negpos + size) % len(self.negatives)
        if len(response) != size:
            return np.concatenate((response, self.negatives[0:self.negpos]))
        return response

In [3]:
# -----------------------------------------------------------------------------------------------------------------

class Word2vecDataset(Dataset):
  # data is the object of class DataReader
    def __init__(self, data, window_size):
        self.data = data
        self.window_size = window_size
        self.input_file = open(data.inputFileName, encoding="utf8")

    def __len__(self):
        return self.data.sentences_count

    def __getitem__(self, idx):
        while True:
            line = self.input_file.readline()
            if not line:
                self.input_file.seek(0, 0)
                line = self.input_file.readline()

            if len(line) > 1:
                words = line.split()

                if len(words) > 1:
                    word_ids = [self.data.word2id[w] for w in words if
                                # according to the discard probabilty to decide keep this word or not so called: subsampling
                                w in self.data.word2id and np.random.rand() < self.data.discards[self.data.word2id[w]]]

                    boundary = np.random.randint(1, self.window_size)
                    # negative sampling
                    return [(u, v, self.data.getNegatives(v, 5)) for i, u in enumerate(word_ids) for j, v in
                            enumerate(word_ids[max(i - boundary, 0):i + boundary]) if u != v]

    @staticmethod
    def collate(batches):
        # u - center word
        all_u = [u for batch in batches for u, _, _ in batch if len(batch) > 0]
        # v - neighbor words
        all_v = [v for batch in batches for _, v, _ in batch if len(batch) > 0]
        all_neg_v = [neg_v for batch in batches for _, _, neg_v in batch if len(batch) > 0]

        return torch.LongTensor(all_u), torch.LongTensor(all_v), torch.LongTensor(all_neg_v)

In [4]:

"""
    u_embedding: Embedding for center word.
    v_embedding: Embedding for neighbor words.
"""


class SkipGramModel(nn.Module):

    def __init__(self, emb_size, emb_dimension):
        super(SkipGramModel, self).__init__()
        self.emb_size = emb_size
        self.emb_dimension = emb_dimension
        self.u_embeddings = nn.Embedding(emb_size, emb_dimension, sparse=True)
        self.v_embeddings = nn.Embedding(emb_size, emb_dimension, sparse=True)

        initrange = 1.0 / self.emb_dimension
        init.uniform_(self.u_embeddings.weight.data, -initrange, initrange)
        init.constant_(self.v_embeddings.weight.data, 0)

    def forward(self, pos_u, pos_v, neg_v):
        emb_u = self.u_embeddings(pos_u)
        emb_v = self.v_embeddings(pos_v)
        emb_neg_v = self.v_embeddings(neg_v)

        score = torch.sum(torch.mul(emb_u, emb_v), dim=1)
        score = torch.clamp(score, max=10, min=-10)
        score = -F.logsigmoid(score)

        neg_score = torch.bmm(emb_neg_v, emb_u.unsqueeze(2)).squeeze()
        neg_score = torch.clamp(neg_score, max=10, min=-10)
        neg_score = -torch.sum(F.logsigmoid(-neg_score), dim=1)

        return torch.mean(score + neg_score)

    def save_embedding(self, id2word, file_name):
        embedding = self.u_embeddings.weight.cpu().data.numpy()
        with open(file_name, 'w') as f:
            csv_writer = csv.writer(f)
            csv_writer.writerow(['Ingredient', 'Vector'])
            for wid, w in id2word.items():
                vector = str(list(embedding[wid]))
                csv_writer.writerow([w,vector])
        return embedding, id2word

In [5]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

# from data_reader import DataReader, Word2vecDataset
# from model import SkipGramModel


class Word2VecTrainer:
    def __init__(self, input_file, output_file, emb_dimension=100, batch_size=32, window_size=5, iterations=10,
                 initial_lr=0.001, min_count=5):

        self.data = DataReader(input_file, min_count)
        # dataset is the object of class Word2vecDataset
        dataset = Word2vecDataset(self.data, window_size)
        self.dataloader = DataLoader(dataset, batch_size=batch_size,
                                     shuffle=False, num_workers=0, collate_fn=dataset.collate)

        self.output_file_name = output_file
        self.emb_size = len(self.data.word2id)
        self.emb_dimension = emb_dimension
        self.batch_size = batch_size
        self.iterations = iterations
        self.initial_lr = initial_lr
        # put model on the GPU
        self.skip_gram_model = SkipGramModel(self.emb_size, self.emb_dimension).to(device)


    def train(self):

        for iteration in range(self.iterations):

            print("\n\n\nIteration: " + str(iteration + 1))
            optimizer = optim.SparseAdam(self.skip_gram_model.parameters(), lr=self.initial_lr)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(self.dataloader))

            running_loss = 0.0
            for i, sample_batched in enumerate(tqdm(self.dataloader)):

                if len(sample_batched[0]) > 1:
                    # put training data on the GPU
                    pos_u = sample_batched[0].to(device)
                    pos_v = sample_batched[1].to(device)
                    neg_v = sample_batched[2].to(device)
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                    loss = self.skip_gram_model.forward(pos_u, pos_v, neg_v)
                    loss.backward()
                    running_loss = running_loss * 0.9 + loss.item() * 0.1
                    if i > 0 and i % 500 == 0:
                        print(" Loss: " + str(running_loss))
        self.skip_gram_model.save_embedding(self.data.id2word, self.output_file_name)
        return self.skip_gram_model 

In [6]:
inputFileName = 'ingredients_1127.txt'
w2v = Word2VecTrainer(input_file=inputFileName, output_file="food2vec_1127.csv")
skip_gram_model = w2v.train()

Read 1M words.

Total embeddings: 7032


  0%|          | 0/5571 [00:00<?, ?it/s]




Iteration: 1


  9%|▉         | 517/5571 [00:15<00:55, 91.42it/s]

 Loss: 4.139847924957127


 18%|█▊        | 1015/5571 [00:20<00:52, 86.60it/s]

 Loss: 3.8480568948127964


 27%|██▋       | 1508/5571 [00:35<01:57, 34.61it/s]

 Loss: 3.5640727781362838


 36%|███▌      | 2012/5571 [00:45<00:50, 70.13it/s]

 Loss: 3.3168417040717935


 45%|████▌     | 2509/5571 [00:57<01:52, 27.15it/s] 

 Loss: 3.195380287736837


 54%|█████▍    | 3018/5571 [01:11<00:29, 86.91it/s]

 Loss: 3.1357663411729346


 63%|██████▎   | 3505/5571 [01:20<01:13, 28.10it/s]

 Loss: 3.047042379150491


 72%|███████▏  | 4018/5571 [01:36<00:29, 52.04it/s]

 Loss: 3.040834763262996


 81%|████████  | 4519/5571 [01:41<00:08, 118.52it/s]

 Loss: 3.0060440985173815


 90%|████████▉ | 5005/5571 [01:49<00:12, 46.02it/s] 

 Loss: 3.0032440430832974


 99%|█████████▉| 5529/5571 [02:00<00:00, 72.16it/s]

 Loss: 3.0040370480446974


100%|██████████| 5571/5571 [02:01<00:00, 46.01it/s]
  0%|          | 13/5571 [00:00<00:45, 122.33it/s]




Iteration: 2


  9%|▉         | 514/5571 [00:04<00:39, 128.84it/s]

 Loss: 2.9515817018562442


 18%|█▊        | 1008/5571 [00:12<01:38, 46.38it/s]

 Loss: 2.8823479633322995


 27%|██▋       | 1523/5571 [00:23<01:01, 65.68it/s]

 Loss: 2.8683153330112607


 36%|███▌      | 2015/5571 [00:27<00:29, 121.17it/s]

 Loss: 2.8328218875359106


 45%|████▌     | 2507/5571 [00:36<01:17, 39.35it/s] 

 Loss: 2.831770650405441


 54%|█████▍    | 3011/5571 [00:47<00:43, 59.26it/s]

 Loss: 2.802843534853335


 63%|██████▎   | 3514/5571 [00:53<00:16, 123.89it/s]

 Loss: 2.7993617084499185


 72%|███████▏  | 4008/5571 [01:00<00:32, 47.50it/s] 

 Loss: 2.816676378554389


 81%|████████  | 4505/5571 [01:11<00:25, 41.56it/s]

 Loss: 2.8028577475665606


 90%|█████████ | 5015/5571 [01:18<00:04, 118.00it/s]

 Loss: 2.8091602478751954


 99%|█████████▉| 5503/5571 [01:24<00:01, 34.61it/s] 

 Loss: 2.786024008427933


100%|██████████| 5571/5571 [01:26<00:00, 64.44it/s]
  0%|          | 4/5571 [00:00<02:35, 35.86it/s]




Iteration: 3


  9%|▉         | 504/5571 [00:12<02:11, 38.45it/s]

 Loss: 2.7440209542194784


 18%|█▊        | 1019/5571 [00:17<00:39, 115.09it/s]

 Loss: 2.7602062351555987


 27%|██▋       | 1504/5571 [00:25<01:40, 40.66it/s] 

 Loss: 2.7424957473910014


 36%|███▌      | 2005/5571 [00:38<01:36, 37.04it/s]

 Loss: 2.7243280742475546


 45%|████▌     | 2513/5571 [00:42<00:27, 110.07it/s]

 Loss: 2.7254826156105967


 54%|█████▍    | 3008/5571 [00:51<01:03, 40.41it/s] 

 Loss: 2.724473701948927


 63%|██████▎   | 3505/5571 [01:03<00:51, 40.33it/s]

 Loss: 2.7058312855285775


 72%|███████▏  | 4018/5571 [01:08<00:13, 117.83it/s]

 Loss: 2.729286872126195


 81%|████████  | 4504/5571 [01:16<00:29, 36.70it/s] 

 Loss: 2.7237828127459354


 90%|████████▉ | 5011/5571 [01:29<00:12, 45.02it/s]

 Loss: 2.7241899773201306


 99%|█████████▉| 5519/5571 [01:34<00:00, 121.92it/s]

 Loss: 2.7038060441983434


100%|██████████| 5571/5571 [01:34<00:00, 58.86it/s] 
  0%|          | 11/5571 [00:00<00:52, 105.68it/s]




Iteration: 4


  9%|▉         | 512/5571 [00:09<01:41, 49.90it/s] 

 Loss: 2.660800967725313


 18%|█▊        | 1020/5571 [00:21<01:19, 57.08it/s]

 Loss: 2.682956657695339


 27%|██▋       | 1519/5571 [00:26<00:34, 117.90it/s]

 Loss: 2.645354539321998


 36%|███▌      | 2009/5571 [00:32<00:53, 67.04it/s] 

 Loss: 2.6249812877485477


 45%|████▍     | 2505/5571 [00:41<00:45, 67.82it/s]

 Loss: 2.628112560406823


 54%|█████▍    | 3020/5571 [00:49<00:26, 95.31it/s]

 Loss: 2.6034430725239424


 63%|██████▎   | 3525/5571 [00:53<00:15, 128.05it/s]

 Loss: 2.632271560199589


 72%|███████▏  | 4013/5571 [01:01<00:22, 70.52it/s] 

 Loss: 2.642648429027285


 81%|████████  | 4511/5571 [01:08<00:12, 82.58it/s]

 Loss: 2.615639231657524


 90%|█████████ | 5019/5571 [01:16<00:06, 85.80it/s]

 Loss: 2.637912364663753


 99%|█████████▉| 5519/5571 [01:20<00:00, 119.89it/s]

 Loss: 2.618997578734389


100%|██████████| 5571/5571 [01:20<00:00, 68.90it/s] 
  0%|          | 13/5571 [00:00<00:46, 120.63it/s]




Iteration: 5


  9%|▉         | 513/5571 [00:04<00:53, 95.02it/s] 

 Loss: 2.5198663354441195


 18%|█▊        | 1015/5571 [00:10<00:57, 79.69it/s]

 Loss: 2.567281840551757


 27%|██▋       | 1519/5571 [00:16<00:44, 90.49it/s]

 Loss: 2.563981840078459


 36%|███▋      | 2020/5571 [00:23<00:39, 89.65it/s]

 Loss: 2.521522516654612


 45%|████▌     | 2516/5571 [00:27<00:29, 104.68it/s]

 Loss: 2.5606106469665546


 54%|█████▍    | 3010/5571 [00:32<00:30, 82.63it/s] 

 Loss: 2.495678138339537


 63%|██████▎   | 3514/5571 [00:40<00:26, 76.80it/s]

 Loss: 2.5223152710200445


 72%|███████▏  | 4010/5571 [00:46<00:18, 84.03it/s]

 Loss: 2.547286349710944


 81%|████████  | 4514/5571 [00:53<00:08, 126.26it/s]

 Loss: 2.539727472460436


 90%|█████████ | 5016/5571 [00:56<00:04, 130.23it/s]

 Loss: 2.5194052024760496


 99%|█████████▉| 5518/5571 [01:03<00:00, 84.35it/s] 

 Loss: 2.501226310250492


100%|██████████| 5571/5571 [01:04<00:00, 86.97it/s]
  0%|          | 9/5571 [00:00<01:15, 73.26it/s]




Iteration: 6


  9%|▉         | 511/5571 [00:06<00:57, 87.30it/s]

 Loss: 2.4336504810745634


 18%|█▊        | 1010/5571 [00:13<00:55, 82.45it/s]

 Loss: 2.4758584205650296


 27%|██▋       | 1523/5571 [00:17<00:30, 132.64it/s]

 Loss: 2.453572214100809


 36%|███▌      | 2010/5571 [00:22<00:43, 81.80it/s] 

 Loss: 2.4903822856863513


 45%|████▌     | 2515/5571 [00:29<00:37, 81.66it/s]

 Loss: 2.444287697598808


 54%|█████▍    | 3011/5571 [00:36<00:33, 76.48it/s]

 Loss: 2.4701494621995166


 63%|██████▎   | 3516/5571 [00:42<00:17, 114.76it/s]

 Loss: 2.4196488621769916


 72%|███████▏  | 4013/5571 [00:46<00:12, 120.30it/s]

 Loss: 2.4536837934975786


 81%|████████  | 4519/5571 [00:52<00:11, 88.07it/s] 

 Loss: 2.4624668725407344


 90%|█████████ | 5015/5571 [00:58<00:06, 86.83it/s]

 Loss: 2.492080903367785


 99%|█████████▉| 5513/5571 [01:04<00:00, 88.86it/s]

 Loss: 2.4619928794033465


100%|██████████| 5571/5571 [01:04<00:00, 85.83it/s]
  0%|          | 7/5571 [00:00<01:24, 65.66it/s]




Iteration: 7


  9%|▉         | 525/5571 [00:06<00:50, 100.02it/s]

 Loss: 2.3699157284982606


 18%|█▊        | 1016/5571 [00:10<00:37, 120.80it/s]

 Loss: 2.373519565195243


 27%|██▋       | 1516/5571 [00:15<00:46, 87.17it/s] 

 Loss: 2.3911318259781504


 36%|███▌      | 2016/5571 [00:21<00:40, 88.86it/s]

 Loss: 2.443549282024269


 45%|████▌     | 2514/5571 [00:27<00:39, 77.11it/s]

 Loss: 2.3928473837649067


 54%|█████▍    | 3011/5571 [00:32<00:27, 91.55it/s]

 Loss: 2.447947100794162


 63%|██████▎   | 3528/5571 [00:37<00:15, 134.36it/s]

 Loss: 2.4137069847980155


 72%|███████▏  | 4011/5571 [00:42<00:24, 62.65it/s] 

 Loss: 2.441586039865295


 81%|████████  | 4508/5571 [00:48<00:14, 73.28it/s]

 Loss: 2.4297900119639406


 90%|████████▉ | 5010/5571 [00:55<00:08, 64.25it/s]

 Loss: 2.4712136610529956


 99%|█████████▉| 5509/5571 [01:02<00:00, 63.91it/s]

 Loss: 2.3885407249875574


100%|██████████| 5571/5571 [01:03<00:00, 87.83it/s]
  0%|          | 13/5571 [00:00<00:44, 125.99it/s]




Iteration: 8


  9%|▉         | 515/5571 [00:03<00:37, 134.46it/s]

 Loss: 2.315892755827841


 18%|█▊        | 1011/5571 [00:10<00:54, 83.88it/s]

 Loss: 2.365391099433201


 27%|██▋       | 1508/5571 [00:17<01:11, 56.46it/s]

 Loss: 2.36299237233387


 36%|███▌      | 2014/5571 [00:24<00:43, 81.24it/s]

 Loss: 2.327916575341117


 45%|████▌     | 2515/5571 [00:29<00:23, 128.61it/s]

 Loss: 2.3604177541538762


 54%|█████▍    | 3016/5571 [00:33<00:18, 134.98it/s]

 Loss: 2.374227170410341


 63%|██████▎   | 3514/5571 [00:39<00:27, 75.76it/s] 

 Loss: 2.3547904911607325


 72%|███████▏  | 4009/5571 [00:46<00:19, 78.98it/s]

 Loss: 2.4043821913318544


 81%|████████  | 4517/5571 [00:52<00:12, 83.35it/s]

 Loss: 2.422301238332371


 90%|█████████ | 5026/5571 [00:58<00:04, 121.60it/s]

 Loss: 2.432962692226536


 99%|█████████▉| 5527/5571 [01:02<00:00, 131.76it/s]

 Loss: 2.3622515489962774


100%|██████████| 5571/5571 [01:02<00:00, 89.08it/s] 
  0%|          | 10/5571 [00:00<00:59, 94.13it/s]




Iteration: 9


  9%|▉         | 498/5571 [00:06<01:13, 69.02it/s]

 Loss: 2.270567944366378


 18%|█▊        | 1007/5571 [00:14<00:57, 79.23it/s]

 Loss: 2.2675454937906045


 27%|██▋       | 1510/5571 [00:20<00:50, 79.82it/s]

 Loss: 2.2798013374818225


 36%|███▋      | 2025/5571 [00:25<00:26, 133.57it/s]

 Loss: 2.2643594725738168


 45%|████▌     | 2514/5571 [00:29<00:40, 76.04it/s] 

 Loss: 2.326492753648909


 54%|█████▍    | 3012/5571 [00:36<00:44, 58.07it/s]

 Loss: 2.332367007717966


 63%|██████▎   | 3513/5571 [00:43<00:26, 77.07it/s]

 Loss: 2.2907259139062974


 72%|███████▏  | 4020/5571 [00:50<00:11, 129.85it/s]

 Loss: 2.4273314651250413


 81%|████████  | 4516/5571 [00:53<00:07, 132.06it/s]

 Loss: 2.4151809040188397


 90%|████████▉ | 5009/5571 [00:59<00:08, 68.36it/s] 

 Loss: 2.3476744649200123


 99%|█████████▉| 5511/5571 [01:05<00:00, 63.82it/s]

 Loss: 2.312617082804779


100%|██████████| 5571/5571 [01:06<00:00, 83.30it/s]
  0%|          | 8/5571 [00:00<01:13, 76.04it/s]




Iteration: 10


  9%|▉         | 515/5571 [00:07<01:08, 73.78it/s]

 Loss: 2.2662763176138574


 18%|█▊        | 1020/5571 [00:12<00:34, 132.88it/s]

 Loss: 2.299467134512746


 27%|██▋       | 1512/5571 [00:16<00:52, 77.62it/s] 

 Loss: 2.257600671722373


 36%|███▌      | 2013/5571 [00:22<00:48, 73.91it/s]

 Loss: 2.2833146693749615


 45%|████▌     | 2514/5571 [00:28<00:39, 76.84it/s]

 Loss: 2.312456284216155


 54%|█████▍    | 3009/5571 [00:35<00:32, 78.62it/s]

 Loss: 2.3162044422864416


 63%|██████▎   | 3521/5571 [00:40<00:17, 116.63it/s]

 Loss: 2.2588533490038785


 72%|███████▏  | 4012/5571 [00:45<00:19, 81.86it/s] 

 Loss: 2.3609788712199813


 81%|████████  | 4513/5571 [00:51<00:13, 78.63it/s]

 Loss: 2.389014120332841


 90%|████████▉ | 5008/5571 [00:57<00:08, 67.74it/s]

 Loss: 2.367199815999722


 99%|█████████▉| 5508/5571 [01:05<00:00, 79.97it/s]

 Loss: 2.3233798079621124


100%|██████████| 5571/5571 [01:06<00:00, 84.34it/s]


In [7]:
skip_gram_model

SkipGramModel(
  (u_embeddings): Embedding(7032, 100, sparse=True)
  (v_embeddings): Embedding(7032, 100, sparse=True)
)