In [1]:
import re

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, AutoTokenizer

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', use_fast=True)
model = BertModel.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
import string
import random
char = string.ascii_letters
data = [' '.join([''.join(random.choices(char, k=random.randint(2, 10))) for _ in range(random.randint(2, 10))]) for __ in range(16384)]
MAX_LEN = max([len(x) for x in data])

In [4]:
dloader = DataLoader(data, shuffle=False, batch_size=2048, num_workers=16)

In [5]:
model = model.cuda()

In [6]:
model.eval()

with torch.no_grad():
    for inputs in dloader:
        token = tokenizer(inputs, max_length=MAX_LEN, padding='max_length', return_tensors='pt')
        token = {k: v.cuda() for k, v in token.items()}
        output = model(**token)
        # set output_hidden_states=True for outputing all the hidden states.

tensor([[[-0.4848, -0.0151,  0.0911,  ..., -0.3857,  0.1335,  0.6159],
         [ 0.0987,  0.4093,  0.5425,  ..., -0.4955,  0.7590,  0.4389],
         [ 0.1751, -0.0626,  0.6250,  ...,  0.0959, -0.0108, -0.1789],
         ...,
         [-0.3851, -0.1833,  0.6297,  ..., -0.2146, -0.0583, -0.0016],
         [-0.3558, -0.2439,  0.7346,  ..., -0.2488, -0.1108, -0.0095],
         [-0.2118, -0.2361,  0.7782,  ..., -0.2775, -0.3132, -0.1115]],

        [[-0.3338, -0.0226,  0.1558,  ..., -0.2014,  0.3652,  0.7062],
         [ 0.3280, -0.6815,  0.9401,  ..., -0.5096,  0.1239,  0.3813],
         [ 0.5583,  0.4054,  1.1353,  ..., -0.2188,  0.3670,  0.3122],
         ...,
         [-0.2832, -0.2842,  1.3165,  ..., -0.0517, -0.2613,  0.3995],
         [-0.3255, -0.7889,  0.1604,  ...,  0.0504,  0.1252,  0.2000],
         [-0.3552, -0.8441,  0.1096,  ...,  0.0565,  0.1755,  0.2790]],

        [[-0.4440, -0.0039,  0.2062,  ...,  0.0895,  0.5349,  0.3113],
         [-0.1341,  0.0839,  1.0658,  ..., -0

tensor([[[-0.4635,  0.2641,  0.3248,  ..., -0.1524,  0.2391,  0.6577],
         [ 0.6241, -0.1916,  1.2526,  ...,  0.3227,  0.4890,  0.3152],
         [ 0.3638, -0.2104,  1.5256,  ..., -0.1833,  0.0624,  0.0318],
         ...,
         [-0.0744,  0.0380,  0.6655,  ..., -0.1614, -0.0075,  0.1989],
         [-0.0342,  0.0979,  0.6397,  ..., -0.0475, -0.0650,  0.3023],
         [-0.0184,  0.0367,  0.5702,  ..., -0.0415,  0.0446,  0.3295]],

        [[-0.4902,  0.1815,  0.2397,  ..., -0.2602,  0.3162,  0.6480],
         [ 0.9713, -0.5351,  1.5599,  ..., -0.3720,  0.4008,  1.0717],
         [ 0.5049, -0.2564,  1.3704,  ...,  0.0292, -0.7465,  0.0756],
         ...,
         [-0.0016, -0.1034,  0.5715,  ...,  0.0462, -0.1485,  0.4610],
         [ 0.1327, -0.1750,  0.4338,  ..., -0.0790,  0.1011,  0.3128],
         [ 0.2605, -0.0033,  0.3789,  ..., -0.0085,  0.0167,  0.3821]],

        [[-0.4854,  0.1374,  0.4754,  ...,  0.0568,  0.2923,  0.3618],
         [-0.7586, -0.3206,  1.2048,  ..., -0

tensor([[[-7.5309e-01, -2.2668e-01,  1.4697e-01,  ..., -4.4257e-01,
          -1.8147e-01,  6.1107e-01],
         [-8.3634e-02, -3.6715e-01,  7.6174e-01,  ..., -3.2895e-01,
           9.5023e-02,  2.2395e-01],
         [-4.3786e-01, -4.4977e-01,  1.0203e+00,  ..., -2.3588e-01,
          -3.2342e-01, -2.2941e-01],
         ...,
         [ 1.3243e-02, -3.3638e-02,  6.2547e-01,  ...,  9.6060e-02,
          -1.2072e-01,  3.5177e-01],
         [-1.5282e-01, -1.0529e-01,  5.5747e-01,  ..., -1.7104e-02,
          -7.7191e-02,  2.6851e-01],
         [-3.5911e-01, -4.3155e-01,  4.1382e-01,  ..., -2.8063e-01,
          -2.1131e-02,  1.3007e-01]],

        [[-6.2550e-01,  1.5979e-01,  2.8757e-01,  ...,  1.3973e-01,
           3.5426e-01,  2.0924e-01],
         [-7.1754e-01, -1.3344e-01,  1.2894e+00,  ..., -1.6469e-02,
           4.6079e-01,  1.4940e-01],
         [-3.2191e-01,  4.4603e-01,  3.6463e-01,  ...,  5.7548e-02,
           1.5289e-01,  1.0667e-01],
         ...,
         [-4.5201e-01,  5

In [7]:
# last_hidden_state = outputs.last_hidden_state
# pooler_output = outputs.pooler_output
# hidden_states = outputs.hidden_states