In [1]:
import numpy as np
import math
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, Dropout, Reshape
from tensorflow.keras.utils import plot_model, Sequence
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.optimizers.schedules import ExponentialDecay
from tensorflow.keras.losses import CategoricalCrossentropy, BinaryCrossentropy
from gradientreversal import GradientReversal
from tensorflow.keras.callbacks import EarlyStopping
import itertools

In [2]:
@tf.custom_gradient
def grad_reverse(x):
    y = tf.identity(x)
    def custom_grad(dy):
        return -dy
    return y, custom_grad

In [3]:
class DataGenerator:
    def __init__(
        self, 
        data_folder, 
        source_domain,
        target_domain,
        input_shape,
        target_labels=0.1
    ):
        source_gen = tf.keras.preprocessing.image.ImageDataGenerator()
        target_gen = tf.keras.preprocessing.image.ImageDataGenerator(validation_split=1-target_labels)
        
        self.source_data = source_gen.flow_from_directory(
            data_folder + "/" + source_domain,
            batch_size=32,
            target_size=input_shape
        )

        self.target_data = target_gen.flow_from_directory(
            data_folder + "/" + target_domain,
            batch_size=32,
            target_size=input_shape,
            subset='training'
        )
        
        self.valid_target = target_gen.flow_from_directory(
            data_folder + "/" + target_domain,
            batch_size=64,
            target_size=input_shape,
            subset='validation'
        )
        
        self.classes = self.source_data.num_classes
    def train_data(self):
        return MergeSequence(self.source_data, self.target_data)
            
        
    def valid_data(self):
        return MergeSequence([], self.valid_target)

class MergeSequence(Sequence):
    def __init__(self, source, target):
        self.source = source
        self.target = target
    
    def __getitem__(self, idx):
        if idx < len(self.source) and idx < len(self.target):
            s_img, s_class = self.source[idx]       
            t_img, t_class = self.target[idx]
            
            s_class = s_class[:, np.newaxis, :]
            s_class = np.concatenate((s_class, np.zeros(s_class.shape)), axis=1)
            
            t_class = t_class[:, np.newaxis, :]
            t_class = np.concatenate((np.zeros(t_class.shape), t_class), axis=1)
            
            img = np.concatenate((s_img, t_img), axis=0)
            cls = np.concatenate((s_class, t_class), axis=0)
            
        elif idx < len(self.source):
            img, cls = self.source[idx] 
            cls = cls[:, np.newaxis, :]
            cls = np.concatenate((cls, np.zeros(cls.shape)), axis=1)
            
        elif idx < len(self.target):
            img, cls = self.target[idx]
            cls = cls[:, np.newaxis, :]
            cls = np.concatenate((np.zeros(cls.shape), cls), axis=1)
            
        return img, cls
   
    def __len__(self):
        return max(len(self.source), len(self.target))

In [4]:
def KLLoss(y_pred, classes):
    y_joint = tf.reshape(y_pred, (-1, 2*classes))

    y_domain = tf.expand_dims(tf.reduce_sum(y_pred, axis=2), -1)
    y_class = tf.expand_dims(tf.reduce_sum(y_pred, axis=1), 1)
    
    y_ind_joint = tf.reshape((y_domain * y_class), (-1,2*classes))
    return tf.keras.losses.KLDivergence()(
            tf.transpose(y_joint), 
            tf.transpose(y_ind_joint))

In [5]:
def build_encoder(input_shape=(224, 224, 3)):
    resnet = ResNet50(include_top=False, input_shape=input_shape)
    pool = GlobalAveragePooling2D()(resnet.output)
    return Model(inputs=resnet.input, outputs=pool, name="Encoder")

def build_classifier(input_shape, classes):
    classifier = Sequential(name="Classifier")
    classifier.add(Dense(1280, 
                  input_shape=input_shape[1:],
                  activation='relu'))
    classifier.add(Dense(1280, activation='relu'))
    classifier.add(Dense(
        2 * classes, 
        activation='softmax',
        name='classifier_output'
    ))
    classifier.add(Reshape((2, classes)))
    return classifier

def build_discriminator(input_shape):
    discriminator = Sequential(name="Discriminator")
    discriminator.add(Dense(1280, 
                  activation='relu'))
    discriminator.add(Dense(1280, activation='relu'))
    discriminator.add(Dense(
        2, 
        activation='softmax',
        name='domain_output'
    ))
    discriminator.build(input_shape)
    return discriminator

In [6]:
class SingleDomainModel(tf.keras.Model):
    def __init__(self, 
                 encoder=None,
                 classifier=None,
                 use_discriminator=False,
                 classes=65
                ):
        super(SingleDomainModel, self).__init__()
        self.classes = classes
        self.encoder = build_encoder() \
            if encoder is None else encoder
        self.classifier = build_classifier(self.encoder.output_shape, self.classes) \
            if classifier is None else classifier
        self.discriminator = build_discriminator(self.encoder.output_shape) \
            if use_discriminator else None
        
        self.build(self.encoder.input_shape)
        
    def call(self, inputs):
        features = self.encoder(inputs)
        classes = self.classifier(features)
        if self.discriminator is not None:
            domains = grad_reverse(features)
            domains = self.discriminator(domains)
            return classes, domains
        else:
            return classes
    
    def compile(self, optimizer=None, pred_loss=None, rep_loss=None):
        super(SingleDomainModel, self).compile()
        if optimizer is None:
            lr_schedule = ExponentialDecay(
                initial_learning_rate=1e-3,
                decay_steps=100,
                decay_rate=0.9)
            self.optimizer = SGD(learning_rate=lr_schedule, momentum=0.9) 
        else:
            self.optimizer = optimizer
            
        if self.discriminator is not None:
            self.dis_loss = BinaryCrossentropy()
            self.domain_accuracy = tf.keras.metrics.CategoricalAccuracy()
            
        if pred_loss is None:
            self.pred_loss = CategoricalCrossentropy()
        else:
            self.pred_loss = pred_loss
            
        if rep_loss is None:
            self.rep_loss = KLLoss
        else:
            self.rep_loss = rep_loss
            
        self.joint_accuracy = tf.keras.metrics.CategoricalAccuracy()
        self.class_accuracy = tf.keras.metrics.CategoricalAccuracy()
    
    def train_step(self, data):
        x, y = data
        y_dom = tf.reduce_sum(y, axis=2)
        y_class = tf.reduce_sum(y, axis=1)
        
        y = tf.reshape(y, (-1, 2*self.classes))
        
        if self.discriminator is None:    
            # Train encoder       
            with tf.GradientTape() as tape:
                y_pred = self(x)                
                kl_loss = self.rep_loss(y_pred, self.classes)
                
                y_pred = tf.reduce_sum(y_pred, axis=1)
                pred_loss = self.pred_loss(y_class, y_pred)
                self.class_accuracy.update_state(y_class, y_pred)
                
                encoder_loss = kl_loss + pred_loss
                                
                grads = tape.gradient(encoder_loss, self.encoder.trainable_weights)
                self.optimizer.apply_gradients(
                    zip(grads, self.encoder.trainable_weights)
                )


            # Train classifier
            with tf.GradientTape() as tape:
                y_pred = tf.reshape(self(x), (-1, 2*self.classes))
                self.joint_accuracy.update_state(y, y_pred)
                joint_pred_loss = self.pred_loss(y, y_pred)
                grads = tape.gradient(joint_pred_loss, self.classifier.trainable_weights)
                self.optimizer.apply_gradients(
                    zip(grads, self.classifier.trainable_weights)
                )
    
            return {
                "loss": encoder_loss,
                "kl_loss": kl_loss,    
                "pred_loss": pred_loss,
                "joint_pred_loss": joint_pred_loss,   
                "accuracy": self.joint_accuracy.result(),
                "class_accuracy": self.class_accuracy.result()
            }
    
        else:
            with tf.GradientTape() as tape:
                y_joint, y_domains = self(x)               
                kl_loss = self.rep_loss(y_joint, self.classes)

                y_pred = tf.reduce_sum(y_joint, axis=1)
                pred_loss = self.pred_loss(y_class, y_pred)
                self.class_accuracy.update_state(y_class, y_pred)
                
                y_pred = tf.reshape(y_joint, (-1, 2*self.classes))
                self.joint_accuracy.update_state(y, y_pred)
                joint_pred_loss = self.pred_loss(y, y_pred)

                dis_loss = self.dis_loss(y_dom, y_domains)
                self.domain_accuracy.update_state(y_dom, y_domains)

                encoder_loss = kl_loss + pred_loss + dis_loss
                
                weights = self.classifier.trainable_weights + \
                          self.encoder.trainable_weights + \
                          self.discriminator.trainable_weights

                grads = tape.gradient(encoder_loss, 
                                      weights)
                
                self.optimizer.apply_gradients(
                    zip(grads, weights)
                )

#             # Train encoder       
#             with tf.GradientTape() as tape:
#                 y_joint, y_domains = self(x)               
#                 kl_loss = self.rep_loss(y_joint, self.classes)
                
#                 y_pred = tf.reduce_sum(y_joint, axis=1)
#                 pred_loss = self.pred_loss(y_class, y_pred)
#                 self.class_accuracy.update_state(y_class, y_pred)
                
#                 dis_loss = self.dis_loss(y_dom, y_domains)
#                 self.domain_accuracy.update_state(y_dom, y_domains)      
#                 encoder_loss = kl_loss + pred_loss + dis_loss
#                 weights = self.encoder.trainable_weights
#                 grads = tape.gradient(encoder_loss, weights)              
#                 self.optimizer.apply_gradients(
#                     zip(grads, weights)
#                 )


#             # Train classifier
#             with tf.GradientTape() as tape:
#                 y_joint, _ = self(x)
#                 y_pred = tf.reshape(y_joint, (-1, 2*self.classes))
#                 self.joint_accuracy.update_state(y, y_pred)
#                 joint_pred_loss = self.pred_loss(y, y_pred)
#                 grads = tape.gradient(joint_pred_loss, self.classifier.trainable_weights)
#                 self.optimizer.apply_gradients(
#                     zip(grads, self.classifier.trainable_weights)
#                 )
                
#             with tf.GradientTape() as tape:
#                 _, y_domains = self(x)                           
#                 dis_loss = self.dis_loss(y_dom, y_domains)
#                 self.domain_accuracy.update_state(y_dom, y_domains)
#                 weights = self.discriminator.trainable_weights
#                 grads = tape.gradient(dis_loss, 
#                                       weights)  
#                 self.optimizer.apply_gradients(
#                     zip(grads, weights)
#                 )
                
            return {
                "loss": encoder_loss,
                "kl_loss": kl_loss,
                "pred_loss": pred_loss,
                "dis_loss": dis_loss, 
                "joint_pred_loss": joint_pred_loss,                  
                "accuracy": self.joint_accuracy.result(),
                "class_accuracy": self.class_accuracy.result(),
                "domain_accuracy": self.domain_accuracy.result(),
            }
    
    def test_step(self, data):
        x, y = data
        y_dom = tf.reduce_sum(y, axis=2)
        y_class = tf.reduce_sum(y, axis=1)
        
        y = tf.reshape(y, (-1, 2*self.classes))
        
        if self.discriminator is None:   
            y_pred = self(x)  
            y_class_pred = tf.reduce_sum(y_joint, axis=1)
            y_joint_pred = tf.reshape(y_pred, (-1, 2*self.classes)) 
            
            kl_loss = self.rep_loss(y_pred, self.classes)
            pred_loss = self.pred_loss(y_class, y_class_pred)
            self.class_accuracy.update_state(y_class, y_class_pred)
            encoder_loss = kl_loss + pred_loss
            
            joint_pred_loss = self.pred_loss(y, y_joint_pred)
            self.joint_accuracy.update_state(y, y_pred)
            
            return {
                "loss": encoder_loss,
                "kl_loss": kl_loss,    
                "pred_loss": pred_loss,   
                "joint_pred_loss": joint_pred_loss,   
                "accuracy": self.joint_accuracy.result(),
                "class_accuracy": self.class_accuracy.result()
            }
            
        else:
            y_joint, y_domains = self(x)
            y_class_pred = tf.reduce_sum(y_joint, axis=1)
            y_joint_pred = tf.reshape(y_joint, (-1, 2*self.classes)) 
            
            
            kl_loss = self.rep_loss(y_joint, self.classes)           
            pred_loss = self.pred_loss(y_class, y_class_pred)
            self.class_accuracy.update_state(y_class, y_class_pred)
            dis_loss = self.dis_loss(y_dom, y_domains)
            self.domain_accuracy.update_state(y_dom, y_domains)
            encoder_loss = kl_loss + pred_loss + dis_loss
            
            
            joint_pred_loss = self.pred_loss(y, y_joint_pred)
            self.joint_accuracy.update_state(y, y_joint_pred)

        return {
            "loss": encoder_loss,
            "kl_loss": kl_loss,    
            "pred_loss": pred_loss, 
            "dis_loss": dis_loss,
            "joint_pred_loss": joint_pred_loss,   
            "accuracy": self.joint_accuracy.result(),
            "class_accuracy": self.class_accuracy.result(),
            "domain_accuracy": self.domain_accuracy.result(),
        }
    

In [7]:
d = DataGenerator(
    "../Datasets/OfficeHomeDataset_10072016",
    "Real World",
    "Product",
    (224, 224),
    0.3
)

Found 4357 images belonging to 65 classes.
Found 1361 images belonging to 65 classes.
Found 3078 images belonging to 65 classes.


In [20]:
m = SingleDomainModel(classes=d.classes, use_discriminator=True)
m.compile()
m.summary()

Model: "single_domain_model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
Encoder (Functional)         (None, 2048)              23587712  
_________________________________________________________________
Classifier (Sequential)      (None, 2, 65)             4428930   
_________________________________________________________________
Discriminator (Sequential)   (None, 2)                 4264962   
_________________________________________________________________
categorical_accuracy (Catego multiple                  2         
_________________________________________________________________
categorical_accuracy (Catego multiple                  2         
_________________________________________________________________
categorical_accuracy (Catego multiple                  2         
Total params: 32,281,610
Trainable params: 32,228,484
Non-trainable params: 53,126
____________________________

In [21]:
early_stopping_monitor = EarlyStopping(
    monitor='val_loss',
    min_delta=0,
    patience=1,
    verbose=0,
    mode='auto',
    baseline=None,
    restore_best_weights=True
)

In [23]:
with tf.device('/device:GPU:0'):
    history = m.fit(
        d.train_data(), 
        epochs=100,
        callbacks=[early_stopping_monitor],
        validation_data = d.valid_data()
    )

Epoch 1/100
Epoch 2/100


In [18]:
m.save('real_world_to_product_30')

INFO:tensorflow:Assets written to: real_world_to_product_30/assets


In [16]:
input = Input(shape=m.encoder.input_shape[1:])
encoder = m.encoder(input)
classifier = m.classifier(encoder)
discriminator = grad_reverse(encoder)
# discriminator = GradientReversal(1.0)(encoder)
discriminator = m.discriminator(discriminator)

lr_schedule = ExponentialDecay(
    initial_learning_rate=1e-3,
    decay_steps=100,
    decay_rate=0.9)
optimizer = SGD(learning_rate=lr_schedule, momentum=0.9)

model = Model(inputs=[input], outputs=[discriminator])
model.compile(
    loss=[CategoricalCrossentropy()], 
    optimizer=optimizer, 
    metrics=['accuracy'])

model.summary()

Model: "functional_7"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_7 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
Encoder (Functional)         (None, 2048)              23587712  
_________________________________________________________________
tf_op_layer_Identity (Tensor [(None, 2048)]            0         
_________________________________________________________________
tf_op_layer_Identity_1 (Tens [(None, 2048)]            0         
_________________________________________________________________
Discriminator (Sequential)   (None, 2)                 4264962   
Total params: 27,852,674
Trainable params: 27,799,554
Non-trainable params: 53,120
_________________________________________________________________


In [18]:
x, y = d[1]

In [21]:
x.shape

(64, 224, 224, 3)

In [22]:
history = model.fit(x, tf.reduce_sum(y, axis=2), epochs=100)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100

KeyboardInterrupt: 

In [91]:
model = m.discriminator
model.compile(loss=CategoricalCrossentropy(), optimizer=optimizer)

model.fit(m.encoder(x), tf.reduce_sum(y, axis=2), epochs=100)
# model(m.encoder(x))

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78



<tensorflow.python.keras.callbacks.History at 0x7faaae6e8580>

In [28]:
class KLLoss():
    def __init__(self, classes):
        self.classes = classes
        self.__name__ = "kl_loss"
    def __call__(self, y_true, y_pred):
        y_joint = tf.reshape(y_pred, (-1, 2*self.classes))

        y_domain = tf.expand_dims(tf.reduce_sum(y_pred, axis=2), -1)
        y_class = tf.expand_dims(tf.reduce_sum(y_pred, axis=1), 1)

        y_ind_joint = tf.reshape((y_domain * y_class), (-1,2*self.classes))
        return tf.keras.losses.KLDivergence()(
                tf.transpose(y_joint), 
                tf.transpose(y_ind_joint))

In [55]:
x, y = d[1]

In [59]:
m.encoder(x).shape, y.shape

(TensorShape([64, 2048]), (64, 2, 65))

In [253]:
y_dom = tf.reduce_sum(y, axis=2)

In [295]:
dis_loss = BinaryCrossentropy()(y_dom, m(x)[1])

In [292]:
kl_loss = m.rep_loss(m(x)[0], m.classes)

In [298]:
kl_loss + dis_loss

<tf.Tensor: shape=(), dtype=float32, numpy=0.60576>

In [299]:
kl_loss, dis_loss

(<tf.Tensor: shape=(), dtype=float32, numpy=0.06295085>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5428091>)

In [168]:
tf.argmax(y_pred)

<tf.Tensor: shape=(130,), dtype=int64, numpy=
array([15, 23, 36, 28, 57, 16, 11, 14, 13, 35,  7, 57, 17, 54, 60, 58,  9,
       16, 57, 54, 54, 61,  1, 14, 20, 40, 38, 22, 36, 44, 48, 28,  4, 37,
       16, 32, 53,  1, 46,  4, 16, 31, 26, 45,  2, 32, 12, 61,  1,  4,  7,
       18, 57, 15, 15,  0, 36, 36,  7,  6, 51,  9, 61, 12, 47, 14, 24, 18,
       56,  2,  1, 28,  4, 38, 60, 14, 27, 51, 22, 57, 44, 31, 55, 22,  6,
       46, 28,  6, 61, 13, 60, 45, 53, 28,  9, 13, 11, 40, 55,  2,  2, 22,
       45, 31, 26,  4, 34, 54, 18, 51, 37, 27, 36, 23,  0, 51,  9, 61, 32,
       62, 55, 49, 48, 18, 20, 31, 60, 24, 38, 10])>

In [170]:
y = tf.reshape(y, (-1, 2*m.classes))

In [172]:
tf.argmax(y)

<tf.Tensor: shape=(130,), dtype=int64, numpy=
array([26,  0,  4, 34, 22, 23,  0,  0,  0,  0,  0, 11, 30,  9,  0,  0,  0,
        0,  0,  8,  0, 55,  7,  0,  0, 24,  2, 25,  0,  0, 45,  1,  6, 59,
       29, 27,  0,  0, 13,  0, 35, 49,  0, 16, 42,  0, 50,  3, 60,  0,  0,
       31, 17,  0, 57, 36, 28,  0, 14, 58,  0,  5,  0,  0, 20,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0])>

In [257]:
m.discriminator.summary()

AttributeError: 'NoneType' object has no attribute 'summary'

In [69]:
input_layer = Input(shape=m.encoder.input_shape[1:])
classifier_out = m.classifier(m.encoder(input_layer))
discriminator_out = m.discriminator(m.encoder(input_layer))
output_layer = [classifier_out, discriminator_out]
model = Model(input_layer, output_layer)
model.trainable = True

In [9]:
class SingleDomainModel:
    def __init__(self, 
                 classes=65):
        self.input_shape = (224,224,3)
        self.classes = classes
        
        self.model = self.build_model()
        self.classifier = self.model.get_layer("Classifier")
        self.discriminator = self.model.get_layer("Discriminator")
        self.encoder = self.model.get_layer("Encoder")
    
    def _build_encoder(self):
        resnet = ResNet50(include_top=False, input_shape=self.input_shape)
        pool = GlobalAveragePooling2D()(resnet.output)
        return Model(inputs=resnet.input, outputs=pool, name="Encoder")
    
    def _build_classifier(self, input_shape):
        classifier = Sequential(name="Classifier")
        classifier.add(Dense(1280, 
                      input_shape=input_shape[1:],
                      activation='relu'))
        classifier.add(Dropout(0.5))
        classifier.add(Dense(1280, activation='relu'))
        classifier.add(Dropout(0.5))
        classifier.add(Dense(
            2 * self.classes, 
            activation='softmax'))
        classifier.add(Reshape((2, self.classes)))
        return classifier
    
    def _build_discriminator(self, input_shape):
        discriminator = Sequential(name="Discriminator")
        discriminator.add(GradientReversal(1))
        discriminator.add(Dense(1280, 
                      activation='relu'))
        discriminator.add(Dropout(0.5))
        discriminator.add(Dense(1280, activation='relu'))
        discriminator.add(Dropout(0.5))
        discriminator.add(Dense(
            2, 
            activation='softmax'))
        discriminator.build(input_shape)
        return discriminator
    
    def build_model(self):
        encoder = self._build_encoder()
        classifier = self._build_classifier(encoder.output_shape)
        discriminator = self._build_discriminator(encoder.output_shape)
        
        input_layer = Input(shape=encoder.input_shape[1:])
        classifier_out = classifier(encoder(input_layer))
        discriminator_out = discriminator(encoder(input_layer))
        
        output_layer = [classifier_out, discriminator_out]
        return Model(input_layer, output_layer)
        
    def train(self, X, Ydomain, Yclass, epochs=100, batch_size=128):
        for epoch in range(epochs):
            pass
            

In [89]:
model.summary()

Model: "functional_22"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_19 (InputLayer)           [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
Feature_Encoder (Functional)    (None, 2048)         23587712    input_19[0][0]                   
                                                                 input_19[0][0]                   
__________________________________________________________________________________________________
Joint_Classifier (Sequential)   (None, 4, 65)        4595460     Feature_Encoder[8][0]            
__________________________________________________________________________________________________
Discriminator (Sequential)      (None, 4)            4267524     Feature_Encoder[9][0]

In [92]:
print(len(model.weights))
print(len(model.trainable_weights))
print(len(model.get_layer("Feature_Encoder").trainable_weights))
print(len(model.get_layer("Joint_Classifier").trainable_weights))
print(len(model.get_layer("Discriminator").trainable_weights))

330
224
212
6
6


In [91]:
c = model.get_layer("Joint_Classifier")
d = model.get_layer("Discriminator")
f = model.get_layer("Feature_Encoder")

c.trainable = True

In [83]:
model.get_layer("Feature_Encoder").trainable = True
# model.get_layer("Feature_Encoder").trainable = False
# model.get_layer("Joint_Classifier").trainable = True
model.get_layer("Joint_Classifier").trainable = False
model.get_layer("Discriminator").trainable = True
# model.get_layer("Discriminator").trainable = False
# model.trainable = False

In [79]:
len(model.trainable_weights)

0

In [6]:
m.classifier.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 1280)              2622720   
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 1280)              1639680   
_________________________________________________________________
dropout_1 (Dropout)          (None, 1280)              0         
_________________________________________________________________
Joint_Classifier (Dense)     (None, 260)               333060    
_________________________________________________________________
reshape (Reshape)            (None, 4, 65)             0         
Total params: 4,595,460
Trainable params: 4,595,460
Non-trainable params: 0
______________________________________________

In [25]:
m.discriminator.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_2 (Dense)              (None, 1280)              2622720   
_________________________________________________________________
dropout_2 (Dropout)          (None, 1280)              0         
_________________________________________________________________
dense_3 (Dense)              (None, 1280)              1639680   
_________________________________________________________________
dropout_3 (Dropout)          (None, 1280)              0         
_________________________________________________________________
Domain_Discriminator (Dense) (None, 4)                 5124      
Total params: 4,267,524
Trainable params: 4,267,524
Non-trainable params: 0
_________________________________________________________________


In [26]:
m.encoder.summary()

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 230, 230, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
conv1_conv (Conv2D)             (None, 112, 112, 64) 9472        conv1_pad[0][0]                  
__________________________________________________________________________________________________
conv1_bn (BatchNormalization)   (None, 112, 112, 64) 256         conv1_conv[0][0]                 
_______________________________________________________________________________________