# I- Begin like ViT_TransferLearning.py

In [36]:
import os
import glob
import h5py as h5
import shutil
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
from random import random
import random
import matplotlib.pyplot as plt
import torch.nn as nn
from transformers import ViTModel
import torch.optim as optim
from Utilities import *

# Paths
data_path = '/home/rbertille/data/pycharm/ViT_project/pycharm_Geoflow/GeoFlow/Tutorial/Datasets/'
dataset_name = 'TutorialDataset'
files_path = os.path.join(data_path, dataset_name)

train_folder = glob.glob(f'{files_path}/train/*')
validate_folder = glob.glob(f'{files_path}/validate/*')
test_folder = glob.glob(f'{files_path}/test/*')

#def data augmentation
data_aug = transforms.Compose(
    [
        DeadTraces(),
    ]
)

class CustomDataset(Dataset):
    def __init__(self, folder, transform=None):
        self.data, self.labels = self.load_data_from_folder(folder)
        self.transform = transform #ici

    def load_data_from_folder(self, folder):
        data = []
        labels = []

        for file_path in folder:
            with h5.File(file_path, 'r') as h5file:
                inputs = h5file['shotgather'][:]
                #take second half only= Z component
                inputs = inputs[:,int(inputs.shape[1]/2):]
                labels_data = h5file['vsdepth'][:]

                # print('data shape:',inputs.shape)
                # print('min data=',np.min(inputs))
                inputs = (inputs - np.min(inputs)) / (np.max(inputs) - np.min(inputs))

                # reshape data
                inputs = torch.tensor(inputs, dtype=torch.float32)
                transform_resize = transforms.Compose([
                    transforms.ToPILImage(),
                    transforms.Resize((224, 224)),
                    transforms.ToTensor()
                ])
                inputs = transform_resize(inputs)

                if inputs.shape[0] == 1:  # Si l'image est en grayscale
                    inputs = inputs.repeat(3, 1, 1)  # Convertir en RGB
                inputs = inputs.numpy()

                data.append(inputs)
                labels.append(labels_data)

        data = np.array(data)

        labels = np.array(labels)
        return data, labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        inputs = self.data[idx]
        labels = self.labels[idx]

        # Convert inputs and labels to Tensors
        inputs = torch.tensor(inputs, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.float32)

        sample = {'data': inputs, 'label': labels}

        if self.transform:
            sample['data'] = self.transform(sample['data']) #ici

        return sample

def create_datasets(data_path, dataset_name):
    train_folder = glob.glob(os.path.join(data_path, dataset_name, 'train', '*'))
    validate_folder = glob.glob(os.path.join(data_path, dataset_name, 'validate', '*'))
    test_folder = glob.glob(os.path.join(data_path, dataset_name, 'test', '*'))

    train_dataset = CustomDataset(train_folder,transform=data_aug)
    validate_dataset = CustomDataset(validate_folder)
    test_dataset = CustomDataset(test_folder)

    return train_dataset, validate_dataset, test_dataset


train_dataset, validate_dataset, test_dataset = create_datasets(data_path, dataset_name)

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=None)

In [37]:
def plot_a_sample(train_dataloader, i=0):
    fig, axs = plt.subplots(1, 2, figsize=(12, 5 * 1))

# Sélectionner le premier échantillon
    sample = train_dataloader.dataset[i]
    image = sample['data']
    label= sample['label']
    print('shape image=',image.shape)

    # Afficher l'image dans la première colonne
    axs[0].imshow(image[0], aspect='auto', cmap='gray')
    axs[0].set_title(f'Original Shot Gather {i + 1} reshaped 224x224')
    axs[0].set_xlabel('Distance (grid points, reshaped)')
    axs[0].set_ylabel('Time (dt, reshaped)')

    # Afficher le label dans la deuxième colonne
    axs[1].plot(label, range(len(label)))
    axs[1].invert_yaxis()
    axs[1].set_xlabel('Vs (m/s)')
    axs[1].set_ylabel('Depth (grid points)')
    axs[1].set_title(f'Vs Depth ')

    plt.tight_layout()
    plt.show()
    
plot_a_sample(train_dataloader, i=0)

# II - Data Augmentation

In [38]:
# Add dead traces
def add_dead_traces(image, dead_trace_ratio=0.04):
    ''''
    Replace some random traces by dead traces to the data: a dead trace is a trace (=column) with all values set to 1
    
    :param image: 3D tensor of shape (C, H, W) where C is the number of channels, H is the height and W is the width
    :param dead_trace_ratio: From 0 to 1, the ratio of dead traces to add
    
    :return: augmented_data: 3D tensor of shape (C, H, W) with some dead traces
    '''
    # Make a copy of the original image
    augmented_data = image.clone() if isinstance(image, torch.Tensor) else image.copy()
    
    #count colums in the data
    num_columns = augmented_data[0].shape[1]
    #print('nb traces=',num_columns)
    #choose missing_trace_ratio % of the traces to be dead traces randomly
    num_dead_traces = int(num_columns * dead_trace_ratio)
    #print('nb dead traces',num_dead_traces)
    #choose the indices of the dead traces
    dead_traces_indices = random.sample(range(num_columns), num_dead_traces)
    #print('indices of the dead traces:',dead_traces_indices)
    
    #set the values of the dead traces to 1
    augmented_data[:,:, dead_traces_indices] = 1
    return augmented_data

In [39]:
#same aumgentation but, but we want to create a custom transform to apply it on the data without take to much memory

class DeadTraces:
    """
    Applies dead traces to a shotgather.
    """

    def __init__(self, dead_trace_ratio=0.04):
        self.dead_trace_ratio = dead_trace_ratio
    def add_dead_traces(self,image, dead_trace_ratio):
        ''''
        Replace some random traces by dead traces to the data: a dead trace is a trace (=column) with all values set to 1

        :param image: 3D tensor of shape (C, H, W) where C is the number of channels, H is the height and W is the width
        :param dead_trace_ratio: From 0 to 1, the ratio of dead traces to add

        :return: augmented_data: 3D tensor of shape (C, H, W) with some dead traces
        '''
        img2D=0
        #verify if image is (h,w) or (c,h,w):
        if len(image.shape)==2:
            image = image.unsqueeze(0)
            img2D=1



        # Make a copy of the original image
        augmented_data = image.clone() if isinstance(image, torch.Tensor) else image.copy()

        # count colums in the data
        num_columns = augmented_data[0].shape[1]
        # print('nb traces=',num_columns)
        # choose missing_trace_ratio % of the traces to be dead traces randomly
        num_dead_traces = int(num_columns * dead_trace_ratio)
        # print('nb dead traces',num_dead_traces)
        # choose the indices of the dead traces
        dead_traces_indices = random.sample(range(num_columns), num_dead_traces)
        # print('indices of the dead traces:',dead_traces_indices)

        # set the values of the dead traces to 1
        augmented_data[:, :, dead_traces_indices] = 1
        if img2D==1:
            augmented_data=augmented_data.squeeze(0)
        return augmented_data

    def __call__(self, sample: torch.Tensor) -> torch.Tensor:
        return self.add_dead_traces(sample, self.dead_trace_ratio)


In [40]:
#apply the transform on the first example

#create transform
# Compose the custom augmentations with available augmentations.
data_aug = transforms.Compose(
    [
        DeadTraces(),
    ]
)

#apply on our example:
image=train_dataloader.dataset.data[0]
modified_image = data_aug(image)
print('shape image:',train_dataloader.dataset.data[0].shape)
#plot the modified example
plt.imshow(modified_image[0], aspect='auto', cmap='gray')


In [41]:
# Add missing signal: replace some traces by average value
def Missing_traces(image, missing_trace_ratio=0.04):
    ''''
    Replace some random traces by missing traces to the data: a missing trace is a trace (=column) with all values set to the average value of the trace
    
    :param image: 3D tensor of shape (C, H, W) where C is the number of channels, H is the height and W is the width
    :param missing_trace_ratio: From 0 to 1, the ratio of missing traces to add
    
    :return: augmented_data: 3D tensor of shape (C, H, W) with some missing traces
    '''
    # Make a copy of the original image
    augmented_data = image.clone() if isinstance(image, torch.Tensor) else image.copy()
    
    #count colums in the data
    num_columns = augmented_data[0].shape[1]
    #print('nb traces=',num_columns)
    #choose missing_trace_ratio % of the traces to be dead traces randomly
    num_missing_traces = int(num_columns * missing_trace_ratio)
    #print('nb missing traces',num_dead_traces)
    #choose the indices of the dead traces
    missing_traces_indices = random.sample(range(num_columns), num_missing_traces)
    #print('indices of the missing traces:',missing_traces_indices)
    
    average_value = augmented_data[:,:, missing_traces_indices].mean()
    
    #set the values of the dead traces to average
    augmented_data[:,:, missing_traces_indices] = average_value
    return augmented_data

#Apply on first example
#plot example 1, which is the first original example
plot_a_sample(train_dataloader, i=0)
# apply
image = train_dataloader.dataset.data[0]
image_dead = Missing_traces(image, missing_trace_ratio=0.04)
#plot the modified example
plt.imshow(image_dead[0], aspect='auto', cmap='gray')

In [42]:
class MissingTraces:
    """
    Applies dead traces to a shotgather.
    """
    def __init__(self, missing_trace_ratio=0.04):
        self.missing_trace_ratio = missing_trace_ratio

    def Missing_traces(self,image, missing_trace_ratio):
        ''''
        Replace some random traces by missing traces to the data: a missing trace is a trace (=column) with all values set to the average value of the trace

        :param image: 3D tensor of shape (C, H, W) where C is the number of channels, H is the height and W is the width
        :param missing_trace_ratio: From 0 to 1, the ratio of missing traces to add

        :return: augmented_data: 3D tensor of shape (C, H, W) with some missing traces
        '''

        img2D = 0
        # verify if image is (h,w) or (c,h,w):
        if len(image.shape) == 2:
            image = image.unsqueeze(0)
            img2D = 1

        # Make a copy of the original image
        augmented_data = image.clone() if isinstance(image, torch.Tensor) else image.copy()

        # count colums in the data
        num_columns = augmented_data[0].shape[1]
        # print('nb traces=',num_columns)
        # choose missing_trace_ratio % of the traces to be dead traces randomly
        num_missing_traces = int(num_columns * missing_trace_ratio)
        # print('nb missing traces',num_dead_traces)
        # choose the indices of the dead traces
        missing_traces_indices = random.sample(range(num_columns), num_missing_traces)
        # print('indices of the missing traces:',missing_traces_indices)

        average_value = augmented_data[:, :, missing_traces_indices].mean()

        # set the values of the dead traces to average
        augmented_data[:, :, missing_traces_indices] = average_value
        if img2D==1:
            augmented_data=augmented_data.squeeze(0)

        return augmented_data

    def __call__(self, sample: torch.Tensor) -> torch.Tensor:
        return self.Missing_traces(sample, self.missing_trace_ratio)
    
#create a transform using MissingTraces
data_aug = transforms.Compose(
    [
        MissingTraces(),
        DeadTraces()
    ]
)

#apply on our example:
image=train_dataloader.dataset.data[0]
modified_image = data_aug(image)
print('shape image:',train_dataloader.dataset.data[0].shape)
#plot the modified example
plt.imshow(modified_image[0], aspect='auto', cmap='gray')

In [43]:
train_dataset, validate_dataset, test_dataset = create_datasets(data_path, dataset_name)

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=None)

In [44]:
#plot an example
plot_a_sample(train_dataloader, i=0)

# II - For ViT_network.py


In [45]:
class CustomDataset(Dataset):
    def __init__(self, folder, transform=None):
        self.data, self.labels = self.load_data_from_folder(folder)
        self.transform = transform

    def load_data_from_folder(self, folder):
        data = []
        labels = []

        for file_path in folder:
            with h5.File(file_path, 'r') as h5file:
                inputs = h5file['shotgather'][:]
                len_inp = inputs.shape[1]
                half_ind= len_inp//2
                inputs = h5file['shotgather'][:,half_ind:]
                labels_data = h5file['vsdepth'][:]

                inputs = (inputs - np.min(inputs)) / (np.max(inputs) - np.min(inputs))
                
                transform_tensor = transforms.Compose([
                    transforms.ToTensor()
                ])
                
                
                inputs = transform_tensor(inputs)

                data.append(inputs)
                labels.append(labels_data)

        data = np.array(data)


        labels = np.array(labels)
        return data, labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        inputs = self.data[idx]
        labels = self.labels[idx]

        # Convert inputs and labels to Tensors
        inputs = torch.tensor(inputs, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.float32)


        if self.transform:
            inputs = self.transform(inputs)
            
        sample = {'data': inputs, 'label': labels}

        return sample

def create_datasets(data_path, dataset_name):
    train_folder = glob.glob(os.path.join(data_path, dataset_name, 'train', '*'))
    validate_folder = glob.glob(os.path.join(data_path, dataset_name, 'validate', '*'))
    test_folder = glob.glob(os.path.join(data_path, dataset_name, 'test', '*'))

    train_dataset = CustomDataset(train_folder,transform=data_aug)
    validate_dataset = CustomDataset(validate_folder)
    test_dataset = CustomDataset(test_folder)

    return train_dataset, validate_dataset, test_dataset


train_dataset, validate_dataset, test_dataset= create_datasets(data_path, dataset_name)

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

In [46]:
plot_a_sample(train_dataloader, i=0)

In [52]:
#add gaussian noise

class GaussianNoise:
    """
    Adds Gaussian noise to a shotgather.
    """
    def __init__(self, mean=0., std=0.01):
        self.std = std
        self.mean = mean

    def add_gaussian_noise(self, image, mean, std):
        '''
        Add Gaussian noise to the data

        :param image: 3D tensor of shape (C, H, W) where C is the number of channels, H is the height and W is the width
        :param mean: mean of the Gaussian noise
        :param std: standard deviation of the Gaussian noise

        :return: augmented_data: 3D tensor of shape (C, H, W) with Gaussian noise
        '''
        # Make a copy of the original image
        augmented_data = image.clone() if isinstance(image, torch.Tensor) else image.copy()

        # Add Gaussian noise
        noise = torch.randn(augmented_data.shape) * std + mean
        augmented_data += noise

        return augmented_data

    def __call__(self, sample: torch.Tensor) -> torch.Tensor:
        return self.add_gaussian_noise(sample, self.mean, self.std)

# Exemple d'utilisation :
data_aug = transforms.Compose(
    [
        GaussianNoise(mean=0., std=0.1),
        MissingTraces(),
        DeadTraces()
    ]
)

# Assurez-vous que l'initialisation et l'utilisation de votre Dataset et DataLoader soient correctes :
train_dataset, validate_dataset, test_dataset = create_datasets(data_path, dataset_name)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=None)

plot_a_sample(train_dataloader, i=0)
