In [100]:
import numpy as np
import pywt
import matplotlib.pyplot as plt

def read_wavelet3d_result(filename):
    subbands = {}
    subband_keys = ['aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd']
    
    with open(filename, 'rb') as file:
        for key in subband_keys:
            # Read the dimensions of the sub-band
            sub_depth = int(np.fromfile(file, dtype=np.uint64, count=1)[0])
            sub_rows = int(np.fromfile(file, dtype=np.uint64, count=1)[0])
            sub_cols = int(np.fromfile(file, dtype=np.uint64, count=1)[0])
            
            # Read the data of the sub-band
            data = np.fromfile(file, dtype=np.float32, count=sub_depth * sub_rows * sub_cols)
            data = data.reshape((sub_depth, sub_rows, sub_cols))
            
            # Store the sub-band data in the dictionary
            subbands[key] = data
    
    return subbands

def perform_3d_dwt(input_vol, wavelet, levels):
    coeffs3 = pywt.wavedecn(input_vol, wavelet=wavelet, level=levels, mode='periodization')
    
    subbands = {}
    subband_keys = ['aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd']
    
    for i, key in enumerate(subband_keys):
        if i == 0:
            subbands[key] = coeffs3[0]
        else:
            if key in coeffs3[1]:
                subbands[key] = coeffs3[1][key]
            else:
                print(f"Key {key} not found in coeffs3[1]")  # Debug print for missing keys
    
    #print(subbands)

    return subbands

def plot_subbands(subbands):
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    for i, (key, data) in enumerate(subbands.items()):
        ax = axes[i // 4, i % 4]
        ax.imshow(data[0], cmap='gray')
        ax.set_title(key)
        ax.axis('off')
    plt.tight_layout()
    plt.show()

def compare_subbands(subbands_file, subbands_dwt):
    for key in subbands_file.keys():
        if key in subbands_dwt:
            data_file = subbands_file[key]
            data_dwt = subbands_dwt[key]
            
            if data_file.shape != data_dwt.shape:
                print(f"Shape mismatch for sub-band {key}: file shape {data_file.shape}, DWT shape {data_dwt.shape}")
                continue
            
            # Compare all elements
            comparison = np.isclose(data_file, data_dwt, atol=1e-6)
            if np.all(comparison):
                print(f"Sub-band {key} matches between file and DWT.")
            else:
                print(f"Sub-band {key} does not match between file and DWT.")
                mismatches = np.where(~comparison)
                for idx in zip(*mismatches):
                    print(f"Mismatch at {idx}: file value {data_file[idx]}, DWT value {data_dwt[idx]}")
        else:
            print(f"Sub-band {key} not found in DWT result.")

# Example usage
filename = '../../Final/serial/outputs/3out.bin'
subbands_file = read_wavelet3d_result(filename)

# Print the value at (20, 100, 100) of each sub-band from the file
for key, data in subbands_file.items():
    try:
        value = data[20, 100, 100]
        print(f"{key} at (20, 100, 100) from file: {value}")
    except IndexError:
        print(f"{key} does not have a value at (20, 100, 100)")

# Read the original input data
original_filename = '../../Final/data/3/3.bin'
shape_filename = '../../Final/data/3/3_shape.txt'

# Read the shape of the input volume
with open(shape_filename, 'r') as f:
    shape = tuple(map(int, f.read().strip().split(',')))

# Read the input volume data
input_vol = np.fromfile(original_filename, dtype=np.float32).reshape(shape)

# Perform the 3D DWT
wavelet = 'db2'  # Replace with your desired wavelet
levels = 1  # Replace with your desired level

subbands_dwt = perform_3d_dwt(input_vol, wavelet, levels)

# Plot the sub-bands
# plot_subbands(subbands_dwt)
# plot_subbands(subbands_file)

# Example usage
#compare_subbands(subbands_file, subbands_dwt)

# Print the value at (20, 100, 100) of each sub-band from the DWT
for key, data in subbands_dwt.items():
    try:
        value = data[20, 100, 100]
        print(f"{key} at (20, 100, 100) from pywavlets: {value}")
    except IndexError:
        print(f"{key} does not have a value at (20, 100, 100)")

aaa at (20, 100, 100) from file: 3029.12060546875
aad at (20, 100, 100) from file: -1.9417352676391602
ada at (20, 100, 100) from file: 1.5019131898880005
add at (20, 100, 100) from file: -5.563841819763184
daa at (20, 100, 100) from file: 203.12469482421875
dad at (20, 100, 100) from file: -41.26618957519531
dda at (20, 100, 100) from file: 37.67226028442383
ddd at (20, 100, 100) from file: -0.8715343475341797
aaa at (20, 100, 100) from pywavlets: 2869.1455078125
aad at (20, 100, 100) from pywavlets: -4.988800048828125
ada at (20, 100, 100) from pywavlets: -20.25414276123047
add at (20, 100, 100) from pywavlets: -12.716032028198242
daa at (20, 100, 100) from pywavlets: 116.29927825927734
dad at (20, 100, 100) from pywavlets: -24.41055679321289
dda at (20, 100, 100) from pywavlets: 4.876809597015381
ddd at (20, 100, 100) from pywavlets: -29.195022583007812
