In [2]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn

In [3]:
class WordDataset(torch.utils.data.Dataset):
    def __init__(self,input_labels,pos_labels,neg_labels):
        super(WordDataset,self).__init__()
        self.input_labels = input_labels
        self.pos_labels = pos_labels
        self.neg_labels = neg_labels

    def __getitem__(self,index):
        return self.input_labels[index],self.pos_labels[index],self.neg_labels[index]
        
    def __len__(self):
        return len(self.input_labels)

In [4]:
import pickle
f = open('word_dataset.pkl','rb')
word_dataset = pickle.load(f)
f.close()

In [5]:
f = open('word_freq.pkl','rb')
word_freq = pickle.load(f)
f.close()

In [6]:
f = open('word2idx_1.pkl','rb')
word2idx = pickle.load(f)
f.close()

In [7]:
class Word2Vec(nn.Module):
    def __init__(self,num_embeddings = len(word2idx),embedding_dim=128):
        super(Word2Vec,self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        init_range = 1/self.num_embeddings
        # print(init_range)
        self.in_embed = nn.Embedding(num_embeddings=self.num_embeddings,embedding_dim=self.embedding_dim)
        self.in_embed.weight.data.uniform_(-init_range, init_range)
        self.out_embed = nn.Embedding(num_embeddings=self.num_embeddings,embedding_dim=self.embedding_dim)
        self.out_embed.weight.data.uniform_(-init_range, init_range)

    def forward(self,input_labels,pos_labels,neg_lables):
        batch_size = input_labels.size(0)
        input_embedding = self.in_embed(input_labels)
        pos_embedding = self.out_embed(pos_labels)
        neg_embedding = self.out_embed(neg_lables)

        input_embedding = input_embedding.unsqueeze(2)

        pos_dot = torch.bmm(pos_embedding,input_embedding).squeeze()
        neg_dot = torch.bmm(-neg_embedding,input_embedding).squeeze()

        pos_log = nn.functional.logsigmoid(pos_dot).sum(1)
        neg_log = nn.functional.logsigmoid(neg_dot).sum(1)

        return - (pos_log + neg_log)
    
    def input_embedding(self):
        return self.in_embed.weight.data.cpu().numpy()

In [8]:
BATCH_SIZE = 128
word_dataloader = torch.utils.data.DataLoader(word_dataset,batch_size=BATCH_SIZE)

In [9]:
model = Word2Vec()
model = model.cuda()
optimizer = torch.optim.SGD(params=model.parameters(), lr=0.008)
USE_CUDA = torch.cuda.is_available()
EPOCH = 50
itr = 0
loss_list = []
for epoch in range(EPOCH):
    for i,(cen,pos,neg) in enumerate(word_dataloader):
        input_labels = cen.long()
        pos_labels = pos.long()
        neg_labels = neg.long()
        
        if USE_CUDA:
            input_labels = input_labels.cuda()
            pos_labels = pos_labels.cuda()
            neg_labels = neg_labels.cuda()

            loss = model(input_labels, pos_labels, neg_labels).mean()  # 传入参数给forward()函数

            loss.backward()
            optimizer.step()
            loss_list.append(loss.detach().cpu().numpy())
            optimizer.zero_grad()
        if (i+1) % 50 == 0:
            itr += 50
            print('epoch\t:{}\titr:{}\tloss:{}\n'.format(epoch,itr,loss))

epoch	:0	itr:50	loss:280.0314636230469

epoch	:0	itr:100	loss:280.0314636230469

epoch	:0	itr:150	loss:280.0314636230469

epoch	:0	itr:200	loss:280.03143310546875

epoch	:0	itr:250	loss:280.03143310546875

epoch	:0	itr:300	loss:280.0314025878906

epoch	:0	itr:350	loss:280.03125

epoch	:0	itr:400	loss:280.0309143066406

epoch	:0	itr:450	loss:280.02996826171875

epoch	:0	itr:500	loss:280.02691650390625

epoch	:0	itr:550	loss:280.0179443359375

epoch	:0	itr:600	loss:279.98162841796875

epoch	:0	itr:650	loss:279.88641357421875

epoch	:0	itr:700	loss:279.71282958984375

epoch	:0	itr:750	loss:278.8403015136719

epoch	:0	itr:800	loss:277.3780212402344

epoch	:0	itr:850	loss:273.03570556640625

epoch	:0	itr:900	loss:263.02069091796875

epoch	:0	itr:950	loss:251.3513641357422

epoch	:0	itr:1000	loss:250.59353637695312

epoch	:0	itr:1050	loss:240.92051696777344

epoch	:0	itr:1100	loss:228.3973388671875

epoch	:0	itr:1150	loss:233.49720764160156

epoch	:0	itr:1200	loss:213.37744140625

epoch	:0	i

In [10]:
word_embedding = model.input_embedding()

In [11]:
f = open('word_embedding.pkl','wb')
idx2word= pickle.dump(word_embedding,f)
f.close()