In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer,BertModel
from torchtext.legacy import data,datasets
import numpy as np
import random
import time


SEED = 2022
TRAIN = False
BATCH_SIZE=128
N_EPOCHS=5
HIDDEN_DIM=256
OUTPUT_DIM=1
N_LAYERS=2
BIDIRECTIONAL=True
DROPOUT=0.25

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic=True

In [3]:
# 通过类的静态方法获取对象,这是预训练的对象
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [4]:
init_token_id = tokenizer.cls_token_id
eos_token_id = tokenizer.sep_token_id
pad_token_id = tokenizer.pad_token_id
unk_token_id = tokenizer.unk_token_id
print('init_token_id',init_token_id)
print('eos_token_id',eos_token_id)
print('pad_token_id',pad_token_id)
print('unk_token_id',unk_token_id)

init_token_id 101
eos_token_id 102
pad_token_id 0
unk_token_id 100


In [5]:
max_input_len = tokenizer.max_model_input_sizes['bert-base-uncased']# bert的输入句子长度
print('max_input_len',max_input_len)

max_input_len 512


In [26]:
device='cuda' if torch.cuda.is_available() else 'cpu'
# 把句子长度切成510，加入开头跟结尾符
def tokenize_and_crop(sentence):
    tokens = tokenizer.tokenize(sentence)
    tokens = tokens[:max_input_len-2]
    return tokens

def load_data():
    text = data.Field(
        batch_first=True,
        use_vocab=False,
        tokenize=tokenize_and_crop,
        preprocessing=tokenizer.convert_tokens_to_ids,
        init_token=init_token_id,
        pad_token=pad_token_id,
        unk_token=unk_token_id
    )
    label = data.LabelField(dtype=torch.float)
    # 对于自己的数据集可以修改IMDB，对外调用只要
    train_data,test_data = datasets.IMDB.splits(text,label)
    print(train_data)
    train_data,valid_data = train_data.split(random_state=random.seed(SEED))
    print(f'train examples counts:{len(train_data)}')
    print(f'test examples counts:{len(test_data)}')
    print(f'valid examples counts:{len(valid_data)}')

    label.build_vocab(train_data)

    train_iter,valid_iter,test_iter = data.BucketIterator.splits(
        (train_data,valid_data,test_data),
        batch_size=BATCH_SIZE,
        device=device
        )
    return train_iter,valid_iter,test_iter

res = load_data()

<torchtext.legacy.datasets.imdb.IMDB object at 0x7f912f81b588>
train examples counts:17500
test examples counts:25000
valid examples counts:7500


TypeError: 'int' object is not subscriptable

In [38]:
# 查看loaddata内部实现
text = data.Field(
    batch_first=True,
    use_vocab=False,
    tokenize=tokenize_and_crop,
    preprocessing=tokenizer.convert_tokens_to_ids,
    init_token=init_token_id,
    pad_token=pad_token_id,
    unk_token=unk_token_id
)
label = data.LabelField(dtype=torch.float)
# IMDB
train_data,test_data = datasets.IMDB.splits(text_field=text,label_field=label)
print(train_data)
train_data,valid_data = train_data.split(random_state=random.seed(SEED))
print(f'train examples counts:{len(train_data)}')
print(f'test examples counts:{len(test_data)}')
print(f'valid examples counts:{len(valid_data)}')

train examples counts:17500
test examples counts:25000
valid examples counts:7500


In [39]:
train_data.__dict__

{'examples': [<torchtext.legacy.data.example.Example at 0x7f912f771eb8>,
  <torchtext.legacy.data.example.Example at 0x7f912f7d67b8>,
  <torchtext.legacy.data.example.Example at 0x7f912f1e3f98>,
  <torchtext.legacy.data.example.Example at 0x7f912ef0bd68>,
  <torchtext.legacy.data.example.Example at 0x7f9140438e48>,
  <torchtext.legacy.data.example.Example at 0x7f912ec68080>,
  <torchtext.legacy.data.example.Example at 0x7f9140c13da0>,
  <torchtext.legacy.data.example.Example at 0x7f912f75c898>,
  <torchtext.legacy.data.example.Example at 0x7f912ece44a8>,
  <torchtext.legacy.data.example.Example at 0x7f91402b0908>,
  <torchtext.legacy.data.example.Example at 0x7f912f86ae80>,
  <torchtext.legacy.data.example.Example at 0x7f912ecc9d68>,
  <torchtext.legacy.data.example.Example at 0x7f912ee1cd30>,
  <torchtext.legacy.data.example.Example at 0x7f914044db00>,
  <torchtext.legacy.data.example.Example at 0x7f9140aa1358>,
  <torchtext.legacy.data.example.Example at 0x7f91401e8e10>,
  <torchtext

In [35]:
train_iter,valid_iter,test_iter = data.BucketIterator.splits(
    (train_data,valid_data,test_data),
    batch_size=BATCH_SIZE,
    device=device
    )
train_iter

<torchtext.legacy.data.iterator.BucketIterator at 0x7f912f86a7f0>

In [36]:
label.build_vocab(train_data)
label.vocab

<torchtext.legacy.vocab.Vocab at 0x7f912f86a860>

In [40]:
count = datasets.IMDB('')#.data/imdb/aclImdb/test/neg

TypeError: __init__() missing 2 required positional arguments: 'text_field' and 'label_field'