# Load .mat file

In [6]:
import h5py
import pandas as pd
import numpy as np

def mat_to_dataframes(file_path):
    """
    Extract data from MATLAB v7.3 file and convert suitable structures to DataFrames
    """
    dataframes = {}  # Dictionary to store all extracted dataframes
    
    try:
        with h5py.File(file_path, 'r') as file:
            # Print file structure first to understand what we're working with
            print("File structure:")
            file.visititems(lambda name, obj: print(f"{'Dataset' if isinstance(obj, h5py.Dataset) else 'Group'}: {name}, "
                                                   f"{'Shape: ' + str(obj.shape) + ', Type: ' + str(obj.dtype) if isinstance(obj, h5py.Dataset) else ''}"))
            
            # Process each top-level element
            print("\nExtracting data to DataFrames:")
            for key in file.keys():
                # Check if it's a dataset directly
                if isinstance(file[key], h5py.Dataset):
                    df = dataset_to_dataframe(file[key], key)
                    if df is not None:
                        dataframes[key] = df
                # If it's a group, try to extract structured data
                else:
                    group_dfs = extract_group_dataframes(file[key], parent_name=key)
                    dataframes.update(group_dfs)
                    
            return dataframes
    
    except Exception as e:
        print(f"Error processing file: {e}")
        return None

def dataset_to_dataframe(dataset, name):
    """Convert an h5py dataset to a pandas DataFrame if possible"""
    try:
        data = dataset[()]
        
        # Handle string data
        if data.dtype.kind in ('S', 'O', 'U'):
            try:
                # Try to convert character arrays to strings
                if len(data.shape) == 2 and min(data.shape) == 1:
                    data = ''.join(chr(c) for c in data.flat)
                    print(f"  Extracted string from {name}: {data}")
                    return None  # Single strings don't need a DataFrame
            except:
                pass  # If conversion fails, continue with normal processing
        
        # For 1D or 2D numeric arrays
        if len(data.shape) <= 2:
            # Convert to DataFrame (handles both 1D and 2D arrays)
            df = pd.DataFrame(data)
            print(f"  Created DataFrame from {name}: {df.shape}")
            return df
            
        else:
            print(f"  Skipping {name}: {len(data.shape)}-dimensional data")
            return None
            
    except Exception as e:
        print(f"  Error converting {name} to DataFrame: {e}")
        return None

def extract_group_dataframes(group, parent_name=""):
    """
    Recursively extract DataFrames from group structure
    Handles common MATLAB structure patterns in HDF5
    """
    dataframes = {}
    
    # For each item in the group
    for key in group.keys():
        full_key = f"{parent_name}/{key}"
        
        # If it's a dataset, try to convert to DataFrame
        if isinstance(group[key], h5py.Dataset):
            df = dataset_to_dataframe(group[key], full_key)
            if df is not None:
                dataframes[full_key] = df
        
        # If it's a group, process recursively
        else:
            sub_dataframes = extract_group_dataframes(group[key], full_key)
            dataframes.update(sub_dataframes)
    
    # Handle special case: MATLAB structure arrays
    # These typically have fields like 'field1', 'field2' that should be combined
    if all(k.startswith(parent_name + "/") for k in dataframes.keys()):
        # Check if we can combine fields into a single DataFrame
        try:
            # Get all direct children fields
            fields = [k.split('/')[-1] for k in dataframes.keys() 
                     if len(k.split('/')) == len(parent_name.split('/')) + 2]
            
            if fields:
                print(f"  Detected potential structure in {parent_name} with fields: {fields}")
                
                # Future enhancement: combine structure fields into a single DataFrame
                # This would require additional logic to align the data properly
        except:
            pass
    
    return dataframes

# Main execution
file_path = "/home/work/OCT_DL/CDAC_OCT/CDAC_PYTHON/data/2_3Dregistration/d5_Int_05_CAO.mat"
all_dataframes = mat_to_dataframes(file_path)

# Display summary of extracted DataFrames
if all_dataframes:
    print("\nExtracted DataFrames Summary:")
    for name, df in all_dataframes.items():
        print(f"- {name}: Shape {df.shape}")
    
    # Show the first DataFrame as an example
    if len(all_dataframes) > 0:
        first_key = list(all_dataframes.keys())[0]
        print(f"\nFirst 5 rows of '{first_key}':")
        print(all_dataframes[first_key].head())
else:
    print("No DataFrames were extracted.")

File structure:
Dataset: CAOfilter, Shape: (625, 625), Type: [('real', '<f8'), ('imag', '<f8')]

Extracting data to DataFrames:
  Error converting CAOfilter to DataFrame: Data must be 1-dimensional, got ndarray of shape (625, 625) instead
No DataFrames were extracted.


# Convert .mat into .pt

In [14]:
# /scripts/convert_mat_v73_to_pt.py
from __future__ import annotations
import h5py
import numpy as np
import torch
from pathlib import Path
from typing import Dict, Any

def _is_complex_struct(arr: np.ndarray) -> bool:
    return hasattr(arr.dtype, "names") and arr.dtype.names is not None and {"real", "imag"} <= set(arr.dtype.names)

def _to_numpy_numeric(ds: h5py.Dataset) -> np.ndarray | None:
    """h5py Dataset -> numpy numeric (float/complex). 비수치형은 None."""
    arr = ds[()]  # Load to numpy
    # compound dtype (real/imag) → complex
    if isinstance(arr, np.ndarray) and _is_complex_struct(arr):
        arr = arr["real"] + 1j * arr["imag"]
    # 문자열/객체 등 비수치형 스킵
    if not isinstance(arr, np.ndarray):
        return None
    if arr.dtype.kind in ("b", "i", "u", "f", "c"):  # bool/int/uint/float/complex
        return arr
    return None

def _walk_collect_numeric(f: h5py.File) -> Dict[str, np.ndarray]:
    """파일 전체를 순회하며 수치형 dataset만 수집 (group 경로를 키로 사용)."""
    out: Dict[str, np.ndarray] = {}
    def visit(name: str, obj: Any):
        if isinstance(obj, h5py.Dataset):
            np_arr = _to_numpy_numeric(obj)
            if np_arr is not None:
                out[name] = np_arr
    f.visititems(visit)
    return out

def convert_mat_v73_to_pt(mat_path: str | Path, out_pt_path: str | Path, keys: list[str] | None = None) -> dict:
    """v7.3 .mat → .pt(dict). keys가 None이면 모든 수치형 dataset 저장."""
    mat_path = str(mat_path)
    tensors: Dict[str, torch.Tensor] = {}
    with h5py.File(mat_path, "r") as f:
        if keys is None:
            np_dict = _walk_collect_numeric(f)
        else:
            np_dict = {}
            for k in keys:
                if k in f and isinstance(f[k], h5py.Dataset):
                    arr = _to_numpy_numeric(f[k])
                    if arr is not None:
                        np_dict[k] = arr

    for name, arr in np_dict.items():
        # numpy → torch (복소 가능)
        t = torch.from_numpy(arr)  # dtype 유지 (float64→torch.float64, complex128→torch.complex128)
        tensors[name] = t

    torch.save(tensors, str(out_pt_path))
    return tensors

if __name__ == "__main__":
    src = "/home/work/OCT_DL/CDAC_OCT/CDAC_PYTHON/data/2_3Dregistration/d5_Int_08_subLayers_0704.mat"
    dst = "/home/work/OCT_DL/CDAC_OCT/CDAC_PYTHON/data/2_3Dregistration/d5_Int_08_subLayers_0704.pt"
    tensors = convert_mat_v73_to_pt(src, dst)  # 모든 수치형 dataset 저장
    print(f"Saved {len(tensors)} tensors to {dst}")
    for k, v in tensors.items():
        print(f"- {k}: shape={tuple(v.shape)}, dtype={v.dtype}")

Saved 4 tensors to /home/work/OCT_DL/CDAC_OCT/CDAC_PYTHON/data/2_3Dregistration/d5_Int_08_subLayers_0704.pt
- ILMnew: shape=(625, 625), dtype=torch.float64
- ISOSnew: shape=(625, 625), dtype=torch.float64
- NFLnew: shape=(625, 625), dtype=torch.float64
- RPEnew: shape=(625, 625), dtype=torch.float64


In [None]:
# ==== 임의 z 단면: CAO 전/후(+volume) 시각화 유틸 (디바이스 정합 버전) ========
import numpy as np
import matplotlib.pyplot as plt
import torch
from utils.utils import fft2c, _center_embed, load_layers,ifft2c, _center_crop


def _amp_phase_np(t: torch.Tensor):
    x = t.detach().cpu().numpy()
    return np.abs(x), np.angle(x)

def _soft_pupil(H: int, W: int, device, r0=0.98, feather=0.02):
    yy = torch.linspace(-1, 1, H, device=device)
    xx = torch.linspace(-1, 1, W, device=device)
    Yg, Xg = torch.meshgrid(yy, xx, indexing='ij')
    r = torch.sqrt(Xg*Xg + Yg*Yg)
    return torch.clamp((1 - (r - r0)/feather), min=0.0, max=1.0)

@torch.no_grad()
def visualize_cao_slice(z_idx: int):
    assert 'img_vol' in globals(), "img_vol (Z,X,Y) complex tensor가 필요합니다."
    Z, X, Y = img_vol.shape
    dev = img_vol.device
    z = int(np.clip(z_idx, 0, Z-1))

    # ---- 3D 임베드/크롭 헬퍼 ----
    def _center_embed_hwz(x: torch.Tensor, out_hw):
        X0, Y0, Zs = x.shape
        H, W = out_hw
        rs, cs = H//2 - X0//2, W//2 - Y0//2
        out = torch.zeros((H, W, Zs), dtype=x.dtype, device=x.device)
        out[rs:rs+X0, cs:cs+Y0, :] = x
        return out, (rs, cs, X0, Y0)

    def _center_crop_hwz(x: torch.Tensor, rs_cs_hw):
        rs, cs, X0, Y0 = rs_cs_hw
        return x[rs:rs+X0, cs:cs+Y0, :]

    # ---- phase/filter 준비 (자가복구) ----
    phase_fd_local = None
    if 'phase_fd' in globals():
        phase_fd_local = globals()['phase_fd']
    elif ('A_AdamW' in globals()) and ('Zfd' in globals()):
        phase_fd_local = torch.tensordot(Zfd, A_AdamW, dims=([2],[0]))  # (Hfd,Wfd)

    CAOfilter_local = globals().get('CAOfilter', None)

    if (phase_fd_local is None) and (CAOfilter_img_local is None):
        raise RuntimeError("phase_fd / (A_AdamW+Zfd) / CAOfilter 중 하나가 필요합니다.")

    # 크기 파라미터
    Him_local = int(globals().get('Him', X))
    Wim_local = int(globals().get('Wim', Y))
    if phase_fd_local is not None:
        Hfd_local, Wfd_local = phase_fd_local.shape
    else:
        Hfd_local = int(globals().get('Hfd', Him_local))
        Wfd_local = int(globals().get('Wfd', Wim_local))

    # ---------- Before ----------
    sl_before = img_vol[z]  # (X,Y) complex

    # ---------- After (slice-only) ----------
    if phase_fd_local is not None:
        sl_embed, rs_cs_hw_local = _center_embed(sl_before, (Hfd_local, Wfd_local))
        Fz = fft2c(sl_embed)
        pupil_soft = _soft_pupil(Hfd_local, Wfd_local, sl_before.device)
        # ↓ 디바이스/자료형 통일 + ones_like 사용
        P_fd = torch.exp(-1j * phase_fd_local.to(device=sl_before.device, dtype=sl_before.dtype))
        ones_fd = torch.ones_like(P_fd)
        P_fd = (P_fd - ones_fd) * pupil_soft + ones_fd
        cz_full = ifft2c(Fz * P_fd)
        sl_after_slice = _center_crop(cz_full, (Him_local, Wim_local), rs_cs_hw_local)
    else:
        pupil_soft_img = _soft_pupil(Him_local, Wim_local, sl_before.device)
        Cimg = CAOfilter_local.to(device=sl_before.device, dtype=sl_before.dtype)
        ones_img = torch.ones_like(Cimg)
        P_img = (Cimg - ones_img) * pupil_soft_img + ones_img
        F_xy = torch.fft.fft2(sl_before)
        sl_after_slice = torch.fft.ifft2(F_xy * P_img)

    # ---------- After (volume-corrected) ----------
    if 'corrected_vol' in globals():
        sl_after_vol = corrected_vol[z]
        vol_title_suffix = " (from corrected_vol)"
    else:
        assert 'ISOS' in globals() and 'RPE' in globals(), "ISOS, RPE 전역이 필요합니다."
        PRLstart = int(np.floor(np.nanmin(ISOS))); PRLstart = max(PRLstart, 0)
        PRLend   = int(np.ceil (np.nanmax(RPE)));  PRLend   = min(PRLend, Z-1)
        slab = img_vol[PRLstart:PRLend+1]                 # (Zs,X,Y)
        slab_xy = slab.permute(1,2,0).contiguous()        # (X,Y,Zs)

        if phase_fd_local is None:
            sl_after_vol = sl_after_slice
            vol_title_suffix = " (slice-like; no phase_fd)"
        else:
            slab_zp_hwz, rs_cs_hw2 = _center_embed_hwz(slab_xy, (Hfd_local, Wfd_local))
            slab_zp_bhw = slab_zp_hwz.permute(2,0,1).contiguous()  # (Zs,Hfd,Wfd)
            F_zp = fft2c(slab_zp_bhw)

            pupil_soft = _soft_pupil(Hfd_local, Wfd_local, F_zp.device)
            P_fd = torch.exp(-1j * phase_fd_local.to(device=F_zp.device, dtype=F_zp.dtype))
            ones_fd = torch.ones_like(P_fd)
            P_fd = ((P_fd - ones_fd) * pupil_soft + ones_fd).unsqueeze(0)  # (1,Hfd,Wfd)
            F_corr = F_zp * P_fd

            slab_corr_bhw = ifft2c(F_corr)                         # (Zs,Hfd,Wfd)
            slab_corr_hwz = slab_corr_bhw.permute(1,2,0).contiguous()
            slab_corr_xy  = _center_crop_hwz(slab_corr_hwz, rs_cs_hw2)

            if PRLstart <= z <= PRLend:
                sl_after_vol = slab_corr_xy[..., z-PRLstart]       # (X,Y)
            else:
                sl_after_vol = sl_before
            vol_title_suffix = " (simulated full-slab)"

    # ---------- 수치/시각화 ----------
    def _H_shannon(t: torch.Tensor) -> float:
        I = (t.real**2 + t.imag**2).clamp_min(0)
        p = I / (I.sum() + 1e-12)
        return float((-(p * (p + 1e-12).log())).sum().item())

    amp_b = torch.abs(sl_before).cpu().numpy()
    amp_s = torch.abs(sl_after_slice).cpu().numpy()
    amp_v = torch.abs(sl_after_vol).cpu().numpy()

    vmin = np.percentile(np.hstack([amp_b.ravel(), amp_s.ravel(), amp_v.ravel()]), 1)
    vmax = np.percentile(np.hstack([amp_b.ravel(), amp_s.ravel(), amp_v.ravel()]), 99)

    H_before = _H_shannon(sl_before)
    H_slice  = _H_shannon(sl_after_slice)
    H_vol    = _H_shannon(sl_after_vol)

    Imax_b = float((sl_before.real**2 + sl_before.imag**2).max().item())
    Imax_s = float((sl_after_slice.real**2 + sl_after_slice.imag**2).max().item())
    Imax_v = float((sl_after_vol  .real**2 + sl_after_vol  .imag**2).max().item())

    plt.figure(figsize=(18, 6))
    plt.suptitle(
        f"CAO slice visualization @ z={z} | "
        f"H_b={H_before:.4f}, H_s={H_slice:.4f}, H_v={H_vol:.4f} | "
        f"peak gain s×{(Imax_s+1e-12)/(Imax_b+1e-12):.2f}, v×{(Imax_v+1e-12)/(Imax_b+1e-12):.2f}"
    )

    ax1 = plt.subplot(1, 3, 1)
    im1 = ax1.imshow(amp_b, cmap="gray", vmin=vmin, vmax=vmax)
    ax1.set_title("Amplitude (Before)")
    ax1.axis("off"); plt.colorbar(im1, ax=ax1, shrink=0.8)

    ax2 = plt.subplot(1, 3, 2)
    im2 = ax2.imshow(amp_s, cmap="gray", vmin=vmin, vmax=vmax)
    ax2.set_title("Amplitude (After CAO — slice-only)")
    ax2.axis("off"); plt.colorbar(im2, ax=ax2, shrink=0.8)

    ax3 = plt.subplot(1, 3, 3)
    im3 = ax3.imshow(amp_v, cmap="gray", vmin=vmin, vmax=vmax)
    ax3.set_title(f"Amplitude (After CAO — volume){vol_title_suffix}")
    ax3.axis("off"); plt.colorbar(im3, ax=ax3, shrink=0.8)

    plt.tight_layout(); plt.show()

In [15]:
# /scripts/bootstrap_visualize_cao_slice_from_fringes.py
from __future__ import annotations
# /scripts/bootstrap_visualize_cao_slice_from_fringes.py

import sys
from pathlib import Path
ROOT = Path("/home/work/OCT_DL/CDAC_OCT/CDAC_PYTHON")
sys.path.append(str(ROOT))

import torch
import numpy as np
from utils.utils import load_layers

# 입력 경로
CAO_PT       = ROOT / "data/2_3Dregistration/d5_Int_04_CAO.pt"          # (CAO 적용된 fringes 들어있음)
RAW_FRINGES  = ROOT / "cache/fringes.pt"                                 # 원본 fringes (Before)
LAYERS_PT    = ROOT / "data/2_3Dregistration/d5_Int_08_Layers_0704.pt"
SUBLAYERS_PT = ROOT / "data/2_3Dregistration/d5_Int_08_subLayers_0704.pt"

dev = "cuda" if torch.cuda.is_available() else "cpu"

def _pick_fringes_key(d: dict) -> str | None:
    # 3D 복소 텐서(Nz,Nx,Ny)인 키를 우선 탐색
    for k, v in d.items():
        if isinstance(v, torch.Tensor) and v.ndim == 3 and torch.is_complex(v):
            return k
    # fallback: 흔한 이름 힌트
    for k in ("fringes_cao", "fringes_after", "fringes", "F_after", "F"):
        if k in d and isinstance(d[k], torch.Tensor):
            return k
    return None

# --- Before: 원본 fringes → img_vol ---
if not RAW_FRINGES.exists():
    raise FileNotFoundError(f"Not found: {RAW_FRINGES}")
fr_before = torch.load(str(RAW_FRINGES), map_location="cpu")
if not torch.is_complex(fr_before) or fr_before.ndim != 3:
    raise TypeError("RAW fringes must be complex (Nz,Nx,Ny).")
globals()["img_vol"] = torch.fft.ifft(fr_before, dim=0).to(dev)  # (Z,X,Y) complex

# --- After: CAO 적용된 fringes → corrected_vol ---
d_cao = torch.load(str(CAO_PT), map_location="cpu")
k = _pick_fringes_key(d_cao)
if k is None:
    raise KeyError(
        f"{CAO_PT} 안에서 (Nz,Nx,Ny) 복소 텐서를 찾지 못했습니다. "
        "저장 시 CAO 적용된 fringes를 그대로 포함시켜 주세요."
    )
fr_after = d_cao[k]
if not torch.is_complex(fr_after) or fr_after.ndim != 3:
    raise TypeError(f"'{k}' must be complex (Nz,Nx,Ny).")
globals()["corrected_vol"] = torch.fft.ifft(fr_after, dim=0).to(dev)  # (Z,X,Y) complex

# 크기/레이어
Z, Him, Wim = globals()["img_vol"].shape
globals()["Him"], globals()["Wim"] = int(Him), int(Wim)

sublayers_arg = SUBLAYERS_PT if SUBLAYERS_PT.exists() else None
ILM, NFL, ISOS, RPE = load_layers(LAYERS_PT, sublayers_arg)
globals()["ISOS"] = np.asarray(ISOS, dtype=float)
globals()["RPE"]  = np.asarray(RPE,  dtype=float)

# 플래그: 프리코렉티드 모드(phase/CAOfilter 재적용 금지)
globals()["USE_PRECORRECTED"] = True

print(f"[ready] img_vol={tuple(globals()['img_vol'].shape)}, "
      f"corrected_vol={tuple(globals()['corrected_vol'].shape)}, "
      f"Him×Wim=({globals()['Him']},{globals()['Wim']})")

KeyError: '/home/work/OCT_DL/CDAC_OCT/CDAC_PYTHON/data/2_3Dregistration/d5_Int_04_CAO.pt 안에서 (Nz,Nx,Ny) 복소 텐서를 찾지 못했습니다. 저장 시 CAO 적용된 fringes를 그대로 포함시켜 주세요.'