In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from PIL import Image, ImageOps

import random
import os
import json

matplotlib.rcParams['figure.figsize'] = (11.75, 8.5)

In [2]:
def pad_data(data, divisor=16):
    shape_y, shape_x = data.shape
    padding = [None, None]
    if shape_x % divisor != 0:
        for i in range(divisor):
            if data.shape[1] % divisor == 0:
                break
            data = np.c_[data,np.zeros(data.shape[0])]
        padding[0] = -i
    if shape_y % divisor != 0:
        for i in range(divisor):
            if data.shape[0] % divisor == 0:
                break
            data = np.r_[data,[np.zeros(data.shape[1])]]
        padding[1] = -i
    return data, padding

In [4]:
DIVISOR = 16
WINDOW_SIZE = 512
WINDOW_THRESHOLD = 2000
WINDOWS_PER_SLICE = 5
RAW_PATH = os.path.join('data', 'raw')

files_for_stages = {
    'train': [['seistrain1.npz', 'faulttrain1.npz'],
              ['seistrain2.npz', 'faulttrain2.npz'],
              ['seistrain3.npz', 'faulttrain3.npz'],
              ['seistrain4.npz', 'faulttrain4.npz'],
              ['seistrain5.npz', 'faulttrain5.npz'],
              ['seistrain6.npz', 'faulttrain6.npz'],
              ['seistrain7.npz', 'faulttrain7.npz'],
              ['seistrain8.npz', 'faulttrain8.npz'],
              ['seistrain9.npz', 'faulttrain9.npz']],
    'eval': [['seisval1.npz', 'faultval1.npz']],
    'val': [['seistest1.npz', 'faulttest1.npz']]
}

for stage in ('train', 'eval', 'val'):
    fault_path = os.path.join('data', stage, 'fault')
    seis_path = os.path.join('data', stage, 'seis')
    img_cnt = 0
    metadata = []
    for data_names in files_for_stages[stage]:
        print(data_names)
        data = np.load(os.path.join(RAW_PATH, data_names[0]))['arr_0'].T
        min_data_value = np.min(data)
        max_data_value = np.max(data)
        data = (data - min_data_value) / (max_data_value - min_data_value) * 255

        labels = np.load(os.path.join(RAW_PATH, data_names[1]))['arr_0'].T
        assert data.shape == labels.shape

        for horizon_num in range(data.shape[2]):
            data_slice = data[:,:,horizon_num]
            labels_slice = labels[:,:,horizon_num]

            timelines, xlines,  = np.shape(data_slice)
            # print(np.shape(data_slice))
            # break
            for window_num in range(WINDOWS_PER_SLICE):
                while True:
                    random_xline = random.randint(0, xlines - WINDOW_SIZE)
                    random_timeline = random.randint(0, timelines - WINDOW_SIZE)
                    random_labels_window = labels_slice[random_timeline:random_timeline+WINDOW_SIZE, random_xline:random_xline+WINDOW_SIZE]
                    if np.sum(random_labels_window) > WINDOW_THRESHOLD:
                        break
                random_data_window = data_slice[random_timeline:random_timeline+WINDOW_SIZE, random_xline:random_xline+WINDOW_SIZE]

            # data_slice, data_padding = pad_data(data_slice, DIVISOR)
            # labels_slice, labels_padding = pad_data(labels_slice, DIVISOR)
            # labels_slice = labels_slice * 255

                data_img = ImageOps.grayscale(Image.fromarray(random_data_window))
                label_img = ImageOps.grayscale(Image.fromarray(random_labels_window * 255))
                metadata.append({
                    'data': f'{img_cnt}.jpeg',
                    'label': f'{img_cnt}.jpeg'
                })
                data_img.save(os.path.join(seis_path, f'{img_cnt}.jpeg'))
                label_img.save(os.path.join(fault_path, f'{img_cnt}.jpeg'))
                img_cnt += 1
    with open(os.path.join('data', stage, 'metadata.json'), 'w') as file:
        json.dump(metadata, file)

['seistrain1.npz', 'faulttrain1.npz']
['seistrain2.npz', 'faulttrain2.npz']
['seistrain3.npz', 'faulttrain3.npz']
['seistrain4.npz', 'faulttrain4.npz']
['seistrain5.npz', 'faulttrain5.npz']
['seistrain6.npz', 'faulttrain6.npz']
['seistrain7.npz', 'faulttrain7.npz']
['seistrain8.npz', 'faulttrain8.npz']
['seistrain9.npz', 'faulttrain9.npz']
['seisval1.npz', 'faultval1.npz']
['seistest1.npz', 'faulttest1.npz']
