# Baseline model

### original file

https://github.com/openai/gpt-2-output-dataset/blob/master/baseline.py

## Before running this notebook

1. Create /output folder
   1. Insert all crawled dataset(csv)
   1. Rename them as same with GPT dataset files.
1. Create /log folder

In [9]:
# import packages

import os
import csv

from scipy import sparse

from sklearn.model_selection import PredefinedSplit, GridSearchCV, train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction.text import TfidfVectorizer

In [10]:
# load data

def load_data(data_dir, crawled_dir, source):
    path = os.path.join(data_dir, "{}.csv".format(source))
    crawled_path = os.path.join(crawled_dir, "{}.csv".format(source))
    dataset = list(csv.reader(open(path, encoding="utf8")))
    crawled_dataset = list(csv.reader(open(crawled_path, encoding="cp949")))
    n = len(dataset)

    texts = []
    labels = [1, 0] * n

    for data in dataset:
        idx = int(data[0])
        texts.append(data[5])
        texts.append(crawled_dataset[idx][4])
    return texts, labels

In [11]:
# main function

def main(
    data_dir="data/",
    crawl_dir="output/",
    topics=["culture", "economy", "it_science", "politics", "society", "world"],
    train_test_ratio=0.1,
):
    texts_list, labels_list = [], []
    for topic in topics:
        texts, labels = load_data(data_dir, crawl_dir, topic)
        texts_list.extend(texts)
        labels_list.extend(labels)

    texts_train, texts_test, labels_train, labels_test = train_test_split(
        texts_list, labels_list, test_size=train_test_ratio, random_state=42, shuffle=True
    )

    vect = TfidfVectorizer()
    train_features = vect.fit_transform(texts_train)
    test_features = vect.transform(texts_test)

    model = LogisticRegression()
    model.fit(train_features, labels_train)
    test_accuracy = model.score(test_features, labels_test) * 100.0
    data = {
        "test_accuracy": test_accuracy,
    }
    print(data)

In [12]:
# run main function

main(topics=["culture", "economy", "it_science", "society"])

{'test_accuracy': 88.5}
