# RWI: image, XRF and XRD

In [None]:
from pathlib import Path
from tqdm.auto import tqdm
import torch
import webdataset as wds

import lasio
import polars as pl

import lightning as L
from sklearn.model_selection import train_test_split

BASE = Path("~/Developer/cutting_cbir/data/cuttings")
TABLES = BASE / "las"
IMAGES = BASE / "raw"
OUTPUT = Path("data/rwi")

OUTPUT.mkdir(parents=True, exist_ok=True)

SEED = 1337

PATCH_SIZE = 256
GRID_ROWS = 2
GRID_COLS = 5

L.seed_everything(SEED)

## Load tables

In [None]:
def load_las(f: Path) -> pl.DataFrame:
    las = lasio.read(f.as_posix())
    well_name: str = las.well["WELL"].value
    quad, rest = well_name.split("/", maxsplit=1)
    well_name = f"{quad}_{rest}"
    df = pl.from_pandas(las.df(), include_index=True).with_columns(
        pl.lit(well_name).alias("well_name")
    )
    return df


In [None]:
xrf_data = pl.concat(
    (load_las(f) for f in tqdm(TABLES.glob("*_XRF_1.LAS"), desc="Loading LAS files")),
    how="diagonal_relaxed",
)
xrf_data

In [None]:
xrd_data = pl.concat(
    (load_las(f) for f in tqdm(TABLES.glob("*_XRD_*.LAS"), desc="Loading LAS files")),
    how="diagonal_relaxed",
)
xrd_data

In [None]:
def parse_depth(f: Path) -> float:
    return round(float(f.stem.split("_")[-3])) / 100


def parse_well_name(f: Path) -> str:
    return f.parent.name


images_data = pl.DataFrame(
    {
        "well_name": parse_well_name(f),
        "DEPTH": parse_depth(f),
        "image": f.stem,
    }
    for f in tqdm(IMAGES.glob("**/*.CR2"), desc="Loading image paths")
)
images_data

## Merge tables

In [None]:
images_data["well_name"].value_counts().sort("well_name")

In [None]:
for w, count in xrf_data["well_name"].value_counts().sort("well_name").iter_rows():
    print(f"{w}: {count}")

In [None]:
df = xrf_data.join(
    xrd_data,
    on=("DEPTH", "well_name"),
).join(
    images_data,
    on=(
        "DEPTH",
        "well_name",
    ),
)

assert df["image"].n_unique() == len(df)
df


## Process features

In [None]:
df.null_count().transpose(include_header=True).filter(pl.col("column_0") > 0)

In [None]:
df = df.drop("BALANCE").with_columns(
    pl.col("MGO").fill_null(0.0),
)

## Splits

In [None]:
wells = df["well_name"].unique().sort().to_list()
print(f"Total wells: {len(wells):,}")

# Stratified split: 80% train+val, 20% test
train_val_wells, test_wells = train_test_split(
    wells,
    test_size=0.2,
    random_state=SEED,
)

# Split train+val: 80% train, 20% val (of the 80%)
train_val_df = df.filter(pl.col("well_name").is_in(train_val_wells))
train_wells, val_wells = train_test_split(
    train_val_wells,
    test_size=0.2,
    random_state=SEED,
)

print(
    f"Train: {len(train_wells):,}, Val: {len(val_wells):,}, Test: {len(test_wells):,}"
)

train_df = df.filter(pl.col("well_name").is_in(train_wells)).sample(
    fraction=1.0,
    with_replacement=False,
    seed=SEED,
    shuffle=True,
)

val_df = df.filter(pl.col("well_name").is_in(val_wells))
test_df = df.filter(pl.col("well_name").is_in(test_wells))

splits = {"train": train_df, "val": val_df, "test": test_df}

print(
    f"Ratios: {len(train_df) / len(df):.1%} / {len(val_df) / len(df):.1%} / {len(test_df) / len(df):.1%}"
)

## Tabular features

In [None]:
index_cols = ["well_name", "DEPTH", "image"]
xrd_cols = [col for col in df.columns if col.startswith("XRD_")]
xrf_cols = [col for col in df.columns if col not in index_cols and col not in xrd_cols]

In [None]:
# Normalise continuous features
train_means = train_df[xrf_cols].mean()
train_stds = train_df[xrf_cols].std()

splits = {
    k: v.with_columns(
        pl.col(col).sub(train_means[col]).truediv(train_stds[col]).alias(f"{col}_norm")
        for col in xrf_cols
    )
    for k, v in splits.items()
}

print("Normalised continuous features")

## Save webdataset

In [None]:
from pathlib import Path
from typing import cast

import numpy as np
import numpy.typing as npt

import rawpy


def read_cr2_image(file_path: Path) -> npt.NDArray[np.uint8]:
    # https://letmaik.github.io/rawpy/api/rawpy.Params.html
    # https://www.libraw.org/docs/API-datastruct-eng.html
    with rawpy.imread(file_path.absolute().as_posix()) as raw:
        rgb_image = raw.postprocess(
            # demosaicing / de-noise
            demosaic_algorithm=rawpy.DemosaicAlgorithm.AHD,  # default
            fbdd_noise_reduction=rawpy.FBDDNoiseReductionMode.Off,
            median_filter_passes=0,
            # white balance: fixed
            use_camera_wb=False,
            use_auto_wb=False,
            user_wb=(2.0, 1.0, 1.3, 0.0),  # daylight balanced
            # color space + bit depth
            output_color=rawpy.ColorSpace.sRGB,
            output_bps=8,
            # exposure / brightness / scaling
            no_auto_bright=True,  # disable LibRaw auto brightness
            auto_bright_thr=0.00,  # ignored if no_auto_bright=True
            bright=1.5,  # no arbitrary brightness scale
            exp_shift=None,  # no exposure shift
            exp_preserve_highlights=0.0,
            no_auto_scale=False,  # KEEP scaling, you want proper WB + normalization
            # black/white levels
            user_black=512,  # all images have black level ~512
            user_sat=0.0,  # no user saturation level
            adjust_maximum_thr=0.0,  # no automatic white level adjustment
            # gamma / tone curve
            gamma=(2.222, 4.5),  # default, BT.709
            # misc
            highlight_mode=rawpy.HighlightMode.Clip,  # dont recover highlights
            user_flip=0,  # no gyro shenanigans
        )

    rgb_image = cast(
        npt.NDArray[np.uint8],
        rgb_image,
    )

    clip_width = 12
    clip_height = 10

    rgb_image = rgb_image[clip_height:-clip_height, clip_width:-clip_width, :]

    return rgb_image



def extract_grid_patches(
    well_name: str,
    image_path: str,
) -> npt.NDArray[np.uint8]:
    path = IMAGES / well_name / f"{image_path}.CR2"
    rgb_image = read_cr2_image(path)
    img_height, img_width, _ = rgb_image.shape

    row_stride = (img_height - PATCH_SIZE) // (GRID_ROWS - 1)
    col_stride = (img_width - PATCH_SIZE) // (GRID_COLS - 1)

    patches = []
    for r in range(GRID_ROWS):
        for c in range(GRID_COLS):
            top = r * row_stride
            left = c * col_stride
            patch = rgb_image[top : top + PATCH_SIZE, left : left + PATCH_SIZE, :]
            patches.append(patch)

    return np.stack(patches)


def extract_centre_patch(
    well_name: str,
    image_path: str,
) -> npt.NDArray[np.uint8]:
    path = IMAGES / well_name / f"{image_path}.CR2"
    rgb_image = read_cr2_image(path)
    patch = rgb_image[1872:2128, 2872:3128, :]
    return patch[np.newaxis, ...]


import vizy

vizy.plot(list(extract_grid_patches(df["well_name"][0], df["image"][0])))

In [None]:
feature_cols = [f"{c}_norm" for c in xrf_cols]


def process_single_image(
    row: dict[str, str | float],
    split: str,
) -> dict[str, object] | None:
    try:
        if split == "train":
            patches = extract_grid_patches(row["well_name"], row["image"])
        else:
            patches = extract_centre_patch(row["well_name"], row["image"])

        return {
            "__key__": row["image"],
            "label.pth": torch.tensor(
                [row[col] for col in xrd_cols], dtype=torch.float32
            ),
            "features.pth": torch.tensor(
                [row[col] for col in feature_cols], dtype=torch.float32
            ),
            "patches.pth": torch.tensor(patches, dtype=torch.uint8),
        }
    except Exception as e:
        print(f"Error processing {row['image']}: {e}")
        return None


from concurrent.futures import ThreadPoolExecutor
from functools import partial

OUTPUT.mkdir(parents=True, exist_ok=True)

shard_counts: dict[str, int] = {}

for name, split_df in splits.items():
    pattern = str(OUTPUT / f"{name}-%04d.tar")
    rows = list(split_df.iter_rows(named=True))
    process_fn = partial(process_single_image, split=name)

    shard_idx = 0
    with wds.ShardWriter(  # type: ignore
        pattern,
        maxsize=2e8, # ~200MB per shard
    ) as sink:
        with ThreadPoolExecutor(max_workers=10) as executor:
            for sample in tqdm(
                executor.map(process_fn, rows),
                total=len(rows),
                desc=f"Processing {name} split",
            ):
                if sample is not None:
                    sink.write(sample)
        shard_idx = sink.shard

    shard_counts[name] = shard_idx + 1
    print(f"{name}: {shard_counts[name]} shards")

In [None]:
import json

metadata = {
    "xrd": xrd_cols,
    "xrf": xrf_cols,
    "xrf_means": {col: train_means[col].item() for col in xrf_cols},
    "xrf_stds": {col: train_stds[col].item() for col in xrf_cols},
    "split_sizes": {
        name: len(split_df) for name, split_df in splits.items()
    },
    "split_wells": {
        name: split_df["well_name"].unique().sort().to_list()
        for name, split_df in splits.items()
    },
    "shard_counts": shard_counts,
}
(OUTPUT / "metadata.json").write_text(json.dumps(metadata, indent=4))

print("Saved metadata")