# Experiments and investigation
Generally one can consider two settings that could be interesting to investigate:
- Decaying the impact of the teacher model with the epochs. As the student becomes better, the hints might become a restriction and the student might overfit to the teacher.
- Despite the teacher being better than the student, it still creates errors and adapting the tranferred knowledge to only include those samples that the teacher predicts correctly, might allow the student to more freely learn what the teacher can not.

In [4]:
import torch
import torch.nn as nn
from distillation.hintonDistiller import HintonDistiller
from distillation.utils import CNN, Accuracy
from torchvision import datasets, transforms
from trainer.baseTrainer import BaseTrainer

In [10]:
# Initialize random models and distiller
imgSize = (1, 28, 28)
batchSize = 256

# Prepare data
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

trainloader = torch.utils.data.DataLoader(
    datasets.MNIST('data/',
                   train=True,
                   download=True,
                   transform=transform
                  ),
    batch_size=batchSize)

validloader = torch.utils.data.DataLoader(
    datasets.MNIST('data/',
                   train=False,
                   download=True,
                   transform=transform
                  ),
    batch_size=batchSize)

## Train teacher on MNIST

In [11]:
EPOCHS = 5

In [12]:
teacher = CNN(imgSize, 64)

# Initialize objectives and optimizer
objective = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(teacher.parameters(), lr=1e-2)
metric = Accuracy(OH=False)

# Initialize trainer
trainer = BaseTrainer()

# Load state if checkpoint is provided otherwise .load_state merely returns 1 as startEpoch
checkpoint = None
startEpoch = trainer.load_state(checkpoint, teacher, optimizer)

# Construct checkpoint directory
trainer.save(0, teacher, optimizer, {}, subDirectory='train')
print('Saves checkpoint in {:s}'.format(trainer.checkpointDir))

print('\n========= Training =========')
for epoch in range(startEpoch, EPOCHS+1):
    # Training step for one full epoch
    trainMetrics = trainer.train_step(model=teacher,
                                      dataloader=trainloader,
                                      objective=objective,
                                      metric=metric,
                                      optimizer=optimizer)
        
    # Validation step for one full epoch
    validMetrics = trainer.validate(model=teacher,
                                    dataloader=validloader,
                                    objective=objective,
                                    metric=metric)

    metrics = {**trainMetrics, **validMetrics}

    # Save teacher
    trainer.save(epoch, teacher, optimizer, metrics, subDirectory='train')
        
    # Print epoch performance
    trainer.print_epoch(epoch, EPOCHS, metrics)

Saves checkpoint in checkpoint/train/20200908-131208

Epoch:   1/5	 Train/Loss: 0.440	 Train/Metric: 0.886	 Valid/Loss: 0.211	 Valid/Metric: 0.939
Epoch:   2/5	 Train/Loss: 0.176	 Train/Metric: 0.952	 Valid/Loss: 0.157	 Valid/Metric: 0.955
Epoch:   3/5	 Train/Loss: 0.136	 Train/Metric: 0.963	 Valid/Loss: 0.129	 Valid/Metric: 0.964
Epoch:   4/5	 Train/Loss: 0.114	 Train/Metric: 0.970	 Valid/Loss: 0.113	 Valid/Metric: 0.967
Epoch:   5/5	 Train/Loss: 0.099	 Train/Metric: 0.975	 Valid/Loss: 0.101	 Valid/Metric: 0.969


## Distill with Hinton KD

In [None]:
student = CNN(imgSize, 32)
distiller = HintonDistiller(alpha=0.1,
                            studentLayer=-2,
                            teacherLayer=-2)

# Initialize objectives and optimizer
objective = nn.CrossEntropyLoss()
distillObjective = nn.KLDivLoss(reduction='batchmean')
optimizer = torch.optim.SGD(student.parameters(), lr=1e-2)

# Load state if checkpoint is provided
checkpoint = None
startEpoch = distiller.load_state(checkpoint, student, teacher, optimizer)
epochs = 2

for epoch in range(startEpoch, epochs+1):
        # Training step for one full epoch
        trainMetrics = distiller.train_step(student=student,
                                            teacher=teacher,
                                            dataloader=trainloader,
                                            optimizer=optimizer,
                                            objective=objective,
                                            distillObjective=distillObjective)
        
        # Validation step for one full epoch
        validMetrics = distiller.validate(student=student,
                                          dataloader=validloader,
                                          objective=objective)
        metrics = {**trainMetrics, **validMetrics}

        # Save model
        distiller.save(epoch, student, teacher, optimizer)
        
        # Print epoch performance
        distiller.print_epoch(epoch, epochs, metrics)