<a href="https://colab.research.google.com/github/nuwandavek/you/blob/master/Training_You.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup

First, connect to a GPU runtime via Edit->Notebook Settings and select GPU as the hardare accelerator. Then, run the block below to install the libraries required to fine-tune DistilGPT2 on your WhatsApp history.

In [None]:
#@title Install libraries
!pip install transformers
!git clone https://github.com/huggingface/transformers.git
!pip install ./transformers
!pip install -r ./transformers/examples/language-modeling/requirements.txt
!mkdir output

## Upload files for training

In order to train the model on your chat history, first export your chat history in the form of txt files (instructions [here](https://faq.whatsapp.com/android/chats/how-to-save-your-chat-history/?lang=en)). Then, run the following block and upload all your txt files.

In [None]:
#@title Upload WhatsApp history files
import re

def RemoveTimestamps(text):
  return re.sub(b'\d+/\d+/\d+.*-\ ', b'', text)

def UnicodeString(bytes_string):
  return bytes_string.decode('utf-8')

def AddSeparators(file_text):
  return b'#\n'.join(file_text.split(b'\n'))

CHUNK_LENGTH = 500
def ChunkFile(file_text):
  lines = file_text.split(b'\n')
  chunks = []
  for line_index in range(0, len(lines), CHUNK_LENGTH):
    chunk = b'\n'.join(lines[line_index:line_index+CHUNK_LENGTH])
    chunk += b'<|endoftext|>'
    chunks.append(chunk)
  return chunks

from itertools import chain
import random
def MixChunks(chunked_files):
  all_chunks = [chunk for chunked_file in chunked_files for chunk in chunked_file]
  random.shuffle(all_chunks)
  return all_chunks

def ConvertChunksToString(chunks):
  return b'\n'.join(chunks)

def GetShuffledAndCleanedTextFromFiles(file_contents):
  file_chunks = []
  for file_content in file_contents:
    file_chunks.append(ChunkFile(AddSeparators(RemoveTimestamps(file_content))))
  return ConvertChunksToString(MixChunks(file_chunks))

import random

def SampleTextFromFile(file):
  file_contents = open(file).readlines()
  begin = random.randint(0, len(file_contents) - 50)
  for line in file_contents[begin:begin+50]:
    print(line, end='')

from google.colab import files
uploaded_files = files.upload()

## Construct training data

Next, we clean up the data and prep it for training.

In [None]:
#@title Clean data and create train and test splits.
cleaned_text = GetShuffledAndCleanedTextFromFiles(uploaded_files.values())
data_file = open('data.txt', 'wb')
data_file.write(cleaned_text)
data_file.close()
num_lines = cleaned_text.count(b'\n')
test_size = int(0.1 * num_lines)
train_size = num_lines - test_size
data_file.close()
!tail -n {test_size} data.txt > test.txt
!head -n {train_size} data.txt > train.txt

We can sample chunks from the training data file to inspect it. Note that a '#' token has been added at the ends of messages, and a <|endoftext|> token delineates different chat files.

In [None]:
SampleTextFromFile('train.txt')

Vivek: Can transfer after one hour of adding#
Vivek: What verification thing?#
Sreejith2: The bank account should have 40 lakhs thing#
Sreejith2: Keep it and transfer after no?#
Vivek: Yoyo all that is over \m/#
Sreejith2: Wooh!#
Sreejith2: Peace peace#
Vivek: That was required before visa#
Sreejith2: Transfer off then#
Vivek: Now peacemax#
Vivek: 😅#
Vivek: Yoyoyo#
Sreejith2: Hahaha nice nice!#
Vivek: What plans today?#
Vivek: Free for a call?#
Sreejith2: Hey, no plans as such#
Sreejith2: Yo in 5 mins#
Vivek: Yoyo#
Vivek: Ping me#
Sreejith2: Haan#
Vivek: Eyo#
Vivek: I sent 1000 rs#
Vivek: Got that?#
Vivek: Once you confirm I'll transfer the rest#
Sreejith2: Hey got#
Sreejith2: Got 1k#
Vivek: Yoyoyo#<|endoftext|>
Himaya: hope I have a good day#
Mihir London: https://player.vimeo.com/video/427943452#
Mihir London: Wtf#
Sreejith2: Wow that's amazing 🤯#
Rishi Amreeka: https://youtu.be/fZSFNUT6iY8#
Rishi Amreeka: Have you guys seen this one?#
Rishi Amreeka: Pretty insane#
Sreejith2: Wow, wt

## Train model

Next, we fine-tune the DistilGPT2 model on our training data. Depending on how many files you uploaded, this could take between 5-30 minutes.

In [None]:
!python ./transformers/examples/language-modeling/run_clm.py --model_name_or_path distilgpt2 --train_file train.txt --validation_file test.txt --do_train --do_eval --output_dir ./output --per_gpu_train_batch_size 1 --per_gpu_eval_batch_size 1 --save_steps 800 --eval_steps 800 --logging_steps 800 --evaluation_strategy steps --overwrite_output_dir --block_size 256

2021-01-05 07:35:25.320305: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1
01/05/2021 07:35:27 - INFO - __main__ -   Training/evaluation parameters TrainingArguments(output_dir=./output, overwrite_output_dir=True, do_train=True, do_eval=True, do_predict=False, model_parallel=False, evaluation_strategy=EvaluationStrategy.STEPS, prediction_loss_only=False, per_device_train_batch_size=8, per_device_eval_batch_size=8, gradient_accumulation_steps=1, eval_accumulation_steps=None, learning_rate=5e-05, weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, max_grad_norm=1.0, num_train_epochs=3.0, max_steps=-1, lr_scheduler_type=SchedulerType.LINEAR, warmup_steps=0, logging_dir=runs/Jan05_07-35-27_6201cf33901f, logging_first_step=False, logging_steps=800, save_steps=800, save_total_limit=None, no_cuda=False, seed=42, fp16=False, fp16_opt_level=O1, local_rank=-1, tpu_num_cores=None, tpu_metrics_debug=False, d

## Play with model

In [None]:
from transformers import pipeline

In [None]:
ft_generator = pipeline('text-generation', model='./output')

In [None]:
def PrettyPrintPrediction(text):
  print()
  text = text.replace('#', '\n')
  print(text)

In [None]:
ft_generator( )

In [None]:
for text in ft_generator("Vivek: Mihir sucks #Sreejith2: I agree! Tell me more#Vivek: Dude he always makes fun of me#Vivek:", max_length=256, num_return_sequences=3):
  PrettyPrintPrediction(text['generated_text'])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Vivek: Mihir sucks \m/
Sreejith2: I agree! Tell me more
Vivek: Dude he always makes fun of me
Vivek: 🤣
Sreejith2: Hey thanks man. There is just one guy in the building who says "hey fuck you bitch"
Sreejith2: https://youtu.be/z6YQtJd8sq
Vivek: He has more jokes than Hitler
Vivek: 🤣
Sreejith2: Oho
Sreejith2: Hey, all peace is here man
Vivek: Peace, will probably find peace here soon
Sreejith2: I think only if you actually feel safe.
Vivek: Yup.
Sreejith2: Can stay in your car next morning
Vivek: There?
Sreejith2: Come to the police station
Sreejith2: Wassup man
Vivek: Hey!
Sreejith2: Whose name you're working on?
Vivek: I want to go to your place
Sreejith2: What time

Vivek: Mihir sucks \m/
Sreejith2: I agree! Tell me more
Vivek: Dude he always makes fun of me
Vivek: He's the same person
Vivek: 🙄
Sreejith2: Hey how did you get your visa?
Vivek: To the US?
Sreejith2: Hahaha
Vivek: I wanted to change the visa
Vivek: I decided to apply for it
Sreejith2: In Bangalore for the last 2 days
Vi

## Download model

Download the model so you can use it with the Chrome extension.

In [None]:
!zip model.zip ./output/*

  adding: output/checkpoint-1600/ (stored 0%)
  adding: output/checkpoint-2400/ (stored 0%)
  adding: output/checkpoint-3200/ (stored 0%)
  adding: output/checkpoint-4000/ (stored 0%)
  adding: output/checkpoint-800/ (stored 0%)
  adding: output/config.json (deflated 51%)
  adding: output/eval_results_clm.txt (stored 0%)
  adding: output/merges.txt (deflated 53%)
  adding: output/pytorch_model.bin (deflated 9%)
  adding: output/special_tokens_map.json (deflated 52%)
  adding: output/tokenizer_config.json (deflated 38%)
  adding: output/trainer_state.json (deflated 70%)
  adding: output/training_args.bin (deflated 46%)
  adding: output/train_results.txt (deflated 10%)
  adding: output/vocab.json (deflated 59%)


In [None]:
ls -l

total 302268
-rw-r--r--  1 root root   1276333 Jan  4 07:30  data.txt
-rw-r--r--  1 root root 305021348 Jan  4 07:58  model.zip
drwxr-xr-x  7 root root      4096 Jan  4 07:38  [0m[01;34moutput[0m/
drwxr-xr-x  3 root root      4096 Jan  4 07:31  [01;34mruns[0m/
drwxr-xr-x  1 root root      4096 Dec 21 17:29  [01;34msample_data[0m/
-rw-r--r--  1 root root    127482 Jan  4 07:30  test.txt
-rw-r--r--  1 root root   1148810 Jan  4 07:30  train.txt
drwxr-xr-x 15 root root      4096 Jan  4 07:28  [01;34mtransformers[0m/
-rw-r--r--  1 root root    188024 Jan  4 07:29 'WhatsApp Chat with 5 Years Time 🌞.txt'
-rw-r--r--  1 root root     96072 Jan  4 07:29 'WhatsApp Chat with Mihir London.txt'
-rw-r--r--  1 root root    493383 Jan  4 07:30 'WhatsApp Chat with Rishi Amreeka.txt'
-rw-r--r--  1 root root    271150 Jan  4 07:29 'WhatsApp Chat with Sreejith2.txt'
-rw-r--r--  1 root root    351486 Jan  4 07:29 'WhatsApp Chat with Sreejith.txt'
-rw-r--r--  1 root root    509144 Jan  4 07:30 'Wha

In [None]:
files.download('model.zip')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

## Load a saved model

Use this to play with a model you've previously downloaded. You will need to connect colab to a locally running Jupyter runtime.

In [None]:
from transformers import pipeline

In [None]:
ft_generator = pipeline('text-generation', model='../../Downloads/output_2')

In [None]:
for text in ft_generator("Vivek: Mihir sucks :(#Sreejith2: I agree! Tell me more#Vivek: Dude he always makes fun of me#Vivek:", max_length=100, num_return_sequences=3, do_sample=True, eos_token_id=2, pad_token_id=0, skip_special_tokens=True, top_k=50, top_p=0.95):
  PrettyPrintPrediction(text['generated_text'])


Vivek: Mihir sucks :(
Sreejith2: I agree! Tell me more
Vivek: Dude he always makes fun of me
Vivek: I just wanted to know if there is any one thing that he does in life for me no?


Vivek: Mihir sucks :(
Sreejith2: I agree! Tell me more
Vivek: Dude he always makes fun of me
Vivek: And what did you mean?
!!!!!!!!!!!!!!

Vivek: Mihir sucks :(
Sreejith2: I agree! Tell me more
Vivek: Dude he always makes fun of me
Vivek: That was a big problem for me when I was younger 🤣
!!!!!!
