In [32]:
import torch
from datasets import load_dataset

class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        self.dataset = load_dataset(path='data_dir/stanfordnlp_imdb/', split=split)
        
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, i):
        text = self.dataset[i]['text']
        label = self.dataset[i]['label']

        return text, label

dataset = Dataset('train')

print(len(dataset))

dataset[0]

    

25000


('I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far between, e

In [33]:
from transformers import BertTokenizer

token = BertTokenizer.from_pretrained('model_dir/bert-large-uncased/')

token




BertTokenizer(name_or_path='model_dir/bert-large-uncased/', vocab_size=30522, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [34]:
def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]

    # encoding
    data = token.batch_encode_plus(batch_text_or_text_pairs=sents,
                                    truncation=True,
                                    padding = 'max_length',
                                    max_length = 512,
                                    return_tensors='pt',
                                    return_length=True)

    # input_ids numbers after encoding
    # atention_mask: padding positions are 0, the others are 1
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    labels = torch.LongTensor(labels)

    return input_ids, attention_mask, token_type_ids, labels

loader = torch.utils.data.DataLoader(dataset = dataset,
                                     batch_size=16,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    break

print(len(loader))
input_ids.shape, attention_mask.shape, token_type_ids.shape, labels

    

1562


(torch.Size([16, 512]),
 torch.Size([16, 512]),
 torch.Size([16, 512]),
 tensor([1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1]))

In [49]:
from transformers import BertModel

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

pretrained = BertModel.from_pretrained('model_dir/bert-large-uncased/').to(device)

# num = 0
for param in pretrained.parameters():
    param.requires_grad_(True)

# test model
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
token_type_ids = token_type_ids.to(device)

out = pretrained(input_ids=input_ids,
                 attention_mask=attention_mask,
                 token_type_ids=token_type_ids)

out.last_hidden_state.shape

391


torch.Size([16, 512, 1024])

In [50]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(1024, 2)

    def forward(self, input_ids, attention_mask, token_type_ids):
        # with torch.no_grad():
        out = pretrained(input_ids=input_ids,
                         attention_mask=attention_mask,
                         token_type_ids=token_type_ids)
        out = self.fc(out.last_hidden_state[:, 0])
        # out = out.softmax(dim=1)
        return out

model = Model().to(device)

model(input_ids=input_ids,
      attention_mask=attention_mask,
      token_type_ids=token_type_ids).shape


torch.Size([16, 2])

In [60]:
from transformers import AdamW
import time
print(device)

optimizer = AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()

start = time.time()
model.train()
# for param in model.parameters():
#     print(param.requires_grad)
iter_start = start
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    token_type_ids = token_type_ids.to(device)
    labels = labels.to(device)

    out = model(input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids)
    # print(out.grad_fn)
    # print(input_ids.requires_grad)
    # break
    loss = criterion(out, labels)
    # print(loss.shape)
    # print(loss)
    # break
    # loss = loss.mean()
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % 100 == 0:
        out = out.argmax(dim=1)
        accuracy = (out == labels).sum().item()/ len(labels)
        # print(len(labels))
        iter_time = time.time() - iter_start
        print(iter_time, i, loss.item(), accuracy)
        iter_start = time.time()
        
end = time.time()
elapsed_time = end - start
print(f"the Elapsed time: {elapsed_time}")

cuda
1.0817129611968994 0 0.49904462695121765 0.75
104.07298970222473 100 0.34007152915000916 0.8125
104.49229693412781 200 0.4733685553073883 0.75
104.48357939720154 300 0.6417906284332275 0.6875
104.49652194976807 400 0.5343965888023376 0.8125
104.48903036117554 500 0.30215781927108765 0.875
104.39909744262695 600 0.5261951088905334 0.875
104.40833878517151 700 0.42340588569641113 0.8125
104.4056625366211 800 0.32407280802726746 0.8125
104.47412610054016 900 0.32360363006591797 0.8125
104.47718501091003 1000 0.5297964215278625 0.8125
104.47092652320862 1100 0.49231386184692383 0.8125
104.42364931106567 1200 0.2538117468357086 1.0
104.45667576789856 1300 0.9778323173522949 0.5625
104.47153663635254 1400 0.3411208987236023 0.8125
104.51145648956299 1500 0.29280781745910645 0.8125
the Elapsed time: 1630.9658122062683


In [80]:
torch.save(model.state_dict(), 'model_dir/bert_large_uncased_imdb_bs16.pth')
torch.save(model, 'model_dir/bert_large_uncased_imdb_bs16')
pretrained.save_pretrained('model_dir/bert_large_uncased_imdb_pretrained')
token.save_pretrained('model_dir/bert_large_uncased_imdb_pretrained')
# torch.save(pretrained.state_dict(), 'model_dir/bert_large_uncased_imdb_pretrained.pth')
# torch.save(pretrained, 'model_dir/bert_large_uncased_imdb_pretrained')
print("finished")

finished


In [85]:
# new_model = Model().to(device)
new_model = model
print("finish")

finish


In [86]:
def test():
    # model = torch.load('model_dir/bert_large_uncased_imdb_bs16').to(device)
    # model = BertModel.from_pretrained('model_dir/bert-large-uncased/').to(device)
    new_model.eval()
    correct = 0
    total = 0
    loader_test = torch.utils.data.DataLoader(dataset=Dataset('test'),
                                            batch_size=32,
                                            collate_fn=collate_fn,
                                            shuffle=True,
                                            drop_last=True)

    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
        if i == 15:
            break
        print(i)
        with torch.no_grad():
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            token_type_ids = token_type_ids.to(device)
            labels = labels.to(device)
            out = new_model(input_ids=input_ids,
                        attention_mask = attention_mask,
                        token_type_ids = token_type_ids)
            out = out.argmax(dim=1)
            correct += (out == labels).sum().item()
            total += len(labels)
            # print(out)
            # print(labels)
            # print("==========================")
    print(correct / total)

test()

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
0.8708333333333333


In [87]:
quit()