<a href="https://colab.research.google.com/github/anwesham-lab/cs-229-230-project/blob/main/Baseline_IMDB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install Necessary Packages 
- datasets
- tokenizers
- transformers

From HuggingFace

In [None]:
!nvidia-smi

In [None]:
!pip install datasets tokenizers wandb seqeval
!pip install -qqq git+https://github.com/huggingface/transformers

Run all necessary imports at the top

In [None]:
# import os
import numpy as np
import torch
from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, DataCollatorForTokenClassification, AutoModelForTokenClassification, TrainingArguments, Trainer, EarlyStoppingCallback
import wandb
import random

In [None]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)

In [None]:
%cd /content/drive/My\ Drive/230
%pwd

# Load in the Dataset

Try the IMDB dataset that's on huggingface. 

In [None]:
dataset = load_dataset('imdb', 'Lucylulu--imdb')

In [None]:
dataset['train'][2]

In [None]:
dataset['validation'][4888]

In [None]:
dataset['test'][904]

In [None]:
print(len(dataset['train']), len(dataset['validation']), len(dataset['test']))

# Tokenization and Labeling Scheme


In [None]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")

def tokenize(all_samples_per_split):
  tokenized_samples = tokenizer.batch_encode_plus(all_samples_per_split["text"], is_split_into_words=False, truncation=True, max_length=512)
  return tokenized_samples

In [None]:
token_data = dataset.map(tokenize, batched=True)

Verify the data returns as expected with attention mask in triple with the input and token type IDs. 

In [None]:
token_data["test"][2]

# Padding

For all samples, X, where X not sample A, the length of X should equal the length of A for regular input handling with the attention model. Use data collator (huggingface implementation of collate_fn from pytorch, but a lil more portable imo). 

In [None]:
data_collator = DataCollatorForTokenClassification(tokenizer)

# Set Up Weights and Biases Logs

In [None]:
wandb.login()

In [None]:
wandb.init(project="trial_imdb", entity="anwesham")

#Evaluation Setup

Want to evaluate the precision, recall, f1, and general accuracy. We want both the f1 and the accuracy because generally, we'll want to gauge not only how impactful false positives and negatives are, but the general rate of correct predictions as well. 

#Initialize the Model

In [None]:
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-cased", num_labels=2)

#Define the training arguments and trainer 

In [None]:
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score

def compute_metrics(p):    
    pred, labels = p
    pred = np.argmax(pred, axis=1)
    accuracy = accuracy_score(y_true=labels, y_pred=pred)
    recall = recall_score(y_true=labels, y_pred=pred)
    precision = precision_score(y_true=labels, y_pred=pred)
    f1 = f1_score(y_true=labels, y_pred=pred)
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1} 

In [None]:
training_args = TrainingArguments(
    output_dir="./distilbert_imdb",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=7,
    logging_strategy='steps',
    logging_steps = 500,
    evaluation_strategy = 'epoch',
    save_strategy = 'epoch',
    learning_rate=2e-5,
    weight_decay=0.01,
    load_best_model_at_end=True,
    report_to = 'wandb'
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset= token_data["train"],
    eval_dataset= token_data["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=1, early_stopping_threshold=0.0))

# Train

In [None]:
trainer.train()

In [None]:
model.save_pretrained("distilbert_imdb")

In [None]:
wandb.finish()