In [None]:
cd ..

In [None]:
import os
import random

In [None]:
import numpy as np

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

In [None]:
import echofilter.shardloader

In [None]:
ROOT_DATA_DIR = '/media/scott/scratch/Datasets/dsforce/'

In [None]:
transect_pth = 'Survey17/Survey17_GR1_S3W_F'
timestamps, depths, signals, d_top, d_bot = echofilter.shardloader.load_transect_from_shards_rel(
    transect_pth, 100, 800,
    root_data_dir=ROOT_DATA_DIR,
)

plt.figure(figsize=(12, 12))
plt.pcolormesh(timestamps, -depths, signals.T)
plt.plot(timestamps, -d_bot, 'b')
plt.plot(timestamps, -d_top, 'c')
plt.show()

In [None]:
transect_pth = 'Survey17/Survey17_GR1_S3W_F'
timestamps, depths, signals, d_top, d_bot = echofilter.shardloader.load_transect_from_shards_rel(
    transect_pth, -100, 800,
    root_data_dir=ROOT_DATA_DIR,
)

plt.figure(figsize=(12, 12))
plt.pcolormesh(timestamps, -depths, signals.T)
plt.plot(timestamps, -d_bot, 'b')
plt.plot(timestamps, -d_top, 'c')
plt.show()

In [None]:
transect_pth = 'Survey17/Survey17_GR1_S3W_F'
timestamps, depths, signals, d_top, d_bot = echofilter.shardloader.load_transect_from_shards_rel(
    transect_pth, 0, 128,
    root_data_dir=ROOT_DATA_DIR,
)

plt.figure(figsize=(12, 12))
plt.pcolormesh(timestamps, -depths, signals.T)
plt.plot(timestamps, -d_bot, 'b')
plt.plot(timestamps, -d_top, 'c')
plt.show()

In [None]:
import torch.utils.data

In [None]:
class TransectDataset(torch.utils.data.Dataset):

    def __init__(
            self,
            transect_paths,
            window_len=128,
            crop_depth=70,
            num_windows_per_transect=0,
            use_dynamic_offsets=True,
            transform_pre=None,
            transform_post=None,
            ):
        '''
        TransectDataset
        
        Parameters
        ----------
        transect_paths : list
            Absolute paths to transects.
        window_len : int
            Width (number of timestamps) to load. Default is `128`.
        crop_depth : float
            Maximum depth to include, in metres. Deeper data will be cropped away.
            Default is `70`.
        num_windows_per_transect : int
            Number of windows to extract for each transect. Start indices for the
            windows will be equally spaced across the total width of the transect.
            If this is `0`, the number of windows will be inferred automatically
            based on `window_len` and the total width of the transect, resulting
            in a different number of windows for each transect. Default is `0`.
        use_dynamic_offsets : bool
            Whether starting indices for each window should be randomly offset.
            Set to `True` for training and `False` for testing. Default is `True`.
        transform_pre : callable
            Operations to perform to the dictionary containing a single sample.
            These are performed before generating the masks. Default is `None`.
        transform_post : callable
            Operations to perform to the dictionary containing a single sample.
            These are performed after generating the masks. Default is `None`.
        '''
        super(TransectDataset, self).__init__()
        self.window_len = window_len
        self.crop_depth = crop_depth
        self.num_windows = num_windows_per_transect
        self.use_dynamic_offsets = use_dynamic_offsets
        self.transform_pre = transform_pre
        self.transform_post = transform_post

        self.datapoints = []

        for transect_path in transect_paths:
            # Lookup the number of rows in the transect
            # Load the sharding metadata
            with open(os.path.join(transect_path, 'shard_size.txt'), 'r') as f:
                n_timestamps, shard_len = f.readline().strip().split(',')
                n_timestamps = int(n_timestamps)
            # Generate an array for window centers within the transect
            # - if this is for training, we want to randomise the offsets
            # - if this is for validation, we want stable windows
            num_windows = self.num_windows
            if self.num_windows is None or self.num_windows == 0:
                # Load enough windows to include all datapoints
                num_windows = int(np.ceil(n_timestamps / self.window_len))
            centers = np.linspace(0, n_timestamps, num_windows + 1)[:num_windows]
            if len(centers) > 1:
                max_dy_offset = centers[1] - centers[0]
            else:
                max_dy_offset = n_timestamps
            if self.use_dynamic_offsets:
                centers += np.random.rand() * max_dy_offset
            else:
                centers += max_dy_offset / 2
            centers = np.round(centers)
            # Add each (transect, center) to the list for this epoch
            for center_idx in centers:
                self.datapoints.append((transect_path, int(center_idx)))

    def __getitem__(self, index):
        transect_pth, center_idx = self.datapoints[index]
        # Load data from shards
        timestamps, depths, signals, d_top, d_bot = echofilter.shardloader.load_transect_from_shards_abs(
            transect_pth,
            center_idx - int(self.window_len / 2),
            center_idx - int(self.window_len / 2) + self.window_len,
        )
        sample = {
            'timestamps': timestamps,
            'depths': depths,
            'signals': signals,
            'd_top': d_top,
            'd_bot': d_bot,
        }
        if self.transform_pre is not None:
            sample = self.transform_pre(sample)
        # Apply depth crop
        depth_crop_mask = sample['depths'] <= self.crop_depth
        sample['depths'] = sample['depths'][depth_crop_mask]
        sample['signals'] = sample['signals'][:, depth_crop_mask]
        # Convert lines to masks
        ddepths = np.broadcast_to(sample['depths'], sample['signals'].shape)
        mask_top = np.single(ddepths < np.expand_dims(sample['d_top'], -1))
        mask_bot = np.single(ddepths > np.expand_dims(sample['d_bot'], -1))
        sample['mask_top'] = mask_top
        sample['mask_bot'] = mask_bot
        sample['r_top'] = sample['d_top'] / abs(sample['depths'][-1] - sample['depths'][0])
        sample['r_bot'] = sample['d_bot'] / abs(sample['depths'][-1] - sample['depths'][0])
        if self.transform_post is not None:
            sample = self.transform_post(sample)
        return sample

    def __len__(self):
        return len(self.datapoints)

In [None]:
transect_paths = [os.path.join(ROOT_DATA_DIR, 'surveyExports_sharded/Survey17/Survey17_GR1_S3W_F')] * 2

In [None]:
dataset = TransectDataset(transect_paths)

In [None]:
dataset.datapoints

In [None]:
sample = dataset[0]

plt.figure(figsize=(12, 12))
plt.imshow(sample['signals'])
plt.show()
plt.figure(figsize=(12, 12))
plt.imshow(sample['mask_top'])
plt.show()
plt.figure(figsize=(12, 12))
plt.imshow(sample['mask_bot'])
plt.show()

In [None]:
sample['signals'].shape

In [None]:
loader = torch.utils.data.DataLoader(dataset, batch_size=2)

In [None]:
for sample in loader:
    print(sample['signals'].shape)

In [None]:
import skimage.transform

In [None]:
class Rescale(object):
    '''
    Rescale the image(s) in a sample to a given size.

    Parameters
    ----------
    output_size : tuple or int
        Desired output size. If tuple, output is matched to output_size. If int,
        output is square.
    '''

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            output_size = (output_size, output_size)
        self.output_size = output_size

    def __call__(self, sample):

        for key in ('signals', 'mask_top', 'mask_bot'):
            if key in sample:
                sample[key] = skimage.transform.resize(
                    sample[key],
                    self.output_size,
                    clip=False,
                    preserve_range=False,
                )

        return sample

In [None]:
class Normalize(object):
    '''
    Normalize mean and standard deviation of image.

    Note that changes are made inplace.

    Parameters
    ----------
    mean : float
        Expected sample pixel mean.
    stdev : float
        Expected sample standard deviation of pixel intensities.
    '''

    def __init__(self, mean, stdev):
        self.mean = mean
        self.stdev = stdev

    def __call__(self, sample):

        sample['signals'] -= self.mean
        sample['signals'] /= self.stdev

        return sample

In [None]:
class RandomReflection(object):
    '''
    Randomly reflect a sample.

    Parameters
    ----------
    axis : int, optional
        Axis to reflect. Default is 0.
    p : float, optional
        Probability of reflection. Default is 0.5.
    '''

    def __init__(self, axis=0, p=0.5):
        self.axis = axis
        self.p = p

    def __call__(self, sample):

        if random.random() > self.p:
            # Nothing to do
            return sample
        
        # Reflect x co-ordinates
        sample['timestamps'] = sample['timestamps'][::-1]

        # Reflect data
        for key in ('signals', 'd_top', 'd_bot', 'mask_top', 'mask_bot'):
            if key in sample:
                sample[key] = np.flip(sample[key], self.axis)

        return sample

In [None]:
class RandomStretchDepth(object):
    '''
    Rescale a set of images in a sample to a given size.
    
    Note that this transform doesn't change images, just the `depth`, `d_top`, and `d_bot`.
    Note that changes are made inplace.

    Parameters
    ----------
    max_factor : float
        Maximum stretch factor. A number between `[1, 1 + max_factor]` will be generated,
        and the depth will either be divided or multiplied by the generated stretch
        factor.
    expected_bottom_gap : float
        Expected gap between actual ocean floor and target bottom line.
    '''

    def __init__(self, max_factor, expected_bottom_gap=1):
        self.max_factor = max_factor
        self.expected_bottom_gap = expected_bottom_gap

    def __call__(self, sample):

        factor = random.uniform(1.0, 1.0 + self.max_factor)

        if random.random() > 0.5:
            factor = 1. / factor

        sample['d_bot'] += self.expected_bottom_gap
        for key in ('depths', 'd_top', 'd_bot'):
            sample[key] *= factor
        sample['d_bot'] -= self.expected_bottom_gap
        
        return sample

In [None]:
class RandomCropWidth(object):
    '''
    Randomly crop a sample in the width dimension.

    Parameters
    ----------
    max_crop_fraction : float
        Maximum amount of material to crop away, as a fraction of the total width.
        The `crop_fraction` will be sampled uniformly from the range
        `[0, max_crop_fraction]`. The crop is always centred.
    '''

    def __init__(self, max_crop_fraction):
        self.max_crop_fraction = max_crop_fraction

    def __call__(self, sample):
        
        width = sample['signals'].shape[0]

        crop_fraction = random.uniform(0., self.max_crop_fraction)
        crop_amount = crop_fraction * width
        
        lft = int(crop_amount / 2)
        rgt = lft + width - int(crop_amount)

        # Crop data
        for key in ('timestamps', 'signals', 'd_top', 'd_bot', 'mask_top', 'mask_bot'):
            if key in sample:
                sample[key] = sample[key][lft:rgt]

        return sample

In [None]:
class ColorJitter(object):
    '''
    Randomly change the brightness and contrast of a normalized image.

    Note that changes are made inplace.

    Parameters
    ----------
    brightness : float or tuple of float (min, max)
        How much to jitter brightness. `brightness_factor` is chosen uniformly from
        `[-brightness, brightness]`
        or the given `[min, max]`. `brightness_factor` is then added to the image.
    contrast : (float or tuple of float (min, max))
        How much to jitter contrast. `contrast_factor` is chosen uniformly from
        `[max(0, 1 - contrast), 1 + contrast]`
        or the given `[min, max]`. Should be non negative numbers.
    '''
    def __init__(self, brightness=0, contrast=0):
        self.brightness = self._check_input(
            brightness,
            'brightness',
            center=0,
            bound=(float('-inf'), float('inf')),
            clip_first_on_zero=False,
        )
        self.contrast = self._check_input(contrast, 'contrast')

    def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
        if isinstance(value, (float, int)):
            if value < 0:
                raise ValueError("If {} is a single number, it must be non negative.".format(name))
            value = [center - value, center + value]
            if clip_first_on_zero:
                value[0] = max(value[0], 0)
        elif isinstance(value, (tuple, list)) and len(value) == 2:
            if not bound[0] <= value[0] <= value[1] <= bound[1]:
                raise ValueError("{} values should be between {}".format(name, bound))
        else:
            raise TypeError("{} should be a single number or a list/tuple with length 2.".format(name))

        if value[0] == value[1] == center:
            value = None
        return value

    def __call__(self, sample):
        init_op = random.randint(0, 1)
        for i_op in range(2):
            op_num = (init_op + i_op) % 2
            if op_num == 0 and self.brightness is not None:
                brightness_factor = random.uniform(self.brightness[0], self.brightness[1])
                sample['signals'] += brightness_factor
            elif op_num == 1 and self.contrast is not None:
                contrast_factor = random.uniform(self.contrast[0], self.contrast[1])
                sample['signals'] *= contrast_factor
        return sample

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        format_string += 'brightness={0}'.format(self.brightness)
        format_string += ', contrast={0})'.format(self.contrast)
        format_string += ')'
        return format_string

In [None]:
import torchvision.transforms

In [None]:
train_transform_pre = torchvision.transforms.Compose([
    RandomCropWidth(0.5),
    RandomStretchDepth(0.5),
    RandomReflection(),
])
train_transform_post = torchvision.transforms.Compose([
    Rescale((128, 512)),
    Normalize(-70, 22),
    ColorJitter(0.5, 0.3),
])

In [None]:
dataset_train = TransectDataset(
    transect_paths,
    window_len=192,
    crop_depth=70,
    num_windows_per_transect=10,
    use_dynamic_offsets=True,
    transform_pre=composed_pre,
    transform_post=composed_post,
)

In [None]:
sample = dataset_train[0]

plt.figure(figsize=(12, 12))
plt.pcolormesh(
    np.linspace(*sample['timestamps'][[0, -1]], sample['signals'].shape[0]),
    -np.linspace(sample['depths'][0], sample['depths'][-1], sample['signals'].shape[1]),
    sample['signals'].T
)
plt.plot(np.linspace(*sample['timestamps'][[0, -1]], sample['d_bot'].shape[0]), -sample['d_bot'], 'b')
plt.plot(np.linspace(*sample['timestamps'][[0, -1]], sample['d_top'].shape[0]), -sample['d_top'], 'c')
plt.show()

plt.figure(figsize=(12, 12))
plt.imshow(sample['signals'])
plt.show()

plt.figure(figsize=(12, 12))
plt.imshow(sample['mask_top'])
plt.show()

plt.figure(figsize=(12, 12))
plt.imshow(sample['mask_bot'])
plt.show()

In [None]:
sample['r_top']

In [None]:
sample['r_bot']

In [None]:
val_transform = torchvision.transforms.Compose([
    Rescale((128, 512)),
    Normalize(-70, 22),
])

dataset_val = TransectDataset(
    transect_paths,
    window_len=128,
    crop_depth=70,
    num_windows_per_transect=20,
    use_dynamic_offsets=False,
    transform_post=val_transform,
)

In [None]:
sample = dataset_val[0]

plt.figure(figsize=(12, 12))
plt.pcolormesh(
    np.linspace(*sample['timestamps'][[0, -1]], sample['signals'].shape[0]),
    -np.linspace(sample['depths'][0], sample['depths'][-1], sample['signals'].shape[1]),
    sample['signals'].T
)
plt.plot(np.linspace(*sample['timestamps'][[0, -1]], sample['d_bot'].shape[0]), -sample['d_bot'], 'b')
plt.plot(np.linspace(*sample['timestamps'][[0, -1]], sample['d_top'].shape[0]), -sample['d_top'], 'c')
plt.show()

plt.figure(figsize=(12, 12))
plt.imshow(sample['signals'])
plt.show()

plt.figure(figsize=(12, 12))
plt.imshow(sample['mask_top'])
plt.show()

plt.figure(figsize=(12, 12))
plt.imshow(sample['mask_bot'])
plt.show()

In [None]:
dataset_val.datapoints