In [6]:
import torch
import matplotlib.pyplot as plt
import os
import numpy as np
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'

In [2]:
def plot_comparison_binary(input_data, output_data, sample_idx, save_path=None):
    """Plot heatmap comparison between input and output data"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
    
    # Plot input data
    im1 = ax1.imshow(input_data, aspect='auto', interpolation='nearest', cmap='gray_r', vmin=0, vmax=1)
    ax1.set_title('Original Data')
    ax1.set_xlabel('Time (ms)')
    # Set x-axis ticks to show time in ms
    xticks = ax1.get_xticks()
    ax1.set_xticks(xticks)
    ax1.set_xticklabels([f'{int(x/30)}' for x in xticks])
    ax1.set_xlim(0, input_data.shape[1])  # Show full range
    ax1.text(input_data.shape[1], -1, '30ms', ha='right')
    ax1.set_ylabel('Neuron')
    plt.colorbar(im1, ax=ax1)
    
    # Plot output data
    im2 = ax2.imshow(output_data, aspect='auto', interpolation='nearest', cmap='gray_r', vmin=0)
    ax2.set_title('Reconstructed Data')
    ax2.set_xlabel('Time (ms)')
    # Set x-axis ticks to show time in ms
    xticks = ax2.get_xticks()
    ax2.set_xticks(xticks)
    ax2.set_xticklabels([f'{int(x/30)}' for x in xticks])
    ax2.set_xlim(0, output_data.shape[1])  # Show full range

    ax2.set_ylabel('Neuron')
    plt.colorbar(im2, ax=ax2)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, format='pdf', dpi=300)
        plt.close()
    else:
        plt.show()

In [17]:
for i in range(1, 6):
    # Load input tensor
    input_path = os.path.join("model_outputs_binary", f"input_{i}.pt")
    output_path = os.path.join("model_outputs_binary", f"output_{i}.pt")
    
    try:
        input_tensor = torch.load(input_path, weights_only=True).squeeze(0).to(torch.float32)
        output_tensor = torch.load(output_path, weights_only=True).squeeze(0).to(torch.float32)
        
        print(f"Input shape: {input_tensor.shape}")
        print(f"Output shape: {output_tensor.shape}")
        
        # Convert tensors to numpy arrays
        input_data = input_tensor.numpy()
        output_data = output_tensor.numpy()
        
        # Create visualization
        plot_comparison_binary(
            input_data,
            output_data,
            sample_idx=i,
            save_path=os.path.join("visualization_results", f"comparison_binary_encodec_{i}.pdf")
        )
        
    except Exception as e:
        print(f"Error processing sample {i}: {e}")


Input shape: torch.Size([75, 900])
Output shape: torch.Size([75, 900])
Input shape: torch.Size([75, 900])
Output shape: torch.Size([75, 900])
Input shape: torch.Size([75, 900])
Output shape: torch.Size([75, 900])
Input shape: torch.Size([75, 900])
Output shape: torch.Size([75, 900])
Input shape: torch.Size([75, 900])
Output shape: torch.Size([75, 900])


In [7]:
def plot_comparison_binned(input_data, output_data, sample_idx, save_path=None):
    """Plot heatmap comparison between input and output data"""
    plt.figure(figsize=(12, 5), dpi=300)
    
    # Get max value for normalization
    vmax = max(input_data.max(), output_data.max())
    vmin = 0
    
    # Calculate extent
    factor = 10
    extent = [0, input_data.shape[1]*factor, input_data.shape[0], 0]  # [left, right, bottom, top]
    
    # Plot input data
    plt.subplot(121)
    plt.imshow(input_data, 
              aspect='auto',
              cmap='viridis',
              extent=extent,
              vmin=vmin,
              vmax=vmax)
    plt.colorbar()
    plt.title('Original Data')
    plt.xlabel('Time (ms)')
    plt.ylabel('Neuron')
    
    # Plot output data
    plt.subplot(122)
    plt.imshow(output_data,
              aspect='auto',
              cmap='viridis', 
              extent=extent,
              vmin=vmin,
              vmax=vmax)
    plt.colorbar()
    plt.title('Reconstructed Data')
    plt.xlabel('Time (ms)')
    plt.ylabel('Neuron')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

In [8]:
for i in range(1, 6):
    # Load input tensor
    input_path = os.path.join("model_outputs_binned_maintain_time", f"input_{i}.pt")
    output_path = os.path.join("model_outputs_binned_maintain_time", f"output_{i}.pt")
    
    try:
        input_tensor = torch.load(input_path, weights_only=True).squeeze(0).to(torch.float32)
        output_tensor = torch.load(output_path, weights_only=True).squeeze(0).to(torch.float32)
        
        print(f"Input shape: {input_tensor.shape}")
        print(f"Output shape: {output_tensor.shape}")
        
        # Convert tensors to numpy arrays
        input_data = input_tensor.numpy()
        output_data = output_tensor.numpy()
        
        # Create visualization
        plot_comparison_binned(
            input_data,
            output_data,
            sample_idx=i,
            save_path=os.path.join("visualization_results", f"comparison_binned_encodec_{i}.pdf")
        )
        
    except Exception as e:
        print(f"Error processing sample {i}: {e}")


Input shape: torch.Size([75, 27])
Output shape: torch.Size([75, 27])
Input shape: torch.Size([75, 27])
Output shape: torch.Size([75, 27])
Input shape: torch.Size([75, 27])
Output shape: torch.Size([75, 27])
Input shape: torch.Size([75, 27])
Output shape: torch.Size([75, 27])
Input shape: torch.Size([75, 27])
Output shape: torch.Size([75, 27])
