## `DeepLabv3+`

In [1]:
import torch
import matplotlib.pyplot as plt
from torchvision import transforms
from tqdm import tqdm
import numpy as np

import segmentation_models_pytorch as smp

# Dataset
from src.core.config import DatasetConfig
from src.utils.helpers import split_dataset
import albumentations as A
import albumentations.augmentations.functional as F
from albumentations.pytorch import ToTensorV2
from src.dataset.kitti import KittiSegDataset

# Model

import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from src.models.deeplabplus import DeepLabV3Plus
from src.core.config import HyperParameters
from src.utils.helpers import get_color_maps
from src.core.metrics import meanIoU


import logging
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)

import gc


device = "cuda" if torch.cuda.is_available() else "cpu"

hyper_parameters = HyperParameters()
id_to_color = get_color_maps()

# find optimal backend for performing convolutions 
torch.backends.cudnn.benchmark = True

  from .autonotebook import tqdm as notebook_tqdm


#### 1. Dataset

In [2]:
train_files_list, val_files_list = split_dataset(train_size=0.8)

config = DatasetConfig()
#https://albumentations.ai/docs/examples/pytorch_semantic_segmentation/
# transforms
train_transform = A.Compose(
    [
        A.Resize(config.image_size, config.image_size),
        A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30, p=0.5),
        A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)

val_and_test_transform = A.Compose(
    [A.Resize(config.image_size, config.image_size), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2()]
)

num_worker = 4 * int(torch.cuda.device_count())

# datasets
train_dataset = KittiSegDataset(train_files_list, transform=train_transform)
val_dataset = KittiSegDataset(val_files_list, transform=val_and_test_transform)

train_dataloader = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=hyper_parameters.batch_size,drop_last=True, num_workers=num_worker, pin_memory=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset,batch_size=hyper_parameters.batch_size, num_workers=num_worker, pin_memory=True)
    

#### 2. Model

In [3]:
# create model, optimizer, lr_scheduler and pass to training function
model = DeepLabV3Plus(in_channels=3, output_stride=8, num_classes=hyper_parameters.n_classes).to(device)

optimizer = optim.Adam(model.parameters(), lr=hyper_parameters.max_lr)
criterion = smp.losses.DiceLoss('multiclass', classes=[0,1], log_loss = True, smooth=1.0)
scheduler = OneCycleLR(optimizer, max_lr= hyper_parameters.max_lr, epochs = hyper_parameters.n_epochs, steps_per_epoch = 2*(len(train_dataloader)), 
                        pct_start=0.3, div_factor=10, anneal_strategy='cos')



#### 3.0 Train and validate model

In [4]:
min_val_loss = np.Inf

def training(model, train_dataloader, criterion, scheduler):
    
    torch.cuda.empty_cache()
    
    # Training
    model.train()
    
    train_loss = 0.0

    for inputs, labels in tqdm(train_dataloader, total=len(train_dataloader)):
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Forward pass
        y_preds = model(inputs)

        loss = criterion(y_preds, labels)
        train_loss += loss.item()
            
        # Backward pass
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # adjust learning rate
        if scheduler is not None:
            scheduler.step()
        
    # compute per batch losses, metric value
    train_loss = train_loss / len(train_dataloader)

    return train_loss

def evaluating(model, dataloader, criterion, metric_class, num_classes, device):
    torch.cuda.empty_cache()

    model.eval()
    total_loss = 0.0
    metric_object = metric_class(num_classes)

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, total=len(dataloader)):
            inputs = inputs.to(device)
            labels = labels.to(device)                
            y_preds = model(inputs)

            # calculate loss
            loss = criterion(y_preds, labels)
            total_loss += loss.item()

            # update batch metric information            
            metric_object.update(y_preds.cpu().detach(), labels.cpu().detach())

    evaluation_loss = total_loss / len(dataloader)
    evaluation_metric = metric_object.compute()
    return evaluation_loss, evaluation_metric

### Testing

In [5]:
def testing(model, test_dataset):
    model.eval()
    with torch.no_grad():
        # predictions on random samples
        testSamples = np.random.choice(len(test_dataset), 1).tolist()

        _, axes = plt.subplots(1,3, figsize=(20,10))

        for i, sampleID in enumerate(testSamples):
            inputImage, gt = test_dataset[sampleID]

            # input rgb image   
            inputImage = inputImage.to(device)
            landscape = inputImage.permute(1, 2, 0).cpu().detach().numpy()

            image_base = landscape.copy() 

            axes[0].imshow(landscape)
            axes[0].set_title("RGB Image")
            axes[0].axis('off')

            axes[1].imshow(gt.cpu().detach().numpy())
            axes[1].set_title("Groundtruth")
            axes[1].axis('off')

            y_pred = torch.argmax(model(inputImage.unsqueeze(0)), dim=1).squeeze(0)
            label_class_predicted = y_pred.cpu().detach().numpy()
            # Use a função 'np.where' para substituir valores maiores que 3 por 0
            label_class_predicted = np.where(label_class_predicted > 1, 0, label_class_predicted)

            axes[2].imshow(id_to_color[label_class_predicted])
            axes[2].set_title("Predicted")
            axes[2].axis('off')


            plt.show()

            return image_base, label_class_predicted

In [6]:
# Training
try:
    results = []
    for epoch in range(hyper_parameters.n_epochs):
        print(f"Starting {epoch + 1} epoch ...")
        train_loss = training(model, train_dataloader, criterion, scheduler)
        validation_loss, validation_metric = evaluating(model, val_dataloader, criterion, meanIoU, hyper_parameters.n_classes, device)

        print(f'Epoch: {epoch+1}, trainLoss:{train_loss:6.5f}, validationLoss:{validation_loss:6.5f}, validationMetric:{validation_metric:6.5f}')
       
        if validation_loss <= min_val_loss:
            print("New best model")
            min_val_loss = validation_loss
            best_validation_metric = validation_metric
            torch.save(model.state_dict(), hyper_parameters.weights_path)

        torch.cuda.empty_cache()
except Exception as ex:
    print(ex)

del model
torch.cuda.empty_cache()
gc.collect()


Starting 1 epoch ...


  5%|▍         | 33/686 [00:24<04:08,  2.63it/s] 

In [None]:

# Load model
model_test = DeepLabV3Plus(in_channels=3, output_stride=8, num_classes=hyper_parameters.n_classes).to(device)
model_test.load_state_dict(torch.load(hyper_parameters.weights_path))

image_base, label_class_predicted = testing(model_test, val_dataset)

del model_test
torch.cuda.empty_cache()
gc.collect()