# U-Net Training Notebook

In [1]:
# !pip install wandb
# !rm -r sample_data
# !git clone https://github.com/PhilipMathieu/unet-orthoimagery.git
# !mv unet-orthoimagery/* ./

[WandB Quickstart](https://docs.wandb.ai/quickstart#:~:text=Provide%20your%20API%20key%20when%20prompted.) (including link to your API key, if logged in)

In [2]:
import wandb
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mphilipmathieu[0m ([33munet-ortho[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
import logging
import torch
from src.train import train_model
from src.unet.unet_model import UNet

In [4]:
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device: {device}')

INFO: Using device: cpu


In [5]:
model = UNet(n_channels=4, n_classes=1, bilinear=False)
model = model.to(memory_format=torch.channels_last)

In [6]:
logging.info(f'Network:\n'
                 f'\t{model.n_channels} input channels\n'
                 f'\t{model.n_classes} output channels (classes)\n'
                 f'\t{"Bilinear" if model.bilinear else "transposed conv"} upscaling')


model.to(device=device);

INFO: Network:
	4 input channels
	1 output channels (classes)
	transposed conv upscaling


In [8]:
data_dir = "data/Image_Chips_128_nostride_unbalanced_dem/"

In [9]:
try:
    train_model(
        model=model,
        epochs=20,
        device=device,
        data_dir=data_dir
    )
except torch.cuda.OutOfMemoryError: # Giving me syntax error saying '"OutOfMemoryError" is not a valid exception class.'
    logging.error('Detected OutOfMemoryError! '
                    'Enabling checkpointing to reduce memory usage, but this slows down training. '
                    'Consider enabling AMP (--amp) for fast and memory efficient training')
    torch.cuda.empty_cache()
    model.use_checkpointing()
    train_model(
        model=model,
        epochs=20,
        device=device,
        data_dir=data_dir
    )