## LoadDataset

In [2]:
from parallel import ParallelExecutor, TaskSpec
from utils import prepare_dataset
trainset, testset = prepare_dataset("MNIST", "MLP")
print(f"Datset loaded, trainset size: {len(trainset)}, testset size: {len(testset)}")

Datset loaded, trainset size: 60000, testset size: 10000


之前做了一个实验来判断如果让各weight更orthogonal一点，会不会影响aggregation的结果。
在keras里面只需要使用```tf.keras.regularizers.OrthogonalRegularizer```.

在Pytorch里需要重写trainer 或者 重写一个loss function

##  Trainer

In [None]:
from trainer import DefaultTrainer
import torch 
class TrainerWithOrthogonalRegularization(DefaultTrainer):
    def __init__(self, orthogonal_regularization_factor, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.orthogonal_regularization = orthogonal_regularization_factor

    def train_step(self, batch):
        self.step += 1
        self.optimizer.zero_grad()
        inputs, targets = batch
        inputs, targets = inputs.to(self.device), targets.to(self.device)
        outputs = self.model(inputs)
        loss = self.criterion(outputs, targets)
        loss = loss + self.orthogonal_regularization * self._orthogonal_regularization_loss()
        # the same in DefaultTrainer
        loss.backward()
        self.optimizer.step()
        metrics = self.metrics(outputs, targets)
        if self.step % self.log_interval == 0:
            self.log_metrics(metrics, loss)
        
    
    def _orthogonal_regularization_loss(self) -> torch.Tensor:
        self.linear_layers = [module for module in self.model.modules() if isinstance(module, torch.nn.Linear)]
        ortho_loss = 0.0
        for layer in self.linear_layers:
            weight = layer.weight
            identity = torch.eye(weight.shape[0], device=weight.device)
            ortho_loss += torch.norm(weight.t() @ weight - identity, p='fro')
        return ortho_loss
             
        
        
    def configure_optimizers(self):
        optimizer = super().configure_optimizers()
        if self.orthogonal_regularization > 0.0:
            optimizer.param_groups[0]['weight_decay'] = 0.0
        return optimizer