In [1]:
#!git clone https://github.com/milesial/Pytorch-UNet.git
# Fix for cuda init error
#!pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html

In [2]:
import argparse
import logging
import sys
from pathlib import Path
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

from utils.clothing_dataset_threshold import BasicDataset, ClothingDataset
from utils.dice_score import dice_loss
from utils.evaluate import evaluate
from unet import UNet
import gc
import matplotlib.pyplot as plt
from datetime import datetime

In [3]:
import torch.cuda
print(torch.cuda.is_available())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

True


device(type='cuda')

# Setup wandb

In [4]:
wandb.init(project='unet', entity='endev')

[34m[1mwandb[0m: Currently logged in as: [33mendev[0m (use `wandb login --relogin` to force relogin)


# Dataset Path Variables

In [5]:
root = "./data/"
# root_dir = "/data/datasets/clothing-size/" # ssd takes time to move files and access them but faster
root_dir = "/data/clothing-size-identification/data" # easy to move files and edit but very slow
dir_checkpoint = Path('./training/')

In [6]:
train_dir_img = Path(root_dir + '/train/imgs/')
train_dir_mask = Path(root_dir + '/train/mask/')

In [7]:
val_dir_img = Path(root_dir + '/val/imgs/')
val_dir_mask = Path(root_dir + '/val/mask/')

# Check data

In [8]:
ls /data/clothing-size-identification/data/val/imgs/

[0m[01;32mDSC_2702.jpg[0m*  [01;32mDSC_2716.jpg[0m*  [01;32mIMG_9283.jpg[0m*  [01;32mIMG_9386.jpg[0m*  [01;32mIMG_9427.jpg[0m*
[01;32mDSC_2704.jpg[0m*  [01;32mDSC_2721.jpg[0m*  [01;32mIMG_9315.jpg[0m*  [01;32mIMG_9388.jpg[0m*  [01;32mIMG_9429.jpg[0m*
[01;32mDSC_2707.jpg[0m*  [01;32mDSC_2722.jpg[0m*  [01;32mIMG_9326.jpg[0m*  [01;32mIMG_9389.jpg[0m*  [01;32mIMG_9430.jpg[0m*
[01;32mDSC_2708.jpg[0m*  [01;32mIMG_9263.jpg[0m*  [01;32mIMG_9329.jpg[0m*  [01;32mIMG_9398.jpg[0m*
[01;32mDSC_2714.jpg[0m*  [01;32mIMG_9279.jpg[0m*  [01;32mIMG_9330.jpg[0m*  [01;32mIMG_9422.jpg[0m*


In [9]:
ls /data/datasets/clothing-size/val/imgs/

[0m[01;32mDSC_2702.jpg[0m*  [01;32mDSC_2716.jpg[0m*  [01;32mIMG_9283.jpg[0m*  [01;32mIMG_9386.jpg[0m*  [01;32mIMG_9427.jpg[0m*
[01;32mDSC_2704.jpg[0m*  [01;32mDSC_2721.jpg[0m*  [01;32mIMG_9315.jpg[0m*  [01;32mIMG_9388.jpg[0m*  [01;32mIMG_9429.jpg[0m*
[01;32mDSC_2707.jpg[0m*  [01;32mDSC_2722.jpg[0m*  [01;32mIMG_9326.jpg[0m*  [01;32mIMG_9389.jpg[0m*  [01;32mIMG_9430.jpg[0m*
[01;32mDSC_2708.jpg[0m*  [01;32mIMG_9263.jpg[0m*  [01;32mIMG_9329.jpg[0m*  [01;32mIMG_9398.jpg[0m*
[01;32mDSC_2714.jpg[0m*  [01;32mIMG_9279.jpg[0m*  [01;32mIMG_9330.jpg[0m*  [01;32mIMG_9422.jpg[0m*


In [10]:
ls /data/datasets/clothing-size/val/mask/

[0m[01;32mDSC_2702_mask.gif[0m*  [01;32mDSC_2721_mask.gif[0m*  [01;32mIMG_9326_mask.gif[0m*  [01;32mIMG_9398_mask.gif[0m*
[01;32mDSC_2704_mask.gif[0m*  [01;32mDSC_2722_mask.gif[0m*  [01;32mIMG_9329_mask.gif[0m*  [01;32mIMG_9422_mask.gif[0m*
[01;32mDSC_2707_mask.gif[0m*  [01;32mIMG_9263_mask.gif[0m*  [01;32mIMG_9330_mask.gif[0m*  [01;32mIMG_9427_mask.gif[0m*
[01;32mDSC_2708_mask.gif[0m*  [01;32mIMG_9279_mask.gif[0m*  [01;32mIMG_9386_mask.gif[0m*  [01;32mIMG_9429_mask.gif[0m*
[01;32mDSC_2714_mask.gif[0m*  [01;32mIMG_9283_mask.gif[0m*  [01;32mIMG_9388_mask.gif[0m*  [01;32mIMG_9430_mask.gif[0m*
[01;32mDSC_2716_mask.gif[0m*  [01;32mIMG_9315_mask.gif[0m*  [01;32mIMG_9389_mask.gif[0m*


In [11]:
def train_net(net,
              device,
              epochs: int = 5,
              batch_size: int = 1,
              learning_rate: float = 0.001,
              save_checkpoint: bool = True,
              img_scale: float = 0.5,
              amp: bool = False):
    # 1. Create dataset
    train_dataset = ClothingDataset(train_dir_img, train_dir_mask, img_scale)
    val_dataset = ClothingDataset(val_dir_img, val_dir_mask, img_scale)

    # 2. Totals
    n_train = len(train_dataset)
    n_val = len(val_dataset)

    # 3. Create data loaders
    loader_args = dict(batch_size=batch_size, num_workers=1, pin_memory=True)
    train_loader = DataLoader(train_dataset, shuffle=True, **loader_args)
    val_loader = DataLoader(val_dataset, shuffle=False, drop_last=True, **loader_args)

    # (Initialize logging)
    experiment = wandb.init(project='U-Net-mask', resume='allow', anonymous='must')
    experiment.config.update(dict(epochs=epochs, 
                                  batch_size=batch_size, 
                                  learning_rate=learning_rate,
                                  save_checkpoint=save_checkpoint, 
                                  img_scale=img_scale,
                                  amp=amp))
    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {learning_rate}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_checkpoint}
        Device:          {device.type}
        Images scaling:  {img_scale}
        Mixed Precision: {amp}
    ''')

    # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
    optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)  # goal: maximize Dice score
    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
    criterion = nn.CrossEntropyLoss()
    global_step = 0

    # 5. Begin training
    for epoch in range(epochs):
        net.train()
        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                images = batch['image']
                true_masks = batch['mask']

                assert images.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {images.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                images = images.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device, dtype=torch.long)

                with torch.cuda.amp.autocast(enabled=amp):
                    masks_pred = net(images)
                    loss = criterion(masks_pred, true_masks) \
                           + dice_loss(F.softmax(masks_pred, dim=1).float(),
                                       F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(),
                                       multiclass=True)

                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                grad_scaler.step(optimizer)
                grad_scaler.update()

                pbar.update(images.shape[0])
                global_step += 1
                epoch_loss += loss.item()
                experiment.log({
                    'train loss': loss.item(),
                    'step': global_step,
                    'epoch': epoch
                })
                
                pbar.set_postfix(**{'loss (batch)': loss.item()})

                # Evaluation round
                if global_step % (n_train // (2 * batch_size)) == 0:
                    histograms = {}
                    for tag, value in net.named_parameters():
                        tag = tag.replace('/', '.')
                        histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
                        histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())

                    val_score = evaluate(net, val_loader, device)
                    scheduler.step(val_score)

                    logging.info('Validation Dice score: {}'.format(val_score))
                    experiment.log({
                        'learning rate': optimizer.param_groups[0]['lr'],
                        'validation Dice': val_score,
                        'images': wandb.Image(images[0].cpu()),
                        'masks': {
                            'true': wandb.Image(true_masks[0].float().cpu()),
                            'pred': wandb.Image(torch.softmax(masks_pred, dim=1)[0].float().cpu()),
                        },
                        'step': global_step,
                        'epoch': epoch,
                        **histograms
                    })

        if save_checkpoint:
            Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
            torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}_{}.pth'.format(epoch, datetime.now().strftime("%m_%d_%Y_%H_%M_%S"))))
            logging.info(f'Checkpoint {epoch + 1} saved at {datetime.now().strftime("%m_%d_%Y_%H_%M_%S")}!')

In [12]:
# Number of epochs
epochs = 10

# Batch size
batch_size=1

# Learning rate
lr = 0.0001

# Load model from a .pth file
load = False

# Downscaling factor of the image
scale = 1

# Use mixed precision
amp = False

weight_decay=1e-8

momentum=0.8

config = wandb.config
config.epochs = epochs
config.batch_size = batch_size
config.lr = lr
config.load = load
config.scale = scale
config.amp = amp
config.weight_decay = weight_decay
config.momentum = momentum

config.model = "Unet-mask-pretrained"

In [13]:
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')

INFO: Using device cuda


In [14]:
net = torch.hub.load('milesial/Pytorch-UNet', 'unet_carvana', pretrained=True)
net.to(device="cuda")

Using cache found in /root/.cache/torch/hub/milesial_Pytorch-UNet_master


UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(128, eps=1e-05, moment

In [15]:
logging.info(f'Network:\n'
                 f'\t{net.n_channels} input channels\n'
                 f'\t{net.n_classes} output channels (classes)\n'
                 f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')

INFO: Network:
	3 input channels
	2 output channels (classes)
	Bilinear upscaling


In [16]:
gc.collect()
torch.cuda.empty_cache()

In [17]:
try:
    train_net(net=net,
              epochs=epochs,
              batch_size=batch_size,
              learning_rate=lr,
              device=device,
              img_scale=scale,
              amp=amp)
except KeyboardInterrupt:
    torch.save(net.state_dict(), 'INTERRUPTED.pth')
    logging.info('Saved interrupt')
    sys.exit(0)

INFO: Creating dataset with 97 examples
INFO: Creating dataset with 23 examples


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

INFO: Starting training:
        Epochs:          10
        Batch size:      1
        Learning rate:   0.0001
        Training size:   97
        Validation size: 23
        Checkpoints:     True
        Device:          cuda
        Images scaling:  1
        Mixed Precision: False
    
Epoch 1/10:  49%|████▉     | 48/97 [00:09<00:09,  5.23img/s, loss (batch)=0.132] 
Validation round:   0%|          | 0/23 [00:00<?, ?batch/s][A
Validation round:   9%|▊         | 2/23 [00:00<00:01, 13.28batch/s][A
Validation round:  17%|█▋        | 4/23 [00:00<00:01, 15.93batch/s][A
Validation round:  26%|██▌       | 6/23 [00:00<00:01, 16.98batch/s][A
Validation round:  35%|███▍      | 8/23 [00:00<00:00, 17.53batch/s][A
Validation round:  43%|████▎     | 10/23 [00:00<00:00, 17.89batch/s][A
Validation round:  52%|█████▏    | 12/23 [00:00<00:00, 18.12batch/s][A
Validation round:  61%|██████    | 14/23 [00:00<00:00, 18.25batch/s][A
Validation round:  70%|██████▉   | 16/23 [00:00<00:00, 18.33batc

Validation round:   9%|▊         | 2/23 [00:00<00:01, 13.91batch/s][A
Validation round:  17%|█▋        | 4/23 [00:00<00:01, 16.09batch/s][A
Validation round:  26%|██▌       | 6/23 [00:00<00:00, 17.03batch/s][A
Validation round:  35%|███▍      | 8/23 [00:00<00:00, 17.56batch/s][A
Validation round:  43%|████▎     | 10/23 [00:00<00:00, 17.86batch/s][A
Validation round:  52%|█████▏    | 12/23 [00:00<00:00, 18.08batch/s][A
Validation round:  61%|██████    | 14/23 [00:00<00:00, 18.21batch/s][A
Validation round:  70%|██████▉   | 16/23 [00:00<00:00, 18.28batch/s][A
Validation round:  78%|███████▊  | 18/23 [00:01<00:00, 18.33batch/s][A
Validation round:  87%|████████▋ | 20/23 [00:01<00:00, 18.34batch/s][A
Validation round:  96%|█████████▌| 22/23 [00:01<00:00, 18.39batch/s][A
                                                                    [AINFO: Validation Dice score: 0.6667773127555847
Epoch 4/10: 100%|██████████| 97/97 [00:22<00:00,  4.37img/s, loss (batch)=0.0353]
INFO: Check

Validation round:  43%|████▎     | 10/23 [00:00<00:00, 17.79batch/s][A
Validation round:  52%|█████▏    | 12/23 [00:00<00:00, 17.98batch/s][A
Validation round:  61%|██████    | 14/23 [00:00<00:00, 18.07batch/s][A
Validation round:  70%|██████▉   | 16/23 [00:00<00:00, 18.14batch/s][A
Validation round:  78%|███████▊  | 18/23 [00:01<00:00, 18.19batch/s][A
Validation round:  87%|████████▋ | 20/23 [00:01<00:00, 18.25batch/s][A
Validation round:  96%|█████████▌| 22/23 [00:01<00:00, 18.27batch/s][A
                                                                    [AINFO: Validation Dice score: 0.7390815019607544
Epoch 8/10:  92%|█████████▏| 89/97 [00:18<00:01,  5.20img/s, loss (batch)=0.0329]
Validation round:   0%|          | 0/23 [00:00<?, ?batch/s][A
Validation round:   9%|▊         | 2/23 [00:00<00:01, 13.62batch/s][A
Validation round:  17%|█▋        | 4/23 [00:00<00:01, 15.93batch/s][A
Validation round:  26%|██▌       | 6/23 [00:00<00:01, 16.94batch/s][A
Validation round:  