In [1]:
""" Parts of the U-Net model """

import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [2]:
""" Full assembly of the parts to form the complete network """

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

In [21]:
import torch
from torch.utils.data.dataset import Dataset  # For custom data-sets
from torchvision import transforms
import glob
import numpy as np
# From https://discuss.pytorch.org/t/beginner-how-do-i-write-a-custom-dataset-that-allows-me-to-return-image-and-its-target-image-not-label-as-a-pair/13988/4
# And https://discuss.pytorch.org/t/how-make-customised-dataset-for-semantic-segmentation/30881

#folder_data = "/mnt/g/Shared drives/2021-gtc-sea-ice/trainingdata/tiled/"

#img_list = []
#for filename in glob.glob(folder_data + '*sar.npy'):
#        filename_label = filename.replace("sar", "label")
#        img_list.append((filename, filename_label))

class CustomImageDataset(Dataset):
    """GTC Code for a dataset class. The class is instantiated with list of filenames within a directory (created using
    the list_npy_filenames function). The __getitem__ method pairs up corresponding sar-label .npy file pairs. This
    dataset can then be input to a dataloader."""
    
    def __init__(self, paths):
        self.paths = paths
    
    def __getitem__(self, index):
        image = torch.from_numpy(np.vstack(np.load(self.paths[index][0])).astype(float))[None,:]
        mask_raw = (np.load(self.paths[index][1]))
        maskremap100 = np.where(mask_raw == 100, 0, mask_raw)
        maskremap200 = np.where(maskremap100 == 200, 1, maskremap100)
        mask = torch.from_numpy(np.vstack(maskremap200).astype(float))
        
        #assert image.size == mask.size, \
        #    'Image and mask {name} should be the same size, but are {img.size} and {mask.size}'
        
        return {
            'image': image,
            'mask': mask
        }

    def __len__(self):
        return len(self.paths) 

In [9]:
# from https://gitcode.net/mirrors/milesial/Pytorch-UNet/-/blob/master/utils/dice_score.py
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
    # Average of Dice coefficient for all batches, or for a single mask
    assert input.size() == target.size()
    if input.dim() == 2 and reduce_batch_first:
        raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})')

    if input.dim() == 2 or reduce_batch_first:
        inter = torch.dot(input.reshape(-1), target.reshape(-1))
        sets_sum = torch.sum(input) + torch.sum(target)
        if sets_sum.item() == 0:
            sets_sum = 2 * inter

        return (2 * inter + epsilon) / (sets_sum + epsilon)
    else:
        # compute and average metric for each batch element
        dice = 0
        for i in range(input.shape[0]):
            dice += dice_coeff(input[i, ...], target[i, ...])
        return dice / input.shape[0]


def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
    # Average of Dice coefficient for all classes
    assert input.size() == target.size()
    dice = 0
    for channel in range(input.shape[1]):
        dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)

    return dice / input.shape[1]


def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
    # Dice loss (objective to minimize) between 0 and 1
    assert input.size() == target.size()
    fn = multiclass_dice_coeff if multiclass else dice_coeff
    return 1 - fn(input, target, reduce_batch_first=True)

In [10]:
import torch
import torch.nn.functional as F
from tqdm import tqdm

#from utils.dice_score import multiclass_dice_coeff, dice_coeff


def evaluate(net, dataloader, device):
    net.eval()
    num_val_batches = len(dataloader)
    dice_score = 0

    # iterate over the validation set
    for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
        image, mask_true = batch['image'], batch['mask']
        # move images and labels to correct device and type
        image = image.to(device=device, dtype=torch.float32)
        mask_true = mask_true.to(device=device, dtype=torch.long)
        mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float()

        with torch.no_grad():
            # predict the mask
            mask_pred = net(image)

            # convert to one-hot format
            if net.n_classes == 1:
                mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
                # compute the Dice score
                dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
            else:
                mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
                # compute the Dice score, ignoring background
                dice_score += multiclass_dice_coeff(mask_pred[:, 1:, ...], mask_true[:, 1:, ...], reduce_batch_first=False)

           

    net.train()

    # Fixes a potential division by zero error
    if num_val_batches == 0:
        return dice_score
    return dice_score / num_val_batches

In [22]:
import argparse
import logging
import sys
from pathlib import Path

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

#from utils.dice_score import dice_loss
# from unet import UNet
# from customImgDataset import CustomImageDataset

dir_img = Path('/mnt/g/Shared drives/2021-gtc-sea-ice/trainingdata/tiled/')
dir_checkpoint = Path('/mnt/g/Shared drives/2021-gtc-sea-ice/trainingdata/')

def create_img_list(data_path):
    img_list = []
    for filename in glob.glob(str(data_path) + '/*sar.npy'):
        filename_label = filename.replace("sar", "label")
        img_list.append((filename, filename_label))
    return img_list


def train_net(net,
              device,
              epochs: int = 5,
              batch_size: int = 10,
              learning_rate: float = 0.001,
              val_percent: float = 0.1,
              save_checkpoint: bool = True,
              img_scale: float = 0.5,
              amp: bool = False):
    # 1. Create dataset
    img_list = create_img_list(dir_img)
    dataset = CustomImageDataset(img_list)
    
    # 2. Split into train / validation partitions
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))
    
    # 3. Create data loaders
    loader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True)
    train_loader = DataLoader(train_set, shuffle=True, **loader_args)
    val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)
    
    # Initialize logging
    experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
    experiment.config.update(dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
                                  val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale,
                                  amp=amp))

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {learning_rate}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_checkpoint}
        Device:          {device.type}
        Images scaling:  {img_scale}
        Mixed Precision: {amp}
    ''')

    # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
    optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)  # goal: maximize Dice score
    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
    criterion = nn.CrossEntropyLoss()
    global_step = 0
    
    # 5. Begin training
    for epoch in range(epochs):
        net.train()
        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                images = batch['image']
                true_masks = batch['mask']
                
                assert images.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {images.shape[1]} channels of {images.shape}. Please check that ' \
                    'the images are loaded correctly.'

                images = images.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device, dtype=torch.long)

                with torch.cuda.amp.autocast(enabled=amp):
                    masks_pred = net(images)
                    loss = criterion(masks_pred, true_masks) \
                           + dice_loss(F.softmax(masks_pred, dim=1).float(),
                                       F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(),
                                       multiclass=True)
                # change number for permute?
                
                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                grad_scaler.step(optimizer)
                grad_scaler.update()

                pbar.update(images.shape[0])
                global_step += 1
                epoch_loss += loss.item()
                experiment.log({
                    'train loss': loss.item(),
                    'step': global_step,
                    'epoch': epoch
                })
                pbar.set_postfix(**{'loss (batch)': loss.item()})

                # Evaluation round
                division_step = (n_train // (10 * batch_size))
                if division_step > 0:
                    if global_step % division_step == 0:
                        histograms = {}
                        for tag, value in net.named_parameters():
                            tag = tag.replace('/', '.')
                            histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
                            histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())

                        val_score = evaluate(net, val_loader, device)
                        scheduler.step(val_score)

                        logging.info('Validation Dice score: {}'.format(val_score))
                        experiment.log({
                            'learning rate': optimizer.param_groups[0]['lr'],
                            'validation Dice': val_score,
                            'images': wandb.Image(images[0].cpu()),
                            'masks': {
                                'true': wandb.Image(true_masks[0].float().cpu()),
                                'pred': wandb.Image(torch.softmax(masks_pred, dim=1).argmax(dim=1)[0].float().cpu()),
                            },
                            'step': global_step,
                            'epoch': epoch,
                            **histograms
                        })

        if save_checkpoint:
            Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
            torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1)))
            logging.info(f'Checkpoint {epoch + 1} saved!')


def get_args():
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
    parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
    parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
    parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=0.00001,
                        help='Learning rate', dest='lr')
    parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
    parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
    parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
                        help='Percent of the data that is used as validation (0-100)')
    parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')

    return parser.parse_args()


if __name__ == '__main__':
    args = get_args()

    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    # Change here to adapt to your data
    # n_channels=3 for RGB images
    # n_classes is the number of probabilities you want to get per pixel
    net = UNet(n_channels=1, n_classes=2, bilinear=True)

    logging.info(f'Network:\n'
                 f'\t{net.n_channels} input channels\n'
                 f'\t{net.n_classes} output channels (classes)\n'
                 f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')

#    if args.load:
#        net.load_state_dict(torch.load(args.load, map_location=device))
#        logging.info(f'Model loaded from {args.load}')

    net.to(device=device)
    try:
        train_net(net=net,
                  epochs=args.epochs,
                  batch_size=args.batch_size,
                  learning_rate=args.lr,
                  device=device,
                  img_scale=args.scale,
                  val_percent=args.val / 100,
                  amp=args.amp)
    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        logging.info('Saved interrupt')
        sys.exit(0)

INFO: Using device cpu
INFO: Network:
	1 input channels
	2 output channels (classes)
	Bilinear upscaling


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

[34m[1mwandb[0m: wandb version 0.12.10 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


INFO: Starting training:
        Epochs:          5
        Batch size:      1
        Learning rate:   1e-05
        Training size:   133
        Validation size: 14
        Checkpoints:     True
        Device:          cpu
        Images scaling:  0.5
        Mixed Precision: False
    
Epoch 1/5:  10%|████▉                                              | 13/133 [02:16<18:14,  9.12s/img, loss (batch)=1.07]
Validation round:   0%|                                                                       | 0/14 [00:00<?, ?batch/s][A
Validation round:   7%|████▌                                                          | 1/14 [00:26<05:37, 26.00s/batch][A
Validation round:  14%|█████████                                                      | 2/14 [00:30<02:38, 13.19s/batch][A
Validation round:  21%|█████████████▌                                                 | 3/14 [00:34<01:38,  8.99s/batch][A
Validation round:  29%|██████████████████                                             | 4/14

AssertionError  File "/home/mlisaius/miniconda3/envs/unetenv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
:     can only test a child process    if w.is_alive():

if w.is_alive():  File "/home/mlisaius/miniconda3/envs/unetenv/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process

  File "/home/mlisaius/miniconda3/envs/unetenv/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f1049b58710>
Traceback (most recent call last):
  File "/home/mlisaius/miniconda3/envs/unetenv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/mlisai

Validation round:  93%|█████████████████████████████████████████████████████████▌    | 13/14 [00:37<00:02,  2.95s/batch][A
Validation round: 100%|██████████████████████████████████████████████████████████████| 14/14 [00:40<00:00,  3.01s/batch][A
                                                                                                                        [AINFO: Validation Dice score: 0.6299758553504944
Epoch 1/5: 100%|█████████████████████████████████████████████████| 133/133 [29:57<00:00, 13.51s/img, loss (batch)=0.959]
INFO: Checkpoint 1 saved!
Epoch 2/5:   4%|█▌                                         | 5/133 [8:28:15<278:18:12, 7827.28s/img, loss (batch)=0.893]wandb: Network error (ConnectionError), entering retry loop.
Epoch 2/5:   8%|███▏                                       | 10/133 [8:29:05<38:47:52, 1135.55s/img, loss (batch)=0.885]
Validation round:   0%|                                                                       | 0/14 [00:00<?, ?batch/s][A
Validati

Validation round:  21%|█████████████▌                                                 | 3/14 [00:09<00:33,  3.03s/batch][A
Validation round:  29%|██████████████████                                             | 4/14 [00:12<00:29,  2.92s/batch][A
Validation round:  36%|██████████████████████▌                                        | 5/14 [00:14<00:25,  2.83s/batch][A
Validation round:  43%|███████████████████████████                                    | 6/14 [00:17<00:22,  2.77s/batch][A
Validation round:  50%|███████████████████████████████▌                               | 7/14 [00:20<00:19,  2.74s/batch][A
Validation round:  57%|████████████████████████████████████                           | 8/14 [00:22<00:16,  2.69s/batch][A
Validation round:  64%|████████████████████████████████████████▌                      | 9/14 [00:25<00:13,  2.66s/batch][A
Validation round:  71%|████████████████████████████████████████████▎                 | 10/14 [00:27<00:10,  2.63s/batch][A
Validati

Validation round:  79%|████████████████████████████████████████████████▋             | 11/14 [00:30<00:07,  2.66s/batch][A
Validation round:  86%|█████████████████████████████████████████████████████▏        | 12/14 [00:33<00:05,  2.68s/batch][A
Validation round:  93%|█████████████████████████████████████████████████████████▌    | 13/14 [00:36<00:02,  2.73s/batch][A
Validation round: 100%|██████████████████████████████████████████████████████████████| 14/14 [00:38<00:00,  2.69s/batch][A
                                                                                                                        [AINFO: Validation Dice score: 0.6098114848136902
Epoch 3/5:  54%|███████████████████████████                       | 72/133 [13:23<08:18,  8.17s/img, loss (batch)=0.877]
Validation round:   0%|                                                                       | 0/14 [00:00<?, ?batch/s][A
Validation round:   7%|████▌                                                          | 

Validation round:  14%|█████████                                                      | 2/14 [00:08<00:46,  3.91s/batch][A
Validation round:  21%|█████████████▌                                                 | 3/14 [00:11<00:40,  3.65s/batch][A
Validation round:  29%|██████████████████                                             | 4/14 [00:14<00:33,  3.38s/batch][A
Validation round:  36%|██████████████████████▌                                        | 5/14 [00:17<00:29,  3.24s/batch][A
Validation round:  43%|███████████████████████████                                    | 6/14 [00:20<00:25,  3.15s/batch][A
Validation round:  50%|███████████████████████████████▌                               | 7/14 [00:23<00:21,  3.08s/batch][A
Validation round:  57%|████████████████████████████████████                           | 8/14 [00:26<00:17,  2.99s/batch][A
Validation round:  64%|████████████████████████████████████████▌                      | 9/14 [00:28<00:14,  2.89s/batch][A
Validati

Epoch 4/5:  62%|██████████████████████████████▊                   | 82/133 [19:07<09:09, 10.77s/img, loss (batch)=0.934]
Validation round:   0%|                                                                       | 0/14 [00:00<?, ?batch/s][A
Validation round:   7%|████▌                                                          | 1/14 [00:04<00:51,  4.00s/batch][A
Validation round:  14%|█████████                                                      | 2/14 [00:07<00:42,  3.53s/batch][A
Validation round:  21%|█████████████▌                                                 | 3/14 [00:10<00:38,  3.46s/batch][A
Validation round:  29%|██████████████████                                             | 4/14 [00:14<00:34,  3.48s/batch][A
Validation round:  36%|██████████████████████▌                                        | 5/14 [00:17<00:30,  3.36s/batch][A
Validation round:  43%|███████████████████████████                                    | 6/14 [00:20<00:26,  3.25s/batch][A
Validation 

Validation round:  57%|████████████████████████████████████                           | 8/14 [00:27<00:17,  3.00s/batch][A
Validation round:  64%|████████████████████████████████████████▌                      | 9/14 [00:29<00:14,  2.94s/batch][A
Validation round:  71%|████████████████████████████████████████████▎                 | 10/14 [00:33<00:11,  3.00s/batch][A
Validation round:  79%|████████████████████████████████████████████████▋             | 11/14 [00:36<00:09,  3.02s/batch][A
Validation round:  86%|█████████████████████████████████████████████████████▏        | 12/14 [00:39<00:05,  2.97s/batch][A
Validation round:  93%|█████████████████████████████████████████████████████████▌    | 13/14 [00:41<00:02,  2.97s/batch][A
Validation round: 100%|██████████████████████████████████████████████████████████████| 14/14 [00:44<00:00,  2.88s/batch][A
                                                                                                                        [AINFO: Val