In [None]:
# import settings and functions
%run ./../../imports.ipynb

print('Importing dataset creation functions...')

In [None]:
def make_grid_linear_dataset(g_max, how_many, csv_name="grid_dataset", k=1.0, plot_dataset=False):
    g_min = - g_max

    G_x = np.linspace(g_min, g_max, how_many)
    Q_x = -k * G_x
    G_y = np.linspace(g_min, g_max, how_many)
    Q_y = -k * G_y

    X, Y = np.meshgrid(G_x, G_y)
    Q_Y = -k * Y
    Q_X = -k * X

    # Flatten the arrays
    X_flat = X.flatten()
    Y_flat = Y.flatten()
    Q_X_flat = Q_X.flatten()
    Q_Y_flat = Q_Y.flatten()

    # Stack them horizontally
    data = np.column_stack((X_flat, Y_flat, Q_X_flat, Q_Y_flat))

    # Define the header
    header = "gradx,grady,fluxx,fluxy"

    # Save to CSV
    np.savetxt(csv_name, data, delimiter=",", header=header, comments='')

    if plot_dataset:
        # 2D scatter plot for G_x and Q_x
        plt.figure()
        plt.scatter(G_x, Q_x, s=2)
        plt.xlabel('G_x')
        plt.ylabel('Q_x')
        plt.grid(True, ls=':')
        plt.xlabel(r'Temperature gradient $g_x^*$')
        plt.ylabel(r'Heat flux $q_x^*$')
        plt.tight_layout()
        plt.savefig('simplest_G_x_Q_x.pdf')
        plt.show()

        # 3D scatter plot for G_x, G_y, and Q_y
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        ax.scatter(X, Y, Q_X, s=2)
        ax.set_xlabel(r'$g_x^*$')
        ax.set_ylabel(r'$g_y^*$')
        ax.set_zlabel(r'$q_y^*$')
        plt.tight_layout()
        plt.savefig('simplest_G_x_G_y_Q_y.pdf')
        plt.show()