In [1]:
import sys
import os
import torch
import visdom
import argparse
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms
import Augmentor as aug
import numpy as np
import pydensecrf.densecrf as dcrf
import cv2
import random
from torch.utils.data.sampler import SubsetRandomSampler
import scipy

from torch.utils import data
import tqdm
from PIL import Image
import torch.backends.cudnn as cudnn
import torch.backends.cudnn

import matplotlib.pyplot as plt
import joblib
import scipy.io as io
import scipy.misc as m
import imageio as mio

## run / prepare

In [2]:
run = True

## data augmentation

**for pix2pix**

Apply bash script below to convert png to jpeg in gt

In [3]:
# for i in *.png ; do convert "$i" "${i%.*}.jpg" ; done

In [4]:
# base_str = '/home/sanityseeker/Documents/Datasets'
# types = ['train', 'test', 'val']
# nums = [2200, 15, 20]

In [5]:
# for i, t in enumerate(types):
#     p = aug.Pipeline(f'{base_str}/A/{t}/')
#     p.ground_truth(f'{base_str}/B/{t}/')
#     p.flip_left_right(probability=0.6)
#     p.rotate(probability=0.7, max_left_rotation=25, max_right_rotation=25)
#     p.rotate90(probability=0.1)
#     p.zoom_random(probability=0.6, percentage_area=0.8)
#     p.flip_top_bottom(probability=0.5)
#     p.sample(nums[i], multi_threaded=True)

Apply commands below to move gt to separate folder and make gt named the same as train images

In [6]:
# mv *_groundtruth\_\(1\)* /home/sanityseeker/Documents/Datasets/B/train
# rename \_groundtruth\_\(1\)\_train\_ train\_original\_ *

In [7]:
# !python train.py --dataroot ~/Documents/Datasets --name airports_default_lr --model pix2pix --batch_size 1 --num_threads 4 --lr_decay_iters 1000

## main_aug

In [8]:
# image = Image.open(os.path.join(lol, 'codes', 'vko19bing_1-4.png'))
# result = image.convert('P', palette=Image.ADAPTIVE)
# result.putalpha(0)
# colors = result.getcolors()

# sorted(colors)

In [9]:
def get_airports_labels():
    return np.asarray(
    [
        [0, 255, 0],  # 0 : green
        [255, 255, 255],  # 1 : white
        [255, 255, 0],  # 2 : yellow
        [0, 255, 255],  # 3 : cyan
        [0, 0, 255],  # 4 : blue
        [0, 0, 0],  # 5 : black
        [255, 0, 0],  # 6 :red
        [255, 0, 255],  # 7 : purple
        [0, 128, 128],  # 8 : teal
        [128, 128, 0]  # 9 : olive
    ]
    )

In [10]:
def encode_label(rgb_image):
    def find_nearest_class(input, classes):
        return np.argmin(np.linalg.norm(classes - input, axis=-1))
    
    classes = get_airports_labels()
    mask = np.array(rgb_image, dtype=int)
    label_mask = np.apply_along_axis(lambda a : find_nearest_class(a, classes), axis=-1, arr=mask)
    return label_mask
    
def decode_label(label_img, n_classes=10, save_img_path=None):
    label_colours = get_airports_labels()
    label_mask = np.array(label_img)
    r = label_mask.copy()
    g = label_mask.copy()
    b = label_mask.copy()
    for ll in range(0, n_classes):
        r[label_mask == ll] = label_colours[ll, 0]
        g[label_mask == ll] = label_colours[ll, 1]
        b[label_mask == ll] = label_colours[ll, 2]
    rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
    rgb[:, :, 0] = r
    rgb[:, :, 1] = g
    rgb[:, :, 2] = b
    img = Image.fromarray(rgb.astype('uint8'), 'RGB')
    if save_img_path:
        mio.imwrite(save_img_path, rgb)
        
    return img

In [11]:
def body_loop(image_name, train_imgs, train_labels, result_labels_decoded_path, 
              result_labels_encoded_path, result_labels_path, result_imgs_path, crop_num, rot_num, 
              random_flip_crop, random_rotation):    
        img_train_path = os.path.join(train_imgs, image_name)
        img_label_path = os.path.join(train_labels, image_name)        
        
        if not os.path.isfile(img_train_path):
            return

        img = Image.open(img_train_path)
        label= Image.open(img_label_path)
        
        for i in range(rot_num):
            seed = random.randint(0,2**32)
            random.seed(seed)
            rot_img = random_rotation(img)
            random.seed(seed)
            rot_label = random_rotation(label)
            
            rot_label_encoded = encode_label(rot_label)
            if np.any(rot_label_encoded > 9):
                raise ValueError
            
            rot_label_encoded = m.toimage(rot_label_encoded, high=rot_label_encoded.max(), low=rot_label_encoded.min())

            for j in range(crop_num):
                seed = random.randint(0,2**32)
                random.seed(seed)
                rot_cropped_img = random_flip_crop(rot_img)
                random.seed(seed)
                rot_cropped_label = random_flip_crop(rot_label)
                random.seed(seed)
                rot_cropped_label_encoded = random_flip_crop(rot_label_encoded)
                
                if np.any(np.array(rot_cropped_label_encoded) > 9):
                    raise ValueError
                
                decode_label(rot_cropped_label_encoded, 
                             save_img_path=f'{result_labels_decoded_path}/_{image_name}_{i}_{j}.png')
                
                rot_cropped_img.save(f'{result_imgs_path}/_{image_name}_{i}_{j}.png')
                rot_cropped_label.save(f'{result_labels_path}/_{image_name}_{i}_{j}.png') 
                mio.imwrite(f'{result_labels_encoded_path}/_{image_name}_{i}_{j}.png', np.array(rot_cropped_label_encoded))

                
def augment_data(data_path, prefix='train', crop_size=512, crop_scale=(0.5, 0.8), 
                 crop_num = 15, rot_num = 15, rotation_degrees=(3, 20)):
    
    train_imgs = os.path.join(data_path, prefix)
    train_labels = os.path.join(data_path, f'{prefix}_masks')
    result_imgs_path = os.path.join(train_imgs, f'augmented_{crop_size}')
    result_labels_path = os.path.join(train_labels, f'augmented_{crop_size}')
    result_labels_encoded_path = os.path.join(train_labels, f'labels_{crop_size}')
    result_labels_decoded_path = os.path.join(train_labels, f'labels_decoded_{crop_size}')
    
    if not os.path.isdir(result_imgs_path):
        os.makedirs(result_imgs_path)
    if not os.path.isdir(result_labels_path):
        os.makedirs(result_labels_path)
    if not os.path.isdir(result_labels_encoded_path):
        os.makedirs(result_labels_encoded_path)
    if not os.path.isdir(result_labels_decoded_path):
        os.makedirs(result_labels_decoded_path)

    filenames = os.listdir(train_imgs)

    center_crop = transforms.CenterCrop(crop_size)
    random_rotation = transforms.RandomRotation(rotation_degrees, expand=False)
    random_flip_crop = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomResizedCrop(size=crop_size, scale=crop_scale, interpolation=Image.NEAREST)
    ])

#     for i, image_name in enumerate(tqdm.tqdm_notebook(filenames))
        
    joblib.Parallel(n_jobs=-1, verbose=1)(joblib.delayed(body_loop)
                                          (image_name, train_imgs, train_labels, result_labels_decoded_path, 
                                           result_labels_encoded_path, result_labels_path, result_imgs_path, 
                                           crop_num, rot_num, random_flip_crop, random_rotation) for image_name in tqdm.tqdm(filenames))
    

In [12]:
if not run:
    data_path = '/home/sanityseeker/Documents/semantic-segmentation/data/'
    for mode in ['train', 'test']:
        augment_data(data_path=data_path, prefix=mode, crop_size=512, rot_num=3, crop_num=20)

### checking cuda

In [1]:
torch.cuda.is_available()

NameError: name 'torch' is not defined

## U-NET

**sample logic**

In [14]:
# # a sample down block
# def make_conv_bn_relu(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
#     return [
#         nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,  stride=stride, padding=padding, bias=False),
#         nn.BatchNorm2d(out_channels),
#         nn.ReLU(inplace=True)
#     ]
# self.down1 = nn.Sequential(
#     *make_conv_bn_relu(in_channels, 64, kernel_size=3, stride=1, padding=1 ),
#     *make_conv_bn_relu(64, 64, kernel_size=3, stride=1, padding=1 ),
# )

# # convolutions followed by a maxpool
# down1 = self.down1(x)
# out1 = F.max_pool2d(down1, kernel_size=2, stride=2)

# # a sample up block
# def make_conv_bn_relu(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
#     return [
#         nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,  stride=stride, padding=padding, bias=False),
#         nn.BatchNorm2d(out_channels),
#         nn.ReLU(inplace=True)
#     ]
# self.up4 = nn.Sequential(
#     *make_conv_bn_relu(128,64, kernel_size=3, stride=1, padding=1 ),
#     *make_conv_bn_relu(64,64, kernel_size=3, stride=1, padding=1 )
# )
# self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1, stride=1, padding=0 )

# # upsample out_last, concatenate with down1 and apply conv operations
# out   = F.upsample(out_last, scale_factor=2, mode='bilinear')  
# out   = torch.cat([down1, out], 1)
# out   = self.up4(out)

# # final 1x1 conv for predictions
# final_out = self.final_conv(out)

### unet init

In [15]:
def conv3x3(in_, out):
    return nn.Conv2d(in_, out, 3, padding=1)

class ConvRelu(nn.Module):
    def __init__(self, in_: int, out: int):
        super(ConvRelu, self).__init__()
        self.conv = conv3x3(in_, out)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.activation(x)
        return x


class DecoderBlock(nn.Module):
    """
    Paramaters for Deconvolution were chosen to avoid artifacts, following
    link https://distill.pub/2016/deconv-checkerboard/
    """

    def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
        super(DecoderBlock, self).__init__()
        self.in_channels = in_channels

        if is_deconv:
            self.block = nn.Sequential(
                ConvRelu(in_channels, middle_channels),
                nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
                                   padding=1),
                nn.ReLU(inplace=True)
            )
        else:
            self.block = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear'),
                ConvRelu(in_channels, middle_channels),
                ConvRelu(middle_channels, out_channels),
            )

    def forward(self, x):
        return self.block(x)


class UNet11(nn.Module):
    def __init__(self, num_classes=1, num_filters=32, pretrained=False):
        """
        :param num_classes:
        :param num_filters:
        :param pretrained:
            False - no pre-trained network used
            True - encoder pre-trained with VGG11
        """
        super().__init__()
        self.pool = nn.MaxPool2d(2, 2)

        self.num_classes = num_classes

        self.encoder = models.vgg11(pretrained=pretrained).features

        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Sequential(self.encoder[0],
                                   self.relu)

        self.conv2 = nn.Sequential(self.encoder[3],
                                   self.relu)

        self.conv3 = nn.Sequential(
            self.encoder[6],
            self.relu,
            self.encoder[8],
            self.relu,
        )
        self.conv4 = nn.Sequential(
            self.encoder[11],
            self.relu,
            self.encoder[13],
            self.relu,
        )

        self.conv5 = nn.Sequential(
            self.encoder[16],
            self.relu,
            self.encoder[18],
            self.relu,
        )

        self.center = DecoderBlock(256 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv=True)
        self.dec5 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv=True)
        self.dec4 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 4, is_deconv=True)
        self.dec3 = DecoderBlock(256 + num_filters * 4, num_filters * 4 * 2, num_filters * 2, is_deconv=True)
        self.dec2 = DecoderBlock(128 + num_filters * 2, num_filters * 2 * 2, num_filters, is_deconv=True)
        self.dec1 = ConvRelu(64 + num_filters, num_filters)

        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(self.pool(conv1))
        conv3 = self.conv3(self.pool(conv2))
        conv4 = self.conv4(self.pool(conv3))
        conv5 = self.conv5(self.pool(conv4))
        center = self.center(self.pool(conv5))

        dec5 = self.dec5(torch.cat([center, conv5], 1))
        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
        dec1 = self.dec1(torch.cat([dec2, conv1], 1))

        if self.num_classes > 1:
            x_out = F.log_softmax(self.final(dec1), dim=1)
        else:
            x_out = self.final(dec1)

        return x_out



## loss function

In [16]:
def _cuda(x):
    return x.cuda(async=True) if torch.cuda.is_available() else x

class LossMulti:
    def __init__(self, jaccard_weight=0, class_weights=None, num_classes=10):
        if class_weights is not None:
            nll_weight = _cuda(
                torch.from_numpy(class_weights.astype(np.float32)))
        else:
            nll_weight = None
        self.nll_loss = nn.NLLLoss2d(weight=nll_weight)
        self.jaccard_weight = jaccard_weight
        self.num_classes = num_classes

    def __call__(self, outputs, targets):
        loss = (1 - self.jaccard_weight) * self.nll_loss(outputs, targets)

        if self.jaccard_weight:
            eps = 1e-15
            for cls in range(self.num_classes):
                jaccard_target = (targets == cls).float()
                jaccard_output = outputs[:, cls].exp()
                intersection = (jaccard_output * jaccard_target).sum()

                union = jaccard_output.sum() + jaccard_target.sum()
                loss -= torch.log((intersection + eps) / (union - intersection + eps)) * self.jaccard_weight
        return loss

class CrossEntropyLoss2d(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(CrossEntropyLoss2d, self).__init__()
        self.nll_loss = nn.NLLLoss2d(weight, size_average)

    def forward(self, inputs, targets):
        print(inputs.shape, type(inputs))
        print(inputs)
        print(targets.shape, type(targets))
        print(targets)
        return self.nll_loss(F.log_softmax(inputs), targets)

## Dataset

In [17]:
class AirportsDataset(torch.utils.data.Dataset):
    """Airports dataset."""

    def __init__(self, img_dir, labels_dir):
        """
        Images and corresponding labels are supposed to have same filenames
        Args:
            img_dir (string): Path to images
            labels_dir (string): Path to image labels
        """
        self.names = os.listdir(img_dir)
        self.img_dir = img_dir
        self.labels_dir = labels_dir
        
        self.transform = transforms.ToTensor()
        
    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.names[idx])
        label_path = os.path.join(self.labels_dir, self.names[idx])
        img = self.transform(Image.open(img_path))
        label = mio.imread(label_path)
        if np.any(label > 9):
            print(label_path)
            print(np.unique(label))
            raise ValueError
        label = torch.LongTensor(label)
        
        sample = {'image': img, 'label': label}
        return sample

In [18]:
def train_val_split(dataset, validation_split=.1, shuffle=True, batch_size=1, seed=1337):
    # Creating data indices for training and validation splits:
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(validation_split * dataset_size))
    if shuffle:
        np.random.seed(seed)
        np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]

    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)

    train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 
                                               sampler=train_sampler)
    validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                    sampler=valid_sampler)
    return train_loader, validation_loader

### check dataloader

In [19]:
data_path = '/home/sanityseeker/Documents/semantic-segmentation/data/'

In [20]:
aug_airports = AirportsDataset(img_dir=os.path.join(data_path, 'train', 'augmented_512'),
                                    labels_dir=os.path.join(data_path, 'train_masks', 'labels_512'))
aug_airports_test = AirportsDataset(img_dir=os.path.join(data_path, 'test', 'augmented_512'),
                                    labels_dir=os.path.join(data_path, 'test_masks', 'labels_512'))

In [21]:
# def plot_figures(figures, nrows = 1, ncols=1):
#     """Plot a dictionary of figures.

#     Parameters
#     ----------
#     figures : <title, figure> dictionary
#     ncols : number of columns of subplots wanted in the display
#     nrows : number of rows of subplots wanted in the figure
#     """

#     fig, axeslist = plt.subplots(ncols=ncols, nrows=nrows)
#     for ind,title in enumerate(figures):
#         axeslist.ravel()[ind].imshow(figures[title], cmap=plt.gray())
#         axeslist.ravel()[ind].set_title(title)
#         axeslist.ravel()[ind].set_axis_off()
#     plt.tight_layout() # optional

# for i in range(1):
#     sample = aug_airports[i]
#     print(i, sample['image'].shape, sample['label'].shape)
#     a = transforms.ToPILImage()
#     b = transforms.ToPILImage()
#     a = a(sample['image'])
#     b = b(sample['label'])
# #     a.show()
# #     b.show()
#     images = [a, b]
#     figures = {'im'+str(i): images[i] for i in range(len(images))}
#     plot_figures(figures, 1, 2)

In [22]:
aug_airports_train, aug_airports_val = train_val_split(dataset=aug_airports)
print(len(aug_airports), len(aug_airports_train), len(aug_airports_val))

1800 1620 180


## training

In [23]:
from validation import validation_multi
import json
from pathlib import Path
from datetime import datetime

In [24]:
def write_event(log, step, **data):
    data['step'] = step
    data['dt'] = datetime.now().isoformat()
    log.write(json.dumps(data, sort_keys=True))
    log.write('\n')
    log.flush()

In [25]:
def train(args, model, criterion, train_loader, valid_loader, validation, init_optimizer, n_epochs=None, fold=None,
          num_classes=10):
    
    lr = args['lr']
    optimizer = init_optimizer(lr)
    
    root = Path(args['root'])
    model_path = root / 'model_{fold}.pt'.format(fold=fold)
    if model_path.exists():
        state = torch.load(str(model_path))
        epoch = state['epoch']
        step = state['step']
        model.load_state_dict(state['model'])
        print('Restored model, epoch {}, step {:,}'.format(epoch, step))
    else:
        epoch = 1
        step = 0

    save = lambda ep: torch.save({
        'model': model.state_dict(),
        'epoch': ep,
        'step': step,
    }, str(model_path))

    report_each = 10
    log = root.joinpath('train_{fold}.log'.format(fold=fold)).open('at', encoding='utf8')
    valid_losses = []
    for epoch in range(epoch, n_epochs + 1):
        model.train()
        random.seed()
        tq = tqdm.tqdm(total=(len(train_loader) * args['batch_size']))
        tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
        losses = []
        tl = train_loader
        try:
            mean_loss = 0
            for i, sample in enumerate(tl):
                inputs = _cuda(sample['image'])
#                 inputs = sample['image']
    
                with torch.no_grad():
                    targets = _cuda(sample['label'])
#                     targets = sample['label']
                
#                 print('inputs:')
#                 print(inputs)
#                 print(inputs.shape)
                
#                 print('targets:')
#                 print(targets)
#                 print(targets.type())
#                 print(targets.shape)
#                 print(targets.unique())
                
                outputs = model(inputs)
                
#                 print('outputs:')
#                 print(outputs)
#                 print(outputs.shape)
                
                loss = criterion(outputs, targets)
                optimizer.zero_grad()
                batch_size = inputs.size(0)
                loss.backward()
                optimizer.step()
                step += 1
                tq.update(batch_size)
                losses.append(loss.item())
                mean_loss = np.mean(losses[-report_each:])
                tq.set_postfix(loss='{:.5f}'.format(mean_loss))
                if i and i % report_each == 0:
                    write_event(log, step, loss=mean_loss)
            write_event(log, step, loss=mean_loss)
            tq.close()
            save(epoch + 1)
            valid_metrics = validation(model, criterion, valid_loader, num_classes)
            write_event(log, step, **valid_metrics)
            valid_loss = valid_metrics['valid_loss']
            valid_losses.append(valid_loss)
        except KeyboardInterrupt:
            tq.close()
            print('Ctrl+C, saving snapshot')
            save(epoch)
            print('done.')
            return


In [26]:
args = {}
args['lr'] = 1e-3
args['batch_size'] = 1
args['root'] = data_path

num_classes = 10

train_loader, val_loader = train_val_split(dataset=aug_airports)
test_loader = torch.utils.data.DataLoader(aug_airports_test, batch_size=1,
                                          shuffle=False, num_workers=4)

In [27]:
print(torch.cuda.current_device(), torch.cuda.is_available(),
torch.cuda.get_device_name(0))

0 True GeForce MX150


In [28]:
torch.cuda.empty_cache()

## Teaching unet11 without pretrained weights

In [29]:
model = UNet11(num_classes=num_classes, pretrained=False)
# model = unet(n_classes=10)

model = nn.DataParallel(model, device_ids=None).cuda()
loss = LossMulti()
cudnn.benchmark = True

train(
        args,
        init_optimizer=lambda l_rate: torch.optim.Adam(model.parameters(), lr=l_rate),
        model=model,
        criterion=loss,
        train_loader=train_loader,
        valid_loader=val_loader,
        validation=validation_multi,
        n_epochs=50,
        fold=0,
        num_classes=num_classes
    )

Epoch 3, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Restored model, epoch 3, step 3,240


Epoch 3, lr 0.001: 100%|██████████| 1620/1620 [22:33<00:00,  1.21it/s, loss=0.73133]
Epoch 4, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.5844, average IoU: 0.3471, average Dice: 0.4036


Epoch 4, lr 0.001: 100%|██████████| 1620/1620 [22:23<00:00,  1.21it/s, loss=0.79237]
Epoch 5, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.5613, average IoU: 0.3503, average Dice: 0.4175


Epoch 5, lr 0.001: 100%|██████████| 1620/1620 [22:24<00:00,  1.21it/s, loss=0.42548]
Epoch 6, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.4808, average IoU: 0.4007, average Dice: 0.4719


Epoch 6, lr 0.001: 100%|██████████| 1620/1620 [22:24<00:00,  1.21it/s, loss=0.60826]
Epoch 7, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.4586, average IoU: 0.4148, average Dice: 0.4942


Epoch 7, lr 0.001: 100%|██████████| 1620/1620 [22:22<00:00,  1.21it/s, loss=0.55589]
Epoch 8, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.4339, average IoU: 0.4200, average Dice: 0.5022


Epoch 8, lr 0.001: 100%|██████████| 1620/1620 [22:23<00:00,  1.20it/s, loss=0.51785]
Epoch 9, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.4656, average IoU: 0.4181, average Dice: 0.5001


Epoch 9, lr 0.001: 100%|██████████| 1620/1620 [22:23<00:00,  1.21it/s, loss=0.45500]
Epoch 10, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.4265, average IoU: 0.4565, average Dice: 0.5522


Epoch 10, lr 0.001: 100%|██████████| 1620/1620 [22:23<00:00,  1.20it/s, loss=0.57385]
Epoch 11, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.3831, average IoU: 0.4829, average Dice: 0.5800


Epoch 11, lr 0.001: 100%|██████████| 1620/1620 [22:23<00:00,  1.21it/s, loss=0.71235]
Epoch 12, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.5618, average IoU: 0.3709, average Dice: 0.4357


Epoch 12, lr 0.001: 100%|██████████| 1620/1620 [22:22<00:00,  1.21it/s, loss=0.55908]
Epoch 13, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.3992, average IoU: 0.4703, average Dice: 0.5673


Epoch 13, lr 0.001: 100%|██████████| 1620/1620 [22:23<00:00,  1.21it/s, loss=0.39655]
Epoch 14, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.3554, average IoU: 0.4888, average Dice: 0.5853


Epoch 14, lr 0.001: 100%|██████████| 1620/1620 [22:23<00:00,  1.20it/s, loss=0.35687]
Epoch 15, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.3915, average IoU: 0.4634, average Dice: 0.5595


Epoch 15, lr 0.001: 100%|██████████| 1620/1620 [22:23<00:00,  1.21it/s, loss=0.41313]
Epoch 16, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.4534, average IoU: 0.4219, average Dice: 0.5052


Epoch 16, lr 0.001: 100%|██████████| 1620/1620 [22:23<00:00,  1.21it/s, loss=1.29949]
Epoch 17, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.7416, average IoU: 0.3493, average Dice: 0.4338


Epoch 17, lr 0.001: 100%|██████████| 1620/1620 [22:22<00:00,  1.20it/s, loss=0.40826]
Epoch 18, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.3334, average IoU: 0.4995, average Dice: 0.5925


Epoch 18, lr 0.001: 100%|██████████| 1620/1620 [22:24<00:00,  1.21it/s, loss=0.24524]
Epoch 19, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.3850, average IoU: 0.4532, average Dice: 0.5375


Epoch 19, lr 0.001: 100%|██████████| 1620/1620 [22:24<00:00,  1.21it/s, loss=0.28088]
Epoch 20, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.3100, average IoU: 0.4996, average Dice: 0.6018


Epoch 20, lr 0.001: 100%|██████████| 1620/1620 [22:23<00:00,  1.21it/s, loss=0.55087]
Epoch 21, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.4125, average IoU: 0.4523, average Dice: 0.5511


Epoch 21, lr 0.001: 100%|██████████| 1620/1620 [22:23<00:00,  1.21it/s, loss=0.47439]
Epoch 22, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.4561, average IoU: 0.4097, average Dice: 0.4846


Epoch 22, lr 0.001: 100%|██████████| 1620/1620 [22:24<00:00,  1.21it/s, loss=0.16341]
Epoch 23, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.2990, average IoU: 0.5246, average Dice: 0.6230


Epoch 23, lr 0.001: 100%|██████████| 1620/1620 [22:23<00:00,  1.21it/s, loss=0.41926]
Epoch 24, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.4055, average IoU: 0.4593, average Dice: 0.5538


Epoch 24, lr 0.001: 100%|██████████| 1620/1620 [22:23<00:00,  1.21it/s, loss=0.36161]
Epoch 25, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.2915, average IoU: 0.5417, average Dice: 0.6391


Epoch 25, lr 0.001: 100%|██████████| 1620/1620 [22:24<00:00,  1.21it/s, loss=0.32871]
Epoch 26, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.2844, average IoU: 0.5458, average Dice: 0.6391


Epoch 26, lr 0.001: 100%|██████████| 1620/1620 [22:23<00:00,  1.21it/s, loss=0.25258]
Epoch 27, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.2940, average IoU: 0.5423, average Dice: 0.6416


Epoch 27, lr 0.001: 100%|██████████| 1620/1620 [22:24<00:00,  1.21it/s, loss=0.41701]
Epoch 28, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.2789, average IoU: 0.5660, average Dice: 0.6641


Epoch 28, lr 0.001: 100%|██████████| 1620/1620 [22:24<00:00,  1.21it/s, loss=0.19526]
Epoch 29, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.2785, average IoU: 0.5598, average Dice: 0.6596


Epoch 29, lr 0.001: 100%|██████████| 1620/1620 [22:23<00:00,  1.21it/s, loss=0.22531]
Epoch 30, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.2785, average IoU: 0.5406, average Dice: 0.6410


Epoch 30, lr 0.001: 100%|██████████| 1620/1620 [22:24<00:00,  1.21it/s, loss=0.26725]
Epoch 31, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.2483, average IoU: 0.5765, average Dice: 0.6736


Epoch 31, lr 0.001: 100%|██████████| 1620/1620 [22:25<00:00,  1.21it/s, loss=0.27937]
Epoch 32, lr 0.001:   0%|          | 0/1620 [00:00<?, ?it/s]

Valid loss: 0.2503, average IoU: 0.5697, average Dice: 0.6661


Epoch 32, lr 0.001:  19%|█▊        | 301/1620 [04:09<18:14,  1.21it/s, loss=0.30250]


Ctrl+C, saving snapshot
done.


## Teaching unet11 with pre-trained on VGG-11 weights

In [29]:
model = UNet11(num_classes=num_classes, pretrained=True)
# model = unet(n_classes=10)

model = nn.DataParallel(model, device_ids=None).cuda()
loss = LossMulti()
cudnn.benchmark = True

train(
        args,
        init_optimizer=lambda l_rate: torch.optim.Adam(model.parameters(), lr=l_rate),
        model=model,
        criterion=loss,
        train_loader=train_loader,
        valid_loader=val_loader,
        validation=validation_multi,
        n_epochs=50,
        fold=1,
        num_classes=num_classes
    )

Epoch 1, lr 0.001: 100%|██████████| 1620/1620 [22:31<00:00,  1.21it/s, loss=0.65994]


RuntimeError: cuda runtime error (2) : out of memory at /pytorch/torch/csrc/generic/serialization.cpp:17