# Pytorch使用Embedding层
此书使用之前tensorflow教程中训练好的gensim模型，用户结合自身情况使用

Pytorch word embedding训练实现：<br>   https://github.com/bamtercelboo/pytorch_word2vec/blob/master/word2vec.py


Pytorch导入训练好的word embedding

In [1]:
# -*- coding: utf-8 -*-

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import xrange

import torch
from torch import nn

import gensim
import numpy as np
from keras.layers import Embedding

emb_weights = gensim.models.Word2Vec.load('brown_skipgram.model')

word_list = "woman women man girl boy green blue did".split()
word2ind = {word: ind for ind, word in enumerate(word_list)}
vocab_size = len(word_list)
embedding_dim = emb_weights.wv.vector_size

embedding_matrix = np.zeros(shape=(vocab_size, embedding_dim))
for ind, word in enumerate(word_list):
    embedding_matrix[ind, :] = emb_weights.wv.syn0[emb_weights.wv.vocab[word].index]

print(embedding_matrix)

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


[[ 2.41059089 -5.51927614  3.74764085  1.30366039 -1.64450049  2.61859632
  -2.77704644 -1.37271416 -1.7589196   0.82422775]
 [ 6.86335182  1.33035862  1.73910975  0.27735218 -0.28538892 -0.17378837
  -7.10598087  0.72247446  1.57231748 -1.01711917]
 [ 1.34075868 -3.23909616  3.30166626  0.99042195  0.02437754  5.61165476
  -2.85212517 -0.9763515  -2.08036757  2.70800734]
 [ 1.83362269 -5.63690138  3.66578293  1.75363398 -0.98059785  3.33684063
  -2.4922204  -1.45831227 -1.53534102  0.64063424]
 [ 1.43617606 -5.53966761  3.85731053  2.8325603  -0.07618857  3.39432216
  -2.63929248 -1.73143554 -1.87638187 -0.61770976]
 [ 3.38406587 -0.33416516  6.58938122  1.98220897 -4.39226103 -0.14040969
   2.61482954 -5.58151579  3.7593286  -1.1867137 ]
 [ 1.03018129  0.23436923  7.27011728  3.36135554 -3.79686093  0.83886689
   2.79096937 -4.7255497   4.0163517  -1.10945129]
 [ 6.96005535  0.48799345 -4.4505558   1.86931384 -7.57206964 -0.22840905
   4.44346809  8.77457142 -3.10547376  0.57873034]]



## 方式1

In [2]:
embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
embedding_layer.weight.data.copy_(torch.from_numpy(embedding_matrix))
print(embedding_layer.weight)

Parameter containing:
tensor([[ 2.4106, -5.5193,  3.7476,  1.3037, -1.6445,  2.6186, -2.7770,
         -1.3727, -1.7589,  0.8242],
        [ 6.8634,  1.3304,  1.7391,  0.2774, -0.2854, -0.1738, -7.1060,
          0.7225,  1.5723, -1.0171],
        [ 1.3408, -3.2391,  3.3017,  0.9904,  0.0244,  5.6117, -2.8521,
         -0.9764, -2.0804,  2.7080],
        [ 1.8336, -5.6369,  3.6658,  1.7536, -0.9806,  3.3368, -2.4922,
         -1.4583, -1.5353,  0.6406],
        [ 1.4362, -5.5397,  3.8573,  2.8326, -0.0762,  3.3943, -2.6393,
         -1.7314, -1.8764, -0.6177],
        [ 3.3841, -0.3342,  6.5894,  1.9822, -4.3923, -0.1404,  2.6148,
         -5.5815,  3.7593, -1.1867],
        [ 1.0302,  0.2344,  7.2701,  3.3614, -3.7969,  0.8389,  2.7910,
         -4.7255,  4.0164, -1.1095],
        [ 6.9601,  0.4880, -4.4506,  1.8693, -7.5721, -0.2284,  4.4435,
          8.7746, -3.1055,  0.5787]])


## 方式2

In [3]:
embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
embedding_layer.weight = nn.Parameter(torch.Tensor(embedding_matrix))
print(embedding_layer.weight)

Parameter containing:
tensor([[ 2.4106, -5.5193,  3.7476,  1.3037, -1.6445,  2.6186, -2.7770,
         -1.3727, -1.7589,  0.8242],
        [ 6.8634,  1.3304,  1.7391,  0.2774, -0.2854, -0.1738, -7.1060,
          0.7225,  1.5723, -1.0171],
        [ 1.3408, -3.2391,  3.3017,  0.9904,  0.0244,  5.6117, -2.8521,
         -0.9764, -2.0804,  2.7080],
        [ 1.8336, -5.6369,  3.6658,  1.7536, -0.9806,  3.3368, -2.4922,
         -1.4583, -1.5353,  0.6406],
        [ 1.4362, -5.5397,  3.8573,  2.8326, -0.0762,  3.3943, -2.6393,
         -1.7314, -1.8764, -0.6177],
        [ 3.3841, -0.3342,  6.5894,  1.9822, -4.3923, -0.1404,  2.6148,
         -5.5815,  3.7593, -1.1867],
        [ 1.0302,  0.2344,  7.2701,  3.3614, -3.7969,  0.8389,  2.7910,
         -4.7255,  4.0164, -1.1095],
        [ 6.9601,  0.4880, -4.4506,  1.8693, -7.5721, -0.2284,  4.4435,
          8.7746, -3.1055,  0.5787]])
