In [1]:
import torch
import os

os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('distilbert-imdb')

tokenizer

DistilBertTokenizerFast(name_or_path='distilbert-imdb', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

In [2]:
from datasets import load_dataset, concatenate_datasets

dataset = load_dataset('imdb')
dataset = concatenate_datasets([dataset[i] for i in ['train', 'test']])


def f(data):
    text = [i['text'] for i in data]
    label = [i['label'] for i in data]

    data = tokenizer(text,
                     padding=True,
                     truncation=True,
                     max_length=50,
                     return_tensors='pt').to(device)

    data['labels'] = torch.LongTensor(label).to(device)

    return data


loader = torch.utils.data.DataLoader(dataset,
                                     batch_size=4,
                                     shuffle=True,
                                     drop_last=True,
                                     collate_fn=f)

len(loader), next(iter(loader))

(12500,
 {'input_ids': tensor([[  101,  1996, 25672, 12083,  3064, 12944,  1997,  2023,  3185,  2003,
           2848,  1051,  1005,  6994,  2063,  1005,  1055,  2836,  1012,  1999,
           2735, 13544, 29257,  1998, 16668, 16668, 13800,  1012,  2515, 10334,
           2079,  2009,  2488,  2084,  1051,  1005,  6994,  2063,  1029,  1045,
           2123,  1005,  1056,  2228,  2061,  1012,  2054,  1037,  2307,   102],
         [  101, 13865, 12001,  1005,  1055,  1999,  1996,  2240,  1997,  2543,
           2003, 23626,  1998,  5681,  1037,  5621,  2317,  1011, 14161, 12722,
           3709,  4536,  1010,  2130,  2065,  2320,  2030,  3807,  2057,  2453,
           2514,  2066,  2057,  1005,  2310,  2042,  2091,  2714,  4925,  2077,
           1012,  2129,  2071,  2028,  2025,  2043, 16235, 24201,  1010,   102],
         [  101,  2009,  1005,  1055,  2307,  2129,  2023,  3185,  8005,  2017,
           2247,  1012,  1045,  2018,  7481, 11680,  1012,  2123,  1005,  1056,
           2729,

In [3]:
from transformers import AutoModelForSequenceClassification

model_critic = AutoModelForSequenceClassification.from_pretrained(
    'distilbert-imdb').to(device)

model_critic.config

  return self.fget.__get__(instance, owner)()


DistilBertConfig {
  "_attn_implementation_autoset": true,
  "_name_or_path": "distilbert-imdb",
  "activation": "gelu",
  "architectures": [
    "DistilBertForSequenceClassification"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "id2label": {
    "0": "NEGATIVE",
    "1": "POSITIVE"
  },
  "initializer_range": 0.02,
  "label2id": {
    "NEGATIVE": 0,
    "POSITIVE": 1
  },
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "pad_token_id": 0,
  "problem_type": "single_label_classification",
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "torch_dtype": "float32",
  "transformers_version": "4.49.0",
  "vocab_size": 30522
}

In [4]:
optimizer = torch.optim.Adam(model_critic.parameters(), lr=1e-5)

for i, data in enumerate(loader):
    out = model_critic(**data)
    out.loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % 1000 == 0:
        acc = (out.logits.argmax(1) == data['labels']).sum() / len(
            data['labels'])
        print(i, len(loader), out.loss.item(), acc.item())
        
model_critic.save_pretrained('model/critic')

0 12500 0.44071608781814575 0.75


KeyboardInterrupt: 