This is a short notebook to load in a pretrained w2v model and train it on general github data. The vocabulary is built and model is initialized in the src/data/build_w2v_vocab notebook.

In [32]:
from google.cloud import bigquery
from google.oauth2 import service_account
import re
import datetime
import numpy as np
import os
from nltk.tokenize import word_tokenize
import nltk
from gensim.models import Word2Vec
from dotenv import find_dotenv, load_dotenv
import boto3
import sys

vocab_path = "../src/data"
if vocab_path not in sys.path:
    sys.path.insert(1, vocab_path)

from w2v_preprocess import remove_quotes, is_bot, is_english, preprocess, is_punc # noqa

nltk.download('punkt')

[nltk_data] Downloading package punkt to /home/atersaak/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [15]:
load_dotenv(find_dotenv())

True

In [16]:
# whether to use ceph or store locally

use_ceph = True

if use_ceph:
    s3_endpoint_url = os.environ["OBJECT_STORAGE_ENDPOINT_URL"]
    s3_access_key = os.environ["AWS_ACCESS_KEY_ID"]
    s3_secret_key = os.environ["AWS_SECRET_ACCESS_KEY"]
    s3_bucket = os.environ["OBJECT_STORAGE_BUCKET_NAME"]

    s3 = boto3.client(
        service_name="s3",
        aws_access_key_id=s3_access_key,
        aws_secret_access_key=s3_secret_key,
        endpoint_url=s3_endpoint_url,
    )

Below we copy over a couple functions from the vocab building notebook that could not easily be copied over.

In [17]:
# save key .json file in the github labeler root
# project id on bigquery account should match

credentials = service_account.Credentials.from_service_account_file(
    '../github-issue-data-extraction-key.json')

project_id = 'github-issue-data-extraction'
client = bigquery.Client(credentials= credentials, project=project_id)

In [18]:
def get_data_for_day(day):
    """
    Pass in a datetime object and a dataframe of all the issue data from that day will be returned.
    """
    date = day.strftime('%Y%m%d')
    response = client.query(f"""SELECT JSON_EXTRACT(payload, '$.issue.title') as title,
                                JSON_EXTRACT(payload, '$.issue.body') as body,
                                JSON_EXTRACT(payload, '$.issue.html_url') as url,
                                JSON_EXTRACT(payload, '$.issue.user.login') as actor
                                FROM githubarchive.day.{date}
                                WHERE type = 'IssuesEvent' AND JSON_EXTRACT(payload, '$.action') = '"opened"'
                                """)
    df = response.to_dataframe()
    return df


def process_df(df):
    """Clean the dataframe a bit."""
    for col in df.columns:
        df[col] = df[col].apply(remove_quotes)
        df = df[~df[col].apply(is_bot)]
    df = df[~df[col].apply(is_bot)]
    return df

First we load in the previous data. Note it saves as multiple files so we must load all of them in.

In [20]:
pattern = re.compile('github-labeler/w2v/.*')

buck = boto3.resource(
    service_name="s3",
    aws_access_key_id=s3_access_key,
    aws_secret_access_key=s3_secret_key,
    endpoint_url=s3_endpoint_url,
)

keys = []

for obj in buck.Bucket(s3_bucket).objects.all():
    if pattern.match(obj.key):
        keys.append(obj.key)

keys = [os.path.basename(key) for key in keys]

In [8]:
if use_ceph:
    for key in keys:
        response = s3.get_object(
            Bucket=s3_bucket,
            Key=f"github-labeler/w2v/{key}",
        )
        with open(f'../models/{key}' ,'wb') as f:
            for i in response['Body']:
                f.write(i)

w = Word2Vec.load('../models/w2v.model')

We define some useful functions, namely a function to filter unknown words, a function to save w2v models, and a training function to train our model for a given day.

In [21]:
def in_set(word):
    """Check if the word is in our set."""
    if word in w.wv:
        return word
    else:
        return '_unknown_'

In [22]:
def save_w2v(w):
    """Save your w2v model to ceph or just c."""
    w.save('../models/w2v.model')
    if use_ceph:
        for file in os.listdir('../models/'):
            if 'w2v.model' in file:
                s3.upload_file(
                    Bucket=s3_bucket,
                    Key=f"github-labeler/w2v/{file}",
                    Filename=f'../models/{file}',
                )
                os.remove(f'../models/{file}')
    return True

In [23]:
def train_w2v_on_day(w, day):
    df = get_data_for_day(day)
    df = process_df(df)
    df['proc'] = df['title'].fillna(' ') + ' SEP ' + df['body'].fillna(' ')
    df['proc'] = df['proc'].apply(preprocess)
    df = df[df['proc'].apply(is_english)]
    df['proc'] = df['proc'].apply(lambda x: x.lower())
    inp = df['proc'].apply(word_tokenize).values
    inp = [[in_set(word) for word in issue if not is_punc(word)] for issue in inp]
    w.train(inp, total_examples = len(inp), epochs = 1)

We will be training the model on random days in a certain interval. We will save the range of these dates as well as the days we have already trained on. This is because the process will likely get interrupted before finishing, especially if being run locally. We come up with methods to save and retrieve the dates locally or from ceph.

In [24]:
# see if dates file exists

dates_exist = True

if use_ceph:
    try:
        s3.get_object(
            Bucket=s3_bucket,
            Key="github-labeler/w2v_dates.txt",
        )
        with open('w2v_dates.txt' ,'wb') as f:
            for i in response['Body']:
                f.write(i)
    except s3.exceptions.NoSuchKey:
        dates_exist = False

else:
    if not os.path.isfile('w2v_dates.txt'):
        dates_exist = False

# read them in if they exist

if dates_exist:
    with open('w2v_dates.txt', 'r') as f:
        dates = f.readlines()
        dates = [d.replace('\n', '') for d in dates if d]
        dates = [datetime.datetime.strptime(d, '%Y-%m-%d') for d in dates]

# start 10 days ago and go back 2 years

else:
    interval_end = datetime.datetime.today().date() - datetime.timedelta(days = 10)
    days = 365*2
    all_days = [interval_end - datetime.timedelta(days = num) for num in range(days)]


def save_dates(days):
    """Convert a list of datetimes to strings and save them."""
    with open('w2v_dates.txt', 'w') as f:
        days = [datetime.datetime.strftime(d, '%Y-%m-%d') for d in days]
        f.write('\n'.join(days))
    if use_ceph:
        s3.upload_file(
            Bucket=s3_bucket,
            Key="github-labeler/w2v_dates.txt",
            Filename='w2v_dates.txt',
        )

In [None]:
initial = len(all_days)

while all_days:
    random_day = np.random.choice(all_days)
    all_days.remove(random_day)
    train_w2v_on_day(w, random_day)
    if len(all_days) % 3 == 0:
        print(f'{initial - len(all_days)} days trained on')
        save_w2v(w)
        save_dates(all_days)