In [13]:
import nbimporter
import torch
import torchvision
import glob
import numpy as np
import re
import os.path
from dataset import KITTIDataset
import random
from torch.utils.data import DataLoader

In [14]:
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

In [15]:
def check_accuracy(loader, model, device="mps"):
    correct = 0
    pixels = 0
    dice_score = 0
    
    model.eval()
    
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)

            predictions = torch.sigmoid(model(x))
            predictions = (predictions > 0.5).float()

            if predictions.shape != y.shape:
                y = y.view_as(predictions)

            correct += (predictions == y).sum().item()
            pixels += y.numel()
            dice_score += (2 * (predictions*y).sum()) / ((predictions + y).sum() + 1e-8)

        accuracy = correct/pixels
        avg_dice_score = dice_score/len(loader)

        print(f"got {correct}/{pixels} with accuracy {accuracy*100:.3f}")
        print(f"Dice score: {avg_dice_score}")

        model.train()
        return accuracy

In [16]:
def save_predictions(loader, model, folder="saved_results", device="mps"):
    model.eval()
    for idx, (x,y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            predictions = torch.sigmoid(model(x))
            predictions = (predictions > 0.5).float()

        for i in range(predictions.shape[0]):
            torchvision.utils.save_image(predictions[i], f"{folder}/pred_{idx}_{i}.png")
            torchvision.utils.save_image(y[i], f"{folder}/gt_{idx}_{i}.png")

    model.train()


work for kitti:

In [17]:
def get_batches(data_folder, image_shape, batch_size):
        """
        Create batches of training data
        :param batch_size: Batch Size
        :return: Batches of training data
        """
        image_paths = glob(os.path.join(data_folder, 'image_2', '*.png'))
        label_paths = {
            re.sub(r'_(lane|road)_', '_', os.path.basename(path)): path
            for path in glob(os.path.join(data_folder, 'gt_image_2', '*_road_*.png'))}
        background_color = np.array([255, 0, 0])

        random.shuffle(image_paths)
        for batch_i in range(0, len(image_paths), batch_size):
            images = []
            gt_images = []
            for image_file in image_paths[batch_i:batch_i+batch_size]:
                gt_image_file = label_paths[os.path.basename(image_file)]

                image = scipy.misc.imresize(scipy.misc.imread(image_file), image_shape)
                gt_image = scipy.misc.imresize(scipy.misc.imread(gt_image_file), image_shape)

                gt_bg = np.all(gt_image == background_color, axis=2)
                gt_bg = gt_bg.reshape(*gt_bg.shape, 1)
                gt_image = np.concatenate((gt_bg, np.invert(gt_bg)), axis=2)

                images.append(image)
                gt_images.append(gt_image)

            yield np.array(images), np.array(gt_images)


In [18]:
def get_loaders(image_dir, mask_dir, batch, train_transform, val_transform, val_split=0.2, num_workers=0, pin_memory=True):
    all_images = [img for img in os.listdir(image_dir) if img.endswith(".png")]
    
    train_images, val_images = all_images, all_images

    train_dataset = KITTIDataset(image_dir=image_dir, mask_dir=mask_dir, image_files=train_images, transform=train_transform)
    train_loader = DataLoader(train_dataset, batch_size=batch, num_workers=num_workers, pin_memory=pin_memory, shuffle=True)

    val_dataset = KITTIDataset(image_dir=image_dir, mask_dir=mask_dir, image_files=val_images, transform=val_transform)
    val_loader = DataLoader(val_dataset, batch_size=batch, num_workers=num_workers, pin_memory=pin_memory, shuffle=False)

    return train_loader, val_loader