# 🏆 Every Layer, Everywhere, All at Once: Segmenting Subsurface

<img src='assets/banner.png'>

## 🌋 Context

Seismic data serves as a window into the Earth, aiding Geophysicists in identifying and mapping rock layers and structures, akin to how radiologists use MRI data for medical diagnoses. Its applications include reservoir identification, CO2 sequestration monitoring, oil and gas exploration, and groundwater management. However, interpreting seismic data presents challenges due to its geological complexity, including ambiguous features like faults and pinch outs.

The challenge provides around 9,000 pre-interpreted seismic volumes with segment masks for model training, each containing unique geological features. Participants are encouraged to use SAM (Seismic Attribute Mapping) in their 3D data segmentation pipelines. Solutions can involve tools like UNET for segment generation, feeding into the SAM model for refinement. The output should be a NumPy array with segment IDs for labeled intervals. Experimentation is encouraged with the provided training data split into four tranches for accessibility.

## 📚 Libraries

Our code run on `Python 3.10.13`

Installing external libraries

In [None]:
! pip install -r requirements.txt

Installing segmenting-subsurface librarie

In [None]:
! pip install -e .
%load_ext autoreload
%autoreload 2

In [None]:
# Generic
import os
import yaml

# Visualization
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Data manipulation
import numpy as np
import torch
import torchvision.transforms.functional as tvF

# Custom package (accessible into src directory)
from src import utils
import src.visualization.utils as vutils
import src.data.make_dataset as md
import src.models.segformer.train_model as segformer_tm
import src.models.mask2former.train_model as mask2former_tm
import src.features.segment_anything_inference as sam_inf
import src.features.segformer_inference as segformer_inf
import src.features.mask2former_inference as mask2former_inf

Connecting to wandb RosIA for demonstration account

In [None]:
# WandB Initialisation 
wandb_api_key = '02ca932e1203e93aaa8c97b8331d6c0b04c3170a' # Please do not share this api key
! wandb login --relogin {wandb_api_key}

## 📸 Data

In [None]:
volume_path = 'data/raw/train/69764103/seismic_block-2_vol_69764103.npy'
volume = np.load(volume_path, allow_pickle=True)

label_path = volume_path.replace('seismic', 'horizon_labels')
label = np.load(label_path, allow_pickle=True)

volume_hollow = vutils.get_volume_hollow(volume)
volume_plotly = vutils.get_plotly_volume(volume_hollow, colorscale='Greys')

label_hollow = vutils.get_volume_hollow(label)
label_plotly = vutils.get_plotly_volume(label_hollow, colorscale='Viridis')

fig = make_subplots(
    rows=1, cols=2,
    specs=[
        [{"type": "scatter3d"}, {"type": "scatter3d"}]
    ],
    subplot_titles=("Original volume","Labelised volume")
)
fig.add_trace(volume_plotly, row=1, col=1)
fig.add_trace(label_plotly, row=1, col=2)
fig.update_layout(showlegend=False)
fig.show()

del label_hollow, volume_hollow

## ⚒️ Preprocessing

In computer vision, the classic preprocessing steps for an image are as follows:

1. `Scaling`: Allows us to scale the values between 0 and 1. (Using a Min Max Scaler)

2. `Normalization`: Helps us achieve a Gaussian distribution of values for each channel. (Using a Standard Scaler)

3. `Rescaling`: If necessary, based on what input the model accepts. (Using a bilinear interpolation)

However, we have observed that in the images of our dataset, the objective is to delineate areas of varying brightness between them. That's why we decided to add `contrast` to highlight these differences in shade between the layers.

In [None]:
config = utils.get_config()

def scale(image):
    image = (image - config['data']['min']) / (config['data']['max'] - config['data']['min'])
    return image

def contrast(image):
    tensor = torch.from_numpy(image).unsqueeze(0)
    tensor = tvF.adjust_contrast(tensor, contrast_factor=25)
    
    return tensor.squeeze().numpy(force=True)

def normalize(image):
    tensor = torch.from_numpy(image).unsqueeze(0)
    tensor = tvF.normalize(tensor, mean=[config['data']['mean']], std=[config['data']['std']])
    
    return tensor.squeeze().numpy(force=True)

In [None]:
slice_idx = 200
image = volume[slice_idx, :, :].T

fig = make_subplots(
    rows=2, cols=2,
    specs=[
        [{"type": "heatmap"}, {"type": "heatmap"}],
        [{"type": "heatmap"}, {"type": "heatmap"}]
    ],
    subplot_titles=("0 - Original image", "1 - Scaled image", "2 - Contrasted image", "3 - Normalized image")
)

fig.add_trace(go.Heatmap(z=image.tolist(), showscale=False, colorscale='Greys'), row=1, col=1)
image = scale(image)
fig.add_trace(go.Heatmap(z=image.tolist(), showscale=False, colorscale='Greys'), row=1, col=2)
image = contrast(image)
fig.add_trace(go.Heatmap(z=image.tolist(), showscale=False, colorscale='Greys'), row=2, col=1)
image = normalize(image)
fig.add_trace(go.Heatmap(z=image.tolist(), showscale=False, colorscale='Greys'), row=2, col=2)
fig.update_layout(showlegend=False)
fig.show()

## 🤖 Segment Anything Model (SAM): Binary mask

We initially attempted to predict all the classes for each image by assuming that each label corresponded to a distinct rock layer. In other words, we placed ourselves in an instance segmentation problem with as many classes as distinct labels, and we fine-tuned a Mask2former model, but this approach did not yield good results.

Therefore, we pursued another approach. Since instance segmentation doesn't seem to produce results, we simplified the problem by converting the labels to 0 or 1 based on their parity. We transformed our problem into binary semantic segmentation. For this, we used a model from the [Segment Anything Model](https://arxiv.org/pdf/2304.02643.pdf) family.

In [None]:
def get_binary_label(label):
    binary_label = np.where(label % 2 == 0, 1, 0)

    return binary_label

fig = make_subplots(
    rows=1, cols=3,
    specs=[
        [{"type": "heatmap"}]*3
    ],
    subplot_titles=("Processed image", "Original label", "Binarized label")
)

fig.add_trace(go.Heatmap(z=image.tolist(), showscale=False, colorscale='Greys', hoverinfo='skip'), row=1, col=1)
image_label = label[slice_idx, :, :].T
fig.add_trace(go.Heatmap(z=image_label.tolist(), showscale=False, colorscale='viridis', hoverinfo='skip'), row=1, col=2)
binary_label = get_binary_label(image_label.copy())
fig.add_trace(go.Heatmap(z=binary_label.tolist(), showscale=False, colorscale='viridis', hoverinfo='skip'), row=1, col=3)

fig.update_layout(showlegend=False)
fig.show()

In [None]:
# Creating a fake WandB run to make the code work.

class RunDemo:
    def __init__(self, config_file, id, name) -> None:
        self.config = self.get_config(config_file)
        self.name = name
        self.id = id
    
    @staticmethod
    def get_config(config_file) -> dict:
        root = os.path.join('config', config_file)
        notebooks = os.path.join(os.pardir, root)
        path = root if os.path.exists(root) else notebooks

        with open(path, 'r') as f:
            config = yaml.safe_load(f)

        return config

In [None]:
# Using the SAM inference class to make predictions on the volume.

sam_inference = sam_inf.SAMInference(
    config=config,
    cuda_idx=0,
    list_volume=[volume_path],
    run=None,
    split='train',
    batch=5 # Reduce the batch size if you encounter cuda memory issues running the inference. Min 2 Max 300
)
volume_name = os.path.basename(volume_path)
binary_mask_path = sam_inference.get_mask_path(volume_name)
sam_dir = os.path.split(binary_mask_path)[0]
os.makedirs(sam_dir, exist_ok=True)
sam_inference()
binary_mask = np.load(binary_mask_path, allow_pickle=True)
sam_binary_pred = binary_mask[slice_idx, :, :].T.astype(np.uint8)


# Deleting the instance to free up memory space.
del sam_inference, binary_mask

In [None]:
fig = make_subplots(
    rows=1, cols=3,
    specs=[
        [{"type": "heatmap"}]*3
    ],
    subplot_titles=("Processed image", "Binarized label", "SAM Prediction")
)

fig.add_trace(go.Heatmap(z=image.tolist(), showscale=False, colorscale='Greys', hoverinfo='skip'), row=1, col=1)
fig.add_trace(go.Heatmap(z=binary_label.tolist(), showscale=False, colorscale='viridis', hoverinfo='skip'), row=1, col=2)
fig.add_trace(go.Heatmap(z=sam_binary_pred.tolist(), showscale=False, colorscale='viridis', hoverinfo='skip'), row=1, col=3)

fig.update_layout(showlegend=False)
fig.show()

We used SAM without fine-tuning because it provided sufficient predictions for our pipeline.

If you wish to create binary masks using SAM model from our solution, uncomment the following line and run the cell.

In [None]:
list_volume = md.get_volumes(config, 'test')
sam_inference = sam_inf.SAMInference(
    config=config,
    cuda_idx=0,
    list_volume=list_volume,
    run=None,
    split='test',
    batch=5
)

sam_dir = sam_inference.get_folder_path()
os.makedirs(sam_dir, exist_ok=True)
sam_inference()

del sam_inference

## 0️⃣1️⃣ Segformer: Binary mask

As the binary mask generated by SAM was not of sufficiently good quality, we refined the prediction by using another model to generate the binary mask. For this, we used a model from the [Segformer](https://proceedings.neurips.cc/paper/2021/file/64f1f27bf1b4ec22924fd0acb550c235-Paper.pdf) family.

In [None]:
fig = make_subplots(
    rows=1, cols=3,
    specs=[
        [{"type": "heatmap"}]*3
    ],
    subplot_titles=("Red channel (Processed image)", "Green channel (SAM prediction)", "Blue channel (Processed image)"),
)
fig.add_trace(go.Heatmap(z=image.tolist(), showscale=False, colorscale='Greys', hoverinfo='skip'), row=1, col=1)
fig.add_trace(go.Heatmap(z=sam_binary_pred.tolist(), showscale=False, colorscale='viridis', hoverinfo='skip'), row=1, col=2)
fig.add_trace(go.Heatmap(z=image.tolist(), showscale=False, colorscale='Greys', hoverinfo='skip'), row=1, col=3)
fig.update_layout(showlegend=False, title='Decomposition of the Segformer input image into channels (RGB)')
fig.show()

In [None]:
fig = make_subplots(
    rows=1, cols=2,
    specs=[
        [{"type": "heatmap"}]*2
    ],
    subplot_titles=("Original label", "Binarized label")
)

image_label = label[slice_idx, :, :].T
fig.add_trace(go.Heatmap(z=image.tolist(), showscale=False, colorscale='viridis', hoverinfo='skip'), row=1, col=1)
binary_label = get_binary_label(image_label)
fig.add_trace(go.Heatmap(z=binary_label.tolist(), showscale=False, colorscale='viridis', hoverinfo='skip'), row=1, col=2)

fig.update_layout(showlegend=False)
fig.show()

In [None]:
# Using the Segformer inference class to make predictions on the volume.

run = RunDemo('segformer.yml', id='mmw4795a', name='abundant-lantern-1231')

segformer_inference = segformer_inf.SegformerInference(
    config=config,
    cuda_idx=0,
    list_volume=[volume_path],
    run=run,
    split='train',
    batch=150 # Reduce the batch size if you encounter issues running the inference. Min 2 Max 300
)
volume_name = os.path.basename(volume_path)
binary_mask_path = segformer_inference.get_mask_path(volume_name)
segformer_dir = os.path.split(binary_mask_path)[0]
os.makedirs(segformer_dir, exist_ok=True)
segformer_inference()
binary_mask = np.load(binary_mask_path, allow_pickle=True)
seg_binary_pred = binary_mask[slice_idx, :, :].T.astype(np.uint8)


# Deleting the instance to free up memory space.
del segformer_inference, binary_mask

In [None]:
fig = make_subplots(
    rows=1, cols=3,
    specs=[
        [{"type": "heatmap"}]*3
    ],
    subplot_titles=("Processed image", "Binarized label", "Segformer Prediction")
)

fig.add_trace(go.Heatmap(z=image.tolist(), showscale=False, colorscale='Greys', hoverinfo='skip'), row=1, col=1)
fig.add_trace(go.Heatmap(z=binary_label.tolist(), showscale=False, colorscale='viridis', hoverinfo='skip'), row=1, col=2)
fig.add_trace(go.Heatmap(z=seg_binary_pred.tolist(), showscale=False, colorscale='viridis', hoverinfo='skip'), row=1, col=3)

fig.update_layout(showlegend=False)
fig.show()

With this approach, we achieved a dice score of `0.811` and an intersection over union (IOU) of `0.6857` on the binary masks.

If you wish to create binary masks using Segformer model from our solution, uncomment the following line and run the cell.

You need to run the SAM inference before or change `data/processed/test/facebook_sam-vit-base-original` -> `data/processed/test/facebook_sam-vit-base`

In [None]:
run = RunDemo('segformer.yml', id='mmw4795a', name='abundant-lantern-1231')
list_volume = md.get_volumes(config, 'test')

segformer_inference = segformer_inf.SegformerInference(
    config=config,
    cuda_idx=0,
    list_volume=list_volume,
    run=run,
    split='test',
    batch=5
)

segformer_dir = segformer_inference.get_folder_path()
os.makedirs(segformer_dir, exist_ok=True)
segformer_inference()

del segformer_inference

## 🎭 Mask2former: Instance mask

Once we have obtained our binary mask, we need to obtain each layer independently of its parity in order to create a prompt as precise as possible for the Segment Anything model.

For this purpose, we will use another segmentation model called [Mask2former](http://openaccess.thecvf.com/content/CVPR2022/papers/Cheng_Masked-Attention_Mask_Transformer_for_Universal_Image_Segmentation_CVPR_2022_paper.pdf). This model will take as input the combination of the binary mask from Segformer and the original image.

In [None]:
def get_instance_label(label):
    instance_label = np.full(label.shape, np.nan)
    old_labels = np.unique(label)
    new_labels = range(len(old_labels))
    for old_label, new_label in zip(old_labels, new_labels):
        instance_label = np.where(label == old_label, new_label, instance_label)

    return instance_label

In [None]:
fig = make_subplots(
    rows=1, cols=3,
    specs=[
        [{"type": "heatmap"}]*3
    ],
    subplot_titles=("Red channel (Seg prediction)", "Green channel (Original image)", "Blue channel (Seg prediction)"),
)
fig.add_trace(go.Heatmap(z=seg_binary_pred.tolist(), showscale=False, colorscale='viridis', hoverinfo='skip'), row=1, col=1)
fig.add_trace(go.Heatmap(z=image.tolist(), showscale=False, colorscale='Greys', hoverinfo='skip'), row=1, col=2)
fig.add_trace(go.Heatmap(z=seg_binary_pred.tolist(), showscale=False, colorscale='viridis', hoverinfo='skip'), row=1, col=3)
fig.update_layout(showlegend=False, title='Decomposition of the Mask2former input image into channels (RGB)')
fig.show()

Its objective will be to predict a variant of the original label. The IDs of the original masks are set to 0 up to the number of masks present in the original label.

In [None]:
image_label = label[slice_idx, :, :].T
variant_label = get_instance_label(image_label)

fig = make_subplots(
    rows=1, cols=2,
    specs=[
        [{"type": "heatmap"}]*2
    ],
    subplot_titles=("Original label", "Variant label"),
)
fig.add_trace(go.Heatmap(z=image_label.tolist(), showscale=False, colorscale='viridis', hoverinfo='skip'), row=1, col=1)
fig.add_trace(go.Heatmap(z=variant_label.tolist(), showscale=False, colorscale='viridis', hoverinfo='skip'), row=1, col=2)
fig.update_layout(showlegend=False)
fig.show()

In [None]:
# Using the Mask2former inference class to make predictions on the volume.

run = RunDemo('mask2former.yml', id='nvbtr9k2', name='vermilion-moon-1241')

mask2former_inference = mask2former_inf.Mask2formerInference(
    config=config,
    cuda_idx=0,
    list_volume=[volume_path],
    run=run,
    split='train',
    batch=25 # Reduce the batch size if you encounter cuda memory issues running the inference. Min 2 Max 300
)
volume_name = os.path.basename(volume_path)
instance_mask_path = mask2former_inference.get_mask_path(volume_name)
mask2former_dir = os.path.split(instance_mask_path)[0]
os.makedirs(mask2former_dir, exist_ok=True)

mask2former_inference()
instance_mask = np.load(instance_mask_path, allow_pickle=True)
intance_pred = instance_mask[slice_idx, :, :].T

del mask2former_inference, instance_mask

In [None]:
fig = make_subplots(
    rows=1, cols=3,
    specs=[
        [{"type": "heatmap"}]*3
    ],
    subplot_titles=("Original image", "Variant label", "Mask2former Prediction")
)

fig.add_trace(go.Heatmap(z=image.tolist(), showscale=False, colorscale='Greys', hoverinfo='skip'), row=1, col=1)
fig.add_trace(go.Heatmap(z=variant_label.tolist(), showscale=False, colorscale='viridis', hoverinfo='skip'), row=1, col=2)
fig.add_trace(go.Heatmap(z=intance_pred.tolist(), showscale=False, colorscale='viridis', hoverinfo='skip'), row=1, col=3)

fig.update_layout(showlegend=False)
fig.show()

Since we do not want any class or background, we have chosen to set the weights of the focus loss and the cross-entropy loss to 1, while the dice loss has a weight of 10. Refer to [Hugging Face](https://huggingface.co/docs/transformers/v4.37.2/en/model_doc/mask2former#transformers.Mask2FormerForUniversalSegmentation).

With this approach, we achieved a validation loss of `0.811` on the instance masks.

If you wish to create instance masks using Mask2former model from our solution, uncomment the following line and run the cell.

You need to run the Segformer inference before or change `data/processed/test/abundant-lantern-1231-mmw4795a-original` -> `data/processed/test/abundant-lantern-1231-mmw4795a`

In [None]:
run = RunDemo('mask2former.yml', id='nvbtr9k2', name='vermilion-moon-1241')
list_volume = md.get_volumes(config, 'test')

mask2former_inference = mask2former_inf.Mask2formerInference(
    config=config,
    cuda_idx=0,
    list_volume=list_volume,
    run=run,
    split='test',
    batch=25
)
mask2former_dir = mask2former_inference.get_folder_path()
os.makedirs(mask2former_dir, exist_ok=True)
mask2former_inference()

del mask2former_inference

## 🧑🏻‍💻 Code Submission

If you want to change the configuration of the models, please refer to the YAML file available in the config folder.

The script takes volumes from the data/raw/train folder for training and data/raw/test for inference. A directory data/processed/[train, test]/{run_id} is created for all intermediate masks.

In [None]:
# WandB Initialisation 
wandb_api_key = '02ca932e1203e93aaa8c97b8331d6c0b04c3170a' # Please do not share this api key
! wandb login --relogin {wandb_api_key}

In [None]:
from src import utils
import src.models.segformer.train_model as segformer_tm
import src.models.mask2former.train_model as mask2former_tm
import src.features.segment_anything_inference as sam_inf
import src.features.segformer_inference as segformer_inf
import src.features.mask2former_inference as mask2former_inf
import src.data.make_dataset as md
import wandb
import os
import yaml

# Creating a fake WandB run to make the code work.

class RunDemo:
    def __init__(self, config_file, id, name) -> None:
        self.config = self.get_config(config_file)
        self.name = name
        self.id = id
    
    @staticmethod
    def get_config(config_file) -> dict:
        root = os.path.join('config', config_file)
        notebooks = os.path.join(os.pardir, root)
        path = root if os.path.exists(root) else notebooks

        with open(path, 'r') as f:
            config = yaml.safe_load(f)

        return config

class SegmentationPipeline:
    def __init__(self, sam_batch, segformer_batch, mask2former_batch) -> None:
        self.config = utils.get_config()
        self.sam_batch = sam_batch
        self.segformer_batch = segformer_batch
        self.mask2former_batch = mask2former_batch
        
    def make_sam_inference(self, split):
        sam_inference = sam_inf.SAMInference(
            config=self.config,
            cuda_idx=0,
            list_volume=md.get_volumes(self.config, split),
            run=None,
            split=split,
            batch=self.sam_batch
        )
        os.makedirs(sam_inference.get_folder_path(), exist_ok=True)
        sam_inference()
        del sam_inference
        
    def make_segformer_inference(self, split, segformer_id):
        segformer_inference = segformer_inf.SegformerInference(
            config=self.config,
            cuda_idx=0,
            list_volume=md.get_volumes(self.config, split),
            run=RunDemo('segformer.yml', **segformer_id),
            split=split,
            batch=self.sam_batch
        )
        os.makedirs(segformer_inference.get_folder_path(), exist_ok=True)
        segformer_inference()
        del segformer_inference
        
    def make_mask2former_inference(self, split, mask2former_id):
        mask2former_inference = mask2former_inf.Mask2formerInference(
            config=self.config,
            cuda_idx=0,
            list_volume=md.get_volumes(self.config, split),
            run=RunDemo('mask2former.yml', **mask2former_id),
            split=split,
            batch=self.sam_batch
        )
        os.makedirs(mask2former_inference.get_folder_path(), exist_ok=True)
        mask2former_inference()
        del mask2former_inference
    
    def train(self):
        print('Create Segment Anything binary masks...')
        self.make_sam_inference(split='train')
        print('Segment Anything binary masks done!')
        print('Train Segformer...')
        segformer_id = self.train_segformer()
        print('Segformer training finished!')
        print('Create Segformer binary masks...')
        self.make_segformer_inference(split='train', segformer_id=segformer_id)
        print('Segformer binary masks done!')
        print('Train Mask2former')
        mask2former_id = self.train_mask2former(segformer_id)
        print('Mask2former training finished!')
        
        return {'segformer_id': segformer_id, 'mask2former_id': mask2former_id}
    
    def train_segformer(self):
        wandb_config = utils.init_wandb('segformer.yml')
        segformer_id = {'name': wandb.run.name, 'id': wandb.run.id}
        trainer = segformer_tm.get_trainer(self.config)
        lightning = segformer_tm.get_lightning(self.config, wandb_config)
        trainer.fit(model=lightning)
        wandb.finish()
        
        return segformer_id
    
    def train_mask2former(self, segformer_id):
        utils.init_wandb('segformer.yml')
        wandb.config.update(segformer_id=f'{segformer_id["name"]}-{segformer_id["id"]}')
        wandb_config = wandb.config
        mask2former_id = {'name': wandb.run.name, 'id': wandb.run.id}
        trainer = mask2former_tm.get_trainer(self.config)
        lightning = mask2former_tm.get_lightning(self.config, wandb_config)
        trainer.fit(model=lightning)
        wandb.finish()
        
        return mask2former_id
    
    def predict(self, segformer_id, mask2former_id):
        print('Create Segment Anything binary masks...')
        self.make_sam_inference(split='test')
        print('Segment Anything binary masks done!')
        print('Create Segformer binary masks...')
        self.make_segformer_inference(split='test', segformer_id=segformer_id)
        print('Segformer binary masks done!')
        print('Create Mask2former instance masks...')
        self.make_mask2former_inference(split='test', mask2former_id=mask2former_id)
        print('Mask2former instance masks done!')

In [None]:
# batch size  Min 2 Max 300
segmentation_pipeline = SegmentationPipeline(sam_batch=5, segformer_batch=150, mask2former_batch=25)
segmentation_pipeline.train()
# segmentation_pipeline.predict()