In [None]:
#tag test decoder


import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"

from dataset_new import *
from models_new_2 import *
from visualisation_new_2 import *
from perturbation import *



### data loading ###

dataset_ID = 6661 # ID of a specific dataset. 6661 refer to preprocessed data with a mask of shape (4609,). 6660 refers to preprocessed data with a mask of shape (15364,)
mask_size = 4609 # number of voxels in the preprocessed fMRI data. either 4609 or 15364
#trainset, valset, testset = get_dataset(dataset_ID, mask_size) # data are loaded into dictionaries

trainset, valset, testset = get_dataset2(dataset_ID)

print_dict_tree(testset)

'''
testset2 = {
    "fMRIs": {
        "Sintel": np.load("processed_data/sub-S01/test/Sintel.npy"),
        "Payload": np.load("processed_data/sub-S01/test/Payload.npy"),
        "Chatter": np.load("processed_data/sub-S01/test/Chatter.npy")
    },
    "videos": {
        "Sintel": np.load("processed_data/videos/test/Sintel.npy"),
        "Payload": np.load("processed_data/videos/test/Payload.npy"),
        "Chatter": np.load("processed_data/videos/test/Chatter.npy")
    }
}

print_dict_tree(testset2)
'''

def prepare_temporal_data_by_movie(fmri_data_dict, video_data_dict, window_size=3):
    """
    Prepare temporal data with overlapping windows of TRs and their corresponding middle frames,
    separately for each movie.
    
    Args:
        fmri_data_dict: Dictionary of fMRI data per movie, where keys are movie names and 
                        values have shape (n_trs, mask_size)
        video_data_dict: Dictionary of video frames per movie, where keys are movie names and 
                         values have shape (n_trs, 3, 112, 112, 32)
        window_size: Number of consecutive TRs to use
    
    Returns:
        tr_windows_dict: Dictionary of windows of consecutive TRs per movie
        frame_targets_dict: Dictionary of middle frames for each TR in the windows per movie
    """
    tr_windows_dict = {}
    frame_targets_dict = {}
    
    # Process each movie separately
    for movie_name in fmri_data_dict.keys():
        fmri_data = fmri_data_dict[movie_name]
        all_frames = video_data_dict[movie_name]
        
        n_trs = fmri_data.shape[0]
        tr_windows = []
        frame_targets = []
        
        # Create sliding windows of TRs
        for i in range(n_trs - window_size + 1):
            # Get window of TRs
            tr_window = fmri_data[i:i+window_size]
            tr_windows.append(tr_window)
            
            # Get middle frame for each TR in the window
            frames_for_window = []
            for j in range(window_size):
                # Get middle frame (assuming 32 frames per TR)
                middle_frame_idx = 15  # Middle of 32 frames (0-indexed, so 15 is the 16th frame)
                middle_frame = all_frames[i+j, :, :, :, middle_frame_idx]
                frames_for_window.append(middle_frame)
            
            frame_targets.append(np.stack(frames_for_window))
        
        # Store results for this movie
        tr_windows_dict[movie_name] = np.array(tr_windows)
        frame_targets_dict[movie_name] = np.array(frame_targets)
    
    return tr_windows_dict, frame_targets_dict


window_size = 3 

testset3 = {
    "fMRIs": {},
    "videos": {}
}

testset3['fMRIs'], testset3['videos'] = prepare_temporal_data_by_movie(
    testset['fMRIs'], 
    testset['videos'], 
    window_size=window_size
)

print_dict_tree(testset3)
print("testset3 videos =", testset3["videos"].keys())



specific_frames_train = [681,681,681, 248,248,248, 3008,3008,3008, 1561, 1561, 1561, 1821, 1821, 1821, 2639, 2639, 2639, 467,467,467, 3558,3558,3558, 2173,2173,2173, 2119, 2119, 2119]
testset_small = {
    "fMRIs": {
        "subdict1": {}
    },
    "videos": {
        "subdict1": {}
    }
}

# Process training data to create temporal windows
#testset_small['fMRIs']["subdict1"], testset_small['videos']["subdict1"] = prepare_temporal_data(
#    testset['fMRIs'][specific_frames_train], 
#    testset['videos'][specific_frames_train], 
#    window_size=window_size
#)


#tag test decoder


def test_new_decoder(real=True, model_name='decoder_4609_1650', test_on_train=False, test_input = testset['fMRIs'], test_label = testset['videos'], add_name='', regions = [], temporal=False):
    '''
    Tests the decoder
    if real=False (default) -> tests on brain activity coming from encoder
    if real=True -> tests on real brain activity
    model_name is the name of the file with the model to be used
    test_on_train is for testing on the trainset
    test_input has the fMRIs for testing. Should be a dictionary with one subdictionary for each film.
    Each subdictionary is shaped (N, 4609) where N is the number of TRs for that film
    test_label has the films. Also one subdictionary for each film
    add_name is for adding something to the end of the name so the output isnt just the names of the film and model, so that doesnt get overwritten
    regions is an array with all the ids of the regions we wish to turn off for this run, for instance
    regions=[1,4,5] would turn off regions 1, 4 and 5. Look at turn_off_regions for what is the region of each ID
    '''
    print("testing decoder", model_name)

#    if regions != []:
#        fmri_regions_off = test_input.copy()
#        for video_name in test_input.keys():
#            fmri_regions_off[video_name] = turn_off_regions(test_input[video_name], regions)
#        test_input = fmri_regions_off
    

    #load decoder part of encoder decoder
    model = TemporalDecoder(mask_size)
#    model = Decoder(mask_size)
    #save_model_as = 'decoder_4609_50'
    #save_model_as = 'decoder_4609_1650'
    state_dict = torch.load(model_name)
    model.load_state_dict(state_dict)
    #model = model1.decoder

    #load data
#    if real:
#        test_input = testset['fMRIs']
#        test_label = testset['videos']
#    else:
    if not real:
        test_input = model1.encoder(testset['videos'])
        test_label = testset['videos']

    if test_on_train:
        num_samples = trainset["fMRIs"].shape[0]

        # Generate random indices based on that number
        random_indices = np.random.choice(num_samples, size=30, replace=False)

        # Create the testset with random samples
        testset2 = {
            "fMRIs": {
                "test": trainset["fMRIs"][random_indices]  # Shape will be (30, 4609)
            },
            "videos": {
                "test": trainset["videos"][random_indices]  # Shape will be (30, 3, 112, 112, 32)
            }
        }
        test_input = testset2['fMRIs']
        test_label = testset2['videos']


    criterion = Temporal_D_Loss()
#    criterion = D_Loss()
    device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
    pretrained_decoder=None
    model_to_test='decoder'
    statistical_testing = False
    display_plots = True
    save_plots = False

    test_model(test_input, test_label, model, criterion, device, pretrained_decoder, model_to_test, statistical_testing, display_plots, save_plots, model_name=model_name + add_name, temporal=temporal)
    return


#test_new_decoder(real=True, model_name='/media/RCPNAS/MIP/Michael/students_work/rodrigo/temporal_decoder_4609_350', test_input = testset3['fMRIs'], test_label = testset3['videos'], regions=[], add_name='', temporal=True)

test_new_decoder(real=True, model_name='/media/RCPNAS/MIP/Michael/students_work/rodrigo/temporal_decoder_4609_351_TRwindow5', test_input = testset3['fMRIs'], test_label = testset3['videos'], regions=[], add_name='', temporal=True)


# All code being ran for current perturbation (the losses for plotting are not updated though)

In [None]:
### necessary imports ###

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"

from dataset_new import *
from models_new_2 import *
from visualisation_new_2 import *
from perturbation import *


### data loading ###

dataset_ID = 6661 # ID of a specific dataset. 6661 refer to preprocessed data with a mask of shape (4609,). 6660 refers to preprocessed data with a mask of shape (15364,)
mask_size = 4609 # number of voxels in the preprocessed fMRI data. either 4609 or 15364
trainset, valset, testset = get_dataset(dataset_ID, mask_size) # data are loaded into dictionaries




#------------------------------------------------------------------------------end of cell 1--------------------------------------------------------------------




#test_new_decoder(real=True, model_name='decoder_4609_1650', test_on_train=False, test_input = testset['fMRIs'], test_label = testset['videos'], add_name='', regions = [], mask_slice=None)

def extract_frames(testset, specific_frames):
    """
    Extract specific frames from testset and combine them into a single array.
    
    Parameters:
    - testset: Dictionary with 'fMRIs' and 'videos' keys, each containing subdictionaries for each movie
    - specific_frames: Dictionary mapping movie names to lists of frame indices to extract
    
    Returns:
    - filtered_data: Dictionary with 'fMRIs' and 'videos' keys, each containing concatenated arrays of selected frames
    """
    # Create lists to store selected frames
    selected_fmris = []
    selected_videos = []
    
    # Loop through each movie in specific_frames
    for movie_name, frames in specific_frames.items():
        # Check if the movie exists in testset
        if movie_name not in testset['fMRIs'] or movie_name not in testset['videos']:
            print(f"Warning: {movie_name} not found in testset")
            continue
        
        # Get fMRI and video data for this movie
        movie_fmri = testset['fMRIs'][movie_name]
        movie_video = testset['videos'][movie_name]
        
        # Loop through each frame index
        for frame in frames:
            # Check if frame index is valid
            if frame >= len(movie_fmri):
                print(f"Warning: Frame {frame} out of range for {movie_name} (max={len(movie_fmri)-1})")
                continue
            
            # Add frame to selected lists
            selected_fmris.append(movie_fmri[frame])
            selected_videos.append(movie_video[frame])
    
    # Convert lists to numpy arrays
    if selected_fmris:
        fmris_array = np.array(selected_fmris)
        videos_array = np.array(selected_videos)
    else:
        print("No valid frames found")
        return None
    
    # Create filtered dataset with all frames in a single array
    filtered_data = {
        'fMRIs': {'combined': fmris_array},
        'videos': {'combined': videos_array}
    }
    
    # Print the structure of the filtered dataset
    print("Filtered dataset structure:")
    print("fMRIs")
    for movie, data in filtered_data['fMRIs'].items():
        print(f"  {movie} (shape: {data.shape})")
    print("videos")
    for movie, data in filtered_data['videos'].items():
        print(f"  {movie} (shape: {data.shape})")
    
    return filtered_data


def extract_frames_train(trainset, specific_frames_train):
    """
    Extract specific frames from testset and combine them into a single array.
    
    Parameters:
    - trainset: Dictionary with 'fMRIs' and 'videos' keys, each containing subdictionaries for each movie
    - specific_frames: Dictionary mapping movie names to lists of frame indices to extract
    
    Returns:
    - filtered_data: Dictionary with 'fMRIs' and 'videos' keys, each containing concatenated arrays of selected frames
    """
    # Create lists to store selected frames
    selected_fmris = []
    selected_videos = []
    
    # Get fMRI and video data for this movie
    movie_fmri = trainset['fMRIs']
    movie_video = trainset['videos']
    
    # Loop through each frame index
    for frame in specific_frames_train:
        # Check if frame index is valid
        if frame >= len(movie_fmri):
            print(f"Warning: Frame {frame} out of range (max={len(movie_fmri)-1})")
            continue
        
        # Add frame to selected lists
        selected_fmris.append(movie_fmri[frame])
        selected_videos.append(movie_video[frame])
    
    # Convert lists to numpy arrays
    if selected_fmris:
        fmris_array = np.array(selected_fmris)
        videos_array = np.array(selected_videos)
    else:
        print("No valid frames found")
        return None
    
    # Create filtered dataset with all frames in a single array
    filtered_data = {
        'fMRIs': {'combined': fmris_array},
        'videos': {'combined': videos_array}
    }
    
    # Print the structure of the filtered dataset
    print("Filtered dataset structure:")
    print("fMRIs")
    for movie, data in filtered_data['fMRIs'].items():
        print(f"  {movie} (shape: {data.shape})")
    print("videos")
    for movie, data in filtered_data['videos'].items():
        print(f"  {movie} (shape: {data.shape})")
    
    return filtered_data


def extract_frames_train_new(specific_frames_train):
    """
    Extract specific frames from trainset of new dataset and combine them into a single array.
    
    Parameters:
    - testset: Dictionary with 'fMRIs' and 'videos' keys, each containing subdictionaries for each movie
    - specific_frames: Dictionary mapping movie names to lists of frame indices to extract
    
    Returns:
    - filtered_data: Dictionary with 'fMRIs' and 'videos' keys, each containing concatenated arrays of selected frames
    """
    # Create lists to store selected frames
    selected_fmris = []
    selected_videos = []
    
    # Get fMRI and video data for this movie
    #movie_fmri = trainset['fMRIs']
    #movie_video = trainset['videos']

    fmris = np.load('processed_data/sub-S32/train.npy')
    videos = np.load('processed_data/videos/videos.npy')
    
    # Loop through each frame index
    for frame in specific_frames_train:
        # Check if frame index is valid
        if frame >= len(fmris):
            print(f"Warning: Frame {frame} out of range (max={len(fmris)-1})")
            continue
        
        # Add frame to selected lists
        selected_fmris.append(fmris[frame])
        selected_videos.append(videos[frame])
    
    # Convert lists to numpy arrays
#    if selected_fmris:
#        fmris_array = np.array(selected_fmris)
#        videos_array = np.array(selected_videos)
#    else:
#        print("No valid frames found")
#        return None
    
    # Create filtered dataset with all frames in a single array
    filtered_data = {
        'fMRIs': {'combined': selected_fmris},
        'videos': {'combined': selected_videos}
    }
    
    # Print the structure of the filtered dataset
    print("Filtered dataset structure:")
    print("fMRIs")
    for movie, data in filtered_data['fMRIs'].items():
        print(f"  {movie} (shape: {len(data)})")
    print("videos")
    for movie, data in filtered_data['videos'].items():
        print(f"  {movie} (shape: {len(data)})")
    
    print(filtered_data['videos'])

    return filtered_data



# Example usage:
specific_frames = {
    'AfterTheRain': [42],
    'BetweenViewings': [111],
    'Chatter': [21],
    'FirstBite': [33],          #gotta change this one
    'LessonLearned': [15, 36],
    'Payload': [18, 30],
    'Spaceman': [12],
    'TearsOfSteel': [39],
    'YouAgain': [300, 495]
}

specific_frames_train = [681, 248, 3008, 1561, 1821, 2639, 467, 3558, 2173, 2119]
#frame 2119 from the trainset is a very nice frame with a face


# Call the function to create the filtered dataset
filtered_testset = extract_frames(testset, specific_frames)
filtered_trainset = extract_frames_train(trainset, specific_frames_train)

trainset2 = {}
trainset2['fMRIs'] = np.memmap(f'encoder_dataset_{dataset_ID}/trainset/fMRIs.npy', dtype='float32', mode='r')
trainset2['videos'] = np.memmap(f'encoder_dataset_{dataset_ID}/trainset/videos.npy', dtype='float32', mode='r')

print("trainset fmris shape =", trainset2['fMRIs'].shape)
print("trainset videos shape =", trainset2['videos'].shape)

'''
trainset_new = {}
#trainset_new['fMRIs'] = np.memmap(f'processed_data/sub-S32/train.npy', dtype='float32', mode='r').reshape(-1, mask_size)
#trainset_new['videos'] = np.memmap(f'processed_data/videos/videos.npy', dtype='float32', mode='r').reshape(-1, 3, 112, 112, 32)

trainset_new['fMRIs'] = np.load('processed_data/sub-average/train.npy')
trainset_new['videos'] = np.load('processed_data/videos/videos.npy')

print("trainset_new fmris shape =", trainset_new['fMRIs'].shape)
print("trainset_new videos shape =", trainset_new['videos'].shape)

filtered_trainset_new = extract_frames_train(trainset_new, specific_frames_train)     #made so the fmri data is on a normalized subject

#print_dict_tree(filtered_testset)

#test_new_decoder(real=True, model_name='decoder_4609_1650', test_on_train=False, test_input = testset['fMRIs'], test_label = testset['videos'], add_name='', regions = [], mask_slice=None)

#tag test filtered
#test_new_decoder(real=True, model_name='decoder_4609_350', test_input = filtered_testset['fMRIs'], test_label = filtered_testset['videos'], add_name='_mask_slice20', save_plots=False, mask_slice=0, all_frames=True)
'''







#------------------------------------------------------------------------------end of cell 2--------------------------------------------------------------------








#tag blocks

import numpy as np
import matplotlib.pyplot as plt

#i took out the comment here



def load_and_reshape_data(file_path):
    """Load region data and reshape to 3D"""
    regions_2d = np.load(file_path, mmap_mode='r')
    print("regions_2d shape =", regions_2d.shape)
    regions_3d = regions_2d.reshape(91, 109, 91)
    return regions_3d


def find_bounds(data_3d):
    """
    Find the minimum and maximum coordinates in each dimension where brain data exists.
    
    Parameters:
    -----------
    data_3d : numpy.ndarray
        3D brain data where non-zero values represent brain regions
        
    Returns:
    --------
    tuple
        ((x_min, x_max), (y_min, y_max), (z_min, z_max))
        The minimum and maximum coordinates in each dimension
    """
    
    # Find the indices of all non-zero voxels (brain regions)
    non_zero_voxels = np.where(data_3d > 0)
    
    # If there are no non-zero voxels, return the full dimensions
    if len(non_zero_voxels[0]) == 0:
        print("Warning: No non-zero values found in the data")
        return (0, data_3d.shape[0]-1), (0, data_3d.shape[1]-1), (0, data_3d.shape[2]-1)
    
    # Get the minimum and maximum indices in each dimension
    x_min, x_max = np.min(non_zero_voxels[0]), np.max(non_zero_voxels[0])
    y_min, y_max = np.min(non_zero_voxels[1]), np.max(non_zero_voxels[1])
    z_min, z_max = np.min(non_zero_voxels[2]), np.max(non_zero_voxels[2])
    #print("non_zero_voxels =", non_zero_voxels)
    #print("non_zero_voxels[0] =", non_zero_voxels[0])
    #print("non_zero_voxels[1] =", non_zero_voxels[1])
    #print("non_zero_voxels[2] =", non_zero_voxels[2])
    
    print(f"Brain boundaries found:")
    print(f"  X range: {x_min} to {x_max} (width: {x_max-x_min+1})")
    print(f"  Y range: {y_min} to {y_max} (height: {y_max-y_min+1})")
    print(f"  Z range: {z_min} to {z_max} (depth: {z_max-z_min+1})")
    
    return (x_min, x_max), (y_min, y_max), (z_min, z_max)


def visualize_blocks(data_3d, blocks, num_blocks=(3, 3, 3), selected_block=None, figsize=None):
    """
    Visualize brain blocks using maximum intensity projection for each z-layer
    Shows axial views for each layer along the z-dimension
    
    Parameters:
    -----------
    data_3d : numpy.ndarray
        3D brain data
    blocks : dict
        Dictionary mapping block IDs to block boundaries
    num_blocks : tuple
        Number of blocks along each dimension (x, y, z)
    selected_block : int, optional
        If provided, highlight this block
    figsize : tuple, optional
        Figure size, if None will be calculated based on z-dimension blocks
    """
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    
    # Unpack the number of blocks in each dimension
    nx, ny, nz = num_blocks
    
    # If figsize is not specified, calculate it based on number of z layers
    if figsize is None:
        figsize = (6 * nz, 6)
    
    # Create the figure with the appropriate number of subplots (one per z-layer)
    fig, axes = plt.subplots(1, nz, figsize=figsize)
    if nz == 1:
        axes = [axes]  # Make sure axes is always a list for consistency
    
    fig.suptitle(f"Brain Divided into {nx}x{ny}x{nz} Blocks", fontsize=16)
    
    # Get block division boundaries
    x_divisions = []
    y_divisions = []
    z_divisions = []
    
    for block_id, ((x_min, x_max), (y_min, y_max), (z_min, z_max)) in blocks.items():
        x_divisions.extend([x_min, x_max])
        y_divisions.extend([y_min, y_max])
        z_divisions.extend([z_min, z_max])
    
    # Get unique boundary values
    x_divisions = sorted(list(set(x_divisions)))
    y_divisions = sorted(list(set(y_divisions)))
    z_divisions = sorted(list(set(z_divisions)))
    
    # Print the division boundaries
    print("X divisions:", x_divisions)
    print("Y divisions:", y_divisions)
    print("Z divisions:", z_divisions)
    
    # Process each z-layer
    for z_idx in range(nz):
        # Get z-boundaries for this layer
        z_min = z_divisions[z_idx]
        z_max = z_divisions[z_idx + 1]
        
        # Create layer name with actual z-range
        layer_name = f"Layer {z_idx} (z={z_min}-{z_max-1})"
        
        # Extract the data for this z-layer
        layer_data = data_3d[:, :, z_min:z_max]
        
        # Create maximum intensity projection along z-axis for just this layer
        layer_projection = np.max(layer_data > 0, axis=2).astype(float)
        print(f"Max layer_projection for layer {z_idx}=", np.max(layer_projection))
        
        # Plot the projection
        im = axes[z_idx].imshow(layer_projection, cmap='gray', origin='lower')
        axes[z_idx].set_title(layer_name)
        
        # Get the data extent for proper coordinate mapping
        height, width = layer_projection.shape
        
        # Add grid lines using data coordinates instead of pixel coordinates
        for x in x_divisions[1:-1]:
            axes[z_idx].axhline(x, color='cyan', linestyle='-', alpha=0.7, linewidth=1)
        
        for y in y_divisions[1:-1]:
            axes[z_idx].axvline(y, color='cyan', linestyle='-', alpha=0.7, linewidth=1)
        
        # Add block numbers for this layer
        for y_idx in range(ny):
            for x_idx in range(nx):
                # Calculate block ID (1-based index) using the formula:
                # block_id = 1 + x_idx + y_idx * nx + z_idx * nx * ny
                block_id = 1 + x_idx + y_idx * nx + z_idx * nx * ny
                
                # Get x, y boundaries for this block
                x_min, x_max = x_divisions[x_idx], x_divisions[x_idx + 1]
                y_min, y_max = y_divisions[y_idx], y_divisions[y_idx + 1]
                
                # Calculate center of block in data coordinates
                x_center = (x_min + x_max) / 2
                y_center = (y_min + y_max) / 2
                
                # Add block ID label with a bounding box
                text_box = dict(facecolor='black', alpha=0.5, boxstyle='round')
                
                # Place the text in the center of the block
                axes[z_idx].text(y_center, x_center, str(block_id), 
                               ha='center', va='center', color='yellow', fontweight='bold',
                               fontsize=12, bbox=text_box)
                
                # Highlight the selected block if needed
                if selected_block and block_id == selected_block:
                    rect = patches.Rectangle((y_min, x_min), y_max-y_min, x_max-x_min, 
                                          fill=False, edgecolor='red', linewidth=2)
                    axes[z_idx].add_patch(rect)
    
    plt.tight_layout()
    plt.show()
    
    # Save figure if requested
    if hasattr(plt, 'savefig') and selected_block is not None:
        plt.savefig(f'brain_blocks_{nx}x{ny}x{nz}_with_{selected_block}_highlighted.png', dpi=300, bbox_inches='tight')
    
    return blocks


def visualize_blocks_2(data_3d, blocks, losses, num_blocks=(3, 3, 3), figsize=None):
    """
    Visualize brain blocks using maximum intensity projection for each z-layer
    Shows axial views for each layer along the z-dimension and displays loss values
    
    Parameters:
    -----------
    data_3d : numpy.ndarray
        3D brain data
    blocks : dict
        Dictionary mapping block IDs to block boundaries
    losses : array-like
        Array of loss values, one per block (index 0 corresponds to block 1)
    num_blocks : tuple
        Number of blocks along each dimension (x, y, z)
    figsize : tuple, optional
        Figure size, if None will be calculated based on z-dimension blocks
    """
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    
    # Unpack the number of blocks in each dimension
    nx, ny, nz = num_blocks
    
    # Convert losses to numpy array if it isn't already
    losses = np.array(losses)
    
    # Verify number of loss values matches number of blocks
    total_blocks = nx * ny * nz
    if len(losses) != total_blocks:
        raise ValueError(f"Expected {total_blocks} loss values, but got {len(losses)}")
    
    # Find the block with the highest absolute loss
    max_abs_loss_idx = np.argmax(np.abs(losses))
    # Convert to 1-based indexing for block ID
    selected_block = max_abs_loss_idx + 1
    max_abs_loss_value = losses[max_abs_loss_idx]
    
    print(f"Block {selected_block} has the highest absolute loss: {max_abs_loss_value}")
    
    # If figsize is not specified, calculate it based on number of z layers
    if figsize is None:
        figsize = (6 * nz, 6)
    
    # Create the figure with the appropriate number of subplots (one per z-layer)
    fig, axes = plt.subplots(1, nz, figsize=figsize)
    if nz == 1:
        axes = [axes]  # Make sure axes is always a list for consistency
    
    fig.suptitle(f"Brain Divided into {nx}x{ny}x{nz} Blocks with Loss Values", fontsize=16)
    
    # Get block division boundaries
    x_divisions = []
    y_divisions = []
    z_divisions = []
    
    for block_id, ((x_min, x_max), (y_min, y_max), (z_min, z_max)) in blocks.items():
        x_divisions.extend([x_min, x_max])
        y_divisions.extend([y_min, y_max])
        z_divisions.extend([z_min, z_max])
    
    # Get unique boundary values
    x_divisions = sorted(list(set(x_divisions)))
    y_divisions = sorted(list(set(y_divisions)))
    z_divisions = sorted(list(set(z_divisions)))
    
    # Print the division boundaries
    print("X divisions:", x_divisions)
    print("Y divisions:", y_divisions)
    print("Z divisions:", z_divisions)
    
    # Process each z-layer
    for z_idx in range(nz):
        # Get z-boundaries for this layer
        z_min = z_divisions[z_idx]
        z_max = z_divisions[z_idx + 1]
        
        # Create layer name with actual z-range
        layer_name = f"Layer {z_idx} (z={z_min}-{z_max-1})"
        
        # Extract the data for this z-layer
        layer_data = data_3d[:, :, z_min:z_max]
        
        # Create maximum intensity projection along z-axis for just this layer
        layer_projection = np.max(layer_data > 0, axis=2).astype(float)
        print(f"Max layer_projection for layer {z_idx}=", np.max(layer_projection))
        
        # Plot the projection
        im = axes[z_idx].imshow(layer_projection, cmap='gray', origin='lower')
        axes[z_idx].set_title(layer_name)
        
        # Add grid lines using data coordinates instead of pixel coordinates
        for x in x_divisions[1:-1]:
            axes[z_idx].axhline(x, color='cyan', linestyle='-', alpha=0.7, linewidth=1)
        
        for y in y_divisions[1:-1]:
            axes[z_idx].axvline(y, color='cyan', linestyle='-', alpha=0.7, linewidth=1)
        
        # Add block numbers and loss values for this layer
        for y_idx in range(ny):
            for x_idx in range(nx):
                # Calculate block ID (1-based index) using the formula:
                # block_id = 1 + x_idx + y_idx * nx + z_idx * nx * ny
                block_id = 1 + x_idx + y_idx * nx + z_idx * nx * ny
                
                # Get loss value for this block (subtract 1 for 0-based indexing)
                loss_value = losses[block_id - 1]
                
                # Get x, y boundaries for this block
                x_min, x_max = x_divisions[x_idx], x_divisions[x_idx + 1]
                y_min, y_max = y_divisions[y_idx], y_divisions[y_idx + 1]
                
                # Calculate center of block in data coordinates
                x_center = (x_min + x_max) / 2
                y_center = (y_min + y_max) / 2
                
                # Add block ID and loss value label with a bounding box
                text_box = dict(facecolor='black', alpha=0.5, boxstyle='round')
                
                # Place the text in the center of the block
                axes[z_idx].text(y_center, x_center, f"{block_id}\n{loss_value:.3f}", 
                               ha='center', va='center', color='yellow', fontweight='bold',
                               fontsize=10, bbox=text_box)
                
                # Highlight the block with the highest absolute loss
                if block_id == selected_block:
                    rect = patches.Rectangle((y_min, x_min), y_max-y_min, x_max-x_min, 
                                          fill=False, edgecolor='red', linewidth=2)
                    axes[z_idx].add_patch(rect)
    
    plt.tight_layout()
    plt.show()
    
    # Save figure if requested
    if hasattr(plt, 'savefig'):
        plt.savefig(f'brain_blocks_{nx}x{ny}x{nz}_losses.png', dpi=300, bbox_inches='tight')
    
    return blocks


def divide_brain_into_blocks(data_3d, num_blocks=(3, 3, 3)):
    """Divide the 3D brain data into blocks"""
    x_size, y_size, z_size = data_3d.shape
    
    # Calculate the size of each block
    x_block_size = x_size // num_blocks[0]
    y_block_size = y_size // num_blocks[1]
    z_block_size = z_size // num_blocks[2]
    
    # Create a dictionary to store block boundaries
    blocks = {}
    block_id = 1
    
    for z in range(num_blocks[2]):
        z_min = z * z_block_size
        z_max = (z + 1) * z_block_size if z < num_blocks[2] - 1 else z_size
        
        for y in range(num_blocks[1]):
            y_min = y * y_block_size
            y_max = (y + 1) * y_block_size if y < num_blocks[1] - 1 else y_size
            
            for x in range(num_blocks[0]):
                x_min = x * x_block_size
                x_max = (x + 1) * x_block_size if x < num_blocks[0] - 1 else x_size
                
                blocks[block_id] = ((x_min, x_max), (y_min, y_max), (z_min, z_max))
                block_id += 1
    
    #num_layers = num_blocks[2]

    
    return blocks


def get_regions_in_block(data_3d, block_boundaries):
    """Get all region IDs contained within a block"""
    (x_min, x_max), (y_min, y_max), (z_min, z_max) = block_boundaries
    
    # Extract the block
    block_data = data_3d[x_min:x_max, y_min:y_max, z_min:z_max]
    
    # Get unique region IDs (excluding 0/background)
    unique_regions = np.unique(block_data)
    unique_regions = unique_regions[unique_regions > 0]
    
    return unique_regions


def get_block_indices_1d(regions_3d, block_boundaries):
    """
    Get the 1D indices corresponding to a 3D block
    
    Parameters:
    -----------
    regions_3d : numpy.ndarray
        3D array of brain regions
    block_boundaries : tuple
        ((x_min, x_max), (y_min, y_max), (z_min, z_max)) for the block
    shape : tuple
        Shape of the full 3D array (default: (91, 109, 91))
        
    Returns:
    --------
    numpy.ndarray
        1D indices corresponding to the block
    """
    # Create a copy of the 3D array filled with zeros
    block_mask = np.zeros_like(regions_3d)
    
    # Extract the block boundaries
    (x_min, x_max), (y_min, y_max), (z_min, z_max) = block_boundaries
    
    # Set the block to 1
    block_mask[x_min:x_max, y_min:y_max, z_min:z_max] = 1
    
    # Reshape to 1D
    block_mask_1d = block_mask.reshape(-1)
    
    # Get the indices where the mask is 1
    block_indices = np.where(block_mask_1d == 1)[0]
    
    return block_indices


def convert_blocks_to_uncut_space(blocks, x_min, y_min, z_min):
    """
    Convert block coordinates from cut space to uncut space
    
    Parameters:
    -----------
    blocks : dict
        Dictionary with block IDs as keys and coordinate ranges as values
    x_min, y_min, z_min : int
        Minimum coordinates of the cut region in the original space
        
    Returns:
    --------
    dict
        Dictionary with block IDs and coordinate ranges in uncut space
    """
    uncut_blocks = {}
    
    for block_id, ((x_start, x_end), (y_start, y_end), (z_start, z_end)) in blocks.items():
        # Adjust coordinates to uncut space
        uncut_x = (x_start + x_min, x_end + x_min)
        uncut_y = (y_start + y_min, y_end + y_min)
        uncut_z = (z_start + z_min, z_end + z_min)
        
        uncut_blocks[block_id] = (uncut_x, uncut_y, uncut_z)
    
    return uncut_blocks


def turn_off_block_new(fmri_2d, flat_mask, block_id, blocks, x_min, y_min, z_min):
    """
    Turn off a specific block in fMRI data using a mask
    
    Parameters:
    -----------
    fmri_2d : numpy.ndarray
        2D array of fMRI data (samples x voxels)
    flat_mask : numpy.ndarray
        Flattened mask indicating which voxels in 3D space are used in fMRI data
    block_id : int
        ID of the block to turn off
    blocks : dict
        Dictionary with block IDs and coordinates in cut space
    x_min, y_min, z_min : int
        Minimum coordinates of the cut region in the original space
        
    Returns:
    --------
    numpy.ndarray
        Modified fMRI data with the specified block turned off
    """
    # Get the original 3D shape
    original_shape = (91, 109, 91)  # Adjust if your shape is different
    
    # Reshape flat_mask to 3D
    mask_3d = flat_mask.reshape(original_shape)
    
    # Convert block coordinates to uncut space
    uncut_blocks = convert_blocks_to_uncut_space(blocks, x_min, y_min, z_min)
    
    # Get the block boundaries in uncut space
    block_boundaries = uncut_blocks[block_id]
    (x_min_block, x_max_block), (y_min_block, y_max_block), (z_min_block, z_max_block) = block_boundaries
    
    # Create a copy of the mask for this block
    block_mask_3d = mask_3d.copy()
    
    # Zero out everything outside the block
    block_mask_3d[:x_min_block, :, :] = 0
    block_mask_3d[x_max_block:, :, :] = 0
    block_mask_3d[:, :y_min_block, :] = 0
    block_mask_3d[:, y_max_block:, :] = 0
    block_mask_3d[:, :, :z_min_block] = 0
    block_mask_3d[:, :, z_max_block:] = 0
    
    # Now block_mask_3d contains only voxels that are both in the mask and in the block
    flat_block_mask = block_mask_3d.flatten()
    
    # Get non-zero indices in the original mask
    original_non_zero = np.nonzero(flat_mask)[0]
    
    # Get non-zero indices in the block mask
    block_non_zero = np.nonzero(flat_block_mask)[0]
    
    # Find where these block indices appear in the original non-zero indices
    indices_to_turn_off = np.where(np.isin(original_non_zero, block_non_zero))[0]
    
    # Calculate statistics before turning off
    zeros_before = (fmri_2d == 0).sum()
    total_values = fmri_2d.size
    percent_zeros_before = (zeros_before / total_values) * 100
    print(f"Before turning off block {block_id} - Zeros: {zeros_before}/{total_values} ({percent_zeros_before:.2f}%)")
    print(f"Block {block_id} contains {len(indices_to_turn_off)} voxels out of 4609 total non-zero voxels")
    
    # Create a copy of the fMRI data
    fmri_modified = fmri_2d.copy()
    
    # Turn off the block
    if len(indices_to_turn_off) > 0:
        fmri_modified[:, indices_to_turn_off] = 0
    
    # Calculate statistics after turning off
    zeros_after = (fmri_modified == 0).sum()
    percent_zeros_after = (zeros_after / total_values) * 100
    new_zeros = zeros_after - zeros_before
    percent_increase = (new_zeros / total_values) * 100
    
    print(f"After turning off block {block_id} - Zeros: {zeros_after}/{total_values} ({percent_zeros_after:.2f}%)")
    print(f"Values turned off: {new_zeros}/{total_values} ({percent_increase:.2f}%)")
    
    return fmri_modified










#--------------------------------------------------------------------end of cell 3------------------------------------------------------------------------------






#tag blocks + ou tag blocks 2
#  2

'''
def visualize_blocks_2(data_3d, blocks, losses, num_blocks=(3, 3, 3), figsize=None):
    """
    Visualize brain blocks using maximum intensity projection for each z-layer
    Shows axial views for each layer along the z-dimension and displays loss values
    
    Parameters:
    -----------
    data_3d : numpy.ndarray
        3D brain data
    blocks : dict
        Dictionary mapping block IDs to block boundaries
    losses : array-like
        Array of loss values, one per block (index 0 corresponds to block 1)
    num_blocks : tuple
        Number of blocks along each dimension (x, y, z)
    figsize : tuple, optional
        Figure size, if None will be calculated based on z-dimension blocks
    """
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    
    # Unpack the number of blocks in each dimension
    nx, ny, nz = num_blocks
    
    # Convert losses to numpy array if it isn't already
    losses = np.array(losses)
    
    # Verify number of loss values matches number of blocks
    total_blocks = nx * ny * nz
    if len(losses) != total_blocks:
        raise ValueError(f"Expected {total_blocks} loss values, but got {len(losses)}")
    
    # Find the block with the highest absolute loss
    max_abs_loss_idx = np.argmax(np.abs(losses))
    # Convert to 1-based indexing for block ID
    selected_block = max_abs_loss_idx + 1
    max_abs_loss_value = losses[max_abs_loss_idx]
    
    print(f"Block {selected_block} has the highest absolute loss: {max_abs_loss_value}")
    
    # If figsize is not specified, calculate it based on number of z layers
    if figsize is None:
        figsize = (6 * nz, 6)
    
    # Create the figure with the appropriate number of subplots (one per z-layer)
    fig, axes = plt.subplots(1, nz, figsize=figsize)
    if nz == 1:
        axes = [axes]  # Make sure axes is always a list for consistency
    
    fig.suptitle(f"Brain Divided into {nx}x{ny}x{nz} Blocks with Loss Values", fontsize=16)
    
    # Get block division boundaries
    x_divisions = []
    y_divisions = []
    z_divisions = []
    
    for block_id, ((x_min, x_max), (y_min, y_max), (z_min, z_max)) in blocks.items():
        x_divisions.extend([x_min, x_max])
        y_divisions.extend([y_min, y_max])
        z_divisions.extend([z_min, z_max])
    
    # Get unique boundary values
    x_divisions = sorted(list(set(x_divisions)))
    y_divisions = sorted(list(set(y_divisions)))
    z_divisions = sorted(list(set(z_divisions)))
    
    # Print the division boundaries
    print("X divisions:", x_divisions)
    print("Y divisions:", y_divisions)
    print("Z divisions:", z_divisions)
    
    # Process each z-layer
    for z_idx in range(nz):
        # Get z-boundaries for this layer
        z_min = z_divisions[z_idx]
        z_max = z_divisions[z_idx + 1]
        
        # Create layer name with actual z-range
        layer_name = f"Layer {z_idx} (z={z_min}-{z_max-1})"
        
        # Extract the data for this z-layer
        layer_data = data_3d[:, :, z_min:z_max]
        
        # Create maximum intensity projection along z-axis for just this layer
        layer_projection = np.max(layer_data > 0, axis=2).astype(float)
        print(f"Max layer_projection for layer {z_idx}=", np.max(layer_projection))
        
        # Plot the projection
        im = axes[z_idx].imshow(layer_projection, cmap='gray', origin='lower')
        axes[z_idx].set_title(layer_name)
        
        # Add grid lines using data coordinates instead of pixel coordinates
        for x in x_divisions[1:-1]:
            axes[z_idx].axhline(x, color='cyan', linestyle='-', alpha=0.7, linewidth=1)
        
        for y in y_divisions[1:-1]:
            axes[z_idx].axvline(y, color='cyan', linestyle='-', alpha=0.7, linewidth=1)
        
        # Add block numbers and loss values for this layer
        for y_idx in range(ny):
            for x_idx in range(nx):
                # Calculate block ID (1-based index) using the formula:
                # block_id = 1 + x_idx + y_idx * nx + z_idx * nx * ny
                block_id = 1 + x_idx + y_idx * nx + z_idx * nx * ny
                
                # Get loss value for this block (subtract 1 for 0-based indexing)
                loss_value = losses[block_id - 1]
                
                # Get x, y boundaries for this block
                x_min, x_max = x_divisions[x_idx], x_divisions[x_idx + 1]
                y_min, y_max = y_divisions[y_idx], y_divisions[y_idx + 1]
                
                # Calculate center of block in data coordinates
                x_center = (x_min + x_max) / 2
                y_center = (y_min + y_max) / 2
                
                # Add block ID and loss value label with a bounding box
                text_box = dict(facecolor='black', alpha=0.5, boxstyle='round')
                
                # Place the text in the center of the block
                axes[z_idx].text(y_center, x_center, f"{block_id}\n{loss_value:.3f}", 
                               ha='center', va='center', color='yellow', fontweight='bold',
                               fontsize=10, bbox=text_box)
                
                # Highlight the block with the highest absolute loss
                if block_id == selected_block:
                    rect = patches.Rectangle((y_min, x_min), y_max-y_min, x_max-x_min, 
                                          fill=False, edgecolor='red', linewidth=2)
                    axes[z_idx].add_patch(rect)
    
    plt.tight_layout()
    plt.show()
    
    # Save figure if requested
    if hasattr(plt, 'savefig'):
        plt.savefig(f'brain_blocks_{nx}x{ny}x{nz}_losses.png', dpi=300, bbox_inches='tight')
    
    return blocks
'''



def visualize_blocks_3(data_3d, blocks, losses, num_blocks=(3, 3, 3), figsize=None, colormap='viridis'):
    """
    Visualize brain blocks using maximum intensity projection for each z-layer
    Shows axial views for each layer along the z-dimension and displays loss values with color-coded blocks
    
    Parameters:
    -----------
    data_3d : numpy.ndarray
        3D brain data
    blocks : dict
        Dictionary mapping block IDs to block boundaries
    losses : array-like
        Array of loss values, one per block (index 0 corresponds to block 1)
    num_blocks : tuple
        Number of blocks along each dimension (x, y, z)
    figsize : tuple, optional
        Figure size, if None will be calculated based on z-dimension blocks
    colormap : str
        Matplotlib colormap name to use for loss values
    """
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    import matplotlib.colors as colors
    
    # Unpack the number of blocks in each dimension
    nx, ny, nz = num_blocks
    
    # Convert losses to numpy array if it isn't already
    losses = np.array(losses)
    
    # Verify number of loss values matches number of blocks
    total_blocks = nx * ny * nz
    if len(losses) != total_blocks:
        raise ValueError(f"Expected {total_blocks} loss values, but got {len(losses)}")
    
    # Find the block with the highest absolute loss
    max_abs_loss_idx = np.argmax(np.abs(losses))
    # Convert to 1-based indexing for block ID
    selected_block = max_abs_loss_idx + 1
    max_abs_loss_value = losses[max_abs_loss_idx]
    
    print(f"Block {selected_block} has the highest absolute loss: {max_abs_loss_value}")
    
    # Create a colormap normalization based on min/max loss values
    norm = colors.Normalize(vmin=np.min(losses), vmax=np.max(losses))
    cmap = plt.cm.get_cmap(colormap)
    
    # If figsize is not specified, calculate it based on number of z layers
    if figsize is None:
        figsize = (6 * nz, 6)
    
    # Create the figure with the appropriate number of subplots (one per z-layer)
    fig, axes = plt.subplots(1, nz, figsize=figsize)
    if nz == 1:
        axes = [axes]  # Make sure axes is always a list for consistency
    
    fig.suptitle(f"Brain Divided into {nx}x{ny}x{nz} Blocks with Loss Values", fontsize=16)
    
    # Get block division boundaries
    x_divisions = []
    y_divisions = []
    z_divisions = []
    
    for block_id, ((x_min, x_max), (y_min, y_max), (z_min, z_max)) in blocks.items():
        x_divisions.extend([x_min, x_max])
        y_divisions.extend([y_min, y_max])
        z_divisions.extend([z_min, z_max])
    
    # Get unique boundary values
    x_divisions = sorted(list(set(x_divisions)))
    y_divisions = sorted(list(set(y_divisions)))
    z_divisions = sorted(list(set(z_divisions)))
    
    # Print the division boundaries
    print("X divisions:", x_divisions)
    print("Y divisions:", y_divisions)
    print("Z divisions:", z_divisions)
    
    # Process each z-layer
    for z_idx in range(nz):
        # Get z-boundaries for this layer
        z_min = z_divisions[z_idx]
        z_max = z_divisions[z_idx + 1]
        
        # Create layer name with actual z-range
        layer_name = f"Layer {z_idx} (z={z_min}-{z_max-1})"
        
        # Extract the data for this z-layer
        layer_data = data_3d[:, :, z_min:z_max]
        
        # Create maximum intensity projection along z-axis for just this layer
        layer_projection = np.max(layer_data > 0, axis=2).astype(float)
        print(f"Max layer_projection for layer {z_idx}=", np.max(layer_projection))
        
        # Create a brain mask - identifies where brain tissue exists
        brain_mask = layer_projection > 0
        
        # Plot black background
        axes[z_idx].imshow(np.zeros_like(layer_projection), cmap='gray', origin='lower')
        axes[z_idx].set_title(layer_name)
        
        # We're removing the grid lines as requested
        # No more axhline or axvline calls here
        
        # Add colored blocks and loss values for this layer
        for y_idx in range(ny):
            for x_idx in range(nx):
                # Calculate block ID (1-based index) using the formula:
                # block_id = 1 + x_idx + y_idx * nx + z_idx * nx * ny
                block_id = 1 + x_idx + y_idx * nx + z_idx * nx * ny
                
                # Get loss value for this block (subtract 1 for 0-based indexing)
                loss_value = losses[block_id - 1]
                
                # Get x, y boundaries for this block
                x_min, x_max = x_divisions[x_idx], x_divisions[x_idx + 1]
                y_min, y_max = y_divisions[y_idx], y_divisions[y_idx + 1]
                
                # Calculate center of block in data coordinates
                x_center = (x_min + x_max) / 2
                y_center = (y_min + y_max) / 2
                
                # Get color from colormap based on loss value
                block_color = cmap(norm(loss_value))
                
                # Create a mask for this block
                block_mask = np.zeros_like(layer_projection, dtype=bool)
                block_mask[x_min:x_max, y_min:y_max] = True
                
                # Combine with brain mask to only color brain regions
                block_brain_mask = block_mask & brain_mask
                
                # If there are brain voxels in this block, add the colored overlay
                if np.any(block_brain_mask):
                    # Create a colored overlay image for this block
                    colored_overlay = np.zeros((*layer_projection.shape, 4))  # RGBA
                    colored_overlay[block_brain_mask, :] = block_color
                    
                    # Add the colored overlay to the plot
                    axes[z_idx].imshow(colored_overlay, origin='lower', interpolation='nearest')
                    
                    # We're removing the white outlines around each block
                    # No more Rectangle patch with fill=False
                
                # Add block ID and loss value label with a bounding box
                text_box = dict(facecolor='black', alpha=0.7, boxstyle='round')
                
                # Place the text in the center of the block
                axes[z_idx].text(y_center, x_center, f"{block_id}\n{loss_value:.3f}", 
                               ha='center', va='center', color='white', fontweight='bold',
                               fontsize=10, bbox=text_box)
                
                # Highlight the block with the highest absolute loss
                if block_id == selected_block:
                    highlight_rect = patches.Rectangle(
                        (y_min, x_min), y_max-y_min, x_max-x_min,
                        fill=False, edgecolor='red', linewidth=2
                    )
                    axes[z_idx].add_patch(highlight_rect)
    
    # Add a colorbar to show the mapping between loss values and colors
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    
    # Create a separate axis for the colorbar below the brain images
    cax = fig.add_axes([0.15, 0.05, 0.7, 0.02])  # [left, bottom, width, height]
    cbar = fig.colorbar(sm, cax=cax, orientation='horizontal')
    cbar.set_label('Loss Value')
    
    # Adjust layout to make room for colorbar
    plt.tight_layout(rect=[0, 0.1, 1, 0.95])
    plt.show()
    
    # Save figure if requested
    if hasattr(plt, 'savefig'):
        plt.savefig(f'brain_blocks_{nx}x{ny}x{nz}_losses_colormap.png', dpi=300, bbox_inches='tight')
    
    return blocks

def test_new_decoder(real=True, model_name='decoder_4609_1650', test_on_train=False, test_input=testset['fMRIs'], 
                     test_label=testset['videos'], add_name='', regions=[], block_id=None, save_plots=False, all_frames=False,
                     change_mode='off', num_blocks=None, metric="ssim", zones=None):
    '''
    Tests the decoder with brain block analysis
    
    Parameters:
    -----------
    real : bool
        If True, tests on real brain activity; if False, tests on brain activity from encoder
    model_name : str
        Name of the model file to be used
    test_on_train : bool
        If True, tests on the training set
    test_input : dict
        Dictionary with fMRIs for testing (one subdictionary for each film)
    test_label : dict
        Dictionary with films (one subdictionary for each film)
    add_name : str
        String to add to the end of output name to avoid overwriting
    regions : list
        List of region IDs to turn off (legacy parameter, use block_id instead)
    block_id : int
        ID of the 3D block to turn off (1-27)
    save_plots : bool
        If True, saves plots
    all_frames : bool
        If True, runs test_model_all, otherwise runs test_model
    change_mode : str
        'off' or 'amplify'
    zones (str or int, optional): Zones to consider for testing. Default is "quadrants", can also be "center_bg".
                                  If it is an integer, the function will analyze a number of zones = that integer squared.
                                  For example, if zones = 4, the function will analyze 16 zones (4x4).

    '''
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches

    print("zones = ", zones, "\n")
    
    # Helper functions for brain blocks (included directly to avoid import issues)
    
    print("Testing decoder", model_name)
    
    # Load the model
    model = Decoder(mask_size)
    state_dict = torch.load(model_name)
    model.load_state_dict(state_dict)
    
    # Testing parameters
    criterion = D_Loss()
    device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
    pretrained_decoder = None
    model_to_test = 'decoder'
    statistical_testing = False
    display_plots = True
    
    # Handle special case for training data
    if test_on_train:
        num_samples = trainset["fMRIs"].shape[0]
        random_indices = np.random.choice(num_samples, size=30, replace=False)
        testset2 = {
            "fMRIs": {
                "test": trainset["fMRIs"][random_indices]
            },
            "videos": {
                "test": trainset["videos"][random_indices]
            }
        }
        test_input = testset2['fMRIs']
        test_label = testset2['videos']
    
    # Check if we're using a brain block
    if num_blocks is not None:
#        print(f"Testing with brain block {block_id}")
        
        # Load the 3D brain data
        regions_3d = load_and_reshape_data('region_ids_4609+.npy')


        # Cut to the tightest rectangular prism around the brain
        (x_min, x_max), (y_min, y_max), (z_min, z_max) = find_bounds(regions_3d)
        brain_data = regions_3d[x_min:x_max+1, y_min:y_max+1, z_min:z_max+1]
#        brain_data = regions_3d
        print(f"Original shape: {regions_3d.shape}, Brain shape: {brain_data.shape}")

        # Divide into blocks
        blocks = divide_brain_into_blocks(brain_data, num_blocks)
        
        # Visualize the blocks with the specified block highlighted (using the layer-wise visualization)
#   ->     visualize_blocks(brain_data, blocks, num_blocks=num_blocks, selected_block=block_id)

        #brain_data is just for getting coordinates for the blocks, then we turn them off in the fmri original data

        
#        # Check if the block ID is valid
#        if block_id not in blocks:
#            print(f"Error: Block ID {block_id} not found in blocks dictionary")
#            return None
        
        # Turn off the specified block in all input videos
        

        mask4609 = np.load('mask_schaefer1000_4609.npy')

        # Flatten the mask
        flat_mask = mask4609.flatten()

        print("brain_data.shape =", brain_data.shape)

        modified_input = {}

        if block_id is not None:                    #if some block is specified
            for video_name in test_input.keys():
                print(f"Turning off block {block_id} for video {video_name}")

                visualize_blocks(brain_data, blocks, num_blocks=num_blocks, selected_block=block_id)
                
                # Use the existing turn_off_regions function for the regions in this block
                if change_mode == 'off':
    #                modified_data = turn_off_regions(test_input[video_name], regions_in_block)
    #                modified_data = turn_off_block(brain_data, blocks, block_id, test_input[video_name])
                    modified_data = turn_off_block_new(test_input[video_name], flat_mask, block_id, blocks, x_min, y_min, z_min)
    #                                turn_off_block_new(fmri_2d, flat_mask, block_id, blocks, x_min, y_min, z_min)
                else:
                    modified_data = turn_off_block_new(test_input[video_name], flat_mask, block_id, blocks, x_min, y_min, z_min)
    #                modified_data = turn_off_block(brain_data, blocks, block_id, test_input[video_name])
    #                modified_data = turn_off_regions(test_input[video_name], regions_in_block, 
    #                                               mode='amplify', amplify_factor=4)
                
                modified_input[video_name] = modified_data
                
            # Use the modified input for testing
            test_input = modified_input

            if all_frames:
                test_model_all(test_input, test_label, model, criterion, device, 
                            pretrained_decoder, model_to_test, statistical_testing, 
                            display_plots, save_plots, model_name=model_name + add_name, metric=metric)
                return None
            else:
                test_model(test_input, test_label, model, criterion, device, 
                        pretrained_decoder, model_to_test, statistical_testing, 
                        display_plots, save_plots, model_name=model_name + add_name)
                #           metric=metric)
                return None
            


        elif block_id is None:
            losses = []
            #loop through all blocks
            for block_id in range(1, num_blocks[0] * num_blocks[1] * num_blocks[2] + 1):
                for video_name in test_input.keys():
                    
                    # Use the existing turn_off_regions function for the regions in this block
                    if change_mode == 'off':
                        modified_data = turn_off_block_new(test_input[video_name], flat_mask, block_id, blocks, x_min, y_min, z_min)
                    else:
                        modified_data = turn_off_block_new(test_input[video_name], flat_mask, block_id, blocks, x_min, y_min, z_min)
                    
                    modified_input[video_name] = modified_data
                    
                # Use the modified input for testing
                test_input2 = modified_input

                if all_frames:
                    results = test_model_all(test_input2, test_label, model, criterion, device, 
                                pretrained_decoder, model_to_test, statistical_testing, 
                                display_plots, save_plots, model_name=model_name + add_name, metric=metric, mean_flag=True, 
                                zones=zones)
                    
                    if metric == "tv":
#                        print("keys do test performance =", results['test_performance'].keys())
                        mean_loss = results['test_performance']

                    elif metric == "ssim":
                        mean_loss = results['test_performance']
                    
                else:
                    test_model(test_input2, test_label, model, criterion, device, 
                            pretrained_decoder, model_to_test, statistical_testing, 
                            display_plots, save_plots, model_name=model_name + add_name)
                    #           metric=metric)
                print("Turned off block", block_id, ", mean loss =", mean_loss)
                losses.append(mean_loss)

            lossiest = np.argmax(losses)
            print("Block with the biggest change =", lossiest)
                
            #visualize_blocks_2(brain_data, blocks, num_blocks=num_blocks, losses=losses)
            # Example usage:
            visualize_blocks_3(
                brain_data,
                blocks,
                losses,
                num_blocks=(2, 2, 2),
                colormap='viridis'  # Other colormaps: 'hot', 'coolwarm', 'RdBu_r', etc.
            )
            
                #gonna have to make this function which plots with the loss for each zone
                #and shows in red the block with the biggest change

            return None
                
                # Use the existing turn_off_regions function for the regions in this block
#                if change_mode == 'off':

#                if all_frames:
#                    results = test_model_all(test_input, test_label, model, criterion, device, 
#                                pretrained_decoder, model_to_test, statistical_testing, 
#                                display_plots, save_plots, model_name=model_name + add_name, metric=metric)
#                    if metric == "tv":
#                        mean_loss = results['test_performance']['mean_tv_D']
#                    elif metric == "ssim":
#                        mean_loss = results['test_performance']['mean_ssim_D']

        



#    elif block_id is None and num_blocks is not None:       #loop through all blocks
#        for block_id in range(1, num_blocks[0] * num_blocks[1] * num_blocks[2]):
        
#        losses = []

        

        
    # Original code for regions list
    elif regions:
        # Specific regions case
        fmri_regions_off = test_input.copy()
        
        for video_name in test_input.keys():
            fmri_regions_off[video_name] = turn_off_regions(test_input[video_name], regions)
            
        test_input = fmri_regions_off
    
    
    
    print("test input shape =", print_dict_tree(test_input))
    print("test label shape =", print_dict_tree(test_label))

    # Run the appropriate test model function
    # Do recall this code is just being used for the case we want regions, not blocks
    if all_frames:
        results = test_model_all(test_input, test_label, model, criterion, device, 
                       pretrained_decoder, model_to_test, statistical_testing, 
                       display_plots, save_plots, model_name=model_name + add_name, metric=metric, mean_flag=True,
#                       zones="quadrants",
                       zones=zones
#                       zones = "center_bg"
                       )
        if metric == "tv":
            mean_loss = results['test_performance']['mean_tv_D']
        elif metric == "ssim":
            mean_loss = results['test_performance']['mean_ssim_D']

        return None
    else:
        test_model(test_input, test_label, model, criterion, device, 
                   pretrained_decoder, model_to_test, statistical_testing, 
                   display_plots, save_plots, model_name=model_name + add_name)
        #           metric=metric)
        return None
    

def run_all_blocks_test():
    # Run test with a specific block (e.g., block 14 - middle block)
    test_new_decoder(
        real=True,
        model_name='decoder_4609_350',
#        test_input=filtered_testset['fMRIs'],
#        test_label=filtered_testset['videos'],
        test_input=filtered_trainset['fMRIs'],
        test_label=filtered_trainset['videos'],
    #    model_name='decoder_all_4609_325',
    #    test_input=filtered_trainset_new['fMRIs'],
    #    test_label=filtered_trainset_new['videos'],
#        block_id=14,  # Choose a block number between 1 and 27
        all_frames=True,
        save_plots=False,
        add_name='_block_14',
        change_mode='off',
        #num_blocks=(3,3,3),
        num_blocks=(2,2,2),
        metric="tv",
        zones=4
    )
    
# Call the function
run_all_blocks_test()

#tag blocks end










#---------------------------------end of cell 4 (code above gives output with good brain bad reconstructions, code below gives bad brain good reconstructions)----------------------------












def test_new_decoder(real=True, model_name='decoder_4609_1650', test_on_train=False, test_input=testset['fMRIs'], 
                     test_label=testset['videos'], add_name='', regions=[], block_id=None, save_plots=False, all_frames=False,
                     change_mode='off', num_blocks=(3,3,3), metric="ssim"):
    '''
    Tests the decoder with brain block analysis
    
    Parameters:
    -----------
    real : bool
        If True, tests on real brain activity; if False, tests on brain activity from encoder
    model_name : str
        Name of the model file to be used
    test_on_train : bool
        If True, tests on the training set
    test_input : dict
        Dictionary with fMRIs for testing (one subdictionary for each film)
    test_label : dict
        Dictionary with films (one subdictionary for each film)
    add_name : str
        String to add to the end of output name to avoid overwriting
    regions : list
        List of region IDs to turn off (legacy parameter, use block_id instead)
    block_id : int
        ID of the 3D block to turn off (1-27)
    save_plots : bool
        If True, saves plots
    all_frames : bool
        If True, runs test_model_all, otherwise runs test_model
    change_mode : str
        'off' or 'amplify'
    '''
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    
    # Helper functions for brain blocks (included directly to avoid import issues)

    
    
    print("Testing decoder", model_name)
    
    # Load the model
    model = Decoder(mask_size)
    state_dict = torch.load(model_name)
    model.load_state_dict(state_dict)
    
    # Testing parameters
    criterion = D_Loss()
    device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
    pretrained_decoder = None
    model_to_test = 'decoder'
    statistical_testing = False
    display_plots = True
    
    # Handle special case for training data
    if test_on_train:
        num_samples = trainset["fMRIs"].shape[0]
        random_indices = np.random.choice(num_samples, size=30, replace=False)
        testset2 = {
            "fMRIs": {
                "test": trainset["fMRIs"][random_indices]
            },
            "videos": {
                "test": trainset["videos"][random_indices]
            }
        }
        test_input = testset2['fMRIs']
        test_label = testset2['videos']
    
    # Check if we're using a brain block
    if block_id is not None:
        print(f"Testing with brain block {block_id}")
        
        # Load the 3D brain data
        regions_3d = load_and_reshape_data('region_ids_4609+.npy')



        # Find the highest z-coordinate with non-zero values
#        z_non_zero = []
#        for z in range(regions_3d.shape[2]):
#            if np.any(regions_3d[:, :, z] > 0):
#                z_non_zero.append(z)

#        if z_non_zero:
#            print(f"Z-coordinates with non-zero values: min={min(z_non_zero)}, max={max(z_non_zero)}")
#        else:
#            print("No non-zero values found in any z-slice")
        # result was that min=26, max=58
            
        # Cut to the tightest rectangular prism around the brain
        (x_min, x_max), (y_min, y_max), (z_min, z_max) = find_bounds(regions_3d)
        brain_data = regions_3d[x_min:x_max+1, y_min:y_max+1, z_min:z_max+1]
#        brain_data = regions_3d
        print(f"Original shape: {regions_3d.shape}, Brain shape: {brain_data.shape}")

        # Divide into blocks
        blocks = divide_brain_into_blocks(brain_data, num_blocks)

#        uncut_blocks=convert_blocks_to_uncut_space(blocks, x_min, y_min, z_min)

#        block_indices = get_block_indices_1d(regions_3d, uncut_blocks[block_id])

        print("blocks =", blocks)

        print()
        
        # Visualize the blocks with the specified block highlighted (using the layer-wise visualization)
        #visualize_brain_blocks_layer_wise(regions_3d, blocks, selected_block=block_id)

        #visualize_brain_blocks_layer_wise(brain_data, blocks, num_blocks=num_blocks, selected_block=block_id)
        visualize_blocks(brain_data, blocks, num_blocks=num_blocks, selected_block=block_id)

        #brain_data is just for getting coordinates for the blocks, then we turn them off in the fmri original data

        
        # Check if the block ID is valid
        if block_id not in blocks:
            print(f"Error: Block ID {block_id} not found in blocks dictionary")
            return None
        
        # Get the regions in the specified block
#        boundaries = blocks[block_id]
#        regions_in_block = get_regions_in_block(regions_3d, boundaries)
        
#        print(f"Block {block_id} contains {len(regions_in_block)} unique regions")
        
        # Turn off the specified block in all input videos
        modified_input = {}

        mask4609 = np.load('mask_schaefer1000_4609.npy')

        # Flatten the mask
        flat_mask = mask4609.flatten()

        print("brain_data.shape =", brain_data.shape)
        for video_name in test_input.keys():
            print(f"Turning off block {block_id} for video {video_name}")
            
            # Use the existing turn_off_regions function for the regions in this block
            if change_mode == 'off':
#                modified_data = turn_off_regions(test_input[video_name], regions_in_block)
#                modified_data = turn_off_block(brain_data, blocks, block_id, test_input[video_name])
                modified_data = turn_off_block_new(test_input[video_name], flat_mask, block_id, blocks, x_min, y_min, z_min)
#                                turn_off_block_new(fmri_2d, flat_mask, block_id, blocks, x_min, y_min, z_min)
            else:
                modified_data = turn_off_block_new(test_input[video_name], flat_mask, block_id, blocks, x_min, y_min, z_min)
#                modified_data = turn_off_block(brain_data, blocks, block_id, test_input[video_name])
#                modified_data = turn_off_regions(test_input[video_name], regions_in_block, 
#                                               mode='amplify', amplify_factor=4)
            
            modified_input[video_name] = modified_data
            
        # Use the modified input for testing
        test_input = modified_input

        
    # Original code for regions list
    elif regions:
        # Specific regions case
        fmri_regions_off = test_input.copy()
        
        for video_name in test_input.keys():
            fmri_regions_off[video_name] = turn_off_regions(test_input[video_name], regions)
            
        test_input = fmri_regions_off
    
    
    
    print("test input shape =", print_dict_tree(test_input))
    print("test label shape =", print_dict_tree(test_label))

    # Run the appropriate test model function
    if all_frames:
        results = test_model_all(test_input, test_label, model, criterion, device, 
                       pretrained_decoder, model_to_test, statistical_testing, 
                       display_plots, save_plots, model_name=model_name + add_name, metric=metric)
        if metric == "tv":
            mean_loss = results['test_performance']['mean_tv_D']
        elif metric == "ssim":
            mean_loss = results['test_performance']['mean_ssim_D']

        return None
    else:
        test_model(test_input, test_label, model, criterion, device, 
                   pretrained_decoder, model_to_test, statistical_testing, 
                   display_plots, save_plots, model_name=model_name + add_name)
        #           metric=metric)
        return None

# Example usage
def run_block_test():
    # Run test with a specific block (e.g., block 14 - middle block)
    test_new_decoder(
        real=True,
        model_name='decoder_4609_350',
#        test_input=filtered_testset['fMRIs'],
#        test_label=filtered_testset['videos'],
        test_input=filtered_trainset['fMRIs'],
        test_label=filtered_trainset['videos'],
    #    model_name='decoder_all_4609_325',
    #    test_input=filtered_trainset_new['fMRIs'],
    #    test_label=filtered_trainset_new['videos'],
        block_id=14,  # Choose a block number between 1 and 27
        all_frames=True,
        save_plots=False,
        add_name='_block_14',
        change_mode='off',
        num_blocks=(3,3,3),
        metric="tv"
    )
    
# Call the function
#run_block_test()

# Call the function
# run_block_test()  # Uncomment to test a single block
# run_all_blocks_test()  # Uncomment to test all blocks sequentially






















#tag debug


def plot_all_predictions7(predictions, videos, performance_dict=None, display_plots=True, save_plots=False, 
                 save_path_prefix=None, model_name="", device="cuda" if torch.cuda.is_available() else "cpu",
                 metric="ssim", mean_flag=False, zone_type="quadrants", max_frames=None, baseline_predictions=None):
    """
    Display comparison plots between original videos and predictions.
    Shows: Original image, baseline reconstruction, perturbed reconstruction, and difference heatmap.
    
    Parameters:
    -----------
    predictions : dict
        Dictionary of prediction arrays (perturbed reconstructions)
    videos : dict
        Dictionary of ground truth video arrays
    performance_dict : dict, optional
        Dictionary to store performance metrics
    display_plots : bool
        Whether to display the plots
    save_plots : bool
        Whether to save the plots
    save_path_prefix : str, optional
        Path prefix for saving plots
    model_name : str
        Name of the model for saving plots
    device : str
        Device to use for computations
    metric : str
        Metric to use for evaluation: "ssim" or "tv" (Total Variation)
    mean_flag : bool
        Whether to return mean metrics or not
    zone_type : str or int
        Type of zones: 
        - "quadrants" for 2×2 grid
        - "center_bg" for center and background
        - integer n for n×n grid (e.g., 4 creates a 4×4 grid with 16 zones)
    max_frames : int, optional
        Maximum number of frames to plot. If None, all frames will be plotted.
    baseline_predictions : dict, optional
        Dictionary of baseline prediction arrays (without perturbation)
    """
    import numpy as np
    import matplotlib.pyplot as plt
    import torch
    import os
    from matplotlib.patches import Rectangle
    from matplotlib.colors import LinearSegmentedColormap, Normalize
    
    # Create output directory if it doesn't exist
    if save_plots and save_path_prefix:
        os.makedirs(save_path_prefix, exist_ok=True)

    # Debug information about inputs
    print("\n=== DEBUG INFO ===")
    print(f"Predictions dictionary contains {len(predictions)} keys: {list(predictions.keys())}")
    print(f"Videos dictionary contains {len(videos)} keys: {list(videos.keys())}")
    print(f"Using metric: {metric}")
    print(f"Baseline predictions provided: {baseline_predictions is not None}")
    
    if isinstance(zone_type, int):
        print(f"Using {zone_type}×{zone_type} grid zones ({zone_type*zone_type} total zones)")
    else:
        print(f"Using zone type: {zone_type}")
    
    # Find the overlapping keys between predictions and videos
    common_keys = [key for key in predictions.keys() if key in videos]
    print(f"Common keys in both dictionaries: {common_keys}")
    
    # Create key mapping between predictions and videos
    if not common_keys and len(videos) > 0:
        print("No common keys found. Trying to match prediction keys to video keys.")
        ref_video_key = list(videos.keys())[0]
        print(f"Using {ref_video_key} as reference video for all predictions")
        key_mapping = {pred_key: ref_video_key for pred_key in predictions.keys()}
    else:
        key_mapping = {}
        for pred_key in predictions.keys():
            if pred_key in videos:
                key_mapping[pred_key] = pred_key
            else:
                matched = False
                for video_key in videos.keys():
                    if video_key in pred_key:
                        key_mapping[pred_key] = video_key
                        matched = True
                        break
                if not matched and len(videos) > 0:
                    key_mapping[pred_key] = list(videos.keys())[0]
    
    print(f"Key mapping from prediction keys to video keys: {key_mapping}")
    
    # Helper function to split frame into zones
    def split_into_zones(frame, zone_type="quadrants", center_ratio=0.5):
        """
        Split a frame into zones.
        """
        if isinstance(frame, torch.Tensor):
            C, H, W = frame.shape
        else:
            C, H, W = frame.shape
            
        zones = {}
        
        if zone_type == "quadrants":
            # Split into 4 quadrants (2×2 grid)
            h_mid = H // 2
            w_mid = W // 2
            
            zones["top_left"] = (slice(None), slice(0, h_mid), slice(0, w_mid))
            zones["top_right"] = (slice(None), slice(0, h_mid), slice(w_mid, W))
            zones["bottom_left"] = (slice(None), slice(h_mid, H), slice(0, w_mid))
            zones["bottom_right"] = (slice(None), slice(h_mid, H), slice(w_mid, W))
            
        elif zone_type == "center_bg":
            # Split into center and background
            h_center = int(H * center_ratio)
            w_center = int(W * center_ratio)
            
            h_start = (H - h_center) // 2
            h_end = h_start + h_center
            w_start = (W - w_center) // 2
            w_end = w_start + w_center
            
            zones["center"] = (slice(None), slice(h_start, h_end), slice(w_start, w_end))
            
            # Background is everything except the center
            center_mask = np.zeros((H, W), dtype=bool)
            center_mask[h_start:h_end, w_start:w_end] = True
            
            zones["background"] = {"mask": ~center_mask, 
                                   "bounds": (h_start, h_end, w_start, w_end)}
            
        elif isinstance(zone_type, int) and zone_type > 0:
            # Create an n×n grid where n = zone_type
            n = zone_type
            
            # Calculate heights of each section
            h_sections = [i * H // n for i in range(n+1)]
            w_sections = [i * W // n for i in range(n+1)]
            
            # Create zones for each grid cell
            for i in range(n):
                for j in range(n):
                    zone_name = f"grid_{i}_{j}"  # Row_Column naming
                    zones[zone_name] = (
                        slice(None),
                        slice(h_sections[i], h_sections[i+1]),
                        slice(w_sections[j], w_sections[j+1])
                    )
                    
        else:
            raise ValueError(f"Unknown zone type: {zone_type}")
            
        return zones
    
    # Helper function to calculate zone metrics
    def calculate_zone_metrics(orig_frame, pred_frame, baseline_frame=None, zones=None, metric="ssim", device=device):
        """
        Calculate metrics for each zone.
        If baseline_frame is provided, calculate the difference: baseline_metrics - pred_metrics
        """
        from pytorch_msssim import ssim
        
        # If zones not provided, calculate them
        if zones is None:
            zones = split_into_zones(orig_frame, zone_type=zone_type)
        
        zone_metrics = {}
        
        # Convert to torch tensors if needed
        if not isinstance(orig_frame, torch.Tensor):
            orig_tensor = torch.from_numpy(orig_frame).unsqueeze(0)
        else:
            orig_tensor = orig_frame.unsqueeze(0)
            
        if not isinstance(pred_frame, torch.Tensor):
            pred_tensor = torch.from_numpy(pred_frame).unsqueeze(0)
        else:
            pred_tensor = pred_frame.unsqueeze(0)
        
        # Process baseline frame if provided    
        if baseline_frame is not None:
            if not isinstance(baseline_frame, torch.Tensor):
                baseline_tensor = torch.from_numpy(baseline_frame).unsqueeze(0)
            else:
                baseline_tensor = baseline_frame.unsqueeze(0)
        
        # Calculate metrics for each zone
        for zone_name, zone_slice in zones.items():
            # Special handling for background in center_bg mode
            if isinstance(zone_slice, dict):  # Background in center_bg mode
                mask = zone_slice["mask"]
                
                orig_zone = orig_tensor.clone()
                pred_zone = pred_tensor.clone()
                
                # Apply mask to all channels
                for c in range(orig_zone.shape[1]):  # For each channel
                    orig_zone[0, c][~mask] = 0
                    pred_zone[0, c][~mask] = 0
                
                # Calculate metric for prediction
                if metric == "ssim":
                    pred_metric = ssim(orig_zone, pred_zone, data_range=1, size_average=True).item()
                else:
                    # TV Loss calculation for masked region
                    tv_loss = torch.abs(pred_zone[:,:,1:,:] - pred_zone[:,:,:-1,:]).sum() + \
                              torch.abs(pred_zone[:,:,:,1:] - pred_zone[:,:,:,:-1]).sum()
                    # Normalize by number of pixels in the zone
                    pred_metric = tv_loss.item() / mask.sum()
                
                # If baseline provided, calculate baseline metric and difference
                if baseline_frame is not None:
                    baseline_zone = baseline_tensor.clone()
                    for c in range(baseline_zone.shape[1]):
                        baseline_zone[0, c][~mask] = 0
                        
                    if metric == "ssim":
                        base_metric = ssim(orig_zone, baseline_zone, data_range=1, size_average=True).item()
                        # For SSIM, higher is better, so baseline - perturbed shows how much we lost
                        # (negative value means perturbation improved SSIM)
                        zone_metrics[zone_name] = base_metric - pred_metric
                    else:
                        # TV Loss
                        tv_loss = torch.abs(baseline_zone[:,:,1:,:] - baseline_zone[:,:,:-1,:]).sum() + \
                                 torch.abs(baseline_zone[:,:,:,1:] - baseline_zone[:,:,:,:-1]).sum()
                        base_metric = tv_loss.item() / mask.sum()
                        # For TV loss, lower is better, so perturbed - baseline shows how much we lost
                        # (positive value means perturbation worsened TV loss)
                        zone_metrics[zone_name] = pred_metric - base_metric
                else:
                    # No baseline, just use the prediction metric
                    zone_metrics[zone_name] = pred_metric
                
            else:  # Normal zones
                # Get the zone data
                orig_zone = orig_tensor[0][zone_slice].unsqueeze(0)
                pred_zone = pred_tensor[0][zone_slice].unsqueeze(0)
                
                # Calculate metric for prediction
                if metric == "ssim":
                    pred_metric = ssim(orig_zone, pred_zone, data_range=1, size_average=True).item()
                else:
                    # TV Loss calculation
                    tv_loss = torch.abs(pred_zone[:,:,1:,:] - pred_zone[:,:,:-1,:]).sum() + \
                              torch.abs(pred_zone[:,:,:,1:] - pred_zone[:,:,:,:-1]).sum()
                    # Normalize by number of pixels in the zone
                    pred_metric = tv_loss.item() / (orig_zone.shape[2] * orig_zone.shape[3])
                
                # If baseline provided, calculate baseline metric and difference
                if baseline_frame is not None:
                    baseline_zone = baseline_tensor[0][zone_slice].unsqueeze(0)
                    
                    if metric == "ssim":
                        base_metric = ssim(orig_zone, baseline_zone, data_range=1, size_average=True).item()
                        # For SSIM, higher is better, so baseline - perturbed shows how much we lost
                        zone_metrics[zone_name] = base_metric - pred_metric
                    else:
                        # TV Loss
                        tv_loss = torch.abs(baseline_zone[:,:,1:,:] - baseline_zone[:,:,:-1,:]).sum() + \
                                 torch.abs(baseline_zone[:,:,:,1:] - baseline_zone[:,:,:,:-1]).sum()
                        base_metric = tv_loss.item() / (baseline_zone.shape[2] * baseline_zone.shape[3])
                        # For TV loss, lower is better, so perturbed - baseline shows how much we lost
                        # (positive value means perturbation worsened TV loss)
                        zone_metrics[zone_name] = pred_metric - base_metric
                else:
                    # No baseline, just use the prediction metric
                    zone_metrics[zone_name] = pred_metric
        
        return zone_metrics
    
    # Helper function for normalizing images for display
    def normalize(img):
        """Normalize image for display"""
        if isinstance(img, torch.Tensor):
            img = img.detach().cpu().numpy()
        
        img = img.copy()
        if img.min() < 0:
            img = (img + 1) / 2  # [-1, 1] -> [0, 1]
        return np.clip(img, 0, 1)
    
    # Initialize metrics
    all_zone_metrics = {}  # Will store metrics for each prediction key
    
    # Get appropriate metric name
    if baseline_predictions is not None:
        metric_description = "Difference from Baseline" 
        if metric == "ssim":
            metric_name = "SSIM Difference"
        else:
            metric_name = "TV Loss Difference"
    else:
        metric_name = "SSIM" if metric == "ssim" else "TV Loss"
        metric_description = metric_name
    
    # Create a TV loss calculator if needed
    if metric == "tv":
        tv_calculator = TotalVariation().to(device)
    
    # ===== PART 1: Calculate zone metrics for each prediction =====
    for pred_key, video_key in key_mapping.items():
        prediction = predictions[pred_key]
        video = videos[video_key][..., 15]  # Middle frame
        
        # Get baseline prediction if available
        baseline_pred = None
        if baseline_predictions is not None:
            # Find the appropriate key in baseline predictions
            baseline_keys = list(baseline_predictions.keys())
            if pred_key in baseline_predictions:
                baseline_pred = baseline_predictions[pred_key]
            elif len(baseline_keys) > 0:
                # If exact key not found, use first available baseline key
                baseline_pred = baseline_predictions[baseline_keys[0]]
                print(f"Using {baseline_keys[0]} as baseline for {pred_key}")
        
        # Check shapes
        print(f"Prediction {pred_key} shape: {prediction.shape}")
        print(f"Video {video_key} shape: {video.shape}")
        if baseline_pred is not None:
            print(f"Baseline prediction shape: {baseline_pred.shape}")
        
        # Ensure prediction and video have compatible shapes
        if prediction.shape[0] != video.shape[0]:
            print(f"Warning: Shape mismatch for {pred_key} vs {video_key}. Skipping.")
            continue
        
        # If baseline exists, ensure it has compatible shape
        if baseline_pred is not None and baseline_pred.shape[0] != prediction.shape[0]:
            print(f"Warning: Baseline shape {baseline_pred.shape} doesn't match prediction shape {prediction.shape}. Ignoring baseline.")
            baseline_pred = None
        
        N = video.shape[0]
        
        # Store metrics for all frames in this prediction
        pred_metrics = []
        
        # Calculate metrics for each frame
        for i in range(N):
            # Get zones for this frame
            zones = split_into_zones(video[i], zone_type=zone_type)
            
            # Calculate metrics for each zone
            try:
                # If baseline exists, calculate difference metrics
                if baseline_pred is not None:
                    zone_metrics = calculate_zone_metrics(
                        video[i], prediction[i], baseline_frame=baseline_pred[i], zones=zones, metric=metric
                    )
                else:
                    zone_metrics = calculate_zone_metrics(
                        video[i], prediction[i], zones=zones, metric=metric
                    )
                pred_metrics.append(zone_metrics)
            except Exception as e:
                print(f"Error calculating zone metrics for {pred_key}, frame {i}: {e}")
                # Create empty metrics
                if zone_type == "quadrants":
                    pred_metrics.append({
                        "top_left": 0, "top_right": 0, 
                        "bottom_left": 0, "bottom_right": 0
                    })
                elif zone_type == "center_bg":
                    pred_metrics.append({"center": 0, "background": 0})
                elif isinstance(zone_type, int):
                    empty_metrics = {}
                    for ii in range(zone_type):
                        for jj in range(zone_type):
                            empty_metrics[f"grid_{ii}_{jj}"] = 0
                    pred_metrics.append(empty_metrics)
        
        # Store metrics for this prediction
        all_zone_metrics[pred_key] = pred_metrics
        
        # Print average metrics for this prediction
        print(f"\nAverage {metric_description} for {pred_key} by zone:")
        
        # Calculate and print mean metrics across frames for each zone
        if len(pred_metrics) > 0:
            zone_names = list(pred_metrics[0].keys())
            
            for zone in zone_names:
                zone_values = [metrics[zone] for metrics in pred_metrics]
                mean_zone = np.mean(zone_values)
                print(f"  - {zone}: {mean_zone:.4f}")
    
    # ===== PART 2: Plot the original image, baseline, reconstruction, and heatmap side by side =====
    if display_plots and len(key_mapping) > 0:
        # Get a reference video key and shape
        ref_video_key = list(videos.keys())[0]
        ref_video = videos[ref_video_key][..., 15]
        N = ref_video.shape[0]
        
        # Determine the frames to plot
        if max_frames is not None and max_frames < N:
            # Evenly sample frames if max_frames is specified
            indices = np.linspace(0, N-1, max_frames, dtype=int)
        else:
            # Plot all frames
            indices = np.arange(N)
        
        # Determine number of panels based on whether baseline is available
        if baseline_predictions is not None:
            num_panels = 4  # Original, Baseline, Perturbed, Difference
            print(f"\nPlotting {len(indices)} frames with Original, Baseline, Perturbed, and Difference")
        else:
            num_panels = 3  # Original, Reconstruction, Heatmap
            print(f"\nPlotting {len(indices)} frames with Original, Reconstruction, and Heatmap")
        
        # REORDERING: Create an ordered list of prediction keys with "original_combined" first
        ordered_keys = []
        for key in key_mapping.keys():
            if key != "original_combined":
                ordered_keys.append(key)
        
        # If "original_combined" exists, insert it at the beginning of the list
        if "original_combined" in key_mapping:
            ordered_keys.insert(0, "original_combined")
        
        # For each frame index
        for frame_idx in indices:
            # Plot original reference frame
            ref_frame = ref_video[frame_idx]
            
            # For each prediction
            for pred_key in ordered_keys:
                try:
                    video_key = key_mapping[pred_key]
                    perturbed_frame = predictions[pred_key][frame_idx]
                    
                    # Get baseline frame if available
                    baseline_frame = None
                    if baseline_predictions is not None:
                        if pred_key in baseline_predictions:
                            baseline_frame = baseline_predictions[pred_key][frame_idx]
                        elif len(baseline_predictions) > 0:
                            # Use first available baseline prediction
                            first_key = list(baseline_predictions.keys())[0]
                            baseline_frame = baseline_predictions[first_key][frame_idx]
                    
                    # Get zone metrics for this prediction
                    zone_metrics = all_zone_metrics[pred_key][frame_idx]
                    
                    # Determine if we're showing differences or absolute values
                    is_difference = baseline_frame is not None
                    
                    # Create a figure with the appropriate number of subplots
                    fig, axes = plt.subplots(1, num_panels, figsize=(5 * num_panels, 5))
                    
                    # 1. Original Frame (leftmost)
                    axes[0].imshow(np.transpose(normalize(ref_frame), (1, 2, 0)))
                    axes[0].set_title(f"Original Frame {frame_idx}")
                    axes[0].axis('off')
                    
                    # Panel index for reconstruction and heatmap depends on whether baseline exists
                    recon_idx = 2 if is_difference else 1
                    heatmap_idx = 3 if is_difference else 2
                    
                    # 2. Baseline Reconstruction (if available)
                    if is_difference:
                        axes[1].imshow(np.transpose(normalize(baseline_frame), (1, 2, 0)))
                        axes[1].set_title(f"Baseline Reconstruction")
                        axes[1].axis('off')
                    
                    # 3. Perturbed Reconstruction 
                    axes[recon_idx].imshow(np.transpose(normalize(perturbed_frame), (1, 2, 0)))
                    if is_difference:
                        axes[recon_idx].set_title(f"Perturbed Reconstruction")
                    else:
                        axes[recon_idx].set_title(f"{pred_key} (Frame {frame_idx})")
                    axes[recon_idx].axis('off')
                    
                    # 4. Heatmap of metrics (rightmost)
                    # Determine colormap based on if we're showing differences
                    cmap_name = 'coolwarm' if is_difference else 'viridis'
                    
                    # Also determine normalization based on if we're showing differences
                    if is_difference:
                        # For differences, use symmetric normalization around zero
                        values = list(zone_metrics.values())
                        max_abs = max(abs(min(values)), abs(max(values))) if values else 1.0
                        norm = Normalize(vmin=-max_abs, vmax=max_abs)
                    else:
                        # For absolute values, use standard normalization
                        norm = None  # Let matplotlib handle it
                    
                    if isinstance(zone_type, int):
                        # For n×n grid, create a grid to display metrics
                        grid_values = np.zeros((zone_type, zone_type))
                        
                        for i in range(zone_type):
                            for j in range(zone_type):
                                zone_name = f"grid_{i}_{j}"
                                grid_values[i, j] = zone_metrics.get(zone_name, 0)
                        
                        # Show the heatmap
                        im = axes[heatmap_idx].imshow(grid_values, cmap=cmap_name, interpolation='nearest', norm=norm)
                        
                        # Add text labels with adjustable font size
                        fontsize = max(6, min(10, 16 - zone_type))  # Scale font size based on grid density
                        for i in range(zone_type):
                            for j in range(zone_type):
                                # Format the value based on magnitude
                                val = grid_values[i, j]
                                if abs(val) >= 0.01:
                                    text = f"{val:.3f}"
                                else:
                                    text = f"{val:.1e}"
                                    
                                axes[heatmap_idx].text(j, i, text,
                                           ha="center", va="center", color="white",
                                           fontsize=fontsize, fontweight='bold')
                        
                        # Add colorbar
                        cbar = plt.colorbar(im, ax=axes[heatmap_idx])
                        cbar.set_label(metric_name)
                        
                        axes[heatmap_idx].set_title(f"Zone {metric_description}")
                        
                    elif zone_type == "quadrants":
                        # For quadrants, create a 2×2 heatmap
                        quadrant_values = np.zeros((2, 2))
                        quadrant_values[0, 0] = zone_metrics.get("top_left", 0)
                        quadrant_values[0, 1] = zone_metrics.get("top_right", 0)
                        quadrant_values[1, 0] = zone_metrics.get("bottom_left", 0)
                        quadrant_values[1, 1] = zone_metrics.get("bottom_right", 0)
                        
                        # Show the heatmap
                        im = axes[heatmap_idx].imshow(quadrant_values, cmap=cmap_name, interpolation='nearest', norm=norm)
                        
                        # Add text labels
                        for i, j in [(0,0), (0,1), (1,0), (1,1)]:
                            val = quadrant_values[i, j]
                            if abs(val) >= 0.01:
                                text = f"{val:.3f}"
                            else:
                                text = f"{val:.1e}"
                            axes[heatmap_idx].text(j, i, text, ha="center", va="center", color="white", fontsize=10)
                        
                        # Add colorbar
                        cbar = plt.colorbar(im, ax=axes[heatmap_idx])
                        cbar.set_label(metric_name)
                        
                        axes[heatmap_idx].set_title(f"Quadrant {metric_description}")
                        
                    elif zone_type == "center_bg":
                        # For center/background, special visualization
                        center_val = zone_metrics.get("center", 0)
                        bg_val = zone_metrics.get("background", 0)
                        
                        # Create a mask-based visualization
                        mask = np.zeros((3, 3), dtype=bool)
                        mask[1, 1] = True  # Center is True, background is False
                        
                        # Create values array where center has one value, background another
                        values = np.ones((3, 3)) * bg_val
                        values[1, 1] = center_val
                        
                        # Show the heatmap
                        im = axes[heatmap_idx].imshow(values, cmap=cmap_name, interpolation='nearest', norm=norm)
                        
                        # Format values for display
                        if abs(center_val) >= 0.01:
                            center_text = f"Center\n{center_val:.3f}"
                        else:
                            center_text = f"Center\n{center_val:.1e}"
                            
                        if abs(bg_val) >= 0.01:
                            bg_text = f"BG\n{bg_val:.3f}"
                        else:
                            bg_text = f"BG\n{bg_val:.1e}"
                        
                        # Add text labels
                        axes[heatmap_idx].text(1, 1, center_text, ha="center", va="center", color="white", fontsize=10)
                        axes[heatmap_idx].text(0, 0, bg_text, ha="center", va="center", color="white", fontsize=10)
                        
                        # Add colorbar
                        cbar = plt.colorbar(im, ax=axes[heatmap_idx])
                        cbar.set_label(metric_name)
                        
                        axes[heatmap_idx].set_title(f"Center/Background {metric_description}")
                    
                    plt.tight_layout()
                    
                    # Save figure if requested
                    if save_plots and save_path_prefix:
                        zone_type_str = zone_type if isinstance(zone_type, str) else f"grid{zone_type}x{zone_type}"
                        type_str = "diff" if is_difference else "abs"
                        fig_path = f"{save_path_prefix}{pred_key}_frame{frame_idx}_{metric}_{zone_type_str}_{type_str}.png"
                        plt.savefig(fig_path, bbox_inches='tight', dpi=300)
                    
                    plt.show()
                    
                except Exception as e:
                    print(f"Error creating visualization for {pred_key} (Frame {frame_idx}): {e}")
    
    # Calculate overall mean metric
    overall_mean = 0
    if len(all_zone_metrics) > 0:
        # Average across all predictions and all zones
        all_values = []
        for pred_metrics in all_zone_metrics.values():
            for metrics in pred_metrics:
                all_values.extend(list(metrics.values()))
        
        if all_values:
            overall_mean = np.mean(all_values)
    
    # Update performance dictionary if provided
    if performance_dict is not None:
        try:
            # Calculate overall metrics across all zones
            for pred_key, pred_metrics in all_zone_metrics.items():
                if len(pred_metrics) > 0:
                    zone_names = list(pred_metrics[0].keys())
                    
                    for zone in zone_names:
                        zone_values = [metrics[zone] for metrics in pred_metrics]
                        zone_mean = np.mean(zone_values)
                        zone_median = np.median(zone_values)
                        
                        # Add to performance dict
                        if baseline_predictions is not None:
                            # This is a difference metric
                            if metric == "ssim":
                                performance_dict[f'diff_ssim_{zone}_D'] = zone_mean
                                performance_dict[f'median_diff_ssim_{zone}_D'] = zone_median
                            else:
                                performance_dict[f'diff_tv_{zone}_D'] = zone_mean
                                performance_dict[f'median_diff_tv_{zone}_D'] = zone_median
                        else:
                            # This is an absolute metric
                            if metric == "ssim":
                                performance_dict[f'mean_ssim_{zone}_D'] = zone_mean
                                performance_dict[f'median_ssim_{zone}_D'] = zone_median
                            else:
                                performance_dict[f'mean_tv_{zone}_D'] = zone_mean
                                performance_dict[f'median_tv_{zone}_D'] = zone_median
            
            # Add overall mean metric
            if baseline_predictions is not None:
                # This is a difference metric
                if metric == "ssim":
                    performance_dict['diff_ssim_D'] = overall_mean
                else:
                    performance_dict['diff_tv_D'] = overall_mean
            else:
                # This is an absolute metric
                if metric == "ssim":
                    performance_dict['mean_ssim_D'] = overall_mean
                else:
                    performance_dict['mean_tv_D'] = overall_mean
                
        except Exception as e:
            print(f"Error updating performance dictionary: {e}")
    
    if mean_flag:
        return overall_mean
    
    return performance_dict



def test_model_all(inputs_dict, labels_dict, model, criterion, device, pretrained_decoder=None, model_to_test=None, 
               statistical_testing=False, display_plots=True, save_plots=False, model_name="", metric="ssim", 
               mean_flag=False, zones=None, baseline_predictions=None):
    """
    Test the pretrained model on the provided dataset.

    Arguments:
        inputs_dict (dict): Dictionary of input data. Keys are movie names or slice identifiers. 
                           If model_to_test is 'encoder' or 'encoder_decoder', then elements have a shape of (TR, 3, 112, 112, 32). 
                           Else, shape of (TR, mask_size).
        labels_dict (dict): Dictionary of labels. Keys are movie names. 
                           If model_to_test is 'encoder' or 'encoder_decoder', then elements have a shape of (TR, mask_size). 
                           Else, shape of (TR, 3, 112, 112, 32).
        model (nn.Module): The pretrained neural network model to be tested.
        criterion (nn.Module): Loss function for testing.
        device (torch.device): Device to test the model on (CPU or GPU).
        pretrained_decoder (str, optional): Path to a pretrained decoder model. Default is None.
        model_to_test (str): Specifies which part of the model to test. Options are 'encoder', 'decoder', or 'encoder_decoder'.
        statistical_testing (bool, optional): Whether to perform statistical testing. Default is False.
        display_plots (bool, optional): Whether to display plots. Default is True.
        save_plots (bool, optional): Whether to save plots. Default is False.
        model_name (str, optional): Name of the model for saving plots. Default is "".
        zones (str or int, optional): Zones to consider for testing. Default is "quadrants", can also be "center_bg".
                                      If it is an integer, the function will analyze a number of zones = that integer squared.
                                      For example, if zones = 4, the function will analyze 16 zones (4x4).
        baseline_predictions (dict, optional): Dictionary of baseline predictions for comparison.

    Returns:
        results (dict): Dictionary containing test results including model predictions and losses.
    """
    print('Start testing:')
    tic = time.time()

    # Create outputs directory if it doesn't exist
    if save_plots:
        import os
        os.makedirs('outputs', exist_ok=True)

    model_type = ['encoder', 'decoder', 'encoder_decoder']
    if model_to_test not in model_type:
        print(f'model_to_test: {model_to_test} not recognized. Must be one of {model_type}')
        return None, None

    # Get list of input keys (movie names or slice identifiers)
    videos = list(inputs_dict.keys())
    inputs_shape = list(inputs_dict[videos[0]].shape)
    inputs_shape[0] = 'TR'
    print(f'### Testing {model_to_test} on inputs of shape {inputs_shape} over {len(videos)} videos/slices ###')
    
    if baseline_predictions is not None:
        print(f'### Using baseline predictions for comparison ###')

    criterion = criterion.to(device)
    # Set model in testing phase
    model.to(device)
    model.eval()

    # Load and set pretrained decoder if specified
    if pretrained_decoder:
        decoder = Decoder(labels_dict[next(iter(labels_dict))].shape[1])  # Assuming shape is consistent across labels
        state_dict = torch.load(pretrained_decoder)
        decoder.load_state_dict(state_dict)
        decoder.to(device)
        for param in decoder.parameters():
            param.requires_grad = False
        decoder.eval()

        print(f'Also using pretrained decoder {pretrained_decoder}')

    if model_to_test != 'encoder_decoder' and pretrained_decoder is None:
        results = {
            model_to_test + '_predictions': {},
            'total_losses': {}
        }
    else:
        results = {
            'encoder_predictions': {},
            'decoder_predictions': {},
            'total_losses': {}
        }

        decoder_saliency = np.zeros(labels_dict[list(labels_dict.keys())[0]].shape[1])

    results['test_performance'] = {}
    
    # Process each item in the inputs and labels dictionaries
    for key in inputs_dict.keys():
        input_tensor = torch.from_numpy(inputs_dict[key].astype('float32'))
        
        # Get the corresponding label - if it's a slice name, extract the original movie name
        if key in labels_dict:
            label_key = key
        else:
            # Extract the movie name from the key (assumed to be after the last underscore)
            # For keys like "slice_0_Payload", this will extract "Payload"
            if '_' in key:
                extracted_movie = key.split('_')[-1]
                if extracted_movie in labels_dict:
                    label_key = extracted_movie
                    print(f"Input key '{key}' not found in labels. Extracted and using '{label_key}' for labels.")
                else:
                    # If extracted name not found, use first available label
                    label_key = list(labels_dict.keys())[0]
                    print(f"Input key '{key}' and extracted movie '{extracted_movie}' not found in labels. Using '{label_key}' for labels.")
            else:
                # If no underscore in key, use first available label
                label_key = list(labels_dict.keys())[0]
                print(f"Input key '{key}' not found in labels and no movie name could be extracted. Using '{label_key}' for labels.")
        
        label_tensor = torch.from_numpy(labels_dict[label_key].astype('float32'))

        # Debug info about tensors but without using print_dict_tree
        print(f"input_tensor shape: {input_tensor.shape}, dtype: {input_tensor.dtype}")
        print(f"label_tensor shape: {label_tensor.shape}, dtype: {label_tensor.dtype}")
        
        test_set = torch.utils.data.TensorDataset(input_tensor, label_tensor)
        test_loader = torch.utils.data.DataLoader(
            test_set,
            batch_size=16,
            shuffle=False,
            pin_memory=torch.cuda.is_available(),
            num_workers=4
        )

        model_outputs, decoder_outputs, total_losses = [], [], []
        with torch.no_grad():
            for input, label in test_loader:
                input, label = input.to(device), label.to(device)
            
                decoder_output = None
                if model_to_test == 'encoder_decoder':
                    model_output, decoder_output = model(input.float())
                elif pretrained_decoder:
                    model_output = model(input.float()).to(device)
                    decoder_output = decoder(model_output.float())
                else:
                    model_output = model(input.float())
                        
                model_outputs.append(model_output.detach().cpu())
                if decoder_output is not None:
                    decoder_outputs.append(decoder_output.detach().cpu())
            
                # Apply the appropriate criterion based on the presence of decoder outputs
                if model_to_test == 'decoder':
                    *loss_metrics, total_loss, metrics_names = criterion(model_output, label[..., 15])          #--> middle frame
                elif decoder_output is None:
                    *loss_metrics, total_loss, metrics_names = criterion(model_output, label)
                else:
                    *loss_metrics, total_loss, metrics_names = criterion(model_output, label, decoder_output, input[..., 15])
                
                total_losses.append(total_loss.item())

        # Store the outputs in results
        if model_to_test != 'encoder_decoder' and pretrained_decoder is None:
            results[model_to_test + '_predictions'][key] = torch.cat(model_outputs, dim=0).numpy()
        else:
            results['encoder_predictions'][key] = torch.cat(model_outputs, dim=0).numpy()
            results['decoder_predictions'][key] = torch.cat(decoder_outputs, dim=0).numpy()
        
        results['total_losses'][key] = np.asarray(total_losses)

        if model_to_test != 'decoder':
            encoded = results['encoder_predictions'][key]
            labels = labels_dict[label_key] if key not in labels_dict else labels_dict[key]
            plot_metrics(labels, encoded, key, plot_TR=False, performance_dict=None, 
                        display_plots=display_plots,
                        save_plots=save_plots,
                        save_path=f'outputs/{key}_{model_name}.png' if save_plots else None)

    if model_to_test != 'decoder':
        all_encoded = results['encoder_predictions']
        all_labels = labels_dict
        # Using the last processed key for display
        results['test_performance'] = plot_metrics(labels, encoded, key, plot_TR=False, performance_dict=None, 
                        display_plots=display_plots,
                        save_plots=save_plots,
                        save_path=f'outputs/{key}_{model_name}.png' if save_plots else None)

        if statistical_testing:
            all_labels, all_predictions = [], []
            for key in labels_dict.keys():
                if key in results['encoder_predictions']:
                    all_predictions.append(results['encoder_predictions'][key])
                    all_labels.append(labels_dict[key])
            all_predictions = np.concatenate(all_predictions, axis=0)
            all_labels = np.concatenate(all_labels, axis=0)
            one_sample_permutation_test(all_labels, all_predictions)

    if model_to_test != 'encoder' or pretrained_decoder is not None:
        if model_to_test == 'decoder':
            print("\n\n\n ALRIGHT ZONES =", zones, "\n\n\n")
            
            # Check if we have baseline predictions to use
            if baseline_predictions is not None:
                # Use the modified plot_all_predictions7 with baseline comparison
                results['test_performance'] = plot_all_predictions7(
                    results['decoder_predictions'], 
                    labels_dict, 
                    results['test_performance'], 
                    display_plots,
                    save_plots=save_plots,
                    save_path_prefix='outputs/' if save_plots else None,
                    model_name=model_name, 
                    metric=metric, 
                    mean_flag=mean_flag,
                    zone_type=zones,
                    baseline_predictions=baseline_predictions
                )
            else:
                # Use the regular plot_all_predictions7 without baseline
                if zones is None:
                    results['test_performance'] = plot_all_predictions5(
                        results['decoder_predictions'], 
                        labels_dict, 
                        results['test_performance'], 
                        display_plots,
                        save_plots=save_plots,
                        save_path_prefix='outputs/' if save_plots else None,
                        model_name=model_name, 
                        metric=metric, 
                        mean_flag=mean_flag
                    )
                else:
                    results['test_performance'] = plot_all_predictions7(
                        results['decoder_predictions'], 
                        labels_dict, 
                        results['test_performance'], 
                        display_plots,
                        save_plots=save_plots,
                        save_path_prefix='outputs/' if save_plots else None,
                        model_name=model_name, 
                        metric=metric, 
                        mean_flag=mean_flag,
                        zone_type=zones
                    )

        else:
            # For encoder or encoder_decoder, use inputs_dict for ground truth
            if baseline_predictions is not None:
                results['test_performance'] = plot_all_predictions7(
                    results['decoder_predictions'], 
                    inputs_dict, 
                    results['test_performance'], 
                    display_plots,
                    save_plots=save_plots,
                    save_path_prefix='outputs/' if save_plots else None,
                    model_name=model_name, 
                    metric=metric, 
                    baseline_predictions=baseline_predictions
                )
            else:
                results['test_performance'] = plot_all_predictions5(
                    results['decoder_predictions'], 
                    inputs_dict, 
                    results['test_performance'], 
                    display_plots,
                    save_plots=save_plots,
                    save_path_prefix='outputs/' if save_plots else None,
                    model_name=model_name, 
                    metric=metric
                )
    print("using new function")
        
    if model_to_test == 'encoder_decoder':
        with torch.enable_grad():
            for key in inputs_dict.keys():
                predicted_fMRIs = torch.from_numpy(results['encoder_predictions'][key])
                # Get corresponding input for ground truth
                if key in inputs_dict:
                    input_key = key
                else:
                    # Use first input if key not found
                    input_key = list(inputs_dict.keys())[0]
                
                ground_truth_frames = torch.from_numpy(inputs_dict[input_key][..., 15])
                for i in range(predicted_fMRIs.shape[0]):
                    decoder_saliency += compute_saliency(model.decoder, predicted_fMRIs[i:i+1], ground_truth_frames[i:i+1], device)

        if display_plots:
            plot_saliency_distribution(decoder_saliency)
        results['decoder_saliency'] = decoder_saliency

    print("Testing completed. Total time: {:.2f} minutes".format((time.time() - tic) / 60))
    print('---')
    return results

def visualize_blocks_3(data_3d, blocks, losses, num_blocks=(3, 3, 3), figsize=None, colormap='viridis', 
                   title_prefix="Brain Divided into", is_difference=False):
    """
    Visualize brain blocks using maximum intensity projection for each z-layer
    Shows axial views for each layer along the z-dimension and displays loss values or differences
    
    Parameters:
    -----------
    data_3d : numpy.ndarray
        3D brain data
    blocks : dict
        Dictionary mapping block IDs to block boundaries
    losses : array-like
        Array of loss values or difference values, one per block (index 0 corresponds to block 1)
    num_blocks : tuple
        Number of blocks along each dimension (x, y, z)
    figsize : tuple, optional
        Figure size, if None will be calculated based on z-dimension blocks
    colormap : str
        Matplotlib colormap name to use for loss values
        'viridis' good for absolute values, 'coolwarm' good for differences
    title_prefix : str
        Prefix for the figure title
    is_difference : bool
        If True, values are treated as differences from baseline
    """
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    import matplotlib.colors as colors
    
    # Unpack the number of blocks in each dimension
    nx, ny, nz = num_blocks
    
    # Convert losses to numpy array if it isn't already
    losses = np.array(losses)
    
    # Print the actual loss values for debugging
    print("Loss/difference values for each block:")
    for i, val in enumerate(losses):
        print(f"  Block {i+1}: {val:.5f}")
    
    # Verify number of loss values matches number of blocks
    total_blocks = nx * ny * nz
    if len(losses) != total_blocks:
        raise ValueError(f"Expected {total_blocks} loss values, but got {len(losses)}")
    
    # Find the block with the highest absolute loss/difference
    max_abs_loss_idx = np.argmax(np.abs(losses))
    # Convert to 1-based indexing for block ID
    selected_block = max_abs_loss_idx + 1
    max_abs_loss_value = losses[max_abs_loss_idx]
    
    if is_difference:
        print(f"Block {selected_block} has the highest absolute difference from baseline: {max_abs_loss_value:.5f}")
    else:
        print(f"Block {selected_block} has the highest absolute loss: {max_abs_loss_value:.5f}")
    
    # If figsize is not specified, calculate it based on number of z layers
    if figsize is None:
        figsize = (6 * nz, 6)
    
    # Create the figure with the appropriate number of subplots (one per z-layer)
    fig, axes = plt.subplots(1, nz, figsize=figsize)
    if nz == 1:
        axes = [axes]  # Make sure axes is always a list for consistency
    
    title_text = f"{title_prefix} {nx}x{ny}x{nz} Blocks"
    if is_difference:
        title_text += " with Differences from Baseline"
    else:
        title_text += " with Loss Values"
        
    fig.suptitle(title_text, fontsize=16)
    
    # Get block division boundaries
    x_divisions = []
    y_divisions = []
    z_divisions = []
    
    for block_id, ((x_min, x_max), (y_min, y_max), (z_min, z_max)) in blocks.items():
        x_divisions.extend([x_min, x_max])
        y_divisions.extend([y_min, y_max])
        z_divisions.extend([z_min, z_max])
    
    # Get unique boundary values
    x_divisions = sorted(list(set(x_divisions)))
    y_divisions = sorted(list(set(y_divisions)))
    z_divisions = sorted(list(set(z_divisions)))
    
    # Print the division boundaries
    print("X divisions:", x_divisions)
    print("Y divisions:", y_divisions)
    print("Z divisions:", z_divisions)
    
    # Create a colormap normalization based on min/max loss values
    if is_difference and colormap == 'coolwarm':
        # For differences, we want a symmetric colormap centered at zero
        max_abs = max(abs(np.min(losses)), abs(np.max(losses)))
        norm = colors.Normalize(vmin=-max_abs, vmax=max_abs)
    else:
        # For absolute values or when not specifically using coolwarm for differences
        norm = colors.Normalize(vmin=np.min(losses), vmax=np.max(losses))
        
    cmap = plt.cm.get_cmap(colormap)
    
    # Process each z-layer
    for z_idx in range(nz):
        # Get z-boundaries for this layer
        z_min = z_divisions[z_idx]
        z_max = z_divisions[z_idx + 1]
        
        # Create layer name with actual z-range
        layer_name = f"Layer {z_idx} (z={z_min}-{z_max-1})"
        
        # Extract the data for this z-layer
        layer_data = data_3d[:, :, z_min:z_max]
        
        # Create maximum intensity projection along z-axis for just this layer
        layer_projection = np.max(layer_data > 0, axis=2).astype(float)
        print(f"Max layer_projection for layer {z_idx}=", np.max(layer_projection))
        
        # Create a brain mask - identifies where brain tissue exists
        brain_mask = layer_projection > 0
        
        # Plot black background
        axes[z_idx].imshow(np.zeros_like(layer_projection), cmap='gray', origin='lower')
        axes[z_idx].set_title(layer_name)
        
        # Add colored blocks and loss values for this layer
        for y_idx in range(ny):
            for x_idx in range(nx):
                # Calculate block ID (1-based index) using the formula:
                # block_id = 1 + x_idx + y_idx * nx + z_idx * nx * ny
                block_id = 1 + x_idx + y_idx * nx + z_idx * nx * ny
                
                # Get loss value for this block (subtract 1 for 0-based indexing)
                if block_id <= len(losses):
                    loss_value = losses[block_id - 1]
                else:
                    print(f"Warning: Block ID {block_id} exceeds losses array length {len(losses)}")
                    loss_value = 0
                
                # Get x, y boundaries for this block
                x_min, x_max = x_divisions[x_idx], x_divisions[x_idx + 1]
                y_min, y_max = y_divisions[y_idx], y_divisions[y_idx + 1]
                
                # Calculate center of block in data coordinates
                x_center = (x_min + x_max) / 2
                y_center = (y_min + y_max) / 2
                
                # Get color from colormap based on loss value
                block_color = cmap(norm(loss_value))
                
                # Create a mask for this block
                block_mask = np.zeros_like(layer_projection, dtype=bool)
                block_mask[x_min:x_max, y_min:y_max] = True
                
                # Combine with brain mask to only color brain regions
                block_brain_mask = block_mask & brain_mask
                
                # If there are brain voxels in this block, add the colored overlay
                if np.any(block_brain_mask):
                    # Create a colored overlay image for this block
                    colored_overlay = np.zeros((*layer_projection.shape, 4))  # RGBA
                    colored_overlay[block_brain_mask, :] = block_color
                    
                    # Add the colored overlay to the plot
                    axes[z_idx].imshow(colored_overlay, origin='lower', interpolation='nearest')
                
                # Format the displayed value
                if is_difference:
                    # For differences, show sign and format based on magnitude
                    if abs(loss_value) >= 0.01:
                        value_str = f"{loss_value:.4f}"
                    else:
                        value_str = f"{loss_value:.2e}"
                else:
                    # For absolute values
                    value_str = f"{loss_value:.4f}"
                
                # Add block ID and loss value label with a bounding box
                text_box = dict(facecolor='black', alpha=0.7, boxstyle='round')
                
                # Place the text in the center of the block
                axes[z_idx].text(y_center, x_center, f"{block_id}\n{value_str}", 
                               ha="center", va="center", color='white', fontweight='bold',
                               fontsize=10, bbox=text_box)
                
                # Highlight the block with the highest absolute loss/difference
                if block_id == selected_block:
                    highlight_rect = patches.Rectangle(
                        (y_min, x_min), y_max-y_min, x_max-x_min,
                        fill=False, edgecolor='red', linewidth=2
                    )
                    axes[z_idx].add_patch(highlight_rect)
    
    # Add a colorbar to show the mapping between loss values and colors
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    
    # Create a separate axis for the colorbar below the brain images
    cax = fig.add_axes([0.15, 0.05, 0.7, 0.02])  # [left, bottom, width, height]
    cbar = fig.colorbar(sm, cax=cax, orientation='horizontal')
    
    if is_difference:
        cbar.set_label('Difference from Baseline')
    else:
        cbar.set_label('Loss Value')
    
    # Adjust layout to make room for colorbar
    plt.tight_layout(rect=[0, 0.1, 1, 0.95])
    plt.show()
    
    # Save figure if requested
    if hasattr(plt, 'savefig'):
        suffix = "differences" if is_difference else "losses"
        plt.savefig(f'brain_blocks_{nx}x{ny}x{nz}_{suffix}_colormap.png', dpi=300, bbox_inches='tight')
    
    return blocks



def calculate_baseline_losses(model_name='decoder_4609_350', test_input=None, test_label=None, 
                           all_frames=True, save_plots=False, metric="tv", zones=4):
    """
    Calculate baseline reconstruction losses without any perturbation
    
    Parameters:
    -----------
    model_name : str
        Name of the model file to be used
    test_input : dict
        Dictionary with fMRIs for testing
    test_label : dict
        Dictionary with films
    all_frames : bool
        If True, runs test_model_all, otherwise runs test_model
    save_plots : bool
        If True, saves plots
    metric : str
        Metric to use for evaluation: "ssim" or "tv" (Total Variation)
    zones : str or int
        Zone configuration for analysis
        
    Returns:
    --------
    dict
        Dictionary with baseline losses for each zone
    """
    import numpy as np
    
    print("Calculating baseline losses with no perturbation")
    
    # Load the model
    model = Decoder(mask_size)
    state_dict = torch.load(model_name)
    model.load_state_dict(state_dict)
    
    # Testing parameters
    criterion = D_Loss()
    device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
    pretrained_decoder = None
    model_to_test = 'decoder'
    statistical_testing = False
    display_plots = True
    
    # Run test_model_all without perturbation to get baseline performance
    if all_frames:
        baseline_results = test_model_all(
            test_input, 
            test_label, 
            model, 
            criterion, 
            device, 
            pretrained_decoder, 
            model_to_test, 
            statistical_testing, 
            display_plots, 
            save_plots, 
            model_name=model_name + '_baseline', 
            metric=metric, 
            mean_flag=False,
            zones=zones
        )
        print("Baseline test completed")
        return baseline_results['test_performance']
    else:
        test_model(
            test_input, 
            test_label, 
            model, 
            criterion, 
            device, 
            pretrained_decoder, 
            model_to_test, 
            statistical_testing, 
            display_plots, 
            save_plots, 
            model_name=model_name + '_baseline'
        )
        return None
    


def test_new_decoder(real=True, model_name='decoder_4609_1650', test_on_train=False, test_input=testset['fMRIs'], 
                     test_label=testset['videos'], add_name='', regions=[], block_id=None, save_plots=False, all_frames=False,
                     change_mode='off', num_blocks=None, metric="ssim", zones=None, compare_to_baseline=True):
    '''
    Tests the decoder with brain block analysis
    
    Parameters:
    -----------
    real : bool
        If True, tests on real brain activity; if False, tests on brain activity from encoder
    model_name : str
        Name of the model file to be used
    test_on_train : bool
        If True, tests on the training set
    test_input : dict
        Dictionary with fMRIs for testing (one subdictionary for each film)
    test_label : dict
        Dictionary with films (one subdictionary for each film)
    add_name : str
        String to add to the end of output name to avoid overwriting
    regions : list
        List of region IDs to turn off (legacy parameter, use block_id instead)
    block_id : int
        ID of the 3D block to turn off (1-27)
    save_plots : bool
        If True, saves plots
    all_frames : bool
        If True, runs test_model_all, otherwise runs test_model
    change_mode : str
        'off' or 'amplify'
    zones (str or int, optional): Zones to consider for testing. Default is "quadrants", can also be "center_bg".
                                  If it is an integer, the function will analyze a number of zones = that integer squared.
                                  For example, if zones = 4, the function will analyze 16 zones (4x4).
    compare_to_baseline : bool
        If True, compute baseline performance without perturbation and report differences

    '''
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    import pickle
    import os

    print("zones = ", zones, "\n")
    
    # Load the model
    model = Decoder(mask_size)
    state_dict = torch.load(model_name)
    model.load_state_dict(state_dict)
    
    # Testing parameters
    criterion = D_Loss()
    device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
    pretrained_decoder = None
    model_to_test = 'decoder'
    statistical_testing = False
    display_plots = True
    
    # Handle special case for training data
    if test_on_train:
        num_samples = trainset["fMRIs"].shape[0]
        random_indices = np.random.choice(num_samples, size=30, replace=False)
        testset2 = {
            "fMRIs": {
                "test": trainset["fMRIs"][random_indices]
            },
            "videos": {
                "test": trainset["videos"][random_indices]
            }
        }
        test_input = testset2['fMRIs']
        test_label = testset2['videos']
    
    # Check if we're using a brain block
    if num_blocks is not None:
        # First, compute the baseline performance if requested
        baseline_performance = None
        if compare_to_baseline:
            print("Computing baseline performance with no perturbation...")
            if all_frames:
                baseline_results = test_model_all(
                    test_input, 
                    test_label, 
                    model, 
                    criterion, 
                    device, 
                    pretrained_decoder, 
                    model_to_test, 
                    statistical_testing, 
                    display_plots, 
                    save_plots, 
                    model_name=model_name + '_baseline', 
                    metric=metric,
                    mean_flag=False,
                    zones=zones
                )
                baseline_performance = baseline_results['test_performance']
                
                # Save baseline performance
                baseline_file = f'baseline_{model_name}_{metric}_zones{zones}.pkl'
                with open(baseline_file, 'wb') as f:
                    pickle.dump(baseline_performance, f)
                print(f"Baseline performance saved to {baseline_file}")
                
                print("Baseline performance:")
                for k, v in baseline_performance.items():
                    if isinstance(v, (int, float)):
                        print(f"  {k}: {v:.5f}")
            else:
                # For now, assume baseline is not needed for non-all_frames mode
                pass
                
        # Load the 3D brain data
        regions_3d = load_and_reshape_data('region_ids_4609+.npy')

        # Cut to the tightest rectangular prism around the brain
        (x_min, x_max), (y_min, y_max), (z_min, z_max) = find_bounds(regions_3d)
        brain_data = regions_3d[x_min:x_max+1, y_min:y_max+1, z_min:z_max+1]
        print(f"Original shape: {regions_3d.shape}, Brain shape: {brain_data.shape}")

        # Divide into blocks
        blocks = divide_brain_into_blocks(brain_data, num_blocks)
        
        # Flatten the mask
        mask4609 = np.load('mask_schaefer1000_4609.npy')
        flat_mask = mask4609.flatten()

        if block_id is not None:  # If some block is specified
            for video_name in test_input.keys():
                print(f"Turning off block {block_id} for video {video_name}")

                visualize_blocks(brain_data, blocks, num_blocks=num_blocks, selected_block=block_id)
                
                # Use the existing turn_off_regions function for the regions in this block
                if change_mode == 'off':
                    modified_data = turn_off_block_new(test_input[video_name], flat_mask, block_id, blocks, x_min, y_min, z_min)
                else:
                    modified_data = turn_off_block_new(test_input[video_name], flat_mask, block_id, blocks, x_min, y_min, z_min)
                
                modified_input[video_name] = modified_data
                
            # Use the modified input for testing
            test_input = modified_input

            if all_frames:
                results = test_model_all(
                    test_input, 
                    test_label, 
                    model, 
                    criterion, 
                    device, 
                    pretrained_decoder, 
                    model_to_test, 
                    statistical_testing, 
                    display_plots, 
                    save_plots, 
                    model_name=model_name + add_name, 
                    metric=metric
                )
                
                # Compare with baseline if available
                if baseline_performance is not None:
                    print("\nComparison with baseline:")
                    block_performance = results['test_performance']
                    
                    # Get the main metric based on metric type
                    main_metric_key = f'mean_{metric}_D'
                    if main_metric_key in baseline_performance and main_metric_key in block_performance:
                        baseline_val = baseline_performance[main_metric_key]
                        block_val = block_performance[main_metric_key]
                        diff = block_val - baseline_val
                        
                        print(f"  {main_metric_key}: {block_val:.5f} (baseline: {baseline_val:.5f}, diff: {diff:.5f})")
                    
                    # Compare zone-specific metrics
                    for k in baseline_performance.keys():
                        if k.startswith('mean_') and k != main_metric_key and k in block_performance:
                            baseline_val = baseline_performance[k]
                            block_val = block_performance[k]
                            diff = block_val - baseline_val
                            print(f"  {k}: {block_val:.5f} (baseline: {baseline_val:.5f}, diff: {diff:.5f})")
                            
                return results
            else:
                test_model(
                    test_input, 
                    test_label, 
                    model, 
                    criterion, 
                    device, 
                    pretrained_decoder, 
                    model_to_test, 
                    statistical_testing, 
                    display_plots, 
                    save_plots, 
                    model_name=model_name + add_name
                )
                return None

        elif block_id is None:  # Loop through all blocks
            losses = []
            loss_diffs = []  # Store differences from baseline
            
            # Create a dictionary to store detailed results for each block
            block_results = {}
            
            # Get baseline metric key based on metric type
            main_metric_key = f'mean_{metric}_D'
            
            # Loop through all blocks
            for block_id in range(1, num_blocks[0] * num_blocks[1] * num_blocks[2] + 1):
                print(f"\nProcessing block {block_id}...")
                modified_input = {}
                
                for video_name in test_input.keys():
                    # Turn off this block in the input data
                    if change_mode == 'off':
                        ffa_mask = np.load('enhanced_union_FFA.npy')
                        brain_mask = np.load('mask_schaefer1000_4609.npy')

                        # Find which voxels to zero out (intersection of brain mask and FFA mask)
                        intersection = (brain_mask > 0) & (ffa_mask > 0)
                        brain_indices = np.where(brain_mask.flatten())[0]
                        ffa_indices = np.where(intersection.flatten())[0]
                        voxels_to_zero = np.where(np.isin(brain_indices, ffa_indices))[0]

                        # Create modified data
                        modified_data = test_input[video_name].copy()
                        print("\n\nmodified_data shape =", modified_data.shape, "\n\n")
                        modified_data[:, voxels_to_zero] = 0
#                    modified_data = turn_off_block_new(test_input[video_name], flat_mask, block_id, blocks, x_min, y_min, z_min)

#                        modified_data = turn_off_block_new(test_input[video_name], flat_mask, block_id, blocks, x_min, y_min, z_min)
                    else:
                        ffa_mask = np.load('enhanced_union_FFA.npy')
                        brain_mask = np.load('mask_schaefer1000_4609.npy')

                        # Find which voxels to zero out (intersection of brain mask and FFA mask)
                        intersection = (brain_mask > 0) & (ffa_mask > 0)
                        brain_indices = np.where(brain_mask.flatten())[0]
                        ffa_indices = np.where(intersection.flatten())[0]
                        voxels_to_zero = np.where(np.isin(brain_indices, ffa_indices))[0]

                        # Create modified data
                        modified_data = test_input[video_name].copy()
                        print("\n\nmodified_data shape =", modified_data.shape, "\n\n")
                        modified_data[:, voxels_to_zero] = 0
#                    modified_data = turn_off_block_new(test_input[video_name], flat_mask, block_id, blocks, x_min, y_min, z_min)

#                        modified_data = turn_off_block_new(test_input[video_name], flat_mask, block_id, blocks, x_min, y_min, z_min)
                    
                    modified_input[video_name] = modified_data
                    
                # Test with this block turned off
                if all_frames:
                    results = test_model_all(
                        modified_input, 
                        test_label, 
                        model, 
                        criterion, 
                        device, 
                        pretrained_decoder, 
                        model_to_test, 
                        statistical_testing, 
                        display_plots, 
                        save_plots, 
                        model_name=model_name + f'_block_{block_id}', 
                        metric=metric, 
                        mean_flag=False,  # Return full performance dict 
                        zones=zones
                    )
                    
                    # Store the block's performance
                    block_performance = results['test_performance']
                    block_results[block_id] = block_performance
                    
                    # Get the main loss value
                    if metric == "tv":
                        mean_loss = block_performance.get(main_metric_key, 0)
                    elif metric == "ssim":
                        mean_loss = block_performance.get(main_metric_key, 0)
                    
                    print(f"Block {block_id}, mean loss = {mean_loss:.5f}")
                    losses.append(mean_loss)
                    
                    # Calculate difference from baseline if available
                    if baseline_performance is not None and main_metric_key in baseline_performance:
                        baseline_val = baseline_performance[main_metric_key]
                        diff = mean_loss - baseline_val
                        loss_diffs.append(diff)
                        print(f"  Difference from baseline: {diff:.5f}")
                    else:
                        loss_diffs.append(0)  # Default if baseline not available
                    
                else:
                    test_model(
                        modified_input, 
                        test_label, 
                        model, 
                        criterion, 
                        device, 
                        pretrained_decoder, 
                        model_to_test, 
                        statistical_testing, 
                        display_plots, 
                        save_plots, 
                        model_name=model_name + f'_block_{block_id}'
                    )
            
            # Find the block with the biggest impact
            if compare_to_baseline and baseline_performance is not None:
                lossiest = np.argmax(np.abs(loss_diffs))
                impact_val = loss_diffs[lossiest]
            else:
                lossiest = np.argmax(losses)
                impact_val = losses[lossiest]
                
            lossiest_block = lossiest + 1  # Convert to 1-based indexing
            print(f"\nBlock with the biggest impact: {lossiest_block} (value: {impact_val:.5f})")
                
            # Visualize with appropriate values
            if compare_to_baseline and baseline_performance is not None:
                print("Visualizing differences from baseline...")
                visualize_blocks_3(
                    brain_data,
                    blocks,
                    loss_diffs,  # Use differences from baseline
                    num_blocks=num_blocks,
                    colormap='coolwarm'  # Better for showing positive/negative differences
                )
            else:
                print("Visualizing absolute loss values...")
                visualize_blocks_3(
                    brain_data,
                    blocks,
                    losses,
                    num_blocks=num_blocks,
                    colormap='viridis'  # For absolute values
                )
            
            # Save results
            results_data = {
                'losses': losses,
                'baseline_performance': baseline_performance,
                'loss_differences': loss_diffs if baseline_performance is not None else None,
                'block_with_biggest_impact': lossiest_block,
                'block_results': block_results
            }
            
            results_file = f'block_analysis_{model_name}_{metric}_zones{zones}.pkl'
            with open(results_file, 'wb') as f:
                pickle.dump(results_data, f)
            print(f"Results saved to {results_file}")
            
            return results_data

    # Original code for regions list
    elif regions:
        # Specific regions case
        fmri_regions_off = test_input.copy()
        
        for video_name in test_input.keys():
            fmri_regions_off[video_name] = turn_off_regions(test_input[video_name], regions)
            
        test_input = fmri_regions_off
    
    print("test input shape =", print_dict_tree(test_input))
    print("test label shape =", print_dict_tree(test_label))

    # Run the appropriate test model function
    # Do recall this code is just being used for the case we want regions, not blocks
    if all_frames:
        results = test_model_all(
            test_input, 
            test_label, 
            model, 
            criterion, 
            device, 
            pretrained_decoder, 
            model_to_test, 
            statistical_testing, 
            display_plots, 
            save_plots, 
            model_name=model_name + add_name, 
            metric=metric, 
            mean_flag=True,
            zones=zones
        )
        return results
    else:
        test_model(
            test_input, 
            test_label, 
            model, 
            criterion, 
            device, 
            pretrained_decoder, 
            model_to_test, 
            statistical_testing, 
            display_plots, 
            save_plots, 
            model_name=model_name + add_name
        )
        return None

'''
def run_all_blocks_test():
    """
    Run test with all blocks analysis, comparing to baseline performance
    """
    # Call the test_new_decoder function with baseline comparison
    results = test_new_decoder(
        real=True,
        model_name='decoder_4609_350',
        test_input=filtered_trainset['fMRIs'],
        test_label=filtered_trainset['videos'],
        all_frames=True,
        save_plots=False,
        add_name='_block_analysis',
        change_mode='off',
        num_blocks=(2, 2, 2),
        metric="tv",
        zones=4,
        compare_to_baseline=True  # Enable baseline comparison
    )
    
    return results

# Call the function to run the analysis
results = run_all_blocks_test()
'''







def run_all_blocks_test():
    """
    Run test with all blocks analysis, comparing to baseline performance
    Shows difference in loss between perturbed and baseline reconstructions for each frame
    """
    # Step 1: Calculate baseline reconstructions (no perturbation)
    print("=== STEP 1: Calculating baseline reconstructions (no perturbation) ===")
    
    model_name = 'decoder_4609_350'
    test_input = filtered_trainset['fMRIs']
    test_label = filtered_trainset['videos']
    num_blocks = (1, 1, 2)
    metric = "tv"
    zones = 4
    save_plots = False
    
    # Load the model
    model = Decoder(mask_size)
    state_dict = torch.load(model_name)
    model.load_state_dict(state_dict)
    
    # Testing parameters
    criterion = D_Loss()
    device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
    pretrained_decoder = None
    model_to_test = 'decoder'
    statistical_testing = False
    display_plots = True
    
    # Generate baseline predictions
    baseline_results = test_model_all(
        test_input, 
        test_label, 
        model, 
        criterion, 
        device, 
        pretrained_decoder, 
        model_to_test, 
        statistical_testing, 
        display_plots=False,  # Don't display plots for baseline
        save_plots=False,
        model_name=model_name + '_baseline', 
        metric=metric, 
        mean_flag=False,
        zones=zones
    )
    
    # Save baseline performance metrics and predictions
    baseline_performance = baseline_results['test_performance']
    baseline_predictions = baseline_results['decoder_predictions']
    
    print("Baseline performance metrics:")
    for k, v in baseline_performance.items():
        if isinstance(v, (int, float)):
            print(f"  {k}: {v:.5f}")
    
    # Step 2: Run the perturbation analysis for each block
    print("\n=== STEP 2: Running perturbation analysis for each block ===")
    
    # Load 3D brain data and prepare blocks
    regions_3d = load_and_reshape_data('region_ids_4609+.npy')
    (x_min, x_max), (y_min, y_max), (z_min, z_max) = find_bounds(regions_3d)
    brain_data = regions_3d[x_min:x_max+1, y_min:y_max+1, z_min:z_max+1]
    print(f"Original shape: {regions_3d.shape}, Brain shape: {brain_data.shape}")
    
    # Divide into blocks
    blocks = divide_brain_into_blocks(brain_data, num_blocks)
    
    # Prepare mask for block manipulation
    mask4609 = np.load('mask_schaefer1000_4609.npy')
    flat_mask = mask4609.flatten()
    
    # Container for results
    block_losses = []
    loss_differences = []
    block_results = {}
    
    # Main metric key based on chosen metric
    main_metric_key = f'mean_{metric}_D'
    
    # Loop through all blocks
    total_blocks = num_blocks[0] * num_blocks[1] * num_blocks[2]
    for block_id in range(1, total_blocks + 1):
        print(f"\nProcessing block {block_id}/{total_blocks}")
        
        # Create modified input with current block turned off
        modified_input = {}
        for video_name in test_input.keys():
            if block_id == 1:
                print("\n\nDOING FFA NOW\n\n")
                ffa_mask = np.load('enhanced_union_FFA.npy')
            else:
                print("\n\nDOING PPA NOW\n\n")
                ffa_mask = np.load('resampled_ppa.npy')
            #modified_data = turn_off_block_new(
            #    test_input[video_name], 
            #    flat_mask, 
            #    block_id, 
            #    blocks, 
            #    x_min, 
            #    y_min, 
            #    z_min
            #)
            
            brain_mask = np.load('mask_schaefer1000_4609.npy')

            # Find which voxels to zero out (intersection of brain mask and FFA mask)
            intersection = (brain_mask > 0) & (ffa_mask > 0)
            brain_indices = np.where(brain_mask.flatten())[0]
            ffa_indices = np.where(intersection.flatten())[0]
            voxels_to_zero = np.where(np.isin(brain_indices, ffa_indices))[0]
            print(f"Zeroing out {len(voxels_to_zero)} voxels out of 4609 total ({100*len(voxels_to_zero)/4609:.1f}%)")

            # Create modified data
            modified_data = test_input[video_name].copy()
            print("\n\nmodified_data shape =", modified_data.shape, "\n\n")
            modified_data[:, voxels_to_zero] = 0
#                    modified_data = turn_off_block_new(test_input[video_name], flat_mask, block_id, blocks, x_min, y_min, z_min)

#                        modified_data = turn_off_block_new(test_input[video_name], flat_mask, block_id, blocks, x_min, y_min, z_min)
        
            modified_input[video_name] = modified_data
#            modified_input[video_name] = modified_data
        
        # Run test with this block turned off, passing baseline predictions for comparison
        block_results = test_model_all(
            modified_input, 
            test_label, 
            model, 
            criterion, 
            device, 
            pretrained_decoder, 
            model_to_test, 
            statistical_testing, 
            display_plots=True,  # Show plots with difference visualization
            save_plots=save_plots, 
            model_name=model_name + f'_block_{block_id}', 
            metric=metric, 
            mean_flag=False,
            zones=zones,
            baseline_predictions=baseline_predictions  # Pass baseline predictions for frame-level comparison
        )
        
        # Get the main metric value
        block_performance = block_results['test_performance']
        if main_metric_key in block_performance:
            mean_loss = block_performance[main_metric_key]
        else:
            # If the key isn't found, try to find a similar key
            metric_keys = [k for k in block_performance.keys() if metric in k.lower() and 'mean' in k.lower()]
            mean_loss = block_performance[metric_keys[0]] if metric_keys else 0
            
        print(f"Block {block_id} - Mean loss: {mean_loss:.5f}")
        block_losses.append(mean_loss)
        
        # Calculate difference from baseline
        if main_metric_key in baseline_performance:
            baseline_val = baseline_performance[main_metric_key]
            diff = mean_loss - baseline_val
            loss_differences.append(diff)
            print(f"  Difference from baseline: {diff:.5f}")
        else:
            loss_differences.append(0)
    
    # Step 3: Visualize block-level results
    print("\n=== STEP 3: Visualizing block-level impact analysis ===")
    
    # Find block with biggest impact
    max_diff_idx = np.argmax(np.abs(loss_differences))
    max_diff_block = max_diff_idx + 1
    max_diff_value = loss_differences[max_diff_idx]
    
    print(f"Block with biggest impact: {max_diff_block} (difference: {max_diff_value:.5f})")
    
    # Visualize loss differences across blocks
    visualize_blocks_3(
        brain_data,
        blocks,
        loss_differences,  # Use differences from baseline
        num_blocks=num_blocks,
        colormap='coolwarm',  # Better for showing positive/negative differences
        is_difference=True  # Indicate we're showing differences
    )
    
    # Save all results
    results_data = {
        'baseline_performance': baseline_performance,
        'block_losses': block_losses,
        'loss_differences': loss_differences,
        'max_impact_block': max_diff_block,
        'max_impact_value': max_diff_value
    }
    
    # Save to file
    import pickle
    results_file = f'block_analysis_{model_name}_{metric}_zones{zones}.pkl'
    with open(results_file, 'wb') as f:
        pickle.dump(results_data, f)
    print(f"Results saved to {results_file}")
    
    return results_data

# Call the function to run the analysis
results = run_all_blocks_test()


#i didnt include visualize blocks on this one
def run_all_blocks_test_debug():
    """
    Run test with all blocks analysis with extensive debugging
    to find why all blocks have the same difference value
    """
    # Step 1: Calculate baseline reconstructions (no perturbation)
    print("=== STEP 1: Calculating baseline reconstructions (no perturbation) ===")
    
    model_name = 'decoder_4609_350'
    test_input = filtered_trainset['fMRIs']
    test_label = filtered_trainset['videos']
    num_blocks = (2, 2, 2)
    metric = "tv"
    zones = 4
    save_plots = False
    
    # Load the model
    model = Decoder(mask_size)
    state_dict = torch.load(model_name)
    model.load_state_dict(state_dict)
    
    # Testing parameters
    criterion = D_Loss()
    device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
    pretrained_decoder = None
    model_to_test = 'decoder'
    statistical_testing = False
    display_plots = True
    
    # Generate baseline predictions
    baseline_results = test_model_all(
        test_input, 
        test_label, 
        model, 
        criterion, 
        device, 
        pretrained_decoder, 
        model_to_test, 
        statistical_testing, 
        display_plots=False,  # Don't display plots for baseline
        save_plots=False,
        model_name=model_name + '_baseline', 
        metric=metric, 
        mean_flag=False,
        zones=zones
    )
    
    # Save baseline performance metrics and predictions
    baseline_performance = baseline_results['test_performance']
    baseline_predictions = baseline_results['decoder_predictions']
    
    print("\n==== DEBUG: Baseline performance metrics ====")
    for k, v in baseline_performance.items():
        if isinstance(v, (int, float)):
            print(f"  {k}: {v}")
    
    # Identify the main metric key we'll be using
    if metric == "tv":
        main_keys = [k for k in baseline_performance.keys() if 'tv' in k.lower() and 'mean' in k.lower()]
    else:
        main_keys = [k for k in baseline_performance.keys() if 'ssim' in k.lower() and 'mean' in k.lower()]
    
    print(f"\n==== DEBUG: Available metric keys: {main_keys} ====")
    
    # Determine the main metric key to use
    main_metric_key = f'mean_{metric}_D'
    if main_metric_key not in baseline_performance:
        if main_keys:
            main_metric_key = main_keys[0]
            print(f"Main metric key '{main_metric_key}' not found. Using '{main_metric_key}' instead.")
        else:
            print(f"ERROR: No suitable metric keys found in baseline performance!")
            return None
    
    baseline_val = baseline_performance[main_metric_key]
    print(f"Using main metric key: {main_metric_key} with baseline value: {baseline_val}")
    
    # Step 2: Run the perturbation analysis for each block
    print("\n=== STEP 2: Running perturbation analysis for each block ===")
    
    # Load 3D brain data and prepare blocks
    regions_3d = load_and_reshape_data('region_ids_4609+.npy')
    (x_min, x_max), (y_min, y_max), (z_min, z_max) = find_bounds(regions_3d)
    brain_data = regions_3d[x_min:x_max+1, y_min:y_max+1, z_min:z_max+1]
    print(f"Original shape: {regions_3d.shape}, Brain shape: {brain_data.shape}")
    
    # Divide into blocks
    blocks = divide_brain_into_blocks(brain_data, num_blocks)
    
    # Prepare mask for block manipulation
    mask4609 = np.load('mask_schaefer1000_4609.npy')
    flat_mask = mask4609.flatten()
    
    # Container for results
    block_losses = []
    loss_differences = []
    block_results_dict = {}
    
    # Loop through a subset of blocks for debugging
    total_blocks = num_blocks[0] * num_blocks[1] * num_blocks[2]
    for block_id in range(1, total_blocks + 1):
        print(f"\n==== PROCESSING BLOCK {block_id}/{total_blocks} ====")
        
        # Create modified input with current block turned off
        modified_input = {}
        for video_name in test_input.keys():
            modified_data = turn_off_block_new(
                test_input[video_name], 
                flat_mask, 
                block_id, 
                blocks, 
                x_min, 
                y_min, 
                z_min
            )
            modified_input[video_name] = modified_data
        
        # Run test with this block turned off
        block_results = test_model_all(
            modified_input, 
            test_label, 
            model, 
            criterion, 
            device, 
            pretrained_decoder, 
            model_to_test, 
            statistical_testing, 
            display_plots=False,  # Don't show plots during debugging
            save_plots=False, 
            model_name=model_name + f'_block_{block_id}', 
            metric=metric, 
            mean_flag=False,
            zones=zones,
            baseline_predictions=None  # Don't pass baseline during debug - we'll compute differences manually
        )
        
        print(f"\n==== DEBUG: Block {block_id} Results ====")
        block_performance = block_results['test_performance']
        block_results_dict[block_id] = block_performance
        
        # Show all metrics in block performance
        for k, v in block_performance.items():
            if isinstance(v, (int, float)):
                print(f"  {k}: {v}")
        
        # Calculate actual block mean loss value
        if main_metric_key in block_performance:
            mean_loss = block_performance[main_metric_key]
        else:
            print(f"WARNING: Main metric key '{main_metric_key}' not found in block_performance!")
            # Try to find similar keys
            if metric == "tv":
                similar_keys = [k for k in block_performance.keys() if 'tv' in k.lower() and 'mean' in k.lower()]
            else:
                similar_keys = [k for k in block_performance.keys() if 'ssim' in k.lower() and 'mean' in k.lower()]
                
            if similar_keys:
                mean_loss = block_performance[similar_keys[0]]
                print(f"Using alternative key: {similar_keys[0]}")
            else:
                mean_loss = 0
                print("No similar keys found! Using 0.")
        
        # Manually calculate the difference
        diff = mean_loss - baseline_val
        
        print(f"BLOCK {block_id}:")
        print(f"  Raw loss value: {mean_loss}")
        print(f"  Baseline value: {baseline_val}")
        print(f"  Difference: {diff}")
        
        # Store the values
        block_losses.append(mean_loss)
        loss_differences.append(diff)
    
    # Print all the collected values
    print("\n==== FINAL DEBUG: All collected values ====")
    print("Block Losses:")
    for i, loss in enumerate(block_losses):
        print(f"  Block {i+1}: {loss}")
    
    print("\nDifferences from baseline:")
    for i, diff in enumerate(loss_differences):
        print(f"  Block {i+1}: {diff}")
    
    # For testing, try a different approach to calculate differences
    print("\nRecalculating differences to double-check:")
    for i, loss in enumerate(block_losses):
        recalc_diff = loss - baseline_val
        print(f"  Block {i+1}: Loss={loss}, Baseline={baseline_val}, Diff={recalc_diff}")
    
    # Find block with biggest impact
    max_diff_idx = np.argmax(np.abs(loss_differences))
    max_diff_block = max_diff_idx + 1
    max_diff_value = loss_differences[max_diff_idx]
    
    print(f"\nBlock with biggest impact: {max_diff_block} (difference: {max_diff_value})")
    
    # Return the debug data for inspection
    debug_data = {
        'baseline_performance': baseline_performance,
        'baseline_value': baseline_val,
        'main_metric_key': main_metric_key,
        'block_losses': block_losses,
        'loss_differences': loss_differences,
        'block_results': block_results_dict
    }
    
    return debug_data

# Run the debug function
#debug_results = run_all_blocks_test_debug()

#tag blocks with original and perturbed reconstruction

In [3]:
# make dataset


dataset_ID = 6661 # ID of a specific dataset. 6661 refer to preprocessed data with a mask of shape (4609,). 6660 refers to preprocessed data with a mask of shape (15364,)
mask_size = 4609 # number of voxels in the preprocessed fMRI data. either 4609 or 15364
trainset, valset, testset = get_dataset(dataset_ID, mask_size) # data are loaded into dictionaries


def extract_frames(testset, specific_frames):
    """
    Extract specific frames from testset and combine them into a single array.
    
    Parameters:
    - testset: Dictionary with 'fMRIs' and 'videos' keys, each containing subdictionaries for each movie
    - specific_frames: Dictionary mapping movie names to lists of frame indices to extract
    
    Returns:
    - filtered_data: Dictionary with 'fMRIs' and 'videos' keys, each containing concatenated arrays of selected frames
    """
    # Create lists to store selected frames
    selected_fmris = []
    selected_videos = []
    
    # Loop through each movie in specific_frames
    for movie_name, frames in specific_frames.items():
        # Check if the movie exists in testset
        if movie_name not in testset['fMRIs'] or movie_name not in testset['videos']:
            print(f"Warning: {movie_name} not found in testset")
            continue
        
        # Get fMRI and video data for this movie
        movie_fmri = testset['fMRIs'][movie_name]
        movie_video = testset['videos'][movie_name]
        
        # Loop through each frame index
        for frame in frames:
            # Check if frame index is valid
            if frame >= len(movie_fmri):
                print(f"Warning: Frame {frame} out of range for {movie_name} (max={len(movie_fmri)-1})")
                continue
            
            # Add frame to selected lists
            selected_fmris.append(movie_fmri[frame])
            selected_videos.append(movie_video[frame])
    
    # Convert lists to numpy arrays
    if selected_fmris:
        fmris_array = np.array(selected_fmris)
        videos_array = np.array(selected_videos)
    else:
        print("No valid frames found")
        return None
    
    # Create filtered dataset with all frames in a single array
    filtered_data = {
        'fMRIs': {'combined': fmris_array},
        'videos': {'combined': videos_array}
    }
    
    # Print the structure of the filtered dataset
    print("Filtered dataset structure:")
    print("fMRIs")
    for movie, data in filtered_data['fMRIs'].items():
        print(f"  {movie} (shape: {data.shape})")
    print("videos")
    for movie, data in filtered_data['videos'].items():
        print(f"  {movie} (shape: {data.shape})")
    
    return filtered_data


def extract_frames_train(trainset, specific_frames_train):
    """
    Extract specific frames from testset and combine them into a single array.
    
    Parameters:
    - trainset: Dictionary with 'fMRIs' and 'videos' keys, each containing subdictionaries for each movie
    - specific_frames: Dictionary mapping movie names to lists of frame indices to extract
    
    Returns:
    - filtered_data: Dictionary with 'fMRIs' and 'videos' keys, each containing concatenated arrays of selected frames
    """
    # Create lists to store selected frames
    selected_fmris = []
    selected_videos = []
    
    # Get fMRI and video data for this movie
    movie_fmri = trainset['fMRIs']
    movie_video = trainset['videos']
    
    # Loop through each frame index
    for frame in specific_frames_train:
        # Check if frame index is valid
        if frame >= len(movie_fmri):
            print(f"Warning: Frame {frame} out of range (max={len(movie_fmri)-1})")
            continue
        
        # Add frame to selected lists
        selected_fmris.append(movie_fmri[frame])
        selected_videos.append(movie_video[frame])
    
    # Convert lists to numpy arrays
    if selected_fmris:
        fmris_array = np.array(selected_fmris)
        videos_array = np.array(selected_videos)
    else:
        print("No valid frames found")
        return None
    
    # Create filtered dataset with all frames in a single array
    filtered_data = {
        'fMRIs': {'combined': fmris_array},
        'videos': {'combined': videos_array}
    }
    
    # Print the structure of the filtered dataset
    print("Filtered dataset structure:")
    print("fMRIs")
    for movie, data in filtered_data['fMRIs'].items():
        print(f"  {movie} (shape: {data.shape})")
    print("videos")
    for movie, data in filtered_data['videos'].items():
        print(f"  {movie} (shape: {data.shape})")
    
    return filtered_data


def extract_frames_train_new(specific_frames_train):
    """
    Extract specific frames from trainset of new dataset and combine them into a single array.
    
    Parameters:
    - testset: Dictionary with 'fMRIs' and 'videos' keys, each containing subdictionaries for each movie
    - specific_frames: Dictionary mapping movie names to lists of frame indices to extract
    
    Returns:
    - filtered_data: Dictionary with 'fMRIs' and 'videos' keys, each containing concatenated arrays of selected frames
    """
    # Create lists to store selected frames
    selected_fmris = []
    selected_videos = []
    
    # Get fMRI and video data for this movie
    #movie_fmri = trainset['fMRIs']
    #movie_video = trainset['videos']

    fmris = np.load('processed_data/sub-S32/train.npy')
    videos = np.load('processed_data/videos/videos.npy')
    
    # Loop through each frame index
    for frame in specific_frames_train:
        # Check if frame index is valid
        if frame >= len(fmris):
            print(f"Warning: Frame {frame} out of range (max={len(fmris)-1})")
            continue
        
        # Add frame to selected lists
        selected_fmris.append(fmris[frame])
        selected_videos.append(videos[frame])
    
    # Convert lists to numpy arrays
#    if selected_fmris:
#        fmris_array = np.array(selected_fmris)
#        videos_array = np.array(selected_videos)
#    else:
#        print("No valid frames found")
#        return None
    
    # Create filtered dataset with all frames in a single array
    filtered_data = {
        'fMRIs': {'combined': selected_fmris},
        'videos': {'combined': selected_videos}
    }
    
    # Print the structure of the filtered dataset
    print("Filtered dataset structure:")
    print("fMRIs")
    for movie, data in filtered_data['fMRIs'].items():
        print(f"  {movie} (shape: {len(data)})")
    print("videos")
    for movie, data in filtered_data['videos'].items():
        print(f"  {movie} (shape: {len(data)})")
    
    print(filtered_data['videos'])

    return filtered_data



# Example usage:
specific_frames = {
    'AfterTheRain': [42],
    'BetweenViewings': [111],
    'Chatter': [21],
    'FirstBite': [33],          #gotta change this one
    'LessonLearned': [15, 36],
    'Payload': [18, 30],
    'Spaceman': [12],
    'TearsOfSteel': [39],
    'YouAgain': [300, 495]
}

specific_frames_train = [681, 248, 3008, 1561, 1821, 2639, 467, 3558, 2173, 2119]
#frame 2119 from the trainset is a very nice frame with a face


# Call the function to create the filtered dataset
filtered_testset = extract_frames(testset, specific_frames)
filtered_trainset = extract_frames_train(trainset, specific_frames_train)

trainset2 = {}
trainset2['fMRIs'] = np.memmap(f'encoder_dataset_{dataset_ID}/trainset/fMRIs.npy', dtype='float32', mode='r')
trainset2['videos'] = np.memmap(f'encoder_dataset_{dataset_ID}/trainset/videos.npy', dtype='float32', mode='r')

print("trainset fmris shape =", trainset2['fMRIs'].shape)
print("trainset videos shape =", trainset2['videos'].shape)

Filtered dataset structure:
fMRIs
  combined (shape: (12, 4609))
videos
  combined (shape: (12, 3, 112, 112, 32))
Filtered dataset structure:
fMRIs
  combined (shape: (10, 4609))
videos
  combined (shape: (10, 3, 112, 112, 32))
trainset fmris shape = (19915489,)
trainset videos shape = (5203451904,)


In [None]:
#tag debug


import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"

from dataset_new import *
from models_new_2 import *
from visualisation_new_2 import *
from perturbation import *


def plot_all_predictions7(predictions, videos, performance_dict=None, display_plots=True, save_plots=False, 
                 save_path_prefix=None, model_name="", device="cuda" if torch.cuda.is_available() else "cpu",
                 metric="ssim", mean_flag=False, zone_type="quadrants", max_frames=None, baseline_predictions=None):
    """
    Display comparison plots between original videos and predictions.
    Shows: Original image, baseline reconstruction, perturbed reconstruction, and difference heatmap.
    
    Parameters:
    -----------
    predictions : dict
        Dictionary of prediction arrays (perturbed reconstructions)
    videos : dict
        Dictionary of ground truth video arrays
    performance_dict : dict, optional
        Dictionary to store performance metrics
    display_plots : bool
        Whether to display the plots
    save_plots : bool
        Whether to save the plots
    save_path_prefix : str, optional
        Path prefix for saving plots
    model_name : str
        Name of the model for saving plots
    device : str
        Device to use for computations
    metric : str
        Metric to use for evaluation: "ssim" or "tv" (Total Variation)
    mean_flag : bool
        Whether to return mean metrics or not
    zone_type : str or int
        Type of zones: 
        - "quadrants" for 2×2 grid
        - "center_bg" for center and background
        - integer n for n×n grid (e.g., 4 creates a 4×4 grid with 16 zones)
    max_frames : int, optional
        Maximum number of frames to plot. If None, all frames will be plotted.
    baseline_predictions : dict, optional
        Dictionary of baseline prediction arrays (without perturbation)
    """
    import numpy as np
    import matplotlib.pyplot as plt
    import torch
    import os
    from matplotlib.patches import Rectangle
    from matplotlib.colors import LinearSegmentedColormap, Normalize
    
    # Create output directory if it doesn't exist
    if save_plots and save_path_prefix:
        os.makedirs(save_path_prefix, exist_ok=True)

    # Debug information about inputs
    print("\n=== DEBUG INFO ===")
    print(f"Predictions dictionary contains {len(predictions)} keys: {list(predictions.keys())}")
    print(f"Videos dictionary contains {len(videos)} keys: {list(videos.keys())}")
    print(f"Using metric: {metric}")
    print(f"Baseline predictions provided: {baseline_predictions is not None}")
    
    if isinstance(zone_type, int):
        print(f"Using {zone_type}×{zone_type} grid zones ({zone_type*zone_type} total zones)")
    else:
        print(f"Using zone type: {zone_type}")
    
    # Find the overlapping keys between predictions and videos
    common_keys = [key for key in predictions.keys() if key in videos]
    print(f"Common keys in both dictionaries: {common_keys}")
    
    # Create key mapping between predictions and videos
    if not common_keys and len(videos) > 0:
        print("No common keys found. Trying to match prediction keys to video keys.")
        ref_video_key = list(videos.keys())[0]
        print(f"Using {ref_video_key} as reference video for all predictions")
        key_mapping = {pred_key: ref_video_key for pred_key in predictions.keys()}
    else:
        key_mapping = {}
        for pred_key in predictions.keys():
            if pred_key in videos:
                key_mapping[pred_key] = pred_key
            else:
                matched = False
                for video_key in videos.keys():
                    if video_key in pred_key:
                        key_mapping[pred_key] = video_key
                        matched = True
                        break
                if not matched and len(videos) > 0:
                    key_mapping[pred_key] = list(videos.keys())[0]
    
    print(f"Key mapping from prediction keys to video keys: {key_mapping}")
    
    # Helper function to split frame into zones
    def split_into_zones(frame, zone_type="quadrants", center_ratio=0.5):
        """
        Split a frame into zones.
        """
        if isinstance(frame, torch.Tensor):
            C, H, W = frame.shape
        else:
            C, H, W = frame.shape
            
        zones = {}
        
        if zone_type == "quadrants":
            # Split into 4 quadrants (2×2 grid)
            h_mid = H // 2
            w_mid = W // 2
            
            zones["top_left"] = (slice(None), slice(0, h_mid), slice(0, w_mid))
            zones["top_right"] = (slice(None), slice(0, h_mid), slice(w_mid, W))
            zones["bottom_left"] = (slice(None), slice(h_mid, H), slice(0, w_mid))
            zones["bottom_right"] = (slice(None), slice(h_mid, H), slice(w_mid, W))
            
        elif zone_type == "center_bg":
            # Split into center and background
            h_center = int(H * center_ratio)
            w_center = int(W * center_ratio)
            
            h_start = (H - h_center) // 2
            h_end = h_start + h_center
            w_start = (W - w_center) // 2
            w_end = w_start + w_center
            
            zones["center"] = (slice(None), slice(h_start, h_end), slice(w_start, w_end))
            
            # Background is everything except the center
            center_mask = np.zeros((H, W), dtype=bool)
            center_mask[h_start:h_end, w_start:w_end] = True
            
            zones["background"] = {"mask": ~center_mask, 
                                   "bounds": (h_start, h_end, w_start, w_end)}
            
        elif isinstance(zone_type, int) and zone_type > 0:
            # Create an n×n grid where n = zone_type
            n = zone_type
            
            # Calculate heights of each section
            h_sections = [i * H // n for i in range(n+1)]
            w_sections = [i * W // n for i in range(n+1)]
            
            # Create zones for each grid cell
            for i in range(n):
                for j in range(n):
                    zone_name = f"grid_{i}_{j}"  # Row_Column naming
                    zones[zone_name] = (
                        slice(None),
                        slice(h_sections[i], h_sections[i+1]),
                        slice(w_sections[j], w_sections[j+1])
                    )
                    
        else:
            raise ValueError(f"Unknown zone type: {zone_type}")
            
        return zones
    
    # Helper function to calculate zone metrics
    def calculate_zone_metrics(orig_frame, pred_frame, baseline_frame=None, zones=None, metric="ssim", device=device):
        """
        Calculate metrics for each zone.
        If baseline_frame is provided, calculate the difference: baseline_metrics - pred_metrics
        """
        from pytorch_msssim import ssim
        
        # If zones not provided, calculate them
        if zones is None:
            zones = split_into_zones(orig_frame, zone_type=zone_type)
        
        zone_metrics = {}
        
        # Convert to torch tensors if needed
        if not isinstance(orig_frame, torch.Tensor):
            orig_tensor = torch.from_numpy(orig_frame).unsqueeze(0)
        else:
            orig_tensor = orig_frame.unsqueeze(0)
            
        if not isinstance(pred_frame, torch.Tensor):
            pred_tensor = torch.from_numpy(pred_frame).unsqueeze(0)
        else:
            pred_tensor = pred_frame.unsqueeze(0)
        
        # Process baseline frame if provided    
        if baseline_frame is not None:
            if not isinstance(baseline_frame, torch.Tensor):
                baseline_tensor = torch.from_numpy(baseline_frame).unsqueeze(0)
            else:
                baseline_tensor = baseline_frame.unsqueeze(0)
        
        # Calculate metrics for each zone
        for zone_name, zone_slice in zones.items():
            # Special handling for background in center_bg mode
            if isinstance(zone_slice, dict):  # Background in center_bg mode
                mask = zone_slice["mask"]
                
                orig_zone = orig_tensor.clone()
                pred_zone = pred_tensor.clone()
                
                # Apply mask to all channels
                for c in range(orig_zone.shape[1]):  # For each channel
                    orig_zone[0, c][~mask] = 0
                    pred_zone[0, c][~mask] = 0
                
                # Calculate metric for prediction
                if metric == "ssim":
                    pred_metric = ssim(orig_zone, pred_zone, data_range=1, size_average=True).item()
                else:
                    # TV Loss calculation for masked region
                    tv_loss = torch.abs(pred_zone[:,:,1:,:] - pred_zone[:,:,:-1,:]).sum() + \
                              torch.abs(pred_zone[:,:,:,1:] - pred_zone[:,:,:,:-1]).sum()
                    # Normalize by number of pixels in the zone
                    pred_metric = tv_loss.item() / mask.sum()
                
                # If baseline provided, calculate baseline metric and difference
                if baseline_frame is not None:
                    baseline_zone = baseline_tensor.clone()
                    for c in range(baseline_zone.shape[1]):
                        baseline_zone[0, c][~mask] = 0
                        
                    if metric == "ssim":
                        base_metric = ssim(orig_zone, baseline_zone, data_range=1, size_average=True).item()
                        # For SSIM, higher is better, so baseline - perturbed shows how much we lost
                        # (negative value means perturbation improved SSIM)
                        zone_metrics[zone_name] = base_metric - pred_metric
                    else:
                        # TV Loss
                        tv_loss = torch.abs(baseline_zone[:,:,1:,:] - baseline_zone[:,:,:-1,:]).sum() + \
                                 torch.abs(baseline_zone[:,:,:,1:] - baseline_zone[:,:,:,:-1]).sum()
                        base_metric = tv_loss.item() / mask.sum()
                        # For TV loss, lower is better, so perturbed - baseline shows how much we lost
                        # (positive value means perturbation worsened TV loss)
                        zone_metrics[zone_name] = pred_metric - base_metric
                else:
                    # No baseline, just use the prediction metric
                    zone_metrics[zone_name] = pred_metric
                
            else:  # Normal zones
                # Get the zone data
                orig_zone = orig_tensor[0][zone_slice].unsqueeze(0)
                pred_zone = pred_tensor[0][zone_slice].unsqueeze(0)
                
                # Calculate metric for prediction
                if metric == "ssim":
                    pred_metric = ssim(orig_zone, pred_zone, data_range=1, size_average=True).item()
                else:
                    # TV Loss calculation
                    tv_loss = torch.abs(pred_zone[:,:,1:,:] - pred_zone[:,:,:-1,:]).sum() + \
                              torch.abs(pred_zone[:,:,:,1:] - pred_zone[:,:,:,:-1]).sum()
                    # Normalize by number of pixels in the zone
                    pred_metric = tv_loss.item() / (orig_zone.shape[2] * orig_zone.shape[3])
                
                # If baseline provided, calculate baseline metric and difference
                if baseline_frame is not None:
                    baseline_zone = baseline_tensor[0][zone_slice].unsqueeze(0)
                    
                    if metric == "ssim":
                        base_metric = ssim(orig_zone, baseline_zone, data_range=1, size_average=True).item()
                        # For SSIM, higher is better, so baseline - perturbed shows how much we lost
                        zone_metrics[zone_name] = base_metric - pred_metric
                    else:
                        # TV Loss
                        tv_loss = torch.abs(baseline_zone[:,:,1:,:] - baseline_zone[:,:,:-1,:]).sum() + \
                                 torch.abs(baseline_zone[:,:,:,1:] - baseline_zone[:,:,:,:-1]).sum()
                        base_metric = tv_loss.item() / (baseline_zone.shape[2] * baseline_zone.shape[3])
                        # For TV loss, lower is better, so perturbed - baseline shows how much we lost
                        # (positive value means perturbation worsened TV loss)
                        zone_metrics[zone_name] = pred_metric - base_metric
                else:
                    # No baseline, just use the prediction metric
                    zone_metrics[zone_name] = pred_metric
        
        return zone_metrics
    
    '''
    # Helper function for normalizing images for display
    def normalize(img):
        """Normalize image for display"""
        if isinstance(img, torch.Tensor):
            img = img.detach().cpu().numpy()
        
        img = img.copy()
        if img.min() < 0:
            img = (img + 1) / 2  # [-1, 1] -> [0, 1]
        return np.clip(img, 0, 1)
    '''

    def normalize(img):
        """Normalize image for display"""
        if isinstance(img, torch.Tensor):
            img = img.detach().cpu().numpy()
        
        img = img.copy().astype(np.float32)
        
        # Use actual min/max instead of assuming [-1,1]
        img_min, img_max = img.min(), img.max()
        if img_max > img_min:
            img = (img - img_min) / (img_max - img_min)
        else:
            img = np.zeros_like(img)
        
        return np.clip(img, 0, 1)
    
    # Initialize metrics
    all_zone_metrics = {}  # Will store metrics for each prediction key
    
    # Get appropriate metric name
    if baseline_predictions is not None:
        metric_description = "Difference from Baseline" 
        if metric == "ssim":
            metric_name = "SSIM Difference"
        else:
            metric_name = "TV Loss Difference"
    else:
        metric_name = "SSIM" if metric == "ssim" else "TV Loss"
        metric_description = metric_name
    
    # Create a TV loss calculator if needed
    if metric == "tv":
        tv_calculator = TotalVariation().to(device)
    
    # ===== PART 1: Calculate zone metrics for each prediction =====
    for pred_key, video_key in key_mapping.items():
        prediction = predictions[pred_key]
        video = videos[video_key][..., 15]  # Middle frame
        
        # Get baseline prediction if available
        baseline_pred = None
        if baseline_predictions is not None:
            # Find the appropriate key in baseline predictions
            baseline_keys = list(baseline_predictions.keys())
            if pred_key in baseline_predictions:
                baseline_pred = baseline_predictions[pred_key]
            elif len(baseline_keys) > 0:
                # If exact key not found, use first available baseline key
                baseline_pred = baseline_predictions[baseline_keys[0]]
                print(f"Using {baseline_keys[0]} as baseline for {pred_key}")
        
        # Check shapes
        print(f"Prediction {pred_key} shape: {prediction.shape}")
        print(f"Video {video_key} shape: {video.shape}")
        if baseline_pred is not None:
            print(f"Baseline prediction shape: {baseline_pred.shape}")
        
        # Ensure prediction and video have compatible shapes
        if prediction.shape[0] != video.shape[0]:
            print(f"Warning: Shape mismatch for {pred_key} vs {video_key}. Skipping.")
            continue
        
        # If baseline exists, ensure it has compatible shape
        if baseline_pred is not None and baseline_pred.shape[0] != prediction.shape[0]:
            print(f"Warning: Baseline shape {baseline_pred.shape} doesn't match prediction shape {prediction.shape}. Ignoring baseline.")
            baseline_pred = None
        
        N = video.shape[0]
        
        # Store metrics for all frames in this prediction
        pred_metrics = []
        
        # Calculate metrics for each frame
        for i in range(N):
            # Get zones for this frame
            zones = split_into_zones(video[i], zone_type=zone_type)
            
            # Calculate metrics for each zone
            try:
                # If baseline exists, calculate difference metrics
                if baseline_pred is not None:
                    zone_metrics = calculate_zone_metrics(
                        video[i], prediction[i], baseline_frame=baseline_pred[i], zones=zones, metric=metric
                    )
                else:
                    zone_metrics = calculate_zone_metrics(
                        video[i], prediction[i], zones=zones, metric=metric
                    )
                pred_metrics.append(zone_metrics)
            except Exception as e:
                print(f"Error calculating zone metrics for {pred_key}, frame {i}: {e}")
                # Create empty metrics
                if zone_type == "quadrants":
                    pred_metrics.append({
                        "top_left": 0, "top_right": 0, 
                        "bottom_left": 0, "bottom_right": 0
                    })
                elif zone_type == "center_bg":
                    pred_metrics.append({"center": 0, "background": 0})
                elif isinstance(zone_type, int):
                    empty_metrics = {}
                    for ii in range(zone_type):
                        for jj in range(zone_type):
                            empty_metrics[f"grid_{ii}_{jj}"] = 0
                    pred_metrics.append(empty_metrics)
        
        # Store metrics for this prediction
        all_zone_metrics[pred_key] = pred_metrics
        
        # Print average metrics for this prediction
        print(f"\nAverage {metric_description} for {pred_key} by zone:")
        
        # Calculate and print mean metrics across frames for each zone
        if len(pred_metrics) > 0:
            zone_names = list(pred_metrics[0].keys())
            
            for zone in zone_names:
                zone_values = [metrics[zone] for metrics in pred_metrics]
                mean_zone = np.mean(zone_values)
                print(f"  - {zone}: {mean_zone:.4f}")
    
    # ===== PART 2: Plot the original image, baseline, reconstruction, and heatmap side by side =====
    if display_plots and len(key_mapping) > 0:
        # Get a reference video key and shape
        ref_video_key = list(videos.keys())[0]
        ref_video = videos[ref_video_key][..., 15]
        N = ref_video.shape[0]
        
        # Determine the frames to plot
        if max_frames is not None and max_frames < N:
            # Evenly sample frames if max_frames is specified
            indices = np.linspace(0, N-1, max_frames, dtype=int)
        else:
            # Plot all frames
            indices = np.arange(N)
        
        # Determine number of panels based on whether baseline is available
        if baseline_predictions is not None:
            num_panels = 4  # Original, Baseline, Perturbed, Difference
            print(f"\nPlotting {len(indices)} frames with Original, Baseline, Perturbed, and Difference")
        else:
            num_panels = 3  # Original, Reconstruction, Heatmap
            print(f"\nPlotting {len(indices)} frames with Original, Reconstruction, and Heatmap")
        
        # REORDERING: Create an ordered list of prediction keys with "original_combined" first
        ordered_keys = []
        for key in key_mapping.keys():
            if key != "original_combined":
                ordered_keys.append(key)
        
        # If "original_combined" exists, insert it at the beginning of the list
        if "original_combined" in key_mapping:
            ordered_keys.insert(0, "original_combined")
        
        # For each frame index
        for frame_idx in indices:
            # Plot original reference frame
            ref_frame = ref_video[frame_idx]
            
            # For each prediction
            for pred_key in ordered_keys:
                try:
                    video_key = key_mapping[pred_key]
                    perturbed_frame = predictions[pred_key][frame_idx]
                    
                    # Get baseline frame if available
                    baseline_frame = None
                    if baseline_predictions is not None:
                        if pred_key in baseline_predictions:
                            baseline_frame = baseline_predictions[pred_key][frame_idx]
                        elif len(baseline_predictions) > 0:
                            # Use first available baseline prediction
                            first_key = list(baseline_predictions.keys())[0]
                            baseline_frame = baseline_predictions[first_key][frame_idx]
                    
                    # Get zone metrics for this prediction
                    zone_metrics = all_zone_metrics[pred_key][frame_idx]
                    
                    # Determine if we're showing differences or absolute values
                    is_difference = baseline_frame is not None
                    
                    # Create a figure with the appropriate number of subplots
                    fig, axes = plt.subplots(1, num_panels, figsize=(5 * num_panels, 5))
                    
                    # 1. Original Frame (leftmost)
                    axes[0].imshow(np.transpose(normalize(ref_frame), (1, 2, 0)))
                    axes[0].set_title(f"Original Frame {frame_idx}")
                    axes[0].axis('off')
                    
                    # Panel index for reconstruction and heatmap depends on whether baseline exists
                    recon_idx = 2 if is_difference else 1
                    heatmap_idx = 3 if is_difference else 2
                    
                    # 2. Baseline Reconstruction (if available)
                    if is_difference:
                        axes[1].imshow(np.transpose(normalize(baseline_frame), (1, 2, 0)))
                        axes[1].set_title(f"Baseline Reconstruction")
                        axes[1].axis('off')
                    
                    # 3. Perturbed Reconstruction 
                    axes[recon_idx].imshow(np.transpose(normalize(perturbed_frame), (1, 2, 0)))
                    if is_difference:
                        axes[recon_idx].set_title(f"Perturbed Reconstruction")
                    else:
                        axes[recon_idx].set_title(f"{pred_key} (Frame {frame_idx})")
                    axes[recon_idx].axis('off')
                    
                    # 4. Heatmap of metrics (rightmost)
                    # Determine colormap based on if we're showing differences
                    cmap_name = 'coolwarm' if is_difference else 'viridis'
                    
                    # Also determine normalization based on if we're showing differences
                    if is_difference:
                        # For differences, use symmetric normalization around zero
                        values = list(zone_metrics.values())
                        max_abs = max(abs(min(values)), abs(max(values))) if values else 1.0
                        norm = Normalize(vmin=-max_abs, vmax=max_abs)
                    else:
                        # For absolute values, use standard normalization
                        norm = None  # Let matplotlib handle it
                    
                    if isinstance(zone_type, int):
                        # For n×n grid, create a grid to display metrics
                        grid_values = np.zeros((zone_type, zone_type))
                        
                        for i in range(zone_type):
                            for j in range(zone_type):
                                zone_name = f"grid_{i}_{j}"
                                grid_values[i, j] = zone_metrics.get(zone_name, 0)
                        
                        # Show the heatmap
                        im = axes[heatmap_idx].imshow(grid_values, cmap=cmap_name, interpolation='nearest', norm=norm)
                        
                        # Add text labels with adjustable font size
                        fontsize = max(6, min(10, 16 - zone_type))  # Scale font size based on grid density
                        for i in range(zone_type):
                            for j in range(zone_type):
                                # Format the value based on magnitude
                                val = grid_values[i, j]
                                if abs(val) >= 0.01:
                                    text = f"{val:.3f}"
                                else:
                                    text = f"{val:.1e}"
                                    
                                axes[heatmap_idx].text(j, i, text,
                                           ha="center", va="center", color="white",
                                           fontsize=fontsize, fontweight='bold')
                        
                        # Add colorbar
                        cbar = plt.colorbar(im, ax=axes[heatmap_idx])
                        cbar.set_label(metric_name)
                        
                        axes[heatmap_idx].set_title(f"Zone {metric_description}")
                        
                    elif zone_type == "quadrants":
                        # For quadrants, create a 2×2 heatmap
                        quadrant_values = np.zeros((2, 2))
                        quadrant_values[0, 0] = zone_metrics.get("top_left", 0)
                        quadrant_values[0, 1] = zone_metrics.get("top_right", 0)
                        quadrant_values[1, 0] = zone_metrics.get("bottom_left", 0)
                        quadrant_values[1, 1] = zone_metrics.get("bottom_right", 0)
                        
                        # Show the heatmap
                        im = axes[heatmap_idx].imshow(quadrant_values, cmap=cmap_name, interpolation='nearest', norm=norm)
                        
                        # Add text labels
                        for i, j in [(0,0), (0,1), (1,0), (1,1)]:
                            val = quadrant_values[i, j]
                            if abs(val) >= 0.01:
                                text = f"{val:.3f}"
                            else:
                                text = f"{val:.1e}"
                            axes[heatmap_idx].text(j, i, text, ha="center", va="center", color="white", fontsize=10)
                        
                        # Add colorbar
                        cbar = plt.colorbar(im, ax=axes[heatmap_idx])
                        cbar.set_label(metric_name)
                        
                        axes[heatmap_idx].set_title(f"Quadrant {metric_description}")
                        
                    elif zone_type == "center_bg":
                        # For center/background, special visualization
                        center_val = zone_metrics.get("center", 0)
                        bg_val = zone_metrics.get("background", 0)
                        
                        # Create a mask-based visualization
                        mask = np.zeros((3, 3), dtype=bool)
                        mask[1, 1] = True  # Center is True, background is False
                        
                        # Create values array where center has one value, background another
                        values = np.ones((3, 3)) * bg_val
                        values[1, 1] = center_val
                        
                        # Show the heatmap
                        im = axes[heatmap_idx].imshow(values, cmap=cmap_name, interpolation='nearest', norm=norm)
                        
                        # Format values for display
                        if abs(center_val) >= 0.01:
                            center_text = f"Center\n{center_val:.3f}"
                        else:
                            center_text = f"Center\n{center_val:.1e}"
                            
                        if abs(bg_val) >= 0.01:
                            bg_text = f"BG\n{bg_val:.3f}"
                        else:
                            bg_text = f"BG\n{bg_val:.1e}"
                        
                        # Add text labels
                        axes[heatmap_idx].text(1, 1, center_text, ha="center", va="center", color="white", fontsize=10)
                        axes[heatmap_idx].text(0, 0, bg_text, ha="center", va="center", color="white", fontsize=10)
                        
                        # Add colorbar
                        cbar = plt.colorbar(im, ax=axes[heatmap_idx])
                        cbar.set_label(metric_name)
                        
                        axes[heatmap_idx].set_title(f"Center/Background {metric_description}")
                    
                    plt.tight_layout()
                    
                    # Save figure if requested
                    if save_plots and save_path_prefix:
                        zone_type_str = zone_type if isinstance(zone_type, str) else f"grid{zone_type}x{zone_type}"
                        type_str = "diff" if is_difference else "abs"
                        fig_path = f"{save_path_prefix}{pred_key}_frame{frame_idx}_{metric}_{zone_type_str}_{type_str}.png"
                        plt.savefig(fig_path, bbox_inches='tight', dpi=300)
                    
                    plt.show()
                    
                except Exception as e:
                    print(f"Error creating visualization for {pred_key} (Frame {frame_idx}): {e}")
    
    # Calculate overall mean metric
    overall_mean = 0
    if len(all_zone_metrics) > 0:
        # Average across all predictions and all zones
        all_values = []
        for pred_metrics in all_zone_metrics.values():
            for metrics in pred_metrics:
                all_values.extend(list(metrics.values()))
        
        if all_values:
            overall_mean = np.mean(all_values)
    
    # Update performance dictionary if provided
    if performance_dict is not None:
        try:
            # Calculate overall metrics across all zones
            for pred_key, pred_metrics in all_zone_metrics.items():
                if len(pred_metrics) > 0:
                    zone_names = list(pred_metrics[0].keys())
                    
                    for zone in zone_names:
                        zone_values = [metrics[zone] for metrics in pred_metrics]
                        zone_mean = np.mean(zone_values)
                        zone_median = np.median(zone_values)
                        
                        # Add to performance dict
                        if baseline_predictions is not None:
                            # This is a difference metric
                            if metric == "ssim":
                                performance_dict[f'diff_ssim_{zone}_D'] = zone_mean
                                performance_dict[f'median_diff_ssim_{zone}_D'] = zone_median
                            else:
                                performance_dict[f'diff_tv_{zone}_D'] = zone_mean
                                performance_dict[f'median_diff_tv_{zone}_D'] = zone_median
                        else:
                            # This is an absolute metric
                            if metric == "ssim":
                                performance_dict[f'mean_ssim_{zone}_D'] = zone_mean
                                performance_dict[f'median_ssim_{zone}_D'] = zone_median
                            else:
                                performance_dict[f'mean_tv_{zone}_D'] = zone_mean
                                performance_dict[f'median_tv_{zone}_D'] = zone_median
            
            # Add overall mean metric
            if baseline_predictions is not None:
                # This is a difference metric
                if metric == "ssim":
                    performance_dict['diff_ssim_D'] = overall_mean
                else:
                    performance_dict['diff_tv_D'] = overall_mean
            else:
                # This is an absolute metric
                if metric == "ssim":
                    performance_dict['mean_ssim_D'] = overall_mean
                else:
                    performance_dict['mean_tv_D'] = overall_mean
                
        except Exception as e:
            print(f"Error updating performance dictionary: {e}")
    
    if mean_flag:
        return overall_mean
    
    return performance_dict



def test_model_all(inputs_dict, labels_dict, model, criterion, device, pretrained_decoder=None, model_to_test=None, 
               statistical_testing=False, display_plots=True, save_plots=False, model_name="", metric="ssim", 
               mean_flag=False, zones=None, baseline_predictions=None):
    """
    Test the pretrained model on the provided dataset.

    Arguments:
        inputs_dict (dict): Dictionary of input data. Keys are movie names or slice identifiers. 
                           If model_to_test is 'encoder' or 'encoder_decoder', then elements have a shape of (TR, 3, 112, 112, 32). 
                           Else, shape of (TR, mask_size).
        labels_dict (dict): Dictionary of labels. Keys are movie names. 
                           If model_to_test is 'encoder' or 'encoder_decoder', then elements have a shape of (TR, mask_size). 
                           Else, shape of (TR, 3, 112, 112, 32).
        model (nn.Module): The pretrained neural network model to be tested.
        criterion (nn.Module): Loss function for testing.
        device (torch.device): Device to test the model on (CPU or GPU).
        pretrained_decoder (str, optional): Path to a pretrained decoder model. Default is None.
        model_to_test (str): Specifies which part of the model to test. Options are 'encoder', 'decoder', or 'encoder_decoder'.
        statistical_testing (bool, optional): Whether to perform statistical testing. Default is False.
        display_plots (bool, optional): Whether to display plots. Default is True.
        save_plots (bool, optional): Whether to save plots. Default is False.
        model_name (str, optional): Name of the model for saving plots. Default is "".
        zones (str or int, optional): Zones to consider for testing. Default is "quadrants", can also be "center_bg".
                                      If it is an integer, the function will analyze a number of zones = that integer squared.
                                      For example, if zones = 4, the function will analyze 16 zones (4x4).
        baseline_predictions (dict, optional): Dictionary of baseline predictions for comparison.

    Returns:
        results (dict): Dictionary containing test results including model predictions and losses.
    """
    print('Start testing:')
    tic = time.time()

    # Create outputs directory if it doesn't exist
    if save_plots:
        import os
        os.makedirs('outputs', exist_ok=True)

    model_type = ['encoder', 'decoder', 'encoder_decoder']
    if model_to_test not in model_type:
        print(f'model_to_test: {model_to_test} not recognized. Must be one of {model_type}')
        return None, None

    # Get list of input keys (movie names or slice identifiers)
    videos = list(inputs_dict.keys())
    inputs_shape = list(inputs_dict[videos[0]].shape)
    inputs_shape[0] = 'TR'
    print(f'### Testing {model_to_test} on inputs of shape {inputs_shape} over {len(videos)} videos/slices ###')
    
    if baseline_predictions is not None:
        print(f'### Using baseline predictions for comparison ###')

    criterion = criterion.to(device)
    # Set model in testing phase
    model.to(device)
    model.eval()

    # Load and set pretrained decoder if specified
    if pretrained_decoder:
        decoder = Decoder(labels_dict[next(iter(labels_dict))].shape[1])  # Assuming shape is consistent across labels
        state_dict = torch.load(pretrained_decoder)
        decoder.load_state_dict(state_dict)
        decoder.to(device)
        for param in decoder.parameters():
            param.requires_grad = False
        decoder.eval()

        print(f'Also using pretrained decoder {pretrained_decoder}')

    if model_to_test != 'encoder_decoder' and pretrained_decoder is None:
        results = {
            model_to_test + '_predictions': {},
            'total_losses': {}
        }
    else:
        results = {
            'encoder_predictions': {},
            'decoder_predictions': {},
            'total_losses': {}
        }

        decoder_saliency = np.zeros(labels_dict[list(labels_dict.keys())[0]].shape[1])

    results['test_performance'] = {}
    
    # Process each item in the inputs and labels dictionaries
    for key in inputs_dict.keys():
        input_tensor = torch.from_numpy(inputs_dict[key].astype('float32'))
        
        # Get the corresponding label - if it's a slice name, extract the original movie name
        if key in labels_dict:
            label_key = key
        else:
            # Extract the movie name from the key (assumed to be after the last underscore)
            # For keys like "slice_0_Payload", this will extract "Payload"
            if '_' in key:
                extracted_movie = key.split('_')[-1]
                if extracted_movie in labels_dict:
                    label_key = extracted_movie
                    print(f"Input key '{key}' not found in labels. Extracted and using '{label_key}' for labels.")
                else:
                    # If extracted name not found, use first available label
                    label_key = list(labels_dict.keys())[0]
                    print(f"Input key '{key}' and extracted movie '{extracted_movie}' not found in labels. Using '{label_key}' for labels.")
            else:
                # If no underscore in key, use first available label
                label_key = list(labels_dict.keys())[0]
                print(f"Input key '{key}' not found in labels and no movie name could be extracted. Using '{label_key}' for labels.")
        
        label_tensor = torch.from_numpy(labels_dict[label_key].astype('float32'))

        # Debug info about tensors but without using print_dict_tree
        print(f"input_tensor shape: {input_tensor.shape}, dtype: {input_tensor.dtype}")
        print(f"label_tensor shape: {label_tensor.shape}, dtype: {label_tensor.dtype}")
        
        test_set = torch.utils.data.TensorDataset(input_tensor, label_tensor)
        test_loader = torch.utils.data.DataLoader(
            test_set,
            batch_size=16,
            shuffle=False,
            pin_memory=torch.cuda.is_available(),
            num_workers=4
        )

        model_outputs, decoder_outputs, total_losses = [], [], []
        with torch.no_grad():
            for input, label in test_loader:
                input, label = input.to(device), label.to(device)
            
                decoder_output = None
                if model_to_test == 'encoder_decoder':
                    model_output, decoder_output = model(input.float())
                elif pretrained_decoder:
                    model_output = model(input.float()).to(device)
                    decoder_output = decoder(model_output.float())
                else:
                    model_output = model(input.float())
                        
                model_outputs.append(model_output.detach().cpu())
                if decoder_output is not None:
                    decoder_outputs.append(decoder_output.detach().cpu())
            
                # Apply the appropriate criterion based on the presence of decoder outputs
                if model_to_test == 'decoder':
                    *loss_metrics, total_loss, metrics_names = criterion(model_output, label[..., 15])          #--> middle frame
                elif decoder_output is None:
                    *loss_metrics, total_loss, metrics_names = criterion(model_output, label)
                else:
                    *loss_metrics, total_loss, metrics_names = criterion(model_output, label, decoder_output, input[..., 15])
                
                total_losses.append(total_loss.item())

        # Store the outputs in results
        if model_to_test != 'encoder_decoder' and pretrained_decoder is None:
            results[model_to_test + '_predictions'][key] = torch.cat(model_outputs, dim=0).numpy()
        else:
            results['encoder_predictions'][key] = torch.cat(model_outputs, dim=0).numpy()
            results['decoder_predictions'][key] = torch.cat(decoder_outputs, dim=0).numpy()
        
        results['total_losses'][key] = np.asarray(total_losses)

        if model_to_test != 'decoder':
            encoded = results['encoder_predictions'][key]
            labels = labels_dict[label_key] if key not in labels_dict else labels_dict[key]
            plot_metrics(labels, encoded, key, plot_TR=False, performance_dict=None, 
                        display_plots=display_plots,
                        save_plots=save_plots,
                        save_path=f'outputs/{key}_{model_name}.png' if save_plots else None)

    if model_to_test != 'decoder':
        all_encoded = results['encoder_predictions']
        all_labels = labels_dict
        # Using the last processed key for display
        results['test_performance'] = plot_metrics(labels, encoded, key, plot_TR=False, performance_dict=None, 
                        display_plots=display_plots,
                        save_plots=save_plots,
                        save_path=f'outputs/{key}_{model_name}.png' if save_plots else None)

        if statistical_testing:
            all_labels, all_predictions = [], []
            for key in labels_dict.keys():
                if key in results['encoder_predictions']:
                    all_predictions.append(results['encoder_predictions'][key])
                    all_labels.append(labels_dict[key])
            all_predictions = np.concatenate(all_predictions, axis=0)
            all_labels = np.concatenate(all_labels, axis=0)
            one_sample_permutation_test(all_labels, all_predictions)

    if model_to_test != 'encoder' or pretrained_decoder is not None:
        if model_to_test == 'decoder':
            print("\n\n\n ALRIGHT ZONES =", zones, "\n\n\n")
            
            # Check if we have baseline predictions to use
            if baseline_predictions is not None:
                # Use the modified plot_all_predictions7 with baseline comparison
                results['test_performance'] = plot_all_predictions7(
                    results['decoder_predictions'], 
                    labels_dict, 
                    results['test_performance'], 
                    display_plots,
                    save_plots=save_plots,
                    save_path_prefix='outputs/' if save_plots else None,
                    model_name=model_name, 
                    metric=metric, 
                    mean_flag=mean_flag,
                    zone_type=zones,
                    baseline_predictions=baseline_predictions
                )
            else:
                # Use the regular plot_all_predictions7 without baseline
                if zones is None:
                    results['test_performance'] = plot_all_predictions5(
                        results['decoder_predictions'], 
                        labels_dict, 
                        results['test_performance'], 
                        display_plots,
                        save_plots=save_plots,
                        save_path_prefix='outputs/' if save_plots else None,
                        model_name=model_name, 
                        metric=metric, 
                        mean_flag=mean_flag
                    )
                else:
                    results['test_performance'] = plot_all_predictions7(
                        results['decoder_predictions'], 
                        labels_dict, 
                        results['test_performance'], 
                        display_plots,
                        save_plots=save_plots,
                        save_path_prefix='outputs/' if save_plots else None,
                        model_name=model_name, 
                        metric=metric, 
                        mean_flag=mean_flag,
                        zone_type=zones
                    )

        else:
            # For encoder or encoder_decoder, use inputs_dict for ground truth
            if baseline_predictions is not None:
                results['test_performance'] = plot_all_predictions7(
                    results['decoder_predictions'], 
                    inputs_dict, 
                    results['test_performance'], 
                    display_plots,
                    save_plots=save_plots,
                    save_path_prefix='outputs/' if save_plots else None,
                    model_name=model_name, 
                    metric=metric, 
                    baseline_predictions=baseline_predictions
                )
            else:
                results['test_performance'] = plot_all_predictions5(
                    results['decoder_predictions'], 
                    inputs_dict, 
                    results['test_performance'], 
                    display_plots,
                    save_plots=save_plots,
                    save_path_prefix='outputs/' if save_plots else None,
                    model_name=model_name, 
                    metric=metric
                )
    print("using new function")
        
    if model_to_test == 'encoder_decoder':
        with torch.enable_grad():
            for key in inputs_dict.keys():
                predicted_fMRIs = torch.from_numpy(results['encoder_predictions'][key])
                # Get corresponding input for ground truth
                if key in inputs_dict:
                    input_key = key
                else:
                    # Use first input if key not found
                    input_key = list(inputs_dict.keys())[0]
                
                ground_truth_frames = torch.from_numpy(inputs_dict[input_key][..., 15])
                for i in range(predicted_fMRIs.shape[0]):
                    decoder_saliency += compute_saliency(model.decoder, predicted_fMRIs[i:i+1], ground_truth_frames[i:i+1], device)

        if display_plots:
            plot_saliency_distribution(decoder_saliency)
        results['decoder_saliency'] = decoder_saliency

    print("Testing completed. Total time: {:.2f} minutes".format((time.time() - tic) / 60))
    print('---')
    return results

def visualize_blocks_3(data_3d, blocks, losses, num_blocks=(3, 3, 3), figsize=None, colormap='viridis', 
                   title_prefix="Brain Divided into", is_difference=False):
    """
    Visualize brain blocks using maximum intensity projection for each z-layer
    Shows axial views for each layer along the z-dimension and displays loss values or differences
    
    Parameters:
    -----------
    data_3d : numpy.ndarray
        3D brain data
    blocks : dict
        Dictionary mapping block IDs to block boundaries
    losses : array-like
        Array of loss values or difference values, one per block (index 0 corresponds to block 1)
    num_blocks : tuple
        Number of blocks along each dimension (x, y, z)
    figsize : tuple, optional
        Figure size, if None will be calculated based on z-dimension blocks
    colormap : str
        Matplotlib colormap name to use for loss values
        'viridis' good for absolute values, 'coolwarm' good for differences
    title_prefix : str
        Prefix for the figure title
    is_difference : bool
        If True, values are treated as differences from baseline
    """
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    import matplotlib.colors as colors
    
    # Unpack the number of blocks in each dimension
    nx, ny, nz = num_blocks
    
    # Convert losses to numpy array if it isn't already
    losses = np.array(losses)
    
    # Print the actual loss values for debugging
    print("Loss/difference values for each block:")
    for i, val in enumerate(losses):
        print(f"  Block {i+1}: {val:.5f}")
    
    # Verify number of loss values matches number of blocks
    total_blocks = nx * ny * nz
    if len(losses) != total_blocks:
        raise ValueError(f"Expected {total_blocks} loss values, but got {len(losses)}")
    
    # Find the block with the highest absolute loss/difference
    max_abs_loss_idx = np.argmax(np.abs(losses))
    # Convert to 1-based indexing for block ID
    selected_block = max_abs_loss_idx + 1
    max_abs_loss_value = losses[max_abs_loss_idx]
    
    if is_difference:
        print(f"Block {selected_block} has the highest absolute difference from baseline: {max_abs_loss_value:.5f}")
    else:
        print(f"Block {selected_block} has the highest absolute loss: {max_abs_loss_value:.5f}")
    
    # If figsize is not specified, calculate it based on number of z layers
    if figsize is None:
        figsize = (6 * nz, 6)
    
    # Create the figure with the appropriate number of subplots (one per z-layer)
    fig, axes = plt.subplots(1, nz, figsize=figsize)
    if nz == 1:
        axes = [axes]  # Make sure axes is always a list for consistency
    
    title_text = f"{title_prefix} {nx}x{ny}x{nz} Blocks"
    if is_difference:
        title_text += " with Differences from Baseline"
    else:
        title_text += " with Loss Values"
        
    fig.suptitle(title_text, fontsize=16)
    
    # Get block division boundaries
    x_divisions = []
    y_divisions = []
    z_divisions = []
    
    for block_id, ((x_min, x_max), (y_min, y_max), (z_min, z_max)) in blocks.items():
        x_divisions.extend([x_min, x_max])
        y_divisions.extend([y_min, y_max])
        z_divisions.extend([z_min, z_max])
    
    # Get unique boundary values
    x_divisions = sorted(list(set(x_divisions)))
    y_divisions = sorted(list(set(y_divisions)))
    z_divisions = sorted(list(set(z_divisions)))
    
    # Print the division boundaries
    print("X divisions:", x_divisions)
    print("Y divisions:", y_divisions)
    print("Z divisions:", z_divisions)
    
    # Create a colormap normalization based on min/max loss values
    if is_difference and colormap == 'coolwarm':
        # For differences, we want a symmetric colormap centered at zero
        max_abs = max(abs(np.min(losses)), abs(np.max(losses)))
        norm = colors.Normalize(vmin=-max_abs, vmax=max_abs)
    else:
        # For absolute values or when not specifically using coolwarm for differences
        norm = colors.Normalize(vmin=np.min(losses), vmax=np.max(losses))
        
    cmap = plt.cm.get_cmap(colormap)
    
    # Process each z-layer
    for z_idx in range(nz):
        # Get z-boundaries for this layer
        z_min = z_divisions[z_idx]
        z_max = z_divisions[z_idx + 1]
        
        # Create layer name with actual z-range
        layer_name = f"Layer {z_idx} (z={z_min}-{z_max-1})"
        
        # Extract the data for this z-layer
        layer_data = data_3d[:, :, z_min:z_max]
        
        # Create maximum intensity projection along z-axis for just this layer
        layer_projection = np.max(layer_data > 0, axis=2).astype(float)
        print(f"Max layer_projection for layer {z_idx}=", np.max(layer_projection))
        
        # Create a brain mask - identifies where brain tissue exists
        brain_mask = layer_projection > 0
        
        # Plot black background
        axes[z_idx].imshow(np.zeros_like(layer_projection), cmap='gray', origin='lower')
        axes[z_idx].set_title(layer_name)
        
        # Add colored blocks and loss values for this layer
        for y_idx in range(ny):
            for x_idx in range(nx):
                # Calculate block ID (1-based index) using the formula:
                # block_id = 1 + x_idx + y_idx * nx + z_idx * nx * ny
                block_id = 1 + x_idx + y_idx * nx + z_idx * nx * ny
                
                # Get loss value for this block (subtract 1 for 0-based indexing)
                if block_id <= len(losses):
                    loss_value = losses[block_id - 1]
                else:
                    print(f"Warning: Block ID {block_id} exceeds losses array length {len(losses)}")
                    loss_value = 0
                
                # Get x, y boundaries for this block
                x_min, x_max = x_divisions[x_idx], x_divisions[x_idx + 1]
                y_min, y_max = y_divisions[y_idx], y_divisions[y_idx + 1]
                
                # Calculate center of block in data coordinates
                x_center = (x_min + x_max) / 2
                y_center = (y_min + y_max) / 2
                
                # Get color from colormap based on loss value
                block_color = cmap(norm(loss_value))
                
                # Create a mask for this block
                block_mask = np.zeros_like(layer_projection, dtype=bool)
                block_mask[x_min:x_max, y_min:y_max] = True
                
                # Combine with brain mask to only color brain regions
                block_brain_mask = block_mask & brain_mask
                
                # If there are brain voxels in this block, add the colored overlay
                if np.any(block_brain_mask):
                    # Create a colored overlay image for this block
                    colored_overlay = np.zeros((*layer_projection.shape, 4))  # RGBA
                    colored_overlay[block_brain_mask, :] = block_color
                    
                    # Add the colored overlay to the plot
                    axes[z_idx].imshow(colored_overlay, origin='lower', interpolation='nearest')
                
                # Format the displayed value
                if is_difference:
                    # For differences, show sign and format based on magnitude
                    if abs(loss_value) >= 0.01:
                        value_str = f"{loss_value:.4f}"
                    else:
                        value_str = f"{loss_value:.2e}"
                else:
                    # For absolute values
                    value_str = f"{loss_value:.4f}"
                
                # Add block ID and loss value label with a bounding box
                text_box = dict(facecolor='black', alpha=0.7, boxstyle='round')
                
                # Place the text in the center of the block
                axes[z_idx].text(y_center, x_center, f"{block_id}\n{value_str}", 
                               ha="center", va="center", color='white', fontweight='bold',
                               fontsize=10, bbox=text_box)
                
                # Highlight the block with the highest absolute loss/difference
                if block_id == selected_block:
                    highlight_rect = patches.Rectangle(
                        (y_min, x_min), y_max-y_min, x_max-x_min,
                        fill=False, edgecolor='red', linewidth=2
                    )
                    axes[z_idx].add_patch(highlight_rect)
    
    # Add a colorbar to show the mapping between loss values and colors
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    
    # Create a separate axis for the colorbar below the brain images
    cax = fig.add_axes([0.15, 0.05, 0.7, 0.02])  # [left, bottom, width, height]
    cbar = fig.colorbar(sm, cax=cax, orientation='horizontal')
    
    if is_difference:
        cbar.set_label('Difference from Baseline')
    else:
        cbar.set_label('Loss Value')
    
    # Adjust layout to make room for colorbar
    plt.tight_layout(rect=[0, 0.1, 1, 0.95])
    plt.show()
    
    # Save figure if requested
    if hasattr(plt, 'savefig'):
        suffix = "differences" if is_difference else "losses"
        plt.savefig(f'brain_blocks_{nx}x{ny}x{nz}_{suffix}_colormap.png', dpi=300, bbox_inches='tight')
    
    return blocks



def calculate_baseline_losses(model_name='decoder_4609_350', test_input=None, test_label=None, 
                           all_frames=True, save_plots=False, metric="tv", zones=4):
    """
    Calculate baseline reconstruction losses without any perturbation
    
    Parameters:
    -----------
    model_name : str
        Name of the model file to be used
    test_input : dict
        Dictionary with fMRIs for testing
    test_label : dict
        Dictionary with films
    all_frames : bool
        If True, runs test_model_all, otherwise runs test_model
    save_plots : bool
        If True, saves plots
    metric : str
        Metric to use for evaluation: "ssim" or "tv" (Total Variation)
    zones : str or int
        Zone configuration for analysis
        
    Returns:
    --------
    dict
        Dictionary with baseline losses for each zone
    """
    import numpy as np
    
    print("Calculating baseline losses with no perturbation")
    
    # Load the model
    model = Decoder(mask_size)
    state_dict = torch.load(model_name)
    model.load_state_dict(state_dict)
    
    # Testing parameters
    criterion = D_Loss()
    device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
    pretrained_decoder = None
    model_to_test = 'decoder'
    statistical_testing = False
    display_plots = True
    
    # Run test_model_all without perturbation to get baseline performance
    if all_frames:
        baseline_results = test_model_all(
            test_input, 
            test_label, 
            model, 
            criterion, 
            device, 
            pretrained_decoder, 
            model_to_test, 
            statistical_testing, 
            display_plots, 
            save_plots, 
            model_name=model_name + '_baseline', 
            metric=metric, 
            mean_flag=False,
            zones=zones
        )
        print("Baseline test completed")
        return baseline_results['test_performance']
    else:
        test_model(
            test_input, 
            test_label, 
            model, 
            criterion, 
            device, 
            pretrained_decoder, 
            model_to_test, 
            statistical_testing, 
            display_plots, 
            save_plots, 
            model_name=model_name + '_baseline'
        )
        return None
    


def test_new_decoder(real=True, model_name='decoder_4609_1650', test_on_train=False, test_input=testset['fMRIs'], 
                     test_label=testset['videos'], add_name='', regions=[], block_id=None, save_plots=False, all_frames=False,
                     change_mode='off', num_blocks=None, metric="ssim", zones=None, compare_to_baseline=True):
    '''
    Tests the decoder with brain block analysis
    
    Parameters:
    -----------
    real : bool
        If True, tests on real brain activity; if False, tests on brain activity from encoder
    model_name : str
        Name of the model file to be used
    test_on_train : bool
        If True, tests on the training set
    test_input : dict
        Dictionary with fMRIs for testing (one subdictionary for each film)
    test_label : dict
        Dictionary with films (one subdictionary for each film)
    add_name : str
        String to add to the end of output name to avoid overwriting
    regions : list
        List of region IDs to turn off (legacy parameter, use block_id instead)
    block_id : int
        ID of the 3D block to turn off (1-27)
    save_plots : bool
        If True, saves plots
    all_frames : bool
        If True, runs test_model_all, otherwise runs test_model
    change_mode : str
        'off' or 'amplify'
    zones (str or int, optional): Zones to consider for testing. Default is "quadrants", can also be "center_bg".
                                  If it is an integer, the function will analyze a number of zones = that integer squared.
                                  For example, if zones = 4, the function will analyze 16 zones (4x4).
    compare_to_baseline : bool
        If True, compute baseline performance without perturbation and report differences

    '''
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    import pickle
    import os

    print("zones = ", zones, "\n")
    
    # Load the model
    model = Decoder(mask_size)
    state_dict = torch.load(model_name)
    model.load_state_dict(state_dict)
    
    # Testing parameters
    criterion = D_Loss()
    device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
    pretrained_decoder = None
    model_to_test = 'decoder'
    statistical_testing = False
    display_plots = True
    
    # Handle special case for training data
    if test_on_train:
        num_samples = trainset["fMRIs"].shape[0]
        random_indices = np.random.choice(num_samples, size=30, replace=False)
        testset2 = {
            "fMRIs": {
                "test": trainset["fMRIs"][random_indices]
            },
            "videos": {
                "test": trainset["videos"][random_indices]
            }
        }
        test_input = testset2['fMRIs']
        test_label = testset2['videos']
    
    # Check if we're using a brain block
    if num_blocks is not None:
        # First, compute the baseline performance if requested
        baseline_performance = None
        if compare_to_baseline:
            print("Computing baseline performance with no perturbation...")
            if all_frames:
                baseline_results = test_model_all(
                    test_input, 
                    test_label, 
                    model, 
                    criterion, 
                    device, 
                    pretrained_decoder, 
                    model_to_test, 
                    statistical_testing, 
                    display_plots, 
                    save_plots, 
                    model_name=model_name + '_baseline', 
                    metric=metric,
                    mean_flag=False,
                    zones=zones
                )
                baseline_performance = baseline_results['test_performance']
                
                # Save baseline performance
                baseline_file = f'baseline_{model_name}_{metric}_zones{zones}.pkl'
                with open(baseline_file, 'wb') as f:
                    pickle.dump(baseline_performance, f)
                print(f"Baseline performance saved to {baseline_file}")
                
                print("Baseline performance:")
                for k, v in baseline_performance.items():
                    if isinstance(v, (int, float)):
                        print(f"  {k}: {v:.5f}")
            else:
                # For now, assume baseline is not needed for non-all_frames mode
                pass
                
        # Load the 3D brain data
        regions_3d = load_and_reshape_data('region_ids_4609+.npy')

        # Cut to the tightest rectangular prism around the brain
        (x_min, x_max), (y_min, y_max), (z_min, z_max) = find_bounds(regions_3d)
        brain_data = regions_3d[x_min:x_max+1, y_min:y_max+1, z_min:z_max+1]
        print(f"Original shape: {regions_3d.shape}, Brain shape: {brain_data.shape}")

        # Divide into blocks
        blocks = divide_brain_into_blocks(brain_data, num_blocks)
        
        # Flatten the mask
        mask4609 = np.load('mask_schaefer1000_4609.npy')
        flat_mask = mask4609.flatten()

        if block_id is not None:  # If some block is specified
            for video_name in test_input.keys():
                print(f"Turning off block {block_id} for video {video_name}")

                visualize_blocks(brain_data, blocks, num_blocks=num_blocks, selected_block=block_id)
                
                # Use the existing turn_off_regions function for the regions in this block
                if change_mode == 'off':
                    modified_data = turn_off_block_new(test_input[video_name], flat_mask, block_id, blocks, x_min, y_min, z_min)
                else:
                    modified_data = turn_off_block_new(test_input[video_name], flat_mask, block_id, blocks, x_min, y_min, z_min)
                
                modified_input[video_name] = modified_data
                
            # Use the modified input for testing
            test_input = modified_input

            if all_frames:
                results = test_model_all(
                    test_input, 
                    test_label, 
                    model, 
                    criterion, 
                    device, 
                    pretrained_decoder, 
                    model_to_test, 
                    statistical_testing, 
                    display_plots, 
                    save_plots, 
                    model_name=model_name + add_name, 
                    metric=metric
                )
                
                # Compare with baseline if available
                if baseline_performance is not None:
                    print("\nComparison with baseline:")
                    block_performance = results['test_performance']
                    
                    # Get the main metric based on metric type
                    main_metric_key = f'mean_{metric}_D'
                    if main_metric_key in baseline_performance and main_metric_key in block_performance:
                        baseline_val = baseline_performance[main_metric_key]
                        block_val = block_performance[main_metric_key]
                        diff = block_val - baseline_val
                        
                        print(f"  {main_metric_key}: {block_val:.5f} (baseline: {baseline_val:.5f}, diff: {diff:.5f})")
                    
                    # Compare zone-specific metrics
                    for k in baseline_performance.keys():
                        if k.startswith('mean_') and k != main_metric_key and k in block_performance:
                            baseline_val = baseline_performance[k]
                            block_val = block_performance[k]
                            diff = block_val - baseline_val
                            print(f"  {k}: {block_val:.5f} (baseline: {baseline_val:.5f}, diff: {diff:.5f})")
                            
                return results
            else:
                test_model(
                    test_input, 
                    test_label, 
                    model, 
                    criterion, 
                    device, 
                    pretrained_decoder, 
                    model_to_test, 
                    statistical_testing, 
                    display_plots, 
                    save_plots, 
                    model_name=model_name + add_name
                )
                return None

        elif block_id is None:  # Loop through all blocks
            losses = []
            loss_diffs = []  # Store differences from baseline
            
            # Create a dictionary to store detailed results for each block
            block_results = {}
            
            # Get baseline metric key based on metric type
            main_metric_key = f'mean_{metric}_D'
            
            # Loop through all blocks
            for block_id in range(1, num_blocks[0] * num_blocks[1] * num_blocks[2] + 1):
                print(f"\nProcessing block {block_id}...")
                modified_input = {}
                
                for video_name in test_input.keys():
                    # Turn off this block in the input data
                    if change_mode == 'off':
                        ffa_mask = np.load('enhanced_union_FFA.npy')
                        brain_mask = np.load('mask_schaefer1000_4609.npy')

                        # Find which voxels to zero out (intersection of brain mask and FFA mask)
                        intersection = (brain_mask > 0) & (ffa_mask > 0)
                        brain_indices = np.where(brain_mask.flatten())[0]
                        ffa_indices = np.where(intersection.flatten())[0]
                        voxels_to_zero = np.where(np.isin(brain_indices, ffa_indices))[0]

                        # Create modified data
                        modified_data = test_input[video_name].copy()
                        print("\n\nmodified_data shape =", modified_data.shape, "\n\n")
                        modified_data[:, voxels_to_zero] = 0
#                    modified_data = turn_off_block_new(test_input[video_name], flat_mask, block_id, blocks, x_min, y_min, z_min)

#                        modified_data = turn_off_block_new(test_input[video_name], flat_mask, block_id, blocks, x_min, y_min, z_min)
                    else:
                        ffa_mask = np.load('enhanced_union_FFA.npy')
                        brain_mask = np.load('mask_schaefer1000_4609.npy')

                        # Find which voxels to zero out (intersection of brain mask and FFA mask)
                        intersection = (brain_mask > 0) & (ffa_mask > 0)
                        brain_indices = np.where(brain_mask.flatten())[0]
                        ffa_indices = np.where(intersection.flatten())[0]
                        voxels_to_zero = np.where(np.isin(brain_indices, ffa_indices))[0]

                        # Create modified data
                        modified_data = test_input[video_name].copy()
                        print("\n\nmodified_data shape =", modified_data.shape, "\n\n")
                        modified_data[:, voxels_to_zero] = 0
#                    modified_data = turn_off_block_new(test_input[video_name], flat_mask, block_id, blocks, x_min, y_min, z_min)

#                        modified_data = turn_off_block_new(test_input[video_name], flat_mask, block_id, blocks, x_min, y_min, z_min)
                    
                    modified_input[video_name] = modified_data
                    
                # Test with this block turned off
                if all_frames:
                    results = test_model_all(
                        modified_input, 
                        test_label, 
                        model, 
                        criterion, 
                        device, 
                        pretrained_decoder, 
                        model_to_test, 
                        statistical_testing, 
                        display_plots, 
                        save_plots, 
                        model_name=model_name + f'_block_{block_id}', 
                        metric=metric, 
                        mean_flag=False,  # Return full performance dict 
                        zones=zones
                    )
                    
                    # Store the block's performance
                    block_performance = results['test_performance']
                    block_results[block_id] = block_performance
                    
                    # Get the main loss value
                    if metric == "tv":
                        mean_loss = block_performance.get(main_metric_key, 0)
                    elif metric == "ssim":
                        mean_loss = block_performance.get(main_metric_key, 0)
                    
                    print(f"Block {block_id}, mean loss = {mean_loss:.5f}")
                    losses.append(mean_loss)
                    
                    # Calculate difference from baseline if available
                    if baseline_performance is not None and main_metric_key in baseline_performance:
                        baseline_val = baseline_performance[main_metric_key]
                        diff = mean_loss - baseline_val
                        loss_diffs.append(diff)
                        print(f"  Difference from baseline: {diff:.5f}")
                    else:
                        loss_diffs.append(0)  # Default if baseline not available
                    
                else:
                    test_model(
                        modified_input, 
                        test_label, 
                        model, 
                        criterion, 
                        device, 
                        pretrained_decoder, 
                        model_to_test, 
                        statistical_testing, 
                        display_plots, 
                        save_plots, 
                        model_name=model_name + f'_block_{block_id}'
                    )
            
            # Find the block with the biggest impact
            if compare_to_baseline and baseline_performance is not None:
                lossiest = np.argmax(np.abs(loss_diffs))
                impact_val = loss_diffs[lossiest]
            else:
                lossiest = np.argmax(losses)
                impact_val = losses[lossiest]
                
            lossiest_block = lossiest + 1  # Convert to 1-based indexing
            print(f"\nBlock with the biggest impact: {lossiest_block} (value: {impact_val:.5f})")
                
            # Visualize with appropriate values
            if compare_to_baseline and baseline_performance is not None:
                print("Visualizing differences from baseline...")
                visualize_blocks_3(
                    brain_data,
                    blocks,
                    loss_diffs,  # Use differences from baseline
                    num_blocks=num_blocks,
                    colormap='coolwarm'  # Better for showing positive/negative differences
                )
            else:
                print("Visualizing absolute loss values...")
                visualize_blocks_3(
                    brain_data,
                    blocks,
                    losses,
                    num_blocks=num_blocks,
                    colormap='viridis'  # For absolute values
                )
            
            # Save results
            results_data = {
                'losses': losses,
                'baseline_performance': baseline_performance,
                'loss_differences': loss_diffs if baseline_performance is not None else None,
                'block_with_biggest_impact': lossiest_block,
                'block_results': block_results
            }
            
            results_file = f'block_analysis_{model_name}_{metric}_zones{zones}.pkl'
            with open(results_file, 'wb') as f:
                pickle.dump(results_data, f)
            print(f"Results saved to {results_file}")
            
            return results_data

    # Original code for regions list
    elif regions:
        # Specific regions case
        fmri_regions_off = test_input.copy()
        
        for video_name in test_input.keys():
            fmri_regions_off[video_name] = turn_off_regions(test_input[video_name], regions)
            
        test_input = fmri_regions_off
    
    print("test input shape =", print_dict_tree(test_input))
    print("test label shape =", print_dict_tree(test_label))

    # Run the appropriate test model function
    # Do recall this code is just being used for the case we want regions, not blocks
    if all_frames:
        results = test_model_all(
            test_input, 
            test_label, 
            model, 
            criterion, 
            device, 
            pretrained_decoder, 
            model_to_test, 
            statistical_testing, 
            display_plots, 
            save_plots, 
            model_name=model_name + add_name, 
            metric=metric, 
            mean_flag=True,
            zones=zones
        )
        return results
    else:
        test_model(
            test_input, 
            test_label, 
            model, 
            criterion, 
            device, 
            pretrained_decoder, 
            model_to_test, 
            statistical_testing, 
            display_plots, 
            save_plots, 
            model_name=model_name + add_name
        )
        return None

'''
def run_all_blocks_test():
    """
    Run test with all blocks analysis, comparing to baseline performance
    """
    # Call the test_new_decoder function with baseline comparison
    results = test_new_decoder(
        real=True,
        model_name='decoder_4609_350',
        test_input=filtered_trainset['fMRIs'],
        test_label=filtered_trainset['videos'],
        all_frames=True,
        save_plots=False,
        add_name='_block_analysis',
        change_mode='off',
        num_blocks=(2, 2, 2),
        metric="tv",
        zones=4,
        compare_to_baseline=True  # Enable baseline comparison
    )
    
    return results

# Call the function to run the analysis
results = run_all_blocks_test()
'''







def run_all_blocks_test():
    """
    Run test with all blocks analysis, comparing to baseline performance
    Shows difference in loss between perturbed and baseline reconstructions for each frame
    """
    # Step 1: Calculate baseline reconstructions (no perturbation)
    print("=== STEP 1: Calculating baseline reconstructions (no perturbation) ===")
    
    #model_name = 'decoder_4609_350'
    model_name = 'decoder_4936_201'
    test_input = filtered_trainset['fMRIs']
    test_label = filtered_trainset['videos']
    num_blocks = (1, 1, 2)
    metric = "tv"
    zones = 4
    save_plots = False
    
    # Load the model
    model = Decoder(mask_size)
    state_dict = torch.load(model_name)
    model.load_state_dict(state_dict)
    
    # Testing parameters
    criterion = D_Loss()
    device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
    pretrained_decoder = None
    model_to_test = 'decoder'
    statistical_testing = False
    display_plots = True
    
    # Generate baseline predictions
    baseline_results = test_model_all(
        test_input, 
        test_label, 
        model, 
        criterion, 
        device, 
        pretrained_decoder, 
        model_to_test, 
        statistical_testing, 
        display_plots=False,  # Don't display plots for baseline
        save_plots=False,
        model_name=model_name + '_baseline', 
        metric=metric, 
        mean_flag=False,
        zones=zones
    )
    
    # Save baseline performance metrics and predictions
    baseline_performance = baseline_results['test_performance']
    baseline_predictions = baseline_results['decoder_predictions']
    
    print("Baseline performance metrics:")
    for k, v in baseline_performance.items():
        if isinstance(v, (int, float)):
            print(f"  {k}: {v:.5f}")
    
    # Step 2: Run the perturbation analysis for each block
    print("\n=== STEP 2: Running perturbation analysis for each block ===")
    
    # Load 3D brain data and prepare blocks
    regions_3d = load_and_reshape_data('region_ids_4609+.npy')
    (x_min, x_max), (y_min, y_max), (z_min, z_max) = find_bounds(regions_3d)
    brain_data = regions_3d[x_min:x_max+1, y_min:y_max+1, z_min:z_max+1]
    print(f"Original shape: {regions_3d.shape}, Brain shape: {brain_data.shape}")
    
    # Divide into blocks
    blocks = divide_brain_into_blocks(brain_data, num_blocks)
    
    # Prepare mask for block manipulation
    #mask4609 = np.load('mask_schaefer1000_4609.npy')
    
    mask4609 = np.load('mask_4609+ffa.npy')
    flat_mask = mask4609.flatten()
    
    # Container for results
    block_losses = []
    loss_differences = []
    block_results = {}
    
    # Main metric key based on chosen metric
    main_metric_key = f'mean_{metric}_D'
    
    # Loop through all blocks
    total_blocks = num_blocks[0] * num_blocks[1] * num_blocks[2]
    for block_id in range(1, total_blocks + 1):
        print(f"\nProcessing block {block_id}/{total_blocks}")
        
        # Create modified input with current block turned off
        modified_input = {}
        for video_name in test_input.keys():
            if block_id == 1:
                print("\n\nDOING FFA NOW\n\n")
                ffa_mask = np.load('enhanced_union_FFA.npy')
            else:
                print("\n\nDOING PPA NOW\n\n")
                ffa_mask = np.load('resampled_ppa.npy')
            #modified_data = turn_off_block_new(
            #    test_input[video_name], 
            #    flat_mask, 
            #    block_id, 
            #    blocks, 
            #    x_min, 
            #    y_min, 
            #    z_min
            #)
            
            #brain_mask = np.load('mask_schaefer1000_4609.npy')
            brain_mask = np.load('mask_4609+ffa.npy')

            # Find which voxels to zero out (intersection of brain mask and FFA mask)
            intersection = (brain_mask > 0) & (ffa_mask > 0)
            brain_indices = np.where(brain_mask.flatten())[0]
            ffa_indices = np.where(intersection.flatten())[0]
            voxels_to_zero = np.where(np.isin(brain_indices, ffa_indices))[0]
            #print(f"Zeroing out {len(voxels_to_zero)} voxels out of 4609 total ({100*len(voxels_to_zero)/4609:.1f}%)")
            print(f"Zeroing out {len(voxels_to_zero)} voxels out of" + str(mask_size) + f"total ({100*len(voxels_to_zero)/mask_size:.1f}%)")

            # Create modified data
            modified_data = test_input[video_name].copy()
            print("\n\nmodified_data shape =", modified_data.shape, "\n\n")
            modified_data[:, voxels_to_zero] = 0
#                    modified_data = turn_off_block_new(test_input[video_name], flat_mask, block_id, blocks, x_min, y_min, z_min)

#                        modified_data = turn_off_block_new(test_input[video_name], flat_mask, block_id, blocks, x_min, y_min, z_min)
        
            modified_input[video_name] = modified_data
#            modified_input[video_name] = modified_data
        
        # Run test with this block turned off, passing baseline predictions for comparison
        block_results = test_model_all(
            modified_input, 
            test_label, 
            model, 
            criterion, 
            device, 
            pretrained_decoder, 
            model_to_test, 
            statistical_testing, 
            display_plots=True,  # Show plots with difference visualization
            save_plots=save_plots, 
            model_name=model_name + f'_block_{block_id}', 
            metric=metric, 
            mean_flag=False,
            zones=zones,
            baseline_predictions=baseline_predictions  # Pass baseline predictions for frame-level comparison
        )
        
        # Get the main metric value
        block_performance = block_results['test_performance']
        if main_metric_key in block_performance:
            mean_loss = block_performance[main_metric_key]
        else:
            # If the key isn't found, try to find a similar key
            metric_keys = [k for k in block_performance.keys() if metric in k.lower() and 'mean' in k.lower()]
            mean_loss = block_performance[metric_keys[0]] if metric_keys else 0
            
        print(f"Block {block_id} - Mean loss: {mean_loss:.5f}")
        block_losses.append(mean_loss)
        
        # Calculate difference from baseline
        if main_metric_key in baseline_performance:
            baseline_val = baseline_performance[main_metric_key]
            diff = mean_loss - baseline_val
            loss_differences.append(diff)
            print(f"  Difference from baseline: {diff:.5f}")
        else:
            loss_differences.append(0)
    
    # Step 3: Visualize block-level results
    print("\n=== STEP 3: Visualizing block-level impact analysis ===")
    
    # Find block with biggest impact
    max_diff_idx = np.argmax(np.abs(loss_differences))
    max_diff_block = max_diff_idx + 1
    max_diff_value = loss_differences[max_diff_idx]
    
    print(f"Block with biggest impact: {max_diff_block} (difference: {max_diff_value:.5f})")
    
    # Visualize loss differences across blocks
    visualize_blocks_3(
        brain_data,
        blocks,
        loss_differences,  # Use differences from baseline
        num_blocks=num_blocks,
        colormap='coolwarm',  # Better for showing positive/negative differences
        is_difference=True  # Indicate we're showing differences
    )
    
    # Save all results
    results_data = {
        'baseline_performance': baseline_performance,
        'block_losses': block_losses,
        'loss_differences': loss_differences,
        'max_impact_block': max_diff_block,
        'max_impact_value': max_diff_value
    }
    
    # Save to file
    import pickle
    results_file = f'block_analysis_{model_name}_{metric}_zones{zones}.pkl'
    with open(results_file, 'wb') as f:
        pickle.dump(results_data, f)
    print(f"Results saved to {results_file}")
    
    return results_data

# Call the function to run the analysis
results = run_all_blocks_test()


#i didnt include visualize blocks on this one
def run_all_blocks_test_debug():
    """
    Run test with all blocks analysis with extensive debugging
    to find why all blocks have the same difference value
    """
    # Step 1: Calculate baseline reconstructions (no perturbation)
    print("=== STEP 1: Calculating baseline reconstructions (no perturbation) ===")
    
    model_name = 'decoder_4609_350'
    test_input = filtered_trainset['fMRIs']
    test_label = filtered_trainset['videos']
    num_blocks = (2, 2, 2)
    metric = "tv"
    zones = 4
    save_plots = False
    
    # Load the model
    model = Decoder(mask_size)
    state_dict = torch.load(model_name)
    model.load_state_dict(state_dict)
    
    # Testing parameters
    criterion = D_Loss()
    device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
    pretrained_decoder = None
    model_to_test = 'decoder'
    statistical_testing = False
    display_plots = True
    
    # Generate baseline predictions
    baseline_results = test_model_all(
        test_input, 
        test_label, 
        model, 
        criterion, 
        device, 
        pretrained_decoder, 
        model_to_test, 
        statistical_testing, 
        display_plots=False,  # Don't display plots for baseline
        save_plots=False,
        model_name=model_name + '_baseline', 
        metric=metric, 
        mean_flag=False,
        zones=zones
    )
    
    # Save baseline performance metrics and predictions
    baseline_performance = baseline_results['test_performance']
    baseline_predictions = baseline_results['decoder_predictions']
    
    print("\n==== DEBUG: Baseline performance metrics ====")
    for k, v in baseline_performance.items():
        if isinstance(v, (int, float)):
            print(f"  {k}: {v}")
    
    # Identify the main metric key we'll be using
    if metric == "tv":
        main_keys = [k for k in baseline_performance.keys() if 'tv' in k.lower() and 'mean' in k.lower()]
    else:
        main_keys = [k for k in baseline_performance.keys() if 'ssim' in k.lower() and 'mean' in k.lower()]
    
    print(f"\n==== DEBUG: Available metric keys: {main_keys} ====")
    
    # Determine the main metric key to use
    main_metric_key = f'mean_{metric}_D'
    if main_metric_key not in baseline_performance:
        if main_keys:
            main_metric_key = main_keys[0]
            print(f"Main metric key '{main_metric_key}' not found. Using '{main_metric_key}' instead.")
        else:
            print(f"ERROR: No suitable metric keys found in baseline performance!")
            return None
    
    baseline_val = baseline_performance[main_metric_key]
    print(f"Using main metric key: {main_metric_key} with baseline value: {baseline_val}")
    
    # Step 2: Run the perturbation analysis for each block
    print("\n=== STEP 2: Running perturbation analysis for each block ===")
    
    # Load 3D brain data and prepare blocks
    regions_3d = load_and_reshape_data('region_ids_4609+.npy')
    (x_min, x_max), (y_min, y_max), (z_min, z_max) = find_bounds(regions_3d)
    brain_data = regions_3d[x_min:x_max+1, y_min:y_max+1, z_min:z_max+1]
    print(f"Original shape: {regions_3d.shape}, Brain shape: {brain_data.shape}")
    
    # Divide into blocks
    blocks = divide_brain_into_blocks(brain_data, num_blocks)
    
    # Prepare mask for block manipulation
    mask4609 = np.load('mask_schaefer1000_4609.npy')
    flat_mask = mask4609.flatten()
    
    # Container for results
    block_losses = []
    loss_differences = []
    block_results_dict = {}
    
    # Loop through a subset of blocks for debugging
    total_blocks = num_blocks[0] * num_blocks[1] * num_blocks[2]
    for block_id in range(1, total_blocks + 1):
        print(f"\n==== PROCESSING BLOCK {block_id}/{total_blocks} ====")
        
        # Create modified input with current block turned off
        modified_input = {}
        for video_name in test_input.keys():
            modified_data = turn_off_block_new(
                test_input[video_name], 
                flat_mask, 
                block_id, 
                blocks, 
                x_min, 
                y_min, 
                z_min
            )
            modified_input[video_name] = modified_data
        
        # Run test with this block turned off
        block_results = test_model_all(
            modified_input, 
            test_label, 
            model, 
            criterion, 
            device, 
            pretrained_decoder, 
            model_to_test, 
            statistical_testing, 
            display_plots=False,  # Don't show plots during debugging
            save_plots=False, 
            model_name=model_name + f'_block_{block_id}', 
            metric=metric, 
            mean_flag=False,
            zones=zones,
            baseline_predictions=None  # Don't pass baseline during debug - we'll compute differences manually
        )
        
        print(f"\n==== DEBUG: Block {block_id} Results ====")
        block_performance = block_results['test_performance']
        block_results_dict[block_id] = block_performance
        
        # Show all metrics in block performance
        for k, v in block_performance.items():
            if isinstance(v, (int, float)):
                print(f"  {k}: {v}")
        
        # Calculate actual block mean loss value
        if main_metric_key in block_performance:
            mean_loss = block_performance[main_metric_key]
        else:
            print(f"WARNING: Main metric key '{main_metric_key}' not found in block_performance!")
            # Try to find similar keys
            if metric == "tv":
                similar_keys = [k for k in block_performance.keys() if 'tv' in k.lower() and 'mean' in k.lower()]
            else:
                similar_keys = [k for k in block_performance.keys() if 'ssim' in k.lower() and 'mean' in k.lower()]
                
            if similar_keys:
                mean_loss = block_performance[similar_keys[0]]
                print(f"Using alternative key: {similar_keys[0]}")
            else:
                mean_loss = 0
                print("No similar keys found! Using 0.")
        
        # Manually calculate the difference
        diff = mean_loss - baseline_val
        
        print(f"BLOCK {block_id}:")
        print(f"  Raw loss value: {mean_loss}")
        print(f"  Baseline value: {baseline_val}")
        print(f"  Difference: {diff}")
        
        # Store the values
        block_losses.append(mean_loss)
        loss_differences.append(diff)
    
    # Print all the collected values
    print("\n==== FINAL DEBUG: All collected values ====")
    print("Block Losses:")
    for i, loss in enumerate(block_losses):
        print(f"  Block {i+1}: {loss}")
    
    print("\nDifferences from baseline:")
    for i, diff in enumerate(loss_differences):
        print(f"  Block {i+1}: {diff}")
    
    # For testing, try a different approach to calculate differences
    print("\nRecalculating differences to double-check:")
    for i, loss in enumerate(block_losses):
        recalc_diff = loss - baseline_val
        print(f"  Block {i+1}: Loss={loss}, Baseline={baseline_val}, Diff={recalc_diff}")
    
    # Find block with biggest impact
    max_diff_idx = np.argmax(np.abs(loss_differences))
    max_diff_block = max_diff_idx + 1
    max_diff_value = loss_differences[max_diff_idx]
    
    print(f"\nBlock with biggest impact: {max_diff_block} (difference: {max_diff_value})")
    
    # Return the debug data for inspection
    debug_data = {
        'baseline_performance': baseline_performance,
        'baseline_value': baseline_val,
        'main_metric_key': main_metric_key,
        'block_losses': block_losses,
        'loss_differences': loss_differences,
        'block_results': block_results_dict
    }
    
    return debug_data

# Run the debug function
#debug_results = run_all_blocks_test_debug()

#tag blocks with original and perturbed reconstruction

NameError: name 'testset' is not defined