# 8Ô∏è‚É£ Reflection Connection: Bringing New Algorithms to Old Data

<img src=assets/reflection-connection.png>

## üåã Context

For those unfamiliar with geophysics, seismic images can appear enigmatic: layers of black-and-white waveforms stacked atop one another. However, with increasing familiarity, distinct features within seismic data become discernible. These features signify prevalent geological formations such as river channels, salt pans, or faults. The process of identifying seismic features parallels that of a medical professional distinguishing between anatomical structures on an echocardiogram. Geoscientists amalgamate these identified features to formulate hypotheses regarding the geological evolution of the surveyed area. An algorithm capable of pinpointing specific segments within seismic images holds the promise of empowering geoscientists to construct more comprehensive hypotheses, thereby facilitating the integration of diverse datasets into a cohesive model of Earth's composition and evolution.

## üß† Our approach

## üìö Libraries

Our code run on Python 3.10.13

Installing external libraries

In [None]:
! pip install -r requirements.txt

Installing reflection-connection librarie

In [None]:
! pip install -e .

In [None]:
# Generic
import os
import random
import multiprocessing as mp

import matplotlib.pyplot as plt

from PIL import Image
import numpy as np
import torch
import torchvision.transforms.v2.functional as tvF

import wandb

from src import utils
import src.models.utils as mutils
import src.data.make_pretrain_data as mpd
import src.data.datasets.inference as inference_d
from src.models.inference import EmbeddingsBuilder
from src.models.retriever import FaissRetriever
from src.submissions.make_submissions import ResultBuilder, dist_to_conf
from src.models.iterative import IterativeTrainer

Connecting to fake wandb account

In [None]:
# W&B initialisation
os.environ["WANDB_MODE"]="offline"
wandb_api_key = 'X'*40
! wandb login --relogin {wandb_api_key}

Creating a multiprocessing wrapper because PyTorch struggles to manage between the models loaded on the main process and the subprocesses. It's necessary to pass everything to subprocesses.

In [None]:
from typing import Any

def multiprocess_wrapper(func):
        def wrapper(*args: Any, **kwargs: Any):
            p = mp.Process(target=func, args=args, kwargs=kwargs)
            p.start()
            p.join()
            
        return wrapper

## üì∏ Data

In [None]:
# Chemin vers le r√©pertoire contenant les dossiers de classes
base_path = "data/raw/train/"
# Liste des noms de classe (nom des dossiers)
class_names = [folder for folder in os.listdir(base_path) if os.path.isdir(base_path + folder)]
# Cr√©ez une grille de sous-graphiques
num_classes = len(class_names)
num_images_per_class = 5
fig, axs = plt.subplots(num_images_per_class, num_classes, figsize=(20, 15))

# Parcourez chaque classe
for i, class_name in enumerate(class_names):
    # Chemin vers le dossier de la classe
    class_path = os.path.join(base_path, class_name)

    # Liste des fichiers image dans le dossier de la classe
    image_files = os.listdir(class_path)

    # S√©lectionnez al√©atoirement 10 images de la classe
    selected_images = random.sample(image_files, num_images_per_class)

    # Affichez les images dans la colonne correspondante
    for j, image_file in enumerate(selected_images):
        # Chemin complet vers l'image
        image_path = os.path.join(class_path, image_file)

        # Lisez l'image et affichez-la dans le sous-graphique correspondant
        img = Image.open(image_path)
        axs[j, i].imshow(img)
        axs[j, i].axis('off')

    # Ajoutez le nom de la classe comme titre pour la colonne
    axs[0, i].set_title(class_name, fontsize=16, fontweight='bold', loc='center', pad=20)

# Ajustez l'espacement entre les sous-graphiques
plt.tight_layout()

# Affichez le graphique de mosa√Øque
plt.show()

## ‚öíÔ∏è Preprocessing

In computer vision, the classic preprocessing steps for an image are as follows:

Scaling: Allows us to scale the values between 0 and 1. (Using a Min Max Scaler)

Normalization: Helps us achieve a Gaussian distribution of values for each channel. (Using a Standard Scaler)

Rescaling: If necessary, based on what input the model accepts. (Using a bicubic interpolation)

Cropping: If necessary, based on what input the model accepts. (Using random crop for training and center crop for inference)

However, we have observed that in the images of our dataset, the objective is to delineate areas of varying brightness between them. That's why we decided to add contrast to highlight these differences in shade between the layers.

In [None]:
def processing_wrapper(func):
    def mp_wrapper(image):
            manager = mp.Manager()
            namespace = manager.Namespace()
            p = mp.Process(target=func, args=(image, namespace))
            p.start()
            p.join()

            return namespace.image
    
    return mp_wrapper
        
    
@processing_wrapper
def scale(image, namespace):
    image = tvF.to_image(image)
    image = tvF.to_dtype_image(image, torch.float32, scale=True)
    
    namespace.image = np.transpose(image.numpy(force=True), (1, 2, 0))

@processing_wrapper
def contrast(image, namespace):
    image = tvF.to_image(image)
    image = tvF.adjust_contrast(image, contrast_factor=10)
    
    namespace.image = np.transpose(image.numpy(force=True), (1, 2, 0))

@processing_wrapper
def resize(image, namespace):
    image = tvF.to_image(image)
    image = tvF.resize(image, size=256, interpolation=tvF.InterpolationMode.BICUBIC)

    namespace.image = np.transpose(image.numpy(force=True), (1, 2, 0))

@processing_wrapper
def crop(image, namespace):
    image = tvF.to_image(image)
    image = tvF.center_crop(image, output_size=224)

    namespace.image = np.transpose(image.numpy(force=True), (1, 2, 0))

In [None]:
# Define a function to plot the image
def plot_image(image, title, subplot_pos):
    plt.subplot(*subplot_pos)
    plt.imshow(image, cmap='gray')
    plt.title(title)
    # plt.axis('off')

# Create subplots
plt.figure(figsize=(10, 10))

image = Image.open('data/raw/test/image_corpus/afbaz.png').convert("RGB")

# Original image
plot_image(np.array(image), '0 - Original image', (2, 3, 1))

# Scaled image
scaled_image = scale(image)
plot_image(scaled_image, '1 - Scaled image', (2, 3, 2))

# Contrasted image
contrasted_image = contrast(scaled_image)
plot_image(contrasted_image, '2 - Contrasted image', (2, 3, 3))

# Resized image
resized_image = resize(contrasted_image)
plot_image(resized_image, '3 - Resized image', (2, 3, 4))

# Cropped image
cropped_image = crop(resized_image)
plot_image(cropped_image, '4 - Cropped image', (2, 3, 5))

plt.tight_layout()
plt.show()


## üñ®Ô∏è Data Augmentation

For data augmentation, we relied on the Patch the Planet: Restore Missing Data challenge datasets. We took all available volumes and created PNG images by cropping the slides according to the distribution of image dimensions in the challenge. These crops were then scaled between 0 and 255 using the image creation function provided in the challenge. This resulted in extracting just over a million images.

In [None]:
# Define a function to plot the image
def plot_image(image, title, subplot_pos, fontsize):
    plt.subplot(*subplot_pos)
    plt.imshow(image, cmap='gray')
    plt.title(title, fontsize=fontsize)
    # plt.axis('off')

volume_name = '0kamixt53o'
volume = np.load(f'data/raw/pretrain/patch-the-planet-real-train-data/{volume_name}.npy')
slice_idx = random.randint(0, 300)
slice_array = volume[slice_idx, :, :].T

# Create subplots
plt.figure(figsize=(10, 10))
plt.title(f'Volume {volume_name}, slice {slice_idx}', pad=30)
plt.axis('off')

values, counts = mpd.get_values_counts(utils.get_config())
tiles_coords = mpd.get_tiles_coords(values, counts)
for i, (x0, x1, y0, y1) in enumerate(tiles_coords):
    tile = slice_array[x0:x1, y0:y1]
    tile = mpd.normalize_pretrain_slice(tile)
    image = Image.fromarray(tile).convert('RGB')
    plot_image(image, f'Imange: x_min: {x0}, x_max: {x1}, y_min: {y0}, y_max {y1}', (2, 2, i+1), 8)

plt.tight_layout()
plt.show()

## üß© Pretraining

Mettre un texte d'exliation sur le pr√©training dinov2 + vitmae mettre un visuel image originale -> image masqu√©e -> image reconstitu√©e et proposer de lancer un pretraining

## üéõÔ∏è Fine-Tuning

For model fine-tuning, we employed the triplet loss, removing boring images from the anchors (but not from the pool of negative images) as they were not in the final dataset. We tested two distances: Euclidean and Cosine distances. Several models were tested for their ability to produce one-shot predictions, including [CLIP](https://arxiv.org/pdf/2103.00020.pdf), [DINOv2](https://arxiv.org/pdf/2304.07193.pdf), [VITMAE](https://arxiv.org/pdf/2111.06377v2.pdf). We also tested the [ViT](https://arxiv.org/pdf/2010.11929.pdf) models from PyTorch proposed by Onward.

In [None]:
# Launch a fine-tuning using a model by selecting a configuration file
@multiprocess_wrapper
def fine_tune():
    config = utils.get_config()
    # CLIP from OpenAI
    # wandb_config = utils.init_wandb('fine_tuning/clip.yml')
    # DINOv2 from META
    # wandb_config = utils.init_wandb('fine_tuning/dinov2.yml')
    # VITMAE from META
    # wandb_config = utils.init_wandb('fine_tuning/vitmae.yml')
    # VIT from Hugging Face
    # wandb_config = utils.init_wandb('fine_tuning/vit.yml', 'transformers')
    # VIT from Pytorch
    wandb_config = utils.init_wandb('fine_tuning/vit.yml', 'torchvision')
    trainer = mutils.get_trainer(config)
    lightning = mutils.get_lightning(config, wandb_config)
    trainer.fit(model=lightning)
    wandb.finish()

    del lightning, trainer
    torch.cuda.empty_cache()

# Model loading issue without using a process
fine_tune()

## üïµÔ∏è Retriever & Submissions

For the similar image search part, we utilized the [FAISS](https://ai.meta.com/tools/faiss/) library provided by Meta. It enables us to conduct a brute force search among all image embeddings in our corpus and retrieve the most similar images efficiently. Additionally, we can choose the distance metric based on the triplet loss. Thus, we implemented two types of retrievers, one based on Euclidean similarity and the other on Cosine similarity.

In [None]:
# Specify the ID of the model you wish to use.
wandb_id = 'nszfciym'
# Specify the Name of the model you wish to use.
wandb_name = 'key-lime-pie-110'

# If you are using a WandB account to record the runs, use the code below.
# wandb_run = utils.get_run(wandb_id)
# Otherwise, specify the name and ID of the model and choose the corresponding configuration file for training the model.
wandb_run = utils.RunDemo('fine_tuning/vit.yml', id=wandb_id, name=wandb_name, sub_config='torchvision')

@multiprocess_wrapper
def make_submission(wandb_run):
    config = utils.get_config()
    # You can adjust the number of workers and batch size based on your system configuration.
    embeddings_builder = EmbeddingsBuilder(devices=[0], batch_size=4, num_workers=4)

    corpus_dataset = inference_d.make_submission_corpus_inference_dataset(config, wandb_run.config)
    corpus_embeddings, corpus_names = embeddings_builder.build_embeddings(config, wandb_run, dataset=corpus_dataset)
    query_dataset = inference_d.make_submission_query_inference_dataset(config, wandb_run.config)
    query_embeddings, query_names = embeddings_builder.build_embeddings(config, wandb_run, dataset=query_dataset)

    metric = utils.get_metric(wandb_run.config)
    retriever = FaissRetriever(embeddings_size=corpus_embeddings.shape[1], metric=metric)
    retriever.add_to_index(corpus_embeddings, labels=corpus_names)
    distances, matched_labels = retriever.query(query_embeddings, k=3)
    confidence_scores = dist_to_conf(distances)

    # Create submission file
    result_builder = ResultBuilder(config['path']['submissions'], k=3)
    result_builder(
        query_names,
        matched_labels,
        confidence_scores,
        f'{wandb_run.name}-{wandb_run.id}'
    )

# Model loading issue without using a process
make_submission(wandb_run)

You will find the result in the `submissions` folder.

## üîÅ Iterative Fine-Tuning

The iterative training is based on everything we've presented before; it's heavily inspired by Meta's dataset augmentation process with DINOv2. We start by fine-tuning a model with triplet loss on the challenge data, then we find the most similar images to the training images of each class to augment our dataset. With this new augmented dataset, we rerun fine-tuning on the augmented dataset while keeping the exact same validation dataset. We repeat this process as long as the model shows an improvement in performance.

mettre un visuelle du fonctionnement du process

In [None]:
config = utils.get_config()
iterative_config = utils.load_config('fine_tuning/iterative.yml')
curated_folder = os.path.join(config['path']['data'], 'raw', 'train')

iterative_trainer = IterativeTrainer(
    config,
    iterative_config,
    curated_folder
)

# The iterative trainer was built to natively support loading in subprocesses, so it doesn't need a wrapper.
iterative_trainer.fit()

## üßëüèª‚Äçüíª Code Submission

If you want to change the configuration of the models, please refer to the YAML file available in the config folder.

The script takes images from the data/raw/train folder for training and data/raw/test for inference.

In the code submission, we only provide the solution that yielded the best results. However, you can find all our approaches in the preceding cells.

In [None]:
# W&B initialisation
wandb_api_key = 'c2f177f1a9d0a0415a0ec16af4eb4e9ede7bb392'
! wandb login --relogin {wandb_api_key}

In [None]:
import multiprocessing as mp
import wandb

from src import utils
import src.models.utils as mutils
import src.data.datasets.inference as inference_d
from src.models.inference import EmbeddingsBuilder
from src.models.retriever import FaissRetriever
from src.submissions.make_submissions import ResultBuilder, dist_to_conf


class RefConPipeline:
    def __init__(self, yml_file: str='fine_tuning/vit.yml', sub_config: str='torchvision'):
        self.yml_file = yml_file
        self.sub_config = sub_config
        self.config = utils.get_config()
        self.manager = mp.Manager()
    
    def _train(self, wandb_dict):
        wandb_config = utils.init_wandb(self.yml_file, self.sub_config)
        wandb_dict['wandb_id'] = wandb.run.id
        wandb_dict['wandb_name'] = ''
        trainer = mutils.get_trainer(self.config)
        lightning = mutils.get_lightning(self.config, wandb_config)
        trainer.fit(model=lightning)
        wandb.finish()
       
    def train(self):
        wandb_dict = self.manager.dict({'wandb_id': '', 'wandb_name': 'demo-name'})
        p = mp.Process(target=self._train, args=(wandb_dict,))
        p.start()
        p.join()
        
        return dict(wandb_dict)
    
    def _predict(self, query_folder: str, corpus_folder: str, wandb_run: utils.RunDemo, k: int, batch_size: int, num_workers: int):
        embeddings_builder = EmbeddingsBuilder(devices=wandb_run.config['devices'], batch_size=batch_size, num_workers=num_workers)
        corpus_dataset = inference_d.make_submission_inference_dataset(corpus_folder, self.config, wandb_run.config)
        corpus_embeddings, corpus_names = embeddings_builder.build_embeddings(self.config, wandb_run, dataset=corpus_dataset)
        query_dataset = inference_d.make_submission_inference_dataset(query_folder, self.config, wandb_run.config)
        query_embeddings, query_names = embeddings_builder.build_embeddings(self.config, wandb_run, dataset=query_dataset)

        metric = utils.get_metric(wandb_run.config)
        retriever = FaissRetriever(embeddings_size=corpus_embeddings.shape[1], metric=metric)
        retriever.add_to_index(corpus_embeddings, labels=corpus_names)
        distances, matched_labels = retriever.query(query_embeddings, k=k)
        confidence_scores = dist_to_conf(distances)

        # Create submission file
        result_builder = ResultBuilder(config['path']['submissions'], k=k)
        result_builder(
            query_names,
            matched_labels,
            confidence_scores,
            f'{wandb_run.name}-{wandb_run.id}'
        )
        
    def predict(self, query_folder: str, corpus_folder: str, wandb_id: str='nszfciym', wandb_name: str='key-lime-pie-110', k: int=3, batch_size: int=16, num_workers: int=16):
        wandb_run = utils.RunDemo(self.yml_file, id=wandb_id, name=wandb_name, sub_config=self.sub_config)
        p = mp.Process(target=self._predict, args=(query_folder, corpus_folder, wandb_run, k, batch_size, num_workers))
        p.start()
        p.join()
    
    def __call__(self, query_folder, corpus_folder, k: int=3, batch_size: int=16, num_workers: int=16):
        wandb_dict = self.train()
        self.predict(query_folder=query_folder, corpus_folder=corpus_folder, **wandb_dict, k=k, batch_size=batch_size, num_workers=num_workers)

### üîÆ Predict pipeline

The prediction is done on the images located in the definded folders `corpus_folder` and `query_folder`. The final results are saved in `submissions/key-lime-pie-110-nszfciym.json`.

In [None]:
# To predict with our solution
ref_conf_pipeline = RefConPipeline()

# Change the corpus_foder and query_folder by the path of your data.
corpus_folder = 'data/raw/test/image_corpus'
query_folder = 'data/raw/test/query'
# You can adjust the number of workers and batch size based on your system configuration.
ref_conf_pipeline.predict(query_folder=query_folder, corpus_folder=corpus_folder, batch_size=16, num_workers=16)

### ‚öôÔ∏è Full pipeline (fine-tuning + infering)

In [None]:
# Feel free to modify the configuration according to your preferences.
yml_file = 'fine_tuning/vit.yml'
sub_config = 'torchvision'

ref_conf_pipeline = RefConPipeline(yml_file, sub_config)

# Change the corpus_foder and query_folder by the path of your data.
corpus_folder = 'data/raw/test/image_corpus'
query_folder = 'data/raw/test/query'
# You can adjust the number of workers and batch size based on your system configuration.
# To change workers and batch size for fine-tuning refer to the corresponding configuration file.
ref_conf_pipeline(query_folder=query_folder, corpus_folder=corpus_folder, k=3, batch_size=16, num_workers=16)