# WingsNet

In [None]:
#General OS and numerical
# %matplotlib qt
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import os
import sys
import argparse
import pandas as pd
from tqdm import tqdm_notebook as tqdm
from tqdm import tnrange as trange
import itertools
from PIL import Image
from collections import OrderedDict
import re

#Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

#Data management
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables.kps import Keypoint, KeypointsOnImage

#Image processing
import cv2 as cv
from torchvision import transforms

import random

matplotlib.use("Agg")

## Args/Params

In [None]:
#GPUs
print("CUDA availability = {}, number devices = {}".format(torch.cuda.is_available(), torch.cuda.device_count()))
for x in range(torch.cuda.device_count()):
    print(x, torch.cuda.get_device_name(x))
    
#Flags
TRAIN = True

#Data
# DATA_PATH = "/storage/data_storage/wings/wings/all_wings"
DATA_PATH = "/storage/data_storage/wings/wings/all_wings_ordered" #good and consistant
VALID_PATH = "/storage/data_storage/wings/wings/all_wings_valid"
# DATA_PATH = "/storage/data_storage/wings/wings/clem_wings.txt" #good and consistant
# DATA_PATH = "/storage/data_storage/wings/wings/fiona_wings.txt" #de-flipped and consistant
# DATA_PATH = "/storage/data_storage/wings/wings/fiona_wings_2.txt" #de-flipped and consistant
# DATA_PATH = "/storage/data_storage/wings/wings/ness_wings_2.txt" #de-flipped and consistant
# DATA_PATH = "/storage/data_storage/wings/wings/ness_wings_3.txt" #de-flipped and consistant
# DATA_PATH = "/storage/data_storage/wings/wings/sandra_wings.txt"  #good and consistant
# DATA_PATH = "/storage/data_storage/wings/wings/shaun_wings.txt" #de-flipped and consistant
# DATA_PATH = "/storage/data_storage/wings/wings/tamblyn_wings.txt" #de-flipped and consistant
# DATA_PATH = "/storage/data_storage/wings/wings/tamblyn_wings_2.txt" #de-flipped and consistant
# DATA_PATH = "/storage/data_storage/wings/wings/teresa_wings.txt" #de-flipped and consistant
# DATA_PATH = "/storage/data_storage/wings/wings/zoe_wings.txt" #good and consistant

# DATA_PATH = "/storage/data_storage/wings/wings/ness_wings_4.txt" # - bad - orders are not consistant
# VALID_PATH = "/storage/data_storage/wings/wings/ilaria_wings.txt" #-bad - too many overlapping wings
# DATA_PATH = "/storage/data_storage/wings/wings/ness_wings.txt" #-bad, broekn wings/overlapping
PATH_PREFIX = "/storage/data_storage/wings/"
TRAIN_RATIO = 1.0
RESIZE = (256, 256)

#Training
gpu_name = "cuda:0"
# DEVICE = torch.device("cpu") 
DEVICE = torch.device(gpu_name if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64
N_ITERS = 4

z_scale = 843
c_scale = 828

IMG_SIZE = (2048, 1536)
KPT_DIV = np.array([RESIZE[0], RESIZE[1], RESIZE[0], RESIZE[1], RESIZE[0], RESIZE[1], RESIZE[0], RESIZE[1],
                    RESIZE[0], RESIZE[1], RESIZE[0], RESIZE[1], RESIZE[0], RESIZE[1], RESIZE[0], RESIZE[1]])

## Dataloader

In [None]:
class WingData(Dataset):
    def __init__(self, list_paths, labels, resize_dims=(512, 512), augment=False, device='cpu'):
        'Initialization'
        super().__init__()
        
        self.list_paths = list_paths
        self.labels = labels
        self.device = device
        self.resize_dims = resize_dims
        self.augment = augment
        
        self.data_transform = transforms.Compose([
            transforms.Resize(resize_dims), 
            transforms.ToTensor()])
        
        self.seq_basic = iaa.Sequential([iaa.Resize(resize_dims)])
        
        self.seq1 = iaa.Sequential([
            iaa.Affine(scale=(0.7, 1.0), mode='edge'), #'reflect'
            iaa.Fliplr(0.5),
            iaa.Flipud(0.5),
            iaa.Resize(resize_dims)])
              
        self.seq2 = iaa.Sequential([
            iaa.Affine(rotate=(-60, 60), scale=(0.7, 1.1), mode='edge'), #'reflect'
            iaa.Crop(px=(0, 25)), # crop images from each side by 0 to 16px (randomly chosen)
            iaa.Fliplr(0.5), # horizontally flip 50% of the images
            iaa.Flipud(0.5),
#             iaa.AddToHueAndSaturation((-20, 20), per_channel=True),
#             iaa.Grayscale(),
            iaa.Resize(resize_dims)])
        
    def add_noise(self, image, mean, var):
        row, col, ch = image.shape
        sigma = var**0.5
        gauss = np.random.normal(mean, sigma, (row, col, ch))
        gauss = gauss.reshape(row,col,ch)
        noisy = image + gauss
        return noisy
        
    def np_to_keypoints(self, np_kpoints, image_size):
        np_kpoints = np_kpoints
        kps = [
            Keypoint(x=np_kpoints[0], y=image_size[0]-np_kpoints[1]),
            Keypoint(x=np_kpoints[2], y=image_size[0]-np_kpoints[3]),
            Keypoint(x=np_kpoints[4], y=image_size[0]-np_kpoints[5]),
            Keypoint(x=np_kpoints[6], y=image_size[0]-np_kpoints[7]),
            Keypoint(x=np_kpoints[8], y=image_size[0]-np_kpoints[9]),
            Keypoint(x=np_kpoints[10], y=image_size[0]-np_kpoints[11]),
            Keypoint(x=np_kpoints[12], y=image_size[0]-np_kpoints[13]),
            Keypoint(x=np_kpoints[14], y=image_size[0]-np_kpoints[15]),
        ]
        return kps
    
    def point_out_of_range(self, kpts, image_size):
        kpts_np = kpts.to_xy_array()
        in_range_x = (kpts_np[:, 0] >= 0).all() and (kpts_np[:, 0] < RESIZE[0]).all()
        in_range_y = (kpts_np[:, 1] >= 0).all() and (kpts_np[:, 1] < RESIZE[1]).all()
        out_of_range = not in_range_x or not in_range_y
        return out_of_range
    
    def apply_canny(self, img):
        img_gray = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
        img_gray = cv.normalize(img_gray, None, alpha=0, beta=255,
                                norm_type=cv.NORM_MINMAX, dtype=cv.CV_8U)
        img_gray = cv.GaussianBlur(img_gray, (3, 3), 0)
        img_edges = cv.Canny(img_gray,10,50) 
        return img_edges
        
    def __len__(self):
        'Denotes the total number of samples'
        return len(self.list_paths)

    def __getitem__(self, index):
        augment = True
        'Generates one sample of data'
        # Select sample
        sample_path = self.list_paths[index]

        if not os.path.isfile(sample_path):
            print("{} is not a file/does not exist!".format(sample_path))
        # Load data and get label
#         img = Image.open(sample_path)
        image = cv.imread(sample_path)
        if image is None:
            print("{} is not a valid image".format(sample_path))       
        image_size = image.shape
            
        kps = self.np_to_keypoints(self.labels[index].flatten(), image_size)
        kpsoi = KeypointsOnImage(kps, shape=image.shape)
#         image_aug, kpsoi_aug = self.seq(image=image, keypoints=kpsoi)
        if not self.augment:
            image_aug, kpsoi_aug = self.seq_basic(image=image, keypoints=kpsoi)
            image_aug = cv.cvtColor(image_aug, cv.COLOR_BGR2RGB)
            
        if self.augment:
            image_aug, kpsoi_aug = self.seq2(image=image, keypoints=kpsoi)
            out_of_range = self.point_out_of_range(kpsoi_aug, image_size)
            if out_of_range:
                image_aug, kpsoi_aug = self.seq1(image=image, keypoints=kpsoi)
#             H: 0-179, S: 0-255, V: 0-255 
            image_aug = cv.cvtColor(image_aug, cv.COLOR_BGR2HSV)
            add_hue = np.random.normal(0, 5)
            add_sat = np.random.normal(0, 10)
            add_val = np.random.normal(0, 0)
#             print("h={}, s={}, v={}".format(add_hue, add_sat, add_val))
            image_aug[:, :, 0] = np.mod((image_aug[:, :, 0]+int(add_hue)), 180)
            image_aug[:, :, 1] = np.clip((image_aug[:, :, 1]+int(add_sat)), 0, 254)
            image_aug[:, :, 2] = np.clip((image_aug[:, :, 2]+int(add_val)), 0, 254)
            
            image_aug = cv.cvtColor(image_aug, cv.COLOR_HSV2RGB)

            variance = random.uniform(0, 80)
            image_aug = self.add_noise(image_aug, 0, variance)
            
        image_aug = cv.normalize(image_aug, None, alpha=0., beta=1., 
                                 norm_type=cv.NORM_MINMAX, dtype=cv.CV_32FC3)
        img_edges = self.apply_canny(image_aug)
        img_edges = cv.normalize(img_edges, None, alpha=0., beta=1., 
                                 norm_type=cv.NORM_MINMAX, dtype=cv.CV_32FC1)
#         image_aug = image
#         kpsoi_aug = kpsoi    
        input_tensor = ((torch.tensor(image_aug)).permute(2, 0, 1))#self.normalize
        edge_tensor = (torch.tensor(img_edges))[None, :]
        output_tensor = torch.tensor(kpsoi_aug.to_xy_array().flatten()/KPT_DIV)

        return input_tensor, edge_tensor, output_tensor, sample_path

In [None]:
point_orders = {
    "clem_wings": [0,1,2,3,4,5,6,7],
    "fiona_wings": [0,1,2,3,4,5,6,7],
    "fiona_wings_2": [7,6,5,4,3,2,1,0],
    "ilaria_wings": [0,1,2,3,4,5,6,7],
    "ness_wings": [7,6,5,4,0,1,2,3],
    "ness_wings_2": [4,5,6,7,3,2,1,0],
    "ness_wings_3": [7,6,5,4,0,1,2,3],
    "ness_wings_4": [7,6,5,4,3,2,1,0],
    "sandra_wings": [0,1,2,3,4,5,6,7],
    "shaun_wings": [0,1,2,3,4,5,6,7],
    "tamblyn_wings": [6,7,4,5,3,2,1,0],
    "tamblyn_wings_2": [0,1,2,3,7,6,5,4],
    "tamblyn_wings_3": [0,1,2,3,7,6,5,4],
    "teresa_wings": [7,6,5,4,3,2,1,0],
    "teresa_wings_2": [0,1,2,3,4,5,6,7],
    "zoe_wings": [0,1,2,3,4,5,6,7],
    "validation": [7,6,5,4,3,2,1,0],
    }

def get_paths_from_tps_file(path_to_file):
    # DATA_PATH = "/storage/data/wingNet/landmarks"
    data_files = pd.read_csv(path_to_file, header=None, delimiter="\n").values.flatten().tolist()

    image_paths = []
    feature_coords = []
    success_cnt = 0
    fail_cnt = 0

    for file in data_files:
        file_path = PATH_PREFIX+file
#         print(file_path)
        folder_names = re.split('/|\n', file_path)
        point_order = point_orders[folder_names[5]]
        f = open(file_path,'r')
        cnt = 0

        folder_path = os.path.dirname(file)
        img_feature_coords = []
        
        warning_given = False
        for line in f:
            str_in = re.split('=|\n', line)
            if str_in[0]=="SCALE" or str_in[0]=="LM" or str_in[0]=="ID":
                continue
            elif str_in[0]=="IMAGE":
                image_name = re.split('=|\n', line)
                if image_name[1][0]==".":
                    image_name[1] = image_name[1][1:]
#                     print("First character is dot: {}".format(image_name[1]))
                image_path = (PATH_PREFIX+folder_path+"/"+image_name[1]).strip()
                if os.path.isfile(image_path) and len(img_feature_coords) == 8:
                    image_paths.append(image_path)
                    
                    features = np.asarray(img_feature_coords, dtype=np.float32, order='C')
                    permuted_features = []
                    for i in range(0, len(features), 1):
                        permuted_features.append(img_feature_coords[point_order[i]])
                    permuted_features = np.asarray(permuted_features, dtype=np.float32, order='C')
                    feature_coords.append(permuted_features)
                    success_cnt+=1
                else:
                    if not warning_given:
#                         print("==================================")
                        print("Issue with {} (has {} coordinates)".format(
                            image_path, len(img_feature_coords)))
#                         print(img_feature_coords)
#                         print("{}, {}".format(image_name, line))
                        warning_given = True
                    
                    fail_cnt+=1
#                     print(img_feature_coords)
                img_feature_coords = []
            else:
                coords_str = str.split(line)
                img_feature_coords.append(coords_str)
    print("Success/fail = {}/{}".format(success_cnt, fail_cnt))
    return image_paths, feature_coords

In [None]:
image_paths, feature_coords = get_paths_from_tps_file(DATA_PATH)
data = WingData(image_paths, feature_coords, resize_dims=RESIZE, augment=True, device=DEVICE)
train_size = int(len(data)*TRAIN_RATIO)
data_train, data_test = random_split(data, [train_size, len(data)-train_size])

image_paths, feature_coords = get_paths_from_tps_file(VALID_PATH)
data = WingData(image_paths, feature_coords, resize_dims=RESIZE, augment=False, device=DEVICE)
train_size = int(len(data)*TRAIN_RATIO)
data_valid, data_test = random_split(data, [train_size, len(data)-train_size])

dataloader_valid = DataLoader(data_valid, batch_size=BATCH_SIZE, shuffle=True)
dataloader = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=True)

print("Using {} images in training, {} in validation.".format(len(data_train), len(data_valid)))

In [None]:
%matplotlib inline
dataloader_test = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=True)

# dataloader_test = DataLoader(data_test, batch_size=BATCH_SIZE, shuffle=True)
cnt=0
for batch_test in dataloader_test:
    cnt+=1
    if cnt < 0:
        print(cnt)
        continue
    print(cnt)
    images = batch_test[0]
    edge_images = batch_test[1]
    labels = batch_test[2]
    filename = batch_test[3]
    img = images[0].permute(1, 2, 0).numpy()
    edge_img = edge_images[0][0].numpy()
    f, axarr = plt.subplots(1,2)
    
    axarr[0].imshow(img)
    axarr[1].imshow(edge_img)
    break

#     images = batch_test[0].to(DEVICE, dtype=torch.float)
#     labels = batch_test[1].to(DEVICE, dtype=torch.float)
#     filename = batch_test[2]
    
#     #Forward pass
#     model.eval()
#     NN_out = model(images)
#     model.train()
    
#     input_valid = np.squeeze(labels.cpu().detach().numpy())*KPT_DIV
#     output_valid = np.squeeze(NN_out.cpu().detach().numpy())*KPT_DIV
#     img_in = images[0][0].permute(1, 2, 0).cpu().detach().numpy()
#     for i in range(0, 8, 1):
#         plt.figure()
#         plt.imshow(img_in, cmap='gray')
#         plt.scatter(input_valid[2*i], input_valid[2*i+1], c='r')
#     break

In [None]:
def find_best_orientation(kpts_gt, kpts_est):
    DIFF_THRESH = 0.5
    for batch in range(0, len(kpts_gt), 1):
        tmp = kpts_est[batch].clone()
        
        #no flip
        min_loss = criterion(kpts_gt[batch], tmp)
        original_loss = min_loss
        min_op = "identity"
        #horizontal flip
        tmp[0::2] = 1.0-tmp[0::2]
        loss = criterion(kpts_gt[batch], tmp)
        if loss < (min_loss*DIFF_THRESH):
            min_loss = loss
            min_op = "hor"
        #horizontal and vertical flip          
        tmp[1::2] = 1.0-tmp[1::2]
        loss = criterion(kpts_gt[batch], tmp)
        if loss < (min_loss*DIFF_THRESH):
            min_loss = loss
            min_op = "hor_ver"
        #vertical flip
        tmp[0::2] = 1.0-tmp[0::2]
        loss = criterion(kpts_gt[batch], tmp)
        if loss < (min_loss*DIFF_THRESH):
            min_loss = loss
            min_op = "ver"
                
#         if min_op=="hor":
# #             print("horizontal flip on {} - loss = {} v {}".format(batch, min_loss, original_loss))
#             kpts_est[batch, 0::2] = 1.0-kpts_est[batch, 0::2]
#         elif min_op=="hor_ver":
# #             print("double flip on {} - loss = {} v {}".format(batch, min_loss, original_loss))
#             kpts_est[batch, :] = 1.0-kpts_est[batch, :]
#         elif  min_op=="ver":
# #             print("vertical flip on {} - loss = {} v {}".format(batch, min_loss, original_loss))
#             kpts_est[batch, 1::2] = 1.0-kpts_est[batch, 1::2]

    
    return original_loss, min_loss, min_op

In [None]:
from IPython import display
import pylab as pl
import time

fix_automatically = False
dataloader_test = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False)
cnt=0
incorrect_label_order = []
show_dp = 5
colors = [0,1,2,3,4,5,6,7]
leg_pts_x = np.arange(0, RESIZE[0], RESIZE[0]/8)
leg_pts_y = np.full((8), RESIZE[0])
corrupted = []
for batch_test in dataloader_test:
    cnt+=1
    if cnt < 0:
        if cnt%100==0:
            print(cnt)
        continue

    images = batch_test[0].to(DEVICE, dtype=torch.float)
    edge_images = batch_test[1].to(DEVICE, dtype=torch.float)
    labels = batch_test[2].to(DEVICE, dtype=torch.float)
    filename = batch_test[3]
    
    #Forward pass
    model.eval()
    NN_out = model(images)
    model.train()
    loss, min_loss, min_op = find_best_orientation(NN_out, labels)
    print("{}: {:0.8f}, {:0.8f}, {}".format(cnt, loss, min_loss, min_op))
    if fix_automatically:
        if cnt%250==0:
            print(cnt)
        if min_op == "hor":
            print("{}: Flipping {} horizontally".format(cnt, filename[0]))
            img = cv.imread(filename[0])
            img = cv.flip(img, 1)
            cv.imwrite(filename[0], img)
        if min_op == "ver":
            print("{}: Flipping {} vertically".format(cnt, filename[0]))
            img = cv.imread(filename[0])
            img = cv.flip(img, 0)
            cv.imwrite(filename[0], img)
        if min_op == "hor_ver":
            print("{}: Flipping {} horizontally then vertically".format(cnt, filename[0]))
            img = cv.imread(filename[0])
            img = cv.flip(img, 1)
            img = cv.flip(img, 0)
            cv.imwrite(filename[0], img)
        if min_op is not "identity":
            print("loss={:0.8f}, flip_loss={:0.8f}".format(loss, min_loss))
    else:
#         if loss.item() > 0.01 or min_op is not "identity":
#             print("{}: {} = {}".format(cnt, filename[0], loss.item()))
#             os.remove(filename[0])
        print("{}: {} = {}".format(cnt, filename, loss.item()))
        if loss.item() > 0.01 or min_op is not "identity":
    #         print(filename)
            input_valid = np.squeeze(labels.cpu().detach().numpy())*KPT_DIV
            output_valid = np.squeeze(NN_out.cpu().detach().numpy())*KPT_DIV
            img_in = images[0][0].cpu().detach().numpy()
            plt.figure()
            plt.imshow(img_in, cmap='gray')
            plt.scatter(output_valid[::2], output_valid[1::2], c='r', marker='x')
    #         plt.scatter(input_valid[::2], input_valid[1::2], c='b')
            plt.scatter(input_valid[::2], input_valid[1::2], c=colors, cmap='rainbow')
            plt.scatter(leg_pts_x, leg_pts_y, c=colors, cmap='rainbow')

            display.clear_output(wait=True)
            display.display(plt.gcf())
            input_txt = input("Continue...{}: loss={:0.8f}, flip_loss={:0.8f} path={} - suggested op={}".format(
                cnt, loss.item(), min_loss, filename[0], min_op))
            if input_txt == "v":
                img = cv.imread(filename[0])
                img = cv.flip(img, 0)
                cv.imwrite(filename[0], img)
    #             plt.imshow(img_in, cmap='gray')
            if input_txt == "h":
                img = cv.imread(filename[0])
                img = cv.flip(img, 1)
                cv.imwrite(filename[0], img)
            if input_txt == "hv" or input_txt == "vh":
                img = cv.imread(filename[0])
                img = cv.flip(img, 1)
                img = cv.flip(img, 0)
                cv.imwrite(filename[0], img)
    #             plt.imshow(img_in, cmap='gray')
            if input_txt == "d":
                os.remove(filename[0])
            if input_txt == "i":
                incorrect_label_order.append(filename[0])

            input("Action={}: loss={:0.8f} path={}".format(input_txt, loss.item(), filename[0]))
            plt.close()

print(incorrect_label_order)

In [None]:
output_linear = torch.nn.Linear(512, 16, bias=True)
conv1 = torch.nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model = models.resnet34(pretrained=True)
model.fc = output_linear
model.conv1 = conv1

# model = torch.load("/storage/data/models/wings_resnet34_color_256x256", map_location=DEVICE)
# model = torch.load("/storage/data/wing_models/wings_resnet34_gs_all_1mlp_very_good", map_location=DEVICE)
model.to(DEVICE)

In [None]:
# for param in model.parameters():
#     param.requires_grad = True

In [None]:
# model.conv1.requires_grad = False
# model.bn1.requires_grad = False

# for param in model.layer1.parameters():
#     param.requires_grad = False
    
# for param in model.layer2.parameters():
#     param.requires_grad = False
    
# for param in model.layer3.parameters():
#     param.requires_grad = False
    
# for param in model.parameters():
#     print(param.requires_grad)

In [None]:
criterion = nn.MSELoss()

optim = torch.optim.Adam(model.parameters())
#Loss tracking
loss_arr = []
valid_arr = []

In [None]:
def flip_loss(kpts_gt, kpts_est):
    DIFF_THRESH = 0.1
    for batch in range(0, len(kpts_gt), 1):
        tmp = kpts_est[batch].clone()
        
        #no flip
        min_loss = criterion(kpts_gt[batch], tmp)
        original_loss = min_loss
        min_op = "identity"
        #horizontal flip
        tmp[0::2] = 1.0-tmp[0::2]
        loss = criterion(kpts_gt[batch], tmp)
        if loss < min_loss*DIFF_THRESH:
            min_loss = loss
            min_op = "hor"
        #horizontal and vertical flip          
        tmp[1::2] = 1.0-tmp[1::2]
        loss = criterion(kpts_gt[batch], tmp)
        if loss < min_loss*DIFF_THRESH:
            min_loss = loss
            min_op = "hor_ver"
        #vertical flip
        tmp[0::2] = 1.0-tmp[0::2]
        loss = criterion(kpts_gt[batch], tmp)
        if loss < min_loss*DIFF_THRESH:
            min_loss = loss
            min_op = "ver"
                
        if min_op=="hor":
#             print("horizontal flip on {} - loss = {} v {}".format(batch, min_loss, original_loss))
            kpts_est[batch, 0::2] = 1.0-kpts_est[batch, 0::2]
        elif min_op=="hor_ver":
#             print("double flip on {} - loss = {} v {}".format(batch, min_loss, original_loss))
            kpts_est[batch, :] = 1.0-kpts_est[batch, :]
        elif  min_op=="ver":
#             print("vertical flip on {} - loss = {} v {}".format(batch, min_loss, original_loss))
            kpts_est[batch, 1::2] = 1.0-kpts_est[batch, 1::2]

    
    return criterion(kpts_gt, kpts_est)
    

In [None]:
N_ITERS = 30
lr = 0.00002
use_flip_loss = False
use_flip_loss_valid = False
epoch = 0
loss_thresh = 1000.
if TRAIN:
    t = tqdm(range(N_ITERS), desc="epoch: ")
    for i in t:
        optim = torch.optim.Adam(model.parameters(), lr)
        rec = True
        inner = tqdm(dataloader, "batch: ", leave=False)
        ignored = []
        
        #Get a fresh validation batch
        batch_iter_valid = iter(dataloader_valid)    
        batch_valid = batch_iter_valid.__next__()
        input_valid = batch_valid[0].to(DEVICE, dtype=torch.float)
        edge_valid = batch_valid[1].to(DEVICE, dtype=torch.float)
        label_valid = batch_valid[2].to(DEVICE, dtype=torch.float)
        valid_tensor = torch.cat((input_valid, edge_valid), 1)
                
        for batch in inner:
            optim.zero_grad()
            images = batch[0].to(DEVICE, dtype=torch.float)
            edge_images = batch[1].to(DEVICE, dtype=torch.float)
            labels = batch[2].to(DEVICE, dtype=torch.float)
            input_tensor = torch.cat((images, edge_images), 1)

            #Forward pass
            NN_out = model(input_tensor)
            if use_flip_loss:
#                 loss = invariant_mse_loss(NN_out, labels)
                loss = flip_loss(NN_out, labels)
            else:
                loss = criterion(NN_out, labels)
            
            if loss.item() <= loss_thresh:  
                #Training
                loss.backward()
                optim.step()

                loss_arr.append(loss.item())

                model.eval()
                #Validation loss                
                output_valid = model(valid_tensor)
                if use_flip_loss_valid:
                    loss_valid = flip_loss(output_valid, label_valid)
                else:
                    loss_valid = criterion(output_valid, label_valid)
                valid_arr.append(loss_valid.item())

                model.train()
                inner.set_description("loss: {:.6f}, v_loss: {:.6f}".format(loss.item(), loss_valid.item()))
                #Set the first batch loss as the loss in the tqdm description
                if rec==True:
                    t.set_description("loss: {:.8f}, v_loss: {:.8f}".format(loss.item(), loss_valid.item()))
                    rec = False
            else:
                ignored.append(loss.item())
                
        print("epoch {}:lr={}, loss={}, v_loss={}".format(epoch, lr, loss.item(), loss_valid.item()))
        torch.save(model, "/storage/data/models/wings_resnet34_color_canny")
        if epoch%4 == 0:
            lr = lr*0.5
        epoch += 1

In [None]:
print(lr)
lim = 0.0005
if TRAIN:
    plt.ylim(0, lim)
    plt.scatter(range(len(loss_arr[::1])), loss_arr[::1])

In [None]:
if TRAIN:
    plt.ylim(0, lim)
    plt.scatter(range(len(valid_arr[::1])), valid_arr[::1])

In [None]:
print(lr)
lim = 0.00008
if TRAIN:
    plt.ylim(0, lim)
    plt.scatter(range(len(loss_arr[::1])), loss_arr[::1])

In [None]:
lim = 0.0003
if TRAIN:
    plt.ylim(0, lim)
    plt.scatter(range(len(valid_arr[::1])), valid_arr[::1])

In [None]:
loss_arr = []

In [None]:
torch.save(model, "/storage/data/models/wings_resnet34_color_canny_good")

In [None]:
dataloader_test = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=True)
dataloader_test = DataLoader(data_valid, batch_size=BATCH_SIZE, shuffle=True)
batch_iter = iter(dataloader_test)    
batch_test = batch_iter.__next__()

input_tensor = batch_test[0].to(DEVICE, dtype=torch.float)
edge_tensor = batch_test[1].to(DEVICE, dtype=torch.float)
gt_tensor = batch_test[2].numpy()
validation_paths = batch_test[3]
combined_tensor = torch.cat((input_tensor, edge_tensor), 1)

model.eval()
input_valid = (batch_test[2].cpu().detach().numpy())
output_valid = model(combined_tensor).cpu().detach().numpy()
output_valid = np.squeeze(output_valid)
model.train()

print("Shape data = {}, GT = {}".format(input_valid.shape, output_valid.shape))

In [None]:
output_valid = output_valid*KPT_DIV
input_valid = input_valid*KPT_DIV

In [None]:
%matplotlib inline
colors = [0,1,2,3,4,5,6,7]
colormap = 'viridis'
leg_pts_x = np.arange(10, RESIZE[0], RESIZE[0]/8)
leg_pts_y = np.full((8), RESIZE[0]-10)
strt_idx = 20
# for i in range(strt_idx, strt_idx+5, 1):
for i in range(0, BATCH_SIZE, 1):
#     print(validation_paths[i])
    img_in = input_tensor[i].cpu().detach().permute(1, 2, 0).numpy()
#     print("Shape est={}, gt={}".format(output_valid.shape, input_valid.shape))
#     print("{}, {}".format(np.max(img_in), np.min(img_in)))
    plt.figure()
    plt.imshow(img_in)
#     plt.scatter(input_valid[i][::2], input_valid[i][1::2], c=colors, cmap=colormap)
#     plt.scatter(leg_pts_x, leg_pts_y, c=colors, cmap=colormap)
    plt.scatter(output_valid[i][::2], output_valid[i][1::2], c='r', marker='x')