In [None]:
# MIT License

# Copyright (c) 2022 Ghasem Abdi

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# Original Author       : Ghasem Abdi, ghasem.abdi@yahoo.com
# File Last Update Date : April 15, 2022

In [None]:
#import dependencies
import os
import torch
import changeDetector as cd #from src import changeDetector as cd
from pytorch_toolbelt import losses as L

#define device for running deep learning package
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# # chunk train, valid, and test set (optional)
# chunker = cd.chunk_data(number_tiles=16)

# _ = chunker.chunk(data_root='LEVIR_dataset/train')
# _ = chunker.chunk(data_root='LEVIR_dataset/valid')
# _ = chunker.chunk(data_root='LEVIR_dataset/test' )

In [None]:
#prepare train, valid, and test set
train_ds = cd.prepare_data(
    data_root='LEVIR_dataset/train',
    base_dir='A', 
    base_img_suffix='*.png',
    target_dir='B',
    target_img_suffix='*.png',
    label_dir='label', 
    label_mask_suffix='*.png',
    size=256,
    transform=None
)

valid_ds = cd.prepare_data(
    data_root='LEVIR_dataset/valid',
    base_dir='A', 
    base_img_suffix='*.png',
    target_dir='B',
    target_img_suffix='*.png',
    label_dir='label', 
    label_mask_suffix='*.png',
    size=256,
    transform=False
)

test_ds = cd.prepare_data(
    data_root='LEVIR_dataset/test',
    base_dir='A',
    base_img_suffix='*.png',
    target_dir='B',
    target_img_suffix='*.png',
    label_dir='label',
    label_mask_suffix='*.png',
    size=256,
    transform=False
)

In [None]:
#prepare training and testing data loaders
train_dl = {
    'train': cd.prepare_dataloader(dataset=train_ds, batch_size=64, shuffle=True, num_workers=os.cpu_count()),
    'valid': cd.prepare_dataloader(dataset=valid_ds, batch_size=64, shuffle=False, num_workers=os.cpu_count())
}

test_dl = cd.prepare_dataloader(dataset=test_ds, batch_size=64, shuffle=False, num_workers=os.cpu_count())

In [None]:
#prepare change detection net (avialable options: UNet and UNetPlusPlus)
model = cd.UNet(
    in_channels=3,
    encoder_name='resnet34',
    pretrained=True,
    decoder_channels=(256, 128, 64, 32, 16),
    encoder_fusion_type='concat',
    decoder_attention_type='se',
    classes=2
)

In [None]:
#print change detection net summary
cd.summary(model=model, input_size=((1, 3, 256, 256), (1, 3, 256, 256)))

In [None]:
#prepare change detection learner
loss = L.FocalLoss()
optim = torch.optim.Adam(params=model.parameters(), lr=0.001, betas=(0.5, 0.99), weight_decay=0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optim, step_size=10, gamma=0.1)

learner = cd.prepare_learning(
    model=model, 
    loss=loss, 
    optim=optim, 
    scheduler=scheduler, 
    num_epoch=25, 
    device=device
)

In [None]:
#train change detection net
train_logs, valid_logs = learner.train(data_loader=train_dl, average='micro')

In [None]:
#save change detection net as onnx
cd.export_onnx(model=model, input_size=((1, 3, 256, 256), (1, 3, 256, 256)), filename='change detection.onnx', \
    input_names=['base image', 'target image'], output_names=['change map'], opset_version=11)

In [None]:
#test change detection net
test_logs = learner.predict(data_loader=test_dl, average='micro')

In [None]:
#dechunk results (optional)
chunker = cd.chunk_data(number_tiles=16)
_ = chunker.dechunk(data_root='res/vis')