# Imports and Helper Functions

In [5]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import csv

# Data Generation and Classification

In [1]:
# the input is the size of the data to be generated
# the output is an array of arrays which contains
# the data in form [bias, x, y, z, classification] and
# the parameters for the spheroid
def gen3Ddata(number):
    data = []
    
    # create data with no classifcation
    for i in range(number):
        data.append([1, np.random.uniform(-1.5, 1.5), \
                     np.random.uniform(-1.5, 1.5), \
                     np.random.uniform(-1.5, 1.5), 0])
    
    # gets a, b, c, and r of our figure
    params = [np.random.uniform(0,10) for i in range(4)]
    
    for j in data:
        # classify as 1 if inside the spheroid
        # ax^2 + by^2 + cz^2 = r^2
        if np.dot([j[1]**2, j[2]**2, j[3]**2], params[0:3]) < 1:
            j[4] = 1
        else:
            j[4] = -1
    return data, params

In [2]:
# input is params of sphere equations [a, b, c, r]
# output is a bunch of points to visualize the graph 
# in form [x, y, z]
def outputSphere(params):
    # create points on the unit circle
    vec = np.random.randn(3, 100000)
    vec /= np.linalg.norm(vec, axis=0)
    
    # apply coefficients
    for i in range(3):
        vec[i] *= np.sqrt(params[i])
        
    return vec

# Output Data

In [3]:
# outputs the data and sphere to filename
def outputPoints(data, sphere, filename):
    # open the filename for writing
    with open(filename, mode = 'w', newline='') as f:
        data_writer = csv.writer(f)
        data_writer.writerow(['x_coords', 'y_coords', 'z_coords', 'type'])
        # X Y Z Classification (-1, 1, or 0 for plane)
        # add data
        for line in data:
            data_writer.writerow([line[1], line[2], line[3], line[4]])
        # add plane data
        x, y, z = sphere
        for index in range(len(x)):
            data_writer.writerow([x[index], y[index], z[index], 0])

In [6]:
data, params = gen3Ddata(1000)
outputPoints(data, outputSphere(params), "NonlinearData.csv")