In [1]:
import torch, sys
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel


In [24]:
bert = BertModel.from_pretrained('bert-base-multilingual-cased').cuda()

In [5]:
bert

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(119547, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
         

In [9]:
bert.encoder.layer

ModuleList(
  (0): BertLayer(
    (attention): BertAttention(
      (self): BertSelfAttention(
        (query): Linear(in_features=768, out_features=768, bias=True)
        (key): Linear(in_features=768, out_features=768, bias=True)
        (value): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (output): BertSelfOutput(
        (dense): Linear(in_features=768, out_features=768, bias=True)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (intermediate): BertIntermediate(
      (dense): Linear(in_features=768, out_features=3072, bias=True)
    )
    (output): BertOutput(
      (dense): Linear(in_features=3072, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (1): BertLayer(
    (attention): BertAttention(
      (self)

In [14]:
from create_dataset import DataloaderSC, DataloaderTC
batch_size = 32
lang_list = ['en']#, 'zh', 'de', 'es']
model_name_or_path = 'bert-base-multilingual-cased'
model_type = 'bert'
mode_list = ['train']#, 'dev', 'test']
max_seq_length = 128
data_dir = './data/xnli'
dataloader, iter_dataloader, labels = DataloaderSC(lang_list=lang_list,
                                              model_name_or_path=model_name_or_path,
                                              model_type=model_type,
                                              mode_list=mode_list,
                                              data_dir=data_dir,
                                              max_seq_length=max_seq_length,
                                              batch_size=batch_size)

Loading features from cached file ./data/xnli/cached_feature_train_en_bert-base-multilingual-cased_128
Loading features from cached file ./data/xnli/cached_feature_dev_en_bert-base-multilingual-cased_128
Loading features from cached file ./data/xnli/cached_feature_test_en_bert-base-multilingual-cased_128
Loading features from cached file ./data/xnli/cached_feature_train_zh_bert-base-multilingual-cased_128
Loading features from cached file ./data/xnli/cached_feature_dev_zh_bert-base-multilingual-cased_128
Loading features from cached file ./data/xnli/cached_feature_test_zh_bert-base-multilingual-cased_128
Loading features from cached file ./data/xnli/cached_feature_train_de_bert-base-multilingual-cased_128
Loading features from cached file ./data/xnli/cached_feature_dev_de_bert-base-multilingual-cased_128
Loading features from cached file ./data/xnli/cached_feature_test_de_bert-base-multilingual-cased_128
Loading features from cached file ./data/xnli/cached_feature_train_es_bert-base-mu

In [17]:
from utils import get_data, get_metric

for lg_index, lg in enumerate(lang_list):
    inputs = get_data(lg, 'train', dataloader, iter_dataloader)
    break

In [20]:
inputs

{'input_ids': tensor([[  101,   128,   114,  ...,     0,     0,     0],
         [  101,   113,   177,  ...,     0,     0,     0],
         [  101, 10167, 14763,  ...,     0,     0,     0],
         ...,
         [  101, 15946, 10355,  ...,     0,     0,     0],
         [  101, 10357, 30518,  ...,     0,     0,     0],
         [  101, 10882, 30488,  ...,     0,     0,     0]], device='cuda:0'),
 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0'),
 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]], device='cuda:0'),
 'labels': tensor([2, 2, 2, 1, 1, 1, 2, 0, 2, 1, 0, 1, 0, 2, 1, 2, 1, 2, 2, 0,

In [25]:
bert(input_ids=inputs['input_ids'],
   attention_mask=inputs['attention_mask'],
   token_type_ids=inputs['token_type_ids'])

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[ 0.1560,  0.1230,  0.2009,  ...,  0.4411, -0.1311, -0.0379],
         [ 0.0520,  0.0826,  0.0486,  ...,  0.7898,  0.0891, -0.1957],
         [-0.0241, -0.1823,  0.4683,  ...,  0.5369,  0.0843, -0.4512],
         ...,
         [-0.0347, -0.0031,  0.4630,  ...,  0.0147,  0.1195,  0.3126],
         [-0.2506, -0.0086,  0.5796,  ...,  0.1662, -0.0300,  0.4108],
         [-0.2346, -0.1130,  0.4524,  ...,  0.5567,  0.0497, -0.2308]],

        [[ 0.2691,  0.2161,  0.4421,  ...,  0.5187, -0.0949, -0.0696],
         [-0.0736, -0.2158,  0.4108,  ...,  0.8857,  0.0219, -0.5870],
         [-0.3079, -0.6461,  0.6990,  ...,  0.5714, -0.3067,  0.3699],
         ...,
         [-0.4748, -0.3428,  0.1623,  ..., -0.1648, -0.1211, -0.1244],
         [-0.3601, -0.2409,  0.1604,  ...,  0.1717,  0.0998, -0.3816],
         [-0.3106, -0.3102,  0.2656,  ..., -0.1083, -0.1761, -0.0273]],

        [[ 0.3186,  0.1389,  0.3918,  ...,  0.8319, -

In [41]:
l1out = bert.encoder.layer[0](bert.embeddings(input_ids=inputs['input_ids'],
   token_type_ids=inputs['token_type_ids']))
l1out[0].shape

torch.Size([32, 128, 768])

In [43]:
bert.encoder.layer[1](l1out[0])[0].shape

torch.Size([32, 128, 768])