# WingsNet

In [None]:
#General OS and numerical
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
import sys
import argparse
import pandas as pd
from tqdm import tqdm_notebook as tqdm
from tqdm import tnrange as trange
import itertools
from PIL import Image
from collections import OrderedDict
import re
import glob

#Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

#Data management
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables.kps import Keypoint, KeypointsOnImage

#Image processing
import cv2 as cv
from torchvision import transforms

## Args/Params

In [None]:
#GPUs
print("CUDA availability = {}, number devices = {}".format(torch.cuda.is_available(), torch.cuda.device_count()))
for x in range(torch.cuda.device_count()):
    print(x, torch.cuda.get_device_name(x))
    
#Flags
TRAIN = True

NUM_LAYERS = 16
TRAIN_RATIO = 1.0
RESIZE = (256, 256)

#Training
gpu_name = "cuda:0"
# DEVICE = torch.device("cpu") 
DEVICE = torch.device(gpu_name if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 16
N_ITERS = 4

z_scale = 843
c_scale = 828

IMG_SIZE = (2048, 1536)
KPT_DIV = np.array([RESIZE[0], RESIZE[1], RESIZE[0], RESIZE[1], RESIZE[0], RESIZE[1], RESIZE[0], RESIZE[1],
                    RESIZE[0], RESIZE[1], RESIZE[0], RESIZE[1], RESIZE[0], RESIZE[1], RESIZE[0], RESIZE[1]])

## Dataloader

In [None]:
class WingData(Dataset):
    def __init__(self, list_paths, resize_dims=(512, 512), device='cpu'):
        'Initialization'
        super().__init__()
        
        self.list_paths = list_paths
        self.device = device
        self.resize_dims = resize_dims
        
        self.data_transform = transforms.Compose([
            transforms.Resize(resize_dims), 
            transforms.ToTensor()])
        
#         self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
#             std=[0.229, 0.224, 0.225])
        
        self.seq = iaa.Sequential([iaa.Resize(resize_dims)])
      
    def __len__(self):
        'Denotes the total number of samples'
        return len(self.list_paths)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        sample_path = self.list_paths[index]

        if not os.path.isfile(sample_path):
            print("{} is not a file/does not exist!".format(sample_path))
        # Load data and get label
#         img = Image.open(sample_path)
        image = cv.imread(sample_path)
        image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
        image = cv.normalize(image, None, alpha=0, beta=1, norm_type=cv.NORM_MINMAX, dtype=cv.CV_32FC3)
        
        image_aug = self.seq(image=image)
        
        input_tensor = ((torch.tensor(image_aug)).permute(2, 0, 1))
        
        return input_tensor, sample_path

In [None]:
# # DATA_PATH = "/storage/data/wingNet/avi_data"
# data_files = pd.read_csv(DATA_PATH, header=None, delimiter="\n").values.flatten().tolist()

# image_paths = []

# for file in data_files:
#     paths = glob.glob(file+'/*.tif')
#     for path in paths:
#         image_paths.append(path)
    
# data = WingData(image_paths, resize_dims=RESIZE, device=DEVICE)
# train_size = int(len(data)*TRAIN_RATIO)
# data_train, data_test = random_split(data, [train_size, len(data)-train_size])
# dataloader = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False)
# print("Using {} images in training, {} in validation.".format(len(data_train), len(data_test)))

In [None]:
# dataloader_test = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False)

# # dataloader_test = DataLoader(data_test, batch_size=BATCH_SIZE, shuffle=True)
# batch_iter = iter(dataloader_test)    
# batch_test = batch_iter.__next__()

# in_imgs = batch_test[0].numpy()
# # keypoints = batch_test[1].numpy()*KPT_DIV
# path = batch_test[1]

# for i in range(0, BATCH_SIZE, 1):
#     img_in = np.transpose(in_imgs[i], (1, 2, 0))
#     plt.figure()
#     plt.imshow(img_in)
# #     plt.scatter(keypoints[i][::2], keypoints[i][1::2], c='r')
# #     plt.scatter(keypoints[i][4], keypoints[i][5], c='r')

In [None]:
model = torch.load("/storage/data/wing_models/wings_resnet34_color_256x256_good", map_location=DEVICE)
model.eval()

In [None]:
CONFIG_PATH = "/storage/data/wingNet/inference"
data_paths = np.array(pd.read_csv(CONFIG_PATH, header=None, delimiter="\t").values.tolist())
for path in data_paths:
    print("{}: {}".format(path[0], path[1]))

In [None]:
def get_image_paths(data_path):
#     data_files = pd.read_csv(data_path, header=None, delimiter="\n").values.flatten().tolist()
    image_paths = []
    paths = glob.glob(data_path+'/*.tif')
    paths.extend(glob.glob(data_path+'/*.bmp'))
    paths.extend(glob.glob(data_path+'/*.png'))
    paths.extend(glob.glob(data_path+'/*.jpg'))
#     print(paths)
    for path in paths:
        image_paths.append(path)
    return image_paths

In [None]:
def pass_forward(data):
    dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
#     print("Using {} images".format(len(data)))
    
    batch_iter = iter(dataloader)    
    batch_test = batch_iter.__next__()

    input_tensor = batch_test[0].to(DEVICE, dtype=torch.float)
    output_valid = model(input_tensor).cpu().detach().numpy()
    output_valid = np.squeeze(output_valid)

    output_valid = output_valid*KPT_DIV
    
    return input_tensor, output_valid

In [None]:
show_horizontally = True
grid_width = 15
grid_height = len(data_paths)
image_width = 5
image_whitespace = 0.03
if show_horizontally:
    f, axarr = plt.subplots(grid_height, grid_width, 
                            figsize=(grid_width*image_width, 
                                     grid_height*image_width + image_whitespace*(grid_height)))
    f.subplots_adjust(hspace=image_whitespace*2, wspace=image_whitespace)
else:
    f, axarr = plt.subplots(grid_width, grid_height, 
                            figsize=(grid_height*image_width + image_whitespace*(grid_height), 
                                     grid_width*image_width))
    f.subplots_adjust(hspace=image_whitespace, wspace=image_whitespace)

for paths, img_xpos in zip(data_paths, range(0, len(data_paths), 1)):
    path = paths[1]
    name = paths[0]
    
    data = WingData(get_image_paths(path), resize_dims=RESIZE, device=DEVICE)
    train_size = int(len(data)*TRAIN_RATIO)
    data_infer, data_test = random_split(data, [train_size, len(data)-train_size])
    input_img, output_labels = pass_forward(data_infer)
    
    show_idx = 0
    for i in range(show_idx, show_idx+grid_width, 1):
        img_in = input_img[i].cpu().detach().numpy()
        img_in = np.transpose(img_in, (1, 2, 0))
        
        if show_horizontally:
            axarr[img_xpos, i].imshow(img_in)
            axarr[img_xpos, i].scatter(output_labels[i][::2], output_labels[i][1::2], c='r', marker='x')
            axarr[img_xpos, i].set_xticklabels([])
            axarr[img_xpos, i].set_yticklabels([])
            axarr[img_xpos, 0].set_ylabel(name, fontsize=40)
        else:
            axarr[i, img_xpos].imshow(img_in)
            axarr[i, img_xpos].scatter(output_labels[i][::2], output_labels[i][1::2], c='r', marker='x')
            axarr[i, img_xpos].set_xticklabels([])
            axarr[i, img_xpos].set_yticklabels([])
            axarr[0, img_xpos].set_title(name, fontsize=40)
#         plt.imshow(img_in)
#         plt.scatter(output_labels[i][::2], output_labels[i][1::2], c='r', marker='x')
# for ax, col in zip(axarr[0], DATA_NAMES):
#     ax.set_title(col)
f.savefig('foo.png')
f.savefig('foo.pdf')

In [None]:
# dataloader_test = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False)

# # dataloader_test = DataLoader(data_test, batch_size=BATCH_SIZE, shuffle=True)
# batch_iter = iter(dataloader_test)    
# batch_test = batch_iter.__next__()

# input_tensor = batch_test[0].to(DEVICE, dtype=torch.float)
# output_valid = model(input_tensor).cpu().detach().numpy()
# output_valid = np.squeeze(output_valid)

# print("Shape data = {}".format(output_valid.shape))
# output_valid = output_valid*KPT_DIV

In [None]:
# show_idx = 0
# for i in range(show_idx, show_idx+5, 1):
#     img_in = input_tensor[i].cpu().detach().numpy()
# #     gt_pts = output_valid[i][0]
# #     print("Shape est={}, gt={}".format(img_in.shape, img_in.shape))

# #     f, axarr = plt.subplots(1,2)
#     plt.figure()
#     img_in = np.transpose(img_in, (1, 2, 0))
#     plt.imshow(img_in)
#     plt.scatter(output_valid[i][::2], output_valid[i][1::2], c='r', marker='x')
# #     plt.scatter(input_valid[i][::2], input_valid[i][1::2], c='b')
# #     plt.scatter(input_valid[i][8], input_valid[i][9], c='b')
# #     print(input_valid[i])
# #     axarr[0].imshow(img_in, cmap='gray')