In [1]:
from datasets import load_dataset, load_from_disk
from transformers import BertModel, AutoTokenizer, DataCollatorWithPadding, get_scheduler, AdamW
from transformers.modeling_outputs import TokenClassifierOutput
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import time, datetime
from datasets import load_metric

In [2]:
imdb = load_from_disk("../data/imdb")
imdb

DatasetDict({
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 24000
    })
    attack_eval_truncated: Dataset({
        features: ['text', 'label'],
        num_rows: 100
    })
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    dev: Dataset({
        features: ['text', 'label'],
        num_rows: 1000
    })
})

In [3]:
max_sequence_length = 2
batch_size = 8
learning_rate=2e-05
num_epochs=3
num_log_steps = 1
output_dir = "../output/"
model_dir = "../models/"
checkpoint = "bert-base-cased"

In [4]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [5]:
def tokenize_function(example):
    return tokenizer(example["text"],  truncation=True, padding="max_length", max_length=max_sequence_length)


tokenized_datasets = imdb.map(tokenize_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

  0%|          | 0/25 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/24 [00:00<?, ?ba/s]

In [6]:
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")
tokenized_datasets["train"].column_names

['labels', 'input_ids', 'token_type_ids', 'attention_mask']

In [7]:
train_dataloader = DataLoader(
    tokenized_datasets["train"].select(range(64)), shuffle=True, batch_size=batch_size, collate_fn=data_collator
)
eval_dataloader = DataLoader(
    tokenized_datasets["dev"].select(range(64)), batch_size=batch_size, collate_fn=data_collator
)

test_dataloader = DataLoader(
    tokenized_datasets["test"].select(range(64)), batch_size=batch_size, collate_fn=data_collator
)

In [8]:
for batch in train_dataloader:
    break
{k: v.shape for k, v in batch.items()}

{'labels': torch.Size([8]),
 'input_ids': torch.Size([8, 2]),
 'token_type_ids': torch.Size([8, 2]),
 'attention_mask': torch.Size([8, 2])}

In [9]:
class BertClassifier(nn.Module):
    def __init__(self, checkpoint, n_classes):
        super().__init__()
        self.n_classes = n_classes
        self.model = BertModel.from_pretrained(checkpoint, num_labels=self.n_classes)
        self.dropout = nn.Dropout(0.1) 
        self.hidden_dim = self.model.embeddings.word_embeddings.embedding_dim
        self.classifier_layer = nn.Linear(self.hidden_dim, self.n_classes)
        
    def forward(self, input_ids=None, attention_mask=None,labels=None, token_type_ids=None):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, )
        cls_rep = outputs['last_hidden_state'][:,0,:]
        cls_rep = self.dropout(cls_rep)
        logits =self.classifier_layer(cls_rep)
        
        loss = None
        if labels is not None:
            # print(logits.shape, labels.shape)
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)

        return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states,attentions=outputs.attentions)

In [10]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model=BertClassifier(checkpoint=checkpoint,n_classes=2).to(device)
device

device(type='cpu')

In [14]:
model

BertClassifier(
  (model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

In [13]:
checkpoint

'bert-base-cased'

In [11]:
optimizer = AdamW(model.parameters(), lr=learning_rate)
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)
print(num_training_steps)

24


In [15]:
progress_bar = tqdm(range(num_training_steps))
training_start_time = time.time()

for epoch in range(num_epochs):
    print("")
    print(f'======== Epoch {epoch+1} / {num_epochs} ========')
    print('Training...')

    # Measure how long the training epoch takes.
    epoch_start_time = time.time()

    # Reset the total loss for this epoch.
    total_train_loss = 0

    # Put the model into training mode. Don't be mislead--the call to 
    # `train` just changes the *mode*, it doesn't *perform* the training.
    # `dropout` and `batchnorm` layers behave differently during training
    # vs. test (source: https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch)
    model.train()
 
    for n_iter, batch in enumerate(train_dataloader):
        if n_iter % 3 == 0 and not n_iter == 0:
            # Calculate elapsed time in minutes.
            elapsed = datetime.timedelta(seconds=int(time.time() - epoch_start_time))
            # Report progress.
            print(f'  Batch {n_iter:>5,} of {len(train_dataloader):>5,}. Elapsed: {elapsed}.')
        
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        
        loss = outputs.loss
        total_train_loss+=loss
        
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
    
    # Measure how long this epoch took.
    print("Running Eval...")
    accuracy = load_metric("accuracy")
    f1 = load_metric("f1")
    total_eval_loss=0
    
    model.eval()
    for batch in eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
        eval_loss = outputs.loss
        total_eval_loss += eval_loss
        logits = outputs.logits
        labels = batch['labels']
        predictions = torch.argmax(logits, dim=-1)
        f1.add_batch(predictions=predictions, references=labels)
        accuracy.add_batch(predictions=predictions, references=labels)
   
    train_loss = total_train_loss/len(train_dataloader)
    eval_loss = total_eval_loss/len(eval_dataloader)
    epoch_time = datetime.timedelta(seconds=int(time.time() - epoch_start_time))
    print(f'Time taken:{epoch_time}')     
    print(f"Train Loss: {train_loss:5.3f}. Eval Loss: {eval_loss:5.3f}. Eval Accuracy:{accuracy.compute()}. Eval F1:{f1.compute()}.")
    
print("")
print("Training complete!")
train_time = datetime.timedelta(seconds=int(time.time()-training_start_time))
print(f"Total training took {train_time}")

    

  0%|          | 0/24 [00:00<?, ?it/s]


Training...
  Batch     3 of     8. Elapsed: 0:00:01.
  Batch     6 of     8. Elapsed: 0:00:03.
Running Eval...
Time taken:0:00:08
Train Loss: 0.004. Eval Loss: 3.519. Eval Accuracy:{'accuracy': 0.5}. Eval F1:{'f1': 0.0}.

Training...
  Batch     3 of     8. Elapsed: 0:00:01.
  Batch     6 of     8. Elapsed: 0:00:03.
Running Eval...
Time taken:0:00:07
Train Loss: 0.003. Eval Loss: 3.569. Eval Accuracy:{'accuracy': 0.5}. Eval F1:{'f1': 0.0}.

Training...
  Batch     3 of     8. Elapsed: 0:00:01.
  Batch     6 of     8. Elapsed: 0:00:03.
Running Eval...
Time taken:0:00:08
Train Loss: 0.002. Eval Loss: 3.569. Eval Accuracy:{'accuracy': 0.5}. Eval F1:{'f1': 0.0}.

Training complete!
Total training took 0:00:24


In [16]:
torch.save(model, model_dir+"bert-base-cased-CLS-finetuned-imdb")

In [18]:
finetuned_model = torch.load(model_dir+"bert-base-cased-CLS-finetuned-imdb")
finetuned_model

BertClassifier(
  (model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

In [19]:
test_accuracy = load_metric("accuracy")
test_f1 = load_metric("f1")
finetuned_model.eval()
test_progress_bar = tqdm(range(len(test_dataloader)))
for batch in test_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = finetuned_model(**batch)

    logits = outputs.logits
    labels = batch['labels']
    predictions = torch.argmax(logits, dim=-1)
    test_f1.add_batch(predictions=predictions, references=labels)
    test_accuracy.add_batch(predictions=predictions, references=labels)
    test_progress_bar.update(1)

print(f"Test Accuracy:{test_accuracy.compute()}. Test F1:{test_f1.compute()}.")

  0%|          | 0/8 [00:00<?, ?it/s]

Test Accuracy:{'accuracy': 0.515625}. Test F1:{'f1': 0.0}.
