In [None]:
import torch

from torchvision import datasets

import numpy as np

from tqdm.notebook import tqdm

from torch.utils.tensorboard import SummaryWriter

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

# Модель

In [None]:
class BayesianNetworkFNN(torch.nn.Module):
    @property
    def device(self):
        return next(self.parameters()).device
    
    def __init__(self, 
                 input_dim=10, 
                 output_dim=1, 
                 layers=[],
                 mu_prior=None,
                 sigma_prior=None):
        r'''
        Initialise Neural Network for Bayesian inference

        :param input_dim: size of feature space
        :type input_dim: int
        :param output_dim: size of target space
        :type output_dim: int
        :param layers: sizes of each layer for fully connected neural network
        :type layers: list[int]
        :param mu_prior: parameter of prior for :math:`\mathsf(E)w`
        :type mu_prior: array
        :param sigma_prior: parameter of prior for :math:`\mathsf(V)w`
        :type sigma_prior: array
        '''
        super(BayesianNetworkFNN, self).__init__()

        layers = [input_dim] + \
                 [layer for layer in layers if layer != 0] + \
                 [output_dim]

        self.network = torch.nn.Sequential()
        for i in range(1, len(layers)):
            self.network.add_module('layer{}'.format(i), 
                                    torch.nn.Linear(layers[i-1], 
                                                    layers[i]))
            if i != len(layers) - 1:
                self.network.add_module('relu{}'.format(i), 
                                        torch.nn.LeakyReLU())
        weight = self.get_weight(self.network)
        
        self.log_sigma_posterior = torch.nn.Parameter(
            torch.zeros_like(weight).float(), requires_grad=True)
        self.log_sigma_prior = torch.nn.Parameter(
            torch.zeros_like(weight).float(), requires_grad=True)
        self.mu_prior = torch.nn.Parameter(
            weight.data, requires_grad=True)
        
        if sigma_prior is not None:
            self.log_sigma_prior = torch.nn.Parameter(
                torch.log(torch.tensor(sigma_prior).float()), requires_grad=True)

        if mu_prior is not None:
            self.mu_prior = torch.nn.Parameter(
                torch.tensor(mu_prior).float(), requires_grad=True)
            
            self.set_weight(self.network, weight, self.mu_prior.data)

        assert len(self.log_sigma_prior) == len(weight), \
          'length of sigma_prior and weight must bu equels but {} != {}'.format(
              len(self.log_sigma_prior), len(weight)
          )
        assert len(self.mu_prior) == len(weight), \
          'length of mu_prior and weight must bu equels but {} != {}'.format(
              len(self.mu_prior), len(weight)
          )

    def posterior_parameters(self):
        yield self.log_sigma_posterior
        for param in self.network.parameters():
            yield param

    def prior_parameters(self):
        yield self.log_sigma_prior
        yield self.mu_prior

    @staticmethod  
    def set_weight(seq, weight, new_weight):
        assert len(weight) == len(new_weight)

        bias = 0
        for param in seq.parameters():
            param_size = torch.tensor(param.size()).prod()
            param.data = new_weight[bias:bias+param_size].view_as(param)
            
            bias += param_size

    @property
    def weight(self):
        return self.get_weight(self.network)

    @staticmethod
    def get_weight(seq, requires_grad=True):
        weight = None
        if requires_grad:
            parameters = []
            for param in seq.parameters():
                parameters.append(param.view(-1))

            weight = torch.cat(parameters)
        else:
            with torch.no_grad():
                parameters = []
                for param in seq.parameters():
                    parameters.append(param.view(-1))

                weight = torch.cat(parameters)
        return weight

    def forward(self, x_batch):
        r'''
        Model inference for one batch

        :param x_batch: 
            input tensor of shape `batch_size` :math:`\times` `input_dim`
        :type x_batch: Tensor
        :return: 
            output tensor of shape `batch_size` :math:`\times` `output_dim`
        :rtype: Tensor
        '''

        return self.network(x_batch)

    def _D_KL_loss(self):
        r'''
        The method return KL divergence between prior and posterior. 
        The distributions are assumed to be normal.
        '''

        mu_posterior = self.weight

        D_KL_1 = 0.5*torch.exp(self.log_sigma_posterior \
                               -self.log_sigma_prior).sum()
        D_KL_2 = (0.5*torch.exp(-self.log_sigma_prior) \
                     *(self.mu_prior-mu_posterior)**2).sum()
        D_KL_3 = 0.5*(self.log_sigma_prior.sum() \
                      - self.log_sigma_posterior.sum())
        D_KL_4 = -1*0.5*len(mu_posterior)

        return D_KL_1 + D_KL_2 + D_KL_3 + D_KL_4

    def loss(self, likelihood, alpha=1.):
        r'''
        Compute completed loss for bayesian model.
        :param likelihood: data likelihood
        :type likelihood: Tensor
        '''
        return alpha*self._D_KL_loss() - likelihood


    def fit(self, 
            train_dataset,
            loss_function=torch.nn.MSELoss,
            optim_function=torch.optim.Adam,
            epochs=100,
            lr=0.001,
            hyp_lr=0.00001, 
            batch_size=8,
            callback=None):
      
        optimiser_posterior = optim_function(self.posterior_parameters(), lr=lr)
        optimiser_prior = optim_function(self.prior_parameters(), lr=hyp_lr)
        loss_funct = loss_function()
      
        dataloader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, pin_memory=True, shuffle=True)
        
        for epoch in tqdm(range(epochs)):
            total_loss = 0
            for x_batch, y_batch in dataloader:
                x_batch = x_batch.to(self.device)
                y_batch = y_batch.to(self.device)

                optimiser_posterior.zero_grad()
                output = self(x_batch)

                if len(y_batch.shape) > 1:
                    likelihood_1 = -1*loss_funct(
                        output.transpose(1, -1), 
                        y_batch.transpose(1, -1))/len(x_batch)
                else:
                    likelihood_1 = -1*loss_funct(
                        output, 
                        y_batch)/len(x_batch)
                loss_1 = self.loss(likelihood_1)

                loss_1.backward()
                optimiser_posterior.step()


                optimiser_prior.zero_grad()
                output = self(x_batch)

                if len(y_batch.shape) > 1:
                    likelihood_2 = -1*loss_funct(
                        output.transpose(1, -1), 
                        y_batch.transpose(1, -1))/len(x_batch)
                else:
                    likelihood_2 = -1*loss_funct(
                        output, 
                        y_batch)/len(x_batch)
                loss_2 = self.loss(likelihood_2)

                loss_2.backward()
                optimiser_prior.step()

                if callback is not None:
                    callback(self, likelihood_1, likelihood_2)


# CallBack

In [None]:
class callback():
    def __init__(self, writer, dataset, 
                 loss_function=torch.nn.MSELoss, 
                 delimeter = 100, 
                 batch_size = 64):
        self.step = 0
        self.writer = writer
        self.delimeter = delimeter
        self.loss_function = loss_function
        self.batch_size = batch_size

        self.dataset = dataset

    def forward(self, model, *loss):
        self.step += 1
        for i, ls in enumerate(loss):
            self.writer.add_scalar('TRAIN/likelihood_{}'.format(i), ls, self.step)
        loss_funct = self.loss_function()
        if self.step % self.delimeter == 0:

            batch_generator = torch.utils.data.DataLoader(
                dataset = self.dataset, batch_size=self.batch_size)
            
            test_loss = 0
            with torch.no_grad():
                for it, (x_batch, y_batch) in enumerate(batch_generator):
                    x_batch = x_batch.to(model.device)
                    y_batch = y_batch.to(model.device)
                    
                    output = model(x_batch)

                    if len(y_batch.shape) > 1:
                        test_loss += -1*loss_funct(
                            output.transpose(1, -1), 
                            y_batch.transpose(1, -1)).cpu().item()*len(x_batch)
                    else:
                        test_loss += -1*loss_funct(
                            output, 
                            y_batch).cpu().item()*len(x_batch)
            
            test_loss /= len(self.dataset)
            
            self.writer.add_scalar('TEST/likelihood', test_loss, self.step)

            self.writer.add_scalar('VARIATION/D_KL', 
                                   model._D_KL_loss(), self.step)
            
            writer.add_histogram('weight', model.weight, self.step)


    def __call__(self, model, *loss):
        return self.forward(model, *loss)

# Синтетика

## Tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir experiment/

## Данные

In [None]:
np.random.seed(0)

l = 1024
n = 10
X = np.random.randn(l, n)
w = np.random.randn(n)

y = X@w + 0.1*np.random.randn(l)

In [None]:
X_train_tr = torch.tensor(X[:900]).float()
X_test_tr = torch.tensor(X[900:]).float()

y_train_tr = torch.tensor(y[:900]).view(-1, 1).float()
y_test_tr = torch.tensor(y[900:]).view(-1, 1).float()


In [None]:
train_dataset = torch.utils.data.TensorDataset(X_train_tr, y_train_tr)
test_dataset = torch.utils.data.TensorDataset(X_test_tr, y_test_tr)

## Параметры обучения

In [None]:
epochs = 150
delimeter = 100
teacher_3W_layers = [100, 50]
student_3W_layers = [10, 10]
# хоть и оставляем первый слой, а удаляем второй, но в параметрах остается 
# минимум из двух размерностей матрици параметров
student_2W_layers = [50, 0]
input_dim = 10
output_dim = 1

## Обучение

In [None]:
teacher = BayesianNetworkFNN(input_dim=input_dim, 
                             output_dim=output_dim, 
                             layers=teacher_3W_layers)
_ = teacher.to(device)

In [None]:
writer = SummaryWriter(log_dir = 'experiment/3W/teacher')

call = callback(writer, test_dataset, delimeter=delimeter)

In [None]:
teacher.fit(train_dataset, callback=call, epochs=epochs)

In [None]:
student = BayesianNetworkFNN(input_dim=input_dim, 
                             output_dim=output_dim, 
                             layers=student_3W_layers)
_ = student.to(device)

In [None]:
writer = SummaryWriter(log_dir = 'experiment/3W/student')

call = callback(writer, test_dataset, delimeter=delimeter)

In [None]:
student.fit(train_dataset, callback=call, epochs=epochs)

In [None]:
student = BayesianNetworkFNN(input_dim=input_dim, 
                             output_dim=output_dim, 
                             layers=student_2W_layers)
_ = student.to(device)

In [None]:
writer = SummaryWriter(log_dir = 'experiment/2W/student')

call = callback(writer, test_dataset, delimeter=delimeter)

In [None]:
student.fit(train_dataset, callback=call, epochs=epochs)

## Дистиляция

### Удаление нейронов
Так как ковариация отсутсвует, то в данном случае формула превращения параметров будет тривиальная
$$
p(u) = \mathcal{N}\bigr(\mu_{posterior}', \Sigma_{posterior}'\bigr)
$$

In [None]:
teacher_3W_layers_distil=[input_dim] + teacher_3W_layers + [output_dim]
student_3W_layers_distil=[input_dim] + student_3W_layers + [output_dim]

np.random.seed(0)
preference = [np.random.choice(np.arange(0, teacher_layer_size), 
                               size=student_layer_size, 
                               replace=False ).tolist() \
              for teacher_layer_size, student_layer_size in zip(
                  teacher_3W_layers_distil, student_3W_layers_distil)]
 
weight_mask = []
for id in range(1, len(teacher_3W_layers_distil)):
    layer_mask = np.zeros(
        shape=[teacher_3W_layers_distil[id-1], teacher_3W_layers_distil[id]])
    
    bias_mask = np.zeros(teacher_3W_layers_distil[id])

    for pref_neuron_id in preference[id-1]:
        for cur_neuron_id in preference[id]:
            layer_mask[pref_neuron_id, cur_neuron_id] = 1

    for cur_neuron_id in preference[id]:
        bias_mask[cur_neuron_id] = 1
    
    weight_mask += layer_mask.reshape(-1).tolist()
    weight_mask += bias_mask.reshape(-1).tolist()

weight_mask = np.array(weight_mask)

In [None]:
mu_prior = teacher.weight[weight_mask == 1].detach().cpu().numpy().tolist()
sigma_prior = torch.exp(
    teacher.log_sigma_prior)[weight_mask == 1].detach().cpu().numpy().tolist()

In [None]:
student = BayesianNetworkFNN(input_dim=input_dim, 
                             output_dim=output_dim, 
                             layers=student_3W_layers, 
                             mu_prior=mu_prior)
_ = student.to(device)

In [None]:
writer = SummaryWriter(log_dir = 'experiment/3W/distil_student')

call = callback(writer, test_dataset, delimeter=delimeter)

In [None]:
student.fit(train_dataset, callback=call, epochs=epochs)

In [None]:
student = BayesianNetworkFNN(input_dim=input_dim, 
                             output_dim=output_dim, 
                             layers=student_3W_layers, 
                             mu_prior=mu_prior, 
                             sigma_prior=sigma_prior)
_ = student.to(device)

In [None]:
writer = SummaryWriter(log_dir = 'experiment/3W/distil_student_all')

call = callback(writer, test_dataset, delimeter=delimeter)

In [None]:
student.fit(train_dataset, callback=call, epochs=epochs)

### Удаление слоя

Так как ковариация отсутсвует, то в данном случае формула превращения параметров будет тривиальная
$$
p(u) = \mathcal{N}\bigr(\mu_{posterior}', \Sigma_{posterior}'\bigr)
$$

In [None]:
# сначала выбираем какой слой убираем 
#      (в нашем случае я убираю вторую матрицу W_2)
# опять таки так как ковариации нет, то просто удаляем вторую матрицу 
# и берем подмножетво строк первой матрицы
teacher_3W_layers_distil=[input_dim] + teacher_3W_layers + [output_dim]
student_2W_layers_distil=[input_dim] + student_2W_layers + [output_dim]

np.random.seed(0)
preference = [np.random.choice(np.arange(0, teacher_layer_size), 
                               size=student_layer_size, 
                               replace=False ).tolist() \
              for teacher_layer_size, student_layer_size in zip(
                  teacher_3W_layers_distil, student_2W_layers_distil)]
 
weight_mask = []
prev_id = 0
for id in range(1, len(teacher_3W_layers_distil)):
    layer_mask = np.zeros(
        shape=[teacher_3W_layers_distil[id-1], teacher_3W_layers_distil[id]])
    
    bias_mask = np.zeros(teacher_3W_layers_distil[id])

    prev_pref = preference[id-1]
    if not prev_pref:
        prev_pref = np.arange(0, teacher_3W_layers_distil[id-1])
    for pref_neuron_id in prev_pref:
        for cur_neuron_id in preference[id]:
            layer_mask[pref_neuron_id, cur_neuron_id] = 1

    for cur_neuron_id in preference[id]:
        bias_mask[cur_neuron_id] = 1
        
    weight_mask += layer_mask.reshape(-1).tolist()
    weight_mask += bias_mask.reshape(-1).tolist()

weight_mask = np.array(weight_mask)

In [None]:
student_2W_layers_distil, teacher_3W_layers_distil

In [None]:
mu_prior = teacher.weight[weight_mask == 1].detach().cpu().numpy().tolist()
sigma_prior = torch.exp(
    teacher.log_sigma_prior)[weight_mask == 1].detach().cpu().numpy().tolist()

In [None]:
student = BayesianNetworkFNN(input_dim=input_dim, 
                             output_dim=output_dim, 
                             layers=student_2W_layers, 
                             mu_prior=mu_prior)
_ = student.to(device)

In [None]:
writer = SummaryWriter(log_dir = 'experiment/2W/distil_student')

call = callback(writer, test_dataset, delimeter=delimeter)

In [None]:
student.fit(train_dataset, callback=call, epochs=epochs)

In [None]:
student = BayesianNetworkFNN(input_dim=input_dim, 
                             output_dim=output_dim, 
                             layers=student_2W_layers, 
                             mu_prior=mu_prior, 
                             sigma_prior=sigma_prior)
_ = student.to(device)

In [None]:
writer = SummaryWriter(log_dir = 'experiment/2W/distil_student_all')

call = callback(writer, test_dataset, delimeter=delimeter)

In [None]:
student.fit(train_dataset, callback=call, epochs=epochs)

# FashionMnist

## Tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir fashionmnist/

## Данные

In [None]:
MNIST = datasets.FashionMNIST('../data/', train=True, download=True, transform=None)
train_dataset = torch.utils.data.TensorDataset(MNIST.data.view(-1, 28 * 28).float() / 255, MNIST.targets)

MNIST = datasets.FashionMNIST('../data/', train=False, download=True, transform=None)
test_dataset = torch.utils.data.TensorDataset(MNIST.data.view(-1, 28 * 28).float() / 255, MNIST.targets)


## Параметры обучения

In [None]:
epochs = 30
delimeter = 100
teacher_3W_layers = [100, 50]
student_3W_layers = [10, 10]
# хоть и оставляем первый слой, а удаляем второй, но в параметрах остается 
# минимум из двух размерностей матрици параметров
student_2W_layers = [50, 0]
input_dim = 784
output_dim = 10
batch_size = 64

## Обучение

In [None]:
teacher = BayesianNetworkFNN(input_dim=input_dim, 
                             output_dim=output_dim, 
                             layers=teacher_3W_layers)
_ = teacher.to(device)

In [None]:
writer = SummaryWriter(log_dir = 'fashionmnist/3W/teacher')

call = callback(writer, test_dataset, 
                delimeter=delimeter, loss_function=torch.nn.CrossEntropyLoss)

In [None]:
teacher.fit(train_dataset, 
            callback=call, 
            epochs=epochs,
            batch_size=batch_size,
            loss_function=torch.nn.CrossEntropyLoss)

In [None]:
student = BayesianNetworkFNN(input_dim=input_dim, 
                             output_dim=output_dim, 
                             layers=student_3W_layers)
_ = student.to(device)

In [None]:
writer = SummaryWriter(log_dir = 'fashionmnist/3W/student')

call = callback(writer, test_dataset, 
                delimeter=delimeter, loss_function=torch.nn.CrossEntropyLoss)

In [None]:
student.fit(train_dataset, 
            callback=call, 
            epochs=epochs,
            batch_size=batch_size,
            loss_function=torch.nn.CrossEntropyLoss)

In [None]:
student = BayesianNetworkFNN(input_dim=input_dim, 
                             output_dim=output_dim, 
                             layers=student_2W_layers)
_ = student.to(device)

In [None]:
writer = SummaryWriter(log_dir = 'fashionmnist/2W/student')

call = callback(writer, test_dataset, 
                delimeter=delimeter, loss_function=torch.nn.CrossEntropyLoss)

In [None]:
student.fit(train_dataset, 
            callback=call, 
            epochs=epochs,
            batch_size=batch_size,
            loss_function=torch.nn.CrossEntropyLoss)

## Дистиляция

### Удаление нейронов
Так как ковариация отсутсвует, то в данном случае формула превращения параметров будет тривиальная
$$
p(u) = \mathcal{N}\bigr(\mu_{posterior}', \Sigma_{posterior}'\bigr)
$$

In [None]:
teacher_3W_layers_distil=[input_dim] + teacher_3W_layers + [output_dim]
student_3W_layers_distil=[input_dim] + student_3W_layers + [output_dim]

np.random.seed(0)
preference = [np.random.choice(np.arange(0, teacher_layer_size), 
                               size=student_layer_size, 
                               replace=False ).tolist() \
              for teacher_layer_size, student_layer_size in zip(
                  teacher_3W_layers_distil, student_3W_layers_distil)]
 
weight_mask = []
for id in range(1, len(teacher_3W_layers_distil)):
    layer_mask = np.zeros(
        shape=[teacher_3W_layers_distil[id-1], teacher_3W_layers_distil[id]])
    
    bias_mask = np.zeros(teacher_3W_layers_distil[id])

    for pref_neuron_id in preference[id-1]:
        for cur_neuron_id in preference[id]:
            layer_mask[pref_neuron_id, cur_neuron_id] = 1

    for cur_neuron_id in preference[id]:
        bias_mask[cur_neuron_id] = 1
    
    weight_mask += layer_mask.reshape(-1).tolist()
    weight_mask += bias_mask.reshape(-1).tolist()

weight_mask = np.array(weight_mask)

In [None]:
mu_prior = teacher.weight[weight_mask == 1].detach().cpu().numpy().tolist()
sigma_prior = torch.exp(
    teacher.log_sigma_prior)[weight_mask == 1].detach().cpu().numpy().tolist()

In [None]:
student = BayesianNetworkFNN(input_dim=input_dim, 
                             output_dim=output_dim, 
                             layers=student_3W_layers, 
                             mu_prior=mu_prior)
_ = student.to(device)

In [None]:
writer = SummaryWriter(log_dir = 'fashionmnist/3W/distil_student')

call = callback(writer, test_dataset, 
                delimeter=delimeter, loss_function=torch.nn.CrossEntropyLoss)

In [None]:
student.fit(train_dataset, 
            callback=call, 
            epochs=epochs,
            batch_size=batch_size,
            loss_function=torch.nn.CrossEntropyLoss)

In [None]:
student = BayesianNetworkFNN(input_dim=input_dim, 
                             output_dim=output_dim, 
                             layers=student_3W_layers, 
                             mu_prior=mu_prior, 
                             sigma_prior=sigma_prior)
_ = student.to(device)

In [None]:
writer = SummaryWriter(log_dir = 'fashionmnist/3W/distil_student_all')

call = callback(writer, test_dataset, 
                delimeter=delimeter, loss_function=torch.nn.CrossEntropyLoss)

In [None]:
student.fit(train_dataset, 
            callback=call, 
            epochs=epochs,
            batch_size=batch_size,
            loss_function=torch.nn.CrossEntropyLoss)

### Удаление слоя

Так как ковариация отсутсвует, то в данном случае формула превращения параметров будет тривиальная
$$
p(u) = \mathcal{N}\bigr(\mu_{posterior}', \Sigma_{posterior}'\bigr)
$$

In [None]:
# сначала выбираем какой слой убираем 
#      (в нашем случае я убираю вторую матрицу W_2)
# опять таки так как ковариации нет, то просто удаляем вторую матрицу 
# и берем подмножетво строк первой матрицы
teacher_3W_layers_distil=[input_dim] + teacher_3W_layers + [output_dim]
student_2W_layers_distil=[input_dim] + student_2W_layers + [output_dim]

np.random.seed(0)
preference = [np.random.choice(np.arange(0, teacher_layer_size), 
                               size=student_layer_size, 
                               replace=False ).tolist() \
              for teacher_layer_size, student_layer_size in zip(
                  teacher_3W_layers_distil, student_2W_layers_distil)]
 
weight_mask = []
prev_id = 0
for id in range(1, len(teacher_3W_layers_distil)):
    layer_mask = np.zeros(
        shape=[teacher_3W_layers_distil[id-1], teacher_3W_layers_distil[id]])
    
    bias_mask = np.zeros(teacher_3W_layers_distil[id])

    prev_pref = preference[id-1]
    if not prev_pref:
        prev_pref = np.arange(0, teacher_3W_layers_distil[id-1])
    for pref_neuron_id in prev_pref:
        for cur_neuron_id in preference[id]:
            layer_mask[pref_neuron_id, cur_neuron_id] = 1

    for cur_neuron_id in preference[id]:
        bias_mask[cur_neuron_id] = 1
        
    weight_mask += layer_mask.reshape(-1).tolist()
    weight_mask += bias_mask.reshape(-1).tolist()

weight_mask = np.array(weight_mask)

In [None]:
student_2W_layers_distil, teacher_3W_layers_distil

In [None]:
mu_prior = teacher.weight[weight_mask == 1].detach().cpu().numpy().tolist()
sigma_prior = torch.exp(
    teacher.log_sigma_prior)[weight_mask == 1].detach().cpu().numpy().tolist()

In [None]:
student = BayesianNetworkFNN(input_dim=input_dim, 
                             output_dim=output_dim, 
                             layers=student_2W_layers, 
                             mu_prior=mu_prior)
_ = student.to(device)

In [None]:
writer = SummaryWriter(log_dir = 'fashionmnist/2W/distil_student')

call = callback(writer, test_dataset, 
                delimeter=delimeter, loss_function=torch.nn.CrossEntropyLoss)

In [None]:
student.fit(train_dataset, 
            callback=call, 
            epochs=epochs,
            batch_size=batch_size,
            loss_function=torch.nn.CrossEntropyLoss)

In [None]:
student = BayesianNetworkFNN(input_dim=input_dim, 
                             output_dim=output_dim, 
                             layers=student_2W_layers, 
                             mu_prior=mu_prior, 
                             sigma_prior=sigma_prior)
_ = student.to(device)

In [None]:
writer = SummaryWriter(log_dir = 'fashionmnist/2W/distil_student_all')

call = callback(writer, test_dataset, 
                delimeter=delimeter, loss_function=torch.nn.CrossEntropyLoss)

In [None]:
student.fit(train_dataset, 
            callback=call, 
            epochs=epochs,
            batch_size=batch_size,
            loss_function=torch.nn.CrossEntropyLoss)