# Pretraining Training script in MAE

> Pretraining Training script in MAE

In [1]:
#| default_exp mae.pretraining_training

In [1]:
#| export
import argparse
import datetime
import json
import numpy as np
import os
import time
from pathlib import Path

import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import timm

2024-09-14 18:31:25.617501: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-14 18:31:25.617568: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-14 18:31:25.617601: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-14 18:31:25.818799: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Gradient Scaling and Norm Counting: 


Dealing with mixed precision training, where we use both 16-bit and 32-bit floating-point numbers to speed up computation and reduce memory usage. But this can lead to some tricky numerical issues. Enter the `NativeScalerWithGradNormCount` class!

## What's this class all about?

This nifty little class is a wrapper around PyTorch's `GradScaler`. It's designed to handle the intricacies of mixed precision training while also giving us some extra goodies like gradient norm calculation.

## Let's break it down:

1. **Initialization**: We start by creating a `GradScaler` object. This is PyTorch's built-in tool for automatic mixed precision training.

2. **The `__call__` method**: This is where the magic happens!
   - It scales the loss and performs backpropagation.
   - If we're updating gradients, it handles gradient unscaling and clipping.
   - It also calculates the gradient norm, which is super useful for monitoring training stability.

3. **State management**: The `state_dict` and `load_state_dict` methods allow us to save and load the scaler's state. This is crucial for resuming training from checkpoints.

## Why is this so cool?

- It seamlessly integrates mixed precision training into our workflow.
- It provides gradient clipping out of the box, which helps prevent exploding gradients.
- The gradient norm calculation gives us valuable insights into our training process.

By using this class, we're not just training our model - we're training it smartly and efficiently. It's like having a personal trainer for your neural network!


In [2]:
#| export
class NativeScalerWithGradNormCount:
    state_dict_key = "amp_scaler"

    def __init__(self):
        self._scaler = torch.cuda.amp.GradScaler()

    def __call__(
		self, 
		loss, 
		optimizer, 
		clip_grad=None, 
		parameters=None, 
		create_graph=False, 
		update_grad=True):
        self._scaler.scale(loss).backward(create_graph=create_graph)
        if update_grad:
            if clip_grad is not None:
                assert parameters is not None
                self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
                norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
            else:
                self._scaler.unscale_(optimizer)
                norm = get_grad_norm_(parameters)
            self._scaler.step(optimizer)
            self._scaler.update()
        else:
            norm = None
        return norm

    def state_dict(self):
        return self._scaler.state_dict()

    def load_state_dict(self, state_dict):
        self._scaler.load_state_dict(state_dict)