### Import libraries

In [None]:
! pip install torchmetrics

In [None]:
import os
import sys
import torch
from torch import nn
from torch import optim
from torchmetrics import JaccardIndex
jaccardidx = JaccardIndex(task="BINARY")
jaccardidx.__name__ = 'JaccardIndex'

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
work_dir = "/content/drive/MyDrive/Deep_learning_course/Lung_segmentation/segmentation_problem/"
sys.path.append(work_dir)
from unet_model import CustomUnet
from fit_func_segmentation import FitTrainEval
from custom_load import LungDatasetSeg, LungDatasetLoader

### Copy data to local path

In [None]:
path_to_zip_folder = '/content/drive/MyDrive/Deep_learning_course/Lung_segmentation/data/m3ex02-data.zip'
current_folder_path = '/content/localpath/'
!unzip -q $path_to_zip_folder -d $current_folder_path

### Define parameters

In [None]:
# data parameters
data_path: str = '/content/localpath/m3ex02-data/'
path_test: str = os.path.join(data_path, 'Test/')
path_train: str = os.path.join(data_path, 'Train/')
path_eval: str = os.path.join(data_path, 'Val/')
# training parameters
batch_size: int = 70
workers: int = 8
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model parameters
path_save: str = '/content/drive/MyDrive/Deep_learning_course/Lung_segmentation/segmentation_problem/models/'
name_save: str = 'covid_segmentation_2.pkl'
name_load: str = 'covid_segmentation_1.pkl'
in_channels: int = 1
out_channels: int = 1
metrics: list = [jaccardidx]
EPOCHS: int = 30

### Load data

In [None]:
data_train = LungDatasetLoader(
    path_train,
    batch_size=batch_size,
    workers=workers,
    pin_memory_device=device
)

data_eval = LungDatasetLoader(
    path_eval,
    batch_size=batch_size,
    workers=workers,
    pin_memory_device=device
)

### Instantiate model

In [None]:
jaccardidx.to(device)

BinaryJaccardIndex()

In [None]:
model = CustomUnet(in_channels=in_channels)
state = torch.load(os.path.join(path_save, name_load))
model.load_state_dict(state["model"])
model = model.to(device)

Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


### Train the model

In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-6)

In [None]:
optimization = FitTrainEval(
    model=model,
    loss=criterion,
    optimizer=optimizer,
    metrics=metrics,
    path_save=os.path.join(path_save, name_save),
    device=device
)

In [None]:
trained = optimization.fit_train_eval(
    data_train,
    data_eval,
    epochs= EPOCHS
)

Epoch 1/30
Training


  0%|          | 0/311 [00:00<?, ?it/s]

Evaluating


  0%|          | 0/78 [00:00<?, ?it/s]

Train loss: 0.6091580375597792 	 Val loss: 0.6099696709559514
Train JaccardIndex: 0.9475890503435657 	 Val JaccardIndex: 0.949759134115317

Best metric: -inf
Current metric: 0.949759134115317
Epoch time: 137.4778175354004 s
Saving model
Epoch 2/30
Training


  0%|          | 0/311 [00:00<?, ?it/s]

Evaluating


  0%|          | 0/78 [00:00<?, ?it/s]

Train loss: 0.6090796384780737 	 Val loss: 0.6098704437414805
Train JaccardIndex: 0.9480312336296131 	 Val JaccardIndex: 0.949452131222456

Best metric: 0.949759134115317
Current metric: 0.949452131222456
Epoch time: 137.51830291748047 s
Epoch 3/30
Training


  0%|          | 0/311 [00:00<?, ?it/s]

Evaluating


  0%|          | 0/78 [00:00<?, ?it/s]

Train loss: 0.6090440890030079 	 Val loss: 0.6098662507839692
Train JaccardIndex: 0.948480083053135 	 Val JaccardIndex: 0.9505055638460013

Best metric: 0.949759134115317
Current metric: 0.9505055638460013
Epoch time: 137.55609583854675 s
Saving model
Epoch 4/30
Training


  0%|          | 0/311 [00:00<?, ?it/s]

Evaluating


  0%|          | 0/78 [00:00<?, ?it/s]

Train loss: 0.6090534850715441 	 Val loss: 0.6098784880760388
Train JaccardIndex: 0.9486798759058741 	 Val JaccardIndex: 0.950832595427831

Best metric: 0.9505055638460013
Current metric: 0.950832595427831
Epoch time: 137.6877899169922 s
Saving model
Epoch 5/30
Training


  0%|          | 0/311 [00:00<?, ?it/s]

Evaluating


  0%|          | 0/78 [00:00<?, ?it/s]

Train loss: 0.6090444311068373 	 Val loss: 0.6098623841236799
Train JaccardIndex: 0.9486604915554501 	 Val JaccardIndex: 0.9510671175443209

Best metric: 0.950832595427831
Current metric: 0.9510671175443209
Epoch time: 137.4990394115448 s
Saving model
Epoch 6/30
Training


  0%|          | 0/311 [00:00<?, ?it/s]

Evaluating


  0%|          | 0/78 [00:00<?, ?it/s]

Train loss: 0.6090123273935348 	 Val loss: 0.6098503294663552
Train JaccardIndex: 0.9488809076152814 	 Val JaccardIndex: 0.9511415263017019

Best metric: 0.9510671175443209
Current metric: 0.9511415263017019
Epoch time: 137.5471532344818 s
Saving model
Epoch 7/30
Training


  0%|          | 0/311 [00:00<?, ?it/s]

Evaluating


  0%|          | 0/78 [00:00<?, ?it/s]

Train loss: 0.609018716781469 	 Val loss: 0.6098634012234516
Train JaccardIndex: 0.9488768964718393 	 Val JaccardIndex: 0.9508686195581387

Best metric: 0.9511415263017019
Current metric: 0.9508686195581387
Epoch time: 137.45852398872375 s
Epoch 8/30
Training


  0%|          | 0/311 [00:00<?, ?it/s]

Evaluating


  0%|          | 0/78 [00:00<?, ?it/s]

Train loss: 0.6089825852507563 	 Val loss: 0.6098667421402075
Train JaccardIndex: 0.9489952785792458 	 Val JaccardIndex: 0.9505611314223363

Best metric: 0.9511415263017019
Current metric: 0.9505611314223363
Epoch time: 137.41750955581665 s
Epoch 9/30
Training


  0%|          | 0/311 [00:00<?, ?it/s]

Evaluating


  0%|          | 0/78 [00:00<?, ?it/s]

Train loss: 0.6089608197810182 	 Val loss: 0.6098808332895621
Train JaccardIndex: 0.9491068247430194 	 Val JaccardIndex: 0.9505560161211551

Best metric: 0.9511415263017019
Current metric: 0.9505560161211551
Epoch time: 137.7095901966095 s
Epoch 10/30
Training


  0%|          | 0/311 [00:00<?, ?it/s]

Evaluating


  0%|          | 0/78 [00:00<?, ?it/s]

Train loss: 0.608998439702957 	 Val loss: 0.6098107695579529
Train JaccardIndex: 0.9491528054142305 	 Val JaccardIndex: 0.9510755561865293

Best metric: 0.9511415263017019
Current metric: 0.9510755561865293
Epoch time: 137.33726119995117 s
Epoch 11/30
Training


  0%|          | 0/311 [00:00<?, ?it/s]

KeyboardInterrupt: 