In [None]:
#-------------------------------------------------------------------------------------JUPYTER NOTEBOOK SETTINGS-------------------------------------------------------------------------------------
from IPython.core.display import display, HTML                                    
display(HTML("<style>.container { width:100% !important; }</style>"))  
import IPython.display as display

In [None]:
import os
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Reshape, UpSampling2D, Conv2D, Dense, Dropout, Flatten
from tensorflow.keras.applications import MobileNetV3Small

In [None]:
# Function to load the latest weights
def load_latest_weights(weights_dir, file_pattern):
    """Load the latest weights based on the file modification time."""
    all_weights = [os.path.join(weights_dir, f) for f in os.listdir(weights_dir) if file_pattern in f]
    latest_weights = max(all_weights, key=os.path.getmtime, default=None)
    if latest_weights:
        print(f"Loading weights from {latest_weights}")
        return latest_weights
    else:
        print("No weights file found.")
        return None

# Gradient Reversal Layer
class GradientReversalLayer(tf.keras.layers.Layer):
    def __init__(self, lambda_):
        super(GradientReversalLayer, self).__init__()
        self.lambda_ = lambda_

    @tf.custom_gradient
    def call(self, x):
        def grad(dy):
            return -self.lambda_ * dy
        return x, grad

# Domain Classifier
def create_domain_classifier(feature_extractor_output):
    x = GradientReversalLayer(lambda_=1.0)(feature_extractor_output)
    x = Dense(128, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01))(x)
    x = Dropout(0.4)(x)
    domain_output = Dense(1, activation='sigmoid', name='domain_output')(x)
    return domain_output

# Input shapes
original_input_shape = (13, 332)
reshaped_input_shape = (13, 332, 1)
target_input_shape = (224, 224, 3)

# Model architecture
base_model = MobileNetV3Small(weights='imagenet', include_top=False, input_shape=target_input_shape, pooling='avg')
base_model.trainable = True

input_layer = Input(shape=original_input_shape)
reshaped_input = Reshape(reshaped_input_shape)(input_layer)
upsampled_input = UpSampling2D(size=(7, 1))(reshaped_input)
conv_input = Conv2D(3, (3, 3), activation='relu', padding='same')(upsampled_input)
resized_input = tf.keras.layers.Resizing(224, 224)(conv_input)
x = base_model(resized_input, training=True)

feature_extractor_output = Flatten()(x)

# Replace with the actual number of classes
num_classes = 10  

# Task-specific classifier
task_output = Dense(64, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01))(feature_extractor_output)
task_output = Dropout(0.4)(task_output)
task_output = Dense(num_classes, activation='softmax', name='task_output')(task_output)

# Domain classifier
domain_output = create_domain_classifier(feature_extractor_output)

# Create the combined model
model = Model(inputs=input_layer, outputs=[task_output, domain_output])

# Load the latest weights
latest_weights_file = load_latest_weights('saved_data/models/adversarial-training_medium-masked_mobilenetv3small-finetuned_v2', '.weights.h5')
if latest_weights_file:
    model.load_weights(latest_weights_file)

# Save the full model
model.save('saved_data/full_model.keras')
print("Full model saved as 'saved_data/full_model.h5'")