In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.colors as mcolors
import random
import os

# Figure dir
figDir = os.path.join(os.getcwd(), 'figures_output')
if not os.path.isdir(figDir):
    os.makedirs(figDir)

# Define colors for the columns - reduced set for post feature selection - columns = colors
colors = ['red', 'green', 'blue', 'purple', 'orange', 'gold'] 
colorSub = ['red', 'green', 'purple', 'orange']

# Define the number of gradient steps and grid size
num_steps = 6
color_step = .16
gridFigSize = (2.69, 3.16)

# Displaying the shadow feature set in Figure 3A requires shuffling. 
shuffleSet = [False, True]
perRow = [False, True]
shuffleSet_name = ['Grid.svg', 'Grid_Shuffle.svg']

# Create factor for alpha modulation
alphaSet = [0.9, 0.74, 0.58, 0.42, 0.26, 0.1]
gridlinewidth = 0.5

grid_size = (6, 6)
# grid_size = (4, 6)
# grid_size = (4, 4)
for shuffle_switch, fName in zip(shuffleSet, shuffleSet_name):
    
    if grid_size[1] == 6:
        colorSet = colors
    else:
        colorSet = colorSub

    # Create a gradient matrix
    gradient_matrix = np.zeros((grid_size[0], grid_size[1], 4))
    for i in range(grid_size[0]):
        for j in range(grid_size[1]):
            color_index = min(j, len(colorSet) - 1)
            gradient_matrix[i, j] = mcolors.to_rgba(colorSet[color_index], alphaSet[i])

    # Remove the 1st and 4th rows to illustrate train/test split
    gradient_matrix = np.delete(gradient_matrix, [0, 3], axis=0)

    # Shuffle columns if needed
    if shuffle_switch:
        for col_i in np.arange(gradient_matrix.shape[1]):
            random.shuffle(gradient_matrix[:, col_i, 3])

    # Plotting
    rows = np.shape(gradient_matrix)[0]
    columns = np.shape(gradient_matrix)[1]

    for row_switch in perRow:
        if row_switch:
            # Create a figure and axis for each row of 'gradient_matrix'
            fig, ax = plt.subplots(rows, 1, figsize=(gridFigSize[0], gridFigSize[1]))

            # Plot each row of the gradient matrix as a colored table
            for i in range(rows):
                # Extract column
                col_data = gradient_matrix[i, :, :]

                ax[i].imshow(col_data[None, :, :], aspect='auto')
                # ax[i].imshow(np.swapaxes(col_data, 0, 1), aspect='auto')
                # Remove x and y ticks
                ax[i].set_xticks([])
                ax[i].set_yticks([])

                # Draw lines to separate cells
                for j in range(1, columns):
                    ax[i].axvline(j - 0.5, color='black', linewidth=gridlinewidth)

            fig.savefig(os.path.join(figDir, f'Rows_{fName}'), format='svg', dpi=1200)

        else:
            # Create the figure and axis
            fig, ax = plt.subplots(figsize=(gridFigSize[0], gridFigSize[1]))

            # Plot the gradient matrix as a colored table
            ax.imshow(np.swapaxes(gradient_matrix, 0, 1), aspect='auto')

            # Remove x and y ticks
            ax.set_xticks([])
            ax.set_yticks([])

            # Draw lines to separate cells
            for i in range(1, columns):
                ax.axhline(i - 0.5, color='black', linewidth=gridlinewidth)
            for j in range(1, rows):
                ax.axvline(j - 0.5, color='black', linewidth=gridlinewidth)

            fig.savefig(os.path.join(figDir, fName), format='svg', dpi=1200)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# Example matrix of RGB values (normalized between 0 and 1)
rgb_matrix = np.array([
    [1, 0, 0],  # Red
    [0, 1, 0],  # Green
    [0, 0, 1],  # Blue
    [1, 1, 0],  # Yellow
    [0, 1, 1],  # Cyan
    [1, 0, 1]   # Magenta
])

# Plotting
fig, ax = plt.subplots(figsize=(1, 6))  # Adjusted for a 6x1 aspect ratio
for i, color in enumerate(rgb_matrix):
    # Create a rectangle patch for each color and add it to the plot
    # Note the change in the rectangle's position and dimensions to accommodate the 6x1 layout
    rect = patches.Rectangle((0, 5 - i), 1, 1, linewidth=1, edgecolor='black', facecolor=color)
    ax.add_patch(rect)

ax.set_xlim(0, 1)
ax.set_ylim(0, 6)
ax.axis('off')  # Hide axis for a cleaner look
plt.show()
