-
Notifications
You must be signed in to change notification settings - Fork 103
/
dropover.py
36 lines (25 loc) · 1.34 KB
/
dropover.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np
import tensorflow as tf
from model import Model
class DropoverLayer(Model):
vertex_embedding_function = {'train': None, 'test': None}
def __init__(self, shape, next_component=None, next_component_2=None):
self.next_component = next_component
self.next_component_2 = next_component_2
self.shape = shape
def compute_vertex_embeddings(self, mode='train'):
if self.vertex_embedding_function[mode] is None and mode=='train':
code_1 = self.next_component.get_all_codes(mode=mode)[0]
code_2 = self.next_component_2.get_all_codes(mode=mode)[0]
choice = tf.random_uniform(self.shape, -1, 1, dtype=tf.float32)
self.vertex_embedding_function[mode] = tf.where(choice > 0, x=code_1, y=code_2)
elif mode=='test':
self.vertex_embedding_function[mode] = self.next_component.get_all_codes(mode=mode)[0]
return self.vertex_embedding_function[mode]
def get_all_codes(self, mode='train'):
collected_messages = self.compute_vertex_embeddings(mode=mode)
return collected_messages, None, collected_messages
def get_all_subject_codes(self, mode='train'):
return self.compute_vertex_embeddings(mode=mode)
def get_all_object_codes(self, mode='train'):
return self.compute_vertex_embeddings(mode=mode)