In [1]:
# !/usr/bin/python
# -*- coding: UTF-8 -*-
# train

import model
import dataloader
import torch

In [2]:
name = "att_baike" # target file
file = "seg_baike" # seg file

MAX_VOCAB_SIZE = "all" # the vocabulary size, of if is "all", take all words

EMBED_SIZE = 256 # word vector size
WINDOW = 5 # context window
INTERVAL = 0
NEGATIVE = 10 # number of negative samples for each context word

BATCH_SIZE = 150
EPOCH = 10
LR = 0.001

In [3]:
# dataloader
from collections import Counter
from torch.utils import data

word_dict = dataloader.WordCounter2Dict()
word_list = []
with open("source/{}.txt".format(file), 'r', encoding='utf-8') as fin:
    c = Counter()
    for line in fin.readlines():
        text = line.strip().split(" ") #分割成词列表
        word_list.extend(text)
        c += Counter(text)
    word_dict.get_counter(c, MAX_VOCAB_SIZE)

word_dataset = dataloader.WordDataset(word_list, word_dict, WINDOW, INTERVAL, NEGATIVE)
wordloader = data.DataLoader(word_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
MAX_VOCAB_SIZE = word_dict.getlen()

In [4]:
# model and cuda
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())

net = model.WordVector(MAX_VOCAB_SIZE, EMBED_SIZE).to(device)

True


In [5]:
# train
from tqdm import tqdm

list_loss = []
optimizer = torch.optim.Adam(net.parameters(), lr=LR)
for epoch in tqdm(range(EPOCH)):
    net.train()    
#     change lr dynamicly
#     if epoch%10==0:
#         optimizer.param_groups[0]['lr'] *= 0.1
#     train
    for i, (cen_vocab_id, con_vocabs, con_vocabs_weight, neg_samples) in enumerate(wordloader):
        cen_vocab_id = cen_vocab_id.to(device)
        con_vocabs = con_vocabs.to(device)
        con_vocabs_weight = con_vocabs_weight.to(device)
        neg_samples = neg_samples.to(device)        

        optimizer.zero_grad() 
        loss = net(cen_vocab_id, con_vocabs, con_vocabs_weight, neg_samples)
        loss.backward()
        optimizer.step()

    list_loss.append(loss)

100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [35:47<00:00, 214.71s/it]


In [6]:
import numpy
import pandas as pd

words = word_dict.id2word()
vectors = list(net.getvectors().cpu().numpy())
dataframe = {"words": words, "vectors": vectors}
dataframe = pd.DataFrame(dataframe)
dataframe.head()

Unnamed: 0,words,vectors
0,，,"[-0.03429354, 0.050134704, 0.024060125, 0.0936..."
1,的,"[0.05852192, 0.04801449, 0.103424594, 0.145902..."
2,。,"[0.12474231, 0.009065641, 0.007999456, 0.16179..."
3,是,"[0.066203885, 0.0727198, -0.0519521, 0.0806786..."
4,了,"[-0.15804285, 0.09725754, 0.12928239, 0.087344..."


In [7]:
import pickle

with open("result/{}".format(name), 'wb') as f:
    pickle.dump(dataframe, f)

In [8]:
with open("param_adjustment.txt", 'a', encoding='utf-8') as f:
    f.write("name: {name}\n\
MAX_VOCAB_SIZE:{a}, \
EMBED_SIZE:{b}, \
WINDOW:{w}, \
INTERVAL:{i}, \
NEGATIVE:{c}, \
BATCH_SIZE:{d}, \
EPOCH:{e}, \
LR:{f}, \
loss:{loss}\n".\
    format(name=name, a=MAX_VOCAB_SIZE, b=EMBED_SIZE, w=WINDOW, i=INTERVAL, c=NEGATIVE, d=BATCH_SIZE, e=EPOCH, f=LR, loss=list_loss[-1]))

In [9]:
# with open("result/ca_evaluation/test.txt", 'w', encoding='utf-8') as f:
#     f.write("{} {}\n".format(MAX_VOCAB_SIZE, EMBED_SIZE))
#     for word, vector in zip(words, vectors):
#         f.write("{} {}\n".format(word, " ".join('{}'.format(i) for i in vector)))

In [10]:
# " ".join('{}'.format(i) fordd i in vectors[0])