# <center>Deep Learning for Medical Imaging</center>
## <center>Dose Predicition Challenge</center>
### <center>Marie Bouvard, Amandine Allmang</center>
#### <center>March, 26th 2023</center>

In [45]:
import torch
import numpy as np
import os
import matplotlib.pyplot as plt
import torchvision.transforms as transforms #.functional

from torch.utils.data import DataLoader
from helpers.dataset import DoseDataset, TestDataset
from helpers.constants import BATCH, LR, EPOCHS
from helpers.model_functions_mps import train_and_eval
from helpers.predict_and_submit import predict_and_submit_mps

In [46]:
data_dir = './MVA-Dose-Prediction/'
train_dir = data_dir + "train/"
test_dir = data_dir + "test/"
val_dir = data_dir + "validation/"

## Loading Data

### Train set

In [47]:
train_dataset = DoseDataset(train_dir)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH, shuffle=True, pin_memory=True)

In [25]:
for item, d in train_dataset:
    print(item.size())
    break

torch.Size([12, 128, 128])


### Validation Set

In [48]:
val_dataset = DoseDataset(val_dir)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH, shuffle=True, pin_memory=True)

### Test Set

In [49]:
test_dataset = TestDataset(test_dir)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, pin_memory=True)

# Model Investigation

In [50]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

Using device: mps


## `U-Net` on 50 epochs

In [29]:
from models.UNet import UNet

In [51]:
generator = UNet()

EPOCH = 50
criterion = torch.nn.L1Loss()
optimizer = torch.optim.Adam(generator.parameters(),lr=LR)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)

generator_unet, df, history = train_and_eval(generator, train_dataloader, val_dataloader, EPOCH, optimizer, criterion, lr_scheduler=scheduler)

Starting training...
--------------------------------------------------
Using device: mps
Epoch 1/50


100%|██████████| 975/975 [06:08<00:00,  2.65it/s]


Loss: 2.17
--------------------------------------------------
Epoch 2/50


100%|██████████| 975/975 [05:28<00:00,  2.97it/s]


Loss: 1.41
--------------------------------------------------
Epoch 3/50


100%|██████████| 975/975 [05:31<00:00,  2.94it/s]


Loss: 0.74
--------------------------------------------------
Epoch 4/50


100%|██████████| 975/975 [05:31<00:00,  2.94it/s]


Loss: 0.59
--------------------------------------------------
Epoch 5/50


100%|██████████| 975/975 [05:36<00:00,  2.89it/s]


Loss: 0.56
--------------------------------------------------
Epoch 6/50


100%|██████████| 975/975 [05:22<00:00,  3.02it/s]


Loss: 0.52
--------------------------------------------------
Epoch 7/50


100%|██████████| 975/975 [06:00<00:00,  2.70it/s]


Loss: 0.48
--------------------------------------------------
Epoch 8/50


100%|██████████| 975/975 [06:05<00:00,  2.67it/s]


Loss: 0.45
--------------------------------------------------
Epoch 9/50


100%|██████████| 975/975 [05:37<00:00,  2.89it/s]


Loss: 0.42
--------------------------------------------------
Epoch 10/50


100%|██████████| 975/975 [05:37<00:00,  2.89it/s]


Loss: 0.40
--------------------------------------------------
Epoch 11/50


100%|██████████| 975/975 [05:44<00:00,  2.83it/s]


Loss: 0.39
--------------------------------------------------
Epoch 12/50


100%|██████████| 975/975 [05:49<00:00,  2.79it/s]


Loss: 0.37
--------------------------------------------------
Epoch 13/50


100%|██████████| 975/975 [06:04<00:00,  2.67it/s]


Loss: 0.36
--------------------------------------------------
Epoch 14/50


100%|██████████| 975/975 [05:24<00:00,  3.00it/s]


Loss: 0.31
--------------------------------------------------
Epoch 15/50


100%|██████████| 975/975 [05:36<00:00,  2.90it/s]


Loss: 0.30
--------------------------------------------------
Epoch 16/50


100%|██████████| 975/975 [05:53<00:00,  2.76it/s]


Loss: 0.29
--------------------------------------------------
Epoch 17/50


100%|██████████| 975/975 [05:33<00:00,  2.92it/s]


Loss: 0.29
--------------------------------------------------
Epoch 18/50


100%|██████████| 975/975 [06:13<00:00,  2.61it/s]


Loss: 0.28
--------------------------------------------------
Epoch 19/50


100%|██████████| 975/975 [06:18<00:00,  2.58it/s]


Loss: 0.28
--------------------------------------------------
Epoch 20/50


100%|██████████| 975/975 [07:48<00:00,  2.08it/s]


Loss: 0.28
--------------------------------------------------
Epoch 21/50


100%|██████████| 975/975 [07:01<00:00,  2.31it/s]


Loss: 0.28
--------------------------------------------------
Epoch 22/50


100%|██████████| 975/975 [06:22<00:00,  2.55it/s]


Loss: 0.28
--------------------------------------------------
Epoch 23/50


100%|██████████| 975/975 [05:43<00:00,  2.84it/s]


Loss: 0.28
--------------------------------------------------
Epoch 24/50


100%|██████████| 975/975 [05:55<00:00,  2.75it/s]


Loss: 0.28
--------------------------------------------------
Epoch 25/50


100%|██████████| 975/975 [07:15<00:00,  2.24it/s]


Loss: 0.28
--------------------------------------------------
Epoch 26/50


100%|██████████| 975/975 [05:36<00:00,  2.90it/s]


Loss: 0.28
--------------------------------------------------
Epoch 27/50


100%|██████████| 975/975 [05:39<00:00,  2.87it/s]


Loss: 0.28
--------------------------------------------------
Epoch 28/50


100%|██████████| 975/975 [05:41<00:00,  2.85it/s]


Loss: 0.28
--------------------------------------------------
Epoch 29/50


100%|██████████| 975/975 [05:36<00:00,  2.90it/s]


Loss: 0.28
--------------------------------------------------
Epoch 30/50


100%|██████████| 975/975 [05:34<00:00,  2.91it/s]


Loss: 0.28
--------------------------------------------------
Epoch 31/50


100%|██████████| 975/975 [05:33<00:00,  2.93it/s]


Loss: 0.28
--------------------------------------------------
Epoch 32/50


100%|██████████| 975/975 [05:38<00:00,  2.88it/s]


Loss: 0.28
--------------------------------------------------
Epoch 33/50


100%|██████████| 975/975 [05:48<00:00,  2.80it/s]


Loss: 0.28
--------------------------------------------------
Epoch 34/50


100%|██████████| 975/975 [05:28<00:00,  2.97it/s]


Loss: 0.28
--------------------------------------------------
Epoch 35/50


100%|██████████| 975/975 [05:27<00:00,  2.98it/s]


Loss: 0.28
--------------------------------------------------
Epoch 36/50


100%|██████████| 975/975 [05:35<00:00,  2.90it/s]


Loss: 0.28
--------------------------------------------------
Epoch 37/50


100%|██████████| 975/975 [05:28<00:00,  2.97it/s]


Loss: 0.28
--------------------------------------------------
Epoch 38/50


100%|██████████| 975/975 [05:30<00:00,  2.95it/s]


Loss: 0.28
--------------------------------------------------
Epoch 39/50


100%|██████████| 975/975 [05:53<00:00,  2.76it/s]


Loss: 0.28
--------------------------------------------------
Epoch 40/50


100%|██████████| 975/975 [07:40<00:00,  2.12it/s]


Loss: 0.28
--------------------------------------------------
Epoch 41/50


100%|██████████| 975/975 [05:30<00:00,  2.95it/s]


Loss: 0.28
--------------------------------------------------
Epoch 42/50


100%|██████████| 975/975 [06:00<00:00,  2.70it/s]


Loss: 0.28
--------------------------------------------------
Epoch 43/50


100%|██████████| 975/975 [06:07<00:00,  2.65it/s]


Loss: 0.28
--------------------------------------------------
Epoch 44/50


100%|██████████| 975/975 [05:46<00:00,  2.82it/s]


Loss: 0.28
--------------------------------------------------
Epoch 45/50


100%|██████████| 975/975 [05:47<00:00,  2.81it/s]


Loss: 0.28
--------------------------------------------------
Epoch 46/50


100%|██████████| 975/975 [05:43<00:00,  2.84it/s]


Loss: 0.28
--------------------------------------------------
Epoch 47/50


100%|██████████| 975/975 [05:47<00:00,  2.80it/s]


Loss: 0.28
--------------------------------------------------
Epoch 48/50


100%|██████████| 975/975 [05:42<00:00,  2.85it/s]


Loss: 0.28
--------------------------------------------------
Epoch 49/50


100%|██████████| 975/975 [05:59<00:00,  2.71it/s]


Loss: 0.28
--------------------------------------------------
Epoch 50/50


100%|██████████| 975/975 [05:52<00:00,  2.77it/s]


Loss: 0.28
--------------------------------------------------
Training done. Took 8650.65s, 173.01s per epoch.

Starting evalution...
--------------------------------------------------
Evaluation on Training Set


100%|██████████| 975/975 [04:35<00:00,  3.54it/s]


Mean Absolute Error on Training set is 0.28
--------------------------------------------------
Evaluation on Validation Set


100%|██████████| 150/150 [00:41<00:00,  3.61it/s]

Mean Absolute Error on Validation set is 0.42
Evaluation done. Took 70.51s.





In [52]:
predict_and_submit_mps(generator, test_dataloader)

Making predictions...


100%|██████████| 1200/1200 [00:39<00:00, 30.33it/s]

Predictions made. Took 0.014087000017752871s.



