In [1]:
# !sh download_data.sh

In [2]:
import torch
from torch import nn
from tqdm.cli import tqdm
from typing import Self, Union, Optional

from unet import Unet

from dataset import get_train_data

In [3]:
class Scaler:
    _scale: torch.Tensor
    found_inf: bool

    def __init__(self: Self, init_scale: float = 2**16) -> None:
        self._scale = torch.tensor(init_scale, dtype=torch.float32)
        
        self.found_inf = False

    def scale(self: Self, outputs: torch.Tensor) -> torch.Tensor:
        return outputs * self._scale
    
    def step(self: Self, optimizer: torch.optim.Optimizer) -> None:
        self.unscale_(optimizer)
        self.found_inf = self._has_inf_or_nan(optimizer)
        if not self.found_inf:
            optimizer.step()

    def update(self: Self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None:
        if new_scale is not None:
            self._scale = new_scale

    def _has_inf_or_nan(self: Self, optimizer: torch.optim.Optimizer) -> bool:
        for group in optimizer.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                if not torch.isfinite(p.grad).all():
                    return True
        return False
    
    def unscale_(self: Self, optimizer: torch.optim.Optimizer) -> None:
        for group in optimizer.param_groups:
            for p in group["params"]:
                if p.grad is not None:
                    p.grad /= self._scale

    
class StaticScaler(Scaler): 
    ...


class DynamicScaler(Scaler):
    correct_steps_counter: int

    def __init__(
        self,
        init_scale=2**16,
        growth_factor=2,
        backoff_factor=0.5,
        growth_interval=2000,
    ) -> None:
        super().__init__(init_scale)

        self.growth_factor = growth_factor
        self.backoff_factor = backoff_factor
        self.growth_interval = growth_interval

        self.correct_steps_counter = 0

    def update(self: Self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None:
        if new_scale is not None:
            self._scale = new_scale
            return

        if self.found_inf:
            self.correct_steps_counter = 0
            self._scale *= self.backoff_factor
            return

        self.correct_steps_counter += 1

        if self.correct_steps_counter >= self.growth_interval:
            self._scale *= self.growth_factor
            self.correct_steps_counter = 0

In [4]:
def train_epoch(
    train_loader: torch.utils.data.DataLoader,
    model: torch.nn.Module,
    criterion: torch.nn.modules.loss._Loss,
    optimizer: torch.optim.Optimizer,
    scaler: Union[Scaler, torch.amp.GradScaler],
    device: torch.device,
) -> None:
    model.train()

    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    for i, (images, labels) in pbar:
        images = images.to(device)
        labels = labels.to(device)

        with torch.amp.autocast(device.type, dtype=torch.float16):
            outputs = model(images)
            loss = criterion(outputs, labels)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        accuracy = ((outputs > 0.5) == labels).float().mean()

        pbar.set_description(f"Loss: {round(loss.item(), 4)} " f"Accuracy: {round(accuracy.item() * 100, 4)}")
    

def train(scaler):
    device = torch.device("cuda:0")
    model = Unet().to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    train_loader = get_train_data()

    num_epochs = 5
    for epoch in range(0, num_epochs):
        train_epoch(train_loader, model, criterion, optimizer, scaler, device=device)

In [5]:
scaler = torch.amp.GradScaler(
    init_scale=2**16,
    growth_interval=2000,
)
train(scaler)

Loss: 0.602 Accuracy: 95.9908: 100%|██████████| 40/40 [00:20<00:00,  1.94it/s] 
Loss: 0.5946 Accuracy: 98.1044: 100%|██████████| 40/40 [00:19<00:00,  2.01it/s]
Loss: 0.589 Accuracy: 98.4951: 100%|██████████| 40/40 [00:19<00:00,  2.04it/s] 
Loss: 0.5894 Accuracy: 98.793: 100%|██████████| 40/40 [00:18<00:00,  2.11it/s] 
Loss: 0.5816 Accuracy: 98.8142: 100%|██████████| 40/40 [00:19<00:00,  2.04it/s]


In [6]:
scaler = StaticScaler(
    init_scale=2**16,
)
train(scaler)

Loss: 0.6078 Accuracy: 95.1426: 100%|██████████| 40/40 [00:20<00:00,  1.97it/s]
Loss: 0.5935 Accuracy: 97.6958: 100%|██████████| 40/40 [00:19<00:00,  2.02it/s]
Loss: 0.5879 Accuracy: 98.3618: 100%|██████████| 40/40 [00:19<00:00,  2.04it/s]
Loss: 0.5856 Accuracy: 98.7929: 100%|██████████| 40/40 [00:19<00:00,  2.01it/s]
Loss: 0.581 Accuracy: 98.8224: 100%|██████████| 40/40 [00:20<00:00,  2.00it/s] 


In [7]:
scaler = torch.amp.GradScaler(
    init_scale=2**16,
    growth_factor=2,
    backoff_factor=0.5,
    growth_interval=1,
)
train(scaler)

Loss: 0.6201 Accuracy: 92.3393: 100%|██████████| 40/40 [00:20<00:00,  1.97it/s]
Loss: 0.6073 Accuracy: 94.9721: 100%|██████████| 40/40 [00:19<00:00,  2.03it/s]
Loss: 0.5952 Accuracy: 97.2112: 100%|██████████| 40/40 [00:20<00:00,  1.99it/s]
Loss: 0.5921 Accuracy: 98.2141: 100%|██████████| 40/40 [00:20<00:00,  1.98it/s]
Loss: 0.5876 Accuracy: 98.2719: 100%|██████████| 40/40 [00:20<00:00,  1.99it/s]


In [8]:
scaler = DynamicScaler(
    init_scale=2**16,
    growth_factor=2,
    backoff_factor=0.5,
    growth_interval=1,
)
train(scaler)

Loss: 0.6219 Accuracy: 92.5257: 100%|██████████| 40/40 [00:19<00:00,  2.04it/s]
Loss: 0.6051 Accuracy: 94.9057: 100%|██████████| 40/40 [00:19<00:00,  2.01it/s]
Loss: 0.6018 Accuracy: 96.4858: 100%|██████████| 40/40 [00:20<00:00,  1.97it/s]
Loss: 0.5926 Accuracy: 97.5584: 100%|██████████| 40/40 [00:20<00:00,  1.99it/s]
Loss: 0.5894 Accuracy: 98.0231: 100%|██████████| 40/40 [00:19<00:00,  2.06it/s]
