In [8]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from torch.amp import autocast

from unet import Unet

from dataset import get_train_data


  from .autonotebook import tqdm as notebook_tqdm


In [35]:
class LossScaler:
    def __init__(self, init_scale):
        self.scale = init_scale

    def get_scale(self):
        pass

    def step(self, optimizer):
        pass

    def _scale_gradients(self, optimizer):
        has_nan_or_inf = False
        has_zero = False
        
        for param_group in optimizer.param_groups:
            for param in param_group['params']:
                if param.grad is not None:
                    param.grad.data.div_(self.scale)

                    grad = param.grad.data
                    if torch.isinf(grad).any() or torch.isnan(grad).any():
                        has_nan_or_inf = True
                        param.grad.data.zero_()

                    if param.grad.data.abs().sum().item() == 0:
                        has_zero = True

        if has_zero:
            print(f'Zero gradients with scale {self.scale}')
        return has_nan_or_inf


class StaticLossScaler(LossScaler):
    def __init__(self, init_scale):
        super().__init__(init_scale)

    def get_scale(self):
        return self.scale

    def step(self, optimizer):
        has_nan_of_inf = self._scale_gradients(optimizer)

        if not has_nan_of_inf:
            optimizer.step()

class DynamicLossScaler(LossScaler):
    def __init__(self, init_scale):
        super().__init__(init_scale)

    def get_scale(self):
        return self.scale

    def step(self, optimizer):
        has_nan_or_inf = self._scale_gradients(optimizer)

        if not has_nan_or_inf:
            optimizer.step()
        else:
            self._update_scale()

    def _update_scale(self):
        self.scale /= 2
        print(f'Scale updated: {self.scale}')


def train_epoch(
    train_loader: DataLoader,
    model: torch.nn.Module,
    criterion: torch.nn.modules.loss._Loss,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    scaler: LossScaler
) -> None:
    model.train()

    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    for i, (images, labels) in pbar:
        images = images.to(device)
        labels = labels.to(device)

        with autocast(device.type, dtype=torch.float16):
            outputs = model(images)
            loss = criterion(outputs, labels)
        # TODO: your code for loss scaling here
        optimizer.zero_grad()
        scale = scaler.get_scale()

        scaled_loss = scale * loss
        scaled_loss.backward()

        scaler.step(optimizer)

        accuracy = ((outputs > 0.5) == labels).float().mean()

        pbar.set_description(f"Loss: {round(loss.item(), 4)} " f"Accuracy: {round(accuracy.item() * 100, 4)}")

def train(loss_scaling, device_id):
    device = torch.device(f"cuda:{device_id}")
    model = Unet().to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    train_loader = get_train_data()

    num_epochs = 5
    for epoch in range(0, num_epochs):
        train_epoch(train_loader, model, criterion, optimizer, device, loss_scaling)

In [31]:
init_scale = 2.0**16
static_scaler = StaticLossScaler(init_scale)

train(static_scaler, 4)

Loss: 0.6044 Accuracy: 95.163: 100%|██████████| 40/40 [00:32<00:00,  1.21it/s] 
Loss: 0.5942 Accuracy: 97.6608: 100%|██████████| 40/40 [00:32<00:00,  1.22it/s]
Loss: 0.5869 Accuracy: 98.5456: 100%|██████████| 40/40 [00:32<00:00,  1.21it/s]
Loss: 0.5833 Accuracy: 98.546: 100%|██████████| 40/40 [00:32<00:00,  1.22it/s] 
Loss: 0.5851 Accuracy: 98.8529: 100%|██████████| 40/40 [00:33<00:00,  1.18it/s]


In [36]:
init_scale = 2.0**20
dynamic_scaler = DynamicLossScaler(init_scale)

train(dynamic_scaler, 4)

Loss: 0.8495 Accuracy: 61.0644:   2%|▎         | 1/40 [00:03<02:22,  3.65s/it]

Zero gradients with scale 1048576.0
Scale updated: 524288.0


Loss: 0.6092 Accuracy: 94.6808: 100%|██████████| 40/40 [00:33<00:00,  1.20it/s]
Loss: 0.5933 Accuracy: 97.2532: 100%|██████████| 40/40 [00:32<00:00,  1.22it/s]
Loss: 0.5875 Accuracy: 98.516: 100%|██████████| 40/40 [00:33<00:00,  1.21it/s] 
Loss: 0.5845 Accuracy: 98.6235: 100%|██████████| 40/40 [00:33<00:00,  1.19it/s]
Loss: 0.5804 Accuracy: 98.6768: 100%|██████████| 40/40 [00:33<00:00,  1.19it/s]
