In [1]:
#!git clone https://github.com/milesial/Pytorch-UNet.git
# Fix for cuda init error
#!pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html

In [2]:
import argparse
import logging
import sys
from pathlib import Path
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

from utils.clothing_dataset import BasicDataset, ClothingDataset
from utils.dice_score import dice_loss
from utils.evaluate import evaluate
from unet import UNet
import gc
import matplotlib.pyplot as plt
from datetime import datetime

In [3]:
import torch.cuda
print(torch.cuda.is_available())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

True


device(type='cuda')

# Setup wandb

In [4]:
# wandb.init(project='unet', entity='endev')

# Dataset Path Variables

In [5]:
root = "./data/"
dir_checkpoint = Path('./training/')

In [6]:
train_dir_img = Path('/data/datasets/clothing-size/train/imgs/')
train_dir_mask = Path('/data/datasets/clothing-size/train/mask/')

In [7]:
val_dir_img = Path('/data/datasets/clothing-size/val/imgs/')
val_dir_mask = Path('/data/datasets/clothing-size/val/mask/')

# Check data

In [8]:
ls /data/datasets/clothing-size/val/imgs/

[0m[01;32mDSC_2702.jpg[0m*  [01;32mDSC_2716.jpg[0m*  [01;32mIMG_9283.jpg[0m*  [01;32mIMG_9386.jpg[0m*  [01;32mIMG_9427.jpg[0m*
[01;32mDSC_2704.jpg[0m*  [01;32mDSC_2721.jpg[0m*  [01;32mIMG_9315.jpg[0m*  [01;32mIMG_9388.jpg[0m*  [01;32mIMG_9429.jpg[0m*
[01;32mDSC_2707.jpg[0m*  [01;32mDSC_2722.jpg[0m*  [01;32mIMG_9326.jpg[0m*  [01;32mIMG_9389.jpg[0m*  [01;32mIMG_9430.jpg[0m*
[01;32mDSC_2708.jpg[0m*  [01;32mIMG_9263.jpg[0m*  [01;32mIMG_9329.jpg[0m*  [01;32mIMG_9398.jpg[0m*
[01;32mDSC_2714.jpg[0m*  [01;32mIMG_9279.jpg[0m*  [01;32mIMG_9330.jpg[0m*  [01;32mIMG_9422.jpg[0m*


In [9]:
ls /data/datasets/clothing-size/val/mask/

[0m[01;32mDSC_2702_mask.gif[0m*  [01;32mDSC_2721_mask.gif[0m*  [01;32mIMG_9326_mask.gif[0m*  [01;32mIMG_9398_mask.gif[0m*
[01;32mDSC_2704_mask.gif[0m*  [01;32mDSC_2722_mask.gif[0m*  [01;32mIMG_9329_mask.gif[0m*  [01;32mIMG_9422_mask.gif[0m*
[01;32mDSC_2707_mask.gif[0m*  [01;32mIMG_9263_mask.gif[0m*  [01;32mIMG_9330_mask.gif[0m*  [01;32mIMG_9427_mask.gif[0m*
[01;32mDSC_2708_mask.gif[0m*  [01;32mIMG_9279_mask.gif[0m*  [01;32mIMG_9386_mask.gif[0m*  [01;32mIMG_9429_mask.gif[0m*
[01;32mDSC_2714_mask.gif[0m*  [01;32mIMG_9283_mask.gif[0m*  [01;32mIMG_9388_mask.gif[0m*  [01;32mIMG_9430_mask.gif[0m*
[01;32mDSC_2716_mask.gif[0m*  [01;32mIMG_9315_mask.gif[0m*  [01;32mIMG_9389_mask.gif[0m*


In [10]:
def train_net(net,
              device,
              epochs: int = 5,
              batch_size: int = 1,
              learning_rate: float = 0.001,
              save_checkpoint: bool = True,
              img_scale: float = 0.5,
              amp: bool = False):
    # 1. Create dataset
    train_dataset = ClothingDataset(train_dir_img, train_dir_mask, img_scale)
    val_dataset = ClothingDataset(val_dir_img, val_dir_mask, img_scale)

    # 2. Totals
    n_train = len(train_dataset)
    n_val = len(val_dataset)

    # 3. Create data loaders
    loader_args = dict(batch_size=batch_size, num_workers=1, pin_memory=True)
    train_loader = DataLoader(train_dataset, shuffle=True, **loader_args)
    val_loader = DataLoader(val_dataset, shuffle=False, drop_last=True, **loader_args)

    # (Initialize logging)
#     experiment = wandb.init(project='U-Net-mask', resume='allow', anonymous='must')
#     experiment.config.update(dict(epochs=epochs, 
#                                   batch_size=batch_size, 
#                                   learning_rate=learning_rate,
#                                   save_checkpoint=save_checkpoint, 
#                                   img_scale=img_scale,
#                                   amp=amp))
    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {learning_rate}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_checkpoint}
        Device:          {device.type}
        Images scaling:  {img_scale}
        Mixed Precision: {amp}
    ''')

    # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
    optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)  # goal: maximize Dice score
    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
    criterion = nn.CrossEntropyLoss()
    global_step = 0

    # 5. Begin training
    for epoch in range(epochs):
        net.train()
        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                images = batch['image']
                true_masks = batch['mask']

                assert images.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {images.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                images = images.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device, dtype=torch.long)

                with torch.cuda.amp.autocast(enabled=amp):
                    masks_pred = net(images)
                    loss = criterion(masks_pred, true_masks) \
                           + dice_loss(F.softmax(masks_pred, dim=1).float(),
                                       F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(),
                                       multiclass=True)

                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                grad_scaler.step(optimizer)
                grad_scaler.update()

                pbar.update(images.shape[0])
                global_step += 1
                epoch_loss += loss.item()
#                 experiment.log({
#                     'train loss': loss.item(),
#                     'step': global_step,
#                     'epoch': epoch
#                 })
                
                pbar.set_postfix(**{'loss (batch)': loss.item()})

                # Evaluation round
                if global_step % (n_train // (2 * batch_size)) == 0:
                    histograms = {}
                    for tag, value in net.named_parameters():
                        tag = tag.replace('/', '.')
                        histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
                        histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())

                    val_score = evaluate(net, val_loader, device)
                    scheduler.step(val_score)

                    logging.info('Validation Dice score: {}'.format(val_score))
#                     experiment.log({
#                         'learning rate': optimizer.param_groups[0]['lr'],
#                         'validation Dice': val_score,
#                         'images': wandb.Image(images[0].cpu()),
#                         'masks': {
#                             'true': wandb.Image(true_masks[0].float().cpu()),
#                             'pred': wandb.Image(torch.softmax(masks_pred, dim=1)[0].float().cpu()),
#                         },
#                         'step': global_step,
#                         'epoch': epoch,
#                         **histograms
#                     })

        if save_checkpoint:
            Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
            torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}_{}.pth'.format(epoch, datetime.now().strftime("%m_%d_%Y_%H_%M_%S"))))
            logging.info(f'Checkpoint {epoch + 1} saved at {datetime.now().strftime("%m_%d_%Y_%H_%M_%S")}!')

In [11]:
# Number of epochs
epochs = 5

# Batch size
batch_size=1

# Learning rate
lr = 0.0001

# Load model from a .pth file
load = False

# Downscaling factor of the image
scale = 1

# Use mixed precision
amp = False

weight_decay=1e-8

momentum=0.9

config = wandb.config
config.epochs = epochs
config.batch_size = batch_size
config.lr = lr
config.load = load
config.scale = scale
config.amp = amp
config.weight_decay = weight_decay
config.momentum = momentum

config.model = "Unet-mask"

In [12]:
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')

INFO: Using device cuda


In [13]:
import torch
import torch.nn as nn
import torchvision
resnet = torchvision.models.resnet.resnet18(pretrained=True)


class ConvBlock(nn.Module):
    """
    Helper module that consists of a Conv -> BN -> ReLU
    """

    def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.with_nonlinearity = with_nonlinearity

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        if self.with_nonlinearity:
            x = self.relu(x)
        return x


class Bridge(nn.Module):
    """
    This is the middle layer of the UNet which just consists of some
    """

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.bridge = nn.Sequential(
            ConvBlock(in_channels, out_channels),
            ConvBlock(out_channels, out_channels)
        )

    def forward(self, x):
        return self.bridge(x)


class UpBlockForUNetWithResNet50(nn.Module):
    """
    Up block that encapsulates one up-sampling step which consists of Upsample -> ConvBlock -> ConvBlock
    """

    def __init__(self, in_channels, out_channels, up_conv_in_channels=None, up_conv_out_channels=None,
                 upsampling_method="conv_transpose"):
        super().__init__()

        if up_conv_in_channels == None:
            up_conv_in_channels = in_channels
        if up_conv_out_channels == None:
            up_conv_out_channels = out_channels

        if upsampling_method == "conv_transpose":
            self.upsample = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2)
        elif upsampling_method == "bilinear":
            self.upsample = nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2),
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
            )
        self.conv_block_1 = ConvBlock(in_channels, out_channels)
        self.conv_block_2 = ConvBlock(out_channels, out_channels)

    def forward(self, up_x, down_x):
        """
        :param up_x: this is the output from the previous up block
        :param down_x: this is the output from the down block
        :return: upsampled feature map
        """
        x = self.upsample(up_x)
        x = torch.cat([x, down_x], 1)
        x = self.conv_block_1(x)
        x = self.conv_block_2(x)
        return x


class UNetWithResnet50Encoder(nn.Module):
    DEPTH = 6

    def __init__(self, n_classes=2):
        super().__init__()
        resnet = torchvision.models.resnet.resnet18(pretrained=True)
        down_blocks = []
        up_blocks = []
        self.input_block = nn.Sequential(*list(resnet.children()))[:3]
        self.input_pool = list(resnet.children())[3]
        for bottleneck in list(resnet.children()):
            if isinstance(bottleneck, nn.Sequential):
                down_blocks.append(bottleneck)
        self.down_blocks = nn.ModuleList(down_blocks)
        self.bridge = Bridge(2048, 2048)
        up_blocks.append(UpBlockForUNetWithResNet50(2048, 1024))
        up_blocks.append(UpBlockForUNetWithResNet50(1024, 512))
        up_blocks.append(UpBlockForUNetWithResNet50(512, 256))
        up_blocks.append(UpBlockForUNetWithResNet50(in_channels=128 + 64, out_channels=128,
                                                    up_conv_in_channels=256, up_conv_out_channels=128))
        up_blocks.append(UpBlockForUNetWithResNet50(in_channels=64 + 3, out_channels=64,
                                                    up_conv_in_channels=128, up_conv_out_channels=64))

        self.up_blocks = nn.ModuleList(up_blocks)

        self.out = nn.Conv2d(64, n_classes, kernel_size=1, stride=1)

    def forward(self, x, with_output_feature_map=False):
        pre_pools = dict()
        pre_pools[f"layer_0"] = x
        x = self.input_block(x)
        pre_pools[f"layer_1"] = x
        x = self.input_pool(x)

        for i, block in enumerate(self.down_blocks, 2):
            x = block(x)
            if i == (UNetWithResnet50Encoder.DEPTH - 1):
                continue
            pre_pools[f"layer_{i}"] = x

        x = self.bridge(x)

        for i, block in enumerate(self.up_blocks, 1):
            key = f"layer_{UNetWithResnet50Encoder.DEPTH - 1 - i}"
            x = block(x, pre_pools[key])
        output_feature_map = x
        x = self.out(x)
        del pre_pools
        if with_output_feature_map:
            return x, output_feature_map
        else:
            return x

In [14]:
model = UNetWithResnet50Encoder().cuda()

In [15]:
inp = torch.rand((2, 3, 2048, 2048)).cuda()

In [16]:
out = model(inp)

RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 3.95 GiB total capacity; 3.16 GiB already allocated; 53.25 MiB free; 3.18 GiB reserved in total by PyTorch)

In [None]:
net = UNet(n_channels=3, n_classes=2, bilinear=True)
net.to(device="cuda")

In [None]:
logging.info(f'Network:\n'
                 f'\t{net.n_channels} input channels\n'
                 f'\t{net.n_classes} output channels (classes)\n'
                 f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
try:
    train_net(net=net,
              epochs=epochs,
              batch_size=batch_size,
              learning_rate=lr,
              device=device,
              img_scale=scale,
              amp=amp)
except KeyboardInterrupt:
    torch.save(net.state_dict(), 'INTERRUPTED.pth')
    logging.info('Saved interrupt')
    sys.exit(0)