In [None]:
# data.py

import numpy as np
import torch
import os
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


class SegmentationDataset(Dataset):
    """Dataset suitable for segmentation tasks.
    """

    def __init__(self, image_dir, mask_dir, filenames, transform=None, device=torch.device('cuda:0')):
        """Constructor.

            Args:
                image_dir: The directory containing the images
                mask_dir: The directory containing the masks
                filenames: The filanems for the images associate with this dataset
                transform: Optional transform to be applied on a sample (default: None).
                device: The device on which tensors should be created (default: 'cuda:0')
        """
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.filenames = filenames
        self.shift = 0.
        self.norm = 255.
        self.normalise = False
        self.device = device

    def set_image_stats(self, shift, norm):
        """Set the image normalisation parameters.

            Applied as norm_val = (val-shift) / norm

            Args:
                shift: The shift parameter (e.g. mean)
                norm: The normalisation parameter (e.g. standard deviation)
        """
        self.shift = shift
        self.norm = norm

    def set_normalisation(self, norm=True):
        """Sets whether or not the image should be normalised.

            Args:
                norm: True if the image should be normalised, False otherwise (default: True)
        """
        self.normalise = norm

    def __len__(self):
        """Retrieve the number of samples in the dataset.

            Returns:
                The number of samples in the dataset
        """
        return len(self.filenames)

    def __getitem__(self, idx):
        """Retrieve a sample from the dataset.

            Args:
                idx: The index of the sample to be retrieved

            Returns:
                The sample requested
        """
        img_name = os.path.join(self.image_dir, self.filenames[idx])
        image = np.asarray(self.open_image(img_name)).astype(np.float32)

        mask_name = os.path.join(self.mask_dir, self.filenames[idx])
        # When using categorical cross entropy, need an un-normalised long
        mask = np.asarray(self.open_image(mask_name)).astype(np.int_)

        image = torch.as_tensor(np.expand_dims(image, axis=0), device=self.device, dtype=torch.float)
        if self.normalise:
            image -= self.shift
            image /= self.norm
        mask = torch.as_tensor(mask, device=self.device, dtype=torch.long)

        return (image, mask)

    def open_image(self, path):
        """Retrieve an image.

            Args:
                path: The path of the image

            Returns:
                The image
        """
        from PIL import Image
        img = Image.open(path)
        if img.mode  != 'L':
            img = img.convert('L')
        return img

class SegmentationBunch():
    """Associates batches of training, validation and testing datasets suitable
        for segmentation tasks.
    """

    def __init__(self, root_dir, image_dir, mask_dir, batch_size, valid_pct=0.25,
                 test_pct=0.0, transform=None, device=torch.device('cuda:0')):
        """Constructor.

            Args:
                root_dir: The top-level directory containing the images
                image_dir: The relative directory containing the images
                mask_dir: The relative directory containing the masks
                batch_size: The batch size
                valid_pct: The fraction of images to be used for validation (default: 0.1)
                test_pct: The fraction of images to be used for testing (default: 0.0)
                transform: Any transforms to be applied to the images (default: None)
                device: The device on which tensors should be created (default: 'cuda:0')
        """
        assert((valid_pct + test_pct) < 1.)
        image_dir = os.path.join(root_dir, image_dir)
        mask_dir = os.path.join(root_dir, mask_dir)
        transform = transform
        image_filenames = np.array(next(os.walk(image_dir))[2])

        n_files = len(image_filenames)
        valid_size = int(n_files * valid_pct)
        train_size = n_files - valid_size
        sample = np.random.permutation(n_files)
        train_sample = sample[valid_size:] if not train_size else \
            sample[valid_size:valid_size + train_size]
        valid_sample = sample[:valid_size]
        
        train_filenames = image_filenames[train_sample]
        valid_filenames = image_filenames[valid_sample]
        print(valid_filenames[0:10])

        train_ds = SegmentationDataset(image_dir, mask_dir, train_filenames, transform, device)
        train_ds.set_image_stats(*self.image_stats())
        train_ds.set_normalisation(True)
        self.train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=0)

        valid_ds = SegmentationDataset(image_dir, mask_dir, valid_filenames, transform, device)
        valid_ds.set_image_stats(*self.image_stats())
        valid_ds.set_normalisation(True)
        self.valid_dl = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=0)

        self.test_dl = None

    def image_stats(self):
        """Retrieve the normalisation statistics.

            Returns:
                The normalisation statistics as the tuple (shift, norm)
        """
        return 0., 255.

    def count_classes(self, num_classes):
        """Count the number of instances of each class in the training set

            Args:
                num_classes: The number of classes in the training set

            Returns:
                A list of the number of instances of each class
        """
        count = np.zeros(num_classes)
        for batch in self.train_dl:
            _, truth = batch
            unique, counts = torch.unique(truth, return_counts=True)
            unique = [ u.item() for u in unique ]
            counts = [ c.item() for c in counts ]
            this_dict = dict(zip(unique, counts))
            for key in this_dict:
                count[key] += this_dict[key]
        return count

In [None]:
# model.py

import torch.nn as nn
import torch


def maxpool():
    """Return a max pooling layer.

        The maxpooling layer has a kernel size of 2, a stride of 2 and no padding.

        Returns:
            The max pooling layer
    """
    return nn.MaxPool2d(kernel_size = 2, stride = 2, padding = 0)


def dropout(prob):
    """Return a dropout layer.

        Args:
            prob: The probability that drop out will be applied.

        Returns:
            The dropout layer
    """
    return nn.Dropout(prob)


def reinit_layer(seq_block, leak = 0.0, use_kaiming_normal=True):
    """Reinitialises convolutional layer weights.

        The default Kaiming initialisation in PyTorch is not optimal, this method
        reinitialises the layers using better parameters

        Args:
            seq_block: The layer to be reinitialised.
            leak: The leakiness of ReLU (default: 0.0)
            use_kaiming_normal: Use Kaiming normal if True, Kaiming uniform otherwise (default: True)
    """
    for layer in seq_block:
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.ConvTranspose2d):
            if use_kaiming_normal:
                nn.init.kaiming_normal_(layer.weight, a = leak)
            else:
                nn.init.kaiming_uniform_(layer.weight, a = leak)
                layer.bias.data.zero_()


class ConvBlock(nn.Module):
    """A convolution block
    """

    # Sigmoid activation suitable for binary cross-entropy
    def __init__(self, c_in, c_out, k_size = 3, k_pad = 1):
        """Constructor.

            Args:
                c_in: The number of input channels
                c_out: The number of output channels
                k_size: The size of the convolution filter
                k_pad: The amount of padding around the images
        """
        super(ConvBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(c_in, c_out, kernel_size = k_size, padding = k_pad, stride = 1),
            nn.GroupNorm(8, c_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(c_out, c_out, kernel_size = k_size, padding = k_pad, stride = 1),
            nn.GroupNorm(8, c_out)
        )
        reinit_layer(self.block)

    def forward(self, x):
        """Forward pass.

            Args:
                x: The input to the layer

            Returns:
                The output from the layer
        """
        return self.block(x)

class TransposeConvBlock(nn.Module):
    """A tranpose convolution block
    """

    def __init__(self, c_in, c_out, k_size = 3, k_pad = 1):
        """Constructor.

            Args:
                c_in: The number of input channels
                c_out: The number of output channels
                k_size: The size of the convolution filter
                k_pad: The amount of padding around the images
        """
        super(TransposeConvBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ConvTranspose2d(c_in, c_out, kernel_size = k_size, padding = k_pad, output_padding = 1, stride = 2),
            nn.GroupNorm(8, c_out),
            nn.ReLU(inplace=True))
        reinit_layer(self.block)

    def forward(self, x):
        """Forward pass.

            Args:
                x: The input to the layer

            Returns:
                The output from the layer
        """
        return self.block(x)

class Sigmoid(nn.Module):
    """A sigmoid activation function that supports categorical cross-entropy
    """

    def __init__(self, out_range = None):
        """Constructor.

            Args:
                out_range: A tuple covering the minimum and maximum values to map to
        """
        super(Sigmoid, self).__init__()
        if out_range is not None:
            self.low, self.high = out_range
            self.range = self.high - self.low
        else:
            self.low = None
            self.high = None
            self.range = None

    def forward(self, x):
        """Applies the sigmoid function.

            Rescales to the specified range if provided during construction

            Args:
                x: The input to the layer

            Returns:
                The (potentially scaled) sigmoid of the input
        """
        if self.low is not None:
            return torch.sigmoid(x) * (self.range) + self.low
        else:
            return torch.sigmoid(x)

class ListModule(nn.Module):
    """A container for a list of modules.

        This class provides flexibility for the network architecture by ensuring layers in a configurable
        architecture are correctly registered with torch.nn.Module
    """

    def __init__(self, *args):
        """Constructor.

            Args:
                args: A list of modules to be added to the network
        """
        super(ListModule, self).__init__()
        for i, module in enumerate(args):
            self.add_module(str(i), module)

    def __getitem__(self, idx):
        """Retrieve a module.

            Args:
                idx: The index of the module to be retrieved

            Returns:
                The requested module
        """
        if idx < 0 or idx >= len(self._modules):
            raise IndexError('index {} is out of range'.format(idx))
        it = iter(self._modules.values())
        for i in range(idx):
            next(it)
        return next(it)

    def __iter__(self):
        """Retrieve an iterator for the modules

            Returns:
                An iterator for the modules
        """
        return iter(self._modules.values())

    def __len__(self):
        """Retrieve the number of modules

            Returns:
                The number of modules
        """
        return len(self._modules)

class UNet(nn.Module):
    """A U-Net for semantic segmentation.
    """

    def __init__(self, in_dim, n_classes, depth = 4, n_filters = 16, drop_prob = 0.1, y_range = None):
        """Constructor.

            Args:
                in_dim: The number of input channels
                n_classes: The number of classes
                depth: The number of convolution blocks in the downsampling and upsampling arms of the U (default: 4)
                n_filters: The number of filters in the first layer (doubles for each downsample) (default: 16)
                drop_prob: The dropout probability for each layer (default: 0.1)
                y_range: The range of values (low, high) to map to in the output (default: None)
        """
        super(UNet, self).__init__()
        # Contracting Path
        ds_convs = []
        for i in range(depth):
            if i == 0: ds_convs.append(ConvBlock(in_dim, n_filters * 2**i))
            else: ds_convs.append(ConvBlock(n_filters * 2**(i - 1), n_filters * 2**i))
        self.ds_convs = ListModule(*ds_convs)

        ds_maxpools = []
        for i in range(depth):
            ds_maxpools.append(maxpool())
        self.ds_maxpools = ListModule(*ds_maxpools)

        ds_dropouts = []
        for i in range(depth):
            ds_dropouts.append(dropout(drop_prob))
        self.ds_dropouts = ListModule(*ds_dropouts)

        self.bridge = ConvBlock(n_filters * 2**(depth - 1), n_filters * 2**depth)

        # Expansive Path
        us_tconvs = []
        for i in range(depth, 0, -1):
            us_tconvs.append(TransposeConvBlock(n_filters * 2**i, n_filters * 2**(i - 1)))
        self.us_tconvs = ListModule(*us_tconvs)

        us_convs = []
        for i in range(depth, 0, -1):
            us_convs.append(ConvBlock(n_filters * 2**i, n_filters * 2**(i - 1)))
        self.us_convs = ListModule(*us_convs)

        us_dropouts = []
        for i in range(depth):
            us_dropouts.append(dropout(drop_prob))
        self.us_dropouts = ListModule(*us_dropouts)

        self.output = nn.Sequential(nn.Conv2d(n_filters * 1, n_classes, 1), Sigmoid(y_range))

    def forward(self, x):
        """Forward pass.

            Args:
                x: The input to the layer

            Returns:
                The output from the layer
        """
        res = x
        conv_stack = []

        # Downsample
        for i in range(len(self.ds_convs)):
            res = self.ds_convs[i](res); conv_stack.append(res)
            res = self.ds_maxpools[i](res)
            res = self.ds_dropouts[i](res)

        # Bridge
        res = self.bridge(res)

        # Upsample
        for i in range(len(self.us_convs)):
            res = self.us_tconvs[i](res)
            res = torch.cat([res, conv_stack.pop()], dim=1)
            res = self.us_dropouts[i](res)
            res = self.us_convs[i](res)

        output = self.output(res)

        return output

In [None]:
# network.p

import numpy as np
import torch
import torch.optim as opt


def set_seed(seed):
    """Set the various seeds and flags to ensure deterministic performance

        Args:
            seed: The random seed
    """
    torch.backends.cudnn.deterministic = True   # Note, can impede performance
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    torch.manual_seed(seed)


def get_class_weights(stats):
    """Get the weights for each class

        Each class has a weight inversely proportional to the number of instances in the training set

        Args:
            stats: The number of instances of each class

        Returns:
            The weights for each class
    """
    weights = 1. / stats
    return [weight / sum(weights) for weight in weights]


def load_model(filename, num_classes, weights, device):
    """Load a model

        Args:
            filename: The name of the file with the pretrained model parameters
            num_classes: The number of classes available to predict
            weights: The weights to apply to the classes
            device: The device on which to run

        Returns:
            A tuple composed (in order) of the model, loss function, and optimiser
    """
    model = UNet(1, n_classes = num_classes, depth = 4, n_filters = 16, y_range = (0, num_classes))
    model.load_state_dict(torch.load(filename, map_location="cpu"))
    model.eval()
    loss_fn = nn.CrossEntropyLoss(torch.as_tensor(weights, device=device, dtype=torch.float))
    optim = opt.Adam(model.parameters())
    return model, loss_fn, optim


def save_model(model, input, filename):
    """Save the model

        The model is saved as both a pkl file and a TorchScript pt file, which can be loaded via
            model.load_state_dict(torch.load(PATH))
            model.eval()

        Args:
            model: The model to save
            input: An example input to the model
            filename: The output filename, without file extension
    """
    eval_model = model.eval()
    torch.save(eval_model.state_dict(), f"{filename}.pkl")
    torch_script_model = torch.jit.trace(eval_model, input, check_trace=False)
    torch_script_model.save(f"{filename}_traced.pt")


def accuracy(pred, truth, type=None):
    """Get the network accuracy

        Args:
            pred: The network prediction
            truth: The true class
            type: The class whose accuracy should be determined (default: None - overall accuracy)

        Returns:
            The accuracy
    """
    target = truth.squeeze(1)
    mask = target != 0 if type is None else target == type
    return (pred.argmax(dim=1)[mask] == target[mask]).float().mean()


def create_model(num_classes, weights, device):
    """Create the model

        Args:
            num_classes: The number of classes available to predict
            weights: The weights to apply to the classes
            device: The device on which to run

        Returns:
            A tuple composed (in order) of the model, loss function, and optimiser
    """
    model = UNet(1, n_classes = num_classes, depth = 4, n_filters = 16, y_range = (0, num_classes))
    loss_fn = nn.CrossEntropyLoss(torch.as_tensor(weights, device=device, dtype=torch.float))
    optim = opt.Adam(model.parameters())
    return model, loss_fn, optim

In [None]:
from tqdm.notebook import tqdm

# This line is important for GPU running, otherwise some weights end up on the CPU
torch.set_default_tensor_type(torch.cuda.FloatTensor)

view = "C"
the_seed = 42
gpu = torch.device('cuda:0')
batch_size = 32
NUM_CLASSES = 5   # Standard case with SHOWER = 1, TRACK = 2, DIFFUSE = 3, NUll = 0
# Standard case with MIP = 1, HIP = 2, SHOWER = 3, DIFFUSE = 4, NUll = 0

set_seed(the_seed)
bunch = SegmentationBunch(f"Images{view}", "Hits", "Truth", batch_size=batch_size, device=gpu)
train_stats = bunch.count_classes(NUM_CLASSES)
weights = get_class_weights(train_stats)
print(weights)

model, loss_fn, optim = create_model(NUM_CLASSES, weights, gpu)

n_epochs = 20
train_losses = torch.zeros(n_epochs * len(bunch.train_dl), device=gpu)
val_losses = torch.zeros(n_epochs, device=gpu)
batch_losses = torch.zeros(len(bunch.valid_dl), device=gpu)

train_accs = torch.zeros([NUM_CLASSES, n_epochs * len(bunch.train_dl)], device=gpu)
val_accs = torch.zeros([NUM_CLASSES, n_epochs], device=gpu)
batch_accs = torch.zeros([NUM_CLASSES, len(bunch.valid_dl)], device=gpu)

import time
t0 = time.perf_counter()
i = 0
set_seed(the_seed)
for e in tqdm(range(0, n_epochs), desc="Training"):
    model = model.train()
    n_batches = len(bunch.train_dl)
    for b, batch in enumerate(bunch.train_dl):
        x, y = batch
        pred = model.forward(x)
        loss = loss_fn(pred, y)

        train_losses[i] = loss.item()
        train_accs[0][i] = accuracy(pred, y)
        train_accs[1][i] = accuracy(pred, y, type = 1)
        train_accs[2][i] = accuracy(pred, y, type = 2)
        train_accs[3][i] = accuracy(pred, y, type = 3)
        train_accs[4][i] = accuracy(pred, y, type = 4)

        loss.backward()
        optim.step()
        #scheduler.step()
        optim.zero_grad()
        i += 1
        if b == (n_batches - 1):
            save_model(model, x, f"unet_{view}_{e}")
    start = e * batch_size
    finish = start + batch_size
    print(f"Train - Loss: {train_losses[start:finish].mean():.3f} MIP: {train_accs[1][start:finish].mean():.3f} HIP: {train_accs[2][start:finish].mean():.3f} Shower: {train_accs[3][start:finish].mean():.3f}  Diffuse: {train_accs[4][start:finish].mean():.3f}")
    
    # Validate
    model = model.eval()
    with torch.no_grad():
        for b, batch in enumerate(bunch.valid_dl):
            x, y = batch
            pred = model.forward(x)
            loss = loss_fn(pred, y)

            batch_losses[b] = loss.item()
            batch_accs[0][b] = accuracy(pred, y)
            batch_accs[1][b] = accuracy(pred, y, type = 1)
            batch_accs[2][b] = accuracy(pred, y, type = 2)
            batch_accs[3][b] = accuracy(pred, y, type = 3)
            batch_accs[4][b] = accuracy(pred, y, type = 4)
        val_losses[e] = torch.mean(batch_losses)
        val_accs[0][e] = torch.mean(batch_accs[0][~torch.isnan(batch_accs[0])])
        val_accs[1][e] = torch.mean(batch_accs[1][~torch.isnan(batch_accs[1])])
        val_accs[2][e] = torch.mean(batch_accs[2][~torch.isnan(batch_accs[2])])
        val_accs[3][e] = torch.mean(batch_accs[3][~torch.isnan(batch_accs[3])])
        val_accs[4][e] = torch.mean(batch_accs[3][~torch.isnan(batch_accs[4])])
    print(f"Valid - Loss: {val_losses[e]:.3f} MIP: {val_accs[1][e]:.3f} HIP: {val_accs[2][e]:.3f} Shower: {val_accs[3][e]:.3f}  Diffuse: {val_accs[4][e]:.3f}")
t1 = time.perf_counter()
print(f"Networked trained in {t1 - t0:0.3f} s")

In [None]:
import scipy.stats as stats
binning = np.linspace(0, NUM_CLASSES, NUM_CLASSES + 1, dtype=int)

model = model.to(gpu)
confusion = np.zeros((NUM_CLASSES,NUM_CLASSES))
for img, cls in bunch.valid_dl:
    img = img.to(gpu)
    output = model(img)
    _, preds = torch.max(output, 1)
    
    cls_detached = cls.cpu().numpy().flatten()
    preds_detached = preds.cpu().numpy().flatten()
    
    H, *_ = stats.binned_statistic_2d(preds_detached, cls_detached, None,
                                      bins=[binning, binning], statistic='count')
    confusion += H

In [None]:
binning

In [None]:
e11 = np.sum(confusion[1:3,1:3])
e12 = np.sum(confusion[1:3,3:])
e21 = np.sum(confusion[3:,1:3])
e22 = np.sum(confusion[3:,3:])

In [None]:
reduced_confusion = np.vstack((np.hstack((e11,e12)), np.hstack((e21, e22))))

In [None]:
sums = np.sum(reduced_confusion, axis=1).repeat(2).reshape((2,2))
reduced_confusion /= sums

In [None]:
import matplotlib.pyplot as plt

In [None]:
temporary = confusion.copy()
fig = plt.figure(figsize=(15, 10))
plt.xlabel('true class')
plt.ylabel('fraction')
plt.step(list(np.arange(1, NUM_CLASSES)), np.sum(temporary[1:], axis=1) / np.sum(temporary[1:]), where="mid")

In [None]:
sums = np.sum(confusion, axis=1).repeat(NUM_CLASSES).reshape((NUM_CLASSES,NUM_CLASSES))
confusion /= sums

In [None]:
print(f"--- Class Accuracy")
for t in range(confusion.shape[0]):
    print(f"{t:2}: {100*(confusion[t,t] / confusion[t].sum()):.1f}")
print()

In [None]:
confusion[0,:] = 0
confusion[:,0] = 0

sums = np.sum(confusion, axis=1).repeat(NUM_CLASSES).reshape((NUM_CLASSES,NUM_CLASSES))
sums[0,:] = 1
confusion /= sums

print(f"--- Class Accuracy")
for t in range(confusion.shape[0]):
    print(f"{t:2}: {100*confusion[t,t]:.1f}")
print()

In [None]:
np.sum(confusion[1:,1:3], axis=1)

In [None]:
fig = plt.figure(figsize=(15, 15))
plt.xlabel('truth')
plt.ylabel('network')
plt.imshow(confusion)
plt.colorbar()

for t in range(1, NUM_CLASSES):
    for n in range(1, NUM_CLASSES):
        print(f"{confusion[t, n]:.2f}", end=" ")
    print()

In [None]:
for t in range(2):
    for n in range(2):
        print(f"{reduced_confusion[t, n]:.2f}", end=" ")
    print()