# Example
## Pretraining for Diffusion MRI

In this example, the functionality of the code is demonstrated by fine-tuning a pretrained network for segmentation
of the dMRI data.

In [1]:
from ExperimentModule import ExperimentModule
import ExperimentDataloader
import pytorch_lightning as pl

Firstly, we load our model. To do this, we access the appropriate pretrained network from the "PretrainedModels" folder
and initialise it as a new network. Here we decide on a network that is to perform a segmentation and uses a classic
autoencoding-transformed network without artificial distortions for this.

In [2]:
model = ExperimentModule(learning_mode='segmentation', pretrained='pre', distortions='nodist')

Now we load the data. These are automatically divided into a test-, training- and validation-set.

In [3]:
dataloader = ExperimentDataloader.DataModule(learning_mode='segmentation')

To carry out the training we use PyTorch-Lightning.

In [4]:
trainer = pl.Trainer(gpus=1,
                     max_epochs=10,
                     deterministic=True,
                     log_every_n_steps=10,
                     resume_from_checkpoint=0)

trainer.fit(model, datamodule=dataloader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type     | Params
---------------------------------------
0 | unet      | UNet3d   | 2.1 M 
1 | out_block | Conv3d   | 68    
2 | loss      | L1Loss   | 0     
3 | metric    | F1       | 0     
4 | metric2   | Accuracy | 0     
---------------------------------------
2.1 M     Trainable params
0         Non-trainable params
2.1 M     Total params
8.271     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]
 Loading data...


  edw = np.divide(dwi, meanb0)
100%|██████████| 3/3 [00:10<00:00,  3.55s/it]

Validation sanity check:   0%|          | 0/1 [00:00<?, ?it/s]


  f"The dataloader, {name}, does not have many workers which may be a bottleneck."




Validation Loss: tensor(0.5032, device='cuda:0')


                                                                      
 Loading data...


100%|██████████| 3/3 [00:03<00:00,  1.27s/it]
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"


Epoch 0:  50%|█████     | 1/2 [00:01<00:00,  1.27it/s, loss=0.503, v_num=0]

Validation Loss: tensor(0.4382, device='cuda:0')


Epoch 1:  50%|█████     | 1/2 [00:00<00:00,  3.37it/s, loss=0.471, v_num=0]   

Validation Loss: tensor(0.3793, device='cuda:0')


Epoch 2:  50%|█████     | 1/2 [00:00<00:00,  3.29it/s, loss=0.44, v_num=0]    

Validation Loss: tensor(0.3539, device='cuda:0')


Epoch 3:  50%|█████     | 1/2 [00:00<00:00,  3.28it/s, loss=0.419, v_num=0]  

Validation Loss: tensor(0.3344, device='cuda:0')


Epoch 4:  50%|█████     | 1/2 [00:00<00:00,  3.32it/s, loss=0.402, v_num=0]   

Validation Loss: tensor(0.3197, device='cuda:0')


Epoch 5:  50%|█████     | 1/2 [00:00<00:00,  3.31it/s, loss=0.388, v_num=0]   

Validation Loss: tensor(0.3110, device='cuda:0')


Epoch 6:  50%|█████     | 1/2 [00:00<00:00,  3.04it/s, loss=0.377, v_num=0]   

Validation Loss: tensor(0.3033, device='cuda:0')


Epoch 7:  50%|█████     | 1/2 [00:00<00:00,  3.21it/s, loss=0.368, v_num=0]   

Validat

Finally, the network can be tested. The results are written here in an Excel file, which is created in the same folder.

In [5]:
trainer.test(ckpt_path='best', test_dataloaders=dataloader)

  "`trainer.test(test_dataloaders)` is deprecated in v1.4 and will be removed in v1.6."
  f"DataModule.{name} has already been called, so it will not be called again. "
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



 Loading data...


100%|██████████| 3/3 [00:03<00:00,  1.27s/it]
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


Testing: 100%|██████████| 1/1 [00:00<00:00,  1.18it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'AccuracyORMAE': 0.757314920425415,
 'AccuracyORMAE_epoch': 0.757314920425415,
 'f1ORMSE': 0.01439349539577961,
 'f1ORMSE_epoch': 0.01439349539577961,
 'test_loss': 0.2902667820453644,
 'test_loss_epoch': 0.2902667820453644}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 1/1 [00:00<00:00,  1.13it/s]


[{'test_loss': 0.2902667820453644,
  'test_loss_epoch': 0.2902667820453644,
  'f1ORMSE': 0.01439349539577961,
  'f1ORMSE_epoch': 0.01439349539577961,
  'AccuracyORMAE': 0.757314920425415,
  'AccuracyORMAE_epoch': 0.757314920425415}]