## Preparation Step

### Data Preprocessing

Check library installation

In [1]:
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.11.0-py3-none-any.whl (468 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m468.7/468.7 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting huggingface-hub<1.0.0,>=0.11.0
  Downloading huggingface_hub-0.13.4-py3-none-any.whl (200 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m200.1/200.1 kB[0m [31m18.8 MB/s[0m eta [36m0:00:00[0m
Collecting aiohttp
  Downloading aiohttp-3.8.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m33.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting dill<0.3.7,>=0.3.0
  Downloading dill-0.3.6-py3-none-any.whl (110 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
Collecting responses<0.19
  Down

In [2]:
from datasets import load_dataset

ag_news_dataset = load_dataset("ag_news")

Downloading builder script:   0%|          | 0.00/4.06k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.65k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.95k [00:00<?, ?B/s]

Downloading and preparing dataset ag_news/default to /root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548...


Downloading data:   0%|          | 0.00/11.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/751k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/120000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7600 [00:00<?, ? examples/s]

Dataset ag_news downloaded and prepared to /root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
train_dataset = ag_news_dataset["train"]
test_dataset = ag_news_dataset["test"]

Import necessary library

In [4]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m53.4 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m100.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, transformers
Successfully installed tokenizers-0.13.3 transformers-4.28.1


In [5]:
import torch
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForSequenceClassification

In [6]:
print("PyTorch version:", torch.__version__)

PyTorch version: 2.0.0+cu118


Load dataset and tokenize it

In [7]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [8]:
def preprocess_data(dataset):
  texts = dataset['text']
  labels = dataset['label']
  
  # Tokenize
  # Note: length can be monified
  inputs = tokenizer(texts, padding='max_length', truncation=True, max_length=64, return_tensors='pt')
  
  # Convert labels to a tensor
  labels = torch.tensor(labels)
  
  # Combine inputs and labels
  processed_dataset = list(zip(inputs['input_ids'], inputs['attention_mask'], inputs['token_type_ids'], labels))
  
  return processed_dataset


In [9]:
train_dataset = preprocess_data(train_dataset)
test_dataset = preprocess_data(test_dataset)

In [10]:
def collate_fn(batch):
  input_ids, attention_mask, token_type_ids, labels = zip(*batch)
  input_ids = torch.stack(input_ids)
  attention_mask = torch.stack(attention_mask)
  token_type_ids = torch.stack(token_type_ids)
  labels = torch.tensor(labels, dtype=torch.long)
  return input_ids, attention_mask, token_type_ids, labels


Create Dataloader for train, validation, and testing:

In [11]:
batch_size = 128

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)


### Model Building

#### Import library:

In [12]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import BertModel

#### Teacher

VDCNN with pretrained BERT embeddings

In [13]:
class BasicConvResBlock(nn.Module):

  def __init__(self, input_dim=128, n_filters=256, kernel_size=3, padding=1, stride=1, shortcut=False, downsample=None):
    super(BasicConvResBlock, self).__init__()

    self.downsample = downsample
    self.shortcut = shortcut

    self.conv1 = nn.Conv1d(input_dim, n_filters, kernel_size=kernel_size, padding=padding, stride=stride)
    self.bn1 = nn.BatchNorm1d(n_filters)
    self.relu = nn.ReLU()
    self.conv2 = nn.Conv1d(n_filters, n_filters, kernel_size=kernel_size, padding=padding, stride=stride)
    self.bn2 = nn.BatchNorm1d(n_filters)

  def forward(self, x):

    residual = x

    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)

    if self.shortcut:
      if self.downsample is not None:
        residual = self.downsample(x)
      out += residual

    out = self.relu(out)

    return out


In [14]:
class VDCNN(nn.Module):
  def __init__(self, num_classes, embedding_dim=768, depth=9):
    super(VDCNN, self).__init__()
    #self.embedding = nn.Embedding(vocab_size, embedding_dim)
    self.conv1 = nn.Conv1d(embedding_dim, 64, kernel_size=3, padding=1)
    self.layers = self._make_layers(depth)
    self.pool = nn.AdaptiveMaxPool1d(8)
    fc_layers = []
    # fully connected
    fc_layers.extend([nn.Linear(8 * 512, 2048), nn.ReLU()])
    fc_layers.extend([nn.Linear(2048, 2048), nn.ReLU()])
    fc_layers.extend([nn.Linear(2048, n_classes)])
    self.fc_layers = nn.Sequential(*fc_layers)

  def _make_layers(self, depth):
    layers = []

    if depth == 9:
      n_conv_block_64, n_conv_block_128, n_conv_block_256, n_conv_block_512 = 1, 1, 1, 1
    elif depth == 17:
      n_conv_block_64, n_conv_block_128, n_conv_block_256, n_conv_block_512 = 2, 2, 2, 2
    elif depth == 29:
      n_conv_block_64, n_conv_block_128, n_conv_block_256, n_conv_block_512 = 5, 5, 2, 2
    elif depth == 49:
      n_conv_block_64, n_conv_block_128, n_conv_block_256, n_conv_block_512 = 8, 8, 5, 3

    layers.append(BasicConvResBlock(input_dim=64, n_filters=64, kernel_size=3, padding=1))
    for _ in range(n_conv_block_64-1):
      layers.append(BasicConvResBlock(input_dim=64, n_filters=64, kernel_size=3, padding=1))
    layers.append(nn.MaxPool1d(kernel_size=3, stride=2, padding=1)) # l = initial length / 2

    ds = nn.Sequential(nn.Conv1d(64, 128, kernel_size=1, stride=1, bias=False), nn.BatchNorm1d(128))
    layers.append(BasicConvResBlock(input_dim=64, n_filters=128, kernel_size=3, padding=1, downsample=ds))
    for _ in range(n_conv_block_128-1):
      layers.append(BasicConvResBlock(input_dim=128, n_filters=128, kernel_size=3, padding=1))
    layers.append(nn.MaxPool1d(kernel_size=3, stride=2, padding=1)) # l = initial length / 4

    ds = nn.Sequential(nn.Conv1d(128, 256, kernel_size=1, stride=1, bias=False), nn.BatchNorm1d(256))
    layers.append(BasicConvResBlock(input_dim=128, n_filters=256, kernel_size=3, padding=1, downsample=ds))
    for _ in range(n_conv_block_256 - 1):
      layers.append(BasicConvResBlock(input_dim=256, n_filters=256, kernel_size=3, padding=1))
    layers.append(nn.MaxPool1d(kernel_size=3, stride=2, padding=1))

    ds = nn.Sequential(nn.Conv1d(256, 512, kernel_size=1, stride=1, bias=False), nn.BatchNorm1d(512))
    layers.append(BasicConvResBlock(input_dim=256, n_filters=512, kernel_size=3, padding=1, downsample=ds))
    for _ in range(n_conv_block_512 - 1):
      layers.append(BasicConvResBlock(input_dim=512, n_filters=512, kernel_size=3, padding=1))

    return nn.Sequential(*layers)

  def forward(self, x):
    #x = self.embedding(x)
    #x = x.transpose(1, 2)
    x = self.conv1(x)
    x = self.layers(x)
    x = self.pool(x)
    x = x.view(x.size(0), -1)
    x = self.fc_layers(x)
    return x


In [15]:
class BERT_VDCNN(nn.Module):
  def __init__(self, n_classes, freeze_bert=True, depth=9):
    super(BERT_VDCNN, self).__init__()
    self.bert = BertModel.from_pretrained('bert-base-uncased')
    self.vdcnn = VDCNN(n_classes, depth=depth)
    if freeze_bert:
      for param in self.bert.parameters():
        param.requires_grad = False
  def forward(self, input_ids, attention_mask, token_type_ids):
    # load bert and feed in data
    outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
    embeddings = outputs.last_hidden_state.permute(0, 2, 1)  # Swap the last two dimensions to match VDCNN input shape (batch_size, embed_size, seq_length)

    # bert as embeddings of the vdcnn
    logits = self.vdcnn(embeddings)

    return logits


#### Student

In [None]:
# simple CNN model that will be trained and coached
class SimpleCNN(nn.Module):
  def __init__(self, vocab_size, embed_dim, n_classes):
    super(SimpleCNN, self).__init__()

    self.embedding = nn.Embedding(vocab_size, embed_dim)
    self.conv1 = nn.Conv1d(embed_dim, 128, 3, padding=1)
    self.conv2 = nn.Conv1d(128, 256, 3, padding=1)
    self.conv3 = nn.Conv1d(256, 512, 3, padding=1)
    self.conv4 = nn.Conv1d(512, 1024, 3, padding=1)
    self.fc = nn.Linear(1024, n_classes)

  def forward(self, x):
    x = self.embedding(x)
    x = x.permute(0, 2, 1)  # Switch dimensions for conv1d

    x = F.relu(self.conv1(x))
    x = F.relu(self.conv2(x))
    x = F.relu(self.conv3(x))
    x = F.relu(self.conv4(x))

    x = F.max_pool1d(x, x.size(2))  # Global max pooling
    x = x.squeeze(2)

    x = self.fc(x)
    return x

##### **Train without Teacher**

For reference only, no need to run multiple times

In [16]:
import torch.optim as optim
from tqdm import tqdm

In [17]:
# Training function
def trainSimple(model, dataloader, criterion, optimizer, device):
  model.train()
  running_loss = 0.0
  correct_predictions = 0
  total_predictions = 0
  for input_ids, attention_mask, token_type_ids, labels in tqdm(dataloader):
    input_ids, labels = input_ids.to(device), labels.to(device)

    optimizer.zero_grad()
    outputs = model(input_ids)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    _, preds = torch.max(outputs, 1)
    correct_predictions += (preds == labels).sum().item()
    total_predictions += labels.size(0)

    running_loss += loss.item()

  loss = running_loss / len(dataloader)
  acc = correct_predictions / total_predictions
  return loss, acc

In [18]:
def evaluateSimple(model, dataloader, device):
  model.eval()
  correct_predictions = 0
  total_predictions = 0
  running_loss = 0.0
  with torch.no_grad():
    for input_ids, attention_mask, token_type_ids, labels in tqdm(dataloader):
      input_ids, labels = input_ids.to(device), labels.to(device)


      outputs = model(input_ids)
      loss = criterion(outputs, labels)

      _, preds = torch.max(outputs, 1)
      correct_predictions += (preds == labels).sum().item()
      total_predictions += labels.size(0)

      running_loss += loss.item()

    loss = running_loss / len(dataloader)
    acc = correct_predictions / total_predictions

  return loss, acc

In [None]:
# Hyperparameters
vocab_size = len(tokenizer.vocab)
embed_dim = 128
n_classes = 4

# Create the model
model = SimpleCNN(vocab_size, embed_dim, n_classes)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Training loop
n_epochs = 20
best_acc = 0.0
for epoch in range(n_epochs):
    train_loss, train_acc = trainSimple(model, train_dataloader, criterion, optimizer, device)
    print(f"Epoch {epoch+1}/{n_epochs}, train loss: {train_loss:.4f}, acc: {train_acc:.4f}")
    val_loss, val_acc = evaluateSimple(model, test_dataloader, device)
    print(f"test loss: {val_loss:.4f}, acc: {val_acc:.4f}")
    if best_acc < val_acc:
      best_acc = val_acc

100%|██████████| 938/938 [00:35<00:00, 26.38it/s]


Epoch 1/20, train loss: 0.5714, acc: 0.7721


100%|██████████| 60/60 [00:00<00:00, 76.81it/s]


test loss: 0.3609, acc: 0.8747


100%|██████████| 938/938 [00:35<00:00, 26.51it/s]


Epoch 2/20, train loss: 0.2872, acc: 0.9013


100%|██████████| 60/60 [00:00<00:00, 77.52it/s]


test loss: 0.3153, acc: 0.8939


100%|██████████| 938/938 [00:35<00:00, 26.57it/s]


Epoch 3/20, train loss: 0.2175, acc: 0.9249


100%|██████████| 60/60 [00:00<00:00, 76.37it/s]


test loss: 0.2977, acc: 0.9036


100%|██████████| 938/938 [00:35<00:00, 26.55it/s]


Epoch 4/20, train loss: 0.1712, acc: 0.9412


100%|██████████| 60/60 [00:00<00:00, 77.40it/s]


test loss: 0.2853, acc: 0.9087


100%|██████████| 938/938 [00:35<00:00, 26.60it/s]


Epoch 5/20, train loss: 0.1359, acc: 0.9532


100%|██████████| 60/60 [00:00<00:00, 77.06it/s]


test loss: 0.3568, acc: 0.8924


100%|██████████| 938/938 [00:35<00:00, 26.70it/s]


Epoch 6/20, train loss: 0.1054, acc: 0.9643


100%|██████████| 60/60 [00:00<00:00, 76.62it/s]


test loss: 0.3266, acc: 0.9049


100%|██████████| 938/938 [00:35<00:00, 26.73it/s]


Epoch 7/20, train loss: 0.0851, acc: 0.9703


100%|██████████| 60/60 [00:00<00:00, 77.44it/s]


test loss: 0.3703, acc: 0.9039


100%|██████████| 938/938 [00:35<00:00, 26.73it/s]


Epoch 8/20, train loss: 0.0700, acc: 0.9757


100%|██████████| 60/60 [00:00<00:00, 77.42it/s]


test loss: 0.4385, acc: 0.9029


100%|██████████| 938/938 [00:35<00:00, 26.73it/s]


Epoch 9/20, train loss: 0.0576, acc: 0.9798


100%|██████████| 60/60 [00:00<00:00, 76.61it/s]


test loss: 0.4719, acc: 0.9005


100%|██████████| 938/938 [00:35<00:00, 26.70it/s]


Epoch 10/20, train loss: 0.0494, acc: 0.9839


100%|██████████| 60/60 [00:00<00:00, 76.88it/s]


test loss: 0.4562, acc: 0.9076


100%|██████████| 938/938 [00:35<00:00, 26.69it/s]


Epoch 11/20, train loss: 0.0444, acc: 0.9854


100%|██████████| 60/60 [00:00<00:00, 77.20it/s]


test loss: 0.4628, acc: 0.9041


100%|██████████| 938/938 [00:35<00:00, 26.75it/s]


Epoch 12/20, train loss: 0.0378, acc: 0.9873


100%|██████████| 60/60 [00:00<00:00, 76.78it/s]


test loss: 0.6042, acc: 0.8889


100%|██████████| 938/938 [00:35<00:00, 26.69it/s]


Epoch 13/20, train loss: 0.0382, acc: 0.9877


100%|██████████| 60/60 [00:00<00:00, 76.63it/s]


test loss: 0.5763, acc: 0.9049


100%|██████████| 938/938 [00:35<00:00, 26.76it/s]


Epoch 14/20, train loss: 0.0339, acc: 0.9887


100%|██████████| 60/60 [00:00<00:00, 76.28it/s]


test loss: 0.5789, acc: 0.9004


100%|██████████| 938/938 [00:34<00:00, 26.82it/s]


Epoch 15/20, train loss: 0.0302, acc: 0.9900


100%|██████████| 60/60 [00:00<00:00, 76.67it/s]


test loss: 0.5893, acc: 0.9055


100%|██████████| 938/938 [00:35<00:00, 26.74it/s]


Epoch 16/20, train loss: 0.0321, acc: 0.9898


100%|██████████| 60/60 [00:00<00:00, 76.98it/s]


test loss: 0.6324, acc: 0.9011


100%|██████████| 938/938 [00:35<00:00, 26.74it/s]


Epoch 17/20, train loss: 0.0302, acc: 0.9905


100%|██████████| 60/60 [00:00<00:00, 76.73it/s]


test loss: 0.5982, acc: 0.9005


100%|██████████| 938/938 [00:35<00:00, 26.79it/s]


Epoch 18/20, train loss: 0.0268, acc: 0.9915


100%|██████████| 60/60 [00:00<00:00, 77.54it/s]


test loss: 0.6502, acc: 0.9036


100%|██████████| 938/938 [00:35<00:00, 26.79it/s]


Epoch 19/20, train loss: 0.0263, acc: 0.9919


100%|██████████| 60/60 [00:00<00:00, 77.19it/s]


test loss: 0.6793, acc: 0.9041


100%|██████████| 938/938 [00:34<00:00, 26.84it/s]


Epoch 20/20, train loss: 0.0271, acc: 0.9913


100%|██████████| 60/60 [00:00<00:00, 75.82it/s]

test loss: 0.6836, acc: 0.8988





In [None]:
print(f"best acc: {best_acc}")

best acc: 0.9086842105263158


#### Student 2

In [25]:
# even small cnn
class SimplierCNN(nn.Module):
  def __init__(self, vocab_size, embed_dim, n_classes):
    super(SimplierCNN, self).__init__()

    self.embedding = nn.Embedding(vocab_size, embed_dim)
    self.conv1 = nn.Conv1d(embed_dim, 128, 3, padding=1)
    self.fc = nn.Linear(128, n_classes)

  def forward(self, x):
    x = self.embedding(x)
    x = x.permute(0, 2, 1)  # Switch dimensions for conv1d

    x = F.relu(self.conv1(x))

    x = F.max_pool1d(x, x.size(2))  # Global max pooling
    x = x.squeeze(2)

    x = self.fc(x)
    return x

##### **Train without Teacher**

For reference only, no need to run multiple times

In [26]:
# Hyperparameters
vocab_size = len(tokenizer.vocab)
embed_dim = 128
n_classes = 4

# Create the model
model = SimplierCNN(vocab_size, embed_dim, n_classes)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [27]:
# Training loop
n_epochs = 20
best_acc = 0.0
for epoch in range(n_epochs):
    train_loss, train_acc = trainSimple(model, train_dataloader, criterion, optimizer, device)
    print(f"Epoch {epoch+1}/{n_epochs}, train loss: {train_loss:.4f}, acc: {train_acc:.4f}")
    val_loss, val_acc = evaluateSimple(model, test_dataloader, device)
    print(f"test loss: {val_loss:.4f}, acc: {val_acc:.4f}")
    if best_acc < val_acc:
      best_acc = val_acc

100%|██████████| 938/938 [00:04<00:00, 197.62it/s]


Epoch 1/20, train loss: 0.4973, acc: 0.8276


100%|██████████| 60/60 [00:00<00:00, 878.09it/s]


test loss: 0.3361, acc: 0.8866


100%|██████████| 938/938 [00:04<00:00, 218.86it/s]


Epoch 2/20, train loss: 0.2451, acc: 0.9186


100%|██████████| 60/60 [00:00<00:00, 839.93it/s]


test loss: 0.2958, acc: 0.9055


100%|██████████| 938/938 [00:04<00:00, 215.51it/s]


Epoch 3/20, train loss: 0.1672, acc: 0.9447


100%|██████████| 60/60 [00:00<00:00, 627.72it/s]


test loss: 0.3104, acc: 0.9054


100%|██████████| 938/938 [00:04<00:00, 203.02it/s]


Epoch 4/20, train loss: 0.1116, acc: 0.9647


100%|██████████| 60/60 [00:00<00:00, 843.73it/s]


test loss: 0.3162, acc: 0.9064


100%|██████████| 938/938 [00:04<00:00, 221.89it/s]


Epoch 5/20, train loss: 0.0708, acc: 0.9787


100%|██████████| 60/60 [00:00<00:00, 844.18it/s]


test loss: 0.3435, acc: 0.9024


100%|██████████| 938/938 [00:04<00:00, 219.42it/s]


Epoch 6/20, train loss: 0.0428, acc: 0.9882


100%|██████████| 60/60 [00:00<00:00, 644.82it/s]


test loss: 0.3684, acc: 0.9038


100%|██████████| 938/938 [00:04<00:00, 198.64it/s]


Epoch 7/20, train loss: 0.0286, acc: 0.9928


100%|██████████| 60/60 [00:00<00:00, 888.86it/s]


test loss: 0.4064, acc: 0.9021


100%|██████████| 938/938 [00:04<00:00, 220.30it/s]


Epoch 8/20, train loss: 0.0233, acc: 0.9943


100%|██████████| 60/60 [00:00<00:00, 861.54it/s]


test loss: 0.4293, acc: 0.9041


100%|██████████| 938/938 [00:04<00:00, 217.25it/s]


Epoch 9/20, train loss: 0.0208, acc: 0.9950


100%|██████████| 60/60 [00:00<00:00, 666.73it/s]


test loss: 0.4640, acc: 0.9043


100%|██████████| 938/938 [00:04<00:00, 195.91it/s]


Epoch 10/20, train loss: 0.0191, acc: 0.9955


100%|██████████| 60/60 [00:00<00:00, 784.50it/s]


test loss: 0.5070, acc: 0.9018


100%|██████████| 938/938 [00:04<00:00, 218.55it/s]


Epoch 11/20, train loss: 0.0169, acc: 0.9963


100%|██████████| 60/60 [00:00<00:00, 802.91it/s]


test loss: 0.5010, acc: 0.9039


100%|██████████| 938/938 [00:04<00:00, 218.50it/s]


Epoch 12/20, train loss: 0.0179, acc: 0.9961


100%|██████████| 60/60 [00:00<00:00, 669.96it/s]


test loss: 0.5397, acc: 0.8982


100%|██████████| 938/938 [00:04<00:00, 193.87it/s]


Epoch 13/20, train loss: 0.0144, acc: 0.9968


100%|██████████| 60/60 [00:00<00:00, 856.05it/s]


test loss: 0.5435, acc: 0.9092


100%|██████████| 938/938 [00:04<00:00, 218.57it/s]


Epoch 14/20, train loss: 0.0157, acc: 0.9968


100%|██████████| 60/60 [00:00<00:00, 876.74it/s]


test loss: 0.5656, acc: 0.9049


100%|██████████| 938/938 [00:04<00:00, 219.26it/s]


Epoch 15/20, train loss: 0.0135, acc: 0.9973


100%|██████████| 60/60 [00:00<00:00, 883.93it/s]


test loss: 0.5843, acc: 0.9047


100%|██████████| 938/938 [00:04<00:00, 194.60it/s]


Epoch 16/20, train loss: 0.0132, acc: 0.9976


100%|██████████| 60/60 [00:00<00:00, 871.28it/s]


test loss: 0.5938, acc: 0.9026


100%|██████████| 938/938 [00:04<00:00, 219.72it/s]


Epoch 17/20, train loss: 0.0136, acc: 0.9974


100%|██████████| 60/60 [00:00<00:00, 863.48it/s]


test loss: 0.6178, acc: 0.9028


100%|██████████| 938/938 [00:04<00:00, 218.90it/s]


Epoch 18/20, train loss: 0.0123, acc: 0.9977


100%|██████████| 60/60 [00:00<00:00, 900.63it/s]


test loss: 0.6956, acc: 0.8889


100%|██████████| 938/938 [00:04<00:00, 195.65it/s]


Epoch 19/20, train loss: 0.0119, acc: 0.9977


100%|██████████| 60/60 [00:00<00:00, 845.17it/s]


test loss: 0.6329, acc: 0.8986


100%|██████████| 938/938 [00:04<00:00, 216.03it/s]


Epoch 20/20, train loss: 0.0118, acc: 0.9980


100%|██████████| 60/60 [00:00<00:00, 858.26it/s]

test loss: 0.6604, acc: 0.9007





In [28]:
print(f"best acc: {best_acc}")

best acc: 0.9092105263157895


## Training Teacher

In [None]:
import torch.optim as optim
from tqdm import tqdm

In [None]:
def train(model, dataloader, optimizer, criterion, device):
  model.train()
  running_loss = 0.0
  running_acc = 0.0
  correct_predictions = 0
  total_predictions = 0
  for input_ids, attention_mask, token_type_ids, labels in tqdm(dataloader):
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    token_type_ids = token_type_ids.to(device)
    labels = labels.to(device)

    optimizer.zero_grad()

    logits = model(input_ids, attention_mask, token_type_ids)
    loss = criterion(logits, labels)
    loss.backward()
    optimizer.step()

    _, preds = torch.max(logits, 1)

    correct_predictions += (preds == labels).sum().item()
    total_predictions += labels.size(0)

    running_loss += loss.item()
  
  loss = running_loss / len(dataloader)
  acc = correct_predictions / total_predictions
  return loss, acc


In [None]:
def evaluate(model, dataloader, device):
    model.eval()
    correct_predictions = 0
    total_predictions = 0
    running_loss = 0.0
    with torch.no_grad():
      for input_ids, attention_mask, token_type_ids, labels in tqdm(dataloader):
          input_ids = input_ids.to(device)
          attention_mask = attention_mask.to(device)
          token_type_ids = token_type_ids.to(device)
          labels = labels.to(device)

          logits = model(input_ids, attention_mask, token_type_ids)
          loss = criterion(logits, labels)
          _, preds = torch.max(logits, 1)
          running_loss += loss.item()
          correct_predictions += (preds == labels).sum().item()
          total_predictions += labels.size(0)
      loss = running_loss / len(dataloader)
      acc = correct_predictions / total_predictions

    return loss, acc

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_classes = 4
model_T = BERT_VDCNN(n_classes, freeze_bert=True, depth=9)
model_T.to(device)

optimizer = optim.Adam(model_T.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
num_epochs = 30
train_loss = []
train_acc = []
best_acc = 0.0
test_loss_record = []
test_acc_record = []

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    current_loss, current_acc = train(model_T, train_dataloader, optimizer, criterion, device)
    train_loss.append(current_loss)
    train_acc.append(current_acc)
    test_loss, test_accuracy = evaluate(model_T, test_dataloader, device)
    test_loss_record.append(test_loss)
    test_acc_record.append(test_accuracy)
    if best_acc < test_accuracy:
      best_acc = test_accuracy
      torch.save(model_T.state_dict(), 'best_teacher.pth')
    print(f'Train Loss: {current_loss}, Acc: {current_acc}')
    print(f"Test Loss: {test_loss},Test Accuracy: {test_accuracy:.4f}")
print(f'best test acc: {best_acc}')

Epoch 1/30


100%|██████████| 938/938 [09:12<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.16it/s]


Train Loss: 1.4028130614204701, Acc: 0.7205166666666667
Test Loss: 0.34167931402722995,Test Accuracy: 0.8830
Epoch 2/30


100%|██████████| 938/938 [09:14<00:00,  1.69it/s]
100%|██████████| 60/60 [00:27<00:00,  2.17it/s]


Train Loss: 0.4320402960660361, Acc: 0.8678833333333333
Test Loss: 0.32749261235197386,Test Accuracy: 0.8861
Epoch 3/30


100%|██████████| 938/938 [09:13<00:00,  1.69it/s]
100%|██████████| 60/60 [00:27<00:00,  2.15it/s]


Train Loss: 0.3242894551996737, Acc: 0.8848583333333333
Test Loss: 0.3209668296078841,Test Accuracy: 0.8888
Epoch 4/30


100%|██████████| 938/938 [09:13<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.17it/s]


Train Loss: 0.5000845436761374, Acc: 0.8188583333333334
Test Loss: 0.39180633748571075,Test Accuracy: 0.8695
Epoch 5/30


100%|██████████| 938/938 [09:12<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.17it/s]


Train Loss: 0.36397088420734225, Acc: 0.8858083333333333
Test Loss: 0.31378791158397995,Test Accuracy: 0.8953
Epoch 6/30


100%|██████████| 938/938 [09:12<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.15it/s]


Train Loss: 0.4715917584325459, Acc: 0.8912416666666667
Test Loss: 0.2851089959343274,Test Accuracy: 0.9043
Epoch 7/30


100%|██████████| 938/938 [09:12<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.15it/s]


Train Loss: 0.3350436176572527, Acc: 0.894775
Test Loss: 0.29518255988756814,Test Accuracy: 0.9045
Epoch 8/30


100%|██████████| 938/938 [09:13<00:00,  1.69it/s]
100%|██████████| 60/60 [00:27<00:00,  2.15it/s]


Train Loss: 0.30036312800798337, Acc: 0.9004916666666667
Test Loss: 0.281724089384079,Test Accuracy: 0.9043
Epoch 9/30


100%|██████████| 938/938 [09:13<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.17it/s]


Train Loss: 0.28241916588628724, Acc: 0.9047833333333334
Test Loss: 0.26625862816969553,Test Accuracy: 0.9116
Epoch 10/30


100%|██████████| 938/938 [09:13<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.15it/s]


Train Loss: 0.2783354803570298, Acc: 0.90715
Test Loss: 0.2572039429098368,Test Accuracy: 0.9117
Epoch 11/30


100%|██████████| 938/938 [09:13<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.16it/s]


Train Loss: 0.2659723780461466, Acc: 0.9083166666666667
Test Loss: 0.26709608907500904,Test Accuracy: 0.9116
Epoch 12/30


100%|██████████| 938/938 [09:14<00:00,  1.69it/s]
100%|██████████| 60/60 [00:27<00:00,  2.16it/s]


Train Loss: 0.2607342643556056, Acc: 0.908975
Test Loss: 0.24885400608181954,Test Accuracy: 0.9132
Epoch 13/30


100%|██████████| 938/938 [09:13<00:00,  1.69it/s]
100%|██████████| 60/60 [00:27<00:00,  2.17it/s]


Train Loss: 0.2526881830143268, Acc: 0.91265
Test Loss: 0.24411788284778596,Test Accuracy: 0.9146
Epoch 14/30


100%|██████████| 938/938 [09:12<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.17it/s]


Train Loss: 0.24594970167413957, Acc: 0.9138166666666667
Test Loss: 0.24156106288234394,Test Accuracy: 0.9175
Epoch 15/30


100%|██████████| 938/938 [09:12<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.16it/s]


Train Loss: 0.2535448242892335, Acc: 0.9140916666666666
Test Loss: 0.24667892021437485,Test Accuracy: 0.9158
Epoch 16/30


100%|██████████| 938/938 [09:12<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.16it/s]


Train Loss: 0.2407683611138543, Acc: 0.91535
Test Loss: 0.23574548860390981,Test Accuracy: 0.9186
Epoch 17/30


100%|██████████| 938/938 [09:11<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.15it/s]


Train Loss: 0.22701736518156046, Acc: 0.9207333333333333
Test Loss: 0.2528752771516641,Test Accuracy: 0.9143
Epoch 18/30


100%|██████████| 938/938 [09:11<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.16it/s]


Train Loss: 0.2185129800410286, Acc: 0.92315
Test Loss: 0.22361003781358402,Test Accuracy: 0.9216
Epoch 19/30


100%|██████████| 938/938 [09:11<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.16it/s]


Train Loss: 0.21449082079472573, Acc: 0.92475
Test Loss: 0.2165831613043944,Test Accuracy: 0.9243
Epoch 20/30


100%|██████████| 938/938 [09:11<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.16it/s]


Train Loss: 0.21156923561049168, Acc: 0.9266666666666666
Test Loss: 0.2274640622238318,Test Accuracy: 0.9212
Epoch 21/30


100%|██████████| 938/938 [09:10<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.17it/s]


Train Loss: 0.2042828661514752, Acc: 0.9286583333333334
Test Loss: 0.22430123972396057,Test Accuracy: 0.9230
Epoch 22/30


100%|██████████| 938/938 [09:11<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.16it/s]


Train Loss: 0.19927882558024768, Acc: 0.93005
Test Loss: 0.21761774371067683,Test Accuracy: 0.9278
Epoch 23/30


100%|██████████| 938/938 [09:11<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.16it/s]


Train Loss: 0.19639955089290514, Acc: 0.9311166666666667
Test Loss: 0.21140520982444286,Test Accuracy: 0.9280
Epoch 24/30


100%|██████████| 938/938 [09:10<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.16it/s]


Train Loss: 0.19255278824266595, Acc: 0.9317583333333334
Test Loss: 0.21319906152784823,Test Accuracy: 0.9268
Epoch 25/30


100%|██████████| 938/938 [09:10<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.17it/s]


Train Loss: 0.1885082815676483, Acc: 0.9335666666666667
Test Loss: 0.21796221199134985,Test Accuracy: 0.9245
Epoch 26/30


100%|██████████| 938/938 [09:11<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.17it/s]


Train Loss: 0.18336903898398887, Acc: 0.9355833333333333
Test Loss: 0.20843046406904855,Test Accuracy: 0.9304
Epoch 27/30


100%|██████████| 938/938 [09:11<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.16it/s]


Train Loss: 0.18253517586952334, Acc: 0.9361333333333334
Test Loss: 0.21753815648456415,Test Accuracy: 0.9258
Epoch 28/30


100%|██████████| 938/938 [09:11<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.16it/s]


Train Loss: 0.1774194588475644, Acc: 0.9369166666666666
Test Loss: 0.21946321092545987,Test Accuracy: 0.9266
Epoch 29/30


100%|██████████| 938/938 [09:10<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.17it/s]


Train Loss: 0.17480221632987197, Acc: 0.9379666666666666
Test Loss: 0.21774864867329596,Test Accuracy: 0.9326
Epoch 30/30


100%|██████████| 938/938 [09:11<00:00,  1.70it/s]
100%|██████████| 60/60 [00:27<00:00,  2.17it/s]

Train Loss: 0.17272642991547263, Acc: 0.9392916666666666
Test Loss: 0.22159153409302235,Test Accuracy: 0.9289
best test acc: 0.9326315789473684





### Utility Function for Plotting

In [None]:
import matplotlib.pyplot as plt

In [None]:
def plot_train_curve(val_loss, val_acc, name_1, name_2):
  fig, ax1 = plt.subplots()

  color = 'tab:red'
  ax1.set_xlabel('Epoch')
  ax1.set_ylabel('Loss', color=color)
  ax1.plot(val_loss, color=color, linestyle='--', label=name_1)
  ax1.tick_params(axis='y', labelcolor=color)
  ax1.legend(loc='upper left')

  ax2 = ax1.twinx()
  color = 'tab:blue'
  ax2.set_ylabel('Accuracy', color=color)
  ax2.plot(val_acc, color=color, label=name_2)
  ax2.tick_params(axis='y', labelcolor=color)
  ax2.legend(loc='upper right')

  # Show the plot
  plt.show()

In [None]:
#plot example
plot_train_curve(test_loss_record, test_acc_record, 'test loss', 'test acc')

## Train Student

Dist Loss

In [29]:
def intra_class_relation(y_s, y_t):
    return inter_class_relation(y_s.transpose(0, 1), y_t.transpose(0, 1))

def inter_class_relation(y_s, y_t):
    return 1 - pearson_correlation(y_s, y_t).mean()

def cosine_similarity(a, b, eps=1e-8):
    return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps)


def pearson_correlation(a, b, eps=1e-8):
    return cosine_similarity(a - a.mean(1).unsqueeze(1),
                             b - b.mean(1).unsqueeze(1), eps)



class DISTLoss(nn.Module):
    def __init__(self, beta=2.0, gamma=2.0, tau=1.0):
        super(DISTLoss, self).__init__()
        self.beta = beta
        self.gamma = gamma
        self.tau = tau

    def forward(self, z_s, z_t, labels):
        y_s = (z_s / self.tau).softmax(dim=1)
        y_t = (z_t / self.tau).softmax(dim=1)
       # print(y_s.shape)
       # print(y_t.shape)
        inter_loss = self.tau**2 * inter_class_relation(y_s, y_t)
        intra_loss = self.tau**2 * intra_class_relation(y_s, y_t)
        classification_loss = nn.CrossEntropyLoss()(z_s, labels)
        kd_loss = classification_loss+self.beta * inter_loss + self.gamma * intra_loss
        return kd_loss

In [30]:
# Training function
def trainWithT(model_S, model_T, dataloader, criterion, optimizer, device):
  model_S.train()
  model_T.eval()
  running_loss = 0.0
  correct_predictions = 0
  total_predictions = 0
  for input_ids, attention_mask, token_type_ids, labels in tqdm(dataloader):
    input_ids, labels = input_ids.to(device), labels.to(device)
    attention_mask = attention_mask.to(device)
    token_type_ids = token_type_ids.to(device)
    optimizer.zero_grad()
    studenet_logit = model_S(input_ids)
    teacher_logit = model_T(input_ids, attention_mask, token_type_ids)
    loss = criterion(studenet_logit, teacher_logit, labels)
    loss.backward()
    optimizer.step()

    _, preds = torch.max(studenet_logit, 1)
    correct_predictions += (preds == labels).sum().item()
    total_predictions += labels.size(0)

    running_loss += loss.item()

  loss = running_loss / len(dataloader)
  acc = correct_predictions / total_predictions
  return loss, acc

def evaluate(model, dataloader, device):
  model.eval()
  correct_predictions = 0
  total_predictions = 0
  running_loss = 0.0
  with torch.no_grad():
    for input_ids, attention_mask, token_type_ids, labels in tqdm(dataloader):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        labels = labels.to(device)

        logits = model(input_ids, attention_mask, token_type_ids)
        loss = criterion(logits, labels)
        _, preds = torch.max(logits, 1)
        running_loss += loss.item()
        correct_predictions += (preds == labels).sum().item()
        total_predictions += labels.size(0)
    loss = running_loss / len(dataloader)
    acc = correct_predictions / total_predictions

  return loss, acc

##### Student 1

In [None]:
# Hyperparameters
vocab_size = len(tokenizer.vocab)
embed_dim = 128
n_classes = 4

# Create the model
model_S = SimpleCNN(vocab_size, embed_dim, n_classes)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_S.to(device)

# Loss and optimizer
criterion_S = DISTLoss()
optimizer_S = optim.Adam(model_S.parameters(), lr=0.001)

Load teacher:

(reminder to check path）

In [None]:
model_T = BERT_VDCNN(n_classes, freeze_bert=True, depth=9)
checkpoint_path = './best_teacher.pth' # Will need adjustment if model path or name is different
# Load the state_dict into model_T
model_T.load_state_dict(torch.load(checkpoint_path))
model_T.to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BERT_VDCNN(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine

Train Loop, with teacher

In [None]:
# Training loop
n_epochs = 20
best_acc = 0.0
for epoch in range(n_epochs):
    train_loss, train_acc = trainWithT(model_S, model_T, train_dataloader, criterion_S, optimizer_S, device)
    print(f"Epoch {epoch+1}/{n_epochs}, train loss: {train_loss:.4f}, acc: {train_acc:.4f}")
    val_loss, val_acc = evaluateSimple(model_S, test_dataloader, device)
    print(f"test loss: {val_loss:.4f}, acc: {val_acc:.4f}")
    if best_acc < val_acc:
      best_acc = val_acc

100%|██████████| 938/938 [09:29<00:00,  1.65it/s]


Epoch 1/20, train loss: 1.3036, acc: 0.7993


100%|██████████| 60/60 [00:00<00:00, 75.81it/s]


test loss: 0.3732, acc: 0.8746


100%|██████████| 938/938 [09:36<00:00,  1.63it/s]


Epoch 2/20, train loss: 0.5621, acc: 0.9055


100%|██████████| 60/60 [00:00<00:00, 76.09it/s]


test loss: 0.3000, acc: 0.9012


100%|██████████| 938/938 [09:36<00:00,  1.63it/s]


Epoch 3/20, train loss: 0.4215, acc: 0.9249


100%|██████████| 60/60 [00:00<00:00, 76.62it/s]


test loss: 0.2718, acc: 0.9109


100%|██████████| 938/938 [09:35<00:00,  1.63it/s]


Epoch 4/20, train loss: 0.3432, acc: 0.9360


100%|██████████| 60/60 [00:00<00:00, 74.94it/s]


test loss: 0.2771, acc: 0.9105


100%|██████████| 938/938 [09:35<00:00,  1.63it/s]


Epoch 5/20, train loss: 0.2948, acc: 0.9434


100%|██████████| 60/60 [00:00<00:00, 75.40it/s]


test loss: 0.2643, acc: 0.9143


100%|██████████| 938/938 [09:35<00:00,  1.63it/s]


Epoch 6/20, train loss: 0.2607, acc: 0.9495


100%|██████████| 60/60 [00:00<00:00, 74.80it/s]


test loss: 0.2765, acc: 0.9141


100%|██████████| 938/938 [09:35<00:00,  1.63it/s]


Epoch 7/20, train loss: 0.2385, acc: 0.9527


100%|██████████| 60/60 [00:00<00:00, 77.13it/s]


test loss: 0.2755, acc: 0.9107


100%|██████████| 938/938 [09:35<00:00,  1.63it/s]


Epoch 8/20, train loss: 0.2166, acc: 0.9564


100%|██████████| 60/60 [00:00<00:00, 76.96it/s]


test loss: 0.2741, acc: 0.9171


100%|██████████| 938/938 [09:35<00:00,  1.63it/s]


Epoch 9/20, train loss: 0.2089, acc: 0.9578


100%|██████████| 60/60 [00:00<00:00, 77.33it/s]


test loss: 0.2842, acc: 0.9086


100%|██████████| 938/938 [09:35<00:00,  1.63it/s]


Epoch 10/20, train loss: 0.1957, acc: 0.9602


100%|██████████| 60/60 [00:00<00:00, 77.56it/s]


test loss: 0.2636, acc: 0.9170


100%|██████████| 938/938 [09:35<00:00,  1.63it/s]


Epoch 11/20, train loss: 0.1867, acc: 0.9614


100%|██████████| 60/60 [00:00<00:00, 75.42it/s]


test loss: 0.2846, acc: 0.9155


100%|██████████| 938/938 [09:35<00:00,  1.63it/s]


Epoch 12/20, train loss: 0.1861, acc: 0.9616


100%|██████████| 60/60 [00:00<00:00, 77.62it/s]


test loss: 0.2776, acc: 0.9166


100%|██████████| 938/938 [09:36<00:00,  1.63it/s]


Epoch 13/20, train loss: 0.1773, acc: 0.9624


100%|██████████| 60/60 [00:00<00:00, 77.44it/s]


test loss: 0.2903, acc: 0.9142


100%|██████████| 938/938 [09:35<00:00,  1.63it/s]


Epoch 14/20, train loss: 0.1728, acc: 0.9639


100%|██████████| 60/60 [00:00<00:00, 77.42it/s]


test loss: 0.2759, acc: 0.9182


100%|██████████| 938/938 [09:35<00:00,  1.63it/s]


Epoch 15/20, train loss: 0.1712, acc: 0.9631


100%|██████████| 60/60 [00:00<00:00, 77.29it/s]


test loss: 0.2903, acc: 0.9099


100%|██████████| 938/938 [09:35<00:00,  1.63it/s]


Epoch 16/20, train loss: 0.1673, acc: 0.9640


100%|██████████| 60/60 [00:00<00:00, 76.99it/s]


test loss: 0.2737, acc: 0.9225


100%|██████████| 938/938 [09:35<00:00,  1.63it/s]


Epoch 17/20, train loss: 0.1644, acc: 0.9644


100%|██████████| 60/60 [00:00<00:00, 77.87it/s]


test loss: 0.2944, acc: 0.9159


100%|██████████| 938/938 [09:34<00:00,  1.63it/s]


Epoch 18/20, train loss: 0.1627, acc: 0.9648


100%|██████████| 60/60 [00:00<00:00, 77.87it/s]


test loss: 0.2798, acc: 0.9163


100%|██████████| 938/938 [09:35<00:00,  1.63it/s]


Epoch 19/20, train loss: 0.1627, acc: 0.9641


100%|██████████| 60/60 [00:00<00:00, 77.75it/s]


test loss: 0.2983, acc: 0.9124


100%|██████████| 938/938 [09:34<00:00,  1.63it/s]


Epoch 20/20, train loss: 0.1578, acc: 0.9653


100%|██████████| 60/60 [00:00<00:00, 77.49it/s]

test loss: 0.2865, acc: 0.9161





##### Student 2

In [35]:
# Hyperparameters
vocab_size = len(tokenizer.vocab)
embed_dim = 128
n_classes = 4

# Create the model
model_S = SimplierCNN(vocab_size, embed_dim, n_classes)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_S.to(device)

# Loss and optimizer
criterion_S = DISTLoss()
optimizer_S = optim.Adam(model_S.parameters(), lr=0.001)

Load teacher:

(reminder to check path）

In [32]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [33]:
model_T = BERT_VDCNN(n_classes, freeze_bert=True, depth=9)
checkpoint_path = './drive/MyDrive/DeepLearn-Project/NLP_Attempts/best_teacher.pth' # Will need adjustment if model path or name is different
# Load the state_dict into model_T
model_T.load_state_dict(torch.load(checkpoint_path))
model_T.to(device)

Downloading pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BERT_VDCNN(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine

In [36]:
# Training loop
n_epochs = 20
best_acc = 0.0
for epoch in range(n_epochs):
    train_loss, train_acc = trainWithT(model_S, model_T, train_dataloader, criterion_S, optimizer_S, device)
    print(f"Epoch {epoch+1}/{n_epochs}, train loss: {train_loss:.4f}, acc: {train_acc:.4f}")
    val_loss, val_acc = evaluateSimple(model_S, test_dataloader, device)
    print(f"test loss: {val_loss:.4f}, acc: {val_acc:.4f}")
    if best_acc < val_acc:
      best_acc = val_acc

100%|██████████| 938/938 [08:44<00:00,  1.79it/s]


Epoch 1/20, train loss: 1.3345, acc: 0.8265


100%|██████████| 60/60 [00:00<00:00, 607.85it/s]


test loss: 0.3723, acc: 0.8847


100%|██████████| 938/938 [08:49<00:00,  1.77it/s]


Epoch 2/20, train loss: 0.5524, acc: 0.9114


100%|██████████| 60/60 [00:00<00:00, 630.09it/s]


test loss: 0.3132, acc: 0.9012


100%|██████████| 938/938 [08:46<00:00,  1.78it/s]


Epoch 3/20, train loss: 0.3885, acc: 0.9328


100%|██████████| 60/60 [00:00<00:00, 795.13it/s]


test loss: 0.2867, acc: 0.9095


100%|██████████| 938/938 [08:46<00:00,  1.78it/s]


Epoch 4/20, train loss: 0.2977, acc: 0.9450


100%|██████████| 60/60 [00:00<00:00, 818.16it/s]


test loss: 0.2887, acc: 0.9070


100%|██████████| 938/938 [08:45<00:00,  1.78it/s]


Epoch 5/20, train loss: 0.2398, acc: 0.9530


100%|██████████| 60/60 [00:00<00:00, 537.70it/s]


test loss: 0.2820, acc: 0.9109


100%|██████████| 938/938 [08:48<00:00,  1.78it/s]


Epoch 6/20, train loss: 0.2083, acc: 0.9569


100%|██████████| 60/60 [00:00<00:00, 559.54it/s]


test loss: 0.2896, acc: 0.9089


100%|██████████| 938/938 [08:49<00:00,  1.77it/s]


Epoch 7/20, train loss: 0.1886, acc: 0.9605


100%|██████████| 60/60 [00:00<00:00, 773.74it/s]


test loss: 0.2885, acc: 0.9105


100%|██████████| 938/938 [08:48<00:00,  1.77it/s]


Epoch 8/20, train loss: 0.1761, acc: 0.9626


100%|██████████| 60/60 [00:00<00:00, 816.84it/s]


test loss: 0.3053, acc: 0.9055


100%|██████████| 938/938 [08:48<00:00,  1.78it/s]


Epoch 9/20, train loss: 0.1714, acc: 0.9643


100%|██████████| 60/60 [00:00<00:00, 822.80it/s]


test loss: 0.3004, acc: 0.9112


100%|██████████| 938/938 [08:46<00:00,  1.78it/s]


Epoch 10/20, train loss: 0.1682, acc: 0.9653


100%|██████████| 60/60 [00:00<00:00, 819.55it/s]


test loss: 0.3130, acc: 0.9104


100%|██████████| 938/938 [08:46<00:00,  1.78it/s]


Epoch 11/20, train loss: 0.1654, acc: 0.9663


100%|██████████| 60/60 [00:00<00:00, 805.94it/s]


test loss: 0.3014, acc: 0.9108


100%|██████████| 938/938 [08:46<00:00,  1.78it/s]


Epoch 12/20, train loss: 0.1620, acc: 0.9673


100%|██████████| 60/60 [00:00<00:00, 836.12it/s]


test loss: 0.3062, acc: 0.9109


100%|██████████| 938/938 [08:48<00:00,  1.77it/s]


Epoch 13/20, train loss: 0.1618, acc: 0.9672


100%|██████████| 60/60 [00:00<00:00, 757.21it/s]


test loss: 0.3063, acc: 0.9139


100%|██████████| 938/938 [08:49<00:00,  1.77it/s]


Epoch 14/20, train loss: 0.1603, acc: 0.9686


100%|██████████| 60/60 [00:00<00:00, 842.71it/s]


test loss: 0.3296, acc: 0.9091


100%|██████████| 938/938 [08:49<00:00,  1.77it/s]


Epoch 15/20, train loss: 0.1582, acc: 0.9683


100%|██████████| 60/60 [00:00<00:00, 781.38it/s]


test loss: 0.3085, acc: 0.9093


100%|██████████| 938/938 [08:48<00:00,  1.77it/s]


Epoch 16/20, train loss: 0.1569, acc: 0.9680


100%|██████████| 60/60 [00:00<00:00, 807.53it/s]


test loss: 0.3036, acc: 0.9128


100%|██████████| 938/938 [08:46<00:00,  1.78it/s]


Epoch 17/20, train loss: 0.1554, acc: 0.9681


100%|██████████| 60/60 [00:00<00:00, 823.39it/s]


test loss: 0.3051, acc: 0.9128


100%|██████████| 938/938 [08:46<00:00,  1.78it/s]


Epoch 18/20, train loss: 0.1522, acc: 0.9686


100%|██████████| 60/60 [00:00<00:00, 825.94it/s]


test loss: 0.3058, acc: 0.9113


100%|██████████| 938/938 [08:47<00:00,  1.78it/s]


Epoch 19/20, train loss: 0.1510, acc: 0.9686


100%|██████████| 60/60 [00:00<00:00, 793.65it/s]


test loss: 0.3154, acc: 0.9083


100%|██████████| 938/938 [08:49<00:00,  1.77it/s]


Epoch 20/20, train loss: 0.1505, acc: 0.9682


100%|██████████| 60/60 [00:00<00:00, 804.30it/s]

test loss: 0.2957, acc: 0.9146





In [37]:
print(f'best acc: {best_acc}')

best acc: 0.9146052631578947
