# Train a Model from Hugging Face

## Install and Import Dependencies

In [None]:
from warnings import filterwarnings

import pandas as pd
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from transformers import AutoTokenizer, AutoModelForSequenceClassification

from configs import (
  TOXIC_DB_PATH, BENIGN_DB_PATH, PROMPT_TEMPLATE, EPOCHS, BATCH_SIZE,
  LEARNING_RATE, FACTOR, PATIENCE, THRESHOLD
)
from utils import DatabaseInterface, Trainer, get_device

filterwarnings('ignore')

## Initialize Database Interfaces

In [None]:
toxic_db = DatabaseInterface(n_neighbors=1, data_path=TOXIC_DB_PATH)
benign_db = DatabaseInterface(n_neighbors=1, data_path=BENIGN_DB_PATH)

## Import Model and Tokenizer from Hugging Face

In [None]:
# Hugging Face checkpoint or local path to the model and tokenizer
checkpoint = 's-nlp/roberta_toxicity_classifier'

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)

## Initialize Optimizer and Scheduler

In [None]:
optimizer = AdamW(
  model.parameters(),
  lr=LEARNING_RATE
)
scheduler = ReduceLROnPlateau(
  optimizer,
  mode='min',
  factor=FACTOR,
  patience=PATIENCE,
  threshold=THRESHOLD,
)

## Load Data

In [None]:
# Path to the csv files
data_paths = [
  '/home/sunil/nani/Detecting-Toxicity-Social-Media/datasets/processed/DHate/train.csv',
  '/home/sunil/nani/Detecting-Toxicity-Social-Media/datasets/processed/SBIC/train.csv',
  '/home/sunil/nani/Detecting-Toxicity-Social-Media/datasets/processed/ToxiGen/train.csv'
]
df = [pd.read_csv(path) for path in data_paths]
df = pd.concat(df)
df.tail()

In [None]:
threshold = .3
texts, labels = df['text'].tolist(), df['label'].map(lambda x: 1 if x > threshold else 0).tolist()
texts[-5:], labels[-5:]

## Initalize Trainer

In [None]:
trainer = Trainer(
  model=model,
  tokenizer=tokenizer,
  optimizer=optimizer,
  scheduler=scheduler,
  toxic_db=toxic_db,
  benign_db=benign_db,
  prompt_template=PROMPT_TEMPLATE
)

## Get Device and Train

In [None]:
device = get_device()
print(f'Using {device} device')

trainer.train(
  texts=texts,
  labels=labels,
  batch_size=BATCH_SIZE,
  epochs=EPOCHS,
  device=device
)

## Save Trained Model and Tokenizer

In [None]:
save_path = '/home/sunil/nani/Detecting-Toxicity-Social-Media/model'
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)