# I- Begin like ViT_TransferLearning.py

In [1]:
from torch.utils.data import DataLoader, Dataset
from random import random
from Utilities import *
from data_augmentation import *

# Paths to search for data
data_path = '/home/rbertille/data/pycharm/ViT_project/pycharm_Geoflow/GeoFlow/Tutorial/Datasets/'
dataset_name = 'TutorialDataset'
files_path = os.path.join(data_path, dataset_name)

train_folder = glob.glob(f'{files_path}/train/*')
validate_folder = glob.glob(f'{files_path}/validate/*')
test_folder = glob.glob(f'{files_path}/test/*')

#def data augmentation, without any augmentation for the moment
data_aug = transforms.Compose(
    [

    ]
)
class CustomDataset(Dataset):
    def __init__(self, folder, transform=None):
        self.data, self.labels = self.load_data_from_folder(folder)
        self.transform = transform #ici

    def load_data_from_folder(self, folder):
        data = []
        labels = []

        for file_path in folder:
            with h5.File(file_path, 'r') as h5file:
                inputs = h5file['shotgather'][:]
                #take second half only= Z component
                inputs = inputs[:,int(inputs.shape[1]/2):]
                labels_data = h5file['vsdepth'][:]

                # print('data shape:',inputs.shape)
                # print('min data=',np.min(inputs))
                inputs = (inputs - np.min(inputs)) / (np.max(inputs) - np.min(inputs))

                # reshape data
                inputs = torch.tensor(inputs, dtype=torch.float32)
                transform_resize = transforms.Compose([
                    transforms.ToPILImage(),
                    #transforms.Resize((224, 224)),
                    transforms.ToTensor()
                ])
                inputs = transform_resize(inputs)

                if inputs.shape[0] == 1:  # Si l'image est en grayscale
                    inputs = inputs.repeat(3, 1, 1)  # Convertir en RGB
                inputs = inputs.numpy()

                data.append(inputs)
                labels.append(labels_data)

        data = np.array(data)

        labels = np.array(labels)
        return data, labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        inputs = self.data[idx]
        labels = self.labels[idx]

        # Convert inputs and labels to Tensors
        inputs = torch.tensor(inputs, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.float32)

        sample = {'data': inputs, 'label': labels}

        if self.transform:
            sample['data'] = self.transform(sample['data']) #ici

        return sample

def create_datasets(data_path, dataset_name):
    train_folder = glob.glob(os.path.join(data_path, dataset_name, 'train', '*'))
    validate_folder = glob.glob(os.path.join(data_path, dataset_name, 'validate', '*'))
    test_folder = glob.glob(os.path.join(data_path, dataset_name, 'test', '*'))

    train_dataset = CustomDataset(train_folder,transform=data_aug)
    validate_dataset = CustomDataset(validate_folder)
    test_dataset = CustomDataset(test_folder)

    return train_dataset, validate_dataset, test_dataset


train_dataset, _, _ = create_datasets(data_path, dataset_name)

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=None)

# I- Example without any augmentation

In [2]:
#display the first example
#apply on our example:
image=train_dataloader.dataset.data[0];nb_traces=image.shape[2]
modified_image = data_aug(image)
#plot the modified example
time_vector = np.linspace(0,1.5,modified_image.shape[0])
plt.imshow(modified_image[0], aspect='auto', cmap='gray',extent=[0,nb_traces,time_vector[-1],time_vector[0]])
plt.xlabel('Traces')
plt.ylabel('Time (s)')
plt.show()


# II - Plot a sample with a trace shift

Purpose is to augment the data by shifting the traces by a random number of grid points. The number of grid points to shift is a random number between 0 and a given ratio of the number of grid points in the data. The traces are shifted to the top/bottom, and the values of the new grid points are set to 0.

In [3]:
#create transform
# Compose the custom augmentations with available augmentations.
data_aug = transforms.Compose(
    [
        TraceShift(shift_ratio=0.02)
    ]
)
#transform image into torch tensor
torch_image = torch.tensor(image, dtype=torch.float32)
modified_image = data_aug(torch_image)
#plot the modified example
plt.imshow(modified_image[0], aspect='auto', cmap='gray',extent=[0,nb_traces,time_vector[-1],time_vector[0]])
plt.xlabel('Traces')
plt.ylabel('Time (s)')
plt.show()

# III - Plot a sample with missing traces

Purpose is to augment the data by randomly removing a given ratio of traces from the data. The number of traces to remove is a random number between 0 and a given ratio of the number of traces in the data. The removed traces are replaced by average values of the trace.

In [4]:
#create transform
# Compose the custom augmentations with available augmentations.
data_aug = transforms.Compose(
    [
        MissingTraces(missing_trace_ratio=0.04)
    ]
)
#transform image into torch tensor
torch_image = torch.tensor(image, dtype=torch.float32)
modified_image = data_aug(torch_image)
#plot the modified example
plt.imshow(modified_image[0], aspect='auto', cmap='gray',extent=[0,nb_traces,time_vector[-1],time_vector[0]])
plt.xlabel('Traces')
plt.ylabel('Time (s)')
plt.show()

# IV - Plot a sample with random dead traces

Purpose is to augment the data by randomly setting a given ratio of traces to 0. The number of traces to set to 0 is a random number between 0 and a given ratio of the number of traces in the data.


In [5]:
#create transform
# Compose the custom augmentations with available augmentations.
data_aug = transforms.Compose(
    [
        DeadTraces(dead_trace_ratio=0.04)
    ]
)
#transform image into torch tensor
torch_image = torch.tensor(image, dtype=torch.float32)
modified_image = data_aug(torch_image)
#plot the modified example
plt.imshow(modified_image[0], aspect='auto', cmap='gray',extent=[0,nb_traces,time_vector[-1],time_vector[0]])
plt.xlabel('Traces')
plt.ylabel('Time (s)')
plt.show()

# V - Plot a sample with random noise

Purpose is to augment the data by adding random noise to the data. The noise is a value around the mean, its strength is controlled by the standard deviation.

In [6]:
#create transform
# Compose the custom augmentations with available augmentations.
data_aug = transforms.Compose(
    [
        GaussianNoise(mean=0., std=0.05)
    ]
)
#transform image into torch tensor
torch_image = torch.tensor(image, dtype=torch.float32)
modified_image = data_aug(torch_image)
#plot the modified example
plt.imshow(modified_image[0], aspect='auto', cmap='gray',extent=[0,nb_traces,time_vector[-1],time_vector[0]])
plt.xlabel('Traces')
plt.ylabel('Time (s)')
plt.show()

# VI - Put everything together

Show an example of what happen if we applay all the augmentations on our example.

In [7]:
#create transform
# Compose the custom augmentations with available augmentations.
data_aug = transforms.Compose(
    [
        TraceShift2(shift_ratio=0.01,contiguous_ratio=0.2),
        GaussianNoise(mean=0., std=0.05),
        MissingTraces(missing_trace_ratio=0.05),
        DeadTraces(dead_trace_ratio=0.04)
    ]
)
#transform image into torch tensor
torch_image = torch.tensor(image, dtype=torch.float32)
modified_image = data_aug(torch_image)
#plot the modified example
plt.imshow(modified_image[0], aspect='auto', cmap='gray',extent=[0,nb_traces,time_vector[-1],time_vector[0]])
plt.xlabel('Traces')
plt.ylabel('Time (s)')
plt.show()

# VII - Test a better way to shift the traces

The idea here is to shift only a small amount of traces, moreover we want to shift traces close one to each other.

In [8]:
#create transform
# Compose the custom augmentations with available augmentations.
data_aug = transforms.Compose(
    [
        TraceShift2(shift_ratio=0.01,contiguous_ratio=0.2)
    ]
)
#transform image into torch tensor
torch_image = torch.tensor(image, dtype=torch.float32)
modified_image = data_aug(torch_image)
#plot the modified example
plt.imshow(modified_image[0], aspect='auto', cmap='gray',extent=[0,nb_traces,time_vector[-1],time_vector[0]])
plt.xlabel('Traces')
plt.ylabel('Time (s)')
plt.show()