<a href="https://colab.research.google.com/github/stebbibg/MSc_Fstudent_SLAM/blob/main/Deeplab_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:

#from google.colab import drive
import PIL
from PIL import Image
#drive.mount('/content/drive')
import numpy as np
import torchvision
import torch
import torchvision.transforms.functional as TF

In [2]:
from skimage.util import random_noise
import random
#from google.colab import drive
import PIL
from PIL import Image

# The validation images were not augmented
def getValImg(img, mask):
  return img, mask

def getTrainImg(img, mask):
  # The probabilities for which augmentations will be used
  p_crop = 0.8
  p_affine = 0.2
  p_color_jitter = 00.5
  p_sp = 0.1
  p_speckle = 0.1
  p_erase = 0.2

  # Gaussian noise parameters
  p_gauss = 0.5
  kernel_size = 15
  
  if random.random() < p_gauss:
    kernel_size = random.randrange(5, 25, 2)
    color_jitter_t = torchvision.transforms.Compose([
      torchvision.transforms.GaussianBlur(kernel_size, sigma=(0.1, 3.0)),
    ])
    img = color_jitter_t(img)

  if random.random() < p_color_jitter:
    color_jitter_t = torchvision.transforms.Compose([
      torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.01)
    ])
    img = color_jitter_t(img)

  # random affine
  if random.random() < p_affine:
    affine_params = torchvision.transforms.RandomAffine.get_params((-8, 8), (0.05, 0.05), (0.95, 0.95), (-8, 8), img.size)
    img, mask = TF.affine(img, *affine_params), TF.affine(mask, *affine_params)

  # Random crop
  if random.random() < p_crop:
    new_width = random.randint(1088, 1554)
    new_height = random.randint(1456, 2080)
    resize = torchvision.transforms.Resize(size=(new_width, new_height), interpolation=PIL.Image.NEAREST)
    img = resize(img)
    mask = resize(mask)

    i, j, h, w = torchvision.transforms.RandomCrop.get_params(
        img, output_size=(1088, 1456))
    img = TF.crop(img, i, j, h, w)
    mask = TF.crop(mask, i, j, h, w)
  
  # speckle noise
  if random.random() < p_speckle:
    img_sp = np.asarray(img)
    img_sp = random_noise(img_sp, mode='speckle', mean=0.1, seed=42,
                                var=0.2)    
    img_sp = img_sp.transpose((2, 0, 1))
    img = torch.from_numpy(img_sp)
    
    img = torchvision.transforms.ToPILImage()(img).convert("RGB")

  # salt and pepper noise
  if random.random() < p_sp:
    img_sp = np.asarray(img)
    img_sp = random_noise(img_sp, mode='s&p', salt_vs_pepper=0.5, clip=True)
    img_sp = img_sp.transpose((2, 0, 1))

    img = torch.from_numpy(img_sp)
    
    img = torchvision.transforms.ToPILImage()(img).convert("RGB")

  if random.random() < p_erase:
    mask_tensor = TF.to_tensor(mask)
    img = TF.to_tensor(img)
    i, j, h, w, v = torchvision.transforms.RandomErasing.get_params(img, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=[0])

    # Make it to three channels
    mask_tensor = mask_tensor.repeat(3,1,1)

    img = TF.erase(img, i, j, h, w, v)
    mask_tensor = TF.erase(mask_tensor, i, j, h, w, v)
    # Extract the first channel
    mask, _, _ = mask_tensor.unbind(0)
    mask = TF.to_pil_image(mask)
    img = TF.to_pil_image(img)

  return img, mask


In [4]:
from os.path import splitext
from os import listdir
import numpy as np
from glob import glob
import torchvision
import torch
from torch.utils.data import Dataset
import logging
from PIL import Image
import torchvision.transforms.functional as TF
import random

#train_indx = []  # Indices for the training images
#val_indx = []    # Indices for the validation images  

class TrainingDataset(Dataset):
    def __init__(self, imgs_dir, masks_dir, train_indx, val_indx, scale=0.5, mask_suffix=''):
        self.train_indx = []
        self.val_indx = []
        self.imgs_dir = imgs_dir
        self.masks_dir = masks_dir
        self.scale = scale
        self.mask_suffix = mask_suffix
        assert 0 < scale <= 1, 'Scale must be between 0 and 1'

        self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
                    if not file.startswith('.')]
        logging.info(f'Creating dataset with {len(self.ids)} examples')

    def __len__(self):
        return len(self.ids)

    @classmethod
    def preprocess(cls, pil_img, scale):
        w, h = pil_img.size
        newW, newH = int(scale * w), int(scale * h)
        assert newW > 0 and newH > 0, 'Scale is too small'
        pil_img = pil_img.resize((newW, newH))

        img_nd = np.array(pil_img)

        if len(img_nd.shape) == 2:
            img_nd = np.expand_dims(img_nd, axis=2)

        # HWC to CHW
        img_trans = img_nd.transpose((2, 0, 1))

        return img_trans

    def __getitem__(self, i):
        idx = self.ids[i]
        mask_file = glob(self.masks_dir + idx + self.mask_suffix + '.*')
        img_file = glob(self.imgs_dir + idx + '.*')

        mask = Image.open(mask_file[0])
        img = Image.open(img_file[0])

        if int(i) in self.train_indx:
          #print("IMAGE IN TRAIN")
          img, mask = getTrainImg(img, mask)
        else:
          #print("IMAGE IN VAL")
          img, mask = getValImg(img, mask)
        
        img = self.preprocess(img, self.scale)
        mask = self.preprocess(mask, self.scale)

        img_tensor = torch.from_numpy(img).type(torch.FloatTensor)

        preprocess_image = torchvision.transforms.Compose([
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        img_tensor = preprocess_image(img_tensor)

        return {
            'image': img_tensor,
            'mask': torch.from_numpy(mask).type(torch.FloatTensor)
        }

In [5]:
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torchvision.models.segmentation import deeplabv3_resnet101

# A pretrained deeplab module
def custom_DeepLabv3(out_channel):
  model = deeplabv3_resnet101(pretrained=True)
  # Make a new output layer
  model.classifier = DeepLabHead(2048, out_channel)

  #Set the model in training mode
  model.train()
  return model

In [6]:
# Counting the parameters in the module
model = custom_DeepLabv3(5)
ct = 0
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(count_parameters(model))

60992090


In [7]:
from torch.utils.data import DataLoader, random_split
import torch
import torch.nn as nn
from torch import optim
import os
import sys
import copy
import matplotlib.pyplot as plt
from torchvision import transforms
from sklearn.metrics import jaccard_score
import tqdm

device = torch.device('cuda:0')
model.to(device)

batch_size = 2
img_scale = 0.5
dir_img = '/home/agervig/git/FSM/MSc_Fstudent_SLAM/data/training_data/imgs/'
dir_mask = '/home/agervig/git/FSM/MSc_Fstudent_SLAM/data/training_data/masks/'
lr = 0.00002
dir_checkpoint = '/home/agervig/git/FSM/MSc_Fstudent_SLAM/Segmentation/experiments/checkpoints/'
dataset = TrainingDataset(dir_img, dir_mask,[], [], img_scale)

n_val = 133
n_test = 100
n_train = 850
train, val, test = random_split(dataset, [n_train, n_val, n_test])
dataset.train_indx = train.indices
dataset.val_indx = val.indices

train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, drop_last=True)
test_loader = DataLoader(test, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, drop_last=True)

optimizer = optim.Adam(model.parameters(), lr=lr)

criterion = nn.CrossEntropyLoss()

epochs = 2 #70
# To load a specific state
#model.load_state_dict(torch.load('/content/drive/My Drive/Colab Notebooks//checkpoints/8.pth'))
train_loop = tqdm.tqdm(range(epochs))



for i in (train_loop):
  current_epoch_loss = 0
  model.train()
  ctr = 0
  for batch in train_loader:
    ctr += 1
    optimizer.zero_grad()
    imgs = batch['image']
    true_masks = batch['mask']
    imgs = imgs.to(device=device, dtype=torch.float32)
    mask_type = torch.float32
    true_masks = true_masks.to(device=device, dtype=torch.long)
    masks_pred = model(imgs)

    true_masks_flat = torch.squeeze(true_masks, dim=0)
    true_masks_flat = true_masks_flat.squeeze(1)

    loss = criterion(masks_pred['out'], true_masks_flat)

    current_epoch_loss += loss.item()
    loss.backward()
    optimizer.step()
  model.eval()
  val_loss = 0
  for batch in val_loader:
    imgs = batch['image']
    true_masks = batch['mask']
    imgs = imgs.to(device=device, dtype=torch.float32)
    mask_type = torch.float32
    true_masks = true_masks.to(device=device, dtype=torch.long)
    masks_pred = model(imgs)

    true_masks_flat = torch.squeeze(true_masks, dim=1)
    true_masks_flat = true_masks_flat.squeeze(1)

    loss = criterion(masks_pred['out'], true_masks_flat)
    val_loss += loss.item()
  total_iou = 0
  model.eval()
  test_counter = 0
  for batch in test_loader:
    imgs = batch['image']
    true_masks = batch['mask']
    imgs = imgs.to(device=device, dtype=torch.float32)
    mask_type = torch.float32
    true_masks = true_masks.to(device=device, dtype=torch.long)
    masks_pred = model(imgs)['out']

    masks_pred = torch.argmax(masks_pred, axis=1)

    true_masks_flat = torch.squeeze(true_masks, dim=1)

    true_masks_flat = true_masks_flat.cpu().numpy().reshape(-1)
    masks_pred = masks_pred.cpu().numpy().reshape(-1)

    iou = jaccard_score(true_masks_flat, masks_pred, labels=[1, 2, 3, 4], average= 'micro')
    total_iou += iou
    test_counter += 1
  
  iou = total_iou /test_counter

  current_epoch_loss /= n_train
  val_loss /= n_val
  print("epoch: " + str(i + 1) + " training loss: " + str(current_epoch_loss) + " val loss: " + str(val_loss) + " iou: " + str(iou) + "\n")
  file1 = open("/home/agervig/git/FSM/MSc_Fstudent_SLAM/Segmentation/experiments/training.txt", "a")  #"/content/drive/My Drive/Colab Notebooks/training.txt"
  file1.write(str(current_epoch_loss) + " " + str(val_loss) + " " + str(iou) + "\n")
  torch.save(copy.deepcopy(model.state_dict()), dir_checkpoint + str(i) + ".pth")

  resize = torchvision.transforms.Resize(size=(new_width, new_height), interpolation=PIL.Image.NEAREST)
  resize = torchvision.transforms.Resize(size=(new_width, new_height), interpolation=PIL.Image.NEAREST)
  resize = torchvision.transforms.Resize(size=(new_width, new_height), interpolation=PIL.Image.NEAREST)
  resize = torchvision.transforms.Resize(size=(new_width, new_height), interpolation=PIL.Image.NEAREST)
  resize = torchvision.transforms.Resize(size=(new_width, new_height), interpolation=PIL.Image.NEAREST)
  resize = torchvision.transforms.Resize(size=(new_width, new_height), interpolation=PIL.Image.NEAREST)
  resize = torchvision.transforms.Resize(size=(new_width, new_height), interpolation=PIL.Image.NEAREST)
  resize = torchvision.transforms.Resize(size=(new_width, new_height), interpolation=PIL.Image.NEAREST)
  0%|          | 0/2 [00:03<?, ?it/s]


RuntimeError: CUDA out of memory. Tried to allocate 98.00 MiB (GPU 0; 7.92 GiB total capacity; 5.90 GiB already allocated; 112.94 MiB free; 6.25 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

: 

# New Section
from google.colab import drive
drive.mount('/content/gdrive')
