<a href="https://colab.research.google.com/github/OmidGhadami95/metaphor-detection-cnn-lstm/blob/main/CNN_LSTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import sys
import pandas as pd
import string
import re
import nltk
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer, PorterStemmer
from nltk.corpus import wordnet
from nltk import pos_tag
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import f1_score, classification_report, precision_score, recall_score
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, LSTM, Conv1D, GlobalMaxPooling1D, Concatenate, Add, LayerNormalization, Dropout, Reshape, Embedding, Bidirectional
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.callbacks import ReduceLROnPlateau, Callback
from tensorflow.keras import backend as K
import pickle

nltk.download('punkt', quiet=True)
nltk.download('averaged_perceptron_tagger', quiet=True)
nltk.download('wordnet', quiet=True)
nltk.download('punkt_tab')
nltk.download('averaged_perceptron_tagger_eng')

def get_wordnet_pos(treebank_tag):
    if treebank_tag.startswith('J'):
        return wordnet.ADJ
    elif treebank_tag.startswith('V'):
        return wordnet.VERB
    elif treebank_tag.startswith('N'):
        return wordnet.NOUN
    elif treebank_tag.startswith('R'):
        return wordnet.ADV
    else:
        return wordnet.NOUN

def preprocess_text(text):
    text = text.translate(str.maketrans('', '', string.punctuation)).lower()
    tokens = word_tokenize(text)
    pos_tags = pos_tag(tokens)
    lemmatizer = WordNetLemmatizer()
    lemmatized = [lemmatizer.lemmatize(word, get_wordnet_pos(pos)) for word, pos in pos_tags]
    stemmer = PorterStemmer()
    stemmed = [stemmer.stem(word) for word in lemmatized]
    return ' '.join(stemmed)

def create_complex_model(input_shape, num_classes, max_words):
    inputs = Input(shape=input_shape)
    x = Embedding(max_words, 128)(inputs)
    conv1 = Conv1D(128, 5, activation='relu', padding='same')(x)
    conv2 = Conv1D(128, 3, activation='relu', padding='same')(x)
    conv3 = Conv1D(128, 7, activation='relu', padding='same')(x)
    x = Concatenate()([conv1, conv2, conv3])
    x = GlobalMaxPooling1D()(x)
    for i in range(5):
        residual = x
        x = Dense(256, activation='relu')(x)
        x = LayerNormalization()(x)
        x = Dropout(0.3)(x)
        x = Dense(256, activation='relu')(x)
        x = LayerNormalization()(x)
        if i > 0:
            x = Add()([x, residual])
    x = Reshape((1, -1))(x)
    x = Bidirectional(LSTM(128, return_sequences=True))(x)
    x = Bidirectional(LSTM(64))(x)
    x = Dense(64, activation='relu')(x)
    x = LayerNormalization()(x)
    x = Dropout(0.3)(x)
    outputs = Dense(num_classes, activation='softmax')(x)
    model = Model(inputs=inputs, outputs=outputs)
    return model

class AccuracyThresholdCallback(Callback):
    def __init__(self, threshold):
        super(AccuracyThresholdCallback, self).__init__()
        self.threshold = threshold
        self.stopped_epoch = 0

    def on_epoch_end(self, epoch, logs=None):
        if logs.get('val_accuracy') > self.threshold:
            self.stopped_epoch = epoch + 1
            self.model.stop_training = True

def f1_macro(y_true, y_pred):
    def recall(y_true, y_pred):
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
        recall = true_positives / (possible_positives + K.epsilon())
        return recall

    def precision(y_true, y_pred):
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
        precision = true_positives / (predicted_positives + K.epsilon())
        return precision

    precision = precision(y_true, y_pred)
    recall = recall(y_true, y_pred)
    return 2*((precision*recall)/(precision+recall+K.epsilon()))

def calculate_metrics(model, X_test_pad, y_test, y_test_cat):
    y_pred = model.predict(X_test_pad)
    y_pred_classes = np.argmax(y_pred, axis=1)

    val_accuracy = np.mean(y_pred_classes == y_test)
    val_f1_macro = f1_score(y_test, y_pred_classes, average='macro')
    val_precision = precision_score(y_test, y_pred_classes, average='macro')
    val_recall = recall_score(y_test, y_pred_classes, average='macro')

    return val_accuracy, val_f1_macro, val_precision, val_recall

def main(input_file):
    df = pd.read_csv(input_file)
    df = df.drop_duplicates()
    df['processed_text'] = df['text'].apply(preprocess_text)
    df = df[df['processed_text'].str.strip() != '']
    df = df.reset_index(drop=True)

    if df['label'].dtype == bool:
        df['label'] = df['label'].map({True: 'True', False: 'False'})

    le = LabelEncoder()
    df['encoded_label'] = le.fit_transform(df['label'])

    X = df['processed_text']
    y = df['encoded_label']

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    max_words = 20000
    max_len = 300
    tokenizer = Tokenizer(num_words=max_words)
    tokenizer.fit_on_texts(X_train)
    X_train_seq = tokenizer.texts_to_sequences(X_train)
    X_test_seq = tokenizer.texts_to_sequences(X_test)
    X_train_pad = pad_sequences(X_train_seq, maxlen=max_len)
    X_test_pad = pad_sequences(X_test_seq, maxlen=max_len)

    num_classes = len(le.classes_)
    y_train_cat = to_categorical(y_train, num_classes)
    y_test_cat = to_categorical(y_test, num_classes)

    model = create_complex_model((max_len,), num_classes, max_words)
    model.compile(optimizer=Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics=['accuracy', f1_macro])

    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.8, patience=5, min_lr=0.0001)
    accuracy_threshold = AccuracyThresholdCallback(threshold=0.83)

    history = model.fit(X_train_pad, y_train_cat, epochs=100, batch_size=32, validation_split=0.2,
                        callbacks=[reduce_lr, accuracy_threshold], verbose=0)

    # Calculate and print final metrics
    val_accuracy, val_f1_macro, val_precision, val_recall = calculate_metrics(model, X_test_pad, y_test, y_test_cat)

    print(f"Training stopped at epoch {accuracy_threshold.stopped_epoch}")
    print(f"Final Validation Accuracy: {val_accuracy:.4f}")
    print(f"Final Validation F1 Macro: {val_f1_macro:.4f}")
    print(f"Final Validation Precision: {val_precision:.4f}")
    print(f"Final Validation Recall: {val_recall:.4f}")

    # Save the model
    with open('saved_model.pkl', 'wb') as f:
        pickle.dump(model, f)

# if __name__ == "__main__":
#     if len(sys.argv) != 2:
#         print("Usage: python3 run_train.py <input_file>")
#         sys.exit(1)

#     input_file = sys.argv[1]
#     main(input_file)

if __name__ == "__main__":
    input_file = 'train-1.csv'  # Replace with the path to your dataset
    main(input_file)

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


[1m12/12[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 85ms/step
Training stopped at epoch 45
Final Validation Accuracy: 0.7936
Final Validation F1 Macro: 0.6923
Final Validation Precision: 0.7136
Final Validation Recall: 0.6793
