In [1]:
"""
The main code for the recurrent and convolutional networks assignment.
See README.md for details.
"""
from typing import Tuple, List, Dict

import tensorflow
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, LSTM, Dense, Conv2D, MaxPool2D, Flatten, Embedding, Conv1D, GlobalMaxPool1D, Dropout, BatchNormalization, TimeDistributed
from tensorflow.keras.callbacks import EarlyStopping

def create_toy_rnn(input_shape: tuple, n_outputs: int) \
        -> Tuple[tensorflow.keras.models.Model, Dict]:
    """Creates a recurrent neural network for a toy problem."""
    model = Sequential([
        Input(shape=input_shape),
        LSTM(64, return_sequences=True),
        BatchNormalization(),
        Dropout(0.2),
        LSTM(32, return_sequences=True),
        BatchNormalization(),
        TimeDistributed(Dense(16, activation='relu')),
        BatchNormalization(),
        Dropout(0.2),
        Dense(n_outputs, activation='linear')
    ])

    optimizer = tensorflow.keras.optimizers.Adam(learning_rate=0.015, weight_decay=0.005)
    model.compile(optimizer=optimizer,
                 loss='mse',
                 metrics=['mae'])

    fit_kwargs = {
        'batch_size': 1,
        'callbacks': [EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)]
    }

    return model, fit_kwargs

def create_mnist_cnn(input_shape: tuple, n_outputs: int) \
        -> Tuple[tensorflow.keras.models.Model, Dict]:
    """Creates a convolutional neural network for digit classification."""
    model = Sequential([
        Input(shape=input_shape),
        Conv2D(32, (3, 3), activation='relu'),
        BatchNormalization(),
        MaxPool2D((2, 2)),
        Conv2D(64, (3, 3), activation='relu'),
        BatchNormalization(),
        MaxPool2D((2, 2)),
        Flatten(),
        Dense(128, activation='relu'),
        BatchNormalization(),
        Dropout(0.5),
        Dense(n_outputs, activation='softmax')
    ])

    model.compile(optimizer=tensorflow.keras.optimizers.Adam(learning_rate=0.001),
                 loss='categorical_crossentropy',
                 metrics=['accuracy'])

    fit_kwargs = {
        'batch_size': 32,
        'callbacks': [EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)]
    }

    return model, fit_kwargs

def create_youtube_comment_rnn(vocabulary: List[str], n_outputs: int) \
        -> Tuple[tensorflow.keras.models.Model, Dict]:
    """Creates a recurrent neural network for spam classification."""
    vocab_size = len(vocabulary)

    model = Sequential([
        Input(shape=(None,)),
        Embedding(input_dim=vocab_size, output_dim=64, mask_zero=True),  # Restored to 64
        LSTM(64, return_sequences=False),  # Restored to 64
        BatchNormalization(),
        Dropout(0.3),
        Dense(32, activation='relu'),  # Restored to 32
        BatchNormalization(),
        Dropout(0.3),  # Added second dropout
        Dense(n_outputs, activation='sigmoid')
    ])

    optimizer = tensorflow.keras.optimizers.Adam(learning_rate=0.001, weight_decay=0.01)
    model.compile(optimizer=optimizer,
                 loss='binary_crossentropy',
                 metrics=['accuracy'])

    fit_kwargs = {
        'batch_size': 32,
        'callbacks': [EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)]  # Reduced patience
    }

    return model, fit_kwargs

def create_youtube_comment_cnn(vocabulary: List[str], n_outputs: int) \
        -> Tuple[tensorflow.keras.models.Model, Dict]:
    """Creates a convolutional neural network for spam classification."""
    vocab_size = len(vocabulary)

    model = Sequential([
        Input(shape=(None,)),
        Embedding(input_dim=vocab_size, output_dim=32),
        Conv1D(64, 5, activation='relu'),
        GlobalMaxPool1D(),
        Dense(32, activation='relu'),
        Dense(n_outputs, activation='sigmoid')
    ])

    model.compile(optimizer='adam',
                 loss='binary_crossentropy',
                 metrics=['accuracy'])

    fit_kwargs = {
        'batch_size': 32,
        'callbacks': [EarlyStopping(monitor='val_loss', patience=10)]
    }

    return model, fit_kwargs