In [None]:
import sys
sys.path.append("..")

Import Necessary Classes

In [5]:
import os

import numpy as np
import pandas as pd
import rasterio
from PIL import Image
import torch
import torchvision
import albumentations as aug
from sklearn.model_selection import train_test_split

from train import Trainer
from models.unet import ResNetUNet
from custom_datasets import SemSegImageFileDataFrameDataset
from transforms import AlbumentationTorchCompat
import constants
from losses import JaccardLoss

SEED = 17

Get path to the dataset

In [None]:
data_dir = r"dataset\train"

images_dir = os.path.join(data_dir, 'images')
mask_dir = os.path.join(data_dir, 'labels')

images = os.listdir(images_dir)
print('Total Images:', len(images))

Split Train\Val Image files

In [6]:
train_images, val_images = train_test_split(images, test_size=0.10, random_state=SEED)
print('Train Images:', len(train_images))
print('Val Images:', len(val_images))

Train Images: 46452
Val Images: 5162


In [8]:
def create_bfp_mask(mask):
    mask = np.array(mask)
    return Image.fromarray(mask[:,:,0])

def rasterio_to_pillow(full_img_file_path):
    
    with rasterio.open(full_img_file_path) as ds:
        band1 = ds.read(1)
        band2 = ds.read(2)
        band3 = ds.read(3)

        img = np.stack([band1, band2, band3], -1)
    
    return Image.fromarray(img)

Define data augmentation and torch transforms for train/val dataset

In [9]:
batch_size = 4

data_augmentation = aug.Compose([
                                 aug.HorizontalFlip(p=0.5),
                                 aug.VerticalFlip(p=0.5),
                                 aug.Transpose(p=0.5),
                                 aug.RandomRotate90(p=0.5),
                                 aug.ShiftScaleRotate(p=0.5),
                                 aug.RandomSizedCrop(p=0.5, min_max_height=(180, 180), width=512, height=512),
                                 aug.OneOf([aug.CLAHE(p=0.5), aug.RandomContrast(p=0.5), aug.RandomBrightness(p=0.5),
                                            aug.RandomGamma(p=0.5), aug.GaussNoise(p=0.5), 
                                            aug.ChannelShuffle(p=0.25), aug.Blur(p=0.3, blur_limit=2)],
                                           )
                                 ],  p=0.5)


torch_transforms = torchvision.transforms.Normalize(mean=constants.IMAGENET_MEAN, std=constants.IMAGENET_STD)

train_transforms = AlbumentationTorchCompat(albu_transforms=data_augmentation, torch_transforms=torch_transforms,
                                           apply_torch_transforms_to_mask=True)

val_transforms = AlbumentationTorchCompat(albu_transforms=None, torch_transforms=torch_transforms,
                                         apply_torch_transforms_to_mask=True)


train_dataset = SemSegImageFileDataFrameDataset(pd.DataFrame(train_images, columns=['image']), images_dir, 
                                          mask_dir, train_transforms, rasterio_to_pillow)


val_dataset = SemSegImageFileDataFrameDataset(pd.DataFrame(val_images, columns=['image']), images_dir,
                                              mask_dir, val_transforms, rasterio_to_pillow)



Define train/val data loaders

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size)

Select model architecture, optimizer and loss function

In [11]:
epochs = 25
model = ResNetUNet(n_class=3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

criterion = JaccardLoss(is_multiclass=True)

trainer = Trainer(model, optimizer, model_save_path=os.path.join('.', 'unet'), load_saved_model=False)

Start Training

In [12]:
trainer.fit(criterion, train_loader, val_loader, use_gpu=True, epochs=epochs)

Epoch 1/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Saving best validation model.

Epoch: 1/25	Train Loss: 0.805125	Val Loss: 0.783987
Epoch 2/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 2/25	Train Loss: 0.795918	Val Loss: 0.783987
Epoch 3/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 3/25	Train Loss: 0.765271	Val Loss: 0.783987
Epoch 4/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 4/25	Train Loss: 0.810570	Val Loss: 0.783987
Epoch 5/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 5/25	Train Loss: 0.798464	Val Loss: 0.783987
Epoch 6/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 6/25	Train Loss: 0.797220	Val Loss: 0.783987
Epoch 7/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 7/25	Train Loss: 0.784075	Val Loss: 0.783987
Epoch 8/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 8/25	Train Loss: 0.798820	Val Loss: 0.783987
Epoch 9/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 9/25	Train Loss: 0.781193	Val Loss: 0.783987
Epoch 10/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 10/25	Train Loss: 0.796209	Val Loss: 0.783987
Epoch 11/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 11/25	Train Loss: 0.804846	Val Loss: 0.783987
Epoch 12/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 12/25	Train Loss: 0.805126	Val Loss: 0.783987
Epoch 13/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 13/25	Train Loss: 0.786607	Val Loss: 0.783987
Epoch 14/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 14/25	Train Loss: 0.777860	Val Loss: 0.783987
Epoch 15/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 15/25	Train Loss: 0.798997	Val Loss: 0.783987
Epoch 16/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 16/25	Train Loss: 0.773023	Val Loss: 0.783987

Epoch 17/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 17/25	Train Loss: 0.797978	Val Loss: 0.783987
Epoch 18/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 18/25	Train Loss: 0.809295	Val Loss: 0.783987
Epoch 19/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 19/25	Train Loss: 0.795974	Val Loss: 0.783987
Epoch 20/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 20/25	Train Loss: 0.806342	Val Loss: 0.783987
Epoch 21/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 21/25	Train Loss: 0.790568	Val Loss: 0.783987
Epoch 22/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 22/25	Train Loss: 0.801428	Val Loss: 0.783987
Epoch 23/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 23/25	Train Loss: 0.787593	Val Loss: 0.783987
Epoch 24/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 24/25	Train Loss: 0.800940	Val Loss: 0.783987
Epoch 25/25:


HBox(children=(FloatProgress(value=0.0, description='Training    ', max=10.0, style=ProgressStyle(description_…

adding training tensorboard images


HBox(children=(FloatProgress(value=0.0, description='Validation  ', max=5.0, style=ProgressStyle(description_w…

Epoch: 25/25	Train Loss: 0.802672	Val Loss: 0.783987
