In [None]:
from torch.utils.data import DataLoader
import super_gradients as sg
from super_gradients.training.metrics.segmentation_metrics import IoU

## Load the data
#
PATH = '/path/to/goose'
test_dict, train_dict, val_dict, mapping_dict = goose_create_dataDict(PATH)

train_dataset = GOOSE_SemanticDataset(train_dict, crop=True, resize_size=(768,768))
val_dataset   = GOOSE_SemanticDataset(val_dict, crop=True, resize_size=(768,768))

## Set-up for training
#

# Create output directory
EXPERIMENT_NAME = "GOOSE_train"
WS_PATH = os.getcwd()
CHECKPOINT_DIR = os.path.join(WS_PATH, 'output', 'ckpts')

if not os.path.isdir(CHECKPOINT_DIR):
    os.makedirs(CHECKPOINT_DIR)

# Params
BATCH_SIZE      = 5
N_EPOCHS        = 10

# Dataloaders
train_dataloader    = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=5, drop_last=True)
val_dataloader      = DataLoader(val_dataset,  batch_size=BATCH_SIZE, shuffle=True, num_workers=5, drop_last=True)

In [None]:
# Trainer Set-up
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
sg.setup_device(device=device)
trainer = sg.Trainer(experiment_name=EXPERIMENT_NAME, ckpt_root_dir=CHECKPOINT_DIR)

## Load Model
#
model = sg.training.models.get(model_name=Models.DDRNET_39,
                    num_classes=64,
                    pretrained_weights='cityscapes')
model.eval()

# Set-up Training params
lr_updates = [int(.3 * N_EPOCHS), int(.6 * N_EPOCHS), int(.9 * N_EPOCHS)]
train_params = {
            "max_epochs": N_EPOCHS,
            "lr_mode":"step",
            "lr_updates": lr_updates,
            "lr_decay_factor": 0.1,
            "initial_lr": 0.005,
            "optimizer": 'sgd',
            "loss": 'cross_entropy',
            "average_best_models": False,
            "greater_metric_to_watch_is_better": True,
            "loss_logging_items_names": ["loss"],
            "drop_last": True,
            }

train_params["train_metrics_list"] = [IoU(num_classes=64)]
train_params["valid_metrics_list"] = [IoU(num_classes=64)]
train_params["metric_to_watch"]    = "IoU"

In [None]:
## Train
trainer.train(model=model, training_params=train_params, train_loader=train_dataloader, valid_loader=val_dataloader)