# NEW TRAINER

In [1]:
from sage.config import load_config
from sage.training.trainer import MRITrainer

cfg = load_config()
cfg.augment = True
cfg.phase_config = {
    'epochs': [200],
    'update': [['reg']]
}

########## Vanilla ##########
cfg.encoder = {
    'name': 'vanillaconv',
    'config': {
        'start_channels': 16
    }
}
#############################


########## ResNet ###########
# cfg.encoder = {
#     'name': 'resnet',
#     'config': {
#         'start_channels': 64
#     }
# }
#############################


########## ConViT ###########
# cfg.use_amp = False
# cfg.encoder = {
#     'name': 'convit',
#     'config': {
#         'embed_dim': 96,
#         'local_up_to_layer': 4,
#         'depth': 5,
#         'num_heads': 8
#     }
# }
# cfg.reg_opt.lr = 1e-4
#############################

trainer = MRITrainer(cfg)

Use cuda:0 as a device.
Output from encoder is 256.
Total Number of parameters: 3618069
TOTAL TRAIN 2322 | VALID 291
MIXED PRECISION:: True


In [2]:
import wandb
wandb.login()
wandb.init(project='3d_smri',
           config=vars(cfg),
           name='Vanilla Base(256) AUG',
           tags=['VanillaConv', 'baseline_aug', 'test']
  )

wandb: Currently logged in as: 1pha (use `wandb login --relogin` to force relogin)


In [None]:
trainer.run(cfg)

----- Epoch 1 / 200 (phase: 0) BEST MAE inf -----
[train] 445.1 sec [valid] 16.9 sec 
TRAIN
acc   : 0.2442  |  auc   : 0.4757  |  corr  : 0.1368
mae   : 52.6294  |  r2    : -8.0410  |  rmse  : 56.0100

VALID
acc   : 0.1100  |  auc   : 0.4674  |  corr  : 0.5807
mae   : 44.2429  |  r2    : -5.2823  |  rmse  : 47.5484


----- Epoch 2 / 200 (phase: 0) BEST MAE 44.243 -----
[train] 431.1 sec [valid] 12.9 sec 
TRAIN
acc   : 0.1555  |  auc   : 0.4349  |  corr  : 0.2955
mae   : 23.4327  |  r2    : -1.4910  |  rmse  : 27.9713

VALID
acc   : 0.1959  |  auc   : 0.3015  |  corr  : 0.8254
mae   : 11.2073  |  r2    : 0.5150  |  rmse  : 13.1111


----- Epoch 3 / 200 (phase: 0) BEST MAE 11.207 -----
[train] 418.3 sec [valid] 12.9 sec 
TRAIN
acc   : 0.1675  |  auc   : 0.4013  |  corr  : 0.6929
mae   : 10.8964  |  r2    : 0.3825  |  rmse  : 10.8840

VALID
acc   : 0.1409  |  auc   : 0.3640  |  corr  : 0.8607
mae   : 7.7556  |  r2    : 0.7342  |  rmse  : 9.4077


----- Epoch 4 / 200 (phase: 0) BEST MAE 7.

[train] 413.0 sec [valid] 12.7 sec 
TRAIN
acc   : 0.1490  |  auc   : 0.4349  |  corr  : 0.9639
mae   : 3.6764  |  r2    : 0.9286  |  rmse  : 4.5009

VALID
acc   : 0.1100  |  auc   : 0.3680  |  corr  : 0.9257
mae   : 6.2982  |  r2    : 0.8103  |  rmse  : 7.9184


----- Epoch 28 / 200 (phase: 0) BEST MAE 5.452 -----
[train] 410.1 sec [valid] 12.5 sec 
TRAIN
acc   : 0.1529  |  auc   : 0.4387  |  corr  : 0.9629
mae   : 3.6979  |  r2    : 0.9268  |  rmse  : 4.4337

VALID
acc   : 0.1340  |  auc   : 0.3836  |  corr  : 0.9322
mae   : 5.7937  |  r2    : 0.8472  |  rmse  : 7.0346


----- Epoch 29 / 200 (phase: 0) BEST MAE 5.452 -----
[train] 407.7 sec [valid] 12.6 sec 
TRAIN
acc   : 0.1503  |  auc   : 0.4333  |  corr  : 0.9593
mae   : 3.8677  |  r2    : 0.9188  |  rmse  : 4.2660

VALID
acc   : 0.1203  |  auc   : 0.3750  |  corr  : 0.9333
mae   : 5.4436  |  r2    : 0.8656  |  rmse  : 6.5616


----- Epoch 30 / 200 (phase: 0) BEST MAE 5.444 -----
[train] 408.7 sec [valid] 12.8 sec 
TRAIN
acc   : 0.

## from checkpoint

In [3]:
PREFIX = 'G:\My Drive\brain_data\workspace\result\models\20210818-1439'
SUFFIX = 'ep139_mae7.95.pt'

checkpoint = {
    'resume_epoch': 0,
    'models':{
        'encoder': f'{PREFIX}/encoder/{SUFFIX}',
        'domainer': f'{PREFIX}/domainer/{SUFFIX}',
        'regressor': f'{PREFIX}/regressor/{SUFFIX}',
    }
}

# FIND THIS IN WANDB.AI HOMEPAGE
resume_id = '287lhmi7'

In [None]:
import wandb
wandb.login()
wandb.init(project='3d_smri',
           config=vars(cfg),
           resume=resume_id,
  )

In [None]:
trainer.run(checkpoint=checkpoint)

## Grid Search

In [1]:
import wandb

from sage.config import load_config
from sage.training.trainer import MRITrainer

sweep_config = {
    "name" : "ConViT",
    "method" : "grid",
    "metric": {
        "name": "valid_mae",
        "goal": "minimize"
    },
    "parameters" : {
        "augment" :{
            "values": [True, False]
        },
        "start_channels": {
            "values": [8, 16, 32]
        }
      }
    }

sweep_id = wandb.sweep(sweep_config, project='3d_smri')

Create sweep with ID: gr9jzr82
Sweep URL: https://wandb.ai/1pha/3d_smri/sweeps/gr9jzr82


In [2]:
from IPython.display import clear_output

def run_sweep():
    
    with wandb.init(tags=['VanillaConv']) as run:

        cfg = load_config()
        
        # PRESET
        cfg = load_config()
        # cfg.augment = True
        cfg.phase_config = {
            'epochs': [200],
            'update': [['reg']]
        }
        
        cfg.encoder = {
            'name': 'vanillaconv',
            'config': {
                'start_channels': 8
            }
        }
        
        # LOAD COMBINATIONS
        conv_cfg = wandb.config
        __cfg = dict()
        __cfg.update(conv_cfg)
        
        cfg.augment = __cfg.pop('augment')
        cfg.encoder.config.update(__cfg)
        
        # DEFINE TRAINER
        trainer = MRITrainer(cfg)

        # WANDB SETUP
        embed_size = {
            8: 128,
            16: 256,
            32: 512,
        }
        name = f'VanillaConv Base ({embed_size[cfg.encoder.config.start_channels]}) AUG' if cfg.augment \
                else f'VanillaConv Base ({embed_size[cfg.encoder.config.start_channels]})'
        
        wandb.run.name = name
        
        tag = 'baseline_aug' if cfg.augment else 'baseline'
        run.tags = run.tags + (tag,)
        wandb.config.update(cfg)
        
        # RUN
        trainer.run(cfg)
        clear_output()

In [3]:
wandb.agent(sweep_id, function=run_sweep)

wandb: Agent Starting Run: hn4sl9ie with config:
wandb: 	augment: True
wandb: 	start_channels: 32




Use cuda:0 as a device.
Output from encoder is 512.
Total Number of parameters: 14464037
TOTAL TRAIN 2322 | VALID 291
MIXED PRECISION:: True




----- Epoch 1 / 200 (phase: 0) BEST MAE inf -----
[train] 3016.8 sec [valid] 17.6 sec 
TRAIN
acc   : 0.3105  |  auc   : 0.5977  |  corr  : 0.0928
mae   : 43.9015  |  r2    : -5.9433  |  rmse  : 47.7639

VALID
acc   : 0.3952  |  auc   : 0.6430  |  corr  : 0.3976
mae   : 20.5329  |  r2    : -0.5693  |  rmse  : 23.2881


----- Epoch 2 / 200 (phase: 0) BEST MAE 20.533 -----
[train] 3350.5 sec [valid] 17.7 sec 
TRAIN
acc   : 0.2050  |  auc   : 0.5548  |  corr  : 0.5233
mae   : 14.9239  |  r2    : -0.0512  |  rmse  : 13.6341

VALID
acc   : 0.1821  |  auc   : 0.4804  |  corr  : 0.8247
mae   : 14.7517  |  r2    : 0.1697  |  rmse  : 17.1580


----- Epoch 3 / 200 (phase: 0) BEST MAE 14.752 -----
[train] 3352.7 sec [valid] 17.9 sec 
TRAIN
acc   : 0.1895  |  auc   : 0.5555  |  corr  : 0.7935
mae   : 9.1926  |  r2    : 0.6113  |  rmse  : 10.2257

VALID
acc   : 0.1821  |  auc   : 0.5223  |  corr  : 0.8823
mae   : 10.4311  |  r2    : 0.5421  |  rmse  : 12.7290


----- Epoch 4 / 200 (phase: 0) BEST MA

[train] 3049.8 sec [valid] 20.1 sec 
TRAIN
acc   : 0.1895  |  auc   : 0.6322  |  corr  : 0.9639
mae   : 3.8605  |  r2    : 0.9287  |  rmse  : 4.6400

VALID
acc   : 0.1821  |  auc   : 0.5914  |  corr  : 0.9347
mae   : 5.5458  |  r2    : 0.8632  |  rmse  : 6.8461


----- Epoch 28 / 200 (phase: 0) BEST MAE 5.546 -----
[train] 3047.9 sec [valid] 20.0 sec 
TRAIN
acc   : 0.1895  |  auc   : 0.6289  |  corr  : 0.9651
mae   : 3.7193  |  r2    : 0.9310  |  rmse  : 4.5968

VALID
acc   : 0.1821  |  auc   : 0.5973  |  corr  : 0.9364
mae   : 5.4397  |  r2    : 0.8642  |  rmse  : 6.8299


----- Epoch 29 / 200 (phase: 0) BEST MAE 5.440 -----
[train] 3045.0 sec [valid] 20.1 sec 
TRAIN
acc   : 0.1895  |  auc   : 0.6341  |  corr  : 0.9678
mae   : 3.5678  |  r2    : 0.9365  |  rmse  : 4.3655

VALID
acc   : 0.1821  |  auc   : 0.6095  |  corr  : 0.9347
mae   : 6.1310  |  r2    : 0.8231  |  rmse  : 7.5990


----- Epoch 30 / 200 (phase: 0) BEST MAE 5.440 -----
[train] 3046.4 sec [valid] 20.2 sec 
TRAIN
acc   

[train] 3039.6 sec [valid] 20.1 sec 
TRAIN
acc   : 0.1895  |  auc   : 0.6395  |  corr  : 0.9796
mae   : 2.7654  |  r2    : 0.9588  |  rmse  : 3.4078

VALID
acc   : 0.1821  |  auc   : 0.6153  |  corr  : 0.9402
mae   : 7.3044  |  r2    : 0.7707  |  rmse  : 8.7905


----- Epoch 54 / 200 (phase: 0) BEST MAE 5.214 -----
[train] 3064.4 sec [valid] 20.2 sec 
TRAIN
acc   : 0.1895  |  auc   : 0.6367  |  corr  : 0.9764
mae   : 3.0574  |  r2    : 0.9531  |  rmse  : 3.4685

VALID
acc   : 0.1821  |  auc   : 0.6118  |  corr  : 0.9417
mae   : 5.0778  |  r2    : 0.8820  |  rmse  : 6.4173


----- Epoch 55 / 200 (phase: 0) BEST MAE 5.078 -----
[train] 3053.6 sec [valid] 20.2 sec 
TRAIN
acc   : 0.1895  |  auc   : 0.6444  |  corr  : 0.9795
mae   : 2.8177  |  r2    : 0.9589  |  rmse  : 3.4269

VALID
acc   : 0.1821  |  auc   : 0.6126  |  corr  : 0.9412
mae   : 5.8437  |  r2    : 0.8445  |  rmse  : 7.3266


----- Epoch 56 / 200 (phase: 0) BEST MAE 5.078 -----
[train] 3053.7 sec [valid] 20.2 sec 
TRAIN
acc   

--- Logging error ---
Traceback (most recent call last):
  File "<ipython-input-2-2ad6a4146aca>", line 51, in run_sweep
    trainer.run(cfg)
  File "G:\My Drive\brain_data\workspace\3DCNN_sMRI\sage\training\trainer.py", line 141, in run
    model_name = f'ep{str(e).zfill(3)}_mae{results.valid_mae:.2f}.pt'
  File "C:\Users\pha\anaconda3\envs\cnn\lib\site-packages\wandb\sdk\wandb_run.py", line 1141, in log
    self.history._row_add(data)
  File "C:\Users\pha\anaconda3\envs\cnn\lib\site-packages\wandb\sdk\wandb_history.py", line 44, in _row_add
    self._flush()
  File "C:\Users\pha\anaconda3\envs\cnn\lib\site-packages\wandb\sdk\wandb_history.py", line 59, in _flush
    self._callback(row=self._data, step=self._step)
  File "C:\Users\pha\anaconda3\envs\cnn\lib\site-packages\wandb\sdk\wandb_run.py", line 885, in _history_callback
    self._backend.interface.publish_history(
  File "C:\Users\pha\anaconda3\envs\cnn\lib\site-packages\wandb\sdk\interface\interface.py", line 224, in publish_his

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

Run sdt0hqdu errored: OSError(22, 'Invalid argument')
wandb: ERROR Run sdt0hqdu errored: OSError(22, 'Invalid argument')
wandb: Sweep Agent: Waiting for job.
wandb: Sweep Agent: Exiting.


# NAIVE LEARNING

In [1]:
import os
import wandb

from sage.config import *
from sage.training.runner import run
from utils.misc import seed_everything, get_today
from IPython.display import clear_output

In [2]:
cfg = load_config()
seed_everything(cfg.seed)
cfg.registration = 'mni'
cfg.unused_src = []
cfg.plot = False

cfg.augment = True
cfg.epochs = 200
cfg.model_name = 'resnet_no_maxpool'
cfg.early_patience = 20
cfg.start_channels = 32
cfg.weight_decay = 0.01

In [3]:
# checkpoint = {
#     'resume_epoch': 45,
#     'path': '../result/models/savetest20210608-1451/resnet_no_maxpool_ep45_mae5.86.pth'
# }

In [3]:
name = 'RES(256) DALLAS ADD'
cfg.RESULT_PATH = os.path.join(cfg.RESULT_PATH, name + get_today())

In [4]:
wandb.login()
wandb.init(project='3d_smri',
           config=vars(cfg),
           name=name
  )

wandb: Currently logged in as: 1pha (use `wandb login --relogin` to force relogin)


In [None]:
run(cfg)

# EFFICIENTNET

In [1]:
import os
import wandb

from sage.config import *
from sage.training.runner import run
from utils.misc import seed_everything, get_today
from IPython.display import clear_output

In [2]:
cfg = load_config()
seed_everything(cfg.seed)
cfg.registration = 'mni'
cfg.unused_src = []
cfg.plot = False

cfg.augment = True
cfg.epochs = 200
cfg.model_name = 'efficientnet-b0'
cfg.early_patience = 30

In [3]:
checkpoint = {
    'resume_epoch': 78,
    'path': '../result/models/EFFICIENTNET TEST20210613-2121/efficientnet-b0_ep77_mae7.69.pth'
}

In [4]:
name = 'EFFICIENTNET TEST RESUME 78'
cfg.RESULT_PATH = os.path.join(cfg.RESULT_PATH, name + get_today())

In [5]:
wandb.login()
wandb.init(project='3d_smri',
           config=vars(cfg),
           name=name
  )

wandb: Currently logged in as: 1pha (use `wandb login --relogin` to force relogin)
wandb: wandb version 0.10.32 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


In [6]:
run(cfg, checkpoint)

Model Efficientnet-b0 is selected.
TOTAL TRAIN 2322 | VALID 291
Epoch 79 / 200, BEST MAE inf
[train] 297.6 sec [valid] 11.9 sec 
TRAIN :: LOSS 48.788 | RMSE 7.00 | MAE 5.14 | R2 0.86 | CORR 0.93
VALID :: LOSS 113.676 | RMSE 10.66 | MAE 8.78 | R2 0.69 | CORR 0.90
Saving ...
Epoch 80 / 200, BEST MAE 8.777
[train] 288.4 sec [valid] 12.2 sec 
TRAIN :: LOSS 49.436 | RMSE 7.05 | MAE 5.21 | R2 0.86 | CORR 0.93
VALID :: LOSS 124.287 | RMSE 11.14 | MAE 9.09 | R2 0.66 | CORR 0.89
Saving ...
Epoch 81 / 200, BEST MAE 8.777
[train] 296.3 sec [valid] 12.5 sec 
TRAIN :: LOSS 50.444 | RMSE 7.12 | MAE 5.17 | R2 0.85 | CORR 0.92
VALID :: LOSS 104.505 | RMSE 10.22 | MAE 8.28 | R2 0.71 | CORR 0.90
Saving ...
Epoch 82 / 200, BEST MAE 8.282
[train] 280.1 sec [valid] 11.4 sec 
TRAIN :: LOSS 46.338 | RMSE 6.82 | MAE 5.07 | R2 0.86 | CORR 0.93
VALID :: LOSS 96.752 | RMSE 9.94 | MAE 8.02 | R2 0.73 | CORR 0.87
Saving ...
Epoch 83 / 200, BEST MAE 8.021
[train] 287.2 sec [valid] 12.0 sec 
TRAIN :: LOSS 47.388 | RM

[train] 287.1 sec [valid] 11.4 sec 
TRAIN :: LOSS 30.440 | RMSE 5.53 | MAE 4.05 | R2 0.91 | CORR 0.95
VALID :: LOSS 99.226 | RMSE 9.94 | MAE 8.11 | R2 0.73 | CORR 0.90
Saving ...
Epoch 119 / 200, BEST MAE 7.410
[train] 293.7 sec [valid] 11.5 sec 
TRAIN :: LOSS 30.661 | RMSE 5.55 | MAE 4.05 | R2 0.91 | CORR 0.95
VALID :: LOSS 84.560 | RMSE 9.17 | MAE 7.47 | R2 0.77 | CORR 0.90
Saving ...
Epoch 120 / 200, BEST MAE 7.410
[train] 305.8 sec [valid] 11.6 sec 
TRAIN :: LOSS 30.559 | RMSE 5.54 | MAE 4.06 | R2 0.91 | CORR 0.95
VALID :: LOSS 94.010 | RMSE 9.69 | MAE 7.95 | R2 0.74 | CORR 0.90
Saving ...
Epoch 121 / 200, BEST MAE 7.410
[train] 286.3 sec [valid] 11.8 sec 
TRAIN :: LOSS 29.863 | RMSE 5.48 | MAE 4.07 | R2 0.91 | CORR 0.96
VALID :: LOSS 104.936 | RMSE 10.07 | MAE 8.24 | R2 0.72 | CORR 0.90
Saving ...
Epoch 122 / 200, BEST MAE 7.410
[train] 282.7 sec [valid] 11.3 sec 
TRAIN :: LOSS 26.777 | RMSE 5.19 | MAE 3.79 | R2 0.92 | CORR 0.96
VALID :: LOSS 105.450 | RMSE 10.12 | MAE 8.24 | R2 0

VBox(children=(Label(value=' 0.03MB of 0.03MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train_rmse,5.10696
train_mae,3.72215
train_r2,0.92423
train_corr,0.96142
train_loss,25.92911
valid_loss,99.7807
valid_rmse,9.9606
valid_mae,8.27774
valid_r2,0.72663
valid_corr,0.89647


0,1
train_rmse,███▇▇▇▆▇▅▆▇▅▅▆▅▄▅▄▅▄▃▃▃▄▃▃▃▃▃▃▂▃▂▂▂▂▁▁▁▁
train_mae,████▇▇▆▆▅▆▇▅▅▅▅▅▅▄▅▄▃▃▃▄▃▃▃▃▃▃▂▃▂▂▂▂▁▂▁▁
train_r2,▁▁▁▂▃▂▄▃▅▄▃▄▅▄▄▅▅▅▅▆▆▆▆▅▆▆▆▆▆▆█▆▇▇▇▇████
train_corr,▁▁▁▂▃▂▄▃▅▄▃▄▅▄▄▅▅▅▅▆▆▆▆▅▆▆▆▆▆▆█▆▇▇▇▇████
train_loss,███▇▆▇▅▆▄▅▆▅▄▅▅▄▄▄▄▃▃▃▃▄▃▃▃▃▃▃▁▃▂▂▂▂▁▁▁▁
valid_loss,▆█▅▆▆▄▃▅▅▃▄▃▄▂▄█▄▇▁▄▃▄▂▆▂▄▂▅▂▃▅▄▅▃▆▄▃▄▃▄
valid_rmse,▆█▅▇▆▄▃▅▅▃▄▃▅▂▄█▄▇▁▄▃▄▂▆▂▄▂▅▂▃▅▄▅▄▆▄▃▄▂▄
valid_mae,▆▇▄▆▆▄▃▅▄▃▃▂▄▂▄█▃▆▁▃▂▄▂▅▁▃▁▄▁▃▄▄▄▃▆▃▃▃▂▄
valid_r2,▃▁▄▂▃▅▆▄▅▆▅▆▄▇▅▁▅▂█▅▆▅▇▃▇▅▇▅▇▆▅▅▄▆▃▆▆▆▇▅
valid_corr,▅▄▅▂▄▂▆▄▄▂▅▅▂▅▆▄▄▅▄▆▃▄▄▆▅▄▇▅▅▆▆▅█▄▁▅▇▆▅▅


EfficientNet3D(
  (_conv_stem): Conv3dStaticSamePadding(
    1, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), bias=False
    (static_padding): ZeroPad2d(padding=(0, 1, 0, 1, 0, 1), value=0.0)
  )
  (_bn0): BatchNorm3d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
  (_blocks): ModuleList(
    (0): MBConvBlock3D(
      (_depthwise_conv): Conv3dStaticSamePadding(
        32, 32, kernel_size=(3, 3, 3), stride=[2, 2, 2], groups=32, bias=False
        (static_padding): ZeroPad2d(padding=(0, 1, 0, 1, 0, 1), value=0.0)
      )
      (_bn1): BatchNorm3d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
      (_se_reduce): Conv3dStaticSamePadding(
        32, 8, kernel_size=(1, 1, 1), stride=(1, 1, 1)
        (static_padding): Identity()
      )
      (_se_expand): Conv3dStaticSamePadding(
        8, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1)
        (static_padding): Identity()
      )
      (_project_conv): Conv3dStaticS

In [7]:
wandb.finish()

## SWEEP

In [2]:
import wandb

sweep_config = {
    "name" : "Resnet",
    "method" : "grid", #grid, random
    "metric": { #목표로 삼을 매트릭에 대한 정보를 입력합니다.
        "name": "valid_mae",
        "goal": "minimize"
    },
    "parameters" : { #실험해볼 하이퍼 파라미터의 조합입니다.
        "optimizer" : {
            "values" : ['adam', 'adamW']
        },
        "weight_decay" :{
            "values": [0, 0.01]
        },
        "start_channels" :{
            "values": [8, 16, 32]
        }
      }
    }

sweep_id = wandb.sweep(sweep_config, project='3d_smri')

Create sweep with ID: i7iv565s
Sweep URL: https://wandb.ai/1pha/3d_smri/sweeps/i7iv565s


In [3]:
def run_sweep():
    
    with wandb.init():

        cfg = load_config()
        seed_everything(cfg.seed)
        cfg.registration = 'mni'
        cfg.unused_src = []
        cfg.plot = False

        cfg.augment = True
        cfg.epochs = 100
        cfg.model_name = 'resnet_no_maxpool'
        cfg.early_patience = 20
        
        _cfg = wandb.config
        __cfg = dict()
        __cfg.update(_cfg)
        cfg.update(__cfg)
        
        name = f'SWEEP_OPT-{cfg.optimizer}_WD{cfg.weight_decay}_SC{cfg.start_channels}'
        cfg.RESULT_PATH = os.path.join(cfg.RESULT_PATH, name + get_today())
        
        wandb.run.name = name
        wandb.config.update(cfg)
        
        run(cfg)
        
        clear_output()

In [None]:
wandb.agent(sweep_id, function=run_sweep)

# UNLEARNING

In [1]:
import os
import wandb

from sage.config import *
from sage.training import unlearner

from utils.misc import seed_everything, get_today

In [2]:
cfg = load_config()
seed_everything(cfg.seed)
cfg.registration = 'mni'
cfg.unused_src = ['Oasis1', 'Oasis3']

cfg.unlearn = True
cfg.unlearn_cfg.encoder.name = 'resnet'
cfg.loss = 'rmse'

cfg.unlearn_cfg.opt_conf.point = 0
cfg.unlearn_cfg.domainer.num_dbs = 4 - len(cfg.unused_src)

In [3]:
name = 'Grad_zero_test_venvcnn_amp_correct'
cfg.RESULT_PATH = os.path.join(cfg.RESULT_PATH, name + get_today())

In [4]:
tags = ['amp', f'num_db-{cfg.unlearn_cfg.domainer.num_dbs}']

In [None]:
wandb.login()
wandb.init(project='3d_smri_unlearning',
           config=vars(cfg),
           name=name,
           tags=tags
  )

In [None]:
unlearner.run(cfg)

## BACKWARD

In [30]:
a = t(2., requires_grad=True)
b = a ** 2
c = 2 * b
d = 4 * b

In [31]:
c.backward(retain_graph=True)
print(a.grad)

tensor(8.)


In [32]:
d.backward()
a.grad

tensor(24.)

In [33]:
a = t(2., requires_grad=True)
b = a ** 2
d = 4 * b
d.backward()
print(a.grad)

tensor(16.)
