In [15]:
from pathlib import Path
import torch
!pip install segmentation_models_pytorch
import segmentation_models_pytorch as smp
from torch.utils.data import DataLoader, random_split
from torch.utils.data.dataset import Dataset
import glob
torch.manual_seed(2022)  # Setting random seed so that augmentations can be reproduced.
from SeaIce.unet.dataset_preparation import CustomImageDataset
import numpy as np



In [4]:
### This class needs to be imported from a separate file.
class CustomImageDataset(Dataset):
    """GTC Code for a dataset class. The class is instantiated with list of filenames within a directory (created using
    the list_npy_filenames function). The __getitem__ method pairs up corresponding image-label .npy file pairs. This
    dataset can then be input to a dataloader."""

    def __init__(self, paths, isSingleBand=True):
        self.paths = paths
        self.isSingleBand = isSingleBand

    def __getitem__(self, index):
        image = torch.from_numpy(np.vstack(np.load(self.paths[index][0])).astype(float))
        if self.isSingleBand:
            image = image[None, :]
        else:
            #image = torch.permute(image, (3, 1, 2, 0))
            image = torch.permute(image, (2, 0, 1))
        mask_raw = (np.load(self.paths[index][1]))
        maskremap100 = np.where(mask_raw == 100, 0, mask_raw)
        maskremap200 = np.where(maskremap100 == 200, 1, maskremap100)
        mask = torch.from_numpy(np.vstack(maskremap200).astype(float))
        mask = mask[None, :]

        # assert image.size == mask.size, \
        #    'Image and mask {name} should be the same size, but are {img.size} and {mask.size}'
        #image = image.expand(-1, 3, -1, -1)
        return image, mask

    def __len__(self):
        return len(self.paths)

In [5]:
def create_npy_list(image_directory, img_string="sar"):
    """A function that returns a list of the names of the SAR/MODIS and labelled .npy files in a directory. These lists can
    then be used as an argument for the Dataset class instantiation. The function also checks that the specified directory
    contains matching sar or MODIS/labelled pairs -- specifically, a label.npy file for each image file."""

    img_names = sorted(glob.glob(str(image_directory) + '/*_' + img_string + '.npy'))
    label_names = sorted(glob.glob(str(image_directory) + '/*_labels.npy'))

    # In-depth file-by-file check for matching sar-label pairs in the directory -- assuming  each sar image has a corresponding
    # labeled image.
    img_label_pairs = []
    for image in img_names:
        expected_label_name = image.replace(img_string, "labels")
        if expected_label_name in label_names:
            img_label_pairs.append((image, expected_label_name))
        else:
            raise Exception(f'{img_string} tile name {image} does not have a matching labeled tile.')

    return img_label_pairs

In [12]:
# If running on colab:
"""
from google.colab import drive
drive.mount('/content/drive')
from pathlib import Path
dir_img = Path('/content/drive/Shareddrives/2021-gtc-sea-ice/trainingdata/tiled/')
"""

# If running locally:
#dir_img = Path('tiled/')
dir_img = Path('/mnt/g/Shared drives/2021-gtc-sea-ice/trainingdata/tiled/')


In [13]:
# Inputs

imagery = 'sar' # SAR / MODIS
val_percent = 0.1
batch_size = 10

# Model Creation
# 1. Create dataset
img_list = create_npy_list(dir_img, 'sar')
if imagery == 'sar':
    single_channel = True
    n_channels = 1
else:
    single_channel = False
    n_channels = 3
dataset = CustomImageDataset(img_list, single_channel)

# 2. Split into train / validation partitions
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

# 3. Create data loaders
loader_args = dict(batch_size=batch_size, num_workers=2, pin_memory=True)
train_loader = DataLoader(train_set, shuffle=True, **loader_args)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

# Specify models settings
loss = smp.utils.losses.DiceLoss()
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]
# Can also specify number of UNet steps and channel numbers.
model = smp.Unet(encoder_name='resnet34', encoder_weights='imagenet', decoder_use_batchnorm=True,
                 decoder_attention_type=None, in_channels=n_channels, classes=1, encoder_depth=5)
model = model.double()

optimizer = torch.optim.Adam([
    dict(params=model.parameters(), lr=0.0001),
])

DEVICE = 'cpu'#torch.device('cuda')

train_epoch = smp.utils.train.TrainEpoch(
    model,
    loss=loss,
    metrics=metrics,
    optimizer=optimizer,
    verbose=True,
    device=DEVICE
)

valid_epoch = smp.utils.train.ValidEpoch(
    model,
    loss=loss,
    metrics=metrics,
    verbose=True,
    device=DEVICE
)

# train model for 40 epochs

max_score = 0
n_epochs = 1



Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /home/mlisaius/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth


  0%|          | 0.00/83.3M [00:00<?, ?B/s]

In [None]:
# Training
if __name__ == '__main__':
    for i in range(0, n_epochs):

        print('\nEpoch: {}'.format(i))
        train_logs = train_epoch.run(train_loader)
        valid_logs = valid_epoch.run(val_loader)

        # do something (save model, change lr, etc.)
        if max_score < valid_logs['iou_score']:
            max_score = valid_logs['iou_score']
            torch.save(model, './best_model.pth')
            print('Model saved!')

        if i == n_epochs/2:
            optimizer.param_groups[0]['lr'] = 1e-5
            print('Decrease decoder learning rate to 1e-5!')


Epoch: 0
train:   0%|                                                                                    | 0/167 [00:00<?, ?it/s]