In [None]:
!pip install pytorch_transformers

In [None]:
from pytorch_transformers import  BertModel, BertConfig,BertTokenizer
import torch
import torch.nn as nn
import pandas as pd
import gc

In [None]:
class TextNet(nn.Module):
    def __init__(self,  code_length): #code_length为fc映射到的维度大小
        super(TextNet, self).__init__()

        modelConfig = BertConfig.from_pretrained('bert-base-uncased')
        self.textExtractor = BertModel.from_pretrained('bert-base-uncased', config=modelConfig)
        embedding_dim = self.textExtractor.config.hidden_size

        self.fc = nn.Linear(embedding_dim, code_length)
        self.tanh = torch.nn.Tanh()

    def forward(self, tokens, segments, input_masks):
        output=self.textExtractor(tokens, token_type_ids=segments,
                                         attention_mask=input_masks)
        text_embeddings = output[0][:, 0, :]
        #output[0](batch size, sequence length, model hidden dimension)
        
        #return text_embeddings
        features = self.fc(text_embeddings)
        del output, text_embeddings
        gc.collect()
        features=self.tanh(features)
        return features

In [None]:
textNet = TextNet(code_length=32)

## dataset

In [None]:
titles = pd.read_csv("../input/shopee-product-matching/train.csv")

In [None]:
texts = ["[CLS] "+unit+" [SEP]" for unit in titles.title]

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

tokens, segments, input_masks = [], [], []
for text in texts:
    tokenized_text = tokenizer.tokenize(text) #用tokenizer对句子分词
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)#索引列表
    tokens.append(indexed_tokens)
    segments.append([0] * len(indexed_tokens))
    input_masks.append([1] * len(indexed_tokens))

max_len = max([len(single) for single in tokens]) #最大的句子长度

for j in range(len(tokens)):
    padding = [0] * (max_len - len(tokens[j]))
    tokens[j] += padding
    segments[j] += padding
    input_masks[j] += padding

In [None]:
tokens_tensor = torch.tensor(tokens)
segments_tensors = torch.tensor(segments)
input_masks_tensors = torch.tensor(input_masks)

In [None]:
text_hashCodes = textNet(tokens_tensor , segments_tensors , input_masks_tensors )
text_hashCodes.shape