# Pretraining Training script in MAE

> Pretraining Training script in MAE

In [1]:
#| default_exp mae.pretraining_training

In [1]:
#| hide
%load_ext autoreload
%autoreload 2

In [4]:
#| export
import argparse
import datetime
import json
import numpy as np
from typing import Iterable
from types import SimpleNamespace
import os
import time
from pathlib import Path

import torch
from torch import inf
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from argparse import ArgumentParser

import timm

In [3]:
timm.__version__

'1.0.9'

In [4]:
torch.__version__

'2.4.1'

In [5]:
torch.cuda.is_available()

True

In [6]:
#| exporti
import segmentation_test.mae.misc as misc
import segmentation_test.mae.model_development as models_mae

In [7]:
#| export
import timm.optim.optim_factory as optim_factory
import math
from fastcore.test import *

In [8]:
torch.cuda.is_available()

True

### Gradient Scaling and Norm Counting: 


Dealing with mixed precision training, where we use both 16-bit and 32-bit floating-point numbers to speed up computation and reduce memory usage. But this can lead to some tricky numerical issues. Enter the `NativeScalerWithGradNormCount` class!

## What's this class all about?

This nifty little class is a wrapper around PyTorch's `GradScaler`. It's designed to handle the intricacies of mixed precision training while also giving us some extra goodies like gradient norm calculation.

## Let's break it down:

1. **Initialization**: We start by creating a `GradScaler` object. This is PyTorch's built-in tool for automatic mixed precision training.

2. **The `__call__` method**: This is where the magic happens!
   - It scales the loss and performs backpropagation.
   - If we're updating gradients, it handles gradient unscaling and clipping.
   - It also calculates the gradient norm, which is super useful for monitoring training stability.

3. **State management**: The `state_dict` and `load_state_dict` methods allow us to save and load the scaler's state. This is crucial for resuming training from checkpoints.

## Why is this so cool?

- It seamlessly integrates mixed precision training into our workflow.
- It provides gradient clipping out of the box, which helps prevent exploding gradients.
- The gradient norm calculation gives us valuable insights into our training process.

By using this class, we're not just training our model - we're training it smartly and efficiently. It's like having a personal trainer for your neural network!


In [52]:
#| export
def get_grad_norm_(
        parameters, 
        norm_type: float = 2.0
        ) -> torch.Tensor:
    """
    Calculate the gradient norm of the given parameters.
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    norm_type = float(norm_type)
    if len(parameters) == 0:
        return torch.tensor(0.)
    device = parameters[0].grad.device
    if norm_type == inf:
        total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
    else:
        total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
    return total_norm

In [9]:
#| export
class NativeScalerWithGradNormCount:
    state_dict_key = "amp_scaler"

    def __init__(self):
        self._scaler = torch.cuda.amp.GradScaler()

    def __call__(
		self, 
		loss, 
		optimizer, 
		clip_grad=None, 
		parameters=None, 
		create_graph=False, 
		update_grad=True):
        self._scaler.scale(loss).backward(create_graph=create_graph)
        if update_grad:
            if clip_grad is not None:
                assert parameters is not None
                self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
                norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
            else:
                self._scaler.unscale_(optimizer)
                norm = get_grad_norm_(parameters)
            self._scaler.step(optimizer)
            self._scaler.update()
        else:
            norm = None
        return norm

    def state_dict(self):
        return self._scaler.state_dict()

    def load_state_dict(self, state_dict):
        self._scaler.load_state_dict(state_dict)

In [10]:
#| export
def adjust_learning_rate(
	optimizer, 
	epoch, 
	args
	):
    """Decay the learning rate with half-cycle cosine after warmup"""
    if epoch < args.warmup_epochs:
        lr = args.lr * epoch / args.warmup_epochs 
    else:
        lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
            (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
    for param_group in optimizer.param_groups:
        if "lr_scale" in param_group:
            param_group["lr"] = lr * param_group["lr_scale"]
        else:
            param_group["lr"] = lr
    return lr


In [11]:
#| export
def test_adjust_learning_rate():
    class DummyOptimizer:
        def __init__(self):
            self.param_groups = [
				{"lr": 0.1}, 
				{"lr": 0.2, "lr_scale": 2}
			]

    class DummyArgs:
        def __init__(self):
            self.warmup_epochs = 5
            self.epochs = 100
            self.lr = 0.1
            self.min_lr = 0.001

    optimizer = DummyOptimizer()
    args = DummyArgs()

    # Test during warmup
    lr = adjust_learning_rate(optimizer, 2, args)
    test_eq(lr, 0.04)
    test_eq(optimizer.param_groups[0]["lr"], 0.04)
    test_eq(optimizer.param_groups[1]["lr"], 0.08)

    # Test after warmup
    lr = adjust_learning_rate(optimizer, 50, args)
    expected_lr = 0.001 + (0.1 - 0.001) * 0.5 * (1 + math.cos(math.pi * 45 / 95))
    test_close(lr, expected_lr, eps=1e-6)
    test_close(optimizer.param_groups[0]["lr"], expected_lr, eps=1e-6)
    test_close(optimizer.param_groups[1]["lr"], expected_lr * 2, eps=1e-6)
    print("All tests passed!")


In [12]:
test_adjust_learning_rate()

All tests passed!


In [13]:
#| export
def train_one_epoch(
	            model: torch.nn.Module,
                data_loader: Iterable, 
				optimizer: torch.optim.Optimizer,
                device: torch.device, 
				epoch: int, 
				loss_scaler: NativeScalerWithGradNormCount,
                log_writer=None,
                args=None):
    model.train(True)
    metric_logger = misc.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 20

    accum_iter = args.accum_iter

    optimizer.zero_grad()

    if log_writer is not None:
        print('log_dir: {}'.format(log_writer.log_dir))

    for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):

        # we use a per iteration (instead of per epoch) lr scheduler
        if data_iter_step % accum_iter == 0:
            adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)

        samples = samples.to(device, non_blocking=True)

        with torch.cuda.amp.autocast():
            loss, _, _ = model(samples, mask_ratio=args.mask_ratio)

        loss_value = loss.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        loss /= accum_iter
        loss_scaler(
			loss, 
			optimizer, 
			parameters=model.parameters(),
            update_grad=(data_iter_step + 1) % accum_iter == 0
		)
        if (data_iter_step + 1) % accum_iter == 0:
        	optimizer.zero_grad()
        torch.cuda.synchronize()

        metric_logger.update(loss=loss_value)

        lr = optimizer.param_groups[0]["lr"]
        metric_logger.update(lr=lr)

        loss_value_reduce = misc.all_reduce_mean(loss_value)
        if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
            """ We use epoch_1000x as the x-axis in tensorboard.
			This calibrates different curves when batch size changes.
			"""
            epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
            log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
            log_writer.add_scalar('lr', lr, epoch_1000x)


	# gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

# main_pretrain.py script will be implemented here

> engine_pretrain.py needs to be implemented here, and call main function

In [14]:
from fastai.vision.all import untar_data, URLs

# Download and extract the ImageWang dataset (a subset of ImageNet)
path = untar_data(URLs.IMAGEWANG_160)

# Update the DATA_PATH to point to the downloaded dataset
DATA_PATH = str(path)

print(f"ImageWang dataset downloaded and extracted to: {DATA_PATH}")


ImageWang dataset downloaded and extracted to: /home/hasan/.fastai/data/imagewang-160


In [15]:
Path(DATA_PATH).ls()
# MAE pre-training arguments

(#3) [Path('/home/hasan/.fastai/data/imagewang-160/unsup'),Path('/home/hasan/.fastai/data/imagewang-160/train'),Path('/home/hasan/.fastai/data/imagewang-160/val')]

In [7]:
#| export
def parse_args_():
    parser = ArgumentParser(description="MAE Pre-training Arguments")
    
    # MAE pre-training arguments
    parser.add_argument("--batch_size", type=int, default=2, help="Batch size per GPU (effective batch size is BATCH_SIZE * ACCUM_ITER * # gpus)")
    parser.add_argument("--epochs", type=int, default=400, help="Number of epochs for training")
    parser.add_argument("--accum_iter", type=int, default=1, help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)")
    
    # Model parameters
    parser.add_argument("--model", type=str, default='mae_vit_base_patch16', help="Name of model to train")
    parser.add_argument("--input_size", type=int, default=224, help="Images input size")
    parser.add_argument("--mask_ratio", type=float, default=0.75, help="Masking ratio (percentage of removed patches)")
    parser.add_argument("--norm_pix_loss", action='store_true', help="Use (per-patch) normalized pixels as targets for computing loss")
    
    # Optimizer parameters
    parser.add_argument("--weight_decay", type=float, default=0.05, help="Weight decay for optimizer")
    parser.add_argument("--lr", type=float, default=None, help="Learning rate (absolute lr)")
    parser.add_argument("--blr", type=float, default=1e-3, help="Base learning rate: absolute_lr = base_lr * total_batch_size / 256")
    parser.add_argument("--min_lr", type=float, default=0., help="Lower lr bound for cyclic schedulers that hit 0")
    parser.add_argument("--warmup_epochs", type=int, default=40, help="Epochs to warmup LR")
    
    # Dataset parameters
    parser.add_argument("--data_path", type=str, default=" /home/hasan/.fastai/data/imagewang-160", help="Dataset path")
    parser.add_argument("--output_dir", type=str, default='./output_dir', help="Path where to save, empty for no saving")
    parser.add_argument("--log_dir", type=str, default='./output_dir', help="Path where to tensorboard log")
    parser.add_argument("--device", type=str, default='cuda', help="Device to use for training / testing")
    parser.add_argument("--seed", type=int, default=0, help="Seed for reproducibility")
    parser.add_argument("--resume", type=str, default='', help="Resume from checkpoint")
    parser.add_argument("--start_epoch", type=int, default=0, help="Starting epoch")
    parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading")
    parser.add_argument("--pin_mem", action='store_true', help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.")
    
    # Distributed training parameters
    parser.add_argument("--world_size", type=int, default=1, help="Number of distributed processes")
    parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training")
    parser.add_argument("--dist_on_itp", action='store_true', help="Distributed training on ITP")
    parser.add_argument("--dist_url", type=str, default='env://', help="URL used to set up distributed training")
    
    return parser.parse_args()


In [None]:

BATCH_SIZE = 2
EPOCHS = 400
ACCUM_ITER = 1
MODEL = 'mae_vit_base_patch16'
INPUT_SIZE = 224
MASK_RATIO = 0.75
NORM_PIX_LOSS = True
WEIGHT_DECAY = 0.05
LR = None
BLR = 1e-3
MIN_LR = 0.
WARMUP_EPOCHS = 40
DATA_PATH = str(path)
OUTPUT_DIR = './output_dir'
LOG_DIR = './output_dir'
DEVICE = 'cuda'
SEED = 0
RESUME = ''
START_EPOCH = 0
NUM_WORKERS = 4
PIN_MEM = True
WORLD_SIZE = 1
LOCAL_RANK = -1
DIST_ON_ITP = True
DIST_URL = 'env://'


In [50]:
# Create a SimpleNamespace object to mimic argparse.Namespace
args = SimpleNamespace(
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    accum_iter=ACCUM_ITER,
    model=MODEL,
    input_size=INPUT_SIZE,
    mask_ratio=MASK_RATIO,
    norm_pix_loss=NORM_PIX_LOSS,
    weight_decay=WEIGHT_DECAY,
    lr=LR,
    blr=BLR,
    min_lr=MIN_LR,
    warmup_epochs=WARMUP_EPOCHS,
    data_path=DATA_PATH,
    output_dir=OUTPUT_DIR,
    log_dir=LOG_DIR,
    device=DEVICE,
    seed=SEED,
    resume=RESUME,
    start_epoch=START_EPOCH,
    num_workers=NUM_WORKERS,
    pin_mem=PIN_MEM,
    world_size=WORLD_SIZE,
    local_rank=LOCAL_RANK,
    dist_on_itp=DIST_ON_ITP,
    dist_url=DIST_URL
)


In [18]:
misc.init_distributed_mode(args)

Not using distributed mode


In [19]:
print('job dir: {}'.format(os.getcwd()))

[08:54:12.740073] job dir: /home/hasan/Schreibtisch/projects/git_data/segmentation_test/nbs


In [20]:
device = torch.device(args.device)
device

device(type='cuda')

In [21]:
# fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True

In [22]:
# simple augmentation
transform_train = transforms.Compose([
            transforms.RandomResizedCrop(
                args.input_size, 
                scale=(0.2, 1.0), 
                interpolation=3),  # 3 is bicubic
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225])])

In [23]:
transform_train

Compose(
    RandomResizedCrop(size=(224, 224), scale=(0.2, 1.0), ratio=(0.75, 1.3333), interpolation=bicubic, antialias=True)
    RandomHorizontalFlip(p=0.5)
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)

In [24]:
dataset_train = datasets.ImageFolder(
    os.path.join(
        args.data_path, 'train'), 
        transform=transform_train)

In [25]:
dataset_train

Dataset ImageFolder
    Number of datapoints: 14669
    Root location: /home/hasan/.fastai/data/imagewang-160/train
    StandardTransform
Transform: Compose(
               RandomResizedCrop(size=(224, 224), scale=(0.2, 1.0), ratio=(0.75, 1.3333), interpolation=bicubic, antialias=True)
               RandomHorizontalFlip(p=0.5)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

In [26]:
if True:  # args.distributed:
    num_tasks = misc.get_world_size()
    global_rank = misc.get_rank()
    sampler_train = torch.utils.data.DistributedSampler(
        dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
    )
    print("Sampler_train = %s" % str(sampler_train))
else:
    sampler_train = torch.utils.data.RandomSampler(dataset_train)

[08:54:17.903333] Sampler_train = <torch.utils.data.distributed.DistributedSampler object at 0x7f1122cae810>


In [27]:
if global_rank == 0 and args.log_dir is not None:
    os.makedirs(args.log_dir, exist_ok=True)
    log_writer = SummaryWriter(log_dir=args.log_dir)
else:
    log_writer = None

In [28]:
data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

In [29]:
args.model

'mae_vit_large_patch16'

In [30]:
%%capture
# define the model
model = models_mae.__dict__[args.model](norm_pix_loss=args.norm_pix_loss)
model

In [31]:
model.to(device)

model_without_ddp = model

In [32]:
%%capture
print("Model = %s" % str(model_without_ddp))

In [33]:
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
eff_batch_size

2

In [34]:
if args.lr is None:  # only base_lr is specified
	args.lr = args.blr * eff_batch_size / 256
args.lr

7.8125e-06

In [35]:
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
print("actual lr: %.2e" % args.lr)


[08:54:32.810797] base lr: 1.00e-03
[08:54:32.810862] actual lr: 7.81e-06


In [36]:
print("accumulate grad iterations: %d" % args.accum_iter)
print("effective batch size: %d" % eff_batch_size)

[08:54:36.804349] accumulate grad iterations: 1
[08:54:36.804411] effective batch size: 2


In [37]:
if args.distributed:
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
    model_without_ddp = model.module
    

In [38]:
%%capture
# following timm: set wd as 0 for bias and norm layers
# this was found in original repo, but may be old timm version
#param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay)
# it was not working, so changeing param_groups_weight_decay
param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, args.weight_decay)
param_groups

In [39]:
optimizer = torch.optim.AdamW(
    param_groups, 
    lr=args.lr, 
    betas=(0.9, 0.95))

In [40]:
#| export
NativeScaler = NativeScalerWithGradNormCount

In [41]:
print(optimizer)
loss_scaler = NativeScaler()

[08:54:46.765071] AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.95)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 7.8125e-06
    maximize: False
    weight_decay: 0.0

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.95)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 7.8125e-06
    maximize: False
    weight_decay: 0.05
)


  self._scaler = torch.cuda.amp.GradScaler()


In [42]:
misc.load_model(
    args=args, 
    model_without_ddp=model_without_ddp, 
    optimizer=optimizer, 
    loss_scaler=loss_scaler)


This function call is responsible for loading the model, optimizer, and loss scaler for the training process.
It takes in the following arguments:
- `args`: The arguments parsed from the command line or configuration file.
- `model_without_ddp`: The model instance without DistributedDataParallel (DDP) wrapping.
- `optimizer`: The optimizer instance to be used for training.
- `loss_scaler`: The loss scaler instance for automatic mixed precision training.
The purpose of this function is to prepare the model, optimizer, and loss scaler for the training loop.
It ensures that the model is correctly configured for training, including setting the optimizer and loss scaler.
This is a crucial step in the training process as it sets up the necessary components for the model to learn from the data.

In [5]:
#| export
def main():
    args = parse_args_()
    misc.init_distributed_mode(args)
    print('job dir: {}'.format(os.getcwd()))
    print("{}".format(args).replace(', ', ',\n'))

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + misc.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)

    cudnn.benchmark = True

    # simple augmentation
    transform_train = transforms.Compose([
            transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3),  # 3 is bicubic
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train)
    print(dataset_train)

    if True:  # args.distributed:
        num_tasks = misc.get_world_size()
        global_rank = misc.get_rank()
        sampler_train = torch.utils.data.DistributedSampler(
            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
        )
        print("Sampler_train = %s" % str(sampler_train))
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)

    if global_rank == 0 and args.log_dir is not None:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)
    else:
        log_writer = None

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )
    # define the model
    model = models_mae.__dict__[args.model](norm_pix_loss=args.norm_pix_loss)

    model.to(device)

    model_without_ddp = model
    print("Model = %s" % str(model_without_ddp))

    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
    
    if args.lr is None:  # only base_lr is specified
        args.lr = args.blr * eff_batch_size / 256

    print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
    print("actual lr: %.2e" % args.lr)

    print("accumulate grad iterations: %d" % args.accum_iter)
    print("effective batch size: %d" % eff_batch_size)

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module
    
    # following timm: set wd as 0 for bias and norm layers
    param_groups = optim_factory.param_groups_weight_decay(
        model_without_ddp, 
        args.weight_decay)
    optimizer = torch.optim.AdamW(
        param_groups, 
        lr=args.lr, 
        betas=(0.9, 0.95))
    print(optimizer)
    loss_scaler = NativeScaler()

    misc.load_model(
        args=args, 
        model_without_ddp=model_without_ddp, 
        optimizer=optimizer, 
        loss_scaler=loss_scaler)

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)
        train_stats = train_one_epoch(
            model, data_loader_train,
            optimizer, device, epoch, loss_scaler,
            log_writer=log_writer, args=args
        )
        if args.output_dir and (epoch % 20 == 0 or epoch + 1 == args.epochs):
            misc.save_model(
                args=args, 
                model=model, 
                model_without_ddp=model_without_ddp, 
                optimizer=optimizer,
                loss_scaler=loss_scaler, 
                epoch=epoch)

        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                        'epoch': epoch,}

        # This section of code is responsible for logging the training statistics to a file.
        # It checks if an output directory is specified and if the current process is the main process.
        # If both conditions are true, it flushes the log writer (if it exists) to ensure all logs are written.
        # Then, it appends the current training statistics (log_stats) to a file named "log.txt" in the output directory.
        # The statistics are written in JSON format followed by a newline character.
        if args.output_dir and misc.is_main_process():
            if log_writer is not None:
                log_writer.flush()
            with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))

In [None]:
#| export
if __name__ == '__main__':
	main()

In [10]:
#| hide
import nbdev; nbdev.nbdev_export('14_mae.pretraining_training.ipynb')