In [43]:
import keras
import torch
import transformers
import pandas as pd
from datasets import Dataset
import torch
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
from datasets import load_from_disk


In [44]:
dataset = load_from_disk('data/distilbert-base-uncased_tokenized_dataset')
dataset

DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 7887
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 1972
    })
})

In [45]:
dataset.column_names

{'train': ['labels', 'input_ids', 'attention_mask'],
 'test': ['labels', 'input_ids', 'attention_mask']}

In [46]:
#create data collator
model_name = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

#create data loader
train_dataloader = DataLoader(dataset['train'], batch_size=16, collate_fn=data_collator)

In [47]:
for batch in train_dataloader:
    print(batch)
    break

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'labels': tensor([[1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
         0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1,
         1, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
         0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 

In [54]:
labels = batch.pop('labels')

In [55]:
batch

{'input_ids': tensor([[  101,  1041,  3676,  ...,     0,     0,     0],
        [  101,  6622,  1997,  ...,     0,     0,     0],
        [  101,  3795,  1060,  ...,     0,     0,     0],
        ...,
        [  101,  2019,  2552,  ...,     0,     0,     0],
        [  101,  3010, 12012,  ...,     0,     0,     0],
        [  101,  4861,  2012,  ...,     0,     0,     0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}

In [56]:
#load model
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=50, problem_type='multi_label_classification')

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'pre_classifier.weight', 'classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [59]:
output = model(**batch)

In [65]:
with torch.no_grad():
    output = model(**batch)
print(output)

SequenceClassifierOutput(loss=None, logits=tensor([[ 0.0617,  0.0664, -0.0906,  0.1200, -0.1792,  0.0214, -0.0105,  0.1984,
          0.1139, -0.1736,  0.0925,  0.0239, -0.1112, -0.0003,  0.0889, -0.1375,
         -0.0091,  0.0537,  0.0994,  0.0878,  0.1000,  0.0403, -0.0995,  0.0165,
         -0.1018,  0.1070, -0.0768, -0.1758,  0.2154, -0.0615, -0.1563,  0.1224,
         -0.0645,  0.0374, -0.0796, -0.0093,  0.0024,  0.0665,  0.1369,  0.0858,
         -0.0881,  0.0556,  0.0411,  0.0982, -0.0972, -0.0228,  0.0071,  0.0676,
         -0.1466, -0.1093],
        [ 0.0366,  0.0838, -0.1000,  0.0857, -0.1110,  0.0560, -0.0626,  0.1722,
          0.1241, -0.1956,  0.0987,  0.0126, -0.1091, -0.0401,  0.1115, -0.1211,
         -0.0608,  0.0620,  0.0470,  0.0753,  0.0720, -0.0136, -0.0844, -0.0160,
         -0.0615,  0.1050, -0.0488, -0.1687,  0.1910, -0.0127, -0.1063,  0.1125,
         -0.0453,  0.0122, -0.0347, -0.0080,  0.0497,  0.0742,  0.0893,  0.1397,
         -0.1149,  0.0124,  0.0150,  0