In [2]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import os
import glob
import math



def plot_profiles(profiles, max_val=None, min_val=None):
    max_h = 0       # red
    min_h = 120     # blue
    if not max_val:
        max_val = np.max(profiles)
    if not min_val:
        min_val = np.min(profiles)
    #print(max_val, min_val)
    heat_map_val = np.clip(profiles, min_val, max_val)
    heat_map = np.zeros(
        (heat_map_val.shape[0], heat_map_val.shape[1], 3), dtype=np.uint8)
    # print(heat_map_val.shape)
    heat_map[:, :, 0] = heat_map_val / \
        (max_val + 1e-6) * (max_h - min_h) + min_h
    heat_map[:, :, 1] = np.ones(heat_map_val.shape) * 255
    heat_map[:, :, 2] = np.ones(heat_map_val.shape) * 255
    heat_map = cv2.cvtColor(heat_map, cv2.COLOR_HSV2BGR)
    return heat_map


def plot_profiles_split_channels(profiles, n_channels, maxval=None, minval=None):
    channel_width = profiles.shape[0] // n_channels

    profiles_img = np.zeros(
        ((channel_width + 5) * n_channels, profiles.shape[1], 3))

    for n in range(n_channels):
        channel_profiles = profiles[n * channel_width: (n + 1) * channel_width]
        profiles_img[n * (channel_width + 5): (n + 1) * (channel_width + 5) - 5,
                     :, :] = plot_profiles(channel_profiles, maxval, minval)

    return profiles_img

#change from visualizing as 1 channel to 4 channels
def vis(input):
    img_input = input.copy()
    diff_profiles_img = plot_profiles_split_channels(img_input.T, 1, 50000000, -50000000)
    #profiles11 = plot_profiles(profiles1, 20000000, -20000000)
    acous_npy_img = cv2.cvtColor(np.float32(diff_profiles_img), cv2.COLOR_BGR2RGB)
    plt.imshow(acous_npy_img.astype(np.uint16), aspect = 'auto')
    plt.savefig('./fake_img.png')

def vis_save(input):
    img_input = input.copy()
    diff_profiles_img = plot_profiles_split_channels(img_input.T, 1, 50000000, -50000000)
    #profiles11 = plot_profiles(profiles1, 20000000, -20000000)
    acous_npy_img = cv2.cvtColor(np.float32(diff_profiles_img), cv2.COLOR_BGR2RGB)
    plt.imshow(acous_npy_img.astype(np.uint16), aspect = 'auto')
    plt.savefig('./fake_img.png')

def vis_out(input):
    img_input = input.copy()
    diff_profiles_img = plot_profiles_split_channels(img_input.T, 1, 50000000, -50000000)
    #profiles11 = plot_profiles(profiles1, 20000000, -20000000)
    acous_npy_img = cv2.cvtColor(np.float32(diff_profiles_img), cv2.COLOR_BGR2RGB)
    return acous_npy_img.astype(np.uint16)

def vis_out_one_channel(input):
    img_input = input.copy()
    diff_profiles_img = plot_profiles(img_input, 50000000, -50000000)
    # diff_profiles_img = plot_profiles_split_channels(img_input.T, 1, 50000000, -50000000)
    #profiles11 = plot_profiles(profiles1, 20000000, -20000000)
    acous_npy_img = cv2.cvtColor(np.float32(diff_profiles_img), cv2.COLOR_BGR2RGB)
    return acous_npy_img.astype(np.uint16)


In [None]:
# Side-by-side comparison for all letters
left_data_dir = '/data/asl_test/dataset/session_0201/acoustic/diff/'
middle_data_dir = '/data/asl_test/dataset/session_0301/acoustic/diff/'
right_data_dir = '/data/asl_test/dataset/session_1201/acoustic/diff/'

left_npy_files = glob.glob(os.path.join(left_data_dir, '*.npy'))
middle_npy_files = glob.glob(os.path.join(middle_data_dir, '*.npy'))
right_npy_files = glob.glob(os.path.join(right_data_dir, '*.npy'))

# ZHJMNOA
letters = 'Z'

for letter in letters:
    files_0201_letter = [f for f in left_npy_files if f.endswith(f'_{letter}.npy')]
    files_0301_letter = [f for f in middle_npy_files if f.endswith(f'_{letter}.npy')]
    files_0101_letter = [f for f in right_npy_files if f.endswith(f'_{letter}.npy')]

    # Find the minimum number of files across all sessions
    n_files = min(len(files_0101_letter), len(files_0201_letter), len(files_0301_letter))
    if n_files == 0:
        print(f"No files found for letter {letter} in one or more sessions")
        print(f"Session 0201: {len(files_0201_letter)} files")
        print(f"Session 0301: {len(files_0301_letter)} files")
        print(f"Session 0101: {len(files_0101_letter)} files")
    else:
        files_0201_letter = files_0201_letter[:n_files]
        files_0301_letter = files_0301_letter[:n_files]
        files_0101_letter = files_0101_letter[:n_files]

        # Load all data to compute global statistics for each channel
        all_data_by_channel = [[] for _ in range(4)]  # List for each channel
        for i in range(n_files):
            data_0201 = np.load(files_0201_letter[i])
            data_0301 = np.load(files_0301_letter[i])
            data_0101 = np.load(files_0101_letter[i])
            for channel in range(4):
                all_data_by_channel[channel].extend([data_0201[channel], data_0301[channel], data_0101[channel]])

        # Compute global statistics for consistent scaling per channel
        vmin_by_channel = []
        vmax_by_channel = []
        for channel_data in all_data_by_channel:
            channel_flat = np.concatenate([d.flatten() for d in channel_data])
            vmin_by_channel.append(np.percentile(channel_flat, 0))
            vmax_by_channel.append(np.percentile(channel_flat, 99.97))

        # Get shape for aspect ratio
        sample_data = np.load(files_0201_letter[0])
        h, w = sample_data[0].shape
        subplot_width = 3  # Smaller width for compact layout
        subplot_height = subplot_width * h / w
        
        # For each file, we'll show 4 rows (one for each channel)
        samples_to_show = min(3, n_files)  # Show up to 3 samples
        n_rows = samples_to_show * 4  # 4 channels per sample
        
        fig_width = subplot_width * 3  # 3 columns (0201 left, 0301 middle, 0101 right)
        fig_height = subplot_height * n_rows

        fig, axes = plt.subplots(n_rows, 3, figsize=(fig_width, fig_height))

        for sample_idx in range(samples_to_show):
            data_0201 = np.load(files_0201_letter[sample_idx])
            data_0301 = np.load(files_0301_letter[sample_idx])
            data_0101 = np.load(files_0101_letter[sample_idx])

            for channel in range(4):
                row = sample_idx * 4 + channel  # 4 rows per sample
                
                # Session 0201 - Left column
                channel_0201 = data_0201[channel]
                img_0201 = plot_profiles(channel_0201, max_val=vmax_by_channel[channel], min_val=vmin_by_channel[channel])
                img_0201_rgb = cv2.cvtColor(img_0201, cv2.COLOR_BGR2RGB)
                axes[row, 0].imshow(img_0201_rgb, aspect='auto')
                axes[row, 0].set_title(f'0201 Sample {sample_idx+1} Ch{channel}\n{os.path.basename(files_0201_letter[sample_idx])}', fontsize=8)
                
                # Session 0301 - Middle column
                channel_0301 = data_0301[channel]
                img_0301 = plot_profiles(channel_0301, max_val=vmax_by_channel[channel], min_val=vmin_by_channel[channel])
                img_0301_rgb = cv2.cvtColor(img_0301, cv2.COLOR_BGR2RGB)
                axes[row, 1].imshow(img_0301_rgb, aspect='auto')
                axes[row, 1].set_title(f'0301 Sample {sample_idx+1} Ch{channel}\n{os.path.basename(files_0301_letter[sample_idx])}', fontsize=8)

                # Session 0101 - Right column
                channel_0101 = data_0101[channel]
                img_0101 = plot_profiles(channel_0101, max_val=vmax_by_channel[channel], min_val=vmin_by_channel[channel])
                img_0101_rgb = cv2.cvtColor(img_0101, cv2.COLOR_BGR2RGB)
                axes[row, 2].imshow(img_0101_rgb, aspect='auto')
                axes[row, 2].set_title(f'0101 Sample {sample_idx+1} Ch{channel}\n{os.path.basename(files_0101_letter[sample_idx])}', fontsize=8)
                
                # Add labels
                if channel == 0:
                    axes[row, 0].set_ylabel(f'Sample {sample_idx+1}')
                axes[row, 2].set_ylabel(f'Channel {channel}')

        plt.tight_layout()
        plt.show()
        
        print(f"Displayed {samples_to_show} samples with all 4 channels for letter {letter}")
        print(f"Session 0201: {len(files_0201_letter)} total files")
        print(f"Session 0301: {len(files_0301_letter)} total files")
        print(f"Session 0101: {len(files_0101_letter)} total files")
        