In [1]:
import torch
import torch.nn as nn
from distillation.datafreeDistiller import DataFreeDistiller
from distillation.utils import PseudoDataset, CNN, Generator
from distillation.datasetDistiller import DatasetDistiller
from distillation.utils import SigmoidScaler

# Initialize random models and distiller
imgSize = (3, 32, 32)
noiseDim = 100
student = CNN(imgSize, 64)
teacher = CNN(imgSize, 64)
distiller = DatasetDistiller(pseudoIters=3,
                             studentIters=2,
                             pseudoLR=1e-1,
                             scaler=SigmoidScaler((0,1)),
                             batchSize=64,
                             pseudoSize=imgSize)

# Initialize objectives and optimizer
objective = nn.KLDivLoss(reduction='batchmean')
validObjective = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(student.parameters(), lr=0.1)

# Pseudo dataset and dataloader 
validloader = torch.utils.data.DataLoader(
    PseudoDataset(size=imgSize),
    batch_size=512,
    shuffle=True)

# Load state if checkpoint is provided
checkpoint = None
startEpoch = distiller.load_state(checkpoint, student, teacher, optimizer)
epochs = 15

# Construct tensorboard logger
distiller.init_tensorboard_logger()

for epoch in range(startEpoch, epochs+1):
        # Training step for one full epoch
        trainMetrics = distiller.train_step(student=student,
                                            teacher=teacher,
                                            dataloader=None,
                                            optimizer=optimizer,
                                            objective=objective,
                                            distillObjective=None)
        
        # Validation step for one full epoch
        validMetrics = distiller.validate(student=student,
                                          dataloader=validloader,
                                          objective=validObjective)
        metrics = {**trainMetrics, **validMetrics}
        
        # Log to tensorbard
        distiller.log(epoch, metrics)

        # Save model
        distiller.save(epoch, student, teacher, optimizer)
        
        # Print epoch performance
        distiller.print_epoch(epoch, epochs, metrics)

Epoch:   1/15	 Loss/Pseudo: 0.019	 Loss/Student: 2.113	 Valid/Loss: 4.464	 Valid/Metric: 0.005
Epoch:   2/15	 Loss/Pseudo: 0.339	 Loss/Student: 2.922	 Valid/Loss: 4.486	 Valid/Metric: 0.027
Epoch:   3/15	 Loss/Pseudo: 0.416	 Loss/Student: 2.300	 Valid/Loss: 4.381	 Valid/Metric: 0.077
Epoch:   4/15	 Loss/Pseudo: 0.318	 Loss/Student: 2.183	 Valid/Loss: 4.858	 Valid/Metric: 0.049
Epoch:   5/15	 Loss/Pseudo: 0.435	 Loss/Student: 2.224	 Valid/Loss: 4.717	 Valid/Metric: 0.023
Epoch:   6/15	 Loss/Pseudo: 0.411	 Loss/Student: 2.143	 Valid/Loss: 3.998	 Valid/Metric: 0.038
Epoch:   7/15	 Loss/Pseudo: 0.377	 Loss/Student: 1.616	 Valid/Loss: 4.438	 Valid/Metric: 0.058
Epoch:   8/15	 Loss/Pseudo: 0.480	 Loss/Student: 2.065	 Valid/Loss: 4.566	 Valid/Metric: 0.001
Epoch:   9/15	 Loss/Pseudo: 0.373	 Loss/Student: 1.750	 Valid/Loss: 4.605	 Valid/Metric: 0.031
Epoch:  10/15	 Loss/Pseudo: 0.488	 Loss/Student: 1.608	 Valid/Loss: 4.647	 Valid/Metric: 0.032
Epoch:  11/15	 Loss/Pseudo: 0.447	 Loss/Student: 1

In [2]:
import torch
import torch.nn as nn
from distillation.datasetDistiller import DatasetDistiller
from distillation.utils import PseudoDataset, CNN




# Initialize random models and distiller
imgSize = (3, 32, 32)
student = CNN(imgSize, 64)
teacher = CNN(imgSize, 64)
distiller = DatasetDistiller(pseudoIters=3,
                             studentIters=2,
                             pseudoLR=1e-1,
                             scaler=SigmoidScaler((0,1)),
                             batchSize=64,
                             pseudoSize=imgSize)

# Initialize objectives and optimizer
objective = nn.KLDivLoss(reduction='batchmean')
validObjective = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(student.parameters(), lr=0.1)

# Pseudo dataset and dataloader 
validloader = torch.utils.data.DataLoader(
    PseudoDataset(size=imgSize),
    batch_size=512,
    shuffle=True)

# Load state if checkpoint is provided
checkpoint = None
startEpoch = distiller.load_state(checkpoint, student, teacher, optimizer)
epochs = 15

# Construct tensorboard logger
distiller.init_tensorboard_logger()

for epoch in range(startEpoch, epochs+1):
        # Training step for one full epoch
        trainMetrics = distiller.train_step(student=student,
                                            teacher=teacher,
                                            dataloader=None,
                                            optimizer=optimizer,
                                            objective=objective,
                                            distillObjective=None)
        
        # Validation step for one full epoch
        validMetrics = distiller.validate(student=student,
                                          dataloader=validloader,
                                          objective=validObjective)
        metrics = {**trainMetrics, **validMetrics}
        
        # Log to tensorbard
        distiller.log(epoch, metrics)

        # Save model
        distiller.save(epoch, student, teacher, optimizer)
        
        # Print epoch performance
        distiller.print_epoch(epoch, epochs, metrics)

Epoch:   1/15	 Loss/Pseudo: 0.019	 Loss/Student: 2.177	 Valid/Loss: 4.061	 Valid/Metric: 0.081
Epoch:   2/15	 Loss/Pseudo: 0.267	 Loss/Student: 3.182	 Valid/Loss: 4.766	 Valid/Metric: 0.003
Epoch:   3/15	 Loss/Pseudo: 0.338	 Loss/Student: 2.260	 Valid/Loss: 4.315	 Valid/Metric: 0.056
Epoch:   4/15	 Loss/Pseudo: 0.342	 Loss/Student: 2.396	 Valid/Loss: 4.540	 Valid/Metric: 0.001
Epoch:   5/15	 Loss/Pseudo: 0.329	 Loss/Student: 2.095	 Valid/Loss: 4.402	 Valid/Metric: 0.072
Epoch:   6/15	 Loss/Pseudo: 0.303	 Loss/Student: 1.921	 Valid/Loss: 4.391	 Valid/Metric: 0.088
Epoch:   7/15	 Loss/Pseudo: 0.379	 Loss/Student: 1.965	 Valid/Loss: 4.452	 Valid/Metric: 0.018
Epoch:   8/15	 Loss/Pseudo: 0.375	 Loss/Student: 1.849	 Valid/Loss: 4.541	 Valid/Metric: 0.014
Epoch:   9/15	 Loss/Pseudo: 0.394	 Loss/Student: 2.082	 Valid/Loss: 4.922	 Valid/Metric: 0.010
Epoch:  10/15	 Loss/Pseudo: 0.439	 Loss/Student: 1.509	 Valid/Loss: 4.253	 Valid/Metric: 0.001
Epoch:  11/15	 Loss/Pseudo: 0.469	 Loss/Student: 1

In [2]:
import torch
import torch.nn as nn
from distillation.datasetDistiller import DatasetDistiller
from distillation.utils import PseudoDataset, CNN
from torchvision import models
from torchvision import datasets, transforms

# Scaler
class ScaleWrapper(nn.Module):
    def __init__(self, interval):
        super(ScaleWrapper, self).__init__()
        
    def _scaler(self, x):
        raise NotImplementedError('_scaler should be implemented in descendent of ScaleWrap class!')
        
    def __call__(self, x):        
        # Apply scaler and return to interval scale.
        return (self.interval[1]-self.interval[0])*self._scaler(x) + self.interval[0]
        
class SigmoidScaler(ScaleWrapper):
    def __init__(self, interval, p=2.463):
        super(SigmoidScaler, self).__init__(interval)
        self.p = p
        self.interval = interval
        self.scale = (self.interval[1] - self.interval[0])/2
        self.center = (self.interval[0] + self.interval[1])/2
        
    def _scaler(self, x):
        return torch.sigmoid(self.p/self.scale * (x - self.center))


# Initialize random models and distiller
imgSize = (3, 32, 32)
student = CNN(imgSize, 10)
teacher = models.alexnet(pretrained=True)
teacher.features[0] = nn.Conv2d(3, 64, kernel_size=(2, 2), stride=(1, 1), padding=(3, 3))
teacher.classifier[6] = nn.Linear(4096, 10)
distiller = DatasetDistiller(pseudoIters=3,
                             studentIters=2,
                             pseudoLR=1e-2,
                             scaler=SigmoidScaler((0,1)),
                             batchSize=10,
                             pseudoSize=imgSize)

# Initialize objectives and optimizer
objective = nn.KLDivLoss(reduction='batchmean')
validObjective = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(student.parameters(), lr=0.1)

# Pseudo dataset and dataloader
validloader = torch.utils.data.DataLoader( 
            datasets.CIFAR10('data',
                             train=False,
                             download=True,
                             transform=transforms.Compose([
                                 transforms.RandomCrop(32, padding=4),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                             ])),
            batch_size=8,
            shuffle=True,
            num_workers=4)

# Load state if checkpoint is provided
checkpoint = None
startEpoch = distiller.load_state(checkpoint, student, teacher, optimizer)
epochs = 15

# Construct tensorboard logger
distiller.init_tensorboard_logger()

for epoch in range(startEpoch, epochs+1):
        # Training step for one full epoch
        trainMetrics = distiller.train_step(student=student,
                                            teacher=teacher,
                                            dataloader=None,
                                            optimizer=optimizer,
                                            objective=objective,
                                            distillObjective=None)
        
        # Validation step for one full epoch
        validMetrics = distiller.validate(student=student,
                                          dataloader=validloader,
                                          objective=validObjective)
        
        metrics = {**trainMetrics, **validMetrics}
        
        # Log to tensorbard
        distiller.log(epoch, metrics)

        # Save model
        distiller.save(epoch, student, teacher, optimizer)
        
        # Print epoch performance
        distiller.print_epoch(epoch, epochs, metrics)

Files already downloaded and verified
Epoch:   1/15	 Loss/Pseudo: 0.030	 Loss/Student: 14.693	 Valid/Loss: 8.820	 Valid/Metric: 0.100
Epoch:   2/15	 Loss/Pseudo: 2.779	 Loss/Student: 13.850	 Valid/Loss: 8.825	 Valid/Metric: 0.100
Epoch:   3/15	 Loss/Pseudo: 2.748	 Loss/Student: 14.720	 Valid/Loss: 8.051	 Valid/Metric: 0.100
Epoch:   4/15	 Loss/Pseudo: 2.410	 Loss/Student: 14.289	 Valid/Loss: 8.543	 Valid/Metric: 0.100
Epoch:   5/15	 Loss/Pseudo: 2.838	 Loss/Student: 12.676	 Valid/Loss: 9.321	 Valid/Metric: 0.103
Epoch:   6/15	 Loss/Pseudo: 2.859	 Loss/Student: 9.734	 Valid/Loss: 11.452	 Valid/Metric: 0.100
Epoch:   7/15	 Loss/Pseudo: 2.428	 Loss/Student: 10.201	 Valid/Loss: 11.617	 Valid/Metric: 0.100
Epoch:   8/15	 Loss/Pseudo: 3.046	 Loss/Student: 12.765	 Valid/Loss: 11.770	 Valid/Metric: 0.100
Epoch:   9/15	 Loss/Pseudo: 3.667	 Loss/Student: 11.770	 Valid/Loss: 9.640	 Valid/Metric: 0.100
Epoch:  10/15	 Loss/Pseudo: 2.771	 Loss/Student: 7.450	 Valid/Loss: 11.943	 Valid/Metric: 0.100


KeyboardInterrupt: 

In [17]:
teacher.features[0] = nn.Conv2d(3, 64, kernel_size=(2, 2), stride=(1, 1), padding=(3, 3))

In [19]:
teacher(torch.rand((8,3,32,32))).shape

torch.Size([8, 10])