# Baseline model

### original file

https://github.com/openai/gpt-2-output-dataset/blob/master/baseline.py

## Before running this notebook

1. Create /output folder
   1. Insert all crawled dataset(csv)
   1. Rename them as same with GPT dataset files.
1. Create /log folder

In [27]:
# import packages

import os
import csv
import json

import numpy as np

from scipy import sparse

from sklearn.model_selection import PredefinedSplit, GridSearchCV, train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction.text import TfidfVectorizer

In [28]:
# create tokenizer
# example code from https://github.com/SKT-AI/KoGPT2

from transformers import PreTrainedTokenizerFast

tokenizer = PreTrainedTokenizerFast.from_pretrained(
    "skt/kogpt2-base-v2",
    bos_token="</s>",
    eos_token="</s>",
    unk_token="<unk>",
    pad_token="<pad>",
    mask_token="<mask>",
)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'GPT2Tokenizer'. 
The class this function is called from is 'PreTrainedTokenizerFast'.


In [29]:
# load data + preprocessing

def load_data(data_dir, crawled_dir, source):
    path = os.path.join(data_dir, "{}.csv".format(source))
    crawled_path = os.path.join(crawled_dir, "{}.csv".format(source))
    dataset = list(csv.reader(open(path, encoding="utf8")))
    crawled_dataset = list(csv.reader(open(crawled_path, encoding="cp949")))
    n = len(dataset)

    texts = []
    labels = [1, 0] * n

    for data in dataset:
        idx = int(data[0])
        tokens = tokenizer.tokenize(data[5])
        texts.append(' '.join(tokens))
        tokens = tokenizer.tokenize(crawled_dataset[idx][4])
        texts.append(' '.join(tokens))
    return texts, labels

In [30]:
# main function

def main(
    data_dir="data/",
    crawl_dir="output/",
    log_dir="log/",
    topics=["culture", "economy", "it_science", "politics", "society", "world"],
    train_test_ratio=0.1,
):
    texts_list, labels_list = [], []
    for topic in topics:
        texts, labels = load_data(data_dir, crawl_dir, topic)
        texts_list.extend(texts)
        labels_list.extend(labels)

    texts_train, texts_test, labels_train, labels_test = train_test_split(
        texts_list, labels_list, test_size=train_test_ratio, random_state=42, shuffle=True
    )

    vect = TfidfVectorizer()
    train_features = vect.fit_transform(texts_train)
    test_features = vect.transform(texts_test)

    model = LogisticRegression()
    model.fit(train_features, labels_train)
    test_accuracy = model.score(test_features, labels_test) * 100.0
    result_proba = model.predict_proba(test_features)
    result_log_proba = model.predict_log_proba(test_features)
    ce_loss = 0
    for label, value in zip(labels_test, result_log_proba):
        ce_loss -= label * value[1] + (1 - label) * value[0]
    ce_loss /= len(labels_test)
    data = {
        "test_accuracy": test_accuracy,
        "mse_loss": np.sum(np.array(labels_test) - model.predict_proba(test_features)[:, 1]) ** 2 / len(labels_test),
        "ce_loss": ce_loss,
        "label_and_result": list(zip(labels_test, model.predict(test_features).tolist(), result_proba.tolist())),
    }
    print(data)
    json.dump(data, open(os.path.join(log_dir, "result.json"), "w"), indent=4)

In [31]:
# run main function

main(topics=["culture", "economy", "it_science", "society"])

{'test_accuracy': 89.0, 'mse_loss': 0.06409735769225443, 'ce_loss': 0.4115393467022112, 'label_and_result': [(1, 1, [0.10611377796899635, 0.8938862220310037]), (0, 0, [0.7001699268963859, 0.29983007310361404]), (0, 0, [0.9373373834419062, 0.06266261655809384]), (0, 1, [0.48952946641272266, 0.5104705335872773]), (0, 0, [0.6970723336025466, 0.3029276663974534]), (0, 0, [0.8652607930036528, 0.1347392069963472]), (1, 1, [0.12125276108672156, 0.8787472389132784]), (0, 0, [0.7525794275363951, 0.24742057246360494]), (0, 0, [0.6680837107942625, 0.3319162892057374]), (0, 0, [0.8638836627478211, 0.1361163372521789]), (1, 1, [0.2769855907662836, 0.7230144092337164]), (1, 1, [0.1680113579910617, 0.8319886420089383]), (1, 1, [0.14080138402764975, 0.8591986159723503]), (1, 1, [0.27633333394205795, 0.723666666057942]), (1, 1, [0.4082642457485335, 0.5917357542514665]), (0, 0, [0.8449242298382185, 0.15507577016178148]), (1, 1, [0.27856944688494467, 0.7214305531150553]), (1, 1, [0.3228210987730924, 0.67