## Load libraries

In [1]:
import sys
sys.path.append('/home/ubuntu/hrlcm/hrlcm')

In [2]:
import argparse
from augmentation import *
from dataset import *
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import pickle as pkl
from models.deeplab import DeepLab
from models.unet import UNet
from train import Trainer
from loss import BalancedCrossEntropyLoss
from sync_batchnorm import convert_model

In [3]:
import os
os.chdir('/home/ubuntu/hrlcm')

## Get dummy inline arguments 

In [15]:
# Define a dummy args for testing
class args_dummy:
    def __init__(self):
        self.exp_name = 'unet_test'
        self.data_dir = 'results/north'
        self.out_dir = 'results/dl'
        self.lowest_score = 10
        self.noise_ratio = 0
        self.trans_prob = 0.5
        self.label_offset = 1
        self.rg_rotate = '-90, 90'
        self.model = 'unet'
        self.train_mode = 'single'
        self.out_stride = 8
        self.gpu_devices = '0, 1, 2, 3'
        self.sync_norm = False
        self.lr = 0.001
        self.decay = 1e-5
        self.save_freq = 10
        self.log_feq = 10
        self.batch_size = 32
        self.epochs = 100
        self.optimizer_name = 'Adam'
        self.resume = None
        self.checkpoint_dir = os.path.join(self.out_dir, self.exp_name, 'checkpoints')
        self.logs_dir = os.path.join(self.out_dir, self.exp_name, 'logs')


# Initialize dummy args
args = args_dummy()

In [5]:
# Set directory for saving files
if args.exp_name:
    args.checkpoint_dir = os.path.join(args.out_dir, args.exp_name, 'checkpoints')
    args.logs_dir = os.path.join(args.out_dir, args.exp_name, 'logs')
else:
    args.checkpoint_dir = os.path.join(args.out_dir, args.model, 'checkpoints')
    args.logs_dir = os.path.join(args.out_dir, args.model, 'logs')

# Create dirs if necessary
if not os.path.isdir(args.checkpoint_dir):
    os.makedirs(args.checkpoint_dir)
if not os.path.isdir(args.logs_dir):
    os.makedirs(args.logs_dir)

# Dir for mean and sd pickles
args.stats_dir = os.path.join(args.data_dir, 'norm_stats')

In [6]:
# Set flags for GPU processing if available
if torch.cuda.is_available():
    args.use_gpu = True
else:
    args.use_gpu = False

# Load dataset
# Define rotate degrees
args.rg_rotate = tuple(float(each) for each in args.rg_rotate.split(','))

# synchronize transform for train dataset
sync_transform = Compose([
    RandomScale(prob=args.trans_prob),
    RandomFlip(prob=args.trans_prob),
    RandomCenterRotate(degree=args.rg_rotate,
                       prob=args.trans_prob),
    SyncToTensor()
])

# synchronize transform for validate dataset
val_transform = Compose([
    SyncToTensor()
])

## Load datasets 

In [7]:
# Get train dataset
train_dataset = NFSEN1LC(data_dir=args.data_dir,
                         usage='train',
                         lowest_score=args.lowest_score,
                         noise_ratio=args.noise_ratio,
                         label_offset=args.label_offset,
                         sync_transform=sync_transform,
                         img_transform=None,
                         label_transform=None)
# Put into DataLoader
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=args.batch_size,
                          shuffle=True,
                          drop_last=True)

# Get validate dataset
validate_dataset = NFSEN1LC(data_dir=args.data_dir,
                            usage='validate',
                            label_offset=args.label_offset,
                            sync_transform=val_transform,
                            img_transform=None,
                            label_transform=None)
# Put into DataLoader
validate_loader = DataLoader(dataset=validate_dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             drop_last=False)

## Set model 

In [9]:
torch.cuda.device_count()

4

In [11]:
args.gpu_devices

[1, 2, 3, 4]

In [16]:
# Set up network
args.n_classes = train_dataset.n_classes
args.n_channels = train_dataset.n_channels
if args.model == "deeplab":
    model = DeepLab(num_classes=args.n_classes,
                    backbone='resnet',
                    pretrained_backbone=False,
                    output_stride=args.out_stride,
                    sync_bn=False,
                    freeze_bn=False,
                    n_in=args.n_channels)
else:
    model = UNet(n_classes=args.n_classes,
                 n_channels=args.n_channels)

args.use_gpu = torch.cuda.is_available()
    # Get devices
if args.gpu_devices:
    args.gpu_devices = [int(each) for each in args.gpu_devices.split(',')]

# Set model
if args.use_gpu:
    if args.gpu_devices:
        torch.cuda.set_device(args.gpu_devices[0])
        model = torch.nn.DataParallel(model, device_ids=args.gpu_devices)
        if args.sync_norm:
            model = convert_model(model)
    model = model.cuda()

# Define loss function
loss_fn = BalancedCrossEntropyLoss()

# Define optimizer
if args.optimizer_name == 'Adadelta':
    optimizer = torch.optim.Adadelta(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.decay)
elif args.optimizer_name == 'Adam':
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr)
else:
    print('Not supported optimizer, use Adam instead.')
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr)

# Set up tensorboard logging
writer = SummaryWriter(log_dir=args.logs_dir)

# Save config
pkl.dump(args, open(os.path.join(args.checkpoint_dir, "args.pkl"), "wb"))

## Run epochs 

In [None]:
step = 0
trainer = Trainer(args)
pbar = tqdm(total=args.epochs, desc="[Epoch]")
for epoch in range(args.epochs):
    # Run training for one epoch
    model, step = trainer.train(model, train_loader, loss_fn,
                                optimizer, writer, step=step)
    # Run validation
    trainer.validate(model, validate_loader, step, loss_fn, writer)

    # Save checkpoint
    if epoch % args.save_freq == 0:
        trainer.export_model(model, optimizer=optimizer, step=step)

    # Update pbar
    pbar.update()

# Export final set of weights
trainer.export_model(model, optimizer, name="final")

# Close pbar
pbar.close()

HBox(children=(FloatProgress(value=0.0, description='[Epoch]', style=ProgressStyle(description_width='initial'…

HBox(children=(FloatProgress(value=0.0, description='[Train]', max=114.0, style=ProgressStyle(description_widt…