In [40]:
from pytorch_pretrained_bert import BertTokenizer, BertConfig, BertAdam, BertForSequenceClassification, BertModel
import torch


In [2]:
data_file="data/train_data.csv"
with open(data_file) as f:
    dataset_BERT=[line.split(",") for line in f]
dataset_BERT=[[pair[0].replace("\n",""),pair[1].replace("\n","")] for pair in dataset_BERT]
dataset_BERT=["[CLS] "+sentence[0]+" [SEP]" for sentence in dataset_BERT]

In [3]:
tokeniser=BertTokenizer.from_pretrained('bert-base-multilingual-uncased',do_lower_case=True)
tokenised_dataset=[tokeniser.tokenize(sentence) for sentence in dataset_BERT]

In [5]:
MAX_LEN=128
input_ids=[tokeniser.convert_tokens_to_ids(sentence) for sentence in tokenised_dataset]

In [9]:
attention_masks=[]
for seq in input_ids:
    seq_mask=[float(i>0) for i in seq]
    attention_masks.append(seq_mask)

In [132]:
model_bert=BertModel.from_pretrained('bert-base-multilingual-uncased')

In [133]:
t=input_ids[1]
print(t,len(t))
t=torch.tensor(t)
t=t.unsqueeze(0)
# t=model(t.unsqueeze(0))
# t[0][-1].shape
# t[1]
# t.unsqueeze(0)==t.view(1,-1)

[101, 12666, 46177, 10503, 10169, 19719, 47454, 10840, 89405, 27421, 49885, 10165, 57466, 10165, 10263, 10106, 13701, 53600, 143, 70705, 143, 41287, 38866, 143, 10106, 31164, 10137, 143, 10106, 46476, 35403, 10218, 10119, 10205, 10688, 59359, 80780, 10368, 10119, 28673, 102] 41


In [88]:
model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(105879, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): BertLayerNorm()
    (dropout): Dropout(p=0.1)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias

In [62]:
sentence = "the red cube is at your left"
tokens = ["[CLS]"] + tokeniser.tokenize(sentence) + ["[SEP]"] 
ids = torch.tensor(tokeniser.convert_tokens_to_ids(tokens))

# model(ids.unsqueeze(0))

tensor([[ 0.1260, -0.1680]], grad_fn=<AddmmBackward>)

In [112]:
import torch
from pytorch_pretrained_bert import BertModel
from torch import nn
device = torch.device("cpu")

## FIRST EXPERIMENT: BERT as a third encoder

## Originally a combination of the output of a GRU network and a CNN, I am now augmenting
## the architecture to include a BERT model whose output is also added to the final context
## vector
class FactsOrAnalysisBERT(nn.Module):
    def __init__(self,embeddings_tensor,hidden_size=512,dropout=.5,gru_dropout=.3,embedding_size=200,
                out_channels=100,
                kernel_size=3,
                max_sen_len=50):
        super(FactsOrAnalysisBERT, self).__init__()
        self.hidden_size = hidden_size
        self.embedding_size = embedding_size
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.max_sen_len = max_sen_len
        self.embedding = nn.Embedding.from_pretrained(embeddings_tensor)
        self.rnn = nn.GRU(embedding_size,
                          hidden_size,
                          num_layers=3,
                          batch_first=True,
                          bidirectional=True,
                          dropout=gru_dropout)

        # Taken from the TextCNN implementation by Anubhav Gupta
        # https://github.com/AnubhavGupta3377/Text-Classification-Models-Pytorch
        self.conv1 = nn.Sequential(
            nn.Conv1d(in_channels=self.embedding_size, out_channels=self.out_channels, kernel_size=self.kernel_size),
            nn.ReLU(),
            nn.MaxPool1d(self.max_sen_len - self.kernel_size+1)
        )
        
        self.linear = nn.Linear(self.hidden_size+self.out_channels,2)
        self.softmax = nn.Softmax(dim=1)
        self.drop = nn.Dropout(dropout)
        
        # BERT
        self.bert=BertModel.from_pretrained('bert-base-multilingual-uncased')
    
    def initHidden(self):
        n=1 if self.rnn.bidirectional==False else 2
        l=self.rnn.num_layers
        return torch.zeros(n*l, 1, self.hidden_size, device=device)
    
    def forward(self,input_tensor):
        hidden=self.initHidden()
#         print("input:",input_tensor.shape)
        output = self.embedding(input_tensor)
#         print("output:",output.shape)
        conv_output = self.conv1(output.permute(0,2,1))
        conv_output = conv_output.squeeze(2)
        output = self.drop(output)
        _, hidden = self.rnn(output)
        hidden = hidden[0]
        cat_tensors = torch.cat((conv_output,hidden),1)
        cat_tensors = self.drop(cat_tensors)
        output = self.linear(cat_tensors)
        return self.softmax(output)

In [120]:
import gensim
from src.dataset import tensorFromSentence, Lexicon

embeddings_file="../frWac_non_lem_no_postag_no_phrase_200_cbow_cut100.bin"
embeddings = gensim.models.KeyedVectors.load_word2vec_format(embeddings_file,binary=True,unicode_errors='ignore')
embeddings_tensor = torch.FloatTensor(embeddings.vectors)
model=FactsOrAnalysisBERT(embeddings_tensor).to(device)
lexicon=Lexicon(embeddings)
dataset[0]

'[CLS] ils sont stupéfiés entendre la voisine du haut vivre [SEP]'

In [121]:
with open(data_file) as f:
    dataset=[line.split(",") for line in f]
dataset=[[pair[0].replace("\n",""),pair[1].replace("\n","")] for pair in dataset]

In [136]:
u=dataset[1][0]
u=tensorFromSentence(lexicon,u)
u=u.unsqueeze(0)

In [145]:
model_bert(t)[1]

tensor([[ 8.6810e-02, -3.0413e-02,  1.8487e-01,  1.1927e-01,  1.5238e-01,
          4.3714e-01,  1.8695e-01, -1.5598e-03, -1.5734e-01,  2.4240e-01,
         -2.2433e-02, -1.0361e-01,  2.7190e-01, -3.2352e-02, -8.2205e-02,
          1.6817e-01,  1.1829e-01,  1.1273e-01,  1.7296e-01,  1.5149e-01,
         -1.5408e-01, -6.5293e-02,  3.1607e-03,  1.1950e-01,  1.4664e-01,
         -1.3633e-01,  2.2866e-01,  7.2574e-02,  2.0212e-01,  1.2555e-01,
          1.6389e-01,  1.6958e-01,  2.6743e-01, -1.4073e-01,  2.2370e-01,
         -6.0207e-02, -7.5467e-02,  6.2049e-02,  2.5658e-01,  2.5464e-02,
         -9.8667e-03,  2.8484e-01,  1.0591e-01, -1.0415e-01, -2.6662e-01,
          8.6102e-02, -1.7451e-01, -6.4375e-02,  9.9998e-01,  2.1277e-01,
          1.6553e-01, -9.7316e-02,  2.7344e-02, -3.9301e-01,  2.9343e-01,
          9.9998e-01, -3.5759e-01, -1.1904e-01,  1.1037e-01, -1.3515e-01,
         -9.1141e-02, -6.2128e-02,  3.0840e-01,  1.7984e-01, -1.0346e-01,
          6.2051e-02, -9.1614e-02,  2.