In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from colorcloud.biasutti2019riu import LitLearner, LitData
from colorcloud.behley2019iccv import ProjectionVizTransform
import lightning as L
import wandb
from lightning.pytorch.loggers import WandbLogger
from datetime import datetime
from torch import nn

In [None]:
def wandb_hook(logger, step):
    def hook(module, input, output):
        if 'relu' in module.name:
            acts = output.detach()
            min_count = (acts < 1e-1).sum((0, 2, 3))
            shape = acts.shape
            total_count = shape[0]*shape[2]*shape[3]
            rate = min_count/total_count
            logger.log({"dead_rate/" + str(module.name): rate.max()}, step=step)
    return hook

In [None]:
data = LitData(train_batch_size=10)
data.setup('fit')
epoch_steps = len(data.train_dataloader())
data.setup('test')
proj_viz_tfm = ProjectionVizTransform(data.ds_test.color_map_rgb_np, data.ds_test.learning_map_inv_np)

In [None]:
n_epochs = 1
learner = LitLearner(total_steps=n_epochs*epoch_steps, debugging=True, debugging_hook=wandb_hook, proj_viz_tfm=proj_viz_tfm)

In [None]:
time = datetime.now()
timestamp = str(time.year) + '-' + str(time.month) + '-' + str(time.day) + '_' + str(time.hour) + '-' + str(time.minute) + '-' + str(time.second)

wandb_logger = WandbLogger(project="colorcloud", name=timestamp, log_model="all")
wandb_logger.watch(learner.model, log="all")

In [None]:
# train model
trainer = L.Trainer(max_epochs=n_epochs, logger=wandb_logger)
trainer.fit(learner, data)

In [None]:
wandb.finish()