In [4]:
import argparse
import logging
import os
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms

import matplotlib.pyplot as plt
from unet.network_structure import UNet

import matplotlib.pyplot as plt


def plot_img_and_mask(img, mask):
    classes = mask.shape[0] if len(mask.shape) > 2 else 1
    fig, ax = plt.subplots(1, classes + 1)
    ax[0].set_title('Input image')
    ax[0].imshow(img)
    if classes > 1:
        for i in range(classes):
            ax[i + 1].set_title(f'Output mask (class {i + 1})')
            #ax[i + 1].imshow(mask[:, :, i])
            ax[i + 1].imshow(mask[i, :, :])
    else:
        ax[1].set_title(f'Output mask')
        ax[1].imshow(mask)
    plt.xticks([]), plt.yticks([])
    plt.show()
    

def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5):
    net.eval()
    img = full_img #torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False))
    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)[0]
        else:
            probs = torch.sigmoid(output)[0]

        tf = transforms.Compose([
            transforms.ToPILImage(),
            #transforms.Resize((full_img.size[1], full_img.size[0])),
            transforms.ToTensor()
        ])

        full_mask = tf(probs.cpu()).squeeze()

    if net.n_classes == 1:
        return (full_mask > out_threshold).numpy()
    else:
        return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy()


def get_args():
    parser = argparse.ArgumentParser(description='Predict masks from input images')
    parser.add_argument('--model', '-m', default='MODEL.pth', metavar='FILE',
                        help='Specify the file in which the model is stored')
    parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='Filenames of input images', required=True)
    parser.add_argument('--output', '-o', metavar='INPUT', nargs='+', help='Filenames of output images')
    parser.add_argument('--viz', '-v', action='store_true',
                        help='Visualize the images as they are processed')
    parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks')
    parser.add_argument('--mask-threshold', '-t', type=float, default=0.5,
                        help='Minimum probability value to consider a mask pixel white')
    parser.add_argument('--scale', '-s', type=float, default=0.5,
                        help='Scale factor for the input images')

    return parser.parse_args()


def get_output_filenames(args):
    def _generate_name(fn):
        split = os.path.splitext(fn)
        return f'{split[0]}_OUT{split[1]}'

    return args.output or list(map(_generate_name, args.input))


def mask_to_image(mask: np.ndarray):
    if mask.ndim == 2:
        return Image.fromarray((mask * 255).astype(np.uint8))
    elif mask.ndim == 3:
        return Image.fromarray((np.argmax(mask, axis=0) * 255 / mask.shape[0]).astype(np.uint8))
    

def make_predictions(dir_test, dir_out, viz = False, save = False):
    img_list = create_npy_list(dir_test)
    
    net = UNet(n_channels=1, n_classes=2)

    checkpoint = str(dir_checkpoint) + "/checkpoint_epoch1.pth"
    net.load_state_dict(torch.load(checkpoint))
    
    for i in enumerate(img_list):
        filename = i[1][0]
        #print(filename)
        logging.info(f'\nPredicting image {filename} ...')
        #img = Image.open(filename)
        img = torch.from_numpy(np.vstack(np.load(filename)).astype(float))[None,:]

        mask = predict_img(net=net,
                            full_img=img,
                            scale_factor=0.5,
                            out_threshold=0.5,
                            device=device)

        if viz:
             logging.info(f'Visualizing results for image {filename}, close to continue...')
             #print(img.size())
             plot_img_and_mask(img.squeeze(), mask)

        if save:
             out_filename = str(dir_out) + "/" + str(filename[-33:-4]) + "_prediction.png"
             print(out_filename)
             result = mask_to_image(mask)
             result.save(out_filename)
             logging.info(f'Mask saved to {out_filename}')
            

ModuleNotFoundError: No module named 'SeaIce'

In [5]:
dir_test = Path('/mnt/g/Shared drives/2021-gtc-sea-ice/trainingdata/testing/')
dir_out = Path('/mnt/g/Shared drives/2021-gtc-sea-ice/model/outtiles/')

make_predictions(dir_test, dir_out, True, True)

NameError: name 'make_predictions' is not defined