# Using BERT-base to train compact BERT models with knowledge distillation

Resources
- BERT-base-uncased: https://huggingface.co/google-bert/bert-base-uncased
- TensorFlow sources for the compact BERT model family: https://github.com/google-research/bert
- Pytorch compact BERT implementations uploaded to HuggingFace hub by user: https://huggingface.co/prajjwal1
- Dataset: https://huggingface.co/datasets/SetFit/imdb

Tutorials
- PyTorch knowledge distillation: https://pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html
- HuggingFace fine tune pre-trained models: https://huggingface.co/docs/transformers/en/training

In [1]:
import torch
import torch_directml
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_scheduler
from datasets import load_dataset
import evaluate
from tqdm.auto import tqdm

device = torch_directml.device(torch_directml.default_device())
device

  from .autonotebook import tqdm as notebook_tqdm


device(type='privateuseone', index=0)

## Load pre-trained + fine-tuned BERT

In [2]:
bert_base_teacher = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
bert_base_teacher.load_state_dict(torch.load('bert-base-uncased_IMDB.pt'))
bert_base_teacher.to(device)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

## Load IMDB dataset
- this is the same dataset that was used to train BERT-base teacher model
- contains movie reviews labeled positive/negative

In [3]:
dataset = load_dataset("SetFit/imdb")
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'label_text'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label', 'label_text'],
        num_rows: 25000
    })
})

### Preprocessing
    - Tokenize the raw text of the reviews with BERT tokenizer
    - Restructure dataset structure to format expected by BERT models
    - Choose subset of the data for faster training

In [4]:
def build_data_loaders(dataset, subset_size, batch_size):
    # Tokenize
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    tokenized_datasets = dataset.map(
        lambda examples : tokenizer(examples["text"], padding="max_length", truncation=True), 
        batched=True
    )

    # Reformat
    tokenized_datasets = tokenized_datasets.remove_columns(["text", "label_text"])
    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
    tokenized_datasets.set_format("torch")

    # Use subset of the data
    train_subset = tokenized_datasets["train"].shuffle(seed=42).select(range(subset_size))
    test_subset = tokenized_datasets["test"].shuffle(seed=42).select(range(subset_size))

    # Make dataloaders with the processed data
    train_dataloader = torch.utils.data.DataLoader(train_subset, shuffle=True, batch_size=batch_size)
    test_dataloader = torch.utils.data.DataLoader(test_subset, batch_size=batch_size)

    return (train_dataloader, test_dataloader)

## Training and testing functions

### Standard Training
    - Train model only using the loss between predictions and labels

In [5]:
def standard_train(model, dataloader, num_epochs, device=device):
    # Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

    # Learning rate scheduler
    num_training_steps = num_epochs * len(dataloader)
    lr_scheduler = get_scheduler(
        name="cosine_with_restarts", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
    )

    # Training loop
    progress_bar = tqdm(range(num_training_steps))
    model.train()
    for epoch in range(num_epochs):
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
        print(f'Epoch {epoch+1}/{num_epochs} complete.')
        

### Knowledge distillation training
    - Train the model using loss between predictions and labels + loss between student and teacher predicitons
    - Apply some temperature T"" to control "smoothness" of logits (greater T means smoother probability distributions)
    - Apply some weight both losses, balancing how much each loss affects training

In [6]:
def KD_train(student, teacher, dataloader, num_epochs, T, soft_target_loss_weight, target_loss_weight, device=device):
    # Optimizer
    optimizer = torch.optim.AdamW(student.parameters(), lr=5e-5)

    # Learning rate scheduler
    num_training_steps = num_epochs * len(dataloader)
    lr_scheduler = get_scheduler(
        name="cosine_with_restarts", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
    )

    # Training loop
    progress_bar = tqdm(range(num_training_steps))
    student.train()
    teacher.eval()
    for epoch in range(num_epochs):
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}

            # Forward pass - evaluate with teacher and train with student
            with torch.no_grad():
                teacher_logits = teacher(**batch).logits
            student_outputs = student(**batch)
            student_logits = student_outputs.logits

            # Soften logits
            soft_targets = torch.nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = torch.nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate loss between teacher and student probabilities
            soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)
            
            # Calculate the true label loss
            label_loss = student_outputs.loss

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + target_loss_weight * label_loss

            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
        print(f'Epoch {epoch+1}/{num_epochs} complete.')

### Evaluation

In [7]:
def test(model, dataloader, device=device):
    metric = evaluate.load("accuracy")
    model.eval()
    for batch in dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=predictions, references=batch["labels"])

    return (metric.compute())

# Standard training VS Knowledge distillation for fine-tuning a pre-trained model

In [8]:
train_loader, test_loader = build_data_loaders(dataset, 1000, 32)

## BERT-tiny

In [9]:
baseline = AutoModelForSequenceClassification.from_pretrained('prajjwal1/bert-tiny')
bert_tiny = AutoModelForSequenceClassification.from_pretrained('prajjwal1/bert-tiny')

baseline.to(device)
bert_tiny.to(device)

  return self.fget.__get__(instance, owner)()
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-1): 2 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-12, e

In [10]:
KD_train(
    bert_tiny, 
    bert_base_teacher, 
    train_loader,
    num_epochs=10,
    T=4,
    soft_target_loss_weight=0.10,
    target_loss_weight=0.90
)

 10%|█         | 10/100 [00:40<05:03,  3.37s/it]

Epoch 1/10 complete.


 20%|██        | 20/100 [01:13<04:25,  3.32s/it]

Epoch 2/10 complete.


 30%|███       | 30/100 [01:46<03:51,  3.30s/it]

Epoch 3/10 complete.


 40%|████      | 40/100 [02:19<03:17,  3.29s/it]

Epoch 4/10 complete.


 50%|█████     | 50/100 [02:52<02:43,  3.28s/it]

Epoch 5/10 complete.


 60%|██████    | 60/100 [03:25<02:12,  3.31s/it]

Epoch 6/10 complete.


 70%|███████   | 70/100 [03:58<01:38,  3.27s/it]

Epoch 7/10 complete.


 80%|████████  | 80/100 [04:31<01:06,  3.30s/it]

Epoch 8/10 complete.


 90%|█████████ | 90/100 [05:04<00:32,  3.29s/it]

Epoch 9/10 complete.


100%|██████████| 100/100 [05:37<00:00,  3.37s/it]

Epoch 10/10 complete.





In [11]:
test(bert_tiny, test_loader)

{'accuracy': 0.62}

In [12]:
standard_train(
    baseline,
    train_loader,
    num_epochs=10,
)

 10%|█         | 10/100 [00:05<00:45,  2.00it/s]

Epoch 1/10 complete.


 20%|██        | 20/100 [00:10<00:39,  2.02it/s]

Epoch 2/10 complete.


 30%|███       | 30/100 [00:15<00:34,  2.03it/s]

Epoch 3/10 complete.


 40%|████      | 40/100 [00:20<00:30,  1.97it/s]

Epoch 4/10 complete.


 50%|█████     | 50/100 [00:25<00:25,  1.97it/s]

Epoch 5/10 complete.


 60%|██████    | 60/100 [00:30<00:19,  2.03it/s]

Epoch 6/10 complete.


 70%|███████   | 70/100 [00:35<00:14,  2.01it/s]

Epoch 7/10 complete.


 80%|████████  | 80/100 [00:40<00:10,  1.98it/s]

Epoch 8/10 complete.


 90%|█████████ | 90/100 [00:45<00:04,  2.01it/s]

Epoch 9/10 complete.


100%|██████████| 100/100 [00:50<00:00,  1.99it/s]

Epoch 10/10 complete.





In [13]:
test(baseline, test_loader)

{'accuracy': 0.632}

In [14]:
torch.save(baseline.state_dict(), './models/bert-tiny-baseline_IMDB.pt')
torch.save(bert_tiny.state_dict(), './models/bert-tiny_IMDB.pt')

In [15]:
del baseline
del bert_tiny

## BERT-mini

In [16]:
baseline = AutoModelForSequenceClassification.from_pretrained('prajjwal1/bert-mini')
bert_mini = AutoModelForSequenceClassification.from_pretrained('prajjwal1/bert-mini')

baseline.to(device)
bert_mini.to(device)

  return self.fget.__get__(instance, owner)()
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-mini and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-mini and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 256, padding_idx=0)
      (position_embeddings): Embedding(512, 256)
      (token_type_embeddings): Embedding(2, 256)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=256, out_features=256, bias=True)
              (value): Linear(in_features=256, out_features=256, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=256, out_features=256, bias=True)
              (LayerNorm): LayerNorm((256,), eps=1e-12, e

In [17]:
KD_train(
    bert_mini, 
    bert_base_teacher, 
    train_loader,
    num_epochs=10,
    T=4,
    soft_target_loss_weight=0.10,
    target_loss_weight=0.90
)

 10%|█         | 10/100 [00:43<06:35,  4.40s/it]

Epoch 1/10 complete.


 20%|██        | 20/100 [01:27<05:50,  4.38s/it]

Epoch 2/10 complete.


 30%|███       | 30/100 [02:10<05:00,  4.29s/it]

Epoch 3/10 complete.


 40%|████      | 40/100 [02:53<04:20,  4.33s/it]

Epoch 4/10 complete.


 50%|█████     | 50/100 [03:37<03:35,  4.31s/it]

Epoch 5/10 complete.


 60%|██████    | 60/100 [04:20<02:55,  4.38s/it]

Epoch 6/10 complete.


 70%|███████   | 70/100 [05:03<02:09,  4.33s/it]

Epoch 7/10 complete.


 80%|████████  | 80/100 [05:47<01:26,  4.34s/it]

Epoch 8/10 complete.


 90%|█████████ | 90/100 [06:30<00:43,  4.35s/it]

Epoch 9/10 complete.


100%|██████████| 100/100 [07:14<00:00,  4.34s/it]

Epoch 10/10 complete.





In [18]:
test(bert_mini, test_loader)

{'accuracy': 0.759}

In [19]:
standard_train(
    baseline, 
    train_loader, 
    num_epochs=10
)

 10%|█         | 10/100 [00:18<02:44,  1.82s/it]

Epoch 1/10 complete.


 20%|██        | 20/100 [00:36<02:25,  1.82s/it]

Epoch 2/10 complete.


 30%|███       | 30/100 [00:54<02:03,  1.77s/it]

Epoch 3/10 complete.


 40%|████      | 40/100 [01:11<01:41,  1.69s/it]

Epoch 4/10 complete.


 50%|█████     | 50/100 [01:28<01:27,  1.75s/it]

Epoch 5/10 complete.


 60%|██████    | 60/100 [01:45<01:09,  1.73s/it]

Epoch 6/10 complete.


 70%|███████   | 70/100 [02:03<00:52,  1.75s/it]

Epoch 7/10 complete.


 80%|████████  | 80/100 [02:20<00:34,  1.74s/it]

Epoch 8/10 complete.


 90%|█████████ | 90/100 [02:38<00:17,  1.74s/it]

Epoch 9/10 complete.


100%|██████████| 100/100 [02:55<00:00,  1.75s/it]

Epoch 10/10 complete.





In [20]:
test(baseline, test_loader)

{'accuracy': 0.762}

In [21]:
torch.save(baseline.state_dict(), './models/bert-mini-baseline_IMDB.pt')
torch.save(bert_mini.state_dict(), './models/bert-mini_IMDB.pt')

In [22]:
del baseline
del bert_mini

## BERT-small

In [9]:
baseline = AutoModelForSequenceClassification.from_pretrained('prajjwal1/bert-small')
bert_small = AutoModelForSequenceClassification.from_pretrained('prajjwal1/bert-small')

baseline.to(device)
bert_small.to(device)

  return self.fget.__get__(instance, owner)()
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-small and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-small and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 512, padding_idx=0)
      (position_embeddings): Embedding(512, 512)
      (token_type_embeddings): Embedding(2, 512)
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=512, out_features=512, bias=True)
              (key): Linear(in_features=512, out_features=512, bias=True)
              (value): Linear(in_features=512, out_features=512, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=512, out_features=512, bias=True)
              (LayerNorm): LayerNorm((512,), eps=1e-12, e

In [10]:
KD_train(
    bert_small, 
    bert_base_teacher, 
    train_loader,
    num_epochs=10,
    T=4,
    soft_target_loss_weight=0.10,
    target_loss_weight=0.90
)

 10%|█         | 32/320 [01:11<07:03,  1.47s/it] 

Epoch 1/10 complete.


 20%|██        | 64/320 [02:10<06:20,  1.49s/it]

Epoch 2/10 complete.


 30%|███       | 96/320 [03:09<05:30,  1.48s/it]

Epoch 3/10 complete.


 40%|████      | 128/320 [04:08<04:43,  1.48s/it]

Epoch 4/10 complete.


 50%|█████     | 160/320 [05:07<03:58,  1.49s/it]

Epoch 5/10 complete.


 60%|██████    | 192/320 [06:04<03:02,  1.42s/it]

Epoch 6/10 complete.


 70%|███████   | 224/320 [07:03<02:18,  1.45s/it]

Epoch 7/10 complete.


 80%|████████  | 256/320 [08:00<01:32,  1.44s/it]

Epoch 8/10 complete.


 90%|█████████ | 288/320 [08:58<00:46,  1.45s/it]

Epoch 9/10 complete.


100%|██████████| 320/320 [09:56<00:00,  1.86s/it]

Epoch 10/10 complete.





In [None]:
test(bert_small, test_loader)

{'accuracy': 0.876}

In [None]:
standard_train(
    baseline, 
    train_loader, 
    num_epochs=10
)

 10%|█         | 79/790 [01:23<09:27,  1.25it/s]

Epoch 1/10 complete.


 20%|██        | 158/790 [02:49<08:36,  1.22it/s]

Epoch 2/10 complete.


 30%|███       | 237/790 [04:16<07:31,  1.22it/s]

Epoch 3/10 complete.


 40%|████      | 316/790 [05:42<06:27,  1.22it/s]

Epoch 4/10 complete.


 50%|█████     | 395/790 [07:10<05:20,  1.23it/s]

Epoch 5/10 complete.


 60%|██████    | 474/790 [08:37<04:23,  1.20it/s]

Epoch 6/10 complete.


 70%|███████   | 553/790 [10:04<03:18,  1.19it/s]

Epoch 7/10 complete.


 80%|████████  | 632/790 [11:30<02:10,  1.21it/s]

Epoch 8/10 complete.


 90%|█████████ | 711/790 [12:56<01:05,  1.21it/s]

Epoch 9/10 complete.


100%|██████████| 790/790 [14:22<00:00,  1.09s/it]

Epoch 10/10 complete.





In [None]:
test(baseline, test_loader)

{'accuracy': 0.8788}

In [None]:
torch.save(baseline.state_dict(), './models/bert-small-baseline_IMDB.pt')
torch.save(bert_small.state_dict(), './models/bert-small_IMDB.pt')

In [None]:
del baseline
del bert_small

## BERT-medium

In [9]:
baseline = AutoModelForSequenceClassification.from_pretrained('prajjwal1/bert-medium')
bert_medium = AutoModelForSequenceClassification.from_pretrained('prajjwal1/bert-medium')

baseline.to(device)
bert_medium.to(device)

  return self.fget.__get__(instance, owner)()
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-medium and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-medium and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 512, padding_idx=0)
      (position_embeddings): Embedding(512, 512)
      (token_type_embeddings): Embedding(2, 512)
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-7): 8 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=512, out_features=512, bias=True)
              (key): Linear(in_features=512, out_features=512, bias=True)
              (value): Linear(in_features=512, out_features=512, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=512, out_features=512, bias=True)
              (LayerNorm): LayerNorm((512,), eps=1e-12, e

In [10]:
KD_train(
    bert_medium, 
    bert_base_teacher, 
    train_loader,
    num_epochs=10,
    T=4,
    soft_target_loss_weight=0.10,
    target_loss_weight=0.90
)

 10%|█         | 32/320 [01:34<09:43,  2.03s/it] 

Epoch 1/10 complete.


 20%|██        | 64/320 [02:56<08:43,  2.05s/it]

Epoch 2/10 complete.


 30%|███       | 96/320 [04:17<07:33,  2.02s/it]

Epoch 3/10 complete.


 40%|████      | 128/320 [05:38<06:26,  2.01s/it]

Epoch 4/10 complete.


 50%|█████     | 160/320 [06:59<05:17,  1.98s/it]

Epoch 5/10 complete.


 60%|██████    | 192/320 [08:20<04:17,  2.01s/it]

Epoch 6/10 complete.


 70%|███████   | 224/320 [09:47<03:31,  2.20s/it]

Epoch 7/10 complete.


 80%|████████  | 256/320 [11:16<02:21,  2.20s/it]

Epoch 8/10 complete.


 90%|█████████ | 288/320 [12:44<01:10,  2.19s/it]

Epoch 9/10 complete.


100%|██████████| 320/320 [14:12<00:00,  2.66s/it]

Epoch 10/10 complete.





In [11]:
test(bert_medium, test_loader)

{'accuracy': 0.839}

In [None]:
standard_train(
    baseline,
    train_loader,
    num_epochs=10
)

 10%|█         | 79/790 [02:34<17:36,  1.49s/it]

Epoch 1/10 complete.


 20%|██        | 158/790 [05:09<15:17,  1.45s/it]

Epoch 2/10 complete.


 30%|███       | 237/790 [07:44<13:34,  1.47s/it]

Epoch 3/10 complete.


 40%|████      | 316/790 [10:20<11:42,  1.48s/it]

Epoch 4/10 complete.


 50%|█████     | 395/790 [12:55<09:55,  1.51s/it]

Epoch 5/10 complete.


 60%|██████    | 474/790 [15:29<07:40,  1.46s/it]

Epoch 6/10 complete.


 70%|███████   | 553/790 [18:04<05:47,  1.46s/it]

Epoch 7/10 complete.


 80%|████████  | 632/790 [20:39<03:51,  1.47s/it]

Epoch 8/10 complete.


 90%|█████████ | 711/790 [23:13<01:57,  1.49s/it]

Epoch 9/10 complete.


100%|██████████| 790/790 [25:46<00:00,  1.96s/it]

Epoch 10/10 complete.





In [None]:
test(baseline, test_loader)

{'accuracy': 0.89}

In [None]:
torch.save(baseline.state_dict(), './models/bert-medium-baseline_IMDB.pt')
torch.save(bert_medium.state_dict(), './models/bert-medium_IMDB.pt')

In [14]:
del baseline
del bert_medium