In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import pyarrow.parquet as pq
import pandas as pd
import numpy as np
import re
import tensorflow as tf
import matplotlib.pyplot as plt
from keras.utils import to_categorical

import time

In [None]:
def LoadDataset(file_name):
    parquet_file = pq.ParquetFile(file_name)
    data = parquet_file.read().to_pandas()
    # Show the structure
    print(data.head())
    return data

# 1:SST2 2:IMDB 3:Yelp
dataset_type = 3
if dataset_type == 3:
    train_data = LoadDataset('/content/drive/MyDrive/Colab Notebooks/data/sst2_process/train.parquet')
    # validation_data = LoadDataset('./data/yelp_process/validation.parquet')
    # test_data = LoadDataset('./data/yelp_process/test.parquet')
    validation_data = LoadDataset('/content/drive/MyDrive/Colab Notebooks/data/sst2_process/test.parquet')
    # Define the path where the CSV file will be saved
    log_dir = '/content/drive/MyDrive/Colab Notebooks/yelp'
    csv_file_path = os.path.join(log_dir, 'cnn_yelp_training1_log.csv')

    train_texts = train_data['text']
    train_labels = to_categorical(train_data['label'], 5)
    validation_texts = validation_data['text']
    validation_labels = to_categorical(validation_data['label'],5)

   idx                                           sentence  label
0    0          hide new secretion from the parental unit      0
1    1                      contain no wit only labor gag      0
2    2  that love its character and communicates somet...      1
3    3  remains utterly satisfy to remain the same thr...      0
4    4  on the bad revenge of the nerd cliches the fil...      0
   idx                                           sentence  label
0    0                 uneasy mishmash of style and genre     -1
1    1  this film s relationship to actual tension be ...     -1
2    2  by the end of no such thing the audience like ...     -1
3    3  director rob marshall go out gun to make a gre...     -1
4    4  lathan and diggs have considerable personal ch...     -1


KeyError: 'text'

In [None]:
if __name__ == '__main__':
    # 配置数据集类型
    net = 'rnn'
    datset = 'yelp'

    # 加载训练集和测试集数据
    data_train = LoadDataset('/content/drive/MyDrive/Colab Notebooks/data/yelp5_process/train.parquet')
    data_train_label = to_categorical(data_train['label'] - 1, num_classes=5)  # 假设标签从 1 开始，减去 1
    data_test = LoadDataset('/content/drive/MyDrive/Colab Notebooks/data/yelp5_process/test.parquet')
    data_test_label = to_categorical(data_test['label'] - 1, num_classes=5)  # 假设标签从 1 开始，减去 1

    # 打印数据类型和列
    print("Training data type:", type(data_train))
    print("Training data columns:", data_train.columns)
    print("Test data type:", type(data_test))
    print("Test data columns:", data_test.columns)

    # 调用 preparing_encoder 函数
    encoder = preparing_encoder(training_set=data_train)

    # 创建模型
    if net == 'lstm':
        model = tf.keras.Sequential([
            encoder,
            tf.keras.layers.Embedding(input_dim=len(encoder.get_vocabulary()), output_dim=64, mask_zero=True),
            tf.keras.layers.LSTM(64),
            tf.keras.layers.Dense(5, activation='softmax')  # 多分类任务
        ])
    elif net == 'gru':
        model = tf.keras.Sequential([
            encoder,
            tf.keras.layers.Embedding(input_dim=len(encoder.get_vocabulary()), output_dim=64, mask_zero=True),
            tf.keras.layers.GRU(64),
            tf.keras.layers.Dense(5, activation='softmax')
        ])
    elif net == 'rnn':
        model = tf.keras.Sequential([
            encoder,
            tf.keras.layers.Embedding(input_dim=len(encoder.get_vocabulary()), output_dim=64, mask_zero=True),
            tf.keras.layers.SimpleRNN(64),
            tf.keras.layers.Dense(5, activation='softmax')
        ])
    else:
        raise ValueError('Invalid network type')

    # 定义 EarlyStopping 和 CSVLogger 回调
    earlystopping = tf.keras.callbacks.EarlyStopping(
        monitor="val_accuracy", mode="max", patience=5, restore_best_weights=True
    )

    log_dir = '/content/drive/MyDrive/Colab Notebooks/yelp'
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    training_start_time = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime())
    csv_logger = tf.keras.callbacks.CSVLogger(
        os.path.join(log_dir, f'{net}_{datset}_training_log_{training_start_time}.csv'), append=True
    )

    # 编译和训练模型
    model.compile(
        loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
        optimizer=tf.keras.optimizers.Adam(clipvalue=0.5),
        metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall(), tf.keras.metrics.AUC()]
    )

    trained_model = model.fit(
        data_train['text'], data_train_label,
        epochs=10,
        validation_data=(data_test['text'], data_test_label),
        batch_size=32,
        callbacks=[earlystopping, csv_logger]
    )

        label                                               text
177288      0  first of all I m not a big fan of buffet I try...
238756      1  thanks yelp I be look for the word to describe...
604225      2  service be so so they be receive a delivery so...
2838        2  stamoolis brother be one of the strip district...
586957      0  I want to give a star because the service staf...
       label                                               text
33553      4  come a few day ago for a lease wasn t sure of ...
9427       0  I choose the queen for my visit to las vegas f...
199        3  I go here on the day of a wedding I m from out...
12447      1  isn t it strange how the little thing can sour...
39489      4  visit here several time a year the food be alw...
Training data type: <class 'pandas.core.frame.DataFrame'>
Training data columns: Index(['label', 'text'], dtype='object')
Test data type: <class 'pandas.core.frame.DataFrame'>
Test data columns: Index(['label', 'text'], dtype=