In [13]:
# This code is written by Yara Al-Shorman
# Date created: May 18 2022
# Last modified: May 22 2022
# Github repo: https://github.com/YaraAlShorman/Research-spring-22
# The purpose of this code is to predict the bulkflows for a large given set of galaxies (using machine learning)

### imports

In [34]:
# imports
import numpy as np
import tensorflow as tf
from sklearn.preprocessing import MinMaxScaler
# import matplotlib.pyplot as plt

### data preprocessing

In [22]:
# stacking all of the data into one humungous tensor
# I am using tensors because they can be ragged and can be stacked
# these means we won't have to cut off our data to standardize array sizes
data = np.load(f'C:/Users/yaras/Documents/Research/Feldman/rotated-outerrim-cz-rand/rotated-0-error-40.npy.npz')
input_data = tf.convert_to_tensor(data['data'])
scaler = MinMaxScaler()
scaler.fit(input_data)
input_data = scaler.transform(input_data)
input_data = tf.expand_dims(input_data, axis=0)
input_data = tf.RaggedTensor.from_tensor(input_data)

output_data = tf.convert_to_tensor(data['header'])
output_data = tf.expand_dims(output_data, axis=0)
# output_data = tf.RaggedTensor.from_tensor(output_data)

for i in range(1,10): # limit 3000
        array_data = np.load(f'C:/Users/yaras/Documents/Research/Feldman/rotated-outerrim-cz-rand/rotated-{i}-error-40.npy.npz')
        # 'data' data
        temp = tf.convert_to_tensor(array_data['data'], dtype='float64')
        scaler = MinMaxScaler()
        scaler.fit(temp)
        temp = scaler.transform(temp)
        temp = tf.expand_dims(temp, axis=0)  # dimensions have to be expandad to be able to concat along the outer-dimension 
        input_data = tf.concat([input_data, temp], axis=0)  # concat adds to existing dimensions, does not create new ones
        # 'header' data (bulkflows)
        temp = array_data['header']
        temp = np.expand_dims(temp, axis=0)
        output_data = np.concatenate([output_data, temp], axis=0)
        
# next is scaling

In [31]:
# normalizing/scaling the data
scaler = MinMaxScaler()
scaler.fit(output_data)
output_data = scaler.transform(output_data)
output_data = tf.convert_to_tensor(output_data)

# data is ready

In [39]:
def train_test_split_tensors(X, y, **options):
    """
    encapsulation for the sklearn.model_selection.train_test_split function
    in order to split tensors objects and return tensors as output

    :param X: tensorflow.Tensor object
    :param y: tensorflow.Tensor object
    :dict **options: typical sklearn options are available, such as test_size and train_size
    """

    from sklearn.model_selection import train_test_split

    X_train, X_test, y_train, y_test = train_test_split(np.asarray(X), np.asarray(y), **options)

    X_train, X_test = tf.constant(X_train), tf.constant(X_test)
    y_train, y_test = tf.constant(y_train), tf.constant(y_test)

    del(train_test_split)

    return X_train, X_test, y_train, y_test

In [40]:
# train-test splitting
train_test_split_tensors(input_data, output_data, train_size=0.85, random_state=0, shuffle=True)

TypeError: Singleton array array(<tf.RaggedTensor [[[0.04476167604186792, 0.4319348470437463, 0.6600869858736284,
   0.10562281276603319],
  [0.04470801316027669, 0.2964123450175338, 0.07774368327560849,
   0.45829554731734073],
  [0.03172237834438992, 0.3748306898537166, 0.5074990849844845,
   0.07828420800789393],
  ...,
  [0.9203440848389758, 0.9166299008698995, 0.9097912849111669,
   0.46985195993497325],
  [0.9973402144941484, 0.9373032164146093, 0.03445411172523899,
   0.13522825316059572],
  [1.0, 0.9376738124403152, 0.8614850196591373, 0.15025120160109273]],
 [[0.04476167604186792, 0.35549937554079314, 0.11415318614489606,
   0.2616085383021353],
  [0.03172237834438992, 0.35488049713897984, 0.11140956913496375,
   0.1683793761334571],
  [0.04253692822809806, 0.313280451454661, 0.6326202464701759,
   0.1588511671586666],
  ...,
  [0.9102376093501364, 0.9098281299488606, 0.741198125000201,
   0.30017373612523124],
  [0.9973402144941484, 0.9216550949622406, 0.27213734061030426,
   0.27830875813205336],
  [1.0, 0.9558731697428202, 0.1940244049557851, 0.3545961199969835]],
 [[0.04476167604186792, 0.3196367976397032, 0.9130267340038772,
   0.7966851751156288],
  [0.04470801316027669, 0.2284764641929149, 0.15972317664037589,
   0.4240256444944078],
  [0.03172237834438992, 0.2992224102108061, 0.9153222036717249,
   0.7034497600504326],
  ...,
  [0.9102376093501364, 0.9402396912503876, 0.8089464905308076,
   0.3237117017379686],
  [0.9973402144941484, 0.979874900638495, 0.04955081096588493,
   0.6725864980281844],
  [1.0, 0.9664629965211011, 0.053321559880914415, 0.8237698281916794]],
 [[0.04476167604186792, 0.36012233281780803, 0.13103016470734263,
   0.7145934026031384],
  [0.04470801316027669, 0.2769216248057358, 0.6480294341489201,
   0.7407476689313115],
  [0.03172237834438992, 0.3562910033672755, 0.18969825676793461,
   0.7302267796632919],
  ...,
  [0.9102376093501364, 0.8962580121375718, 0.3521967418883952,
   0.42620960835882216],
  [0.9973402144941484, 0.8939487646532025, 0.11304839376406023,
   0.9485280573400574],
  [1.0, 0.9213137103176692, 0.03571380939654273, 0.8065934418548002]],
 [[0.04476167604186792, 0.3839275087359342, 0.21026212493907542,
   0.42214979035263756],
  [0.04470801316027669, 0.3317848331679407, 0.9132596360121805,
   0.3782634451199948],
  [0.04253692822809806, 0.34285204138544145, 0.06483000342731782,
   0.751883862196961],
  ...,
  [0.9102376093501364, 0.9592404877001837, 0.16964643152359848,
   0.9284803197151871],
  [0.9973402144941484, 0.9410588429602806, 0.0911742109866846,
   0.35768668067225723],
  [1.0, 0.8888228363210113, 0.16785868250260066, 0.2801469583893811]],
 [[0.04476167604186792, 0.37949106705094904, 0.0945291810014463,
   0.5927720979168991],
  [0.04470801316027669, 0.23527688172546135, 0.5720392669425723,
   0.8586121698128514],
  [0.03172237834438992, 0.34230884209742873, 0.1426245913654885,
   0.6016329788833947],
  ...,
  [0.9102376093501364, 0.9234944213963323, 0.32485625434201015,
   0.3641541755172344],
  [0.9973402144941484, 0.9231479933972699, 0.11669997517337549,
   0.8249036581802016],
  [1.0, 0.9355648623882642, 0.03677012595262741, 0.7162842934067022]],
 [[0.04476167604186792, 0.4256653570102338, 0.22430811160017466,
   0.6119172409789656],
  [0.04470801316027669, 0.3460209453276666, 0.42286311173742747,
   0.13609545105152715],
  [0.03172237834438992, 0.40414054758960427, 0.18661195118566337,
   0.5540767724161276],
  ...,
  [0.9371342293961398, 0.9097622698097769, 0.08523832433000345,
   0.1336911308482866],
  [0.9973402144941484, 0.901550977555909, 0.26638835568549735,
   0.3914512691276043],
  [1.0, 0.9276180622354988, 0.2954059949948066, 0.5321616920329105]],
 [[0.04476167604186792, 0.3941314336850845, 0.7706529306830818,
   0.7021741919683584],
  [0.04470801316027669, 0.3203773599777975, 0.5037745456810879,
   0.45145646007008383],
  [0.03172237834438992, 0.2902176011593989, 0.7509313559599824,
   0.7909652925082666],
  ...,
  [0.9102376093501364, 0.9264160675126405, 0.19059688660319457,
   0.7847580354191271],
  [0.9973402144941484, 0.9331055660885013, 0.6392277966017237,
   0.6426383736264991],
  [1.0, 0.9307749669095937, 0.7122177860947065, 0.5816677191641729]],
 [[0.04476167604186792, 0.3761044400965412, 0.2948071591983945,
   0.7757858779881248],
  [0.03172237834438992, 0.3677164921562901, 0.31026058040246884,
   0.6851692458250461],
  [0.04253692822809806, 0.3731258462277167, 0.33979003951909437,
   0.35966765462511086],
  ...,
  [0.9102376093501364, 0.9138068683539999, 0.23300940900771433,
   0.2789733350293794],
  [0.9973402144941484, 0.9001145951175373, 0.4462894484779493,
   0.7006424502097648],
  [1.0, 0.9680607616363606, 0.4224840370213973, 0.8493486530177722]],
 [[0.04476167604186792, 0.42042858714206455, 0.5642182038925748,
   0.6983340284169863],
  [0.04470801316027669, 0.3530431745005731, 0.15846862419825208,
   0.7127967042244923],
  [0.03172237834438992, 0.40580997805212293, 0.6151754583354292,
   0.7487436171514269],
  ...,
  [0.9203440848389758, 0.9318805269950832, 0.3175406733368905,
   0.5380717488119464],
  [0.9973402144941484, 0.9322191436791121, 0.4348470125189277,
   0.8830095468838342],
  [1.0, 0.8839729434841583, 0.46431531657742897, 0.734822831930793]]]>,
      dtype=object) cannot be considered a valid collection.

In [32]:
print(input_data.shape[0], output_data.shape[0])

10 10


### model creation

In [None]:
# model creations

model0 = tf.keras.models.Sequential(
    [
        # add layers here
        # figure out what layers to use
    ]
)
# add model stuff here