In [None]:

# resnet50.py

import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.layers import (Conv2D, BatchNormalization, Activation, 
                                    MaxPooling2D, Add, GlobalAveragePooling2D, 
                                    Flatten, Dense, Input)
import numpy as np
import os
import datetime
from sklearn.preprocessing import LabelEncoder
from utils.metrics import calculate_metrics, print_metrics_summary
from utils.visualization import save_visualizations

class ResNet50:
    def __init__(self, input_shape, num_classes, model_dir="saved_models"):
        self.input_shape = input_shape
        self.num_classes = num_classes
        self.model_dir = model_dir
        self.log_dir = os.path.join(model_dir, "logs")
        self.model = self._build_model()
        os.makedirs(self.model_dir, exist_ok=True)
        os.makedirs(self.log_dir, exist_ok=True)
    
    def _residual_block(self, X_start, filters, name, reduce=False, res_conv2d=False):
        nb_filters_1, nb_filters_2, nb_filters_3 = filters
        strides_1 = [2,2] if reduce else [1,1]

        X = Conv2D(filters=nb_filters_1, kernel_size=[1,1], strides=strides_1, 
                  padding='same', name=f'{name}_conv1')(X_start)
        X = BatchNormalization(name=f'{name}_bn1')(X)
        X = Activation('relu')(X)

        X = Conv2D(filters=nb_filters_2, kernel_size=[3,3], strides=[1,1], 
                  padding='same', name=f'{name}_conv2')(X)
        X = BatchNormalization(name=f'{name}_bn2')(X)
        X = Activation('relu')(X)

        X = Conv2D(filters=nb_filters_3, kernel_size=[1,1], strides=[1,1], 
                  padding='same', name=f'{name}_conv3')(X)
        X = BatchNormalization(name=f'{name}_bn3')(X)

        if res_conv2d:
            X_res = Conv2D(filters=nb_filters_3, kernel_size=[1,1], strides=strides_1, 
                         padding='same', name=f'{name}_conv_res')(X_start)
            X_res = BatchNormalization(name=f'{name}_bn_res')(X_res)
        else:
            X_res = X_start

        X = Add(name=f'{name}_add')([X, X_res])
        return Activation('relu', name=f'{name}_relu')(X)

    def _build_model(self):
        X_input = Input(shape=self.input_shape, name='input')
        # Architecture to be completed by user
        outputs = X_input  # placeholder
        return Model(inputs=X_input, outputs=outputs, name='ResNet50')

    def _get_callbacks(self, model_name):
        timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        checkpoint_path = os.path.join(self.model_dir, f"{model_name}_best_{timestamp}.h5")
        checkpoint = tf.keras.callbacks.ModelCheckpoint(
            filepath=checkpoint_path,
            monitor='val_accuracy',
            save_best_only=True,
            save_weights_only=False,
            mode='max',
            verbose=1
        )
        tb_log_dir = os.path.join(self.log_dir, model_name, timestamp)
        tensorboard = tf.keras.callbacks.TensorBoard(
            log_dir=tb_log_dir,
            histogram_freq=1,
            update_freq='epoch'
        )
        early_stop = tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=10,
            restore_best_weights=True
        )
        lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=3,
            verbose=1
        )
        return [checkpoint, tensorboard, early_stop, lr_scheduler]

    def train(self, train_x, train_y, valid_x, valid_y, 
              batch_size=32, epochs=100, learning_rate=1e-4, 
              weight_decay=1e-4, model_name="ResNet50"):
        self.encoder = LabelEncoder()
        train_y_enc = self.encoder.fit_transform(train_y)
        valid_y_enc = self.encoder.transform(valid_y)
        callbacks = self._get_callbacks(model_name)
        optimizer = tf.keras.optimizers.AdamW(
            learning_rate=learning_rate,
            weight_decay=weight_decay
        )
        self.model.compile(
            optimizer=optimizer,
            loss="sparse_categorical_crossentropy",
            metrics=["accuracy"]
        )
        history = self.model.fit(
            x=train_x,
            y=train_y_enc,
            batch_size=batch_size,
            epochs=epochs,
            validation_data=(valid_x, valid_y_enc),
            callbacks=callbacks
        )
        self.save_model(model_name)
        return history

    def save_model(self, model_name, save_format='h5'):
        timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        if save_format == 'h5':
            save_path = os.path.join(self.model_dir, f"{model_name}_final_{timestamp}.h5")
            self.model.save(save_path)
        else:
            save_path = os.path.join(self.model_dir, f"{model_name}_final_{timestamp}")
            self.model.save(save_path)
        print(f"Model saved to {save_path}")
        return save_path

    def load_model(self, model_path):
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"No model found at {model_path}")
        self.model = tf.keras.models.load_model(model_path)
        if self.model.input_shape[1:] != self.input_shape:
            print(f"Warning: Loaded model input shape {self.model.input_shape[1:]} "
                  f"doesn't match expected {self.input_shape}")
        if self.model.output_shape[-1] != self.num_classes:
            print(f"Warning: Loaded model output shape {self.model.output_shape[-1]} "
                  f"doesn't match expected {self.num_classes}")
        print(f"Successfully loaded model from {model_path}")
        return self.model

    def evaluate(self, test_x, test_y, model_name="ResNet50", class_names=None):
        test_y_enc = self.encoder.transform(test_y)
        y_pred = np.argmax(self.model.predict(test_x), axis=1)
        metrics = calculate_metrics(test_y_enc, y_pred, model_name)
        print_metrics_summary(metrics)
        save_visualizations(
            model=self.model,
            x_data=test_x,
            y_true=test_y_enc,
            y_pred=y_pred,
            model_name=model_name,
            class_names=class_names
        )
        self._save_evaluation_metrics(metrics, model_name)
        return metrics

    def _save_evaluation_metrics(self, metrics, model_name):
        timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        metrics_path = os.path.join(self.model_dir, f"{model_name}_metrics_{timestamp}.txt")
        with open(metrics_path, 'w') as f:
            for key, value in metrics.items():
                if isinstance(value, (np.ndarray, list)):
                    f.write(f"{key}: {[round(v, 4) for v in value]}
")
                else:
                    f.write(f"{key}: {round(value, 4)}
")
        print(f"Metrics saved to {metrics_path}")
