In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn

In [2]:
class WordDataset(torch.utils.data.Dataset):
    def __init__(self,input_labels,pos_labels,neg_labels):
        super(WordDataset,self).__init__()
        self.input_labels = input_labels
        self.pos_labels = pos_labels
        self.neg_labels = neg_labels

    def __getitem__(self,index):
        return self.input_labels[index],self.pos_labels[index],self.neg_labels[index]
        
    def __len__(self):
        return len(self.input_labels)

In [3]:
import pickle
f = open('word_dataset.pkl','rb')
word_dataset = pickle.load(f)
f.close()

In [4]:
f = open('word_freq.pkl','rb')
word_freq = pickle.load(f)
f.close()

In [5]:
f = open('word2idx.pkl','rb')
word2idx = pickle.load(f)
f.close()

In [6]:
class Word2Vec(nn.Module):
    def __init__(self,num_embeddings = len(word2idx),embedding_dim=256):
        super(Word2Vec,self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        init_range = 1/self.num_embeddings
        # print(init_range)
        self.in_embed = nn.Embedding(num_embeddings=self.num_embeddings,embedding_dim=self.embedding_dim)
        self.in_embed.weight.data.uniform_(-init_range, init_range)
        self.out_embed = nn.Embedding(num_embeddings=self.num_embeddings,embedding_dim=self.embedding_dim)
        self.out_embed.weight.data.uniform_(-init_range, init_range)

    def forward(self,input_labels,pos_labels,neg_lables):
        batch_size = input_labels.size(0)
        input_embedding = self.in_embed(input_labels)
        pos_embedding = self.out_embed(pos_labels)
        neg_embedding = self.out_embed(neg_lables)

        input_embedding = input_embedding.unsqueeze(2)

        pos_dot = torch.bmm(pos_embedding,input_embedding).squeeze()
        neg_dot = torch.bmm(-neg_embedding,input_embedding).squeeze()

        pos_log = nn.functional.logsigmoid(pos_dot).sum(1)
        neg_log = nn.functional.logsigmoid(neg_dot).sum(1)

        return - (pos_log + neg_log)
    
    def input_embedding(self):
        return self.in_embed.weight.data.cpu().numpy()

In [7]:
BATCH_SIZE = 128
word_dataloader = torch.utils.data.DataLoader(word_dataset,batch_size=BATCH_SIZE)

In [8]:
model = Word2Vec()
model = model.cuda()
optimizer = torch.optim.SGD(params=model.parameters(), lr=0.8)
USE_CUDA = torch.cuda.is_available()
EPOCH = 100
itr = 0
loss_list = []
for epoch in range(EPOCH):
    for i,(cen,pos,neg) in enumerate(word_dataloader):
        input_labels = cen.long()
        pos_labels = pos.long()
        neg_labels = neg.long()
        
        if USE_CUDA:
            input_labels = input_labels.cuda()
            pos_labels = pos_labels.cuda()
            neg_labels = neg_labels.cuda()

            loss = model(input_labels, pos_labels, neg_labels).mean()  # 传入参数给forward()函数

            loss.backward()
            optimizer.step()
            loss_list.append(loss.detach().cpu().numpy())
            optimizer.zero_grad()
        if (i+1) % 50 == 0:
            itr += 50
            print('epoch\t:{}\titr:{}\tloss:{}\n'.format(epoch,itr,loss))

epoch	:0	itr:50	loss:280.02642822265625

epoch	:0	itr:100	loss:233.15731811523438

epoch	:0	itr:150	loss:186.1634521484375

epoch	:0	itr:200	loss:172.9337615966797

epoch	:0	itr:250	loss:170.69053649902344

epoch	:0	itr:300	loss:173.353759765625

epoch	:0	itr:350	loss:162.4599151611328

epoch	:0	itr:400	loss:171.92755126953125

epoch	:0	itr:450	loss:181.28237915039062

epoch	:0	itr:500	loss:175.28729248046875

epoch	:0	itr:550	loss:173.6910400390625

epoch	:0	itr:600	loss:158.22682189941406

epoch	:0	itr:650	loss:148.56533813476562

epoch	:0	itr:700	loss:166.8938446044922

epoch	:0	itr:750	loss:139.20364379882812

epoch	:0	itr:800	loss:157.62513732910156

epoch	:0	itr:850	loss:145.94558715820312

epoch	:0	itr:900	loss:138.74493408203125

epoch	:0	itr:950	loss:126.54096984863281

epoch	:0	itr:1000	loss:145.09091186523438

epoch	:0	itr:1050	loss:139.12619018554688

epoch	:0	itr:1100	loss:128.59689331054688

epoch	:0	itr:1150	loss:129.71397399902344

epoch	:0	itr:1200	loss:125.27648925781

In [None]:
word_embedding = model.input_embedding()

In [None]:
f = open('word_embedding.pkl','wb')
idx2word= pickle.dump(word_embedding,f)
f.close()