In [None]:
import torch
from datasets import load_dataset

class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        self.dataset = load_dataset(path='data_dir/stanfordnlp_imdb/', split=split)
        
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, i):
        text = self.dataset[i]['text']
        label = self.dataset[i]['label']

        return text, label



In [None]:
from transformers import BertModel
from transformers import BertTokenizer
import os

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

class Model(torch.nn.Module):
    def __init__(self, pretrained_model, tokenizer):
        super().__init__()
        self.fc = torch.nn.Linear(1024, 2)
        self.pretrained = BertModel.from_pretrained(pretrained_model)
        self.tokenizer = BertTokenizer.from_pretrained(tokenizer)
        for param in self.pretrained.parameters():
            param.requires_grad_(True)

    def forward(self, input_ids, attention_mask, token_type_ids):
        # with torch.no_grad():
        out = self.pretrained(input_ids=input_ids,
                         attention_mask=attention_mask,
                         token_type_ids=token_type_ids)
        out = self.fc(out.last_hidden_state[:, 0])
        # out = out.softmax(dim=1)
        return out

    def save(self, save_directory):
        if not os.path.exists(save_directory):
            os.makedirs(save_directory)

        torch.save(self.state_dict(), os.path.join(save_directory, 'model_weights.pth'))
        self.pretrained.save_pretrained(os.path.join(save_directory, 'bert_model'))
        self.tokenizer.save_pretrained(os.path.join(save_directory, 'bert_tokenizer'))

        print(f'model save in: {save_directory}')
    @classmethod
    def load(cls, load_directory):
        pretrained_model_path = os.path.join(load_directory, 'bert_model')
        tokenizer_path = os.path.join(load_directory, 'bert_tokenizer')

        print(pretrained_model_path)
        model = cls(pretrained_model=pretrained_model_path, tokenizer=tokenizer_path)

        model_weight_path = os.path.join(load_directory, 'model_weights.pth')
        model.load_state_dict(torch.load(model_weight_path))

        # model.tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
        print(f'Model loaded from: {load_directory}')

        return model
        
        

# model = Model(pretrained_model='model_dir/bert-large-uncased/').to(device)
model = Model.load(load_directory='model_dir/hello')
model.to(device)
print("finished")



In [None]:
import torch
from datasets import load_dataset

token = model.tokenizer
def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]

    # encoding
    data = token.batch_encode_plus(batch_text_or_text_pairs=sents,
                                    truncation=True,
                                    padding = 'max_length',
                                    max_length = 512,
                                    return_tensors='pt',
                                    return_length=True)

    # input_ids numbers after encoding
    # atention_mask: padding positions are 0, the others are 1
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    labels = torch.LongTensor(labels)

    return input_ids, attention_mask, token_type_ids, labels

dataset = Dataset('train')
loader = torch.utils.data.DataLoader(dataset = dataset,
                                     batch_size=16,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    break

print(len(loader))
input_ids.shape, attention_mask.shape, token_type_ids.shape, labels

print("finished")



In [None]:
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
token_type_ids = token_type_ids.to(device)
model(input_ids=input_ids,
      attention_mask=attention_mask,
      token_type_ids=token_type_ids).shape

In [None]:
def test():
    # model = torch.load('model_dir/bert_large_uncased_imdb_bs16').to(device)
    # model = BertModel.from_pretrained('model_dir/bert-large-uncased/').to(device)
    model.eval()
    correct = 0
    total = 0
    loader_test = torch.utils.data.DataLoader(dataset=Dataset('test'),
                                            batch_size=32,
                                            collate_fn=collate_fn,
                                            shuffle=True,
                                            drop_last=True)

    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
        if i == 15:
            break
        print(i)
        with torch.no_grad():
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            token_type_ids = token_type_ids.to(device)
            labels = labels.to(device)
            out = model(input_ids=input_ids,
                        attention_mask = attention_mask,
                        token_type_ids = token_type_ids)
            out = out.argmax(dim=1)
            correct += (out == labels).sum().item()
            total += len(labels)
            print(correct, total)
            # print(out)
            # print(labels)
            # print("==========================")
    print(correct / total)

test()