## Fine Tune Language Model

Since our questions come from the Stack Exchange network, it's probably useful for us to fine-tune a model against all of the questions and answers on this site.

I've never done this before, so let's try to figure it out.

In [1]:
import os
import glob

import pandas as pd
import xml.etree.ElementTree as ET

from lxml import etree
from tqdm import tqdm
from os import listdir
from pathlib import Path
from bs4 import BeautifulSoup

import cufflinks as cf
cf.go_offline()

In [2]:
stackexchange = Path('data/stackexchange')
all_sites_directories = next(os.walk(stackexchange))[1]
len(all_sites_directories)

312

In [13]:
# Take each site and write its raw text to a .txt file
def save_site_text_as_txt_file(directory:Path):
    
    output_file_path = stackexchange / (directory.name + ".txt")
    if os.path.exists(output_file_path):
        print("Already done: {}".format(output_file_path))
        return
    with open(output_file_path, 'w') as out_file:
        
        # If the Posts-short.xml file exists, use that.      
        full_path = directory/"Posts-short.xml"
        if not os.path.exists(full_path):
            full_path = directory/"Posts.xml"
            
        tree = ET.parse(full_path)
        root = tree.getroot()
        for child in root.getchildren():
            # Read HTML post with BeautifulSoup, strip tags
            soup = BeautifulSoup(child.attrib['Body'])
            text = soup.get_text()

            try:
                out_file.write(text)
                out_file.write('\n')
            except:
                print("Error, skipping post.")


In [14]:
# Get .txt file for every site we've downloaded
for i, directory in enumerate(all_sites_directories):

    # We have to handle this one manually since it's 76 GB :(
    if directory == 'stackoverflow.com-Posts':
        print("Skipping posts for: {}".format(directory))
        continue
        
    print("{} Working on: {}".format(i,directory))
    relative_path = stackexchange/directory
    save_site_text_as_txt_file(relative_path)
    print("{} Completed: {}".format(i,directory))

0 Working on: earthscience.meta.stackexchange.com
Already done: data/stackexchange/earthscience.meta.stackexchange.com.txt
0 Completed: earthscience.meta.stackexchange.com
1 Working on: monero.stackexchange.com
Already done: data/stackexchange/monero.stackexchange.com.txt
1 Completed: monero.stackexchange.com
2 Working on: christianity.stackexchange.com
Already done: data/stackexchange/christianity.stackexchange.com.txt
2 Completed: christianity.stackexchange.com
3 Working on: ebooks.stackexchange.com
Already done: data/stackexchange/ebooks.stackexchange.com.txt
3 Completed: ebooks.stackexchange.com
4 Working on: vegetarianism.meta.stackexchange.com
Already done: data/stackexchange/vegetarianism.meta.stackexchange.com.txt
4 Completed: vegetarianism.meta.stackexchange.com
5 Working on: space.meta.stackexchange.com
Already done: data/stackexchange/space.meta.stackexchange.com.txt
5 Completed: space.meta.stackexchange.com
6 Working on: workplace.stackexchange.com
Already done: data/stacke


This method will be removed in future versions.  Use 'list(elem)' or iteration over elem instead.



261 Completed: math.stackexchange.com
262 Working on: skeptics.stackexchange.com
Already done: data/stackexchange/skeptics.stackexchange.com.txt
262 Completed: skeptics.stackexchange.com
263 Working on: tor.meta.stackexchange.com
Already done: data/stackexchange/tor.meta.stackexchange.com.txt
263 Completed: tor.meta.stackexchange.com
264 Working on: crafts.stackexchange.com
Already done: data/stackexchange/crafts.stackexchange.com.txt
264 Completed: crafts.stackexchange.com
265 Working on: aviation.meta.stackexchange.com
Already done: data/stackexchange/aviation.meta.stackexchange.com.txt
265 Completed: aviation.meta.stackexchange.com
266 Working on: devops.stackexchange.com
Already done: data/stackexchange/devops.stackexchange.com.txt
266 Completed: devops.stackexchange.com
267 Working on: softwarerecs.meta.stackexchange.com
Already done: data/stackexchange/softwarerecs.meta.stackexchange.com.txt
267 Completed: softwarerecs.meta.stackexchange.com
268 Working on: gis.stackexchange.com


So we've got a problem with StackOverflow. StackOverflow is ~76GB which is crashing when we try to load it into memory. 

In [7]:
so_path = stackexchange/'stackoverflow.com-Posts/Posts-short.xml'

In [10]:
def fast_iter(context, func, *args, **kwargs):
    """
    http://lxml.de/parsing.html#modifying-the-tree
    Based on Liza Daly's fast_iter
    http://www.ibm.com/developerworks/xml/library/x-hiperfparse/
    See also http://effbot.org/zone/element-iterparse.htm
    """
    for i, (event, elem) in enumerate(context):
        func(elem, *args, **kwargs)
        # It's safe to call clear() here because no descendants will be
        # accessed
        elem.clear()
        # Also eliminate now-empty references from the root node to elem
        for ancestor in elem.xpath('ancestor-or-self::*'):
            while ancestor.getprevious() is not None:
                del ancestor.getparent()[0]
             
    del context


def process_element(elem, out_file):
    soup = BeautifulSoup(elem.attrib['Body'])
    text = soup.get_text()
    
    # write it to the file
    try:
        out_file.write(text)
        out_file.write('\n')        
    except:
        print("Error, skipping post.")

    
# Open up the XML file for all posts
with open(so_path, 'rb') as xml_file:
    context = etree.iterparse(xml_file, tag='row')
    
    output_file_path = stackexchange / 'stackoverflow.com-Posts.txt'
    with open(output_file_path, 'w') as out_file:
        fast_iter(context, process_element, out_file)        



"https://github.com/OpenOLAT/OpenOLAT" looks like a URL. Beautiful Soup is not an HTTP client. You should probably use an HTTP client like requests to get the document behind the URL, and feed that document to Beautiful Soup.



## Compare to Train and Test Distributions

We may have different quantities of text for each website than our train and test distributions do. We should make some effort to get the magnitudes relatively close (eg. We don't want 75% of our training data to be from StackOverflow)

In [15]:
data = Path('data')
train = data/'train.csv'
test = data/'test.csv'

train_df = pd.read_csv(train)
test_df = pd.read_csv(test)

In [16]:
temp = train_df.loc[train_df['host'] != 'stackoverflow.com']["host"].value_counts()
df = pd.DataFrame({'labels': temp.index,
                   'values': temp.values
                  })
df.iplot(kind='pie',labels='labels',values='values', title='Distribution of hosts')

In [17]:
temp = test_df.loc[test_df['host'] != 'stackoverflow.com']["host"].value_counts()
df = pd.DataFrame({'labels': temp.index,
                   'values': temp.values
                  })
df.iplot(kind='pie',labels='labels',values='values', title='Distribution of hosts')

### Create Training Set

In [18]:
os.getcwd()

'/home/josh/git/GoogleQALabeling'

In [19]:
text_files = glob.glob('data/stackexchange/*.txt')

In [20]:
file_sizes = [(Path(path).name.strip('.txt'), os.path.getsize(path)) for path in text_files] 

In [21]:
df = pd.DataFrame(file_sizes, columns=['labels', 'values'])

In [26]:
df = df.loc[df['labels'] != 'stackoverflow.com-Posts']

In [27]:
df.iplot(kind='pie',labels='labels',values='values', title='Distribution of hosts in Training data')

The differences I note:
    - math.stackexchange.com is over-represented
    - Most other categories seem within a magnitude of their proper distribution
    - There is a long tail of small sites
    

Let's remove a bunch of math.stackexchange.com. We'd like to reduce it from ~2GB to about 700MB.

In [70]:
!wc -l data/stackexchange/math.stackexchange.com.txt

0 data/stackexchange/math.stackexchange.com.txt


In [69]:
# We'd like about 1/3 of the file size
26866122 * (1/3)

8955374.0

In [68]:
# Move to intermediate file while copying
!tail -n 8955374 data/stackexchange/math.stackexchange.com.txt > data/stackexchange/math.stackexchange.com.txt.tmp

In [None]:
# Overwrite original
!mv data/stackexchange/math.stackexchange.com.txt.tmp data/stackexchange/math.stackexchange.com.txt