<a href="https://colab.research.google.com/github/OOAAHH/WDL_Tools_docs/blob/main/docs/test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 点击运行后，在下方输入矩阵尺寸

In [3]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Matrix Splitting Algorithm with Enhanced Visualization

This script captures user input for matrix dimensions, validates the input,
generates a matrix, splits it based on factorization, and visualizes the
original and split matrices with clearly defined grid lines and highlighted
split boundaries. The final visualization is saved as a PNG image.

Enhancements:
1. Added green thick lines to represent row splits.
2. Incorporated tqdm progress bars to monitor plotting progress.

Author: OpenAI ChatGPT
Date: 2024-11-08
"""

import sys
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from tqdm import tqdm  # Import tqdm for progress bars

class MatrixSplitter:
    def __init__(self, rows, cols):
        """
        Initializes the MatrixSplitter with the specified number of rows and columns.

        Parameters:
            rows (int): Number of rows in the matrix.
            cols (int): Number of columns in the matrix.
        """
        self.rows = rows
        self.cols = cols
        self.matrix = None
        self.small_matrices = []  # 2D list: small_matrices[i][j]
        self.seam_matrices_cols = []
        self.seam_matrices_rows = []

    def smallest_factor_greater_than_three(self, n):
        """
        Finds the smallest factor of n that is greater than three.

        Parameters:
            n (int): The number to find a factor for.

        Returns:
            int: The smallest factor greater than three, or n if no such factor exists.
        """
        for i in range(4, n + 1):
            if n % i == 0:
                return i
        return n

    def generate_matrix(self):
        """
        Generates a matrix with random integers between 0 and 99.
        """
        self.matrix = np.random.randint(0, 100, size=(self.rows, self.cols))
        print(f"Generated matrix of size {self.rows}x{self.cols}.")

    def split_matrix(self):
        """
        Splits the matrix into smaller matrices based on factorization.
        """
        # Find factors
        row_factor = self.smallest_factor_greater_than_three(self.rows)
        col_factor = self.smallest_factor_greater_than_three(self.cols)
        print(f"Row factor: {row_factor}, Column factor: {col_factor}")

        # Calculate number of splits
        num_row_splits = self.rows // row_factor
        num_col_splits = self.cols // col_factor

        # First split: Vertical (rows)
        row_split = np.array_split(self.matrix, num_row_splits, axis=0)
        print(f"Matrix split into {len(row_split)} row-wise submatrices.")

        # Further split: Horizontal (columns)
        # Convert small_matrices to a 2D list
        for sub_matrix in row_split:
            cols_split = np.array_split(sub_matrix, num_col_splits, axis=1)
            self.small_matrices.append(cols_split)
        total_small_matrices = len(self.small_matrices) * num_col_splits
        print(f"Total small matrices after horizontal split: {total_small_matrices}")

        # Second split: Seam matrices for columns
        for sub_matrix in row_split:
            row_seams = []
            for j in range(num_col_splits - 1):
                last_col = sub_matrix[:, (j + 1) * col_factor - 1].reshape(-1, 1)
                next_col = sub_matrix[:, (j + 1) * col_factor].reshape(-1, 1)
                seam_matrix = np.hstack((last_col, next_col))
                row_seams.append(seam_matrix)
            self.seam_matrices_cols.append(row_seams)
        print(f"Generated seam matrices for columns: {len(self.seam_matrices_cols)} sets.")

        # Third split: Seam matrices for rows
        for i in range(num_row_splits - 1):
            row_seams = []
            for j in range(num_col_splits):
                # Access small_matrices as a 2D list
                top_matrix = self.small_matrices[i][j]
                bottom_matrix = self.small_matrices[i + 1][j]
                # Get the last row of top_matrix and first row of bottom_matrix
                top_last_row = top_matrix[-1, :].reshape(1, -1)
                bottom_first_row = bottom_matrix[0, :].reshape(1, -1)
                seam_matrix = np.vstack((top_last_row, bottom_first_row))
                row_seams.append(seam_matrix)
            self.seam_matrices_rows.append(row_seams)
        print(f"Generated seam matrices for rows: {len(self.seam_matrices_rows)} sets.")

    def visualize_splits(self, output_filename="matrix_split.png"):
        """
        Visualizes the original and split matrices as a numbered grid with black grid lines,
        red thick lines for column splits, and green thick lines for row splits.
        The visualization is saved as a PNG image.

        Parameters:
            output_filename (str): The filename for the saved PNG image.
        """
        fig, ax = plt.subplots(figsize=(self.cols, self.rows))
        ax.set_xlim(0, self.cols)
        ax.set_ylim(0, self.rows)
        ax.set_xticks(np.arange(0, self.cols + 1, 1))
        ax.set_yticks(np.arange(0, self.rows + 1, 1))
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.grid(which='both', color='black', linewidth=1)

        # Invert y-axis to have the first row at the top
        ax.invert_yaxis()

        # Add numbers to each cell with a progress bar
        total_cells = self.rows * self.cols
        with tqdm(total=total_cells, desc="Plotting cells", unit="cell") as pbar:
            for i in range(self.rows):
                for j in range(self.cols):
                    cell_value = self.matrix[i, j]
                    ax.text(j + 0.5, i + 0.5, str(cell_value),
                            va='center', ha='center', fontsize=8)
                    pbar.update(1)

        # Highlight split boundaries with red thick lines (columns) and green thick lines (rows)
        row_factor = self.smallest_factor_greater_than_three(self.rows)
        col_factor = self.smallest_factor_greater_than_three(self.cols)

        # Horizontal split lines (rows) - Green
        with tqdm(total=self.rows // row_factor - 1, desc="Drawing row splits", unit="split") as pbar:
            for i in range(1, self.rows // row_factor):
                ax.axhline(i * row_factor, color='green', linewidth=2)
                pbar.update(1)

        # Vertical split lines (columns) - Red
        with tqdm(total=self.cols // col_factor - 1, desc="Drawing column splits", unit="split") as pbar:
            for j in range(1, self.cols // col_factor):
                ax.axvline(j * col_factor, color='red', linewidth=2)
                pbar.update(1)

        # Remove whitespace and axes
        plt.axis('off')
        plt.tight_layout()

        # Save the figure
        plt.savefig(output_filename, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Visualization saved as '{output_filename}'.")

def validate_input(input_str):
    """
    Validates that the input string contains two integers greater than or equal to 4.

    Parameters:
        input_str (str): The input string containing two numbers.

    Returns:
        tuple: A tuple containing two integers (rows, cols).

    Raises:
        ValueError: If input is invalid.
    """
    try:
        parts = input_str.strip().split()
        if len(parts) != 2:
            raise ValueError("Please enter exactly two numbers separated by space.")
        rows, cols = map(int, parts)
        if rows < 4 or cols < 4:
            raise ValueError("Both dimensions must be at least 4.")
        return rows, cols
    except Exception as e:
        raise ValueError(f"Invalid input: {e}")

def main():
    """
    Main function to execute the matrix splitting algorithm.
    """
    print("Matrix Splitting Algorithm with Enhanced Visualization")
    print("------------------------------------------------------")
    try:
        user_input = input("Enter the number of rows and columns (e.g., '8 12'): ")
        rows, cols = validate_input(user_input)
    except ValueError as ve:
        print(ve)
        sys.exit(1)

    splitter = MatrixSplitter(rows, cols)
    splitter.generate_matrix()
    splitter.split_matrix()
    splitter.visualize_splits()

if __name__ == "__main__":
    main()

Matrix Splitting Algorithm with Enhanced Visualization
------------------------------------------------------
Enter the number of rows and columns (e.g., '8 12'): 8 12
Generated matrix of size 8x12.
Row factor: 4, Column factor: 4
Matrix split into 2 row-wise submatrices.
Total small matrices after horizontal split: 6
Generated seam matrices for columns: 2 sets.
Generated seam matrices for rows: 1 sets.


Plotting cells: 100%|██████████| 96/96 [00:00<00:00, 4477.36cell/s]
Drawing row splits: 100%|██████████| 1/1 [00:00<00:00, 338.25split/s]
Drawing column splits: 100%|██████████| 2/2 [00:00<00:00, 789.81split/s]


Visualization saved as 'matrix_split.png'.
