# Training UNet for drop segmentation

In [9]:
import torch
from checkpoints import *
from measures import *

In [10]:
config = {
    "lr": 0.001,
    "batch_size": 8,
    "epochs": 20,
    "threshold": 0.3,
    "init_from_checkpoint": False,
    "image_dir": '../../data/stereo/train/image',
    "mask_dir": '../../data/stereo/train/mask',
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "checkpoint_dir": "checkpoints",
    "checkpoint": None, # if None loads last saved checkpoint
    "print_model": False,
    "seed": 3407 # if None uses random seed
}
print(f"Training using {config['device']}")

Training using cuda


In [11]:
# Set seed
if config["seed"] is not None:
    torch.manual_seed(config["seed"])  

In [12]:
# Load train and val datasets and prepare loaders

from torch.utils.data import DataLoader
import importlib
import dataset
importlib.reload(dataset)
from dataset import WaterDropDataset
dataset = WaterDropDataset(
    image_dir=config["image_dir"],
    mask_dir=config["mask_dir"],
    threshold=config["threshold"],
    crop_shape=(256, 256)
)

assert dataset, "Dataset is empty!"

train_dataset, val_dataset = dataset.random_split(0.1)
train_loader = DataLoader(
    train_dataset,
    batch_size=config["batch_size"],
    shuffle=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config["batch_size"],
    shuffle=True
)
print(len(train_dataset), len(val_dataset))
print (f'Loaded {len(dataset)} images\n')
print (f'Train: {len(train_dataset)} images, {len(train_loader)} batches')
print (f'Val: {len(val_dataset)} images, {len(val_loader)} batches')

611 67
Loaded 678 images

Train: 611 images, 77 batches
Val: 67 images, 9 batches


In [13]:
# Load model, loss function and optimizer
from torch import nn
from unet import UNet
from unet import init_weights
from pathlib import Path

model = UNet(4).to(config['device'])

# Load or fill weights
# And set the start_epoch of model
if config["init_from_checkpoint"]:
    if config["checkpoint"] is None:
        path = last_checkpoint(config["checkpoint_dir"])
    else:
        path = Path(config["checkpoint_dir"], config["checkpoint"])     
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint["model_state_dict"])
    print(f"Loaded parameters from '{path}'")
    print_checkpoint(checkpoint)
    start_epoch = checkpoint["epochs"]
else:
    init_weights(model, torch.nn.init.normal_, mean=0., std=1)
    print("Randomly initiated parameters")
    start_epoch = 0

# Set optimizer & loss_fn
optimizer = torch.optim.Adam(params=model.parameters(), lr=config['lr'])

loss_fn = torch.nn.BCEWithLogitsLoss()
scaler = torch.cuda.amp.GradScaler()

layers = model.train()
if config["print_model"]:
    print(layers)

Randomly initiated parameters


In [14]:
def check_accuracy_and_save(model, optimizer, loss_fn, epoch, train_loss, save=True):
    model.eval()

    losses = []
    accuracies = []
    precisions = []
    recalls = []
    with torch.no_grad():
        for x, y in val_loader:
            x = x.to(config['device'])
            y = y.to(config['device'])

            pred = model(x)
            loss = loss_fn(pred, y)
            losses.append(loss.item())
            
            pred = torch.sigmoid(model(x))
            pred = pred.cpu().detach().numpy()
            y = y.cpu().detach().numpy()
            
            accuracies.append(accuracy(y, pred))
            precisions.append(precision(y, pred))
            recalls.append(recall(y, pred))
            
    mean = lambda l: sum(l) / len(l)
    mean_loss = mean(losses)
    print("Valid loss:", mean_loss, '\n')
    print("Accuracy: ", mean(accuracies))
    print("Precision:", mean(precisions))
    print("Recall:   ", mean(recalls), '\n')
    model.train()

    if save:
        save_checkpoint(
            config["checkpoint_dir"],
            model,
            optimizer,
            loss_fn,
            epoch,
            train_loss,
            mean_loss
        )

In [15]:
from tqdm.notebook import tqdm
from sys import stdout
def train(save_checkpoints=True):
    for epoch in range(config['epochs']):
        epoch += start_epoch
        print("Epoch", epoch)
        
        loader = tqdm(train_loader)
        losses = []
        
        for image, gt in loader:
            image = image.to(config['device'])
            gt = gt.float().to(config['device'])
            with torch.cuda.amp.autocast():
                pred = model(image)
                loss = loss_fn(pred, gt)
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_loss = loss.item()
            losses.append(train_loss)
            loader.set_postfix(loss=train_loss)

        mean_loss = sum(losses) / len(losses)
        print("Train loss:", mean_loss)
        check_accuracy_and_save(
            model, 
            optimizer,
            loss_fn,
            epoch, 
            mean_loss,
            save=save_checkpoints and mean_loss < 0.31
        )

In [None]:
train(save_checkpoints=True)
start_epoch = start_epoch + config["epochs"]

Epoch 0


  0%|          | 0/77 [00:00<?, ?it/s]

Train loss: 0.9012227453194656
Valid loss: 0.6169067621231079 

Accuracy:  0.7860359461219223
Precision: 0.3682696513003773
Recall:    0.3522972898316328 

Epoch 1


  0%|          | 0/77 [00:00<?, ?it/s]

Train loss: 0.5969705779057044
Valid loss: 0.5816483232710097 

Accuracy:  0.8001757965043739
Precision: 0.31787168200093285
Recall:    0.3661659701934291 

Epoch 2


  0%|          | 0/77 [00:00<?, ?it/s]

Train loss: 0.5697598766970944
Valid loss: 0.5761347512404124 

Accuracy:  0.8030943095132157
Precision: 0.4520794084226643
Recall:    0.3613169506813089 

Epoch 3


  0%|          | 0/77 [00:00<?, ?it/s]

Train loss: 0.5540530735796149
Valid loss: 0.5394074585702684 

Accuracy:  0.8184413352498301
Precision: 0.3512517511844635
Recall:    0.40694226583258974 

Epoch 4


  0%|          | 0/77 [00:00<?, ?it/s]

Train loss: 0.5388863779507674
Valid loss: 0.5141156448258294 

Accuracy:  0.8280531054845564
Precision: 0.3892309615319526
Recall:    0.4315181139132215 

Epoch 5


  0%|          | 0/77 [00:00<?, ?it/s]

Train loss: 0.5169836658161956
Valid loss: 0.48825517627927995 

Accuracy:  0.8383113245169321
Precision: 0.5096869829490229
Recall:    0.4194192105593781 

Epoch 6


  0%|          | 0/77 [00:00<?, ?it/s]

Train loss: 0.5001560844384231
Valid loss: 0.564020938343472 

Accuracy:  0.8144431354271041
Precision: 0.2952349045555349
Recall:    0.48371766780123665 

Epoch 7


  0%|          | 0/77 [00:00<?, ?it/s]

Train loss: 0.4903015320177202
Valid loss: 0.45507612493303085 

Accuracy:  0.8523819325146852
Precision: 0.5362839533223046
Recall:    0.4271659716894781 

Epoch 8


  0%|          | 0/77 [00:00<?, ?it/s]

Train loss: 0.48391300439834595
Valid loss: 0.4481518632835812 

Accuracy:  0.8550867474189512
Precision: 0.4604626137211367
Recall:    0.42696460119138163 

Epoch 9


  0%|          | 0/77 [00:00<?, ?it/s]

Train loss: 0.47564365143899795
Valid loss: 0.4740421242184109 

Accuracy:  0.845158516532845
Precision: 0.5720568773923097
Recall:    0.4337090523568568 

Epoch 10


  0%|          | 0/77 [00:00<?, ?it/s]

In [None]:
check_accuracy_and_save(
            model, 
            optimizer,
            loss_fn,
            9, 
            0.4662005650062187,
            save=True
        )

# Estimation

In [None]:
def predict(model, x, y, binary_map=True, binary_threshold=0.4):
    image_np = x.detach().numpy()[0]
    gt_np = y.detach().numpy()[0][0]

    model.eval()
    with torch.no_grad():
        x = x.to(config['device'])
        y = y.to(config['device'])

        pred = torch.sigmoid(model(x))
    model.train()
    pred_np  = pred.cpu().detach().numpy()[0][0]
    if binary_map:
        pred_np = pred_np >= binary_threshold
    return image_np, gt_np, pred_np

In [None]:
def plot_prediction(x, y, binary_map=True, binary_threshold=0.42):
    x = torch.stack([x])
    y = torch.stack([y])
    x, y, pred = predict(model, x, y, binary_map, binary_threshold)
    print("Accuracy: ", accuracy([y], [pred]))
    print("Precision:", precision([y], [pred]))
    print("Recall:   ", recall([y], [pred]))
    
    from matplotlib import pyplot as plt
    from PIL import Image
    import numpy as np

    sat = x[3]
    
    fig, axs = plt.subplots(1, 3, figsize=(17, 10))
    img = np.transpose(x[:3], (1, 2, 0))
    axs[0].imshow(img, cmap='gray')
    axs[1].imshow(y, cmap='gray')
    axs[2].imshow(pred, cmap='gray')

In [None]:
# Get prediction for random image and crop
from random import randint
idx = randint(0, len(val_dataset)-1)
print("Index:", idx)
x, y = val_dataset[idx]
plot_prediction(x, y, binary_map=True)

In [None]:
# Clear checkpoints by condition
clear_checkpoints(config["checkpoint_dir"], condition=lambda ch: ch["train_loss"] >= 0.30, save_last=1)

In [None]:
!ls -la checkpoints

In [None]:
# Count the coverage of the dataset
import numpy as np
areas = np.array([y.mean().cpu().numpy() for [x, y] in dataset])
for x in np.linspace(0, 0.9, 10):
    print(f'{x}-{x+0.1}: {np.sum((areas >= x) & (areas < x + 0.1))/len(areas) * 100}')