<a href="https://colab.research.google.com/github/BushmelevKostya/NLP_course_task_2/blob/hw/NLP_generate_text.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers
!pip install datasets
!pip install accelerate -U
!pip install evaluate

In [None]:
random_state = 42

# Загрузка dataset

In [2]:
from datasets import load_dataset

## Fairy tale dataset

In [None]:
dataset = load_dataset("vicclab/fairy_tales")

## Yelp_review_full dataset

In [None]:
dataset = load_dataset("yelp_review_full")

# Загрузка Tokenizer

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")


def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)


tokenized_datasets = dataset.map(tokenize_function, batched=True)

In [5]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=random_state).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=random_state).select(range(1000))

In [7]:
small_train_dataset["text"][2]

'I LOVE Bloom Salon... all of their stylist are very qualified and provide excellent hair care...I prefer to book my appointments with Andrea, but if she is not available I am not afraid to book with anyone else.  Not only does this salon provide hair care, but they also offer skin, nails and massage therapy!!  What a great place with a relaxing atmosphere...I HIGHLY recommend this place.'

# Train

## GPT model

In [None]:
from transformers import GPT2LMHeadModel
model = GPT2LMHeadModel.from_pretrained('gpt2')

## Bert Model

In [None]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)

# Evaluate

In [None]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")

In [10]:
def compute_metrics(eval_pred):
  logits, labels = eval_pred
  predictions = np.argmax(logits, axis=1)
  return metric.compute(predictions=predictions, references=labels)

# Training hyperparameters


In [11]:
from transformers import TrainingArguments

training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")

In [12]:
training_args

TrainingArguments(
_n_gpu=0,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=True,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=None,
evaluation_strategy=epoch,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
gradient_checkpointing_kwargs=None,
greater_is_better=None,
group_by_length=False,
half_precision_backend=

# Trainer

In [None]:
from transformers import Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()

# Качество модели после дообучения

In [None]:
from transformers import BertTokenizer, BertForSequenceClassification, DataLoader
from datasets import load_dataset
from evaluate import load_metric

## Оценка качества модели на тестовых данных

In [None]:
test_dataloader = DataLoader(small_eval_dataset, shuffle=False, batch_size=8)
metric = load_metric("accuracy")
metric.compute(predictions=model.generate_predictions(test_dataloader), references=small_eval_dataset["label"])
print("Accuracy:", metric.compute())