In [None]:
cd ..

In [None]:
import numpy as np

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

In [None]:
import echofilter.shardloader

In [None]:
transect_pth = 'Survey17/Survey17_GR1_S3W_F'
timestamps, depths, signals, d_top, d_bot = echofilter.shardloader.load_transect_from_shards(
    transect_pth, 100, 800,
    root_data_dir='/media/scott/scratch/Datasets/dsforce'
)

plt.figure(figsize=(12, 12))
plt.pcolormesh(timestamps, -depths, signals.T)
plt.plot(timestamps, -d_bot, 'b')
plt.plot(timestamps, -d_top, 'c')
plt.show()

In [None]:
transect_pth = 'Survey17/Survey17_GR1_S3W_F'
timestamps, depths, signals, d_top, d_bot = echofilter.shardloader.load_transect_from_shards(
    transect_pth, -100, 800,
    root_data_dir='/media/scott/scratch/Datasets/dsforce'
)

plt.figure(figsize=(12, 12))
plt.pcolormesh(timestamps, -depths, signals.T)
plt.plot(timestamps, -d_bot, 'b')
plt.plot(timestamps, -d_top, 'c')
plt.show()

In [None]:
transect_pth = 'Survey17/Survey17_GR1_S3W_F'
timestamps, depths, signals, d_top, d_bot = echofilter.shardloader.load_transect_from_shards(
    transect_pth, 0, 128,
    root_data_dir='/media/scott/scratch/Datasets/dsforce'
)

plt.figure(figsize=(12, 12))
plt.pcolormesh(timestamps, -depths, signals.T)
plt.plot(timestamps, -d_bot, 'b')
plt.plot(timestamps, -d_top, 'c')
plt.show()

In [None]:
import torch.utils.data

In [None]:
class TransectDataset(torch.utils.data.Dataset):

    def __init__(
            self,
            transect_paths,
            window_len=128,
            num_windows_per_transect=0,
            use_dynamic_offsets=True,
            transform=None,
            root_data_dir=None,
            ):
        super(TransectDataset, self).__init__()
        self.window_len = window_len
        self.num_windows = num_windows_per_transect
        self.use_dynamic_offsets = use_dynamic_offsets
        self.transform = transform
        self.root_data_dir = root_data_dir

        self.datapoints = []

        for transect_path in transect_paths:
            # Lookup the number of rows in the transect
            # Load the sharding metadata
            with open(os.path.join(dirname, 'shard_size.txt'), 'r') as f:
                n_timestamps, shard_len = f.readline().strip().split(',')
                n_timestamps = int(n_timestamps)
            # Generate an array for window centers within the transect
            # - if this is for training, we want to randomise the offsets
            # - if this is for validation, we want stable windows
            num_windows = self.num_windows
            if self.num_windows is None or self.num_windows == 0:
                # Load enough windows to include all datapoints
                num_windows = np.ceil(n_timestamps / self.window_len)
            centers = np.linspace(0, n_timestamps, num_windows + 1)[:num_windows]
            if self.use_dynamic_offsets:
                if len(centers) > 1:
                    max_dy_offset = centers[1] - centers[0]
                else:
                    max_dy_offset = n_timestamps
                centers += np.random.rand() * max_dy_offset
            centers = np.round(centers)
            # Add each (transect, center) to the list for this epoch
            for center_idx in centers:
                self.datapoints.append((transect_path, int(center_idx)))

    def __getitem__(self, index):
        transect_pth, center_idx = self.datapoints[index]
        # Load data from shards
        timestamps, depths, signals, d_top, d_bot = echofilter.shardloader.load_transect_from_shards(
            transect_pth,
            center_idx - int(self.window_len / 2),
            center_idx - int(self.window_len / 2) + self.window_len,
            root_data_dir=self.root_data_dir,
        )
        # Convert lines to masks
        ddepths = np.broadcast_to(depths, signals.shape)
        mask_top = ddepths < np.expand_dims(d_top, -1)
        mask_bot = ddepths > np.expand_dims(d_bot, -1)
        if self.transform is not None:
            signals, mask_top, mask_bot = self.transform(signals, mask_top, mask_bot)
        return signals, mask_top, mask_bot

    def __len__(self):
        return len(self.datapoints)

In [None]:
TransectDataset()