In [8]:
from typing import List

import numpy as np
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.base.modules import Activation
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets
from torchvision.transforms import transforms

from baal import ActiveLearningDataset

import argparse
from copy import deepcopy
from pprint import pprint

import torch.backends
from PIL import Image
from torch import optim
from torchvision.transforms import transforms
from tqdm import tqdm

import baal
# from baal import get_heuristic, ActiveLearningLoop
# from baal.bayesian.dropout import MCDropoutModule
# from baal import ModelWrapper
# from baal import ClassificationReport
# from baal import PILToLongTensor

try:
    import segmentation_models_pytorch as smp
except ImportError:
    raise Exception("This example requires `smp`.\n pip install segmentation_models_pytorch")

import torch
import torch.nn.functional as F
import numpy as np

In [12]:
from pathlib import Path

In [3]:
!pip install baal
!pip install segmentation-models-pytorch

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [13]:
class SegmentationHead(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1):
        dropout = nn.Dropout2d(0.5)
        conv2d = nn.Conv2d(
            in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2
        )
        upsampling = (
            nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
        )
        activation = Activation(activation)
        super().__init__(dropout, conv2d, upsampling, activation)


def add_dropout(
    model: smp.Unet,
    decoder_channels: List[int] = (256, 128, 64, 32, 16),
    classes=1,
    activation=None,
):
    seg_head = SegmentationHead(
        in_channels=decoder_channels[-1],
        out_channels=classes,
        activation=activation,
        kernel_size=3,
    )
    model.add_module("segmentation_head", seg_head)
    model.initialize()


class FocalLoss(nn.Module):
    """
    References:
        Author: clcarwin
        Site https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py
    """

    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha, (float, int)):
            self.alpha = torch.Tensor([alpha, 1 - alpha])
        if isinstance(alpha, list):
            self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim() > 2:
            input = input.view(input.size(0), input.size(1), -1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1, 2)  # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1, input.size(2))  # N,H*W,C => N*H*W,C
        target = target.view(-1, 1)

        logpt = F.log_softmax(input, dim=1)
        logpt = logpt.gather(1, target)
        logpt = logpt.view(-1)
        pt = logpt.data.exp()

        if self.alpha is not None:
            if self.alpha.type() != input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            select = (target != 0).type(torch.LongTensor).to(self.alpha.device)
            at = self.alpha.gather(0, select.data.view(-1))
            logpt = logpt * at

        loss = -1 * (1 - pt) ** self.gamma * logpt
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()

In [18]:
def mean_regions(n, grid_size=16):
    # Compute the mean uncertainty per regions.
    # [batch_size, W, H]
    n = torch.from_numpy(n[:, None, ...])
    # [Batch_size, 1, grid, grid]
    out = F.adaptive_avg_pool2d(n, grid_size)
    return np.mean(out.view([-1, grid_size**2]).numpy(), -1)


class ArrayDataset(torch.utils.data.Dataset):
    def __init__(self, array, image_transforms=None, both_transforms=None):
        self.array = array
                
        self.image_transforms = image_transforms
        self.segment_transforms = transforms.Compose([
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.array)

    def __getitem__(self, index):
        id, imagePath, segmentPath = self.array[index]
        image = Image.open(imagePath).convert('RGB')
        segment = Image.open(segmentPath).convert('L')
        segment = self.segment_transforms(segment)        
            
        if self.image_transforms is not None:
            image = self.image_transforms(image)
        
        # print(image.shape, segment.shape)

        if image.shape != (3, 448, 448):
            print(f"Image shape is {image.shape}")
        if segment.shape != (1, 448, 448):
            print(f"Segment shape is {segment.shape}")
        
        return image, segment
    
    def split(self, p=0.5):
        count = len(self.array)
        index = np.arange(count)
        first = int(count * p)
        return [
            ArrayDataset(self.array[index[:first]], 
                    image_transforms=self.image_transforms), 
            ArrayDataset(self.array[index[first:]], 
                    image_transforms=self.image_transforms)
        ]



def get_datasets(initial_pool, path):
    IM_SIZE = 224
    
    transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    X_dir = Path(path)/'train'/'images'
    y_dir = Path(path)/'train'/'masks'

    files = [y for y in X_dir.glob('*')] 

    for i in files:
        assert((y_dir / i.name).exists())

    data = np.array([(i, (y_dir / i.name)) for id, i in enumerate(files)])

    dataset = ArrayDataset(data, image_transforms = transform)

    active_set, test_set = dataset.split(0.8)
    active_set = ActiveLearningDataset(active_set)
    
    active_set.label_randomly(initial_pool)
    return active_set, test_set

In [19]:

p_active_learning_steps = 200
p_batch_size = 8
p_initial_pool = 40
p_query_size = 20
p_lr = 0.001
p_heuristic = "random"
p_reduce="sum"
p_data_path = "./data/roadsegm"
p_iterations=20
p_leaning_epoch=30

batch_size = 8

use_cuda = torch.cuda.is_available()

active_set, test_set = get_datasets(p_initial_pool, p_data_path)

# We will use the FocalLoss
criterion = FocalLoss(gamma=2, alpha=0.25)

# Our model is a simple Unet
model = smp.Unet(
    encoder_name="resnext50_32x4d",
    encoder_depth=5,
    encoder_weights="imagenet",
    decoder_use_batchnorm=False,
    classes=1,
)

# Add a Dropout layerto use MC-Dropout
add_dropout(model, classes=1, activation=None)

# This will enable Dropout at test time.
model = baal.bayesian.dropout.MCDropoutModule(model)

# Put everything on GPU.
if use_cuda:
    model.cuda()


criterion = FocalLoss(gamma=2, alpha=0.25)
# Make an optimizer
optimizer = optim.SGD(model.parameters(), lr=p_lr, momentum=0.9, weight_decay=5e-4)
# Keep a copy of the original weights
initial_weights = deepcopy(model.state_dict())

# Add metrics
model = baal.ModelWrapper(model, criterion)

# Which heuristic you want to use?
# We will use our custom reduction function.
heuristic = baal.ModelWrapper.get_heuristic(p_heuristic, reduction=mean_regions)

# The ALLoop is in charge of predicting the uncertainty and
loop = baal.ActiveLearningLoop(
    active_set,
    model.predict_on_dataset_generator,
    heuristic=heuristic,
    query_size=p_query_size,
    # Instead of predicting on the entire pool, only a subset is used
    max_sample=1000,
    batch_size=batch_size,
    iterations=p_iterations,
    use_cuda=use_cuda,
)
acc = []
for epoch in tqdm(p_active_learning_steps):
    # Following Gal et al. 2016, we reset the weights.
    model.load_state_dict(initial_weights)
    # Train 50 epochs before sampling.
    model.train_on_dataset(
        active_set, optimizer, batch_size, p_leaning_epoch, use_cuda
    )

    # Validation!
    model.test_on_dataset(test_set, batch_size, use_cuda)
    should_continue = loop.step()

    logs = model.get_metrics()
    pprint(logs)
    acc.append(logs)
    if not should_continue:
        break

ValueError: ignored