## Install required libraries

In [None]:
!pip install hub
!pip install matplotlib
!pip install torch
!pip install segmentation_models_pytorch

## Imports

In [None]:
!hub login

In [None]:
import psutil

import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
import torch.nn.functional as F
from torch.utils.data import ConcatDataset
import torch
from torch.utils.data import random_split
import segmentation_models_pytorch as smp

import hub
from hub.compute.transforms_generic.ds_transforms import shift_scale_rotate, gaussian_noise
from hub.api.sharded_datasetview import ShardedDatasetView

## Load the dataset

In [None]:
ds = hub.load("activeloop/lost_and_found_semantic_segmentation_train")
ds_test = hub.load("activeloop/lost_and_found_semantic_segmentation_test")

## Augment images and add to the original Dataset

In [None]:
ds_augmented = shift_scale_rotate(ds, keys=['image_left', 'segmentation_label'], rotate_limit=0, shift_limit=0.1)
ds_augmented = gaussian_noise(ds_augmented, keys=['image_left'])
ds_augmented = ds_augmented.store("/tmp/lost_and_found_aug")
ds_sharded = ShardedDatasetView([ds, ds_augmented])

@hub.transform(schema=ds_sharded.schema, scheduler="threaded", workers=psutil.cpu_count() - 1)
def transform_identity(sample):
    return sample

ds = transform_identity(ds_sharded).store('/tmp/lost_and_found_all')

In [None]:
model = smp.Unet('resnet50',
            classes=39,
            in_channels=3,
        )
iou = smp.utils.metrics.IoU(eps=1.0, activation=None, threshold=0.5)
loss = smp.utils.losses.DiceLoss()
optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])
train_epoch = smp.utils.train.TrainEpoch(
  model, 
  loss=loss, 
  metrics=[iou], 
  optimizer=optimizer,
  device='cuda',
  verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=[iou],
    device='cuda',
    verbose=True,
)


def train(train_loader: torch.utils.data.DataLoader, val_loader: torch.utils.data.DataLoader):
  max_score = 0

  for i in range(0, 10):      
      print('\nEpoch: {}'.format(i))
      train_logs = train_epoch.run(train_loader)
      valid_logs = valid_epoch.run(val_loader)
      
      # do something (save model, change lr, etc.)
      if max_score < valid_logs['iou_score']:
          max_score = valid_logs['iou_score']
          torch.save(model, './best_model.pth')
          print('Model saved!')
          
      if i == 25:
          optimizer.param_groups[0]['lr'] = 1e-5
          print('Decrease decoder learning rate to 1e-5!')

In [None]:
def transform(sample):
  img = sample['image_left'].permute(2, 0, 1).float()
  mask = sample['segmentation_label'].permute(2, 0, 1).float()
  return img, mask

torch_ds = ds.to_pytorch(key_list=['image_left', 'segmentation_label'], output_type=tuple, transform=transform)
torch_ds_test = ds_test.to_pytorch(key_list=['image_left', 'segmentation_label'], output_type=tuple, transform=transform)
train_dataloader = torch.utils.data.DataLoader(
        torch_ds,
        batch_size=1,
        shuffle=True,
        num_workers=4
    )
val_dataloader = torch.utils.data.DataLoader(
        torch_ds_test,
        batch_size=1,
        shuffle=False,
        num_workers=4
    )
train(train_dataloader, val_dataloader)