In [None]:
import torch
from datasets import load_dataset

class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        self.dataset = load_dataset(path='data_dir/stanfordnlp_imdb/', split=split)
        
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, i):
        text = self.dataset[i]['text']
        label = self.dataset[i]['label']

        return text, label

dataset = Dataset('train')

print(len(dataset))

dataset[0]

    

In [None]:
from transformers import BertTokenizer

token = BertTokenizer.from_pretrained('model_dir/bert-large-uncased/')

token


In [None]:
def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]

    # encoding
    data = token.batch_encode_plus(batch_text_or_text_pairs=sents,
                                    truncation=True,
                                    padding = 'max_length',
                                    max_length = 512,
                                    return_tensors='pt',
                                    return_length=True)

    # input_ids numbers after encoding
    # atention_mask: padding positions are 0, the others are 1
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    labels = torch.LongTensor(labels)

    return input_ids, attention_mask, token_type_ids, labels

loader = torch.utils.data.DataLoader(dataset = dataset,
                                     batch_size=16,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    break

print(len(loader))
input_ids.shape, attention_mask.shape, token_type_ids.shape, labels

    

In [None]:
from transformers import BertModel
from transformers import BertTokenizer
# import os
from MyModel import Model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        

model = Model(pretrained_model='model_dir/bert-large-uncased/', tokenizer='model_dir/bert-large-uncased').to(device)

input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
token_type_ids = token_type_ids.to(device)
model(input_ids=input_ids,
      attention_mask=attention_mask,
      token_type_ids=token_type_ids).shape

# model.save('model_dir/hello1')


In [None]:
from transformers import AdamW
import time
print(device)

optimizer = AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()

start = time.time()
model.train()
# for param in model.parameters():
#     print(param.requires_grad)
iter_start = start
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    token_type_ids = token_type_ids.to(device)
    labels = labels.to(device)

    out = model(input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids)
    # print(out.grad_fn)
    # print(input_ids.requires_grad)
    # break
    loss = criterion(out, labels)
    # print(loss.shape)
    # print(loss)
    # break
    # loss = loss.mean()
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    if i == 300:
        break

    if i % 30 == 0:
        out = out.argmax(dim=1)
        accuracy = (out == labels).sum().item()/ len(labels)
        # print(len(labels))
        iter_time = time.time() - iter_start
        print(iter_time, i, loss.item(), accuracy)
        iter_start = time.time()
        
        
end = time.time()
elapsed_time = end - start
print(f"the Elapsed time: {elapsed_time}")

In [None]:
# torch.save(model.state_dict(), 'model_dir/new_bert_large_trained')
# torch.save(model, 'model_dir/new_bert_large_trained')
model.save('model_dir/hello1')
# model.pretrained.save_pretrained('model_dir/1_new')
# model.pretrained.save_pretrained('model_dir/2_new')
# token.save_pretrained('model_dir/2_new')
# torch.save(model, 'model_dir/2_new/1.pth')
# torch.save({
#     'model_state_dict': model.state_dict(),
#     'config': model.pretrained.config.to_dict()
# }, 'model_dir/new_bert_large_trained')

# pretrained.save_pretrained('model_dir/bert_large_uncased_imdb_pretrained')
# token.save_pretrained('model_dir/bert_large_uncased_imdb_pretrained')
# torch.save(pretrained.state_dict(), 'model_dir/bert_large_uncased_imdb_pretrained.pth')
# torch.save(pretrained, 'model_dir/bert_large_uncased_imdb_pretrained')
print("finished")

In [None]:
# new_model = Model().to(device)
for name, param in model.pretrained.named_parameters():
    for name, param in model.pretrained.named_parameters():
        if "encoder.layer.11" in name or "pooler" in name:
            print(f"Parameter name: {name}")
            print(f"Parameter value: {param}\n")

print('--------')
for name, param in model.named_parameters():
    print(f"Parameter name: {name}")
    print(f"Parameter value: {param}\n")
    break
new_model = model
print("finish")

In [None]:
def test():
    # model = torch.load('model_dir/bert_large_uncased_imdb_bs16').to(device)
    # model = BertModel.from_pretrained('model_dir/bert-large-uncased/').to(device)
    new_model.eval()
    correct = 0
    total = 0
    loader_test = torch.utils.data.DataLoader(dataset=Dataset('test'),
                                            batch_size=32,
                                            collate_fn=collate_fn,
                                            shuffle=True,
                                            drop_last=True)

    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
        if i == 15:
            break
        print(i)
        with torch.no_grad():
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            token_type_ids = token_type_ids.to(device)
            labels = labels.to(device)
            out = new_model(input_ids=input_ids,
                        attention_mask = attention_mask,
                        token_type_ids = token_type_ids)
            out = out.argmax(dim=1)
            correct += (out == labels).sum().item()
            total += len(labels)
            # print(out)
            # print(labels)
            # print("==========================")
    print(correct / total)

test()

In [None]:
import torch 
torch.cuda.empty_cache()