# SI 630: Homework 2: Word2Vec

This homework will have you implementing word2vec using PyTorch and let you familiarize yourself with building more complex neural networks and the larger PyTorch development infrastructure.

Broadly, this homework consists of a few major parts:
1. Implement a `Corpus` class that will load the dataset and convert it to a sequence of token ids
2. Implement negative sampling to select tokens to be used as negative examples of words in the context
3. Create your dataset of positive and negative examples per context and load it into PyTorch's `DataLoader` to use for sampling
4. Implement a `Word2Vec` class that is a PyTorch neural network
5. Implement a training loop that samples a _batch_ of target words and their respective positive/negative context words
6. Implement rare word removal and frequent word subsampling
7. Run your model on the full dataset for at least one epoch
8. Do the exploratory parts of the homework
9. Make a copy of this notebook and change your implementation so it learns word vectors with less bias

After Step 5, you should be able to run your word2vec implementation on a small dataset and verify that it's learning correctly. Once you can verify everything is working, proceed with steps 6 and beyond. **Please note that this list is a general sketch and the homework PDF has the full list/description of to-dos and all your deliverables.**

In [1]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from tqdm.auto import tqdm, trange
from collections import Counter
import random
from torch import optim

from torch.utils.tensorboard import SummaryWriter

# Helpful for computing cosine similarity--Note that this is NOT a similarity!
from scipy.spatial.distance import cosine

# Handy command-line argument parsing
import argparse

# Sort of smart tokenization
from nltk.tokenize import RegexpTokenizer
from nltk.tokenize import MWETokenizer

# We'll use this to save our models
from gensim.models import KeyedVectors

#
# IMPORTANT NOTE: Always set your random seeds when dealing with stochastic
# algorithms as it lets your bugs be reproducible and (more importantly) it lets
# your results be reproducible by others.
#
random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234)

<torch._C.Generator at 0x7ff898f941f0>

## Create a class to hold the data

Before we get to training word2vec, we'll need to process the corpus into some representation. The `Corpus` class will handle much of the functionality for corpus reading and keeping track of which word types belong to which ids. The `Corpus` class will also handle the crucial functionality of generating negative samples for training (i.e., randomly-sampled words that were not in the target word's context).

Some parts of this class can be completed after you've gotten word2vec up and running, so see the notes below and the details in the homework PDF.

In [2]:
class Corpus:
    
    def __init__(self):

        self.tokenizer = RegexpTokenizer(r'\w+')
        
        # These state variables become populated with function calls
        #
        # 1. load_data()
        # 2. generate_negative_sampling_table()
        #
        # See those functions for how the various values get filled in

        self.word_to_index = {} # word to unique-id
        self.index_to_word = {} # unique-id to word

        # How many times each word occurs in our data after filtering
        self.word_counts = Counter()

        # A utility data structure that lets us quickly sample "negative"
        # instances in a context. This table contains unique-ids
        self.negative_sampling_table = []
        
        # The dataset we'll use for training, as a sequence of unqiue word
        # ids. This is the sequence across all documents after tokens have been
        # randomly subsampled by the word2vec preprocessing step
        self.full_token_sequence_as_ids = None
        
    def tokenize(self, text):
        '''
        Tokenize the document and returns a list of the tokens
        '''
        
        return self.tokenizer.tokenize(text)        

    def load_data(self, file_name, min_token_freq):
        '''
        Reads the data from the specified file as long long sequence of text
        (ignoring line breaks) and populates the data structures of this
        word2vec object.
        '''

        # Step 1: Read in the file and create a long sequence of tokens for
        # all tokens in the file
        tokenizer2 = MWETokenizer()
        f=open('bio-mwes.txt','r')
        lines=f.readlines()
        for line in lines:
            arr=line.split(' ')
            arr[-1]=arr[-1][:-1]
#             print(arr)
            tokenizer2.add_mwe(tuple(arr))
        f.close()
        
        all_tokens = []
        print('Reading data and tokenizing')
        
        tokenizer = RegexpTokenizer('\w+|\$[\d\.]+')
        fp = open(file_name, 'r')
        lines=fp.readlines()
        for line in lines:
            tokens=tokenizer.tokenize(line)
            tokens=tokenizer2.tokenize(tokens)
#             print(tokens)
            tokens=[i.lower() for i in tokens]
            all_tokens+=tokens
#         print(all_tokens)
    
        # Step 2: Count how many tokens we have of each type
        print('Counting token frequencies')
        
        all_tokens_dict=dict(Counter(all_tokens))
        

        # Step 3: Replace all tokens below the specified frequency with an <UNK>
        # token. 
        #
        # NOTE: You can do this step later if needed
        print("Performing minimum thresholding")
        
        delete_keys1=set()
        all_tokens_dict['<UNK>']=0
        for k,v in all_tokens_dict.items():
            if v<min_token_freq and k!='<UNK>':
                all_tokens_dict['<UNK>']+=1
                delete_keys1.add(k)
        
        for k in delete_keys1:
            del all_tokens_dict[k]
                
        
        

        # Step 4: update self.word_counts to be the number of times each word
        # occurs (including <UNK>)
        self.word_counts=all_tokens_dict
        
        
        # Step 5: Create the mappings from word to unique integer ID and the
        # reverse mapping.
        idx=0
        for k in all_tokens_dict.keys():
            self.word_to_index[k]=idx
            self.index_to_word[idx]=k
            idx+=1
        
        
        # Step 6: Compute the probability of keeping any particular *token* of a
        # word in the training sequence, which we'll use to subsample. This subsampling
        # avoids having the training data be filled with many overly common words
        # as positive examples in the context
        length=len(all_tokens)
        token_set=set()
        delete_keys=set()
        for k,v in all_tokens_dict.items():
            p=v/length
            keep=(np.sqrt(p/0.001)+1)*0.001/(np.sqrt(p))
#             if k=='the' or k=='there' or k=='monday':
#                 print(k,keep)
#             if keep>=0.0358:
            if keep>=random.random()/50:
                token_set.add(k)
                
            
#         for k in delete_keys:
#             print(k)
#             if k in all_tokens_dict.keys():
#                 del all_tokens_dict[k]
#                 del self.index_to_word[self.word_to_index[k]]
#                 del self.word_to_index[k]
            
            
        print("============")
                        
        # Step 7: process the list of tokens (after min-freq filtering) to fill
        # a new list self.full_token_sequence_as_ids where 
        #
        # (1) we probabilistically choose whether to keep each *token* based on the
        # subsampling probabilities (note that this does not mean we drop
        # an entire word!) and 
        #
        # (2) all tokens are convered to their unique ids for faster training.
        #
        # NOTE: You can skip the subsampling part and just do step 2 to get
        # your model up and running.
            
        # NOTE 2: You will perform token-based subsampling based on the probabilities in
        # word_to_sample_prob. When subsampling, you are modifying the sequence itself 
        # (like deleting an item in a list). This action effectively makes the context
        # window  larger for some target words by removing context words that are common
        # from a particular context before the training occurs (which then would now include
        # other words that were previously just outside the window).

        
        unk_id=self.word_to_index['<UNK>']
        self.full_token_sequence_as_ids=[]
        for token in all_tokens:
            if token in token_set:
                self.full_token_sequence_as_ids.append(self.word_to_index[token])
            else:
                self.full_token_sequence_as_ids.append(unk_id)
            
        

        # Helpful print statement to verify what you've loaded
        print('Loaded all data from %s; saw %d tokens (%d unique)' \
              % (file_name, len(self.full_token_sequence_as_ids),
                 len(self.word_to_index)))

        
    def generate_negative_sampling_table(self, exp_power=0.75, table_size=1e6):
        '''
        Generates a big list data structure that we can quickly randomly index into
        in order to select a negative training example (i.e., a word that was
        *not* present in the context). 
        '''       
        
        # Step 1: Figure out how many instances of each word need to go into the
        # negative sampling table. 
        #
        # HINT: np.power and np.fill might be useful here        
        print("Generating sampling table")

        keep_token_dict={}
        total_p=0
        for k in self.index_to_word.keys():
            keep_token_dict[k]=np.power(self.word_counts[self.index_to_word[k]],exp_power)
            total_p+=keep_token_dict[k]
            
        
        for k in self.index_to_word.keys():
            keep_token_dict[k]=keep_token_dict[k]/total_p
        
        
        ct=0
        for idx, val in keep_token_dict.items():
            i=0
            max_val = round(val*table_size)

            while ct < table_size and i <= max_val:
#                 print(idx)
                self.negative_sampling_table.append(idx)
                ct+=1
                i+=1
        
        
        
        # Step 2: Create the table to the correct size. You'll want this to be a
        # numpy array of type int


        # Step 3: Fill the table so that each word has a number of IDs
        # proportionate to its probability of being sampled.
        #
        # Example: if we have 3 words "a" "b" and "c" with probabilites 0.5,
        # 0.33, 0.16 and a table size of 6 then our table would look like this
        # (before converting the words to IDs):
        #
        # [ "a", "a", "a", "b", "b", "c" ]
        #
        


    def generate_negative_samples(self, cur_context_word_id, num_samples):
        '''
        Randomly samples the specified number of negative samples from the lookup
        table and returns this list of IDs as a numpy array. As a performance
        improvement, avoid sampling a negative example that has the same ID as
        the current positive context word.
        '''

        results = []

        # Create a list and sample from the negative_sampling_table to
        # grow the list to num_samples, avoiding adding a negative example that
        # has the same ID as the current context_word
        


        for i in range(num_samples):
            while len(results)!=num_samples:
                idx=random.randint(0,len(self.negative_sampling_table)-1)
                if self.negative_sampling_table[idx] not in cur_context_word_id:
                    results.append(self.negative_sampling_table[idx])
        
#         print(self.index_to_word[cur_context_word_id])
#         for i in results:
#             print(self.index_to_word[i])
        
        return results

## Create the corpus

Now that we have code to turn the text into training data, let's do so. We've provided several files for you to help:

* `wiki-bios.DEBUG.txt` -- use this to debug your corpus reader
* `wiki-bios.10k.txt` -- use this to debug/verify the whole word2vec works
* `wiki-bios.med.txt` -- use this when everything works to generate your vectors for later parts
* `wiki-bios.HUGE.txt.gz` -- _do not use this_ unless (1) everything works and (2) you really want to test/explore. This file is not needed at all to do your homework.

We recommend startin to debug with the first file, as it is small and fast to load (quicker to find bugs). When debugging, we recommend setting the `min_token_freq` argument to 2 so that you can verify that part of the code is working but you still have enough word types left to test the rest.

You'll use the remaining files later, where they're described.

In the next cell, create your `Corpus`, read in the data, and generate the negative sampling table.

In [3]:
corpus = Corpus()
corpus.load_data('wiki-bios.med.txt', 5)
corpus.generate_negative_sampling_table()
# corpus.generate_negative_samples(325,10)

Reading data and tokenizing
Counting token frequencies
Performing minimum thresholding
Loaded all data from wiki-bios.med.txt; saw 22220944 tokens (110331 unique)
Generating sampling table


## Generate the training data

Once we have the corpus ready, we need to generate our training dataset. Each instance in the dataset is a target word and positive and negative examples of contexts words. Given the target word as input, we'll want to predict (or not predict) these positive and negative context words as outputs using our network. Your task here is to create a python `list` of instances. 

Your final training data should be a list of tuples in the format ([target_word_id], [word_id_1, ...], [predicted_labels]), where each item in the list is a list:
1. The first item is a list consisting only of the target word's ID.
2. The second item is a list of word ids for both context words and negative samples 
3. The third item is a list of labels to predicted for each of the word ids in the second list (i.e., `1` for context words and `0` for negative samples). 

You will feed these tuples into the PyTorch `DatasetLoader` later that will do the converstion to `Tensor` objects. You will need to make sure that all of the lists in each tuple are `np.array` instances and are not plain python lists for this `Tensor` converstion to work.

In [4]:
window_size = 2
num_negative_samples_per_target = 2

training_data = []
    
# Loop through each token in the corpus and generate an instance for each, 
# adding it to training_data
ct=0
unk_id=corpus.word_to_index['<UNK>']

for i in tqdm(range(len(corpus.full_token_sequence_as_ids))):
#     print(i)
    idx=corpus.full_token_sequence_as_ids[i]
    
    if idx==unk_id:
        continue
    elif i==0:
        positive_samples_list=corpus.full_token_sequence_as_ids[1:1+window_size]
        negative_samples_num=2*window_size*num_negative_samples_per_target+2
    elif i==1:
        positive_samples_list=corpus.full_token_sequence_as_ids[:1]+corpus.full_token_sequence_as_ids[2:4]
        negative_samples_num=2*window_size*num_negative_samples_per_target+1
    elif i==len(corpus.full_token_sequence_as_ids)-2:
        positive_samples_list=corpus.full_token_sequence_as_ids[i-2:i]+corpus.full_token_sequence_as_ids[i+1:]
        negative_samples_num=2*window_size*num_negative_samples_per_target+1
    elif i==len(corpus.full_token_sequence_as_ids)-1:
        positive_samples_list=corpus.full_token_sequence_as_ids[i-2:i]
        negative_samples_num=2*window_size*num_negative_samples_per_target+2
    else:
        positive_samples_list=[corpus.full_token_sequence_as_ids[i-2],corpus.full_token_sequence_as_ids[i-1],corpus.full_token_sequence_as_ids[i+1],corpus.full_token_sequence_as_ids[i+2]]
        negative_samples_num=2*window_size*num_negative_samples_per_target
    
    
    negative_samples_list=corpus.generate_negative_samples(positive_samples_list, negative_samples_num)
    predicted_labels=[1]*len(positive_samples_list)+[0]*negative_samples_num
    training_data.append((np.array([idx]),np.array(positive_samples_list+negative_samples_list),np.array(predicted_labels)))
    # For exach target word in our dataset, select context words 
    # within +/- the window size in the token sequence
    
    # For each positive target, we need to select negative examples of
    # words that were not in the context. Use the num_negative_samples_per_target
    # hyperparameter to generate these, using the generate_negative_samples()
    # method from the Corpus class

    # NOTE: this part might not make sense until later when you do the training 
    # so feel free to revisit it to see why it happens.
    #
    # Our training will use batches of instances together (compare that 
    # with HW1's SGD that used one item at a time). PyTorch will require
    # that all instances in a batches have the same size, which creates an issue
    # for us here since the target wordss at the very beginning or end of the corpus
    # have shorter contexts. 
    # 
    # To work around these edge-cases, we need to ensure that each instance has
    # the same size, which means it needs to have the same number of positive
    # and negative examples. Since we are short on positive examples here (due
    # to the edge of the corpus), we can just add more negative samples.
    #
    # YOUR TASK: determine what is the maximum number of context words (positive
    # and negative) for any instance and then, for instances that have fewer than
    # this number of context words, add in negative examples.
    #
    # NOTE: The maximum is fixed, so you can precompute this outside the loop
    # ahead of time.
    

  0%|          | 0/22220944 [00:00<?, ?it/s]

## Create the network

We'll create a new neural network as a subclass of `nn.Module` like we did in Homework 1. However, _unlike_ the network you built in Homework 1, we do not need to used linear layers to implement word2vec. Instead, we will use PyTorch's `Emedding` class, which maps an index (e.g., a word id in this case) to an embedding. 

Roughly speaking, word2vec's network makes a prediction by computing the dot product of the target word's embedding and a context word's embedding and then passing this dot product through the sigmoid function ($\sigma$) to predict the probability that the context word was actually in the context. The homework write-up has lots of details on how this works. Your `forward()` function will have to implement this computation.

In [30]:
class Word2Vec(nn.Module):
    
    def __init__(self, vocab_size, embedding_size):
        super(Word2Vec, self).__init__()
        
        # Save what state you want and create the embeddings for your
        # target and context words
        self.target_embeddings = nn.Embedding(vocab_size, embedding_size)
        self.context_embeddings = nn.Embedding(vocab_size, embedding_size)
        
        # Once created, let's fill the embeddings with non-zero random
        # numbers. We need to do this to get the training started. 
        #
        # NOTE: Why do this? Think about what happens if all the embeddings
        # are all zeros initially. What would the predictions look like for
        # word2vec with these embeddings and how would the updated work?
        
        self.init_emb(init_range=0.5/vocab_size)
        
    def init_emb(self, init_range):
        
        # Fill your two embeddings with random numbers uniformly sampled
        # between +/- init_range
        
        self.target_embeddings.weight.data.uniform_(-init_range, init_range)
        self.context_embeddings.weight.data.uniform_(-init_range,init_range)
        
    def forward(self, target_word_id, context_word_ids):
        ''' 
        Predicts whether each context word was actually in the context of the target word.
        The input is a tensor with a single target word's id and a tensor containing each
        of the context words' ids (this includes both positive and negative examples).
        '''
        
        v_t=self.target_embeddings(target_word_id)
        v_c=self.context_embeddings(context_word_ids)

        hid=torch.matmul(v_t,torch.transpose(v_c,2,1))
        out=torch.sigmoid(hid)
        return out
        
        # NOTE 1: This is probably the hardest part of the homework, so you'll
        # need to figure out how to do the dot-product between embeddings and return
        # the sigmoid. Be prepared for lots of debugging. For some reference,
        # our implementation is three lines and really the hard part is just
        # the last line. However, it's usually a matter of figuring out what
        # that one line looks like that ends up being the hard part.
        
        # NOTE 2: In this homework you'll be dealing with *batches* of instances
        # rather than a single instance at once. PyTorch mostly handles this
        # seamlessly under the hood for you (which is very nice) but batching
        # can show in weird ways and create challenges in debugging initially.
        # For one, your inputs will get an extra dimension. So, for example,
        # if you have a batch size of 4, your input for target_word_id will
        # really be 4 x 1. If you get the embeddings of those targets,
        # it then becomes 4x50! The same applies to the context_word_ids, except
        # that was alreayd a list so now you have things with shape 
        #
        #    (batch x context_words x embedding_size)
        #
        # One of your tasks will be to figure out how to get things lined up
        # so everything "just works". When it does, the code looks surprisingly
        # simple, but it might take a lot of debugging (or not!) to get there.
        
        # NOTE 3: We *strongly* discourage you from looking for existing 
        # implementations of word2vec online. Sadly, having reviewed most of the
        # highly-visible ones, they are actually wrong (wow!) or are doing
        # inefficient things like computing the full softmax instead of doing
        # the negative sampling. Looking at these will likely leave you more
        # confused than if you just tried to figure it out yourself.
        
        # NOTE 4: There many ways to implement this, some more efficient
        # than others. You will want to get it working first and then
        # test the timing to see how long it takes. As long as the
        # code works (vector comparisons look good) you'll receive full
        # credit. However, very slow implementations may take hours(!)
        # to converge so plan ahead.
        
        
        # Hint 1: You may want to review the mathematical operations on how
        # to compute the dot product to see how to do these
        
        # Hint 2: the "dim" argument for some operations may come in handy,
        # depending on your implementation
        
        # Hint 3: printing the shape of the tensors can come in very handy when
        # debugging to see where things aren't lining up
           
        # TODO: Implement the forward pass of word2vec
        
    

## Train the network!

Now that you have data in the right format and a neural network designed, it's time to train the network and see if it's all working. The trainin code will look surprisingly similar at times to your pytorch code from Homework 1 since all networks share the same base training setup. However, we'll add a few new elements to get you familiar with more common training techniques. 

For all steps, be sure to use the hyperparameters values described in the write-up.

1. Initialize your optimizer and loss function 
2. Create your network
3. Load your dataset into PyTorch's `DataLoader` class, which will take care of batching and shuffling for us (yay!)
4. Create a new `SummaryWriter` to periodically write our running-sum of the loss to a tensorboard
5. Train your model 

Two new elements show up. First, we'll be using `DataLoader` which is going to sample data for us and put it in a batch (and also convert the data to `Tensor` objects. You can iterate over the batches and each iteration will return all the items eventually, one batch at a time (a full epoch's worth).

The second new part is using `tensorboard`. As you might have noticed in Homework 1, training neural models can take some time. [TensorBoard](https://www.tensorflow.org/tensorboard/) is a handy web-based view that you can check during training to see how the model is doing. We'll use it here and periodically log a running sum of the loss after a set number of steps. The Homework write up has a plot of what this looks like. We'll be doing something simple here with tensorboard but it will come in handy later as you train larger models (for longer) and may want to visually check if your model is converging. TensorBoard was initially written for another deep learning framework, TensorFlow, but proved so useful it was ported to work in PyTorch too and is [easy to integrate](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html).

To start training, we recommend training on the `wiki-bios.10k.txt` dataset. This data is small enough you can get through an epoch in a few minutes (or less) while still being large enough you can test whether the model is learning anything by examining common words. Below this cell we've added a few helper functions that you can use to debug and query your model. In particular, the `get_neighbors()` function is a great way to test: if your model has learned anything, the nearest neighbors for common words should seem reasonable (without having to jump through mental hoops). An easy word to test on the `10k` data is "january" which should return month-related words as being most similar.

**NOTE**: Since we're training biographies, the text itself will be skewed towards words likely to show up biographices--which isn't necessary like "regular" text. You may find that your model has few instances of words you think are common, or that the model learns poor or unusual neighbors for these. When querying the neighbors, it can help to think of which words you think are likely to show up in biographies on Wikipedia and use those as probes to see what the model has learned.

Once you're convinced the model is learning, switch to the `med` data and train your model as specified in the PDF. Once trained, save your model using the `save()` function at the end of the notebook. This function records your data in a common format for word2vec vectors and lets you load the vectors into other libraries that have more advanced functionality. In particular, you can use the [gensim](https://radimrehurek.com/gensim/auto_examples/tutorials/run_word2vec.html) code in other notebook included to explore the vectors and do simple vector analogies.

In [31]:
def get_neighbors(model, word_to_index, target_word):
    """ 
    Finds the top 10 most similar words to a target word
    """
    outputs = []
    for word, index in tqdm(word_to_index.items(), total=len(word_to_index)):
        similarity = compute_cosine_similarity(model, word_to_index, target_word, word)
        result = {"word": word, "score": similarity}
        outputs.append(result)

    # Sort by highest scores
    neighbors = sorted(outputs, key=lambda o: o['score'], reverse=True)
    return neighbors[1:1000]

def compute_cosine_similarity(model, word_to_index, word_one, word_two):
    '''
    Computes the cosine similarity between the two words
    '''
    try:
        word_one_index = word_to_index[word_one]
        word_two_index = word_to_index[word_two]
    except KeyError:
#         print("============")
#         print(word_one,word_two)
        return 0

#     print(word_one_index,word_two_index)
    embedding_one = model.target_embeddings(torch.LongTensor([word_one_index]))
    embedding_two = model.target_embeddings(torch.LongTensor([word_two_index]))
#     print(embedding_one,embedding_two,word_one_index,word_two_index)
    similarity = 1 - abs(float(cosine(embedding_one.detach().numpy(),
                                      embedding_two.detach().numpy())))
    return similarity

### Reduce bias

In [32]:
def bias_evaluation(model):
    man_woman_sim = (compute_cosine_similarity(model, corpus.word_to_index,"man", "woman")\
        +compute_cosine_similarity(model, corpus.word_to_index,"he", "she"))/2
    er_ress_sim = compute_cosine_similarity(model, corpus.word_to_index,"actor", "actress")\
                +compute_cosine_similarity(model, corpus.word_to_index,"actor", "actress")\
        +compute_cosine_similarity(model, corpus.word_to_index,"waiter", "waitress")\
        +compute_cosine_similarity(model, corpus.word_to_index,"actor", "actress")\
        +compute_cosine_similarity(model, corpus.word_to_index,"steward", "stewardess")
    er_ress_sim=er_ress_sim/5
    return 1 - max(max(man_woman_sim, 0),er_ress_sim)


In [33]:
batch_size = 32
min_token_freq = 5
epochs = 1
embedding_size = 50
learning_rate=5e-5
loss_func = nn.CrossEntropyLoss()
model = Word2Vec(len(corpus.word_counts), embedding_size)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

train_loader = DataLoader(dataset=training_data, batch_size=batch_size, shuffle=True)
writer = SummaryWriter()
# HINT: wrapping the epoch/step loops in nested tqdm calls is a great way
# to keep track of how fast things are and how much longer training will take

min_loss = 2000
patience =30
trigger_times = 0
stop_model=False

model.train()

for epoch in range(epochs):

    loss_sum = 0.0
    
    # TODO: use your DataLoader to iterate over the data
    for step, data in enumerate(tqdm(train_loader)):
        
        if stop_model:
            break
        
        # NOTE: since you created the data as a tuple of three np.array instances,
        # these have now been converted to Tensor objects for us
        target_ids, context_ids, labels = data    
        # TODO: Fill in all the training details here
        
        # TODO: Based on the details in the Homework PDF, periodically
        # report the running-sum of the loss to tensorboard. Be sure
        # to reset the running sum after reporting it.
        
        # TODO: it can be helpful to add some early stopping here after
        # a fixed number of steps (e.g., if step > max_steps)
        optimizer.zero_grad()
        y = model(target_ids, context_ids).view(len(target_ids), -1)
        
        # bias evaulation every 20 steps
        if step%20==0:
            loss = loss_func(y, labels.float())+bias_evaluation(model)
        else:
            loss = loss_func(y, labels.float())
        
#         loss = loss_func(y, labels.float())+bias_evaluation(model)
        loss_sum += loss.item()
        loss.backward()
        optimizer.step()
        if step % 100 == 99:
            writer.add_scalar("Loss/train", loss_sum, (step+1)/100)
            print("step: {}, loss: {}".format((step+1)/100, loss_sum))

            # Early Stopping
#             if loss_sum > min_loss:
#                 trigger_times += 1
#                 print('trigger times:', trigger_times)
#                 if trigger_times >= patience:
#                     print('Early stopping!\nStart to test process.')
#                     writer.close()
#                     model.eval()
#                     stop_model=True
#             else:
#                 print('trigger times: 0')
#                 trigger_times = 0
#             min_loss = min(loss_sum,min_loss)
            loss_sum = 0
        
        
        

if not stop_model:          
    writer.close()
# once you finish training, it's good practice to switch to eval.
    model.eval()

  0%|          | 0/680814 [00:00<?, ?it/s]

step: 1.0, loss: 998.1401433944702
step: 2.0, loss: 995.9329824447632
step: 3.0, loss: 994.2385625839233
step: 4.0, loss: 994.141167640686
step: 5.0, loss: 994.3312759399414
step: 6.0, loss: 994.2990665435791
step: 7.0, loss: 994.0350112915039
step: 8.0, loss: 993.7798357009888
step: 9.0, loss: 993.7109670639038
step: 10.0, loss: 993.6447715759277
step: 11.0, loss: 993.6056575775146
step: 12.0, loss: 993.4813098907471
step: 13.0, loss: 993.3919477462769
step: 14.0, loss: 993.2496223449707
step: 15.0, loss: 993.1263790130615
step: 16.0, loss: 993.0352048873901
step: 17.0, loss: 992.9101400375366
step: 18.0, loss: 992.8140020370483
step: 19.0, loss: 992.6486930847168
step: 20.0, loss: 992.4874820709229
step: 21.0, loss: 992.3602924346924
step: 22.0, loss: 992.2697677612305
step: 23.0, loss: 991.9124612808228
step: 24.0, loss: 991.766923904419
step: 25.0, loss: 991.7572374343872
step: 26.0, loss: 991.5710077285767
step: 27.0, loss: 991.5125408172607
step: 28.0, loss: 991.2476806640625
ste

step: 226.0, loss: 969.3612031936646
step: 227.0, loss: 969.143964767456
step: 228.0, loss: 970.1375675201416
step: 229.0, loss: 968.9046154022217
step: 230.0, loss: 969.6152925491333
step: 231.0, loss: 969.399938583374
step: 232.0, loss: 970.3990602493286
step: 233.0, loss: 969.4069957733154
step: 234.0, loss: 969.5409822463989
step: 235.0, loss: 969.4139881134033
step: 236.0, loss: 968.5234308242798
step: 237.0, loss: 968.6493577957153
step: 238.0, loss: 969.040020942688
step: 239.0, loss: 969.0174369812012
step: 240.0, loss: 967.5755271911621
step: 241.0, loss: 969.2416944503784
step: 242.0, loss: 969.0648279190063
step: 243.0, loss: 969.2906551361084
step: 244.0, loss: 968.1310768127441
step: 245.0, loss: 968.3900203704834
step: 246.0, loss: 969.4932661056519
step: 247.0, loss: 968.4427671432495
step: 248.0, loss: 968.566596031189
step: 249.0, loss: 969.6232147216797
step: 250.0, loss: 967.8563852310181
step: 251.0, loss: 968.6176233291626
step: 252.0, loss: 968.3653335571289
step:

step: 449.0, loss: 960.0613126754761
step: 450.0, loss: 959.0369472503662
step: 451.0, loss: 960.4486694335938
step: 452.0, loss: 960.1271924972534
step: 453.0, loss: 959.3364210128784
step: 454.0, loss: 959.1240272521973
step: 455.0, loss: 960.0269594192505
step: 456.0, loss: 958.7825269699097
step: 457.0, loss: 959.3619508743286
step: 458.0, loss: 959.5500965118408
step: 459.0, loss: 958.859787940979
step: 460.0, loss: 959.4976625442505
step: 461.0, loss: 959.4323167800903
step: 462.0, loss: 958.0067157745361
step: 463.0, loss: 959.1380434036255
step: 464.0, loss: 958.8501987457275
step: 465.0, loss: 958.3387250900269
step: 466.0, loss: 957.5407495498657
step: 467.0, loss: 957.7642126083374
step: 468.0, loss: 958.4733276367188
step: 469.0, loss: 959.3065214157104
step: 470.0, loss: 958.5182466506958
step: 471.0, loss: 959.113317489624
step: 472.0, loss: 958.7238903045654
step: 473.0, loss: 957.6771326065063
step: 474.0, loss: 958.5432615280151
step: 475.0, loss: 959.1093521118164
ste

step: 672.0, loss: 953.6053810119629
step: 673.0, loss: 953.3173770904541
step: 674.0, loss: 952.8842344284058
step: 675.0, loss: 952.2291851043701
step: 676.0, loss: 954.2591648101807
step: 677.0, loss: 952.596396446228
step: 678.0, loss: 954.2597513198853
step: 679.0, loss: 954.2100057601929
step: 680.0, loss: 953.2261009216309
step: 681.0, loss: 953.0515537261963
step: 682.0, loss: 953.4767055511475
step: 683.0, loss: 954.3754634857178
step: 684.0, loss: 952.8189611434937
step: 685.0, loss: 951.5963134765625
step: 686.0, loss: 953.8276510238647
step: 687.0, loss: 953.513671875
step: 688.0, loss: 953.254225730896
step: 689.0, loss: 954.1397876739502
step: 690.0, loss: 953.2795391082764
step: 691.0, loss: 952.9120216369629
step: 692.0, loss: 952.6092767715454
step: 693.0, loss: 952.8358097076416
step: 694.0, loss: 954.0168190002441
step: 695.0, loss: 952.851803779602
step: 696.0, loss: 953.1328439712524
step: 697.0, loss: 953.9074544906616
step: 698.0, loss: 953.5862712860107
step: 69

step: 895.0, loss: 948.9782810211182
step: 896.0, loss: 949.3774757385254
step: 897.0, loss: 949.0826120376587
step: 898.0, loss: 949.0274076461792
step: 899.0, loss: 949.6390171051025
step: 900.0, loss: 949.3440885543823
step: 901.0, loss: 949.0588464736938
step: 902.0, loss: 949.6975650787354
step: 903.0, loss: 947.9730939865112
step: 904.0, loss: 949.8659744262695
step: 905.0, loss: 950.133282661438
step: 906.0, loss: 949.3290948867798
step: 907.0, loss: 949.088963508606
step: 908.0, loss: 948.5068368911743
step: 909.0, loss: 949.4069423675537
step: 910.0, loss: 947.788200378418
step: 911.0, loss: 949.1089220046997
step: 912.0, loss: 948.8480815887451
step: 913.0, loss: 948.3013257980347
step: 914.0, loss: 950.5277729034424
step: 915.0, loss: 948.2202377319336
step: 916.0, loss: 949.2486515045166
step: 917.0, loss: 949.2757930755615
step: 918.0, loss: 948.1622190475464
step: 919.0, loss: 947.757342338562
step: 920.0, loss: 948.1527719497681
step: 921.0, loss: 948.3944463729858
step:

step: 1115.0, loss: 946.3772783279419
step: 1116.0, loss: 945.6233987808228
step: 1117.0, loss: 945.9465351104736
step: 1118.0, loss: 945.0832185745239
step: 1119.0, loss: 947.3339710235596
step: 1120.0, loss: 945.8384923934937
step: 1121.0, loss: 945.7051830291748
step: 1122.0, loss: 948.03249168396
step: 1123.0, loss: 944.1587791442871
step: 1124.0, loss: 945.9944677352905
step: 1125.0, loss: 945.6400232315063
step: 1126.0, loss: 946.303032875061
step: 1127.0, loss: 947.5517997741699
step: 1128.0, loss: 945.0510091781616
step: 1129.0, loss: 946.7101202011108
step: 1130.0, loss: 945.2761220932007
step: 1131.0, loss: 946.5724840164185
step: 1132.0, loss: 944.3688020706177
step: 1133.0, loss: 946.4098825454712
step: 1134.0, loss: 944.7318115234375
step: 1135.0, loss: 946.1574230194092
step: 1136.0, loss: 945.2323188781738
step: 1137.0, loss: 945.6682081222534
step: 1138.0, loss: 945.1878461837769
step: 1139.0, loss: 945.8007802963257
step: 1140.0, loss: 946.5837640762329
step: 1141.0, l

step: 1332.0, loss: 943.599573135376
step: 1333.0, loss: 942.437762260437
step: 1334.0, loss: 944.3764896392822
step: 1335.0, loss: 943.8201065063477
step: 1336.0, loss: 943.2728729248047
step: 1337.0, loss: 943.8827142715454
step: 1338.0, loss: 942.3527698516846
step: 1339.0, loss: 942.5061616897583
step: 1340.0, loss: 945.0101737976074
step: 1341.0, loss: 943.0335235595703
step: 1342.0, loss: 941.3506469726562
step: 1343.0, loss: 943.8094329833984
step: 1344.0, loss: 942.8995342254639
step: 1345.0, loss: 942.6790323257446
step: 1346.0, loss: 943.4656715393066
step: 1347.0, loss: 944.1192207336426
step: 1348.0, loss: 942.7970170974731
step: 1349.0, loss: 943.3947048187256
step: 1350.0, loss: 945.8371019363403
step: 1351.0, loss: 944.2690267562866
step: 1352.0, loss: 942.173565864563
step: 1353.0, loss: 943.3468046188354
step: 1354.0, loss: 943.511420249939
step: 1355.0, loss: 942.428258895874
step: 1356.0, loss: 942.8732433319092
step: 1357.0, loss: 941.7286806106567
step: 1358.0, los

step: 1549.0, loss: 941.8350839614868
step: 1550.0, loss: 941.3098735809326
step: 1551.0, loss: 940.552264213562
step: 1552.0, loss: 941.0564737319946
step: 1553.0, loss: 941.0236959457397
step: 1554.0, loss: 940.6225471496582
step: 1555.0, loss: 942.2609148025513
step: 1556.0, loss: 942.333930015564
step: 1557.0, loss: 940.1365871429443
step: 1558.0, loss: 939.6776151657104
step: 1559.0, loss: 940.6400156021118
step: 1560.0, loss: 941.3751544952393
step: 1561.0, loss: 941.2048244476318
step: 1562.0, loss: 941.1020030975342
step: 1563.0, loss: 939.5781707763672
step: 1564.0, loss: 941.3828134536743
step: 1565.0, loss: 943.4130544662476
step: 1566.0, loss: 941.8953485488892
step: 1567.0, loss: 941.2317295074463
step: 1568.0, loss: 940.0960626602173
step: 1569.0, loss: 940.8871841430664
step: 1570.0, loss: 941.4820909500122
step: 1571.0, loss: 941.4120426177979
step: 1572.0, loss: 941.3966865539551
step: 1573.0, loss: 942.0874576568604
step: 1574.0, loss: 939.3990030288696
step: 1575.0, 

step: 1766.0, loss: 939.7711791992188
step: 1767.0, loss: 939.9133911132812
step: 1768.0, loss: 940.6118278503418
step: 1769.0, loss: 938.4441738128662
step: 1770.0, loss: 940.3109045028687
step: 1771.0, loss: 938.4687585830688
step: 1772.0, loss: 938.5443096160889
step: 1773.0, loss: 941.0863199234009
step: 1774.0, loss: 939.1661405563354
step: 1775.0, loss: 940.5425825119019
step: 1776.0, loss: 939.3945941925049
step: 1777.0, loss: 940.0381889343262
step: 1778.0, loss: 939.4267930984497
step: 1779.0, loss: 940.9048452377319
step: 1780.0, loss: 938.6580457687378
step: 1781.0, loss: 938.166934967041
step: 1782.0, loss: 938.0953321456909
step: 1783.0, loss: 939.2094964981079
step: 1784.0, loss: 939.349365234375
step: 1785.0, loss: 939.4165077209473
step: 1786.0, loss: 939.9634485244751
step: 1787.0, loss: 939.5972127914429
step: 1788.0, loss: 938.1565093994141
step: 1789.0, loss: 939.5702428817749
step: 1790.0, loss: 940.1472082138062
step: 1791.0, loss: 939.357419013977
step: 1792.0, l

step: 1983.0, loss: 939.1442928314209
step: 1984.0, loss: 935.4390535354614
step: 1985.0, loss: 937.895058631897
step: 1986.0, loss: 937.6948614120483
step: 1987.0, loss: 937.0303182601929
step: 1988.0, loss: 937.5826101303101
step: 1989.0, loss: 938.2951030731201
step: 1990.0, loss: 937.6707849502563
step: 1991.0, loss: 936.8913154602051
step: 1992.0, loss: 937.5159435272217
step: 1993.0, loss: 937.2296848297119
step: 1994.0, loss: 938.6783380508423
step: 1995.0, loss: 939.38560962677
step: 1996.0, loss: 938.0964336395264
step: 1997.0, loss: 939.3913803100586
step: 1998.0, loss: 936.7647323608398
step: 1999.0, loss: 936.7242374420166
step: 2000.0, loss: 939.4375562667847
step: 2001.0, loss: 939.3379640579224
step: 2002.0, loss: 936.9405546188354
step: 2003.0, loss: 936.8815212249756
step: 2004.0, loss: 936.8543405532837
step: 2005.0, loss: 937.2838687896729
step: 2006.0, loss: 934.549563407898
step: 2007.0, loss: 937.314866065979
step: 2008.0, loss: 936.5820808410645
step: 2009.0, los

step: 2200.0, loss: 936.4956178665161
step: 2201.0, loss: 938.2271003723145
step: 2202.0, loss: 936.1989297866821
step: 2203.0, loss: 937.9594430923462
step: 2204.0, loss: 936.3687715530396
step: 2205.0, loss: 933.9245719909668
step: 2206.0, loss: 936.6122827529907
step: 2207.0, loss: 936.736762046814
step: 2208.0, loss: 937.145435333252
step: 2209.0, loss: 937.1243181228638
step: 2210.0, loss: 938.1030082702637
step: 2211.0, loss: 936.4945755004883
step: 2212.0, loss: 936.5334196090698
step: 2213.0, loss: 934.7787590026855
step: 2214.0, loss: 935.3615789413452
step: 2215.0, loss: 936.8829069137573
step: 2216.0, loss: 934.2920322418213
step: 2217.0, loss: 937.1644086837769
step: 2218.0, loss: 934.9073524475098
step: 2219.0, loss: 937.3508567810059
step: 2220.0, loss: 935.9512519836426
step: 2221.0, loss: 934.456392288208
step: 2222.0, loss: 936.1596212387085
step: 2223.0, loss: 937.0589761734009
step: 2224.0, loss: 936.1974420547485
step: 2225.0, loss: 937.7817478179932
step: 2226.0, l

step: 2417.0, loss: 934.4697141647339
step: 2418.0, loss: 936.6462383270264
step: 2419.0, loss: 934.9831037521362
step: 2420.0, loss: 935.1771516799927
step: 2421.0, loss: 937.2724905014038
step: 2422.0, loss: 935.958794593811
step: 2423.0, loss: 935.0491008758545
step: 2424.0, loss: 934.0579652786255
step: 2425.0, loss: 933.6872463226318
step: 2426.0, loss: 934.948127746582
step: 2427.0, loss: 935.8576745986938
step: 2428.0, loss: 934.3153085708618
step: 2429.0, loss: 934.6751794815063
step: 2430.0, loss: 936.5122690200806
step: 2431.0, loss: 934.472936630249
step: 2432.0, loss: 935.2442111968994
step: 2433.0, loss: 936.0962238311768
step: 2434.0, loss: 935.3274488449097
step: 2435.0, loss: 934.3767738342285
step: 2436.0, loss: 935.8089332580566
step: 2437.0, loss: 933.7266502380371
step: 2438.0, loss: 934.9649066925049
step: 2439.0, loss: 935.7859945297241
step: 2440.0, loss: 936.0210371017456
step: 2441.0, loss: 935.9684448242188
step: 2442.0, loss: 935.0599384307861
step: 2443.0, l

step: 2634.0, loss: 934.3588027954102
step: 2635.0, loss: 933.0270557403564
step: 2636.0, loss: 933.2917032241821
step: 2637.0, loss: 934.3446359634399
step: 2638.0, loss: 934.9290933609009
step: 2639.0, loss: 934.8968706130981
step: 2640.0, loss: 932.2770481109619
step: 2641.0, loss: 934.6261739730835
step: 2642.0, loss: 936.2178945541382
step: 2643.0, loss: 934.5002479553223
step: 2644.0, loss: 933.4315357208252
step: 2645.0, loss: 933.5829963684082
step: 2646.0, loss: 934.4740552902222
step: 2647.0, loss: 934.7699356079102
step: 2648.0, loss: 935.5895986557007
step: 2649.0, loss: 932.8310861587524
step: 2650.0, loss: 935.2654857635498
step: 2651.0, loss: 934.9812240600586
step: 2652.0, loss: 934.8908138275146
step: 2653.0, loss: 934.6364507675171
step: 2654.0, loss: 934.5891342163086
step: 2655.0, loss: 933.8985586166382
step: 2656.0, loss: 935.5271968841553
step: 2657.0, loss: 934.8487682342529
step: 2658.0, loss: 933.4266576766968
step: 2659.0, loss: 932.5295782089233
step: 2660.0

step: 2851.0, loss: 933.5920963287354
step: 2852.0, loss: 932.6431932449341
step: 2853.0, loss: 933.3189725875854
step: 2854.0, loss: 932.0446538925171
step: 2855.0, loss: 932.7365055084229
step: 2856.0, loss: 932.5995922088623
step: 2857.0, loss: 930.4834518432617
step: 2858.0, loss: 932.8815050125122
step: 2859.0, loss: 933.211049079895
step: 2860.0, loss: 934.5097236633301
step: 2861.0, loss: 932.6766958236694
step: 2862.0, loss: 932.6852436065674
step: 2863.0, loss: 934.0045499801636
step: 2864.0, loss: 930.8081045150757
step: 2865.0, loss: 933.8530855178833
step: 2866.0, loss: 934.9052639007568
step: 2867.0, loss: 934.5434427261353
step: 2868.0, loss: 933.7998371124268
step: 2869.0, loss: 933.1909589767456
step: 2870.0, loss: 933.5907859802246
step: 2871.0, loss: 933.0205812454224
step: 2872.0, loss: 933.9000387191772
step: 2873.0, loss: 932.1073570251465
step: 2874.0, loss: 931.1006298065186
step: 2875.0, loss: 933.9249000549316
step: 2876.0, loss: 932.5892848968506
step: 2877.0,

step: 3068.0, loss: 933.347222328186
step: 3069.0, loss: 931.930757522583
step: 3070.0, loss: 933.099531173706
step: 3071.0, loss: 933.8819065093994
step: 3072.0, loss: 933.226634979248
step: 3073.0, loss: 934.5516109466553
step: 3074.0, loss: 931.8333053588867
step: 3075.0, loss: 933.5544500350952
step: 3076.0, loss: 933.8000068664551
step: 3077.0, loss: 933.2422437667847
step: 3078.0, loss: 930.9967823028564
step: 3079.0, loss: 932.8048267364502
step: 3080.0, loss: 932.7478113174438
step: 3081.0, loss: 932.1985788345337
step: 3082.0, loss: 933.1841220855713
step: 3083.0, loss: 930.8227005004883
step: 3084.0, loss: 932.9833469390869
step: 3085.0, loss: 931.6265487670898
step: 3086.0, loss: 933.5678224563599
step: 3087.0, loss: 931.5357351303101
step: 3088.0, loss: 931.8658018112183
step: 3089.0, loss: 933.1269826889038
step: 3090.0, loss: 932.28076171875
step: 3091.0, loss: 931.8474225997925
step: 3092.0, loss: 933.0799627304077
step: 3093.0, loss: 932.4896583557129
step: 3094.0, loss

step: 3285.0, loss: 931.2797679901123
step: 3286.0, loss: 930.6429061889648
step: 3287.0, loss: 931.309362411499
step: 3288.0, loss: 929.2084722518921
step: 3289.0, loss: 933.1454277038574
step: 3290.0, loss: 931.4534034729004
step: 3291.0, loss: 933.0913982391357
step: 3292.0, loss: 932.0553483963013
step: 3293.0, loss: 932.7287979125977
step: 3294.0, loss: 930.0855474472046
step: 3295.0, loss: 931.2230834960938
step: 3296.0, loss: 930.3442516326904
step: 3297.0, loss: 931.9533081054688
step: 3298.0, loss: 932.3346786499023
step: 3299.0, loss: 932.5460157394409
step: 3300.0, loss: 931.186164855957
step: 3301.0, loss: 933.6751260757446
step: 3302.0, loss: 932.1824102401733
step: 3303.0, loss: 931.2944107055664
step: 3304.0, loss: 932.932505607605
step: 3305.0, loss: 930.2734155654907
step: 3306.0, loss: 930.7019319534302
step: 3307.0, loss: 931.9998712539673
step: 3308.0, loss: 930.5406274795532
step: 3309.0, loss: 931.8671321868896
step: 3310.0, loss: 929.4327774047852
step: 3311.0, l

step: 3502.0, loss: 928.5786428451538
step: 3503.0, loss: 932.2247476577759
step: 3504.0, loss: 930.6684675216675
step: 3505.0, loss: 931.2336254119873
step: 3506.0, loss: 931.2948474884033
step: 3507.0, loss: 930.7725076675415
step: 3508.0, loss: 931.6968050003052
step: 3509.0, loss: 930.6842956542969
step: 3510.0, loss: 932.3677978515625
step: 3511.0, loss: 932.966908454895
step: 3512.0, loss: 930.9737958908081
step: 3513.0, loss: 931.7335062026978
step: 3514.0, loss: 931.1499490737915
step: 3515.0, loss: 931.5610218048096
step: 3516.0, loss: 931.8597240447998
step: 3517.0, loss: 932.3764200210571
step: 3518.0, loss: 930.340238571167
step: 3519.0, loss: 930.6244430541992
step: 3520.0, loss: 930.7793321609497
step: 3521.0, loss: 929.9295139312744
step: 3522.0, loss: 931.6950702667236
step: 3523.0, loss: 931.8751935958862
step: 3524.0, loss: 931.7270841598511
step: 3525.0, loss: 931.6190881729126
step: 3526.0, loss: 929.9074077606201
step: 3527.0, loss: 932.0046977996826
step: 3528.0, 

step: 3719.0, loss: 930.059247970581
step: 3720.0, loss: 930.9950037002563
step: 3721.0, loss: 931.9522390365601
step: 3722.0, loss: 929.32483959198
step: 3723.0, loss: 930.7861490249634
step: 3724.0, loss: 931.1714000701904
step: 3725.0, loss: 929.8603315353394
step: 3726.0, loss: 930.5082588195801
step: 3727.0, loss: 930.6551008224487
step: 3728.0, loss: 932.0551834106445
step: 3729.0, loss: 931.0535087585449
step: 3730.0, loss: 931.2315378189087
step: 3731.0, loss: 929.3781576156616
step: 3732.0, loss: 930.3020372390747
step: 3733.0, loss: 932.7938079833984
step: 3734.0, loss: 929.860445022583
step: 3735.0, loss: 930.8994264602661
step: 3736.0, loss: 927.6839599609375
step: 3737.0, loss: 932.3211870193481
step: 3738.0, loss: 928.2170505523682
step: 3739.0, loss: 930.501802444458
step: 3740.0, loss: 929.9315338134766
step: 3741.0, loss: 929.2802972793579
step: 3742.0, loss: 931.980767250061
step: 3743.0, loss: 929.4492454528809
step: 3744.0, loss: 932.4147100448608
step: 3745.0, loss

step: 3936.0, loss: 932.4290981292725
step: 3937.0, loss: 930.9258365631104
step: 3938.0, loss: 929.6519794464111
step: 3939.0, loss: 930.4063339233398
step: 3940.0, loss: 931.1158103942871
step: 3941.0, loss: 932.312611579895
step: 3942.0, loss: 928.6077117919922
step: 3943.0, loss: 929.696270942688
step: 3944.0, loss: 930.5576782226562
step: 3945.0, loss: 928.913182258606
step: 3946.0, loss: 929.3403387069702
step: 3947.0, loss: 930.3846883773804
step: 3948.0, loss: 927.561206817627
step: 3949.0, loss: 927.7757987976074
step: 3950.0, loss: 929.7991762161255
step: 3951.0, loss: 930.203706741333
step: 3952.0, loss: 931.0035715103149
step: 3953.0, loss: 930.5503549575806
step: 3954.0, loss: 929.7762956619263
step: 3955.0, loss: 930.336012840271
step: 3956.0, loss: 930.9998397827148
step: 3957.0, loss: 929.1129684448242
step: 3958.0, loss: 929.4779920578003
step: 3959.0, loss: 931.6118030548096
step: 3960.0, loss: 928.6691274642944
step: 3961.0, loss: 930.3280172348022
step: 3962.0, loss

step: 4153.0, loss: 928.57386302948
step: 4154.0, loss: 930.424711227417
step: 4155.0, loss: 928.3412799835205
step: 4156.0, loss: 929.4496793746948
step: 4157.0, loss: 928.0555782318115
step: 4158.0, loss: 930.7072858810425
step: 4159.0, loss: 929.7327394485474
step: 4160.0, loss: 929.2076206207275
step: 4161.0, loss: 929.6869955062866
step: 4162.0, loss: 928.5763626098633
step: 4163.0, loss: 928.0258350372314
step: 4164.0, loss: 929.4675817489624
step: 4165.0, loss: 930.1134366989136
step: 4166.0, loss: 927.799220085144
step: 4167.0, loss: 929.1633758544922
step: 4168.0, loss: 930.2700929641724
step: 4169.0, loss: 929.847734451294
step: 4170.0, loss: 928.6156139373779
step: 4171.0, loss: 930.0815916061401
step: 4172.0, loss: 931.226616859436
step: 4173.0, loss: 929.5512142181396
step: 4174.0, loss: 929.3357992172241
step: 4175.0, loss: 929.66175365448
step: 4176.0, loss: 926.959846496582
step: 4177.0, loss: 929.1335496902466
step: 4178.0, loss: 928.9241189956665
step: 4179.0, loss: 9

step: 4370.0, loss: 927.8826265335083
step: 4371.0, loss: 929.7273416519165
step: 4372.0, loss: 928.8756866455078
step: 4373.0, loss: 930.1254768371582
step: 4374.0, loss: 929.4228658676147
step: 4375.0, loss: 929.466962814331
step: 4376.0, loss: 929.3548736572266
step: 4377.0, loss: 927.3959455490112
step: 4378.0, loss: 928.8373908996582
step: 4379.0, loss: 927.1326560974121
step: 4380.0, loss: 928.5503025054932
step: 4381.0, loss: 928.4302129745483
step: 4382.0, loss: 928.881462097168
step: 4383.0, loss: 929.2848262786865
step: 4384.0, loss: 929.7127914428711
step: 4385.0, loss: 927.7454957962036
step: 4386.0, loss: 927.1009120941162
step: 4387.0, loss: 927.8554735183716
step: 4388.0, loss: 929.4082708358765
step: 4389.0, loss: 929.1828603744507
step: 4390.0, loss: 928.6387758255005
step: 4391.0, loss: 928.1701440811157
step: 4392.0, loss: 928.8061170578003
step: 4393.0, loss: 928.4895029067993
step: 4394.0, loss: 929.6876497268677
step: 4395.0, loss: 929.4587182998657
step: 4396.0, 

step: 4587.0, loss: 927.5242300033569
step: 4588.0, loss: 928.827091217041
step: 4589.0, loss: 928.5516872406006
step: 4590.0, loss: 929.5521354675293
step: 4591.0, loss: 931.1007232666016
step: 4592.0, loss: 926.8386001586914
step: 4593.0, loss: 928.5071907043457
step: 4594.0, loss: 926.7417316436768
step: 4595.0, loss: 930.9661321640015
step: 4596.0, loss: 929.8210172653198
step: 4597.0, loss: 929.6363735198975
step: 4598.0, loss: 929.1229028701782
step: 4599.0, loss: 928.9426202774048
step: 4600.0, loss: 927.6181287765503
step: 4601.0, loss: 926.4738988876343
step: 4602.0, loss: 928.0473670959473
step: 4603.0, loss: 928.7559461593628
step: 4604.0, loss: 928.1610355377197
step: 4605.0, loss: 929.2332582473755
step: 4606.0, loss: 927.6847553253174
step: 4607.0, loss: 927.6076946258545
step: 4608.0, loss: 927.0197677612305
step: 4609.0, loss: 930.1621732711792
step: 4610.0, loss: 929.2063570022583
step: 4611.0, loss: 928.5281810760498
step: 4612.0, loss: 928.3722047805786
step: 4613.0,

step: 4804.0, loss: 928.6363801956177
step: 4805.0, loss: 927.9791383743286
step: 4806.0, loss: 927.3365364074707
step: 4807.0, loss: 928.7543659210205
step: 4808.0, loss: 928.217568397522
step: 4809.0, loss: 926.7154359817505
step: 4810.0, loss: 926.7655658721924
step: 4811.0, loss: 927.8786745071411
step: 4812.0, loss: 928.5929737091064
step: 4813.0, loss: 928.2302780151367
step: 4814.0, loss: 927.8473434448242
step: 4815.0, loss: 926.7913627624512
step: 4816.0, loss: 926.7819900512695
step: 4817.0, loss: 928.3995027542114
step: 4818.0, loss: 926.0919828414917
step: 4819.0, loss: 926.9299125671387
step: 4820.0, loss: 929.0039691925049
step: 4821.0, loss: 930.0972757339478
step: 4822.0, loss: 929.0681819915771
step: 4823.0, loss: 927.6367702484131
step: 4824.0, loss: 929.5845651626587
step: 4825.0, loss: 929.5541858673096
step: 4826.0, loss: 928.1742105484009
step: 4827.0, loss: 929.3362474441528
step: 4828.0, loss: 930.0488452911377
step: 4829.0, loss: 929.0548067092896
step: 4830.0,

step: 5021.0, loss: 928.3972940444946
step: 5022.0, loss: 927.0612745285034
step: 5023.0, loss: 927.1076650619507
step: 5024.0, loss: 928.4850416183472
step: 5025.0, loss: 926.9273319244385
step: 5026.0, loss: 929.258469581604
step: 5027.0, loss: 925.7146406173706
step: 5028.0, loss: 927.5148668289185
step: 5029.0, loss: 927.6539354324341
step: 5030.0, loss: 926.7701587677002
step: 5031.0, loss: 927.7277793884277
step: 5032.0, loss: 927.0305166244507
step: 5033.0, loss: 925.9129209518433
step: 5034.0, loss: 928.6880226135254
step: 5035.0, loss: 926.2261610031128
step: 5036.0, loss: 928.351879119873
step: 5037.0, loss: 930.9259872436523
step: 5038.0, loss: 928.3978385925293
step: 5039.0, loss: 928.0360040664673
step: 5040.0, loss: 926.0653114318848
step: 5041.0, loss: 929.890832901001
step: 5042.0, loss: 930.0204162597656
step: 5043.0, loss: 927.5476980209351
step: 5044.0, loss: 928.1413526535034
step: 5045.0, loss: 929.2116899490356
step: 5046.0, loss: 925.6990623474121
step: 5047.0, l

step: 5238.0, loss: 928.3800859451294
step: 5239.0, loss: 927.9781303405762
step: 5240.0, loss: 928.3550357818604
step: 5241.0, loss: 928.0846166610718
step: 5242.0, loss: 927.5406503677368
step: 5243.0, loss: 926.002462387085
step: 5244.0, loss: 927.5350036621094
step: 5245.0, loss: 928.0909585952759
step: 5246.0, loss: 928.608829498291
step: 5247.0, loss: 926.7859735488892
step: 5248.0, loss: 927.2096862792969
step: 5249.0, loss: 927.0492372512817
step: 5250.0, loss: 931.632472038269
step: 5251.0, loss: 926.8669862747192
step: 5252.0, loss: 927.4257574081421
step: 5253.0, loss: 924.8555374145508
step: 5254.0, loss: 928.3254642486572
step: 5255.0, loss: 928.2244882583618
step: 5256.0, loss: 927.7565317153931
step: 5257.0, loss: 927.4199028015137
step: 5258.0, loss: 927.1826210021973
step: 5259.0, loss: 927.5138282775879
step: 5260.0, loss: 926.7519302368164
step: 5261.0, loss: 927.1589002609253
step: 5262.0, loss: 926.2559328079224
step: 5263.0, loss: 928.0781526565552
step: 5264.0, l

step: 5455.0, loss: 925.6319341659546
step: 5456.0, loss: 925.7371416091919
step: 5457.0, loss: 927.4452648162842
step: 5458.0, loss: 928.037504196167
step: 5459.0, loss: 926.0127534866333
step: 5460.0, loss: 927.025484085083
step: 5461.0, loss: 928.0972528457642
step: 5462.0, loss: 929.5643043518066
step: 5463.0, loss: 927.1318464279175
step: 5464.0, loss: 926.446364402771
step: 5465.0, loss: 926.5553541183472
step: 5466.0, loss: 925.9247121810913
step: 5467.0, loss: 928.9019622802734
step: 5468.0, loss: 929.6385889053345
step: 5469.0, loss: 925.5175180435181
step: 5470.0, loss: 927.9467439651489
step: 5471.0, loss: 927.0958204269409
step: 5472.0, loss: 926.4108448028564
step: 5473.0, loss: 927.5177478790283
step: 5474.0, loss: 925.3011121749878
step: 5475.0, loss: 927.7882738113403
step: 5476.0, loss: 927.5402841567993
step: 5477.0, loss: 928.2749261856079
step: 5478.0, loss: 927.5007028579712
step: 5479.0, loss: 927.7198534011841
step: 5480.0, loss: 926.5547952651978
step: 5481.0, l

step: 5672.0, loss: 926.3144006729126
step: 5673.0, loss: 926.5551023483276
step: 5674.0, loss: 926.122465133667
step: 5675.0, loss: 925.4101886749268
step: 5676.0, loss: 923.734169960022
step: 5677.0, loss: 926.5509576797485
step: 5678.0, loss: 926.0094652175903
step: 5679.0, loss: 925.3033857345581
step: 5680.0, loss: 925.2205200195312
step: 5681.0, loss: 926.3942918777466
step: 5682.0, loss: 926.1440076828003
step: 5683.0, loss: 926.5036611557007
step: 5684.0, loss: 926.4962882995605
step: 5685.0, loss: 926.9912548065186
step: 5686.0, loss: 926.067831993103
step: 5687.0, loss: 926.3124876022339
step: 5688.0, loss: 927.0545749664307
step: 5689.0, loss: 925.0815677642822
step: 5690.0, loss: 926.6814632415771
step: 5691.0, loss: 928.5292539596558
step: 5692.0, loss: 925.5945262908936
step: 5693.0, loss: 926.5992879867554
step: 5694.0, loss: 927.8399238586426
step: 5695.0, loss: 927.5632476806641
step: 5696.0, loss: 928.0248193740845
step: 5697.0, loss: 927.7856550216675
step: 5698.0, l

step: 5889.0, loss: 925.9449119567871
step: 5890.0, loss: 925.8985433578491
step: 5891.0, loss: 925.9684247970581
step: 5892.0, loss: 926.9565267562866
step: 5893.0, loss: 926.6530513763428
step: 5894.0, loss: 925.8254880905151
step: 5895.0, loss: 927.8571996688843
step: 5896.0, loss: 928.1449136734009
step: 5897.0, loss: 926.7371501922607
step: 5898.0, loss: 925.8038911819458
step: 5899.0, loss: 926.5033235549927
step: 5900.0, loss: 924.5322494506836
step: 5901.0, loss: 925.3994159698486
step: 5902.0, loss: 928.2102928161621
step: 5903.0, loss: 927.1822481155396
step: 5904.0, loss: 927.9933423995972
step: 5905.0, loss: 926.4492111206055
step: 5906.0, loss: 926.3422899246216
step: 5907.0, loss: 928.3331184387207
step: 5908.0, loss: 926.3933343887329
step: 5909.0, loss: 925.3531999588013
step: 5910.0, loss: 926.1817007064819
step: 5911.0, loss: 926.1704530715942
step: 5912.0, loss: 925.564190864563
step: 5913.0, loss: 925.1602792739868
step: 5914.0, loss: 924.8855648040771
step: 5915.0,

step: 6106.0, loss: 926.6923198699951
step: 6107.0, loss: 927.7827730178833
step: 6108.0, loss: 927.0943355560303
step: 6109.0, loss: 927.1796455383301
step: 6110.0, loss: 927.0687828063965
step: 6111.0, loss: 926.4156217575073
step: 6112.0, loss: 925.0986938476562
step: 6113.0, loss: 926.5244874954224
step: 6114.0, loss: 925.6469278335571
step: 6115.0, loss: 924.581051826477
step: 6116.0, loss: 925.0303106307983
step: 6117.0, loss: 924.3282880783081
step: 6118.0, loss: 925.6087331771851
step: 6119.0, loss: 926.0888071060181
step: 6120.0, loss: 925.7141246795654
step: 6121.0, loss: 924.1774396896362
step: 6122.0, loss: 924.9632215499878
step: 6123.0, loss: 925.9814739227295
step: 6124.0, loss: 924.3520240783691
step: 6125.0, loss: 926.4185247421265
step: 6126.0, loss: 926.5818948745728
step: 6127.0, loss: 927.9224824905396
step: 6128.0, loss: 925.0171766281128
step: 6129.0, loss: 924.1602010726929
step: 6130.0, loss: 925.1161737442017
step: 6131.0, loss: 924.9676284790039
step: 6132.0,

step: 6323.0, loss: 924.8445062637329
step: 6324.0, loss: 924.0101919174194
step: 6325.0, loss: 925.3958072662354
step: 6326.0, loss: 926.2946634292603
step: 6327.0, loss: 925.8225975036621
step: 6328.0, loss: 925.472785949707
step: 6329.0, loss: 923.4855241775513
step: 6330.0, loss: 926.9584321975708
step: 6331.0, loss: 924.6883115768433
step: 6332.0, loss: 926.0530405044556
step: 6333.0, loss: 925.0236501693726
step: 6334.0, loss: 926.2546768188477
step: 6335.0, loss: 925.1619253158569
step: 6336.0, loss: 924.6172552108765
step: 6337.0, loss: 926.9287233352661
step: 6338.0, loss: 923.2123613357544
step: 6339.0, loss: 924.8553447723389
step: 6340.0, loss: 924.662483215332
step: 6341.0, loss: 926.1748561859131
step: 6342.0, loss: 927.0042734146118
step: 6343.0, loss: 927.3249959945679
step: 6344.0, loss: 923.5956106185913
step: 6345.0, loss: 925.5728988647461
step: 6346.0, loss: 927.9524259567261
step: 6347.0, loss: 926.0359382629395
step: 6348.0, loss: 925.8848218917847
step: 6349.0, 

step: 6540.0, loss: 923.6131496429443
step: 6541.0, loss: 924.2903347015381
step: 6542.0, loss: 923.9341106414795
step: 6543.0, loss: 923.8497905731201
step: 6544.0, loss: 924.2965421676636
step: 6545.0, loss: 926.9124412536621
step: 6546.0, loss: 924.3603363037109
step: 6547.0, loss: 924.863881111145
step: 6548.0, loss: 925.25794506073
step: 6549.0, loss: 924.8901948928833
step: 6550.0, loss: 922.8823041915894
step: 6551.0, loss: 925.2858572006226
step: 6552.0, loss: 926.4309740066528
step: 6553.0, loss: 925.4223041534424
step: 6554.0, loss: 924.252537727356
step: 6555.0, loss: 924.1106719970703
step: 6556.0, loss: 926.1634435653687
step: 6557.0, loss: 925.482494354248
step: 6558.0, loss: 926.230749130249
step: 6559.0, loss: 924.9860992431641
step: 6560.0, loss: 924.9309206008911
step: 6561.0, loss: 924.4993762969971
step: 6562.0, loss: 925.7169933319092
step: 6563.0, loss: 925.8278312683105
step: 6564.0, loss: 925.582986831665
step: 6565.0, loss: 925.0573329925537
step: 6566.0, loss:

step: 6757.0, loss: 925.7204103469849
step: 6758.0, loss: 924.2942247390747
step: 6759.0, loss: 924.9566287994385
step: 6760.0, loss: 924.173921585083
step: 6761.0, loss: 924.9002656936646
step: 6762.0, loss: 925.0846967697144
step: 6763.0, loss: 925.2010707855225
step: 6764.0, loss: 924.8735752105713
step: 6765.0, loss: 925.413984298706
step: 6766.0, loss: 924.5813302993774
step: 6767.0, loss: 926.2609348297119
step: 6768.0, loss: 922.7355518341064
step: 6769.0, loss: 927.044979095459
step: 6770.0, loss: 926.0031261444092
step: 6771.0, loss: 924.3677520751953
step: 6772.0, loss: 925.2375450134277
step: 6773.0, loss: 924.8982877731323
step: 6774.0, loss: 925.1755485534668
step: 6775.0, loss: 923.8886451721191
step: 6776.0, loss: 926.6180381774902
step: 6777.0, loss: 925.1117076873779
step: 6778.0, loss: 926.1805448532104
step: 6779.0, loss: 924.3604021072388
step: 6780.0, loss: 925.2612199783325
step: 6781.0, loss: 923.9717206954956
step: 6782.0, loss: 924.0715637207031
step: 6783.0, l

## Verify things are working

Once you have an initial model trained, try using the following code to query the model for what are the nearest neighbor of a word. This code is intended to help you debug

In [35]:
get_neighbors(model, corpus.word_to_index, "january")

  0%|          | 0/110331 [00:00<?, ?it/s]

[{'word': 'june', 'score': 0.999844491481781},
 {'word': 'february', 'score': 0.9997945427894592},
 {'word': 'november', 'score': 0.9997267723083496},
 {'word': 'july', 'score': 0.9996899962425232},
 {'word': 'march', 'score': 0.9996562004089355},
 {'word': 'september', 'score': 0.9995689392089844},
 {'word': 'august', 'score': 0.9995659589767456},
 {'word': 'october', 'score': 0.999500036239624},
 {'word': 'april', 'score': 0.9994309544563293},
 {'word': 'december', 'score': 0.9990468621253967},
 {'word': '30', 'score': 0.9873328804969788},
 {'word': 'aged', 'score': 0.9869542121887207},
 {'word': '28', 'score': 0.9868108034133911},
 {'word': '26', 'score': 0.9867387413978577},
 {'word': 'windsor_park', 'score': 0.9863975048065186},
 {'word': 'dec', 'score': 0.9863268733024597},
 {'word': 'pandemic', 'score': 0.9861127734184265},
 {'word': '86', 'score': 0.9848241806030273},
 {'word': '27', 'score': 0.9846280217170715},
 {'word': '84', 'score': 0.9844803810119629},
 {'word': '91', 'sc

# Save your model!

Once you have a fully trained model, save it using the code below. Note that we only save the `target_embeddings` from the model, but you could modify the code if you want to save the context vectors--or even try doing fancier things like saving the concatenation of the two or the average of the two!

In [36]:
def save(model, corpus, filename):
    '''
    Saves the model to the specified filename as a gensim KeyedVectors in the
    text format so you can load it separately.
    '''

    # Creates an empty KeyedVectors with our embedding size
#     kv = KeyedVectors(vector_size=model.embedding_size) 
    kv=KeyedVectors(vector_size=50)
    vectors = []
    words = []
    # Get the list of words/vectors in a consistent order
    for index in trange(model.target_embeddings.num_embeddings):
        word = corpus.index_to_word[index]
        vectors.append(model.target_embeddings(torch.LongTensor([index])).detach().numpy()[0])
        words.append(word)

    # Fills the KV object with our data in the right order
    kv.add_vectors(words, vectors) 
    kv.save_word2vec_format(filename, binary=False)


In [None]:
model.target_embeddings.num_embeddings

In [37]:
save(model,corpus,'word2vec_batch32_debias_med_2.kv')

  0%|          | 0/110331 [00:00<?, ?it/s]

# FINAL PART: DO THIS LAST AND READ CAREFULLY

Before you start this part, you need to have a fully working solution and completed the exploratory part of the assignment.

**Once you are ready, create a copy of your working notebook and call it `Debiased Word2Vec.ipynb`. Do not do this part in your working code for the assignment!!!**

## Seriously, save your code in a new file and then start reading the rest of these instructions there.

Ok, hopefully you're reading these in a new file... For this last part of the assignment, we're going to _change_ how word2vec learns at a fundamental level. 

As you might have noticed in your exploratory analysis, the word2vec model learns to weird and sometimes biased associations between words. In particular, your word2vec model has likely learned some unfortunate gender biases, e.g., that the vector for "nurse" is closer to "woman" than "man". The algorithm itself isn't to blame since it is learning these from a corpus (here, Wikipedia biographies) that contain biases already based on how people write. Wikipedia [is](http://markusstrohmaier.info/documents/2015_icwsm2015_wikipedia_gender.pdf) [well](http://dcs.gla.ac.uk/~mounia/Papers/wiki_gender_bias.pdf) [known](https://www.academia.edu/download/64856696/genderanalysisofWikipediabiostext_self_archived.pdf) for having gender biases in how it writes about men and women.


**Side note**: Some times this bias-learning behavior is useful: We can use word2vec to uncover these biases and analyze their trends, like this PNAS paper did for [looking at bias in news writing along multiple dimensions of identity](https://www.pnas.org/content/pnas/115/16/E3635.full.pdf)

In this last part of the homework, we'll ask how we might try to _prevent_ these biases by modifying the training. You won't need to solve this problem by any means, but the act of trying to reduce the biases will open up a whole new toolbox for how you (the experimenter/practioner) can change how and what models learn.

There are many potential ways to _debias_ word embeddings so that their representations are not skewed along one "latent dimension" like gender. In this homework, you'll be trying one of a few different ideas for how to do it. **You are not expected to solve gender bias! This part of the assignment is to have to start grappling with a hard challenge but there is no penalty for doing less-well!** 

One common technique to have models avoid learning bias is similar to another one you already&mdash;**regularization**. In Logistic Regression, we could use L2 regularization to have our model avoid learning $\beta$ weights that are overfit to specific or low-frequency features by adding a regularizer penalty where the larger the weight, the more penalty the model paid. Recall that this forces the model to only pick the most useful (generalizable) weights, since it has to pay a penalty for any non-zero weight. 

In word2vec, we can adapt the idea to think about whether our model's embeddings are closer or farther to different gender dimensions. For example, if we consider the embedding for "president", ideally, we'd want it to be equally similar to the embeddings for "man" and "woman". One idea then is to penalize the model based on how uneven the similarity is. We can do this by directly modifying the loss:
```
loss = loss_criteron(preds, actual_vals) + some_bias_measuring_function(model)
```
Here, the `some_bias_measuring_function` function takes in your model as input and returns how much bias you found. Continuing our example, we might implement it in pseudocode as
```
def some_bias_measuring_function(model):
    pres_woman_sim = cosine_similarity(model, "president", "woman")
    pres_man_sim = cosine_similarity(model, "president", "man")
    return abs(pres_woman_sim - pres_man_sim)
```
This simple example would penalize the model for learning a representation of "president" that is more simular to one of the two words. Of course, this example is overly simple. Why just "president"? Why just "man" and "woman"? Why not other words or other gender-related words or other gender identities?? 

Another idea might be to just make the vectors for "man" and "woman" be as similar as possible:
```
def some_bias_measuring_function(model):
    # cosine similarity is in [-1,1] but we mostly expect it in [0,1]
    man_woman_sim = cosine_similarity(model, "man", "woman")
    # penalize vectors that are not maximally similar, and avoid the edge case 
    # of negative cosine similarity
    return 1 - max(man_woman_sim, 0)
```

All of this works in practice because PyTorch is fantastic about tracking the gradient with respect to the loss. This ability lets us easily define a loss function so that our word2vec model (1) learns to predict the right context words while (2) avoids learning biases. If we compare this code to the numpy part of Homework 1, it's easy to see how powerful PyTorch can be as an experimenter for helping you control what and how your models learn!

Your task is to expand this general approach by coming up with an extension to word2vec that adds some new term to the `loss` value that penalizes bias in the gender dimension. There is no right way to do this and even some right-looking approaches may not work&mdash;or might word but simultaneously destroy the information in the word vectors (all-zero vectors are unbiased but also uninformative!). 

**Suggestion:** You may need to weight your bias term in the loss function (remember that $\lambda_1 x_1 + \lambda_2 x_2$ interpolation? This is sort of similar) so that your debiasing regularizer doesn't overly penalize your model.

Once you have generated your model, record word vector similarities for the pairs listed on canvas in `word-pair-similarity-predictions.csv` where your file writes a result like
```
word1,word2,sim
dog,puppy,0.91234123
woman,preseident,0.81234
```
You'll record the similarity for each pair of words in the file and upload it to CodaLab, which is kind of like Kaggle but lets use a custom scoring program. We'll evaluate your embeddings based on how unbiased they are and how much information they still capture after debiasing. **Your grade does not depend on how well you do in CodaLab, just that you tried something and submitted.** However, the CodaLab leaderboard will hopefully provide a fun and insightful way of comparing just how much bias we can remove from our embeddings.

The CodaLab link will be posted to Piazza