In [2]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt 
import albumentations as albu

from torch.utils.data import DataLoader, Dataset as BaseDataset
from pathlib import Path

In [3]:
np.random.seed(0)

# load the data
datapath = Path.cwd().parent / 'data'
data_train = np.load(datapath/'data_train.npz', allow_pickle=True, mmap_mode='r')
data_train = data_train['data']

labels_train = np.load(datapath/'labels_train.npz', allow_pickle=True, mmap_mode='r')
labels_train = labels_train['labels']


# seperate data into train and test sets
test_idx = np.random.randint(low=0, high=data_train.shape[-1], size=int(data_train.shape[-1] * .15))
train_idx = list(set(range(data_train.shape[-1])) - set(test_idx))

data_test = data_train[...,test_idx]
data_train = data_train[...,train_idx]

labels_test = labels_train[...,test_idx]
labels_train = labels_train[...,train_idx]

In [4]:
def transformations(image, mask):
  image_trans = transforms.Compose([transforms.ToPILImage(),
                                    transforms.Resize((224,224)),
#                                     transforms.Resize((1024,800)),
                                    transforms.ToTensor()])
 
  return image_trans(image), image_trans(mask)

In [5]:
def get_training_augmentation():
    train_transform = [
                        albu.Resize(1024,800),
#                        albu.Resize(224,224),
#                        albu.HorizontalFlip(p=1)
                       albu.pytorch.transforms.ToTensor()
                      ]
    return albu.Compose(train_transform)


In [12]:
class Dataset(BaseDataset):
    
    def __init__(
            self, 
            data, 
            masks, 
            augmentation=None, 
    ):
        self.data = data
        self.masks = masks
        
        # get the class values
        self.class_values = list(range(6))
        self.augmentation = augmentation
    
    def __getitem__(self, i):
        
        # read data
        image = self.data[:,:,i]
        mask = self.masks[:,:,i] - 1 # to shift class labes from 1, 2, ... 6 to 0, 1, ... 5
            
        if self.augmentation:
            image, mask = self.augmentation(image=image, mask=mask.astype(np.float32))
            
        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
#         mask = np.stack(masks, axis=-1)
        mask = torch.cat(masks, axis=0)

        return image, mask
      
    def __len__(self):
        return self.data.shape[-1]


In [15]:
train_dataset = Dataset(data=data_train, masks=labels_train, augmentation=transformations)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True,)

test_dataset = Dataset(data=data_test, masks=labels_train, augmentation=transformations)
test_loader = DataLoader(test_dataset, batch_size=10)

In [16]:
model = smp.Unet(in_channels=1, classes=6, activation='softmax')

loss = smp.utils.losses.DiceLoss()
metrics = [smp.utils.metrics.IoU(threshold=0.5)]
optimizer = torch.optim.Adam([dict(params=model.parameters(), lr=0.0001)])
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [17]:
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [18]:
max_score = 0

for i in range(0, 2):
    
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    
    #do something (save model, change lr, etc.)
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, './best_model.pth')
        print('Model saved!')
        
    if i == 25:
        optimizer.param_groups[0]['lr'] = 1e-5
        print('Decrease decoder learning rate to 1e-5!')



Epoch: 0
train:   0%|                                                                                    | 0/51 [00:00<?, ?it/s]

  img = torch.from_numpy(np.array(pic, np.float32, copy=False))
  return self.activation(x)


train:  10%|███▍                               | 5/51 [00:21<03:21,  4.37s/it, dice_loss - 0.7939, iou_score - 0.03853]


KeyboardInterrupt: 