In [52]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from torchvision.transforms import ToTensor
from jax.scipy.special import logsumexp
#import jax.opt as jopt
import pickle
import numpy as np

In [53]:
def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n, ))
    
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    
    print(keys)
    print(sizes[:-1])
    print(sizes[1:])
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], 
                                                            sizes[1:], 
                                                            keys)]

layer_sizes = [6, 100, 100, 1]
step_size = 0.01
num_epochs = 8
batch_size = 10000
n_targets = 1
params = init_network_params(layer_sizes, random.PRNGKey(0))

[[2285895361 1501764800]
 [1518642379 4090693311]
 [ 433833334 4221794875]
 [ 839183663 3740430601]]
[6, 100, 100]
[100, 100, 1]


In [54]:
def relu(x):
    return jnp.maximum(0, x)

def tanh(x):
    return jnp.tanh(x)

def selu(x):
    return jax.nn.selu(x)

def predict(params, image):
    activations = image
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

In [55]:
random_flattened_images = random.normal(random.PRNGKey(1), 
                                        (10, 6))

In [56]:
batched_predict = vmap(predict, in_axes = (None, 0))
batched_predict_alt = vmap(predict, in_axes = (0, 0))

In [57]:
batched_preds = batched_predict(params, random_flattened_images)

In [58]:
batched_preds.shape

(10, 1)

In [59]:
# def one_hot(x, k, dtype=jnp.float32):
#   """Create a one-hot encoding of x of size k."""
#   return jnp.array(x[:, None] == jnp.arange(k), dtype)
  
# def accuracy(params, images, targets):
#     target_class = jnp.argmax(targets, axis=1)
#     predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
#     return jnp.mean(predicted_class == target_class)

# def loss(params, images, targets):
#     preds = batched_predict(params, images)
#     return -jnp.mean(preds * targets)
def huber_loss(target: float, pred: float, delta: float = 1.0) -> float:
    """Huber loss.

    Args:
    target: ground truth
    pred: predictions
    delta: radius of quadratic behavior
    Returns:
    loss value

    References:
    https://en.wikipedia.org/wiki/Huber_loss
    """
    abs_diff = jnp.abs(target - pred)
    return jnp.where(abs_diff > delta,
                   delta * (abs_diff - .5 * delta),
                   0.5 * abs_diff ** 2)

def loss(params, features, targets, delta = 1.):
    print('preds shape')
    preds = batched_predict(params, features)
    print(preds.shape)
    my_huber_loss = jnp.mean(huber_loss(targets, preds, delta = delta))
    print('loss shape')
    print(my_huber_loss.shape)
    return my_huber_loss

@jit
def update(params, x, y):
    grads = grad(loss)(params, x, y)
    return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

In [60]:
import os
folder = '/users/afengler/data/' + \
                       'proj_lan_pipeline/LAN_scripts/data/lan_mlp/' + \
                       'training_data_0_nbins_0_n_2000/ddm/'
file_list = [folder + file_ for file_ in os.listdir(folder)]

In [61]:
example_file = pickle.load(open(file_list[0], 'rb'))
print(example_file['data'].shape)
print(len(file_list))


(200000, 6)
101


In [62]:
n_steps_per_epoch = np.floor(len(file_list) * example_file['data'].shape[0] / batch_size)

In [63]:
n_steps_per_epoch

2020.0

In [64]:
import torch
class DatasetTorch(torch.utils.data.Dataset):
    def __init__(self, 
                file_IDs, 
                batch_size = 1,
                label_prelog_cutoff_low = 1e-7,
                label_prelog_cutoff_high = None
                ):

        # Initialization
        self.batch_size = batch_size
        self.file_IDs = file_IDs
        self.indexes = np.arange(len(self.file_IDs))
        self.label_prelog_cutoff_low = label_prelog_cutoff_low
        self.label_prelog_cutoff_high = label_prelog_cutoff_high
        self.tmp_data = None
        self.__init_file_shape()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor((len(self.file_IDs) * self.file_shape_dict['inputs'][0]) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch

        # Find list of IDs
        if index % self.batches_per_file == 0 or self.tmp_data == None:
            self.__load_file(file_index = self.indexes[index // self.batches_per_file])

        # Generate data
        batch_ids = np.arange(((index % self.batches_per_file) * self.batch_size), ((index % self.batches_per_file) + 1) * self.batch_size, 1)
        X, y = self.__data_generation(batch_ids)
        return X.astype(jnp.float32), y.astype(jnp.float32)

    def __load_file(self, file_index):
        self.tmp_data = pickle.load(open(self.file_IDs[file_index], 'rb'))
        shuffle_idx = np.random.choice(self.tmp_data['data'].shape[0], size = self.tmp_data['data'].shape[0], replace = True)
        self.tmp_data['data'] = self.tmp_data['data'][shuffle_idx, :]
        self.tmp_data['labels'] = self.tmp_data['labels'][shuffle_idx]       
        return

    def __init_file_shape(self):
        init_file = pickle.load(open(self.file_IDs[0], 'rb'))
        #print('Init file shape: ', init_file['data'].shape, init_file['labels'].shape)
        
        self.file_shape_dict = {'inputs': init_file['data'].shape, 'labels': init_file['labels'].shape}
        self.batches_per_file = int(self.file_shape_dict['inputs'][0] / self.batch_size)
        self.input_dim = self.file_shape_dict['inputs'][1]
        
        if len(self.file_shape_dict['labels']) > 1:
            self.label_dim = self.file_shape_dict['labels'][1]
        else:
            self.label_dim = 1
        return

    def __data_generation(self, batch_ids = None):
        #print('passed datageneration')
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        X = np.squeeze(self.tmp_data['data'][batch_ids, :]) #tmp_file[batch_ids, :-1]
        y = self.tmp_data['labels'][batch_ids] #tmp_file[batch_ids, -1]
        #print(X.shape)
#         print(type(self.tmp_data['data']))
#         print(type(self.tmp_data['labels']))
#         print(type(X))
#         print(type(y))
        if self.label_prelog_cutoff_low is not None:
            y[y < np.log(self.label_prelog_cutoff_low)] = np.log(self.label_prelog_cutoff_low)
        
        if self.label_prelog_cutoff_high is not None:
            y[y > np.log(self.label_prelog_cutoff_high)] = np.log(self.label_prelog_cutoff_high)
#         print(type(y))
#         print(type(X))
#         print(type(y))
#         print(type(x))
        return X, y

In [68]:
torch_training_dataset = DatasetTorch(file_IDs = file_list,
                                      batch_size = batch_size)

In [75]:
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)
    
torch_training_dataloader = torch.utils.data.DataLoader(torch_training_dataset,
                                                        shuffle = True,
                                                        batch_size = 1,
                                                        num_workers = 0,
                                                        collate_fn = numpy_collate,
                                                        pin_memory = True)

In [78]:
for x, y in torch_training_dataloader:
    print(jnp.squeeze(x).shape)
    print(jnp.squeeze(y).shape)
    
    break

(10000, 6)
(10000,)


In [80]:
import time

for epoch in range(num_epochs):
    start_time = time.time()
    cnt = 0
    for x, y in torch_training_dataloader:
        #y = one_hot(y, n_targets)
        params = update(params, jnp.squeeze(x), jnp.squeeze(y))
        cnt += 1
        if (cnt % 100) == 0:
            print(cnt, 'of', n_steps_per_epoch, ' epochs ran')
    epoch_time = time.time() - start_time
    print(epoch_time)
#     for x, y in training_generator:
#         my_loss = 
    #train_acc = accuracy(params, train_images, train_labels)
    #test_acc = accuracy(params, test_images, test_labels)
#     print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
#     print("Training set accuracy {}".format(train_acc))
#     print("Test set accuracy {}".format(test_acc))

100 of 2020.0  epochs ran
200 of 2020.0  epochs ran
300 of 2020.0  epochs ran
400 of 2020.0  epochs ran
500 of 2020.0  epochs ran
600 of 2020.0  epochs ran
700 of 2020.0  epochs ran
800 of 2020.0  epochs ran
900 of 2020.0  epochs ran
1000 of 2020.0  epochs ran
1100 of 2020.0  epochs ran
1200 of 2020.0  epochs ran
1300 of 2020.0  epochs ran
1400 of 2020.0  epochs ran
1500 of 2020.0  epochs ran
1600 of 2020.0  epochs ran
1700 of 2020.0  epochs ran
1800 of 2020.0  epochs ran
1900 of 2020.0  epochs ran
2000 of 2020.0  epochs ran
179.55225157737732
100 of 2020.0  epochs ran
200 of 2020.0  epochs ran
300 of 2020.0  epochs ran
400 of 2020.0  epochs ran
500 of 2020.0  epochs ran
600 of 2020.0  epochs ran
700 of 2020.0  epochs ran
800 of 2020.0  epochs ran
900 of 2020.0  epochs ran
1000 of 2020.0  epochs ran
1100 of 2020.0  epochs ran
1200 of 2020.0  epochs ran
1300 of 2020.0  epochs ran
1400 of 2020.0  epochs ran
1500 of 2020.0  epochs ran
1600 of 2020.0  epochs ran
1700 of 2020.0  epochs ran


In [None]:
import numpy as np
from torch.utils import data
from torchvision.datasets import MNIST

def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

class NumpyLoader(data.DataLoader):
    def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
        super(self.__class__, self).__init__(dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            sampler=sampler,
            batch_sampler=batch_sampler,
            num_workers=num_workers,
            collate_fn=numpy_collate,
            pin_memory=pin_memory,
            drop_last=drop_last,
            timeout=timeout,
            worker_init_fn=worker_init_fn)

# This is applied when the __getitem__ method in the dataset (mnist_dataset below)
# is invoked
class FlattenAndCast(object):
    def __call__(self, pic):
        #print(pic)
        return np.ravel(np.array(pic, dtype=jnp.float32)).astype(jnp.float32)

In [None]:
# Define our dataset, using torch datasets
mnist_dataset = MNIST('data/mnist/', 
                      download=True,
                      transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, 
                                 batch_size=batch_size,
                                 num_workers=0)

In [116]:
# Get the full train dataset ( for checking accuracy while training)
train_images = np.array(mnist_dataset.data[500:, :, :]).reshape(len(mnist_dataset.data[500:]), - 1) 
train_labels = one_hot(np.array(mnist_dataset.targets[500:]), n_targets)                                                                  

# Get test dataset
test_images = np.array(mnist_dataset.data[:500, :, :]).reshape(len(mnist_dataset.data[:500]), - 1) 
test_labels = one_hot(np.array(mnist_dataset.targets[:500]), n_targets) 

In [117]:
import time

for epoch in range(num_epochs):
    start_time = time.time()
    for x, y in training_generator:
        y = one_hot(y, n_targets)
        params = update(params, x, y)
    epoch_time = time.time() - start_time

    train_acc = accuracy(params, train_images, train_labels)
    test_acc = accuracy(params, test_images, test_labels)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))

Epoch 0 in 5.89 sec
Training set accuracy 0.942874014377594
Test set accuracy 0.9500000476837158
Epoch 1 in 5.57 sec
Training set accuracy 0.9526386857032776
Test set accuracy 0.956000030040741
Epoch 2 in 5.16 sec
Training set accuracy 0.960084080696106
Test set accuracy 0.9600000381469727
Epoch 3 in 5.27 sec
Training set accuracy 0.9652605652809143
Test set accuracy 0.9620000720024109
Epoch 4 in 5.37 sec
Training set accuracy 0.9689244031906128
Test set accuracy 0.968000054359436
Epoch 5 in 5.19 sec
Training set accuracy 0.9720168709754944
Test set accuracy 0.9720000624656677
Epoch 6 in 5.29 sec
Training set accuracy 0.974907636642456
Test set accuracy 0.9720000624656677
Epoch 7 in 5.29 sec
Training set accuracy 0.9772941470146179
Test set accuracy 0.9740000367164612
