## LoadDataset

In [1]:
from parallel import ParallelExecutor, TaskSpec
import torch
from utils import prepare_dataset

之前做了一个实验来判断如果让各weight更orthogonal一点，会不会影响aggregation的结果。
在keras里面只需要使用

```python
tf.keras.regularizers.OrthogonalRegularizer
```

在Pytorch里需要重写trainer 或者 重写一个loss function

##  Trainer

In [2]:
from trainer import DefaultTrainer

class TrainerWithOrthogonalRegularization(DefaultTrainer):
    def __init__(self, factor, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.factor = factor

    def train_step(self, batch):
        self.step += 1
        self.optimizer.zero_grad()
        inputs, targets = batch
        inputs, targets = inputs.to(self.device), targets.to(self.device)
        outputs = self.model(inputs)
        loss = self.criterion(outputs, targets)
        reg_loss = self._orthogonal_regularization_loss()
        total_loss = loss + self.factor * reg_loss
        # the same in DefaultTrainer
        total_loss.backward()
        self.optimizer.step()
        metrics = self.metrics(outputs, targets)
        metrics['loss'] = loss.item()
        metrics['reg_loss'] = reg_loss.item()
        return metrics
        
    
    def _orthogonal_regularization_loss(self) -> torch.Tensor:
        self.linear_layers = [module for module in self.model.modules() if isinstance(module, torch.nn.Linear)]
        ortho_loss = 0.0
        for layer in self.linear_layers:
            weight = layer.weight
            identity = torch.eye(weight.shape[1], device=weight.device)
            ortho_loss += torch.norm(weight.t() @ weight - identity, p='fro')
        return ortho_loss
             
        
    def configure_optimizers(self):
        optimizer = super().configure_optimizers()
        if self.orthogonal_regularization > 0.0:
            optimizer.param_groups[0]['weight_decay'] = 0.0
        return optimizer

In [7]:
from lightning import seed_everything
from utils import prepare_dataset, build_mlp_model
import torchmetrics
import logger
def reg_factor_experiment(factor):
    seed_everything(42)
    trainset, _ = prepare_dataset("MNIST", "MLP")
    model = build_mlp_model()
    metrics = torchmetrics.MetricCollection(
        [torchmetrics.Accuracy(task='multiclass', num_classes=10),]
    )
    loggers = logger.LoggerCollection(
        [logger.CSVLogger(name=f"factor_{factor}", 
                             saving_dir="different_reg_factors")]
    )
    criterion = torch.nn.CrossEntropyLoss()
    trainer = TrainerWithOrthogonalRegularization(
        factor=factor,
        model=model,
        optimizer=torch.optim.Adam(model.parameters(), lr=1e-3),
        criterion=criterion,
        metrics=metrics,
        loggers=loggers,
        device="cuda",
        need_saving=True,
        saving_on='last',
        saving_dir="different_reg_factors/checkpoints",
        saving_name=f"factor_{factor}",
    )
    train_loader = torch.utils.data.DataLoader(trainset, 
                                               batch_size=256, 
                                               shuffle=True)
    results = trainer.fit(train_loader, epochs=30)
    del model
    del trainer
    torch.cuda.empty_cache()
    return results

In [8]:
from parallel import SpecGenerator
class FactorsSpecGenerator(SpecGenerator):
    def __init__(self):
        self.factors = [0.0, 1e-4, 1e-3, 1e-2, 1e-1]
    def __iter__(self):
        for factor in self.factors:
            yield TaskSpec(
                id=f"factor_{factor}",
                args=(factor,),
                kwargs={},
            )

In [9]:
task_spec_generator = FactorsSpecGenerator()
for task_spec in task_spec_generator:
    print(task_spec)

TaskSpec(id='factor_0.0', args=(0.0,), kwargs={})
TaskSpec(id='factor_0.0001', args=(0.0001,), kwargs={})
TaskSpec(id='factor_0.001', args=(0.001,), kwargs={})
TaskSpec(id='factor_0.01', args=(0.01,), kwargs={})
TaskSpec(id='factor_0.1', args=(0.1,), kwargs={})


In [10]:
executor = ParallelExecutor(gpu_fraction=0.3)
executor.run(func=reg_factor_experiment, spec_generator=task_spec_generator)

[36m(reg_factor_experiment pid=3464655)[0m Global seed set to 42
  0%|          | 0/235 [00:00<?, ?it/s])[0m 
Epoch 0 - loss: 2.3111:   0%|          | 1/235 [00:01<04:13,  1.08s/it]
Epoch 0 - loss: 2.2479:   1%|          | 2/235 [00:01<02:05,  1.86it/s]
Epoch 0 - loss: 2.2259:   1%|▏         | 3/235 [00:01<01:22,  2.80it/s]
Epoch 0 - loss: 1.7092:   5%|▍         | 11/235 [00:02<00:33,  6.76it/s]
Epoch 0 - loss: 1.7092:   5%|▌         | 12/235 [00:02<00:32,  6.87it/s]
[36m(reg_factor_experiment pid=3464659)[0m Global seed set to 42[32m [repeated 4x across cluster][0m
  0%|          | 0/235 [00:00<?, ?it/s][32m [repeated 4x across cluster][0m
Epoch 0 - loss: 2.3111:   0%|          | 1/235 [00:00<03:20,  1.17it/s][32m [repeated 4x across cluster][0m
Epoch 0 - loss: 0.7173:  16%|█▌        | 38/235 [00:06<00:27,  7.25it/s][32m [repeated 161x across cluster][0m
Epoch 0 - loss: 0.8835:  12%|█▏        | 28/235 [00:04<00:30,  6.79it/s][32m [repeated 2x across cluster][0m
Epoch 0 

Caught exception: [36mray::reg_factor_experiment()[39m (pid=3464655, ip=172.21.47.117)
  File "/tmp/ipykernel_3451343/1599710119.py", line 33, in reg_factor_experiment
  File "/home/hypeng/Research/notebooks_experiments/trainer.py", line 189, in fit
    self.save() if self.need_saving else None
  File "/home/hypeng/Research/notebooks_experiments/trainer.py", line 246, in save
    if isinstance(logger, WandbLogger):
NameError: name 'WandbLogger' is not defined. Terminating workers.


RayTaskError(NameError): [36mray::reg_factor_experiment()[39m (pid=3464655, ip=172.21.47.117)
  File "/tmp/ipykernel_3451343/1599710119.py", line 33, in reg_factor_experiment
  File "/home/hypeng/Research/notebooks_experiments/trainer.py", line 189, in fit
    self.save() if self.need_saving else None
  File "/home/hypeng/Research/notebooks_experiments/trainer.py", line 246, in save
    if isinstance(logger, WandbLogger):
NameError: name 'WandbLogger' is not defined

Epoch 29 - loss: 0.0093: 100%|██████████| 235/235 [00:33<00:00,  7.06it/s]
Epoch 29 - loss: 0.0526:  63%|██████▎   | 148/235 [00:21<00:12,  7.01it/s]
2024-01-05 19:05:16,100	ERROR worker.py:405 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): [36mray::reg_factor_experiment()[39m (pid=3464657, ip=172.21.47.117)
  File "/tmp/ipykernel_3451343/1599710119.py", line 33, in reg_factor_experiment
  File "/home/hypeng/Research/notebooks_experiments/trainer.py", line 189, in fit
    self.save() if self.need_saving else None
  File "/home/hypeng/Research/notebooks_experiments/trainer.py", line 246, in save
    if isinstance(logger, WandbLogger):
NameError: name 'WandbLogger' is not defined
2024-01-05 19:05:19,101	ERROR worker.py:405 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): [36mray::reg_factor_experiment()[39m (pid=3464656, ip=172.21.47.117)
  File "/tmp/ipykernel_3451343/1599710119.py", line 33, in reg_factor_experiment
  File "/home/hypeng/Resea