In [1]:
import os
os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async'

In [2]:
import tensorflow as tf
from typing import Any

2026-01-30 00:26:38.570824: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-01-30 00:26:38.592016: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-01-30 00:26:38.597607: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-01-30 00:26:38.618200: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
from __future__ import annotations

import logging
import sys
import os
from pathlib import Path
from datetime import datetime
from typing import Any, Literal
from dataclasses import dataclass, field
import json

In [4]:
class Colours:
    RESET = "\033[0m"
    BOLD = "\033[1m"
    DIM = "\033[2m"
    
    # Standard colors
    BLACK = "\033[30m"
    RED = "\033[31m"
    GREEN = "\033[32m"
    YELLOW = "\033[33m"
    BLUE = "\033[34m"
    MAGENTA = "\033[35m"
    CYAN = "\033[36m"
    WHITE = "\033[37m"
    
    # Bright colors
    BRIGHT_RED = "\033[91m"
    BRIGHT_GREEN = "\033[92m"
    BRIGHT_YELLOW = "\033[93m"
    BRIGHT_BLUE = "\033[94m"
    BRIGHT_MAGENTA = "\033[95m"
    BRIGHT_CYAN = "\033[96m"
    BRIGHT_WHITE = "\033[97m"
    
    # Background colors
    BG_RED = "\033[41m"
    BG_GREEN = "\033[42m"
    BG_YELLOW = "\033[43m"
    BG_BLUE = "\033[44m"

    @classmethod
    def disable(cls):
        for attr in dir(cls):
            if not attr.startswith('_') and isinstance(getattr(cls, attr), str):
                setattr(cls, attr, "")

In [5]:
@dataclass
class LogLevel:
    name: str
    colour: str
    icon: str
    level: int

In [6]:
LOG_LEVELS = {
    "debug": LogLevel("DEBUG", Colours.DIM, "üîç", logging.DEBUG),
    "info": LogLevel("INFO", Colours.BLUE, "‚ÑπÔ∏è ", logging.INFO),
    "success": LogLevel("SUCCESS", Colours.BRIGHT_GREEN, "‚úì", logging.INFO + 1),
    "metric": LogLevel("METRIC", Colours.CYAN, "üìä", logging.INFO + 2),
    "warning": LogLevel("WARNING", Colours.BRIGHT_YELLOW, "‚ö†Ô∏è ", logging.WARNING),
    "error": LogLevel("ERROR", Colours.BRIGHT_RED, "‚úó", logging.ERROR),
    "critical": LogLevel("CRITICAL", Colours.BG_RED + Colours.WHITE, "üíÄ", logging.CRITICAL),
    "checkpoint": LogLevel("CHECKPOINT", Colours.BRIGHT_GREEN, "üíæ", logging.INFO + 3),
    "epoch": LogLevel("EPOCH", Colours.BRIGHT_MAGENTA, "üîÑ", logging.INFO + 4),
}

In [7]:
class ConsoleFormatter(logging.Formatter):
    def __init__(self, frmt: str | None = None, date_frmt: str | None = None, use_colours: bool = True):
        super().__init__(frmt,date_frmt)
        self.use_colours = use_colours

    def format(self, record: logging.LogRecord):
        # Getting the logging info
        level_name = record.levelname.lower()
        level_config = LOG_LEVELS.get(level_name, LOG_LEVELS['info'])

        # Checking if the colours are needed
        if self.use_colours:

            timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            coloured_time = f"{Colours.DIM}{timestamp}{Colours.RESET}"

            colourized_level = f"{level_config.colour}{level_config.icon}{level_config.name:10}{Colours.RESET}"

            match level_name:
                case "error" | "critical":
                    coloured_msg = f"{Colours.RED}{record.getMessage()}{Colours.RESET}"
                case "success":
                    coloured_msg = f"{Colours.GREEN}{record.getMessage()}{Colours.RESET}"
                case "checkpoint":
                    coloured_msg = f"{Colours.BRIGHT_GREEN}{record.getMessage()}{Colours.RESET}"
                case "warning":
                    coloured_msg = f"{Colours.YELLOW}{record.getMessage()}{Colours.RESET}"
                case "metric":
                    coloured_msg = f"{Colours.CYAN}{record.getMessage()}{Colours.RESET}"
                case "epoch":
                    coloured_msg = f"{Colours.MAGENTA}{record.getMessage()}{Colours.RESET}"
                case _:
                    coloured_msg = record.getMessage()

            return f"{coloured_time} | {colourized_level} | {coloured_msg}"
        else:
            super().format(record)

In [8]:
class FileFormatter(logging.Formatter):
    def format(self,record: logging.LogRecord):
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
        level = record.levelname

        # Including extra stuff I might want to add later
        extra = ""
        if hasattr(record,"step"):
            extra = extra + f" [step={record.step}"
        if hasattr(record,"epoch"):
            extra = extra + f" [epoch={record.epoch}"

        return f"{timestamp} | {level:10} | {record.getMessage()}{extra}"

In [9]:
class TensorBoardWriter:
    def __init__(self, log_directory: Path):
        self.log_directory = log_directory
        self._writer = tf.summary.create_file_writer(str(log_directory))

    @property
    def writer(self):
        return self._writer

    def scalar(self, tag: str, value: float, step: int):
        with self.writer.as_default(step = step):
            tf.summary.scalar(tag, value)

    def scalars(self, main_tag: str, values: dict[str, float], step: int):
        with self.writer.as_default(step = step):
            for name, value in values.items():
                # Writing the scalars to the tensorboard
                tf.summary.scalar(f"{main_tag}/{name}", value)

    def image(self, tag: str, image: tf.Tensor | np.ndarray, step: int):
        with self._writer.as_default(step = step):
            # Writing the image to the tensorboard
            if len(image.shape) == 3:
                image = tf.expand_dims(image,axis = 0)

            tf.summary.image(tag, image)

    def histogram(self, tag: str, values: tf.Tensor | np.ndarray, step : int):
        with self._writer.as_default(step = step):
            tf.summary.histogram(tag, values)

    def text(self, tag: str, text: str, step: int):
        with self._writer.as_default(step = step):
            tf.summary.text(tag, text)

    def flush(self):
        self._writer.flush()

    def close(self):
        self._writer.close()

In [10]:
class Logger:
    def __init__(self, job_name: str, log_dir: str | Path = "logs", tensorboard: bool = True, console: bool = True, file: bool = True, level: str = "info", config: dict | None = None):
        
        self.job_name = job_name
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        self.job_dir = Path(log_dir) / f"{job_name}_{timestamp}"
        self.job_dir.mkdir(parents = True, exist_ok = True)
        
        self.tensorboard_dir = self.job_dir / "tensorboard"
        self.checkpoints_dir = self.job_dir / "checkpoints"
        self.checkpoints_dir.mkdir(exist_ok = True)

        self._logger = logging.getLogger(f"training.{job_name}.{timestamp}")
        self._logger.setLevel(logging.DEBUG)
        self._logger.handlers.clear()
        self._logger.propogate = False

        self._console_logging_enabled = console
        if console:
            console_logging_handler = logging.StreamHandler(sys.stdout)
            console_logging_handler.setLevel(LOG_LEVELS.get(level, LOG_LEVELS['info']).level)
            console_logging_handler.setFormatter(ConsoleFormatter())
            self._logger.addHandler(console_logging_handler)

        self._file_logging_enabled = file
        if file:
            log_file = self.job_dir / "training.log"
            file_logging_handler = logging.FileHandler(log_file, encoding = "utf-8")
            file_logging_handler.setLevel(logging.DEBUG)
            file_logging_handler.setFormatter(FileFormatter())
            self._logger.addHandler(file_logging_handler)

        self._tensorboard_writer: TensorBoardWriter | None = None
        if tensorboard:
            self.tensorboard_dir.mkdir(exist_ok = True)
            self._tensorboard_writer = TensorBoardWriter(self.tensorboard_dir)

        # Storing a metric history
        self._metric_history: list[dict] = []

        # Saving the config file snapshot for examination too
        if config:
            config_path = self.job_dir / "config.json"
            with open(config_path, "w") as file:
                json.dump(config, file, indent = 2, default = str)

        self.info(f"Logger Initialized: {self.job_dir}")

    def _log(self, level: str, message: str, **extra):
        level_config = LOG_LEVELS.get(level, LOG_LEVELS["info"])

        record = self._logger.makeRecord(name = self._logger.name, level = level_config.level, fn = "", lno = 0, msg = message, args = (), exc_info = None)

        record.level_name = level.upper()
        for key, value in extra.items():
            setattr(record, key, value)

        self._logger.handle(record)

    def debug(self, message: str, **extra):
        self._log("debug", message, **extra)

    def info(self, message: str, **extra):
        self._log("info", message, **extra)

    def success(self, message: str, **extra):
        self._log("success", message, **extra)

    def warning(self, message: str, **extra):
        self._log("warning", message, **extra)

    def critical(self, message: str, **extra):
        self._log("critical", message, **extra)

    def error(self, message: str, **extra):
        self._log("error", message, **extra)

    def checkpoint(self, message: str, path: str | Path | None = None, **extra):
        full_message = f"{message} -> {path}" if path else message
        self._log("checkpoint", message, **extra)

    def epoch(self, epoch: int, total: int | None = None, **extra):
        message = f"Epoch {epoch}/{total}" if total else f"Epoch {epoch}"
        self._log("epoch", message, epoch = epoch, **extra)

    def metric(self, message: str, **extra):
        self._log("metric", message, **extra)

    def log_scalar(self, tag: str, value: float, step: int):

        if self._tensorboard_writer:
            self._tensorboard_writer.scalar(tag, value, step)

    def log_scalars(self, tag: str, values: dict[str, float], step: int):
        if self._tensorboard_writer:
            self._tensorboard_writer.scalars(tag, values, step)

    def log_image(self, tag: str, image: tf.Tensor | np.ndarray, step: int):
        if self._tensorboard_writer:
            self._tensorboard_writer.image(tag, image, step)

    def log_histogram(self, tag: str, values: tf.Tensor | np.ndarray, step: int):
        if self._tensorboard_writer:
            self._tensorboard_writer.histogram(tag, image, step)

    def log_text(self, tag: str, text: str, step: int):
        if self._tensorboard_writer:
            self._tensorboard_writer.text(tag, image, step)

    def log_metrics(self, metrics: dict[str, float], step: int, prefix: str = "", to_tensorboard: bool = True, to_console: bool = True):
        if prefix:
            prefixed = {f"{prefix}/{key}": value for key, value in metrics.items()}
        else:
            prefixed = metrics

        if to_tensorboard and self._tensorboard_writer:
            # Write to tensorboard
            for tag, value in prefixed.items():
                self._tensorboard_writer.scalar(tag, value, step)

        if to_console:
            metrics_message = " | ".join(f"{key}: {value:.4f}" for key, value in prefixed.items())
            self.metric(f"[Step {step}] {metrics_message}", step = step)

        # Adding to the metric history
        self._metric_history.append({
            "step": step,
            "timestamp": datetime.now().isoformat(),
            **prefixed
        })

    def log_training_step(self, step: int, loss: float, learning_rate: float, extra: dict[str,float] | None = None, log_every: int = 100):

        # Checking if the step needs to be logged
        if step % log_every != 0:
            return

        metrics = {"loss": loss, "lr": learning_rate}
        if extra:
            metrics.update(extra)

        self.log_metrics(metrics, step, prefix = "train", to_console = True)

    def log_validation(self, metrics: dict[str, float], step: int):

        self.log_metrics(metrics, step = step, prefix = "val", to_console = True)

        # Checking to highlight classic metrics
        for key in ["mAP@0.50", "mAP", "AP"]:
            if key in metrics:
                self.success(f"Validation {key}: {metrics[key]:.4f}")
                break

    def log_epoch_summary(self, epoch: int, train_metrics: dict[str,float], val_metrics: dict[str, float] | None = None):

        # Line divider
        self.info(f"{'-' * 50}")

        training_message = " | ".join(f"{key}: {value:.4f}" for key,value in train_metrics.items())
        self.info(f"Epoch {epoch} Train: {training_message}")

        if val_metrics:
            validation_message = " | ".join(f"{key}: {value:.4f}" for key,value in val_metrics.items())
            self.info(f"Epoch {epoch} Val: {validation_message}")

        self.info(f"{'-' * 50}")

    def get_checkpoint_path(self, filename: str):
        return self.checkpoints_dir / filename

    def save_metric_history(self):
        path = self.job_dir / "metric_history.json"
        with open(path, "w") as file:
            json.dump(self._metric_history, file, indent = 2)

    def flush(self):
        # Flushing each handler
        for handler in self._logger.handlers:
            handler.flush()

        if self._tensorboard_writer:
            self._tensorboard_writer.flush()

    def close(self):

        # Wrapping up everything
        self.save_metric_history()

        if self._tensorboard_writer:
            self._tensorboard_writer.close()

        for handler in self._logger.handlers:
            handler.close()
            self._logger.removeHandler(handler)


    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type:
            self.error(f"Exception: {exc_type.__name__}: {exc_val}")

        self.close()

In [11]:
def build_logger_from_config(config: dict, job_name: str | None = None):
    logging_config = config.get('logging', {})

    return Logger(job_name = job_name, log_dir = logging_config.get('log_dir', "logs"), tensorboard = logging_config.get('tensorboard', True), console = logging_config.get('console', True), file = logging_config.get('file', True), level = logging_config.get('level', "info"))

In [12]:
with Logger("demo", log_dir= Path("logs")) as logger:
    logger.debug("Debug message (only in file)")
    logger.info("Info message")
    logger.success("Success message")
    logger.warning("Warning message")
    logger.error("Error message")

    print()

    for epoch in range(1, 4):
        logger.epoch(epoch, total=3)
            
        for step in range(100):
            global_step = (epoch - 1) * 100 + step
            loss = 1.0 / (global_step + 1)
                
            # Rate-limited logging
            logger.log_training_step(
                step=global_step,
                loss=loss,
                learning_rate=0.001,
                log_every=50,
            )
            
        # Validation
        logger.log_validation(
            {"mAP@0.50": 0.5 + epoch * 0.1, "mAP@0.75": 0.3 + epoch * 0.05},
            step=epoch,
        )
            
        # Checkpoint
        if epoch == 2:
            logger.checkpoint("Best model", path="model_best.h5")
        
    logger.success("Training complete!")
print("\n‚úì Demo complete\n")

I0000 00:00:1769750801.228972   21028 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1769750801.381298   21028 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1769750801.381401   21028 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1769750801.408659   21028 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1769750801.408846   21028 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:0

[2m2026-01-30 00:26:41[0m | [34m‚ÑπÔ∏è INFO      [0m | Logger Initialized: logs/demo_20260130_002641
[2m2026-01-30 00:26:41[0m | [34m‚ÑπÔ∏è INFO      [0m | Info message
[2m2026-01-30 00:26:41[0m | [34m‚ÑπÔ∏è INFO      [0m | Success message
[2m2026-01-30 00:26:41[0m | [91m‚úóERROR     [0m | [31mError message[0m

[2m2026-01-30 00:26:41[0m | [34m‚ÑπÔ∏è INFO      [0m | Epoch 1/3


 00:26:41.619915: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2112] Could not identify NUMA node of platform GPU id 0, defaulting to 0.  Your kernel may not have been built with NUMA support.
2026-01-30 00:26:41.619963: I tensorflow/core/common_runtime/gpu/gpu_process_state.cc:198] Using CUDA malloc Async allocator for GPU: 0
I0000 00:00:1769750801.621536   21028 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2026-01-30 00:26:41.621613: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 9558 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 4070, pci bus id: 0000:01:00.0, compute capability: 8.9


[2m2026-01-30 00:26:42[0m | [34m‚ÑπÔ∏è INFO      [0m | [Step 0] train/loss: 1.0000 | train/lr: 0.0010
[2m2026-01-30 00:26:42[0m | [34m‚ÑπÔ∏è INFO      [0m | [Step 50] train/loss: 0.0196 | train/lr: 0.0010
[2m2026-01-30 00:26:42[0m | [34m‚ÑπÔ∏è INFO      [0m | [Step 1] val/mAP@0.50: 0.6000 | val/mAP@0.75: 0.3500
[2m2026-01-30 00:26:42[0m | [34m‚ÑπÔ∏è INFO      [0m | Validation mAP@0.50: 0.6000
[2m2026-01-30 00:26:42[0m | [34m‚ÑπÔ∏è INFO      [0m | Epoch 2/3
[2m2026-01-30 00:26:42[0m | [34m‚ÑπÔ∏è INFO      [0m | [Step 100] train/loss: 0.0099 | train/lr: 0.0010
[2m2026-01-30 00:26:42[0m | [34m‚ÑπÔ∏è INFO      [0m | [Step 150] train/loss: 0.0066 | train/lr: 0.0010
[2m2026-01-30 00:26:42[0m | [34m‚ÑπÔ∏è INFO      [0m | [Step 2] val/mAP@0.50: 0.7000 | val/mAP@0.75: 0.4000
[2m2026-01-30 00:26:42[0m | [34m‚ÑπÔ∏è INFO      [0m | Validation mAP@0.50: 0.7000
[2m2026-01-30 00:26:42[0m | [34m‚ÑπÔ∏è INFO      [0m | Best model
[2m2026-01-30 00:26:42[0m | [34