In [None]:
# All stages of Data Preprocessing

In [None]:
# generally all the modules needed
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision.transforms.functional as TF
from torch.utils.data import WeightedRandomSampler
import h5py
import numpy as np
import optuna
import random
from sklearn.metrics import classification_report, roc_auc_score, confusion_matrix, roc_curve, auc
from joblib import Parallel, delayed
import cv2
import glob
from optuna.importance import get_param_importances
import gc
import os
import pandas as pd
from scipy.ndimage import zoom
from multiprocessing import Pool, cpu_count

## Preprocessing of Data
For each file a late measure in the exposure was chosed. 

First reference pixel corrections must be applied to extract the cross hatch pattern. A standard row correction is applied, then a background subtraction of the second measure of the exposure, then a normalization without outliers (cnns typically expect normalized data). Then in the processing the images are divided into 16 patches, and stored in tensors with a label of whether or not they contain cross hatching (along with their image and patch id). Finally, for quick loading in of the data, these tensors are stored in 4 chunks of about 5000 each. 

In [None]:
# the functions needed for this
def madstat(a, axis=None, keepdims=False, std=False):
    """
    MADSTAT - Robust statistics using median and median absolute deviation

    Parameters:  a, array_like
                   Calculate the median absolute deviation of these values.
                 axis, tuple
                   Axis or axes along which the median absolute deviation is
                   computed.
                 keepdims, bool (optional)
                   If this is set to True, the axes which are
                   reduced are left in the result as dimensions with
                   size one. With this option, the result will
                   broadcast correctly against the input array.
                 std, bool (optional)
                   If set True, then the MAD is multiplied by 1.4826 to
                   estimate the standard deviation of normal deviates.
    Returns: Median and median absolute deviation. Optionally returns median
             and estimated standard deviation.
    """
    # Compute the median and median absolute deviation
    median = np.median(a, axis=axis, keepdims=keepdims)
    mad = np.median(np.abs(a-median), axis=axis, keepdims=keepdims)

    # Estimate std() if requested
    if std is True:
        mad *= 1.4826

    # Done
    return(median, mad)

def rowcor(D):
    """
        rowcor(D)
        Reference correction using only the top four rows of reference pixels.
        parameters: D, array
                      The input data cube
        Returns:    nothing
                      This overwrites the input data
    """
    # Get cube dimensions
    nz, ny, nx = D.shape
    # Definitions
    nout = 32       # The WFI uses 32 outputs for full frame readout
    w = nx//nout    # Width of each output in pixels
    count = 3        # Clip off the best/worse few samples for robust statistics

    # Compute first and last image columns for each output
    x0 = np.arange(0, nx, w)    # First cols
    x1 = x0 + w-1          # Last cols
    # Apply reference correction working frame-by-frame and output-by-output
    for z in np.arange(nz):
        for op in np.arange(nout):
            refpix = D[z, 4093:4095, x0[op]:x1[op]]           # Get ref. pixels
            refpix = np.sort(refpix.flatten())[count-1:-count+1]    # Trim outliers
            mu = np.mean(refpix)                                 # Robust mean
            D[z, :, x0[op]:x1[op]] -= mu

In [None]:
def apply_gabor_filters(image, angles=[68,114, 34, 170], ksize=15, sigma=4.0, lambd=10.0, gamma=0.5):
    gabor_outputs = []
    for theta in angles:
        kernel = cv2.getGaborKernel((ksize, ksize), sigma, np.deg2rad(theta), lambd, gamma, 0, ktype=cv2.CV_32F)
        filtered = cv2.filter2D(image, cv2.CV_32F, kernel)
        gabor_outputs.append(filtered)
    return np.stack(gabor_outputs, axis=0)
 
def preprocess_patch_with_gabor_and_edges(image):
    normed_8bit = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    edges = cv2.Canny(normed_8bit, 50, 150).astype(np.float32) / 255.0
    gabor_stack = apply_gabor_filters(normed_8bit)
    image = 2 * (image - image.min()) / (image.max() - image.min() + 1e-8) - 1
    stacked = np.concatenate([
        image[np.newaxis, :, :],
        edges[np.newaxis, :, :],
        gabor_stack
    ], axis=0)
    return torch.from_numpy(stacked).float()

In [None]:
# Load CSV
df = pd.read_csv('new_classes.csv', header=None)
df = df[df[0].notna() & (df[0].astype(str).str.strip() != '')]
 
# Parse columns
labels_raw = df[1].tolist()   # 'y' or 'n'
file_paths = df[0].tolist()
patch_column = df[2].fillna('').tolist()
 
labels = [1 if l == 'y' else 0 for l in labels_raw]
positive_patches = []
for entry in patch_column:
    if entry == '':
        positive_patches.append([])
    else:
        # Convert "3;7;12" → [2, 6, 11]
        patches = [int(p.strip()) - 1 for p in str(entry).split(';') if p.strip().isdigit()]
        positive_patches.append(patches)
 
# Set up patch cutting
NUM_PATCHES = 4
PATCH_SIZE = 1024
save_dir = "/explore/nobackup/people/cemeehan/ProcessedPatches"
os.makedirs(save_dir, exist_ok=True)
 
def preprocess_and_save_patch(args):
    fpath, label, patch_ids, idx = args
    try:
        with h5py.File(fpath, 'r') as f:
            frames = f["Frames"][[1, len(f["Frames"]) - 15], :, :4096].astype(np.float32)
    except Exception as e:
        print(f"Error with {fpath}: {e}")
        return
 
    rowcor(frames)
    cds = frames[1] - frames[0]
    vmax = np.percentile(cds, 99)
    vmin = -vmax
    cds = np.clip(cds, vmin, vmax)
    cds = 2 * (cds - vmin) / (vmax - vmin) - 1
 
    for i in range(NUM_PATCHES * NUM_PATCHES):
        row = i // NUM_PATCHES
        col = i % NUM_PATCHES
        patch = cds[row*PATCH_SIZE:(row+1)*PATCH_SIZE, col*PATCH_SIZE:(col+1)*PATCH_SIZE]
        patch_tensor = torch.from_numpy(patch).unsqueeze(0)  # Shape: 1x1024x1024
        patch_tensor = TF.resize(patch_tensor, [512, 512])   # Resize if needed
 
        meta = {
            'data': patch_tensor,
            'label': torch.tensor([1.0]) if i in patch_ids else torch.tensor([0.0]),
            'filename': os.path.basename(fpath),
            'patch_id': i
        }
 
        torch.save(meta, os.path.join(save_dir, f"{idx}_{i}.pt"))

# Run in parallel
args_list = [(f, l, p, i) for i, (f, l, p) in enumerate(zip(file_paths, labels, positive_patches))]
with Pool(min(cpu_count(), 8)) as pool:
    pool.map(preprocess_and_save_patch, args_list)

In [None]:
paths = sorted(glob.glob('/explore/nobackup/people/cemeehan/PPProcessedPatchesGabor/*.pt'))
chunk_size = 5000
chunk = []
chunk_idx = 0
 
for i, path in enumerate(paths):
    img, label, fname, patch_id = torch.load(path)
    sample = {
        'data': img,
        'label': label,
        'filename': fname,
        'patch_id': patch_id
    }
    chunk.append(sample)
 
    if (i + 1) % chunk_size == 0 or (i + 1) == len(paths):
        # Save this chunk
        torch.save(chunk, f'/explore/nobackup/people/cemeehan/chunk_{chunk_idx}.pt')
        print(f"[INFO] Saved chunk {chunk_idx} with {len(chunk)} samples")
        chunk = []  # reset buffer
        chunk_idx += 1