<a href="https://colab.research.google.com/github/AlexandreAdam/Censai/blob/eager2.4/notebooks/train_raytracer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RayTracer

Uncomment the following block to start

In [1]:
!git clone https://github.com/AlexandreAdam/Censai.git
%cd Censai
!git checkout eager2.4
!python setup.py install

Cloning into 'Censai'...
remote: Enumerating objects: 99, done.[K
remote: Counting objects: 100% (99/99), done.[K
remote: Compressing objects: 100% (72/72), done.[K
remote: Total 526 (delta 38), reused 69 (delta 18), pack-reused 427[K
Receiving objects: 100% (526/526), 13.44 MiB | 33.74 MiB/s, done.
Resolving deltas: 100% (271/271), done.
/content/Censai
Branch 'eager2.4' set up to track remote branch 'eager2.4' from 'origin'.
Switched to a new branch 'eager2.4'
running install
running bdist_egg
running egg_info
creating censai.egg-info
writing censai.egg-info/PKG-INFO
writing dependency_links to censai.egg-info/dependency_links.txt
writing top-level names to censai.egg-info/top_level.txt
writing manifest file 'censai.egg-info/SOURCES.txt'
writing manifest file 'censai.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_py
creating build
creating build/lib
creating build/lib/censai
copying censai/physical_model.py -> build

In [2]:
# wanb login, uncomment 
%pip install wandb -q
import wandb


[K     |████████████████████████████████| 2.0MB 17.0MB/s 
[K     |████████████████████████████████| 163kB 57.8MB/s 
[K     |████████████████████████████████| 102kB 15.6MB/s 
[K     |████████████████████████████████| 133kB 49.8MB/s 
[K     |████████████████████████████████| 71kB 11.1MB/s 
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone
  Building wheel for subprocess32 (setup.py) ... [?25l[?25hdone


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [9]:
import tensorflow as tf
from censai.definitions import RayTracer
from censai.data_generator import NISGenerator
from datetime import datetime
import os
# os.mkdir(Config.logdir) # run only once

In [12]:
def main(args):
    if wndb:
        config = wandb.config
        config.learning_rate = args.lr
        config.batch_size = args.batch_size
        config.epochs = args.epochs
        config.architecture="RayTracer UNET"
        config.update({"total_items_per_batch": args.total_items, "validation_size": args.validation})
    gen = NISGenerator(args.total_items, args.batch_size)
    gen_test = NISGenerator(args.validation, args.validation, train=False)
    ray_tracer = RayTracer()
    optim = tf.optimizers.Adam(lr=args.lr)

    # setup tensorboard writer (nullwriter in case we do not want to sync)
    if args.logdir.lower() != "none":
        logdir = os.path.join(args.logdir, args.logname)
        traindir = os.path.join(logdir, "train")
        testdir = os.path.join(logdir, "test")
        if not os.path.isdir(logdir):
            os.mkdir(logdir)
        if not os.path.isdir(traindir):
            os.mkdir(traindir)
        if not os.path.isdir(testdir):
            os.mkdir(testdir)
        train_writer = tf.summary.create_file_writer(traindir)
        test_writer = tf.summary.create_file_writer(testdir)
    else:
        test_writer = nullwriter()
        train_writer = nullwriter()

    epoch_loss = tf.metrics.Mean()
    step = 1
    for epoch in range(1, args.epochs + 1):
        with train_writer.as_default():
            for batch, (kappa, alpha) in enumerate(gen):
                with tf.GradientTape() as tape:
                    tape.watch(ray_tracer.trainable_variables)
                    cost = ray_tracer.cost(kappa, alpha) # call + MSE loss function
                    cost += tf.reduce_sum(ray_tracer.losses) # add regularizer loss
                gradient = tape.gradient(cost, ray_tracer.trainable_variables)
                clipped_gradient = [tf.clip_by_value(grad, -10, 10) for grad in gradient]
                optim.apply_gradients(zip(clipped_gradient, ray_tracer.trainable_variables)) # backprop

                #========== Summary and logs ==========
                epoch_loss.update_state([cost])
                tf.summary.scalar("MSE", cost, step=step)
                step += 1
                
        with test_writer.as_default():
            for (kappa, alpha) in gen_test:
                test_cost = ray_tracer.cost(kappa, alpha)
            tf.summary.scalar("MSE", test_cost, step=step)
        print(f"epoch {epoch} | train loss {epoch_loss.result().numpy():.3e} | val loss {test_cost.numpy():.3e}")
    return gen, gen_test, ray_tracer 

In [15]:
# quick hack to make a config like args of ArgumentParser
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

# setup hyperparameter and other configs
date = datetime.now().strftime("%y-%m-%d_%H-%M-%S")
Config = AttrDict()
Config.update({
    #hparams
    "epochs": 100,
    "lr": 1e-4,
    "batch_size": 50,
    #configs
    "total_items": 1000, # items per epochs
    "logdir": "logs",
    "logname": date,
    "validation": 100
})


In [16]:
wndb = True
wandb.init(project="censai", entity="adam-alexandre01123", sync_tensorboard=True)
gen, gen_test, ray_tracer = main(Config)

VBox(children=(Label(value=' 0.21MB of 0.21MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train/global_step,3400.0
_timestamp,1613939997.25475
train/MSE,95603497369600.0
global_step,3401.0
_step,3400.0
test/global_step,3401.0
test/MSE,39107774382080.0


0,1
train/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▃▃▃▃▆▆▆▆▆▆▆▆▆▆▆█████████
train/MSE,▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
test/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
test/MSE,▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁




epoch 1 | train loss 5.769e+00 | val loss 5.788e+00
epoch 2 | train loss 5.839e+00 | val loss 5.810e+00
epoch 3 | train loss 5.899e+00 | val loss 5.789e+00
epoch 4 | train loss 5.898e+00 | val loss 5.786e+00
epoch 5 | train loss 5.878e+00 | val loss 5.799e+00
epoch 6 | train loss 5.864e+00 | val loss 5.781e+00
epoch 7 | train loss 5.852e+00 | val loss 5.787e+00
epoch 8 | train loss 5.862e+00 | val loss 5.797e+00
epoch 9 | train loss 5.859e+00 | val loss 5.785e+00
epoch 10 | train loss 5.872e+00 | val loss 5.783e+00
epoch 11 | train loss 5.868e+00 | val loss 5.817e+00
epoch 12 | train loss 5.865e+00 | val loss 5.789e+00
epoch 13 | train loss 5.868e+00 | val loss 5.797e+00
epoch 14 | train loss 5.872e+00 | val loss 5.779e+00
epoch 15 | train loss 5.866e+00 | val loss 5.793e+00
epoch 16 | train loss 5.869e+00 | val loss 5.789e+00
epoch 17 | train loss 5.874e+00 | val loss 5.783e+00
epoch 18 | train loss 5.873e+00 | val loss 5.801e+00
epoch 19 | train loss 5.872e+00 | val loss 5.789e+00
ep