# HuggingArtists - Train a model to generate lyrics

Choose your favorite Artist and train a language model to write new lyrics based on their unique voice in just 5 minutes.

Auto-continue training from last chechpoint from [Hub](https://huggingface.co/huggingartists).

<img src="https://raw.githubusercontent.com/AlekseyKorshuk/huggingartists/master/img/logo.jpg" width="520" align="center">

## How to use

To start the demo, click on menu at top, "Runtime" → "Run all" or "Ctrl" + "F9".

Contact me if something doesn't work: [link](https://github.com/AlekseyKorshuk).


In [1]:
#@title Settings
#@markdown Enter artist name:
artist_name = "Eminem" #@param {type:"string"}
#@markdown Check existing dataset first (it will save 1-2 min but will not add new songs):
check_dataset = False #@param {type:"boolean"} 
#@markdown Total number of training epochs to perform (more epochs -> better result -> more time):
num_train_epochs =  1#@param {type:"integer"}
!nvidia-smi

Fri Jul 30 12:10:47 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.42.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   69C    P8    11W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
#@title Install dependencies

%%capture
!pip install transformers
!pip install datasets
!pip install torch
!pip install wandb
!pip install lyricsgenius
!pip install aiohttp
!pip install langdetect
!pip install --upgrade jax jaxlib 
!pip install --upgrade git+https://github.com/google/flax.git
!pip install tqdm --upgrade
!pip install hf-lfs
!git config --global user.email "ale-kor02@mail.ru"
!git config --global user.name "Aleksey Korshuk"
!git lfs install



## Pre-process data

In [3]:
#@title Attach needed code
# %%capture

import logging, sys
logging.disable(sys.maxsize)

from transformers.hf_api import HfApi
hfapi = HfApi()
user, namespace = 'huggingartists-app', 'huggingartists'
token = hfapi.login(user, namespace)
assert hfapi.whoami(token)[0] == user, "Could not log into Hugging Face"
!mkdir /root/.huggingface -p
text_file = open("/root/.huggingface/token", "w+")
text_file.write(token)
text_file.close() 

TOKEN = "q_JK_BFy9OMiG7fGTzL-nUto9JDv3iXI24aYRrQnkOvjSCSbY4BuFIindweRsr5I"
DATASET_LOAD_SCRIPT_URL = "https://raw.githubusercontent.com/AlekseyKorshuk/huggingartists/main/datasets/dataset.py"
DATASET_CARD_URL = "https://raw.githubusercontent.com/AlekseyKorshuk/huggingartists/main/datasets/README.md"
MODEL_CARD_URL = "https://raw.githubusercontent.com/AlekseyKorshuk/huggingartists/main/models/README.md"

from IPython.display import display, HTML, Javascript, clear_output
import lyricsgenius
from tqdm.notebook import tqdm as bar
import requests
from bs4 import BeautifulSoup
import re
from datasets import Dataset, DatasetDict
import numpy as np
import time
import os
import json
import langdetect
import datetime
from pathlib import Path
import wandb
import pathlib
import nest_asyncio
nest_asyncio.apply()

import asyncio
from concurrent.futures import ProcessPoolExecutor
import sys
sys.setrecursionlimit(99999)
import aiohttp

parser = requests.get("https://raw.githubusercontent.com/AlekseyKorshuk/huggingartists/main/datasets/parse.py").text
with open('parse.py', 'w+') as f:
  f.write(parser)

genius = lyricsgenius.Genius(TOKEN)
artist = genius.search_artist(artist_name, max_songs=0, get_full_info=False)


from IPython.display import clear_output 
clear_output()
from IPython.utils import io


from IPython.display import display, HTML, Javascript, clear_output


def stylize():
    "Handle dark mode"
    display(HTML('''
    <style>
    :root {
        --table_bg: #EBF8FF;
    }
    html[theme=dark] {
        --colab-primary-text-color: #d5d5d5;
        --table_bg: #2A4365;
    }
    .jupyter-widgets {
        color: var(--colab-primary-text-color);
    }
    table {
        border-collapse: collapse !important;
    }
    td {
        text-align:left !important;
        border: solid var(--table_bg) !important;
        border-width: 1px 0 !important;
        padding: 6px !important;
    }
    tr:nth-child(even) {
        background-color: var(--table_bg) !important;
    }
    .table_odd {
        background-color: var(--table_bg) !important;
        margin: 0 !important;
    }
    .table_even {
        border: solid var(--table_bg) !important;
        border-width: 1px 0 !important;
        margin: 0 !important;
    }
    .jupyter-widgets {
        margin: 6px;
    }
    .widget-html-content {
        font-size: var(--colab-chrome-font-size) !important;
        line-height: 1.24 !important;
    }
    </style>'''))


def create_dataset_load_script(model_name):
  response = requests.get(DATASET_LOAD_SCRIPT_URL)
  text = str(response.text)
  text = text.replace("MODEL_NAME", model_name)
  with open(f'{model_name}/{model_name}.py', 'w+') as f:
    f.write(text)

def create_dataset_card(model_name, settings):
  response = requests.get(DATASET_CARD_URL)
  text = str(response.text)
  for key in settings.keys():
    text = text.replace(key, str(settings[key]))
  with open(f'{model_name}/README.md', 'w+') as f:
    f.write(text)

def create_model_card(model_name, settings):
  response = requests.get(MODEL_CARD_URL)
  text = str(response.text)
  for key in settings.keys():
    text = text.replace(key, str(settings[key]))
  with open(f'{model_name}/README.md', 'w+') as f:
    f.write(text)

def artist_songs(artist_id, per_page=50, page=None, sort='popularity'):
  url = f'https://api.genius.com/artists/{artist_id}/songs?sort={sort}&per_page={per_page}&page={page}'
  headers = {
      'Authorization': f'Bearer {TOKEN}'
  }
  data = requests.get(
      url,
      headers=headers, 
      stream=True
  ).json()
  return data['response']


def get_artist_song_urls(artist_id):
  
  urls = []
  next_page = 1
  with bar(total=None) as pbar:
    pbar.set_description("⏳ Searching songs")
    while next_page is not None:

      data = artist_songs(artist.id, per_page=50, page=next_page)
      next_page = data['next_page']
      
      for song in data['songs']:
        urls.append(song['url'])
      pbar.update(len(data['songs']))   

    pbar.set_description("✅ Done")
  return urls

async def get_song_urls(artist_id):
  access_token = 'Bearer ' + TOKEN
  authorization_header = {'authorization': access_token}
  urls = []
  async with aiohttp.ClientSession(headers=authorization_header) as session:
    with bar(total=None) as pbar:
      pbar.set_description("⏳ Searching songs...")
      next_page = 1
      while next_page is not None:
        async with session.get(f"https://api.genius.com/artists/{artist_id}/songs?sort=popularity&per_page=50&page={next_page}", timeout=999) as resp:
          response = await resp.json()
          response = response['response']
        next_page = response['next_page']
        
        for song in response['songs']:
          urls.append(song['url'])
        pbar.update(len(response['songs']))
      pbar.set_description("✅ Done")
  return urls


def _get_lyrics(song_url):
    text = requests.get(song_url, stream=True).text
    
    html = BeautifulSoup(text.replace('<br/>', '\n'), 'html.parser')
    div = html.find("div", class_=re.compile("^lyrics$|Lyrics__Root"))
    if div is None:
      return None

    lyrics = div.get_text()

    lyrics = re.sub(r'(\[.*?\])*', '', lyrics)
    lyrics = re.sub('\n{2}', '\n', lyrics)  # Gaps between verses
    
    lyrics = str(lyrics.strip("\n"))
    lyrics = lyrics.replace("EmbedShare URLCopyEmbedCopy", "").replace("'", "")
    lyrics = re.sub("[\(\[].*?[\)\]]", "", lyrics)
    lyrics = re.sub(r'\d+$', '', lyrics)
    lyrics = str(lyrics).lstrip().rstrip()
    lyrics = str(lyrics).replace("\n\n", "\n")
    lyrics = str(lyrics).replace("\n\n", "\n")
    lyrics = re.sub(' +', ' ', lyrics)
    lyrics = str(lyrics).replace('"', "")
    lyrics = str(lyrics).replace("'", "")
    lyrics = str(lyrics).replace("*", "")
    return str(lyrics)


def get_lyrics(url):
  return _get_lyrics(url)


def process_page(html):
    '''Meant for CPU-bound workload'''
    html = BeautifulSoup(html.replace('<br/>', '\n'), 'html.parser')
    div = html.find("div", class_=re.compile("^lyrics$|Lyrics__Root"))
    if div is None:
      lyrics = ""
    else:
      lyrics = div.get_text()
    
    lyrics = re.sub(r'(\[.*?\])*', '', lyrics)
    lyrics = re.sub('\n{2}', '\n', lyrics)  # Gaps between verses
    
    lyrics = str(lyrics.strip("\n"))
    lyrics = lyrics.replace("EmbedShare URLCopyEmbedCopy", "").replace("'", "")
    lyrics = re.sub("[\(\[].*?[\)\]]", "", lyrics)
    lyrics = re.sub(r'\d+$', '', lyrics)
    lyrics = str(lyrics).lstrip().rstrip()
    lyrics = str(lyrics).replace("\n\n", "\n")
    lyrics = str(lyrics).replace("\n\n", "\n")
    lyrics = re.sub(' +', ' ', lyrics)

    lyrics = str(lyrics).replace('"', "")
    # lyrics = str(lyrics).replace("'", "")
    lyrics = str(lyrics).replace("*", "")
    return lyrics #, re.compile("^lyrics$|Lyrics__Root")


async def fetch_page(url, session):
    '''Meant for IO-bound workload'''
    async with session.get(url, timeout=999) as res:
      return await res.text()


async def process(url, session, pool, pbar):
    html = await fetch_page(url, session)
    pbar.update(1)
    return await asyncio.wrap_future(pool.submit(process_page, html))


async def dispatch(urls, pbar):
    print('\n')
    pool = ProcessPoolExecutor()
    async with aiohttp.ClientSession() as session:
        coros = (process(url, session, pool, pbar) for url in urls)
        lyrics = await asyncio.gather(*coros)
    return lyrics

def create_dataset(lyrics):
  train_percentage = 0.9
  validation_percentage = 0.07
  test_percentage = 0.03

  dataset = {}

  dataset['train'] = Dataset.from_dict({'text': list(lyrics)})

  # train, validation, test = np.split(lyrics, [int(len(lyrics)*train_percentage), int(len(lyrics)*(train_percentage + validation_percentage))])
  # if len(list(train)) != 0:
  #   dataset['train'] = Dataset.from_dict({'text': list(train)})
  # if len(list(validation)) != 0:
  #   dataset['validation'] = Dataset.from_dict({'text': list(validation)})
  # if len(list(test)) != 0:
  #   dataset['test'] = Dataset.from_dict({'text': list(test)})
  # del train
  # del validation
  # del test

  datasets = DatasetDict(dataset)
  del dataset
  return datasets


def commit_files(model_name, message):
  %cd $model_name
  !git add .
  !git commit -m "{message}"
  !git push
  %cd ..

def parse_dataset(model_name, namespace, artist_id):
  
  with io.capture_output() as captured:
    url = f"https://huggingface.co/datasets/{namespace}/{model_name}/tree/main"
    data = requests.get(url).text
    if data == "Not Found":
      !huggingface-cli repo create $model_name --type dataset --organization $namespace -y
    !rm -rf $model_name
    !git clone https://$user:$token@huggingface.co/datasets/$namespace/$model_name
    
  save_path = f"{model_name}/datasets.json"
  !python parse.py \
      --artist_id=$artist_id \
      --token=$TOKEN \
      --save_path=$save_path
  
  with io.capture_output() as captured:
    %cd $model_name
    # !git lfs untrack "*.json"
    !git lfs track "*.json"
    # !rm -rf
    %cd ..
    # !rm -rf $model_name
    # !mkdir $model_name
    create_dataset_load_script(model_name)
    global artist_url
    root_directory = Path(model_name)
    size = sum(f.stat().st_size for f in root_directory.glob('**/*') if f.is_file()) / 1000000
    
    with open(save_path) as f:
      data = json.load(f)
    
    settings = {
        'LANGUAGE': 'en',
        'USER_HANDLE': model_name,
        'YEAR': datetime.datetime.now().year,
        'USER_NAME': artist.name,
        'USER_PROFILE': artist.image_url,
        'TRAIN_SIZE': len(data['train']),
        'SIZE': str(size)
    }
    create_dataset_card(model_name, settings)
    commit_files(model_name, namespace)
    !rm -rf $model_name

from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML


def show_random_elements(dataset, num_examples=3):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    print(dataset[picks]['text'])
    print(len(dataset[picks]['text']))
    df = pd.DataFrame(dataset[picks]['text'])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))


def tokenize_function(examples):
    return tokenizer(examples["text"])


def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result


def show_result(dataset, num_examples=3):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    data = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
        data.append(str(dataset[pick]))  
    df = pd.DataFrame(data)
    # for column, typ in dataset.features.items():
    #     if isinstance(typ, ClassLabel):
    #         df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html(index=False).replace("\\n","<br>")))


def post_process(output_sequences):
    predictions = []
    generated_sequences = []

    max_repeat = 2

    # decode prediction
    for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
        generated_sequence = generated_sequence.tolist()
        text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True)
        generated_sequences.append(text.strip())
                    
    for i, g in enumerate(generated_sequences):
        res = str(g).replace('\n\n\n', '\n').replace('\n\n', '\n')
        lines = res.split('\n')
        # print(lines)
        i = max_repeat
        while i != len(lines):
          remove_count = 0
          for index in range(0, max_repeat):
            # print(i - index - 1, i - index)
            if lines[i - index - 1] == lines[i - index]:
              remove_count += 1
          if remove_count == max_repeat:
            lines.pop(i)
            i -= 1
          else:
            i += 1
        predictions.append('\n'.join(lines))

    return predictions

def get_table(table_data):
  html = ("</head>\r\n"
    "<body>\r\n\r\n"
    "<h2></h2>"
    "\r\n\r\n"
    "<table>\r\n"
    "    <colgroup>\r\n"
    "       <col span=\"1"
    "\" style=\"width: 10"
    "%;\">\r\n"
    "       <col span=\"1"
    "\" style=\"width: 10"
    "0%;\">\r\n"
    "    </colgroup>\r\n"
    f"{' '.join(table_data)}"
    "</table>\r\n\r\n"
    "</body>\r\n"
    "</html>")
  
  return html

def get_share_button(url):
    return f'''
            <div style="width: 76px;">
                <a target="_blank" href="{url}" style='background-color:rgb(27, 149, 224);border-bottom-left-radius:4px;border-bottom-right-radius:4px;border-top-left-radius:4px;border-top-right-radius:4px;box-sizing:border-box;color:rgb(255, 255, 255);cursor:pointer;display:inline-block;font-family:"Helvetica Neue", Arial, sans-serif;font-size:13px;font-stretch:100%;font-style:normal;font-variant-caps:normal;font-variant-east-asian:normal;font-variant-ligatures:normal;font-variant-numeric:normal;font-weight:500;height:28px;line-height:26px;outline-color:rgb(255, 255, 255);outline-style:none;outline-width:0px;padding-bottom:1px;padding-left:9px;padding-right:10px;padding-top:1px;position:relative;text-align:left;text-decoration-color:rgb(255, 255, 255);text-decoration-line:none;text-decoration-style:solid;text-decoration-thickness:auto;user-select:none;vertical-align:top;white-space:nowrap;zoom:1;'>
                <i style='background-attachment:scroll;background-clip:border-box;background-color:rgba(0,0,0,0);background-image:url(data:image/svg+xml,%3Csvg%20xmlns%3D%22http%3A%2F%2Fwww.w3.org%2F2000%2Fsvg%22%20viewBox%3D%220%200%2072%2072%22%3E%3Cpath%20fill%3D%22none%22%20d%3D%22M0%200h72v72H0z%22%2F%3E%3Cpath%20class%3D%22icon%22%20fill%3D%22%23fff%22%20d%3D%22M68.812%2015.14c-2.348%201.04-4.87%201.744-7.52%202.06%202.704-1.62%204.78-4.186%205.757-7.243-2.53%201.5-5.33%202.592-8.314%203.176C56.35%2010.59%2052.948%209%2049.182%209c-7.23%200-13.092%205.86-13.092%2013.093%200%201.026.118%202.02.338%202.98C25.543%2024.527%2015.9%2019.318%209.44%2011.396c-1.125%201.936-1.77%204.184-1.77%206.58%200%204.543%202.312%208.552%205.824%2010.9-2.146-.07-4.165-.658-5.93-1.64-.002.056-.002.11-.002.163%200%206.345%204.513%2011.638%2010.504%2012.84-1.1.298-2.256.457-3.45.457-.845%200-1.666-.078-2.464-.23%201.667%205.2%206.5%208.985%2012.23%209.09-4.482%203.51-10.13%205.605-16.26%205.605-1.055%200-2.096-.06-3.122-.184%205.794%203.717%2012.676%205.882%2020.067%205.882%2024.083%200%2037.25-19.95%2037.25-37.25%200-.565-.013-1.133-.038-1.693%202.558-1.847%204.778-4.15%206.532-6.774z%22%2F%3E%3C%2Fsvg%3E);background-origin:padding-box;background-position-x:0px;background-position-y:0px;background-repeat-x;background-repeat-y;background-size:auto;color:rgb(255,255,255);cursor:pointer;display:inline-block;font-family:"Helvetica Neue",Arial,sans-serif;font-size:13px;font-stretch:100%;font-style:italic;font-variant-caps:normal;font-variant-east-asian:normal;font-variant-ligatures:normal;font-variant-numeric:normal;font-weight:500;height:18px;line-height:26px;position:relative;text-align:left;text-decoration-thickness:auto;top:4px;user-select:none;white-space:nowrap;width:18px;'></i>
                <span style='color:rgb(255,255,255);cursor:pointer;display:inline-block;font-family:"Helvetica Neue",Arial,sans-serif;font-size:13px;font-stretch:100%;font-style:normal;font-variant-caps:normal;font-variant-east-asian:normal;font-variant-ligatures:normal;font-variant-numeric:normal;font-weight:500;line-height:26px;margin-left:4px;text-align:left;text-decoration-thickness:auto;user-select:none;vertical-align:top;white-space:nowrap;zoom:1;'>Tweet</span>
            </a>
            </div>
            '''

def share_model_table(artist_name, model_name):
  url = f"https://twitter.com/intent/tweet?text=I created an AI bot of {artist_name} with %23huggingartists!%0APlay with my model or create your own! &url=https://huggingface.co/huggingartists/{model_name}"

  share_button = get_share_button(url)
  table_data = [
        f'<tr><td>{share_button}</td><td>🎉 Share {artist_name} model: <a href="https://huggingface.co/huggingartists/{model_name}">https://huggingface.co/huggingartists/{model_name}</a></td></tr>'    
  ]
  return get_table(table_data)

def get_share_lyrics_url(artist_name, model_name, lyrics):
   return "https://twitter.com/intent/tweet?text=I created an AI bot of " + artist_name + " with %23huggingartists!%0A%0ABrand new song:%0A" + lyrics.replace('\n', '%0A').replace('"', '%22') + "%0A%0APlay with my model or create your own! &url=https://huggingface.co/huggingartists/" + model_name

In [4]:
#@title Collect data
if artist is not None:
  time.sleep(0.1)
  artist_dict = genius.artist(artist.id)['artist']
  artist_url = str(artist_dict['url'])
  model_name = artist_url[artist_url.rfind('/') + 1:].lower()

  datasets = None
  if check_dataset:
    print("Check existing dataset first...")
    url = f"https://huggingface.co/datasets/{namespace}/{model_name}/tree/main"
    data = requests.get(url).text
    if data != "Not Found":
      from datasets import load_dataset
      datasets = load_dataset(f"{namespace}/{model_name}")
      print("Dataset downloaded!")
  
  if datasets == None:
    if check_dataset:
      print("Dataset does not exist!")
    parse_dataset(model_name, namespace, int(artist.id))
    from datasets import load_dataset
    datasets = load_dataset(f"{namespace}/{model_name}")
  # show_random_elements(datasets["train"], num_examples=3)
else:
  import Exception
  raise Exception("Artist does not exist!")

⏳ Searching songs...: : 1228it [01:17, 15.75it/s]
⏳ Parsing lyrics...: 100% 1228/1228 [00:29<00:00, 41.62it/s] 


Downloading:   0%|          | 0.00/4.08k [00:00<?, ?B/s]

Downloading and preparing dataset lyrics_dataset/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/lyrics_dataset/default/1.0.0/c7899943cc5a9b2773cbe4ee63397b67826bf40b759a56d0e2040e88aac9a721...


Downloading:   0%|          | 0.00/3.96M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

Dataset lyrics_dataset downloaded and prepared to /root/.cache/huggingface/datasets/lyrics_dataset/default/1.0.0/c7899943cc5a9b2773cbe4ee63397b67826bf40b759a56d0e2040e88aac9a721. Subsequent calls will reuse this data.


## Clone or create repository 

In [5]:
#@title Clone or create repository
with io.capture_output() as captured:
  !rm -rf $model_name
  repo_url = hfapi.create_repo(token, name=model_name, organization=namespace, exist_ok=True)
  !git clone https://$user:$token@huggingface.co/$namespace/$model_name

## Map datasets and set up the Trainer

In [6]:
#@title Download model and tokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments

try:
  tokenizer = AutoTokenizer.from_pretrained(f"{namespace}/{model_name}")
  model = AutoModelForCausalLM.from_pretrained(f"{namespace}/{model_name}", cache_dir=pathlib.Path('cache').resolve())
except:
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
  model = AutoModelForCausalLM.from_pretrained("gpt2", cache_dir=pathlib.Path('cache').resolve())

tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=1, remove_columns=["text"])

# block_size = tokenizer.model_max_length
block_size = int(tokenizer.model_max_length / 4)

lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
    num_proc=1,
)

Downloading:   0%|          | 0.00/253 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/945 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/510M [00:00<?, ?B/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

In [7]:
#@title Set up the Trainer

trainer_state_path = f'{model_name}/trainer_state.json'
if os.path.isfile(trainer_state_path):
  f = open (trainer_state_path, "r")
  trainer_state = json.loads(f.read()) 
  f.close()
  epoch = trainer_state['epoch']
  num_train_epochs += epoch

seed_data = random.randint(0,2**32-1)
# Set-up Trainer
os.environ['WANDB_WATCH'] = 'false'  # used in Trainer
training_args = TrainingArguments(
    f"output/{model_name}",
    overwrite_output_dir=True,
    # evaluation_strategy = "epoch",
    learning_rate=1.372e-4,
    weight_decay=0.01,
    num_train_epochs=num_train_epochs,
    save_total_limit=1,
    save_strategy='epoch',
    save_steps=1,
    report_to=None,
    seed=seed_data,
    logging_steps=5,
    # disable_tqdm=True
    # load_best_model_at_end=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"],
    
)

from transformers import get_cosine_schedule_with_warmup
train_dataloader = trainer.get_train_dataloader()
num_train_steps = len(train_dataloader)
trainer.create_optimizer_and_scheduler(num_train_steps)
trainer.lr_scheduler = get_cosine_schedule_with_warmup(
      trainer.optimizer,
      num_warmup_steps=0,
      num_training_steps=num_train_steps
)

trainer.model.config.task_specific_params['text-generation'] = {
                    'do_sample': True,
                    'min_length': 100,
                    'max_length': 200,
                    'temperature': 1.,
                    'top_p': 0.95,
                    # 'prefix': '<|endoftext|>',
                    }

## Train and upload the model

In [8]:
#@title W&B Setup

from torch import __version__ as torch_version
from transformers import __version__ as transformers_version

with io.capture_output() as captured:
  wandb.login(key='cd33331f97be3145253704fc38efef090ffe8151') # huggingartists service key

WANDB_PROJECT = 'huggingartists'
WANDB_NOTES = "Github repo: https://github.com/AlekseyKorshuk/huggingartists"
WANDB_ENTITY = 'huggingartists'
VERSION = 1.0
os.environ['WANDB_NOTEBOOK_NAME'] = 'huggingartists-demo.ipynb'  # used in wandb cli
model_card_settings = {}

def setup_wandb():
  global model_card_settings
  run = wandb.init(name=f"{model_name}-preprocess",
          job_type='preprocess',
          config={'huggingartists version':VERSION,
              'handle':model_name,
              'seed data':seed_data},
          project = WANDB_PROJECT,
          entity = WANDB_ENTITY,
          notes = WANDB_NOTES,
          reinit=True)
    
  # log raw tweets as input
  global metadata
  metadata={'handle':model_name,
        'huggingartists version': VERSION}
  artifact_input = wandb.Artifact(
      f"lyrics-{model_name}",
      type='raw-dataset',
      description=f"Lyrics from {model_name} downloaded with Genius",                            
      metadata=metadata)
  with artifact_input.new_file('lyrics.txt') as f:
    json.dump(datasets['train'].to_dict(), f, indent=0, ensure_ascii=False)
  run.use_artifact(artifact_input)
  # log dataset as output                        
  metadata={'handle':model_name,
        'seed data': seed_data,
        'epochs': num_train_epochs,
        'huggingartists version': VERSION}
  global artifact_dataset
  artifact_dataset = wandb.Artifact(
    f"dataset-{model_name}",
    type='train-dataset',
    description=f"Dataset created from lyrics of {model_name}",
    metadata=metadata)
  with open(f"data_{model_name}_train.txt", 'w', encoding='utf-8') as f:
    f.write('\n\n\n'.join(datasets['train']['text']))
  artifact_dataset.add_file(f"data_{model_name}_train.txt")
  run.log_artifact(artifact_dataset)

  # keep track of url
  wandb_url = wandb.run.get_url()
  model_card_settings['WANDB_PREPROCESS'] = str(wandb_url)

  combined_dict = {**model.config.to_dict(), **training_args.to_sanitized_dict()}
  run = wandb.init(name=f"{model_name}-train",
          job_type='train',
          config={'huggingartists version':VERSION,
              'pytorch version': torch_version,
              'transformers version': transformers_version,
              'handle':model_name,
              **combined_dict},
          project = WANDB_PROJECT,
          entity = WANDB_ENTITY,
          notes = WANDB_NOTES,
          reinit=True)


  # keep track of url
  wandb_url = wandb.run.get_url()
  model_card_settings['WANDB_TRAIN'] = wandb_url



  # log dataset and pretrained model
  artifact_dataset.wait()
  run.use_artifact(artifact_dataset)
  artifact_gpt2 = wandb.Artifact(
    f'gpt2',
    type='pretrained-model',
    description=f'Pretrained model from OpenAI downloaded from 🤗 Transformers: https://huggingface.co/gpt2',
    metadata={'huggingartists version': VERSION})
  artifact_gpt2.add_dir('cache', name='gpt2')
  run.use_artifact(artifact_gpt2)
  return run

with io.capture_output() as captured:
  run = setup_wandb()

In [9]:
#@title W&B Visualization
from IPython.display import HTML

url = wandb.run._get_run_url() + "?jupyter=true"

html = HTML(
    """<iframe src="%s" style="border:none;width:100%%;height:900px">
                </iframe>"""
                % url
    )

display(
    html
)

In [10]:
#@title Train the model

# from IPython.display import HTML

# display(
#     html
# )

import torch
torch.cuda.empty_cache()
if os.path.isfile(trainer_state_path):
  data = trainer.train(resume_from_checkpoint=model_name)
else:
  data = trainer.train()
# print(data)

0it [00:00, ?it/s]

Step,Training Loss
32925,0.2539
32930,0.2785
32935,0.3592
32940,0.2371
32945,0.2547
32950,0.2917
32955,0.2698
32960,0.2711
32965,0.2781
32970,0.297


In [11]:
#@title Create model card

from os import listdir
from os.path import isfile, join, isdir
mypath = f'output/{model_name}'
checkpoints = [join(mypath, f) for f in listdir(mypath) if isdir(join(mypath, f)) and 'checkpoint' in join(mypath, f)]
checkpoint = checkpoints[0]
tokenizer.save_pretrained(checkpoint)

!mv $checkpoint/* $model_name
model_card_settings['LANGUAGE'] = 'en'
model_card_settings['USER_HANDLE'] = model_name
model_card_settings['YEAR'] = datetime.datetime.now().year
model_card_settings['USER_NAME'] = artist.name
model_card_settings['USER_PROFILE'] = artist.image_url

create_model_card(model_name, model_card_settings)

In [12]:
#@title W&B Artifacts
global metadata
metadata={'model url':f"https://huggingface.co/huggingartists/{model_name}",
		  'seed trainer':seed_data,
		  **metadata}
artifact_trained = wandb.Artifact(
	model_name,
	type='finetuned-model',
	description=f"Model fine-tuned on lyrics from {model_name}",
	metadata=metadata)

hf_urls = []
model_path = pathlib.Path(model_name)
hf_urls = [f'https://huggingface.co/huggingtweets/{model_name}/resolve/main/{f.name}' for f in model_path.glob('*') if f.suffix]
for hf_url in hf_urls:
	artifact_trained.add_reference(hf_url, checksum = False)
global run
run.log_artifact(artifact_trained)
print("")





In [13]:
#@title Upload the model

url = f"https://huggingface.co/{namespace}/{model_name}/raw/main/trainer_state.json"
data = requests.get(url).text
try:
  data = json.loads(data)
except:
  data = None
if data is None or data['epoch'] <= num_train_epochs:
  with io.capture_output() as captured:
    from transformers import FlaxAutoModelForCausalLM
    model_flax = FlaxAutoModelForCausalLM.from_pretrained(model_name, from_pt=True)
    model_flax.save_pretrained(model_name)
    commit_files(model_name, namespace)

!rm -rf output/$model_name
!rm -rf $model_name

## Generate lyrics

In [14]:
# @title Generate
#@markdown Enter starting sentence:
start = "I am" #@param {type:"string"}
#@markdown Amount of generated texts:
num_sequences =  10 #@param {type:"integer"}
#@markdown Generation settings:
min_length =  100 #@param {type:"integer"}
max_length =   160#@param {type:"integer"}
temperature = 1 #@param {type:"slider", min:0, max:3, step:0.01}
top_p = 0.95 #@param {type:"slider", min:0, max:1, step:0.01}
top_k = 50 #@param {type:"integer"}
repetition_penalty =  1.0#@param {type:"number"}

encoded_prompt = tokenizer(start, add_special_tokens=False, return_tensors="pt").input_ids
encoded_prompt = encoded_prompt.to(trainer.model.device)
# prediction
output_sequences = trainer.model.generate(
                        input_ids=encoded_prompt,
                        max_length=max_length,
                        min_length=min_length,
                        temperature=float(temperature),
                        top_p=float(top_p),
                        top_k=int(top_k),
                        do_sample=True,
                        repetition_penalty=repetition_penalty,
                        num_return_sequences=num_sequences
                        )
# Post-processing
predictions = post_process(output_sequences)

wandb.log({'examples': wandb.Table(data=[(start, result) for result in predictions], columns=['Input', 'Prediction'])})
stylize()
table_data = []
for result in predictions:
  table_data.append('<tr><td>' + get_share_button(get_share_lyrics_url(artist.name, model_name, result)) + '</td><td>' + result.replace("\n", "<br>") + '</td></tr>')
display(HTML(share_model_table(artist.name, model_name)))
display(HTML(get_table(table_data)))

0,1
Tweet,🎉 Share Eminem model: https://huggingface.co/huggingartists/eminem


0,1
Tweet,"I am depressed, Im disfigured I need a new doctor, I just know Ill stare him dead in the mouth With a plastic spoon and a squeegee, Im pissed I should eat a pill, but the side effects Are nasty to my family, like all the pain they could take If they ate every carats, its one million shekels And I dont need an abortion, cause I just take the risk Shoot up the playground and teach the dogs that Brain damage, when I pop the pills, take the heart out the baby Doctor says pop the brain pills cause its just insane, brain out the baby Is what I think? Brain out the baby? Brain out the baby? Yeah, brain out the baby, yeah"
Tweet,"I am a criminal! I dont gotta say a word, I just flip em the bird And keep goin, I dont take shit from no one Im a criminal! Coz I got a gun, do you read me? If you fuck me over some worrisome There aint nothin that I can do to you Im doing a hundred-yard dash just to slash you From head to toe, you makezer I aint never seen a foreseen a murder scene You might ha Run up on me, you be dead period Fore you collapse in your car, its gonna be nothing Fore you jump out of my freaking Pinto You may think thats the first place But thats just the way I live Sometimes I just feel"
Tweet,"I am I am bipolar with a hard cock I grab a pencil and squeeze it The thoughts scary, ask me is What kind of shit would he take me to jump off a bridge Straight off the terrace, cut through the side of your building Take your wife, hit her with a brick While youre out eating breakfast Grab your meat and wont hesitate to give it to you Im a sight to see, so just in case you aint Biting a scene, you taking pictures of me with Jesus All you naysayers, whoa.. Half of you dont even know what terrorism is And half of you dont even know the names of my family My family, my wife, and all my friends All"
Tweet,"I am not your Superman, I am not your Superman I am not your Superman, I am not your Superman Crazy as I am, cant nobody tame me Im on fire, so youre just burning me I am not your Superman Cant anything compare to these blue and yellow purple pills I’m high like a giant, swimming in water I’m ready to give up bodies, so you can ride around in Cant stop me from killing you I am not your Superman Cant anything compare to these blue and yellow purple pills Hello, Good morning You got something on your schedule You can bring me words to break the ice Steve, Im sorry, I couldn’t find my schedule So we start six"
Tweet,"I am! Because I am-Because I am, am crazy! And Im enraged, its my fault I said this shit I never had a motherfucking assight Mom, please dont call me! Bitch, I dont take�em, I take the motto I am whatever you say Cause I am whatever you say Cause I am phenomenal Im just eastside, dont be a downer Im not under the bus, I am the top of the charts Time to show you who I am So sit you down, I ask you this You gotta take off, you gotta take off Mean things, dont be a punk I say things, my thoughts, my thoughts I am phenomenal But Im not throwin subs"
Tweet,"I am Yeah, thats it Good, I love you, girl Would you marry me, still want to be a man All my life I want a bitch, girl You can blow me, all day, baby You can have all the girls on your arm Round and round, round the world You can have a big ol mouth, girl I can make her wet, take a sip Wake up and turn it up, make her sad Put me in the position to make you sad So baby can dance Dancing with the colors, dancing Dancing with the colors, dancing Make you sad You cant take a bath or a plow You cant leave the house You aint got no security cause youre fucking"
Tweet,"I am! And if I am, then Im just a product of my environment If I am, then Im a product of my environment, huh My skin is like a Rubiks Cube If I have the world around me, I belong inside a ball of bubble Theres billions of Rubiks Cube inside of my head You can sit and argue with me, but it is my brain that works And Im like a cocaine Alto Cause I can show you who the best is, yes, I guess And I dont even wanna mingle, but I just wanna Just keep droppin new shit and keep hoping it get released But, Yaowa, I was born dead—eight years ago It all started when my mother gave me an A"
Tweet,"I am not hip-hop vet, I am not dealin with Regardless of who Im tryna rap with I bob up and down in a cypher, tryna get To the spot, yeah, you got your chance to move As you make your move to the forefront, huh, remind me Of when I was still the one they was thinkin about Now the nigga makin the news, tryna make me a guest spot Nigga, notice the stress in my voice, Im releasin weight I gotta cancel out the cring, Im the only one your boys can taste It is the critically acclaimed, lyrically inclined I reach back to the raps, I start drama All day long, these niggas is sp"
Tweet,"I am not Jasmine, I am Aladdin So come see me in your Dennys world Even though Im here to save you, girl به باهی کارایی ها انگشترو دانه بخره Now Shady, I understand youre all in it for me, pretty ئیست میکنم همه که شڵیجاه میکنم But I just want you to me, my darling مینو که انگشترم"
Tweet,"I amma stand up and put one of these fingers on it And make it stop, just so sick, and ugly to look at, and think of, and not accept it. I hope theyre doing this stupid of a fashion, but this is not healthy, the sickest shit I could ever take away from yo, I got to watch this player keep walkin the line, this is sick, the way I got people locked up for this song, sick, the way I got people locked up for this song. Alright, maybe I need a song to cut away The tears come in the wind, the water flows slow in the wind, and it is not an easy thing for me to be here but.. I.. I am not gonna cry... Because I"


## About

*Built by Aleksey Korshuk*

[![Follow](https://img.shields.io/github/followers/AlekseyKorshuk?style=social)](https://github.com/AlekseyKorshuk)

🚀 If you want to contribute to this project OR create something cool together — contact me: [link](https://github.com/AlekseyKorshuk)

For more details, visit the project repository:

[![GitHub stars](https://img.shields.io/github/stars/AlekseyKorshuk/huggingartists?style=social)](https://github.com/AlekseyKorshuk/huggingartists)

**Disclaimer: this project is not to be used to publish any false generated information or unpleasant words but to perform research on Natural Language Generation.**

## Resources
* Inspired by [HuggingTweets](https://github.com/borisdayma/huggingtweets)
* [Explore the W&B report](https://wandb.ai/huggingartists/huggingartists/reportlist) to understand how the model works
* [HuggingFace and W&B integration documentation](https://docs.wandb.com/library/integrations/huggingface)

## Got questions about W&B?
If you have any questions about using W&B to track your model performance and predictions, please reach out to the [slack community](https://wb-forum.slack.com/signup#/).