# Installing Requirements

In [None]:
# !pip install transformers

# Imports

In [None]:
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
from transformers import Trainer, TrainingArguments, BertTokenizer, BertForMaskedLM


# Connecting to Google Drive

# Hyperparameters

In [None]:
MAX_LEN = 64
TRAIN_BATCH_SIZE = 64
VALID_BATCH_SIZE = 32
LEARNING_RATE = 1e-05
NUM_CLASSES = 6

# Processing data

## Creating a dataframe

In [None]:
df = pd.read_csv("../input/friends-dialogues/dialogues_cleaned.csv")
df = df.drop(df[df["person"]=="person"].index)

## Label Encoder

In [None]:
rachel_dlgs = df[df["person"]=="rachel"]["dialogue"].values

In [None]:
model = BertForMaskedLM.from_pretrained('bert-base-uncased')

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', mask_toke="[MASK]", sep_token="[SEP]", pad_token="[PAD]")

def tokenize_batch(batch):
    return [tokenizer.convert_tokens_to_ids(sent) for sent in batch]

def untokenize_batch(batch):
    return [tokenizer.convert_ids_to_tokens(sent) for sent in batch]

def detokenize(sent):
    """ Roughly detokenizes (mainly undoes wordpiece) """
    new_sent = []
    for i, tok in enumerate(sent):
        if tok.startswith("##"):
            new_sent[len(new_sent) - 1] = new_sent[len(new_sent) - 1] + tok[2:]
        else:
            new_sent.append(tok)
    return new_sent

CLS = '[CLS]'
SEP = '[SEP]'
MASK = '[MASK]'
mask_id = tokenizer.convert_tokens_to_ids([MASK])[0]
sep_id = tokenizer.convert_tokens_to_ids([SEP])[0]
cls_id = tokenizer.convert_tokens_to_ids([CLS])[0]

In [None]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)

In [None]:
class Dataset(torch.utils.data.Dataset):    
    def __init__(self, x):          
        self.x = x
     
    def __getitem__(self, idx):
        return tokenizer(self.x[idx])["input_ids"]
        
    def __len__(self):
        return len(self.x)

In [None]:
rachel_ds = Dataset(rachel_dlgs.tolist())


# Model

In [None]:
def compute_metrics(p):
    pred, labels = p
    pred = np.argmax(pred, axis=-1)
    accuracy = accuracy_score(y_true=labels, y_pred=pred)
    # recall = recall_score(y_true=labels, y_pred=pred)
    # precision = precision_score(y_true=labels, y_pred=pred)
    # f1 = f1_score(y_true=labels, y_pred=pred)
    return {"accuracy": accuracy} 

In [None]:
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=10,              # total number of training epochs
    per_device_train_batch_size=64,  # batch size per device during training
    per_device_eval_batch_size=256,   # batch size for evaluation
    evaluation_strategy="epoch",
    report_to=None
)

# trainer_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=6)

In [None]:
rachel_dlgs_tokenized = tokenizer(rachel_dlgs.tolist(), padding=True, truncation=True, max_length=MAX_LEN)

In [None]:
rachel_dlgs_tokenized

In [None]:
 !pip install --upgrade pip

In [None]:
!pip install transformers==4.20.1

In [None]:
import transformers
transformers.__version__

In [None]:
trainer = Trainer(
    model=model,                 # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=rachel_ds,         # training dataset
    eval_dataset=rachel_ds[:10],
#     compute_metrics=compute_metrics,
    data_collator=data_collator,
)

In [None]:
import math
import time

def generate_step(out, gen_idx, top_k=0, sample=False, return_list=True):
    """ Generate a word from from out[gen_idx]
    
    args:
        - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
        - gen_idx (int): location for which to generate for
        - top_k (int): if >0, only sample from the top k most probable words
        - sample (Bool): if True, sample from full distribution. Overridden by top_k 
    """
    # print("g", out["logits"].shape)
    logits = out["logits"][:, gen_idx]

    if top_k > 0:
        kth_vals, kth_idx = logits.topk(top_k, dim=-1)
        dist = torch.distributions.categorical.Categorical(logits=kth_vals)
        idx = kth_idx.gather(dim=1, index=dist.sample().unsqueeze(-1)).squeeze(-1)
    elif sample:
        dist = torch.distributions.categorical.Categorical(logits=logits)
        idx = dist.sample().squeeze(-1)
    else:
        idx = torch.argmax(logits, dim=-1)
    return idx.tolist() if return_list else idx
  
  
def get_init_text(seed_text, max_len, batch_size = 1, rand_init=False):
    """ Get initial sentence by padding seed_text with either masks or random words to max_len """
    batch = [seed_text + [MASK] * max_len + [SEP] for _ in range(batch_size)]
    return tokenize_batch(batch)

def printer(sent, should_detokenize=True):
    if should_detokenize:
        sent = detokenize(sent)[1:-1]
    # print(" ".join(sent))


def generate(n_samples, seed_text="[CLS]", batch_size=10, max_len=15, leed_out_len=15,
             sample=True, top_k=10, temperature=1.0, burnin=200, max_iter=500, print_every=1):
    sentences = []
    n_batches = math.ceil(n_samples / batch_size)
    start_time = time.time()
    seed_len = len(seed_text)
    batch = get_init_text(seed_text, max_len, batch_size)
    
    for ii in range(max_len):
        inp = [sent[:seed_len+ii+leed_out_len]+[sep_id] for sent in batch]
        inp = torch.tensor(batch).cuda()
#         torch.tensor(batch)
        out = model(inp)
        # print(seed_len, ii, out.keys())
        idxs = generate_step(out, gen_idx=seed_len+ii, top_k=top_k, sample=sample)
        for jj in range(batch_size):
            batch[jj][seed_len+ii] = idxs[jj]
        
    return untokenize_batch(batch)


In [None]:
import os
os.environ["WANDB_DISABLED"] = "true"

In [None]:
history = trainer.train()

In [None]:
n_samples = 1
batch_size = 5
max_len = 40
top_k = 100
temperature = 1.0
leed_out_len = 5 # max_len
burnin = 250
sample = True
max_iter = 500

# Choose the prefix context
seed_text = "[CLS]".split()
bert_sents = generate(n_samples, seed_text=seed_text, batch_size=batch_size, max_len=max_len,
                      sample=sample, top_k=top_k, temperature=temperature, burnin=burnin, max_iter=max_iter)

In [None]:
bert_sents

In [None]:
for sent in bert_sents:
  print(' '.join(sent))

In [None]:
# trainer_model.save_pretrained("/content/gdrive/My Drive/nlp project/")