## Importing Required Components

In [None]:
import torch
import random
import re
import math
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from transformers import AutoTokenizer, AutoConfig, GPT2LMHeadModel, AdamW, BertForNextSentencePrediction, get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler, random_split

## Some Preparations

In [None]:
%matplotlib inline

random.seed(17)
np.random.seed(17)
torch.manual_seed(17)
torch.cuda.manual_seed_all(17)

## Reading Dataset

In [None]:
df = pd.read_csv("data.csv")
raps = df["rap"].values.tolist()
rappers = df["rapper"].unique().tolist()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

## Setting Special Tokens

In [None]:
beginning_of_sentence_token = '<bos>'
end_of_sentence_token = '<eos>'
pad_token = '<pad>'
unknown_token = '<unk>'
start_of_rap_token = "<|rap_start|>"

## Configuring Tokenizer and Model Config

In [None]:
model_name = "HooshvareLab/gpt2-fa"

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    bos_token=beginning_of_sentence_token,
    eos_token=end_of_sentence_token,
    pad_token=pad_token,
    unk_token=unknown_token
)

config = AutoConfig.from_pretrained(
    model_name,
    bos_token_id=tokenizer(beginning_of_sentence_token)["input_ids"][0],
    eos_token_id=tokenizer(end_of_sentence_token)["input_ids"][0],
    pad_token_id=tokenizer(pad_token)["input_ids"][0],
    unk_token_id=tokenizer(unknown_token)["input_ids"][0],
)

## Defining Class For Managing The Dataset

In [None]:
class RapDataset(Dataset):

    def __init__(self, raps, tokenizer, max_length=1024):

        self.tokenizer = tokenizer
        self.input_ids = []
        self.attention_masks = []

        for rap in raps:
            encodings_dict = tokenizer(beginning_of_sentence_token + rap + end_of_sentence_token,
                                       truncation=True,
                                       max_length=max_length,
                                       padding="max_length")

            self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
            self.attention_masks.append(torch.tensor(encodings_dict['attention_mask']))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.attention_masks[idx]

## Splitting Train & Val Datasets

In [None]:
max_length = 32
dataset = RapDataset(raps, tokenizer, max_length=max_length)

train_dataset, val_dataset = random_split(dataset, [0.9, 0.1])

print(f'Train Dataset Size: {len(train_dataset)}, Validation Dataset Size: {len(val_dataset)}')

## Configuring Dataloaders

In [None]:
batch_size = 32

train_dataloader = DataLoader(
    train_dataset,
    sampler=RandomSampler(train_dataset),
    batch_size=batch_size
)

val_dataloader = DataLoader(
    val_dataset,
    sampler=SequentialSampler(val_dataset),
    batch_size=batch_size
)

## Getting the Model

In [None]:
model = GPT2LMHeadModel.from_pretrained(model_name, config=config)
model.resize_token_embeddings(len(tokenizer))

model.cuda()

## Configuring Optimizer and Scheduler

In [None]:
epochs = 5
learning_rate = 5e-5

optimizer = AdamW(model.parameters(), lr=learning_rate, eps=1e-8, no_deprecation_warning=True)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=len(train_dataloader) * epochs)

## Training and Evaluating the Model

In [None]:
train_losses, val_losses = [], []

for epoch in range(epochs):
    print(f'Epoch {epoch + 1}')

    model.train()
    train_loss = 0

    for step, batch in enumerate(train_dataloader):

        batch_input_ids = batch[0].to(device)
        batch_labels = batch[0].to(device)
        batch_masks = batch[1].to(device)

        model.zero_grad()

        res = model(batch_input_ids,
                    attention_mask=batch_masks,
                    labels=batch_labels,
                    return_dict=True)

        batch_loss = res.loss
        train_loss += batch_loss.item()

        batch_loss.backward()
        optimizer.step()
        scheduler.step()

    avg_train_loss = train_loss / len(train_dataloader)
    train_losses.append(avg_train_loss)

    model.eval()
    val_loss = 0

    for batch in val_dataloader:

        batch_input_ids = batch[0].to(device)
        batch_labels = batch[0].to(device)
        batch_masks = batch[1].to(device)

        with torch.no_grad():
            res = model(batch_input_ids,
                        attention_mask=batch_masks,
                        labels=batch_labels,
                        return_dict=True)

        batch_loss = res.loss
        val_loss += batch_loss.item()

    avg_val_loss = val_loss / len(val_dataloader)
    val_losses.append(avg_val_loss)

    print(f'Epoch Training Loss: {avg_train_loss}')
    print(f'Epoch Validation loss: {avg_val_loss}')

## Plotting Losses

In [None]:
plt.plot(train_losses)
plt.plot(val_losses)
plt.title("Model Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(['Train', 'Val'], loc='upper left')

plt.show()

## Defining a couple of Functions for Inference Time

In [None]:
def clean_generated_text(text):
    return re.sub("\n+", "\n", text.replace(beginning_of_sentence_token, "").replace(end_of_sentence_token, "").replace("<sep>", "\n").replace(start_of_rap_token, '\n'))

def generate(base_text, max_length=128, num_return_sequences=3):
    model.eval()

    base = torch.tensor(tokenizer.encode(base_text)).unsqueeze(0).to(device)

    generated_output = model.generate(
        base,
        do_sample=True,
        top_k=50,
        max_length=max_length,
        top_p=0.95,
        num_return_sequences=num_return_sequences
    )

    generated_raps = []
    for output in generated_output:
        generated_rap = tokenizer.decode(output, skip_special_tokens=False)
        generated_rap = clean_generated_text(generated_rap)
        generated_raps.append(generated_rap)

    return generated_raps

def find_candidates_with_lowest_common_words(candidates, previous_candidate_with_max_prob, num_of_candidates_to_choose=3):
    previous_sentence_words = set(previous_candidate_with_max_prob.split(" "))
    candidates_words = {candidate: set(candidate.split(" ")) for candidate in candidates}
    common_words_num = {}
    for candidate in candidates:
        candidate_words = candidates_words[candidate]
        num_of_common_words = 0
        for word in candidate_words:
            if word in previous_sentence_words:
                num_of_common_words += 1
        if num_of_common_words not in common_words_num:
            common_words_num[num_of_common_words] = []
        common_words_num[num_of_common_words].append(candidate)

    nums = list(common_words_num.keys())
    sorted(nums)

    chosen_candidates = []
    for num in nums:
        list_of_cans = common_words_num[num]
        for cans in list_of_cans:
            chosen_candidates.append(cans)
            if len(chosen_candidates) == num_of_candidates_to_choose:
                return chosen_candidates

## Getting The Bert Model and Tokenizer

In [None]:
bert_model_name = "HooshvareLab/bert-base-parsbert-uncased"
bert_model = BertForNextSentencePrediction.from_pretrained("models/bert_model").to(device)
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)

## Generating The Final Lyrics

In [None]:
num_of_lines = 15
num_of_candidates_to_consider = 10
rapper = rappers[np.random.randint(0, len(rappers))]
base_text = f"{beginning_of_sentence_token}{rapper}{start_of_rap_token}"
previous_candidate_with_max_prob = None

for i in range(num_of_lines):
    generated_raps = generate(base_text, num_return_sequences=num_of_candidates_to_consider)

    candidates = []
    for generated_rap in generated_raps:
        generated_rap_lines = generated_rap.split("\n")
        candidates.append(generated_rap_lines[i])

    candidates = find_candidates_with_lowest_common_words(candidates, previous_candidate_with_max_prob) if previous_candidate_with_max_prob is not None else candidates

    if i == 0:
        previous_candidate_with_max_prob = base_text
    else:
        max_prob, candidate_with_max_prob = -math.inf, ""

        for candidate in candidates:
            encoding = bert_tokenizer.encode_plus(previous_candidate_with_max_prob, candidate,
                                                  add_special_tokens=True,
                                                  max_length=128,
                                                  padding="max_length",
                                                  return_tensors="pt")

            input_ids = encoding["input_ids"].to(device)
            token_type_ids = encoding["token_type_ids"].to(device)
            attention_mask = encoding["attention_mask"].to(device)

            with torch.no_grad():
                res = bert_model(input_ids,
                                 token_type_ids=token_type_ids,
                                 attention_mask=attention_mask,
                                 return_dict=True)

            prob = res.logits.detach().cpu().numpy()[0][0]
            if prob > max_prob:
                max_prob = prob
                candidate_with_max_prob = candidate

        base_text += candidate_with_max_prob + "<sep>"
        previous_candidate_with_max_prob = candidate_with_max_prob

print(clean_generated_text(base_text))