In [None]:
import argparse
import gym
from gym import spaces

import numpy as np
from itertools import count, chain
from collections import namedtuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

criterion = nn.MSELoss()

class CustomEnv(gym.Env):
    """Custom Environment that follows gym interface"""
    metadata = {'render.modes': ['human']}

    def __init__(self, device):
        super(CustomEnv, self).__init__()
        self.loadData()
        #TODO: Normalization needed
#         self.curves # [
#                         [ #image0
#                          [[x,y],....],[#organ1.........],.....
#                         ],......#100,000
#                        ]

#         self.samples # [
#                         [ #image0
#                          [1.54, 1.2,....],[#organ1 1.54, 1.2,....],.....#6
#                         ],......#100,000
#                        ]
#         self.max_list #8
#         self.min_list #8
#         self.estimators #6
#         self.eigen_cmps = #torso_eigen_cmps

        self.current_step = 0
        self.MAX_STEPS = 9
        self.device = device

        
        self.all_spaces = spaces.Tuple((spaces.Discrete(9), spaces.Discrete(8)))
        self.action_space = spaces.Tuple((spaces.Discrete(9), spaces.Discrete(8)))
        self.observation_space = spaces.Tuple((spaces.Box(low=0, high=1.0, 
                                                          shape=(self.torso_d*3+self.lung_d*3+self.sp_d*2+1,), 
                                                          dtype=np.float16) ))
        
#     def create_training_gr_pos_file(self):
#         f = open('/content/drive/My Drive/pgnn/BMDSXY_NODES_POS.txt', "r")
#         self.data_pos = {}
#         for i,curve in enumerate(f):
#             self.data_pos[i] = np.reshape(np.array([int(x) for x in curve.split()]), (-1,2))

    def step(self, action):
        # Execute one time step within the environment
        self._take_action(action)
        self.current_step += 1
        # if self.current_step > len(self.df.loc[:, 'Open'].values) - 6:
        if self.current_step > self.MAX_STEPS:  #RESET
            # we are not checking for 'no action' action since longer routes 
            # would mean larger rewards
            self.current_step = 0
            done = True
        else:
            done = False
        delay_modifier = (self.current_step / self.MAX_STEPS) #DEFINE
        
        torso = Polygon([self.curves_int[0][index] for index in co_in])
        
        if torso.contains(leftlung) and torso.contains(rightlung) and torso.contains(heart) \
        and torso.contains(esophagus):
            reward = -1
        else:
            reward = 0
                
        reward = reward * delay_modifier

        # print()
        # print('REWARD  ' + str(reward))
        # done = self.net_worth <= 0
        return self.__preprocess_actor_critic_input(self.current_l_dim, self.input_l_dim),  reward, done, {}
    
    def reset(self):
        self.current_step = 0
        self.selected_idx1 = np.random.randrange(len(self.samples))
        self.selected_idx2 = np.random.randrange(len(self.samples))
        self.int_level = np.random.randrange(3)
        self.samples_img1 = self.samples[self.selected_idx1]
        self.samples_img2 = self.samples[self.selected_idx2]
        self.curves_img1 = self.curves[self.selected_idx1]
        self.curves_img2 = self.curves[self.selected_idx2]
        self.samples_int = [] 
        self.samples_chain = [] 
        for sample1,sample2 in zip(self.samples_img1, self.samples_img2):
            if int_level == 0:
                self.samples_int.append(0.7*sample1 + 0.3*sample2)
            elif int_level == 1:
                self.samples_int.append(0.5*sample1 + 0.5*sample2)
            else:
                self.samples_int.append(0.3*sample1 + 0.7*sample2)
        self.samples_chain = [ev for org in self.samples_int for ev in org]
        
        self.curves_int = []
        for org in range(6):        
            curves_o = self.estimators[org].mean_
            for i,val in enumerate(self.samples_int[org]):
                curves_o = curves_o + self.estimators[org].components_[i]*val
            curves_o = np.reshape((curves_o*255.5 + 255.5), (-1, 2)).astype(int).tolist()
            self.curves_int.append(curves_o)
            
        leftlung = Polygon([self.curves_int[1][index] for index in co_in123]) #####
        rightlung = Polygon([self.curves_int[2][index] for index in co_in123])
        heart = Polygon([self.curves_int[3][index] for index in co_in123])
        spinalcord = Polygon([self.curves_int[4][index] for index in co_in45])
        esophagus = Polygon([self.curves_int[5][index] for index in co_in45])

#         self.resetPosAndDims()

#         with torch.no_grad():
#             self.input_l_dim = vae.get_latent_var([self.input_graph,])

        return self.__preprocess_input(self.samples_img1, self.samples_img2, self.samples_chain, 
                                                    self.int_level)

#     def resetPosAndDims(self):

#         pos = copy.deepcopy(self.static_shape)
#         self.current_pos = []
#         for i in range(36):
#             pos[i][0] = (pos[i][0] - 255.5) / 460.0
#             pos[i][1] = (pos[i][1] - 255.5) / 460.0
#             self.current_pos.append(pos[i])
#         self.current_pos = torch.FloatTensor(self.current_pos)
#         self.setCurrentPosFeaturesAndLDim()
    
#     def setCurrentPosFeaturesAndLDim(self):

#         x_c = self.current_pos[:,0]
#         y_c = self.current_pos[:,1]

#         eucl_dist_feat = torch.sqrt(torch.pow((x_c.repeat(36,1) - x_c.repeat(36,1).T), 2) + \
#                    torch.pow((y_c.repeat(36,1) - y_c.repeat(36,1).T), 2) + 0.00000001 )
#         self.cur_eucl_angl_feat_X = (x_c.repeat(36,1) - x_c.repeat(36,1).T)/eucl_dist_feat
#         self.cur_eucl_angl_feat_Y = (y_c.repeat(36,1) - y_c.repeat(36,1).T)/eucl_dist_feat
#         self.reset_graph.node_features[:,36:] = torch.cat([eucl_dist_feat, self.cur_eucl_angl_feat_X, self.cur_eucl_angl_feat_Y], 1)
#         self.vae.eval()
#         with torch.no_grad():
#             self.current_l_dim = vae.get_latent_var([self.reset_graph,])


    def __preprocess_input(self, samples_img1, samples_img2, samples_chain, int_level):
        '''FUSION WITH ONE_HOT VECTORS REPRESENTING NODES And LATENT VECTORS'''
        
        tor1 = samples_img1[0]
        tor2 = samples_img2[0]
        return np.asarray(samples_chain.extend(tor1).extend(tor2).append(int_level))

#     def render(self, mode='human', close=False):
        

    
    def _take_action(self, action):

        eig_cmp = action[0]
        eig_mod = action[1]
        if eig_cmp != 8:
            tor1_ev = self.samples_img1[0][eig_cmp]
            tor2_ev = self.samples_img2[0][eig_cmp]
            tor3_ev = self.samples_int[0][eig_cmp]
            max_ev = max_list[eig_cmp]
            min_ev = min_list[eig_cmp]
            modif = [min_ev, tor1_ev, tor2_ev, max_ev]
            if eig_mod // 2 == 0:
                new_ev = modif[int(eig_mod / 2) // 4]*0.1 + tor3_ev*0.9
            else:
                new_ev = modif[int(eig_mod / 2) // 4]*0.3 + tor3_ev*0.7
            self.samples_int[0][eig_cmp] = new_ev
            self.samples_chain[eig_cmp] = new_ev
            
            curves_o = self.estimators[0].mean_
            for i,val in enumerate(self.samples_int[0]):
                curves_o = curves_o + self.estimators[0].components_[i]*val
            curves_o = np.reshape((curves_o*255.5 + 255.5), (-1, 2)).astype(int).tolist()
            self.curves_int[0] = curves_o
                
        # print('Action CHOSEN ' + str(action[0]) + '    ' + str(action[1]) + '    ' + str(action[2]))

    def create_patient_seg_adj_dict(self):
        n = 12
        gap = 24

        '''Creating static graph structure for all instances'''
        self.patient_seg_adj_dict = {}
        for i in range(n + gap):
            self.patient_seg_adj_dict[i] = []

        for i in range(n):
            for j in range(n):
                if i!=j:
                    self.patient_seg_adj_dict[i].append(j)

        add_co = n
        for i in range(n):
            self.patient_seg_adj_dict[i].append(add_co)
            self.patient_seg_adj_dict[add_co].append(i)
            self.patient_seg_adj_dict[add_co].append(add_co + 1)
            self.patient_seg_adj_dict[add_co + 1].append(add_co)
            add_co+=1
            if i+1 != n:
                self.patient_seg_adj_dict[i+1].append(add_co)
                self.patient_seg_adj_dict[add_co].append(i+1)
                add_co+=1
            else:
                self.patient_seg_adj_dict[0].append(add_co)
                self.patient_seg_adj_dict[add_co].append(0)
