# Knowledge Distillation with BERT

## Overview
An advanced implementation of knowledge distillation using a custom distillation loss for sentiment classification.

In [2]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

In [3]:
# Import necessary libraries
import torch
import torch.nn as nn
from transformers import BertForSequenceClassification, BertTokenizer, BertModel
from datasets import load_dataset
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm


# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


## Fining Tuning BERT

In [None]:
# Step 1: Import Required Libraries
import torch
from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments

# Step 2: Load the SST-2 Dataset
dataset = load_dataset('glue', 'sst2')

# Step 3: Preprocess the Data
# Load the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenize the dataset
def tokenize_function(examples):
    return tokenizer(examples['sentence'], padding='max_length', truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Step 4: Set up the Model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

# Step 5: Define the Trainer and Training Arguments
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    eval_strategy="epoch",           # change to eval_strategy
    learning_rate=2e-5,              # learning rate
    per_device_train_batch_size=16,  # batch size for training
    per_device_eval_batch_size=16,   # batch size for evaluation
    num_train_epochs=3,              # number of training epochs
    weight_decay=0.01,               # strength of weight decay
)




trainer = Trainer(
    model=model,                     # the model to train
    args=training_args,              # training arguments
    train_dataset=tokenized_datasets['train'],   # training dataset
    eval_dataset=tokenized_datasets['validation'], # evaluation dataset
)

# Step 6: Fine-Tune the Model
trainer.train()

# Step 7: Evaluate the Model
results = trainer.evaluate()

# Print evaluation results
print(results)

# Save the model and tokenizer
model.save_pretrained('./fine_tuned_bert')
tokenizer.save_pretrained('./fine_tuned_bert')


In [None]:
from transformers import BertForSequenceClassification, BertTokenizer, Trainer, TrainingArguments
from datasets import load_dataset

# Load fine-tuned model and tokenizer
model = BertForSequenceClassification.from_pretrained('./fine_tuned_bert')
tokenizer = BertTokenizer.from_pretrained('./fine_tuned_bert')

# Load dataset
dataset = load_dataset('glue', 'sst2')

# Tokenize the dataset
def tokenize_function(examples):
    return tokenizer(examples['sentence'], padding='max_length', truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Convert to PyTorch format
tokenized_datasets = tokenized_datasets.with_format("torch", columns=["input_ids", "attention_mask", "label"])

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',          # Directory to save results
    per_device_eval_batch_size=32,  # Batch size for evaluation
    logging_dir='./logs',           # Directory for logs
    eval_strategy="epoch",    # Evaluate every epoch
)

from sklearn.metrics import accuracy_score

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = logits.argmax(axis=-1)
    return {'accuracy': accuracy_score(labels, predictions)}

# Set up Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    eval_dataset=tokenized_datasets['validation'],
    compute_metrics=compute_metrics
)

# Evaluate the model
results = trainer.evaluate()
print(results)


{'eval_loss': 0.33054065704345703, 'eval_model_preparation_time': 0.0016, 'eval_accuracy': 0.926605504587156, 'eval_runtime': 8.9137, 'eval_samples_per_second': 97.827, 'eval_steps_per_second': 3.141}


## Import BERT Teacher Model

In [None]:
from google.colab import files
uploaded = files.upload()


In [None]:
!ls


pytorch_model.bin  sample_data


In [None]:
from transformers import BertForSequenceClassification, BertTokenizer

# Define the path to your pytorch_model.bin
model_path = './pytorch_model.bin'

# Load the tokenizer and model using the checkpoint
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
teacher_model = BertForSequenceClassification.from_pretrained(model_path)

# Ensure the model is in evaluation mode
teacher_model.eval()


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

OSError: Incorrect path_or_model_id: './pytorch_model.bin'. Please provide either the path to a local folder or the repo_id of a model on the Hub.

## My Distillation Class

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):
    """
    Implements the distillation loss for MINILM as described in the paper.

    This loss consists of two components:
    1. Attention Distribution Transfer Loss
    2. Value-Relation Transfer Loss

    The loss captures the knowledge transfer between teacher and student models
    by minimizing the KL-divergence of attention distributions and value relations.
    """

    def __init__(self, temperature=2.0):
        """
        Initializes the DistillationLoss module.

        Args:
            temperature (float): Softening temperature for the softmax distributions.
                               Default is 2.0 as suggested in the paper.
        """
        super(DistillationLoss, self).__init__()
        self.temperature = temperature

    def _compute_attention_distribution(self, attention_scores):
        """
        Computes the attention distribution using scaled dot-product.

        Args:
            queries (torch.Tensor): Query vectors
            keys (torch.Tensor): Key vectors

        Returns:
            torch.Tensor: Attention distribution
        """


        # Apply softmax with temperature scaling
        attention_dist = F.softmax(attention_scores / self.temperature, dim=-1)
        return attention_dist

    def _compute_value_relation(self, values):
        """
        Computes the value relation matrix using scaled dot-product.

        Args:
            values (torch.Tensor): Value vectors

        Returns:
            torch.Tensor: Value relation matrix
        """
        # Compute scaled dot-product between values
        value_relation = torch.matmul(values, values.transpose(-2, -1))
        value_relation = value_relation / (values.size(-1) ** 0.5)

        # Apply softmax
        value_relation_dist = F.softmax(value_relation, dim=-1)
        return value_relation_dist

    def forward(self,
            teacher_A, teacher_values,
            student_A, student_values):
        """
        Computes the distillation loss between teacher and student models
        with proper LAT (attention transfer) and VR (value relation) scaling.

        Args:
            teacher_A (torch.Tensor): Attention distribution from the teacher model
            teacher_values (torch.Tensor): Value vectors from the teacher model
            student_A (torch.Tensor): Attention distribution from the student model
            student_values (torch.Tensor): Value vectors from the student model

        Returns:
            torch.Tensor: Total distillation loss with LAT and VR scaling.
        """

        # Ensure all tensors require gradients
        # Check gradients of relevant tensors

        # Ensure all tensors involved in loss computation are part of the computation graph


        student_A = self._compute_attention_distribution(student_A)
        teacher_A = self._compute_attention_distribution(teacher_A)


        # Attention Distribution Transfer (LAT) - KL Divergence
        kl_attention = F.kl_div(
            torch.log(student_A + 1e-8),  # Log with stability added
            teacher_A, reduction="none"
        )  # Shape: [batch, A_h, |x|, |x|]

        Ah = student_A.size(1)  # Number of attention heads
        X = student_A.size(2)   # Sequence length

        # Summing over sequence positions and averaging over attention heads
        attention_transfer_loss = kl_attention.sum(dim=(-2, -1)).mean() / (Ah * X)

        # Value Relation Transfer (VR) - KL Divergence for value relations
        teacher_value_relation = self._compute_value_relation(teacher_values)
        student_value_relation = self._compute_value_relation(student_values)

        kl_value_relation = F.kl_div(
            torch.log(student_value_relation + 1e-8),  # Log with stability added
            teacher_value_relation, reduction="none"
        )  # Shape: [batch, A_h, |x|, |x|]

        # Summing over sequence positions and averaging over attention heads
        value_relation_loss = kl_value_relation.sum(dim=(-2, -1)).mean() / (Ah * X)

        # Total loss as the sum of both components
        total_loss = attention_transfer_loss + value_relation_loss

        # Ensure total loss requires gradients
        assert total_loss.requires_grad, "Total loss must require gradients!"

        return total_loss


## Knowledge Distillation

In [None]:
# Load dataset
dataset = load_dataset('glue', 'sst2')['train']

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenize and prepare data
def prepare_data(examples, max_length=512):
    encodings = tokenizer(
        examples['sentence'],
        truncation=True,
        padding='max_length',
        max_length=max_length
    )

    return {
        'input_ids': torch.tensor(encodings['input_ids']),
        'attention_mask': torch.tensor(encodings['attention_mask']),
        'labels': torch.tensor(examples['label'])
    }

# Prepare dataset
processed_data = prepare_data(dataset)

In [None]:
# Custom Student Model with Reduced Complexity
class StudentBertModel(nn.Module):
    def __init__(self, num_labels=2, hidden_size=768, num_hidden_layers=6):
        super(StudentBertModel, self).__init__()

        # Base BERT model with reduced layers
        self.bert = BertModel.from_pretrained('bert-base-uncased',
                                               num_hidden_layers=num_hidden_layers,
                                               attn_implementation="eager")

        # Simplified classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 2, num_labels)
        )

    def forward(self, input_ids, attention_mask):
        # Get BERT outputs
        bert_outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        # Use the pooled output for classification
        pooled_output = bert_outputs.pooler_output
        logits = self.classifier(pooled_output)

        return logits

    def get_attention_and_values(self, input_ids, attention_mask):
        # Get BERT outputs including attention
        bert_outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_attentions=True
        )

        # Return attention and values
        return (
            bert_outputs.attentions,  # Attention weights
            bert_outputs.last_hidden_state  # Value vectors
        )

In [None]:
# Training Hyperparameters
batch_size = 32
num_epochs = 100
learning_rate = 2e-5

# Create DataLoader
dataset = TensorDataset(
    processed_data['input_ids'],
    processed_data['attention_mask'],
    processed_data['labels']
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize teacher and student models
teacher_model = model

teacher_base_model = teacher_model.bert  # Extract the base BERT model
student_model = StudentBertModel(num_labels=2).to(device)

# Freeze teacher model
for param in teacher_model.parameters():
    param.requires_grad = False

# Training Setup
optimizer = torch.optim.AdamW(student_model.parameters(), lr=2e-5)
distillation_criterion = DistillationLoss(temperature=2.0).to(device)
classification_criterion = nn.CrossEntropyLoss()



In [None]:
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

# Initialize TensorBoard writer
writer = SummaryWriter(log_dir="runs/knowledge_distillation")

# Training Loop
for epoch in range(num_epochs):
    student_model.train()
    teacher_model.eval()
    total_loss = 0

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
    for batch_idx, batch in enumerate(progress_bar):
        # Unpack batch
        input_ids, attention_mask, labels = [b.to(device) for b in batch]

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass for teacher
        with torch.no_grad():
            teacher_output = teacher_model(input_ids, attention_mask=attention_mask)
            teacher_bert_output = teacher_base_model(
                input_ids,
                attention_mask=attention_mask,
                output_attentions=True
            )
            teacher_attention = teacher_bert_output.attentions
            teacher_values = teacher_bert_output.last_hidden_state

        # Forward pass for student
        student_logits = student_model(input_ids, attention_mask)
        student_attention, student_values = student_model.get_attention_and_values(input_ids, attention_mask)

        # Compute distillation loss
        student_A = student_attention[-1]
        teacher_A = teacher_attention[-1]

        knowledge_loss = distillation_criterion(
            teacher_A, teacher_values,
            student_A, student_values
        )

        # Compute classification loss
        classification_loss = classification_criterion(student_logits, labels)

        # Combine losses
        total_batch_loss = knowledge_loss

        # Backward pass
        total_batch_loss.backward()
        optimizer.step()

        total_loss += total_batch_loss.item()

        # Log batch loss to TensorBoard
        writer.add_scalar("Loss/Batch", total_batch_loss.item(), epoch * len(dataloader) + batch_idx)

        progress_bar.set_description(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_batch_loss.item():.4f}")

    # Compute and log average training loss
    avg_loss = total_loss / len(dataloader)
    writer.add_scalar("Loss/Epoch", avg_loss, epoch)

    print(f"Epoch {epoch+1}, Training Loss: {avg_loss:.4f}")

    # Validation
    student_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in dataloader:
            input_ids, attention_mask, labels = [b.to(device) for b in batch]

            logits = student_model(input_ids, attention_mask)
            predictions = torch.argmax(logits, dim=1)

            correct += torch.sum(predictions == labels).item()
            total += len(labels)

    # Compute and log validation accuracy
    accuracy = correct / total
    writer.add_scalar("Accuracy/Validation", accuracy, epoch)
    print(f"Epoch {epoch+1}, Validation Accuracy: {accuracy:.4f}")

# Close the TensorBoard writer
writer.close()


                                                                          

Epoch 1, Training Loss: 0.0140
Epoch 1, Validation Accuracy: 0.9130


                                                                          

Epoch 2, Training Loss: 0.0138
Epoch 2, Validation Accuracy: 0.9090


                                                                          

Epoch 3, Training Loss: 0.0135
Epoch 3, Validation Accuracy: 0.9090


                                                                          

Epoch 4, Training Loss: 0.0133
Epoch 4, Validation Accuracy: 0.9090


                                                                          

Epoch 5, Training Loss: 0.0131
Epoch 5, Validation Accuracy: 0.9090


                                                                          

Epoch 6, Training Loss: 0.0127
Epoch 6, Validation Accuracy: 0.9090


                                                                          

Epoch 7, Training Loss: 0.0126
Epoch 7, Validation Accuracy: 0.9100


                                                                          

Epoch 8, Training Loss: 0.0125
Epoch 8, Validation Accuracy: 0.9100


                                                                          

Epoch 9, Training Loss: 0.0122
Epoch 9, Validation Accuracy: 0.9110


                                                                           

Epoch 10, Training Loss: 0.0122
Epoch 10, Validation Accuracy: 0.9100


                                                                           

Epoch 11, Training Loss: 0.0117
Epoch 11, Validation Accuracy: 0.9130


                                                                           

Epoch 12, Training Loss: 0.0118
Epoch 12, Validation Accuracy: 0.9130


                                                                           

Epoch 13, Training Loss: 0.0114
Epoch 13, Validation Accuracy: 0.9130


                                                                           

Epoch 14, Training Loss: 0.0115
Epoch 14, Validation Accuracy: 0.9130


                                                                           

Epoch 15, Training Loss: 0.0113
Epoch 15, Validation Accuracy: 0.9140


                                                                           

Epoch 16, Training Loss: 0.0112
Epoch 16, Validation Accuracy: 0.9150


                                                                           

Epoch 17, Training Loss: 0.0110
Epoch 17, Validation Accuracy: 0.9160


                                                                           

Epoch 18, Training Loss: 0.0109
Epoch 18, Validation Accuracy: 0.9160


                                                                           

Epoch 19, Training Loss: 0.0107
Epoch 19, Validation Accuracy: 0.9160


                                                                           

Epoch 20, Training Loss: 0.0107
Epoch 20, Validation Accuracy: 0.9150


                                                                           

Epoch 21, Training Loss: 0.0103
Epoch 21, Validation Accuracy: 0.9130


                                                                           

Epoch 22, Training Loss: 0.0103
Epoch 22, Validation Accuracy: 0.9110


                                                                           

Epoch 23, Training Loss: 0.0102
Epoch 23, Validation Accuracy: 0.9110


                                                                           

Epoch 24, Training Loss: 0.0103
Epoch 24, Validation Accuracy: 0.9110


                                                                           

Epoch 25, Training Loss: 0.0100
Epoch 25, Validation Accuracy: 0.9100


                                                                           

Epoch 26, Training Loss: 0.0100
Epoch 26, Validation Accuracy: 0.9080


                                                                           

Epoch 27, Training Loss: 0.0098
Epoch 27, Validation Accuracy: 0.9080


                                                                           

Epoch 28, Training Loss: 0.0095
Epoch 28, Validation Accuracy: 0.9050


                                                                           

Epoch 29, Training Loss: 0.0095
Epoch 29, Validation Accuracy: 0.9020


                                                                           

Epoch 30, Training Loss: 0.0095
Epoch 30, Validation Accuracy: 0.9020


                                                                           

Epoch 31, Training Loss: 0.0092
Epoch 31, Validation Accuracy: 0.9010


                                                                           

Epoch 32, Training Loss: 0.0092
Epoch 32, Validation Accuracy: 0.8980


                                                                           

Epoch 33, Training Loss: 0.0090
Epoch 33, Validation Accuracy: 0.8960


                                                                           

Epoch 34, Training Loss: 0.0090
Epoch 34, Validation Accuracy: 0.8940


                                                                           

Epoch 35, Training Loss: 0.0090
Epoch 35, Validation Accuracy: 0.8920


                                                                           

Epoch 36, Training Loss: 0.0090
Epoch 36, Validation Accuracy: 0.8910


                                                                           

Epoch 37, Training Loss: 0.0086
Epoch 37, Validation Accuracy: 0.8910


                                                                           

Epoch 38, Training Loss: 0.0086
Epoch 38, Validation Accuracy: 0.8890


                                                                           

Epoch 39, Training Loss: 0.0087
Epoch 39, Validation Accuracy: 0.8880


                                                                           

Epoch 40, Training Loss: 0.0085
Epoch 40, Validation Accuracy: 0.8900


                                                                           

Epoch 41, Training Loss: 0.0085
Epoch 41, Validation Accuracy: 0.8880


                                                                           

Epoch 42, Training Loss: 0.0084
Epoch 42, Validation Accuracy: 0.8880


                                                                           

Epoch 43, Training Loss: 0.0082
Epoch 43, Validation Accuracy: 0.8880


                                                                           

Epoch 44, Training Loss: 0.0081
Epoch 44, Validation Accuracy: 0.8880


                                                                           

Epoch 45, Training Loss: 0.0081
Epoch 45, Validation Accuracy: 0.8910


                                                                           

Epoch 46, Training Loss: 0.0080
Epoch 46, Validation Accuracy: 0.8950


                                                                           

Epoch 47, Training Loss: 0.0079
Epoch 47, Validation Accuracy: 0.8940


                                                                           

Epoch 48, Training Loss: 0.0079
Epoch 48, Validation Accuracy: 0.8970


                                                                           

Epoch 49, Training Loss: 0.0079
Epoch 49, Validation Accuracy: 0.8960


                                                                           

Epoch 50, Training Loss: 0.0078
Epoch 50, Validation Accuracy: 0.9040


                                                                           

Epoch 51, Training Loss: 0.0076
Epoch 51, Validation Accuracy: 0.9010


                                                                           

Epoch 52, Training Loss: 0.0075
Epoch 52, Validation Accuracy: 0.9030


                                                                           

Epoch 53, Training Loss: 0.0075
Epoch 53, Validation Accuracy: 0.9000


                                                                           

Epoch 54, Training Loss: 0.0076
Epoch 54, Validation Accuracy: 0.9010


                                                                           

Epoch 55, Training Loss: 0.0074
Epoch 55, Validation Accuracy: 0.9020


                                                                           

Epoch 56, Training Loss: 0.0072
Epoch 56, Validation Accuracy: 0.9010


                                                                           

Epoch 57, Training Loss: 0.0072
Epoch 57, Validation Accuracy: 0.9030


                                                                           

Epoch 58, Training Loss: 0.0072
Epoch 58, Validation Accuracy: 0.9050


                                                                           

Epoch 59, Training Loss: 0.0072
Epoch 59, Validation Accuracy: 0.9040


                                                                           

Epoch 60, Training Loss: 0.0072
Epoch 60, Validation Accuracy: 0.9010


                                                                           

Epoch 61, Training Loss: 0.0069
Epoch 61, Validation Accuracy: 0.8990


                                                                           

Epoch 62, Training Loss: 0.0068
Epoch 62, Validation Accuracy: 0.8990


                                                                           

Epoch 63, Training Loss: 0.0066
Epoch 63, Validation Accuracy: 0.8980


                                                                           

Epoch 64, Training Loss: 0.0067
Epoch 64, Validation Accuracy: 0.8970


                                                                           

Epoch 65, Training Loss: 0.0066
Epoch 65, Validation Accuracy: 0.8980


                                                                           

Epoch 66, Training Loss: 0.0065
Epoch 66, Validation Accuracy: 0.8990


                                                                           

Epoch 67, Training Loss: 0.0064
Epoch 67, Validation Accuracy: 0.8990


                                                                           

Epoch 68, Training Loss: 0.0063
Epoch 68, Validation Accuracy: 0.8950


                                                                           

Epoch 69, Training Loss: 0.0063
Epoch 69, Validation Accuracy: 0.8950


                                                                           

Epoch 70, Training Loss: 0.0062
Epoch 70, Validation Accuracy: 0.8950


                                                                           

Epoch 71, Training Loss: 0.0062
Epoch 71, Validation Accuracy: 0.8950


                                                                           

Epoch 72, Training Loss: 0.0061
Epoch 72, Validation Accuracy: 0.8920


                                                                           

Epoch 73, Training Loss: 0.0061
Epoch 73, Validation Accuracy: 0.8920


                                                                           

Epoch 74, Training Loss: 0.0058
Epoch 74, Validation Accuracy: 0.8890


                                                                           

Epoch 75, Training Loss: 0.0058
Epoch 75, Validation Accuracy: 0.8860


                                                                           

Epoch 76, Training Loss: 0.0059
Epoch 76, Validation Accuracy: 0.8860


                                                                           

Epoch 77, Training Loss: 0.0058
Epoch 77, Validation Accuracy: 0.8830


                                                                           

Epoch 78, Training Loss: 0.0057
Epoch 78, Validation Accuracy: 0.8800


                                                                           

Epoch 79, Training Loss: 0.0055
Epoch 79, Validation Accuracy: 0.8760


                                                                           

Epoch 80, Training Loss: 0.0055
Epoch 80, Validation Accuracy: 0.8750


                                                                           

Epoch 81, Training Loss: 0.0055
Epoch 81, Validation Accuracy: 0.8700


                                                                           

Epoch 82, Training Loss: 0.0053
Epoch 82, Validation Accuracy: 0.8670


                                                                           

Epoch 83, Training Loss: 0.0053
Epoch 83, Validation Accuracy: 0.8660


                                                                           

Epoch 84, Training Loss: 0.0052
Epoch 84, Validation Accuracy: 0.8640


                                                                           

Epoch 85, Training Loss: 0.0052
Epoch 85, Validation Accuracy: 0.8640


                                                                           

Epoch 86, Training Loss: 0.0051
Epoch 86, Validation Accuracy: 0.8620


                                                                           

Epoch 87, Training Loss: 0.0051
Epoch 87, Validation Accuracy: 0.8600


                                                                           

Epoch 88, Training Loss: 0.0050
Epoch 88, Validation Accuracy: 0.8570


                                                                           

Epoch 89, Training Loss: 0.0049
Epoch 89, Validation Accuracy: 0.8540


                                                                           

Epoch 90, Training Loss: 0.0049
Epoch 90, Validation Accuracy: 0.8530


                                                                           

Epoch 91, Training Loss: 0.0048
Epoch 91, Validation Accuracy: 0.8480


                                                                           

Epoch 92, Training Loss: 0.0048
Epoch 92, Validation Accuracy: 0.8470


                                                                           

Epoch 93, Training Loss: 0.0047
Epoch 93, Validation Accuracy: 0.8470


                                                                           

Epoch 94, Training Loss: 0.0046
Epoch 94, Validation Accuracy: 0.8510


                                                                           

Epoch 95, Training Loss: 0.0046
Epoch 95, Validation Accuracy: 0.8470


                                                                           

Epoch 96, Training Loss: 0.0046
Epoch 96, Validation Accuracy: 0.8480


                                                                           

Epoch 97, Training Loss: 0.0045
Epoch 97, Validation Accuracy: 0.8480


                                                                           

Epoch 98, Training Loss: 0.0044
Epoch 98, Validation Accuracy: 0.8470


                                                                           

Epoch 99, Training Loss: 0.0043
Epoch 99, Validation Accuracy: 0.8480


                                                                            

Epoch 100, Training Loss: 0.0043
Epoch 100, Validation Accuracy: 0.8520


In [None]:

# Save student model
torch.save(student_model.state_dict(), './runs/student_model.pt')
