In [None]:
!rm -r sample_data
!git clone https://github.com/PhilipMathieu/unet-orthoimagery.git .

In [10]:
import logging
import torch
from src.train import train_model
from src.unet.unet_model import UNet

In [11]:
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device: {device}')

INFO: Using device: cpu


In [12]:
model = UNet(n_channels=4, n_classes=1, bilinear=False)
model = model.to(memory_format=torch.channels_last)

In [15]:
logging.info(f'Network:\n'
                 f'\t{model.n_channels} input channels\n'
                 f'\t{model.n_classes} output channels (classes)\n'
                 f'\t{"Bilinear" if model.bilinear else "transposed conv"} upscaling')


model.to(device=device);

INFO: Network:
	4 input channels
	1 output channels (classes)
	transposed conv upscaling


In [17]:
try:
    train_model(
        model=model,
        epochs=20,
        device=device,
        data_dir="data/Image_Chips_128_overlap_unbalanced_dem/"
    )
except torch.cuda.OutOfMemoryError: # Giving me syntax error saying '"OutOfMemoryError" is not a valid exception class.'
    logging.error('Detected OutOfMemoryError! '
                    'Enabling checkpointing to reduce memory usage, but this slows down training. '
                    'Consider enabling AMP (--amp) for fast and memory efficient training')
    torch.cuda.empty_cache()
    model.use_checkpointing()
    train_model(
        model=model,
        epochs=20,
        device=device,
        data_dir="data/Image_Chips_128_overlap_unbalanced_dem/"
    )

KeyboardInterrupt: 