In [1]:
import numpy as np
import setigen as stg
from blimpy import Waterfall
import matplotlib.pyplot as plt
import random
import os
from astropy import units as u
from tqdm import tqdm
from sklearn.metrics import silhouette_score
import tensorflow as tf
from tensorflow.keras import layers

os.environ["CUDA_VISIBLE_DEVICES"]="1"
num_classes = 100
num_samples_per_class = 1000


2023-08-23 14:30:15.789946: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
def painting(data):
    all_data = []
    labels = []
    for c in range(num_classes):
        drift = 2*random.random()*(-1)**random.randint(0,2)
        snr = random.randint(100, 150)
        width = random.randint(20, 50)
        for s in range(num_samples_per_class):
            index = random.randint(0, data.shape[0]-1)
            window = data[index, :,:]
            
            start = random.randint(50, 180)
            
            frame = stg.Frame.from_data(df=2.7939677238464355*u.Hz,
                                        dt=18.253611008*u.s,
                                        fch1=1289*u.MHz,
                                        ascending=True,
                                        data=window)
            frame.add_signal(stg.constant_path(
                                        f_start=frame.get_frequency(index=start),
                                       drift_rate=drift*u.Hz/u.s),
                                      stg.constant_t_profile(level=frame.get_intensity(snr=snr)),
                                      stg.gaussian_f_profile(width=width*u.Hz),
                                      stg.constant_bp_profile(level=1))
            all_data.append(frame.data)
            labels.append(c)
    all_data = np.array(all_data)
    labels = np.vstack(labels)
    return all_data, labels

In [3]:
import cv2
import numpy as np


In [4]:
from tqdm import tqdm
import gc

In [5]:
from tqdm import tqdm
import gc
import keras
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.beta = 8
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(
            name="kl_loss"
        )
        self.kl_additional = tf.keras.losses.KLDivergence()
    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
        ]
    def gaussanity_loss(self, data, base):
        return self.kl_additional(data, base)
    
    def train_step(self, data_in):
        data = data_in
        print(data.shape)
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            
            
            total_loss = reconstruction_loss + self.beta * kl_loss
        
        mse_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.mse(data, reconstruction), axis=(1, 2)
                )
            )
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        self.reconstruction_loss_tracker.update_state(mse_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
        }
    def test_step(self, data_in):
        data, _ = data_in
        z_mean, z_log_var, z = self.encoder(data)
        reconstruction = self.decoder(z)
        reconstruction_loss = tf.reduce_mean(
            tf.reduce_sum(
                keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
            )
        )
        kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))

        total_loss = reconstruction_loss + self.beta * kl_loss 
        
        mse_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.mse(data, reconstruction), axis=(1, 2)
                )
            )
        self.total_loss_tracker.update_state(total_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        self.reconstruction_loss_tracker.update_state(mse_loss)
        return {
            "test_loss": self.total_loss_tracker.result(),
            "test_kl_loss": self.kl_loss_tracker.result(),
            "test_reconstruction_loss": self.reconstruction_loss_tracker.result()
        }
    def __call__ (self, inputs):
        return self.decoder(self.encoder(inputs)[0])

In [6]:
class Sampling(layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [7]:
latent_dim = 10
time_samples = 16
freq_sample =  256
encoder_inputs = keras.Input(shape=(time_samples, freq_sample, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=1, padding="same")(encoder_inputs)
x = layers.MaxPool2D(pool_size=(1, 2))(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2D(64, 3, activation="relu", strides=1, padding="same")(x)
x = layers.MaxPool2D(pool_size=(1, 2))(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2D(64, 3, activation="relu", strides=1, padding="same")(x)
x = layers.MaxPool2D(pool_size=(1, 2))(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2D(128, 3, activation="relu", strides=1, padding="same")(x)
x = layers.MaxPool2D(pool_size=(1, 2))(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2D(128, 3, activation="relu", strides=1, padding="same")(x)
x = layers.MaxPool2D(pool_size=(1, 2))(x)
x = layers.BatchNormalization()(x)
x_shape = x.shape
x = layers.Flatten()(x)
x = layers.Dense(256, activation="relu")(x)
x = layers.BatchNormalization()(x)
x = layers.Dense(32, activation="relu")(x)
x = layers.BatchNormalization()(x)


z_mean = layers.Dense(32, activation="relu")(x)
z_mean = layers.BatchNormalization()(z_mean)
z_mean = layers.Dense(latent_dim, name="z_mean")(z_mean)

z_log_var = layers.Dense(32, activation="relu")(x)
z_log_var = layers.BatchNormalization()(z_log_var)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(z_log_var)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

Model: "encoder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 16, 256, 1)  0           []                               
                                ]                                                                 
                                                                                                  
 conv2d (Conv2D)                (None, 16, 256, 32)  320         ['input_1[0][0]']                
                                                                                                  
 max_pooling2d (MaxPooling2D)   (None, 16, 128, 32)  0           ['conv2d[0][0]']                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 16, 128, 32)  128        ['max_pooling2d[0][0]']    

2023-08-23 14:30:20.258544: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-08-23 14:30:20.737174: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1613] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13888 MB memory:  -> device: 0, name: NVIDIA RTX A4000, pci bus id: 0000:61:00.0, compute capability: 8.6


 rmalization)                                                                                     
                                                                                                  
 flatten (Flatten)              (None, 16384)        0           ['batch_normalization_4[0][0]']  
                                                                                                  
 dense (Dense)                  (None, 256)          4194560     ['flatten[0][0]']                
                                                                                                  
 batch_normalization_5 (BatchNo  (None, 256)         1024        ['dense[0][0]']                  
 rmalization)                                                                                     
                                                                                                  
 dense_1 (Dense)                (None, 32)           8224        ['batch_normalization_5[0][0]']  
          

In [8]:
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(256, activation="relu")(latent_inputs)
x = layers.BatchNormalization()(x)
x = layers.Dense(x_shape[1]* x_shape[2]* x_shape[3], activation="relu")(x)
x = layers.BatchNormalization()(x)
x = layers.Reshape((x_shape[1], x_shape[2], x_shape[3]))(x)
x = layers.Conv2DTranspose(128, 3, activation="relu", strides=2, padding="same")(x)
x = layers.MaxPool2D(pool_size=(2, 1))(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2DTranspose(128, 3, activation="relu", strides=2, padding="same")(x)
x = layers.MaxPool2D(pool_size=(2, 1))(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.MaxPool2D(pool_size=(2, 1))(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.MaxPool2D(pool_size=(2, 1))(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
x = layers.MaxPool2D(pool_size=(2, 1))(x)
x = layers.BatchNormalization()(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="linear", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

Model: "decoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 10)]              0         
                                                                 
 dense_4 (Dense)             (None, 256)               2816      
                                                                 
 batch_normalization_9 (Batc  (None, 256)              1024      
 hNormalization)                                                 
                                                                 
 dense_5 (Dense)             (None, 16384)             4210688   
                                                                 
 batch_normalization_10 (Bat  (None, 16384)            65536     
 chNormalization)                                                
                                                                 
 reshape (Reshape)           (None, 16, 8, 128)        0   

In [9]:
autoencoder = VAE(encoder, decoder)
autoencoder.compile(optimizer=keras.optimizers.Adam(learning_rate = 1e-3))
# autoencoder.load_weights("../b-vae/models/full-weights-"+'07-02-2023-15-19-23')
autoencoder.load_weights("../b-vae/models/full-weights-"+'08-23-2023-13-58-57')


<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7fa00e38b0a0>

In [10]:
def normalize(data):
    epsilon = 1
    min_val = data.min()
    data = data - min_val + epsilon
    new_data = np.log(data)
    min_val = data.min()
    max_val = data.max()
    final_data = (data - min_val) / (max_val - min_val)
    return final_data
    
def normalize_data(data):
    for i in tqdm(range(data.shape[0])):
        data[i,:,:] = normalize(data[i,:,:])
    return data

In [11]:
import os
from tqdm import tqdm
total_scores = []
for i in tqdm(range(10)):
    directory = os.fsencode( "../../../../../datax/scratch/pma/reverse_search/test/")
    count = 0
    data = []
    for folder in os.listdir(directory):
        print(folder)
        for subfolder in os.listdir(directory+folder):
            back = os.fsencode( "/")
            if '.' not in str(subfolder):
                for file in os.listdir(directory+folder+back+subfolder):
                    file_directory = str(os.path.join(directory+folder+back+subfolder, file)).replace('b', '').replace("'","")
                    if 'filtered.npy' in file_directory:
                        data.append(np.load(str(file_directory)))
                        count += 1
    data = np.vstack(data)
    print(data.shape)
    injected, labels = painting(data)
    
    print(injected.shape)
    
    input_data = np.expand_dims(normalize_data(injected), axis = -1)
    del data
    gc.collect()
    input_data = normalize_data(input_data)
    print(input_data[0,:,:].max(), input_data[0,:,:].min())
    print(input_data.shape)
    features = []
    # for i in range(1,101):
    tensor = tf.convert_to_tensor(input_data, dtype=tf.float32)
    X = autoencoder.encoder.predict(tensor, batch_size= 1024)[0]
    # del input_data
    gc.collect()
    # features = np.vstack(features)
    score = silhouette_score(X, labels = labels[:, 0])
    print("SCORE IS: ", score)
    total_scores.append(score)

  0%|                                                                                              | 0/10 [00:00<?, ?it/s]

b'HIP104887-1850'
b'HIP87579-1008'
b'clustering_tests'
(347064, 16, 256)
(100000, 16, 256)



  0%|                                                                                          | 0/100000 [00:00<?, ?it/s][A
  1%|▊                                                                           | 1011/100000 [00:00<00:09, 10104.90it/s][A
  2%|█▌                                                                          | 2029/100000 [00:00<00:09, 10147.49it/s][A
  3%|██▎                                                                         | 3048/100000 [00:00<00:09, 10164.65it/s][A
  4%|███                                                                         | 4065/100000 [00:00<00:09, 10123.74it/s][A
  5%|███▊                                                                        | 5084/100000 [00:00<00:09, 10144.58it/s][A
  6%|████▋                                                                       | 6104/100000 [00:00<00:09, 10161.34it/s][A
  7%|█████▍                                                                      | 7122/100000 [00:00<00:09, 10166.25

1.0 0.0
(100000, 16, 256, 1)


2023-08-23 14:34:51.904642: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:428] Loaded cuDNN version 8401


 6/98 [>.............................] - ETA: 2s 

2023-08-23 14:34:54.615816: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:630] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.




 10%|████████▎                                                                          | 1/10 [06:58<1:02:49, 418.88s/it]

SCORE IS:  -0.12153472
b'HIP104887-1850'
b'HIP87579-1008'
b'clustering_tests'
(347064, 16, 256)
(100000, 16, 256)



  0%|                                                                                          | 0/100000 [00:00<?, ?it/s][A
  1%|▊                                                                           | 1025/100000 [00:00<00:09, 10245.91it/s][A
  2%|█▌                                                                          | 2050/100000 [00:00<00:09, 10065.61it/s][A
  3%|██▎                                                                         | 3064/100000 [00:00<00:09, 10098.20it/s][A
  4%|███                                                                         | 4074/100000 [00:00<00:09, 10032.74it/s][A
  5%|███▊                                                                        | 5089/100000 [00:00<00:09, 10071.45it/s][A
  6%|████▋                                                                       | 6097/100000 [00:00<00:09, 10029.18it/s][A
  7%|█████▍                                                                      | 7109/100000 [00:00<00:09, 10057.72

1.0 0.0
(100000, 16, 256, 1)


 20%|█████████████████                                                                    | 2/10 [12:51<50:37, 379.67s/it]

SCORE IS:  -0.1277936
b'HIP104887-1850'
b'HIP87579-1008'
b'clustering_tests'
(347064, 16, 256)
(100000, 16, 256)



  0%|                                                                                          | 0/100000 [00:00<?, ?it/s][A
  1%|▊                                                                           | 1016/100000 [00:00<00:09, 10158.07it/s][A
  2%|█▌                                                                          | 2032/100000 [00:00<00:09, 10066.34it/s][A
  3%|██▎                                                                          | 3039/100000 [00:00<00:09, 9923.33it/s][A
  4%|███                                                                         | 4057/100000 [00:00<00:09, 10019.88it/s][A
  5%|███▊                                                                        | 5078/100000 [00:00<00:09, 10084.48it/s][A
  6%|████▋                                                                       | 6095/100000 [00:00<00:09, 10111.47it/s][A
  7%|█████▍                                                                      | 7109/100000 [00:00<00:09, 10118.70

1.0 0.0
(100000, 16, 256, 1)


 30%|█████████████████████████▌                                                           | 3/10 [17:45<39:45, 340.77s/it]

SCORE IS:  -0.12080323
b'HIP104887-1850'
b'HIP87579-1008'
b'clustering_tests'
(347064, 16, 256)
(100000, 16, 256)



  0%|                                                                                          | 0/100000 [00:00<?, ?it/s][A
  1%|▊                                                                            | 1000/100000 [00:00<00:09, 9995.55it/s][A
  2%|█▌                                                                           | 2000/100000 [00:00<00:09, 9807.39it/s][A
  3%|██▎                                                                          | 2981/100000 [00:00<00:09, 9765.53it/s][A
  4%|███                                                                          | 4004/100000 [00:00<00:09, 9944.83it/s][A
  5%|███▊                                                                        | 5019/100000 [00:00<00:09, 10017.17it/s][A
  6%|████▌                                                                       | 6030/100000 [00:00<00:09, 10048.39it/s][A
  7%|█████▎                                                                      | 7048/100000 [00:00<00:09, 10089.96

1.0 0.0
(100000, 16, 256, 1)


 40%|██████████████████████████████████                                                   | 4/10 [23:24<33:59, 339.96s/it]

SCORE IS:  -0.11706964
b'HIP104887-1850'
b'HIP87579-1008'
b'clustering_tests'
(347064, 16, 256)
(100000, 16, 256)



  0%|                                                                                          | 0/100000 [00:00<?, ?it/s][A
  1%|▊                                                                           | 1030/100000 [00:00<00:09, 10292.80it/s][A
  2%|█▌                                                                          | 2060/100000 [00:00<00:09, 10247.03it/s][A
  3%|██▎                                                                         | 3085/100000 [00:00<00:09, 10228.18it/s][A
  4%|███                                                                         | 4108/100000 [00:00<00:09, 10061.65it/s][A
  5%|███▉                                                                        | 5131/100000 [00:00<00:09, 10121.40it/s][A
  6%|████▋                                                                       | 6155/100000 [00:00<00:09, 10160.99it/s][A
  7%|█████▍                                                                      | 7179/100000 [00:00<00:09, 10184.44

1.0 0.0
(100000, 16, 256, 1)


 50%|██████████████████████████████████████████▌                                          | 5/10 [28:46<27:47, 333.41s/it]

SCORE IS:  -0.12145381
b'HIP104887-1850'
b'HIP87579-1008'
b'clustering_tests'
(347064, 16, 256)
(100000, 16, 256)



  0%|                                                                                          | 0/100000 [00:00<?, ?it/s][A
  1%|▊                                                                           | 1024/100000 [00:00<00:09, 10228.79it/s][A
  2%|█▌                                                                          | 2047/100000 [00:00<00:09, 10178.87it/s][A
  3%|██▎                                                                         | 3065/100000 [00:00<00:09, 10162.40it/s][A
  4%|███                                                                         | 4082/100000 [00:00<00:09, 10144.10it/s][A
  5%|███▉                                                                        | 5103/100000 [00:00<00:09, 10165.93it/s][A
  6%|████▋                                                                       | 6120/100000 [00:00<00:09, 10166.37it/s][A
  7%|█████▍                                                                      | 7138/100000 [00:00<00:09, 10167.94

1.0 0.0
(100000, 16, 256, 1)


 60%|███████████████████████████████████████████████████                                  | 6/10 [34:08<21:58, 329.68s/it]

SCORE IS:  -0.12062328
b'HIP104887-1850'
b'HIP87579-1008'
b'clustering_tests'
(347064, 16, 256)
(100000, 16, 256)



  0%|                                                                                          | 0/100000 [00:00<?, ?it/s][A
  1%|▊                                                                           | 1013/100000 [00:00<00:09, 10121.76it/s][A
  2%|█▌                                                                          | 2033/100000 [00:00<00:09, 10164.08it/s][A
  3%|██▎                                                                          | 3050/100000 [00:00<00:09, 9991.58it/s][A
  4%|███                                                                         | 4070/100000 [00:00<00:09, 10069.50it/s][A
  5%|███▊                                                                        | 5094/100000 [00:00<00:09, 10130.31it/s][A
  6%|████▋                                                                       | 6118/100000 [00:00<00:09, 10165.64it/s][A
  7%|█████▍                                                                       | 7135/100000 [00:00<00:09, 9993.54

1.0 0.0
(100000, 16, 256, 1)


 70%|███████████████████████████████████████████████████████████▍                         | 7/10 [39:24<16:15, 325.05s/it]

SCORE IS:  -0.12067214
b'HIP104887-1850'
b'HIP87579-1008'
b'clustering_tests'
(347064, 16, 256)
(100000, 16, 256)



  0%|                                                                                          | 0/100000 [00:00<?, ?it/s][A
  1%|▊                                                                           | 1028/100000 [00:00<00:09, 10275.89it/s][A
  2%|█▌                                                                          | 2056/100000 [00:00<00:09, 10192.02it/s][A
  3%|██▎                                                                         | 3079/100000 [00:00<00:09, 10204.76it/s][A
  4%|███                                                                         | 4100/100000 [00:00<00:09, 10184.79it/s][A
  5%|███▉                                                                        | 5119/100000 [00:00<00:09, 10181.10it/s][A
  6%|████▋                                                                       | 6138/100000 [00:00<00:09, 10174.93it/s][A
  7%|█████▍                                                                      | 7156/100000 [00:00<00:09, 10171.00

1.0 0.0
(100000, 16, 256, 1)


 80%|████████████████████████████████████████████████████████████████████                 | 8/10 [44:33<10:40, 320.08s/it]

SCORE IS:  -0.12660365
b'HIP104887-1850'
b'HIP87579-1008'
b'clustering_tests'
(347064, 16, 256)
(100000, 16, 256)



  0%|                                                                                          | 0/100000 [00:00<?, ?it/s][A
  1%|▊                                                                           | 1005/100000 [00:00<00:09, 10039.38it/s][A
  2%|█▌                                                                          | 2026/100000 [00:00<00:09, 10134.27it/s][A
  3%|██▎                                                                         | 3048/100000 [00:00<00:09, 10172.35it/s][A
  4%|███                                                                         | 4069/100000 [00:00<00:09, 10186.64it/s][A
  5%|███▊                                                                        | 5088/100000 [00:00<00:09, 10068.60it/s][A
  6%|████▋                                                                       | 6109/100000 [00:00<00:09, 10115.07it/s][A
  7%|█████▍                                                                      | 7133/100000 [00:00<00:09, 10153.61

1.0 0.0
(100000, 16, 256, 1)


 90%|████████████████████████████████████████████████████████████████████████████▌        | 9/10 [49:46<05:18, 318.02s/it]

SCORE IS:  -0.12378133
b'HIP104887-1850'
b'HIP87579-1008'
b'clustering_tests'
(347064, 16, 256)
(100000, 16, 256)



  0%|                                                                                          | 0/100000 [00:00<?, ?it/s][A
  1%|▊                                                                           | 1031/100000 [00:00<00:09, 10306.89it/s][A
  2%|█▌                                                                          | 2062/100000 [00:00<00:09, 10251.00it/s][A
  3%|██▎                                                                         | 3088/100000 [00:00<00:09, 10235.15it/s][A
  4%|███▏                                                                        | 4112/100000 [00:00<00:09, 10229.59it/s][A
  5%|███▉                                                                        | 5135/100000 [00:00<00:09, 10220.58it/s][A
  6%|████▋                                                                       | 6158/100000 [00:00<00:09, 10218.28it/s][A
  7%|█████▍                                                                      | 7183/100000 [00:00<00:09, 10227.94

1.0 0.0
(100000, 16, 256, 1)


100%|████████████████████████████████████████████████████████████████████████████████████| 10/10 [54:56<00:00, 329.63s/it]

SCORE IS:  -0.12686808





In [12]:
print(np.mean(total_scores))
print(np.std(total_scores))

-0.122720346
0.0032612237
