In [16]:
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from mpl_toolkits.mplot3d import Axes3D
import plotly.express as px

# Load your dataset
df = pd.read_csv("../data/train.csv")

# Define the column name for the sales price
sales_price_column = "SalePrice"

# Filter out the sales price column and non-numeric columns from the columns to plot against
columns_to_plot = [
    col
    for col in df.columns
    if col != sales_price_column
    and col != "Id"
    and pd.api.types.is_numeric_dtype(df[col])
]

# Loop through the DataFrame's columns and create a 3D plot for each unique pair
for i in tqdm(range(len(columns_to_plot)), desc="Creating plots", unit="plot"):
    for j in range(i + 1, len(columns_to_plot)):
        # Create a new figure for each plot
        fig = plt.figure(figsize=(15, 10))
        ax = fig.add_subplot(111, projection="3d")
        col_x = columns_to_plot[i]
        col_y = columns_to_plot[j]

        # Create a 3D scatter plot
        ax.scatter(df[col_x], df[col_y], df[sales_price_column])

        # Set the plot title and labels
        ax.set_title(f"{col_x}, {col_y}, {sales_price_column}")
        ax.set_xlabel(col_x)
        ax.set_ylabel(col_y)
        ax.set_zlabel(sales_price_column)

        # Save the plot to a file
        plt.savefig(
            f"../diagram/3D/3D_plot_{col_x}_vs_{col_y}_vs_{sales_price_column}.png"
        )

        # Close the figure to free memory
        plt.close(fig)

Creating plots:   0%|          | 0/36 [00:00<?, ?plot/s]

Creating plots: 100%|██████████| 36/36 [04:16<00:00,  7.12s/plot]
