In [1]:
import pandas as pd
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding
from torch import cuda
import re
import html

device = 'cuda' if cuda.is_available() else 'cpu'
print(device)

  from .autonotebook import tqdm as notebook_tqdm


cuda


# Settings

In [2]:
MAX_LEN = 512
BATCH_SIZE = 4
EPOCHS = 1
LEARNING_RATE = 1e-05

tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased', do_lower_case=True)
model = AutoModelForSequenceClassification.from_pretrained('distilbert-base-cased').to(device)

Some weights of the model checkpoint at distilbert-base-cased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['pre_classifier.bias', 'pre_classifier.weight', 'classifier

# Preprocessing

In [3]:
train = pd.read_csv("../data/preprocessed/fakes_train.csv", index_col=0, encoding='utf-8', engine='python')
validation = pd.read_csv("../data/preprocessed/fakes_validation.csv", index_col=0, encoding='utf-8', engine='python')

sep_token = '<\|reply\|>'

def regex_text(text):
    text = html.unescape(text)
    text = re.sub(r"\#\'", r"'", text)
    text = re.sub(r"\s+$", '', text)    
    return re.findall(sep_token + " (,?.*)", text)

def label_to_list(label):
    if label:
        return [1]
    else:
        return [0]

def clean_dataframe(df):
    df = df[df['text'].str.contains(sep_token)]
    df['text'] = df['text'].apply(regex_text)
    df = df[df['text'].str.len() != 0]
    df['text'] = df['text'].apply(lambda x: x[0])
    df['label'] = df['label'].apply(label_to_list)
    return df

validation = clean_dataframe(validation)
train = clean_dataframe(train)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['text'] = df['text'].apply(regex_text)


In [4]:
dataset = dict()
dataset['validation'] = Dataset.from_pandas(validation, preserve_index=False)
dataset['train'] = Dataset.from_pandas(train, preserve_index=False)
datasets = DatasetDict(dataset)

In [5]:
def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, padding=True, max_length=MAX_LEN)

tokenized_datasets = datasets.map(
    tokenize_function,
    batched=True,
    num_proc=1,
    remove_columns=["text"],
    )

                                                                      

# Training

In [6]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

training_args = TrainingArguments(
    output_dir="../output/bert_discriminator",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,
    save_steps=10000,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
)

trainer.train()

  0%|          | 0/30823 [00:00<?, ?it/s]You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  2%|▏         | 500/30823 [03:04<2:46:42,  3.03it/s]

{'loss': 0.4571, 'learning_rate': 1.967556694676054e-05, 'epoch': 0.02}


  3%|▎         | 1000/30823 [06:03<2:39:22,  3.12it/s]

{'loss': 0.4036, 'learning_rate': 1.9351133893521072e-05, 'epoch': 0.03}


  5%|▍         | 1500/30823 [09:03<2:15:17,  3.61it/s]

{'loss': 0.3722, 'learning_rate': 1.902670084028161e-05, 'epoch': 0.05}


  6%|▋         | 2000/30823 [11:27<1:53:54,  4.22it/s]

{'loss': 0.3485, 'learning_rate': 1.8702267787042146e-05, 'epoch': 0.06}


  8%|▊         | 2500/30823 [13:50<2:16:17,  3.46it/s]

{'loss': 0.3035, 'learning_rate': 1.8377834733802682e-05, 'epoch': 0.08}


 10%|▉         | 3000/30823 [16:12<2:04:55,  3.71it/s]

{'loss': 0.3485, 'learning_rate': 1.8053401680563216e-05, 'epoch': 0.1}


 11%|█▏        | 3500/30823 [18:33<2:02:39,  3.71it/s]

{'loss': 0.3229, 'learning_rate': 1.7728968627323753e-05, 'epoch': 0.11}


 13%|█▎        | 4000/30823 [20:56<2:21:25,  3.16it/s]

{'loss': 0.2961, 'learning_rate': 1.740453557408429e-05, 'epoch': 0.13}


 15%|█▍        | 4500/30823 [23:18<2:04:34,  3.52it/s]

{'loss': 0.2842, 'learning_rate': 1.7080102520844826e-05, 'epoch': 0.15}


 16%|█▌        | 5000/30823 [25:42<1:58:06,  3.64it/s]

{'loss': 0.3009, 'learning_rate': 1.675566946760536e-05, 'epoch': 0.16}


 18%|█▊        | 5500/30823 [28:05<2:07:43,  3.30it/s]

{'loss': 0.3169, 'learning_rate': 1.6431236414365897e-05, 'epoch': 0.18}


 19%|█▉        | 6000/30823 [30:28<1:46:43,  3.88it/s]

{'loss': 0.3069, 'learning_rate': 1.6106803361126434e-05, 'epoch': 0.19}


 21%|██        | 6500/30823 [32:52<2:01:08,  3.35it/s]

{'loss': 0.2965, 'learning_rate': 1.578237030788697e-05, 'epoch': 0.21}


 23%|██▎       | 7000/30823 [35:14<1:47:12,  3.70it/s]

{'loss': 0.3083, 'learning_rate': 1.5457937254647504e-05, 'epoch': 0.23}


 24%|██▍       | 7500/30823 [37:36<2:02:58,  3.16it/s]

{'loss': 0.298, 'learning_rate': 1.513350420140804e-05, 'epoch': 0.24}


 26%|██▌       | 8000/30823 [39:58<1:52:22,  3.38it/s]

{'loss': 0.3036, 'learning_rate': 1.4809071148168578e-05, 'epoch': 0.26}


 28%|██▊       | 8500/30823 [42:21<1:42:41,  3.62it/s]

{'loss': 0.3086, 'learning_rate': 1.4484638094929113e-05, 'epoch': 0.28}


 29%|██▉       | 9000/30823 [44:47<2:03:40,  2.94it/s]

{'loss': 0.308, 'learning_rate': 1.4160205041689648e-05, 'epoch': 0.29}


 31%|███       | 9500/30823 [47:09<1:52:29,  3.16it/s]

{'loss': 0.3013, 'learning_rate': 1.3835771988450185e-05, 'epoch': 0.31}


 32%|███▏      | 10000/30823 [49:31<1:37:32,  3.56it/s]

{'loss': 0.2895, 'learning_rate': 1.351133893521072e-05, 'epoch': 0.32}


 34%|███▍      | 10500/30823 [51:56<1:32:03,  3.68it/s]

{'loss': 0.2776, 'learning_rate': 1.3186905881971257e-05, 'epoch': 0.34}


 36%|███▌      | 11000/30823 [54:17<1:27:47,  3.76it/s]

{'loss': 0.2913, 'learning_rate': 1.2862472828731792e-05, 'epoch': 0.36}


 37%|███▋      | 11500/30823 [56:39<1:26:04,  3.74it/s]

{'loss': 0.2677, 'learning_rate': 1.2538039775492327e-05, 'epoch': 0.37}


 39%|███▉      | 12000/30823 [59:03<1:28:45,  3.53it/s]

{'loss': 0.298, 'learning_rate': 1.2213606722252864e-05, 'epoch': 0.39}


 41%|████      | 12500/30823 [1:01:26<1:33:38,  3.26it/s]

{'loss': 0.2878, 'learning_rate': 1.18891736690134e-05, 'epoch': 0.41}


 42%|████▏     | 13000/30823 [1:03:49<1:36:32,  3.08it/s]

{'loss': 0.2712, 'learning_rate': 1.1564740615773936e-05, 'epoch': 0.42}


 44%|████▍     | 13500/30823 [1:06:11<1:22:04,  3.52it/s]

{'loss': 0.233, 'learning_rate': 1.1240307562534471e-05, 'epoch': 0.44}


 45%|████▌     | 14000/30823 [1:08:34<1:08:34,  4.09it/s]

{'loss': 0.2641, 'learning_rate': 1.0915874509295008e-05, 'epoch': 0.45}


 47%|████▋     | 14500/30823 [1:10:57<1:23:11,  3.27it/s]

{'loss': 0.2351, 'learning_rate': 1.0591441456055545e-05, 'epoch': 0.47}


 49%|████▊     | 15000/30823 [1:13:18<1:22:22,  3.20it/s]

{'loss': 0.2404, 'learning_rate': 1.0267008402816078e-05, 'epoch': 0.49}


 50%|█████     | 15500/30823 [1:15:40<1:16:56,  3.32it/s]

{'loss': 0.271, 'learning_rate': 9.942575349576615e-06, 'epoch': 0.5}


 52%|█████▏    | 16000/30823 [1:18:04<1:04:05,  3.85it/s]

{'loss': 0.2846, 'learning_rate': 9.618142296337152e-06, 'epoch': 0.52}


 54%|█████▎    | 16500/30823 [1:20:26<1:05:47,  3.63it/s]

{'loss': 0.2521, 'learning_rate': 9.293709243097687e-06, 'epoch': 0.54}


 55%|█████▌    | 17001/30823 [1:22:48<54:54,  4.20it/s]  

{'loss': 0.2587, 'learning_rate': 8.969276189858224e-06, 'epoch': 0.55}


 57%|█████▋    | 17500/30823 [1:25:14<1:04:34,  3.44it/s]

{'loss': 0.2455, 'learning_rate': 8.644843136618759e-06, 'epoch': 0.57}


 58%|█████▊    | 18000/30823 [1:27:36<59:08,  3.61it/s]  

{'loss': 0.2459, 'learning_rate': 8.320410083379296e-06, 'epoch': 0.58}


 60%|██████    | 18500/30823 [1:29:59<1:01:01,  3.37it/s]

{'loss': 0.2622, 'learning_rate': 7.99597703013983e-06, 'epoch': 0.6}


 62%|██████▏   | 19000/30823 [1:32:24<1:01:46,  3.19it/s]

{'loss': 0.2922, 'learning_rate': 7.671543976900368e-06, 'epoch': 0.62}


 63%|██████▎   | 19500/30823 [1:34:47<50:23,  3.75it/s]  

{'loss': 0.24, 'learning_rate': 7.347110923660903e-06, 'epoch': 0.63}


 65%|██████▍   | 20000/30823 [1:37:10<50:30,  3.57it/s]  

{'loss': 0.2538, 'learning_rate': 7.02267787042144e-06, 'epoch': 0.65}


 67%|██████▋   | 20500/30823 [1:39:36<48:22,  3.56it/s]  

{'loss': 0.2462, 'learning_rate': 6.698244817181975e-06, 'epoch': 0.67}


 68%|██████▊   | 21000/30823 [1:42:01<51:39,  3.17it/s]  

{'loss': 0.2541, 'learning_rate': 6.373811763942512e-06, 'epoch': 0.68}


 70%|██████▉   | 21500/30823 [1:44:24<38:17,  4.06it/s]

{'loss': 0.2605, 'learning_rate': 6.049378710703047e-06, 'epoch': 0.7}


 71%|███████▏  | 22000/30823 [1:46:45<39:17,  3.74it/s]

{'loss': 0.244, 'learning_rate': 5.724945657463583e-06, 'epoch': 0.71}


 73%|███████▎  | 22500/30823 [1:49:08<38:14,  3.63it/s]

{'loss': 0.2757, 'learning_rate': 5.400512604224119e-06, 'epoch': 0.73}


 75%|███████▍  | 23000/30823 [1:51:29<36:38,  3.56it/s]

{'loss': 0.2269, 'learning_rate': 5.076079550984655e-06, 'epoch': 0.75}


 76%|███████▌  | 23500/30823 [1:53:54<33:00,  3.70it/s]

{'loss': 0.237, 'learning_rate': 4.751646497745191e-06, 'epoch': 0.76}


 78%|███████▊  | 24000/30823 [1:56:16<29:17,  3.88it/s]

{'loss': 0.2361, 'learning_rate': 4.427213444505727e-06, 'epoch': 0.78}


 79%|███████▉  | 24500/30823 [1:58:39<33:04,  3.19it/s]

{'loss': 0.2486, 'learning_rate': 4.102780391266263e-06, 'epoch': 0.79}


 81%|████████  | 25000/30823 [2:01:04<28:47,  3.37it/s]

{'loss': 0.224, 'learning_rate': 3.7783473380267983e-06, 'epoch': 0.81}


 83%|████████▎ | 25500/30823 [2:03:27<23:38,  3.75it/s]

{'loss': 0.2535, 'learning_rate': 3.4539142847873343e-06, 'epoch': 0.83}


 84%|████████▍ | 26000/30823 [2:05:53<23:59,  3.35it/s]

{'loss': 0.2362, 'learning_rate': 3.1294812315478703e-06, 'epoch': 0.84}


 86%|████████▌ | 26500/30823 [2:08:15<19:56,  3.61it/s]

{'loss': 0.2042, 'learning_rate': 2.805048178308406e-06, 'epoch': 0.86}


 88%|████████▊ | 27000/30823 [2:10:38<20:50,  3.06it/s]

{'loss': 0.2447, 'learning_rate': 2.4806151250689423e-06, 'epoch': 0.88}


 89%|████████▉ | 27500/30823 [2:13:02<16:01,  3.46it/s]

{'loss': 0.2073, 'learning_rate': 2.1561820718294783e-06, 'epoch': 0.89}


 91%|█████████ | 28000/30823 [2:15:24<12:57,  3.63it/s]

{'loss': 0.2573, 'learning_rate': 1.831749018590014e-06, 'epoch': 0.91}


 92%|█████████▏| 28500/30823 [2:17:48<11:57,  3.24it/s]

{'loss': 0.2489, 'learning_rate': 1.50731596535055e-06, 'epoch': 0.92}


 94%|█████████▍| 29000/30823 [2:20:11<08:17,  3.66it/s]

{'loss': 0.214, 'learning_rate': 1.182882912111086e-06, 'epoch': 0.94}


 96%|█████████▌| 29500/30823 [2:22:34<05:15,  4.19it/s]

{'loss': 0.2068, 'learning_rate': 8.584498588716219e-07, 'epoch': 0.96}


 97%|█████████▋| 30000/30823 [2:24:56<04:27,  3.08it/s]

{'loss': 0.2434, 'learning_rate': 5.340168056321578e-07, 'epoch': 0.97}


 99%|█████████▉| 30500/30823 [2:27:21<01:30,  3.58it/s]

{'loss': 0.2871, 'learning_rate': 2.0958375239269377e-07, 'epoch': 0.99}


                                                       
100%|██████████| 30823/30823 [2:32:50<00:00,  3.36it/s]

{'eval_loss': 0.24188634753227234, 'eval_runtime': 236.7721, 'eval_samples_per_second': 57.95, 'eval_steps_per_second': 14.491, 'epoch': 1.0}
{'train_runtime': 9170.304, 'train_samples_per_second': 13.444, 'train_steps_per_second': 3.361, 'train_loss': 0.27492766681514336, 'epoch': 1.0}





TrainOutput(global_step=30823, training_loss=0.27492766681514336, metrics={'train_runtime': 9170.304, 'train_samples_per_second': 13.444, 'train_steps_per_second': 3.361, 'train_loss': 0.27492766681514336, 'epoch': 1.0})

In [8]:
#trainer.save_model('../output/bert_discriminator/final')

In [61]:
import torch
test_input = tokenizer(validation['text'][13871], return_tensors='pt').to(device)
with torch.no_grad():
    logits = model(**test_input).logits

predicted_class_id = logits.argmax().item()
predicted_class_id

1

In [48]:
validation

Unnamed: 0,text,label
0,"> Andy could come out first, if he hasn't alre...",[1]
1,[deleted] �� ��,[1]
2,[deleted],[1]
3,"My writing is good, and I feel like I have to ...",[1]
4,I read it and liked it.,[1]
...,...,...
13869,It has already been taught. This was recorded ...,[0]
13870,That's exactly what you're saying. You're actu...,[1]
13871,I don't care what other people think of Scalzi...,[1]
13872,I was thinking of it in a similar vein. I thin...,[1]
