In [1]:
import pandas as pd
from tqdm import tqdm
import torch
import pytorch_utils

[32m[PyTorch-Utils][0m: Loading anchor candidate data...
[32m[PyTorch-Utils][0m: Loading wikipedia title embedings...
[32m[PyTorch-Utils][0m: Loading KB explanations...
[32m[PyTorch-Utils][0m: Loading wikipedia items...


In [2]:
# is cuda available?
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
dataset = pytorch_utils.EntityDataset(device='cuda')
test_dataset = pytorch_utils.EntityDataset(train=False, device='cuda')

[32m[PyTorch-Utils][0m: Loading train set...
[32m[PyTorch-Utils][0m: Loading train_context_text_150.pkl from data/pkl...
[32m[PyTorch-Utils][0m: Now generating entity embeddings...


Batches:   0%|          | 0/572 [00:00<?, ?it/s]

Entity Length: 18288
Entity shape: torch.Size([18288, 384])


Batches:   0%|          | 0/572 [00:00<?, ?it/s]

[32m[PyTorch-Utils][0m: Now computing syntax candidates for each entity...
[32m[PyTorch-Utils][0m: Now computing OLD syntax candidates for each entity...
[32m[PyTorch-Utils][0m: Now generating inputs and labels...
[32m[PyTorch-Utils][0m: Loading train_tokenized_inputs_attention_mask.pt and train_tokenized_inputs_input_ids.pt from data/pkl...
[32m[PyTorch-Utils][0m: Loading test set...
[32m[PyTorch-Utils][0m: Loading test_context_text_150.pkl from data/pkl...
[32m[PyTorch-Utils][0m: Now generating entity embeddings...


Batches:   0%|          | 0/287 [00:00<?, ?it/s]

Entity Length: 9166
Entity shape: torch.Size([9166, 384])


Batches:   0%|          | 0/287 [00:00<?, ?it/s]

[32m[PyTorch-Utils][0m: Now computing syntax candidates for each entity...
[32m[PyTorch-Utils][0m: Now computing OLD syntax candidates for each entity...
[32m[PyTorch-Utils][0m: Now generating inputs and labels...
[32m[PyTorch-Utils][0m: Now tokenizing...


In [4]:
pytorch_utils.delete_corpus_embeds()

[32m[PyTorch-Utils][0m: Deleting corpus embeddings...


In [5]:
# garbage collect
import gc

gc.collect()

4192

In [6]:
from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=100, shuffle=True, collate_fn=lambda x: pytorch_utils.EntityDataset.collate_fn_train(x, device=device))
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=lambda x: pytorch_utils.EntityDataset.collate_fn_test(x, device=device))

In [7]:
model = pytorch_utils.EntityClassifier(transformer_model='distilbert-base-uncased', hidden_size=256, device=device)
# does distilbert_model.pt exists? If so load the weights
import os
if os.path.exists('distilbert_model_with_weights_more_trained.pt'):
    model.load_state_dict(torch.load('distilbert_model_with_weights_more_trained.pt'))

In [8]:
"""
Freeze everything except for:
transformer.transformer.layer.5.attention.q_lin.weight
transformer.transformer.layer.5.attention.q_lin.bias
transformer.transformer.layer.5.attention.k_lin.weight
transformer.transformer.layer.5.attention.k_lin.bias
transformer.transformer.layer.5.attention.v_lin.weight
transformer.transformer.layer.5.attention.v_lin.bias
transformer.transformer.layer.5.attention.out_lin.weight
transformer.transformer.layer.5.attention.out_lin.bias
transformer.transformer.layer.5.sa_layer_norm.weight
transformer.transformer.layer.5.sa_layer_norm.bias
transformer.transformer.layer.5.ffn.lin1.weight
transformer.transformer.layer.5.ffn.lin1.bias
transformer.transformer.layer.5.ffn.lin2.weight
transformer.transformer.layer.5.ffn.lin2.bias
transformer.transformer.layer.5.output_layer_norm.weight
transformer.transformer.layer.5.output_layer_norm.bias
classifier.0.weight
classifier.0.bias
classifier.3.weight
classifier.3.bias

Basically, we want to fine-tune the last layer of the transformer and the classifier
"""
for name, param in model.named_parameters():
    if 'transformer' in name:
        if 'layer.5' in name:
            param.requires_grad = True
        elif 'pooler' in name:
            param.requires_grad = True
        elif 'layer.4' in name:
            param.requires_grad = False
        else:
            param.requires_grad = False
    elif 'classifier' in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

# Check if all parameters are frozen except for the ones we want
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)


transformer.transformer.layer.5.attention.q_lin.weight
transformer.transformer.layer.5.attention.q_lin.bias
transformer.transformer.layer.5.attention.k_lin.weight
transformer.transformer.layer.5.attention.k_lin.bias
transformer.transformer.layer.5.attention.v_lin.weight
transformer.transformer.layer.5.attention.v_lin.bias
transformer.transformer.layer.5.attention.out_lin.weight
transformer.transformer.layer.5.attention.out_lin.bias
transformer.transformer.layer.5.sa_layer_norm.weight
transformer.transformer.layer.5.sa_layer_norm.bias
transformer.transformer.layer.5.ffn.lin1.weight
transformer.transformer.layer.5.ffn.lin1.bias
transformer.transformer.layer.5.ffn.lin2.weight
transformer.transformer.layer.5.ffn.lin2.bias
transformer.transformer.layer.5.output_layer_norm.weight
transformer.transformer.layer.5.output_layer_norm.bias
classifier.0.weight
classifier.0.bias
classifier.2.weight
classifier.2.bias


In [9]:
# loss and optimizer
from torch import optim
# Since for each mention, there are 10 candidates, and only 1 of the candidates is correct, we can use binary cross entropy loss
# As we classify each (mention, candidate) pair as either correct or incorrect
# But because of this we have class imbalance, as there are way more incorrect pairs than correct pairs
# To solve this, we can use weighted binary cross entropy loss
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(8.0).to(device))
optimizer = optim.Adam(model.parameters(), lr=0.001)

# training loop
from tqdm import tqdm
from tqdm.notebook import tqdm as tqdm_notebook

desc = f'Epoch {1} loss: ?? Avg shape: ??'
epochs = 0
for epoch in range(epochs):
    running_loss = 0.0
    running_loss_divider = 0
    if epoch == epochs - 1:
        pbar2 = tqdm(enumerate(dataloader), leave=True, total=len(dataloader), desc=desc)
    else:
        pbar2 = tqdm(enumerate(dataloader), leave=False, total=len(dataloader), desc=desc)
    for i, data in pbar2:
        # get the inputs; data is a list of [inputs, labels]
        tokenized_input_id_batch, tokenized_attention_mask_batch, indexes, labels = data
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        outputs = model(tokenized_input_id_batch, tokenized_attention_mask_batch, '')
        # outputs is of shape (batch_size, 1)
        # labels is of shape (batch_size)
        loss = criterion(outputs.squeeze(), labels)
        loss.backward()
        optimizer.step()
        # print statistics
        running_loss += loss.item()
        running_loss_divider += 1
        running_loss_divided_formatted = f'{(running_loss / running_loss_divider):.4f}'
        desc = f'Epoch {epoch+1} loss: {running_loss_divided_formatted}'
        pbar2.set_description(desc)
        pbar2.update(1)

In [10]:
model.eval()  # Put the model in evaluation mode
mention_to_candidates = {}  # Initialize a dictionary to hold results

with torch.no_grad():  # Disable gradient calculation for inference
    for batch in tqdm(test_dataloader):
        tokenized_inputs_input_ids, tokenized_inputs_attention_mask, indexes, candidate_ids_batch = batch
        probabilities = model(tokenized_inputs_input_ids, tokenized_inputs_attention_mask, '').squeeze()
        
        for i, prob in enumerate(probabilities):
            mention = indexes[i]
            # detach and move to cpu
            mention = mention.detach().cpu().numpy().item()
            candidate_id = candidate_ids_batch[i]  # The candidate ID for this input
            # detach and move to cpu
            candidate_id = candidate_id.detach().cpu().numpy().item()
            
            if mention not in mention_to_candidates:
                mention_to_candidates[mention] = []
            
            mention_to_candidates[mention].append((candidate_id, prob.item()))

100%|██████████| 589/589 [02:39<00:00,  3.70it/s]


In [22]:
len(mention_to_candidates)

9053

In [11]:
# Now select the candidate with the highest probability for each mention
final_disambiguations = {}
for mention, candidates in mention_to_candidates.items():
    # Sort the candidates by probability, select the one with the highest
    best_candidate = max(candidates, key=lambda x: x[1])[0]
    final_disambiguations[mention] = best_candidate

# final_disambiguations now holds the selected candidate ID for each mention

In [12]:
final_disambiguations

{0: 3229147,
 1: 84,
 2: 912881,
 3: 6827914,
 4: 3229147,
 5: 3489807,
 6: 3058749,
 7: 1788018,
 8: 2725632,
 9: 3195284,
 10: 3344790,
 11: 3489807,
 12: 5591359,
 13: 3229147,
 14: 1321565,
 15: 4756511,
 16: 3489807,
 17: 6827914,
 18: 3058749,
 19: 3520174,
 20: 6836032,
 21: 3177444,
 22: 21061609,
 23: 3520174,
 24: 1321565,
 25: 3058749,
 26: 3177444,
 27: 6836032,
 28: 1155836,
 29: 2725632,
 30: 5107238,
 31: 21,
 32: 2725705,
 33: 21,
 34: 6766945,
 35: 2725632,
 36: 1788018,
 37: 2725667,
 38: 408,
 39: 3537597,
 40: 5107238,
 41: 1788018,
 42: 3195284,
 43: 3344790,
 44: 45997,
 45: 6584108,
 46: 3195284,
 47: 1137275,
 48: 84,
 49: 21,
 50: 83065,
 51: 3229147,
 52: 3489807,
 53: 3489807,
 54: 3229147,
 55: 3229147,
 56: 3489807,
 57: 1337792,
 58: 3107254,
 59: 3041569,
 60: 665489,
 61: 3344790,
 62: 3195284,
 63: 84,
 64: 1155836,
 65: 2725705,
 66: 2725632,
 67: 989616,
 68: 3505744,
 69: 2725657,
 70: 72259,
 71: 2725694,
 72: 3126578,
 73: 2725667,
 74: 1788018,
 7

In [14]:
wiki_items = pd.read_csv(DATA_DIR + 'wiki_lite/wiki_items.csv')
# index wiki_items by id
wiki_items = wiki_items.set_index('item_id')
# Create item_id to wikipedia_title map
item_id_to_title = wiki_items['wikipedia_title'].to_dict()

In [15]:
enwiki_redirects = pd.read_csv(DATA_DIR + 'wiki_lite/enwiki_redirects.tsv', sep='\t', header=None, names=['source', 'target'])
# index enwiki_redirects by source
enwiki_redirects = enwiki_redirects.set_index('source')
# create source to target map
source_to_target = enwiki_redirects['target'].to_dict()

In [16]:
wiki_urls = []
not_found = 0
found = 0
redirection = 0
# Now we will map these into wikipedia_urls
for mention, candidate in tqdm(final_disambiguations.items()):
    wikipedia_title = item_id_to_title[candidate]
    # does this wikipedia title exist in the redirects?
    if wikipedia_title in source_to_target:
        # if it does, we will replace it with the redirect
        wikipedia_title = source_to_target[wikipedia_title]
        redirection += 1
    # Now replace the spaces with underscores
    wikipedia_title = wikipedia_title.replace(' ', '_')
    # And add the wikipedia url
    wiki_urls.append(f'http://en.wikipedia.org/wiki/{wikipedia_title}')
    found += 1

print(f'Found {found} wikipedia urls')
print(f'Not found {not_found} wikipedia urls')
print(f'Percentage of wikipedia urls found: {found / (found + not_found)}')
print(f'Percentage of wikipedia urls redirected: {redirection / found}')

100%|██████████| 9053/9053 [00:00<00:00, 934533.59it/s]

Found 9053 wikipedia urls
Not found 0 wikipedia urls
Percentage of wikipedia urls found: 1.0
Percentage of wikipedia urls redirected: 0.003313818623660665





In [32]:
wiki_urls = []
for i in range(test_dataset.entity_df.shape[0]):
    if i in final_disambiguations:
        candidate = final_disambiguations[i]
        wikipedia_title = item_id_to_title[candidate]
        # does this wikipedia title exist in the redirects?
        if wikipedia_title in source_to_target:
            # if it does, we will replace it with the redirect
            wikipedia_title = source_to_target[wikipedia_title]
        # Now replace the spaces with underscores
        wikipedia_title = wikipedia_title.replace(' ', '_')
        # And add the wikipedia url
        wiki_urls.append(f'http://en.wikipedia.org/wiki/{wikipedia_title}')
    else:
        wiki_urls.append('NOT_FOUND')


In [27]:
# load validation prepped csv
import pandas as pd
DATA_DIR = 'data/'
validation_prepped = pd.read_csv(DATA_DIR + 'validation_prepped.csv')
true_wiki_urls = validation_prepped['2'].to_list()

old_submission = pd.read_csv('submission_10_epochs.csv')

# join old_submission and validation_prepped on column called 'id'
old_submission_joined = old_submission.merge(validation_prepped, on='id', how='right')


In [28]:
old_submission_joined

Unnamed: 0,id,wiki_url_x,token,entity_tag,full_mention,wiki_url_y,doc_id,entity_loc,2
0,3,http://en.wikipedia.org/wiki/Leicestershire_Co...,LEICESTERSHIRE,B,LEICESTERSHIRE,?,1.0,2,http://en.wikipedia.org/wiki/Leicestershire_Co...
1,13,http://en.wikipedia.org/wiki/London,LONDON,B,LONDON,?,1.0,11,http://en.wikipedia.org/wiki/London
2,16,http://en.wikipedia.org/wiki/West_Indies_crick...,West,B,West Indian,?,1.0,13,http://en.wikipedia.org/wiki/West_Indies_crick...
3,19,http://en.wikipedia.org/wiki/Phil_Simmons,Phil,B,Phil Simmons,?,1.0,16,http://en.wikipedia.org/wiki/Phil_Simmons
4,28,http://en.wikipedia.org/wiki/Leicestershire_Co...,Leicestershire,B,Leicestershire,?,1.0,25,http://en.wikipedia.org/wiki/Leicestershire_Co...
...,...,...,...,...,...,...,...,...,...
9161,104839,http://en.wikipedia.org/wiki/England_national_...,England,B,England,?,447.0,221,http://en.wikipedia.org/wiki/England_national_...
9162,104851,http://en.wikipedia.org/wiki/Leeds_United_F.C.,Leeds,B,Leeds United,?,447.0,232,http://en.wikipedia.org/wiki/Leeds_United_A.F.C.
9163,104858,http://en.wikipedia.org/wiki/England_national_...,England,B,England,?,447.0,239,http://en.wikipedia.org/wiki/England_national_...
9164,104877,http://en.wikipedia.org/wiki/1966_FIFA_World_Cup,1966,B,1966 World Cup,?,447.0,258,http://en.wikipedia.org/wiki/1966_FIFA_World_Cup


In [29]:
len(old_submission_joined['2'].to_list()), len(old_submission_joined['wiki_url_x'].to_list())

(9166, 9166)

In [30]:
# calculate f1 score between wiki_url_left and wiki_url_right
from sklearn.metrics import f1_score
print(f1_score(old_submission_joined['2'].to_list(), old_submission_joined['wiki_url_x'].to_list(), average='micro'))

0.7794021383373335


In [31]:
print(f1_score(true_wiki_urls, wiki_urls, average='micro'))

0.7794021383373335


In [None]:
# not_found = 0
# found = 0
# train_wiki_urls = []
# # Now we will map these into wikipedia_urls
# for i in tqdm(range(len(predictions_train))):
#     if predictions_train[i] == 0:
#         # if the prediction is 0, we will append a blank url
#         train_wiki_urls.append('NOT_FOUND')
#         not_found += 1
#         continue
#     wikipedia_title = item_id_to_title[predictions_train[i]]
#     # does this wikipedia title exist in the redirects?
#     if wikipedia_title in source_to_target:
#         # if it does, we will replace it with the redirect
#         new_title = source_to_target[wikipedia_title]
#     # Now replace the spaces with underscores
#     wikipedia_title = wikipedia_title.replace(' ', '_')
#     # And add the wikipedia url
#     train_wiki_urls.append(f'http://en.wikipedia.org/wiki/{wikipedia_title}')
#     found += 1

# print(f'Found {found} wikipedia urls')
# print(f'Not found {not_found} wikipedia urls')
# print(f'Percentage of wikipedia urls found: {found / (found + not_found)}')

In [34]:
test = pd.read_csv(DATA_DIR + 'test.csv')
# train = pd.read_csv(DATA_DIR + 'train.csv')

In [35]:
not_nan = test['wiki_url'].notna()
not_nme = test['wiki_url'] != '--NME--'
# train_not_nan = train['wiki_url'].notna()
# train_not_nme = train['wiki_url'] != '--NME--'
test.loc[(not_nan & not_nme) & (test.id == 65002)]

Unnamed: 0,id,token,entity_tag,full_mention,wiki_url
65002,65002,Dejan,B,Dejan Koturovic,?


In [36]:
test.loc[not_nan & not_nme, 'wiki_url'] = wiki_urls
# train.loc[train_not_nan & train_not_nme, 'wiki_url'] = train_wiki_urls

In [37]:
# replace NaN or --NME-- with NOT_FOUND
test['wiki_url'] = test['wiki_url'].fillna('NOT_FOUND')
test['wiki_url'] = test['wiki_url'].replace('--NME--', 'NOT_FOUND')
# train['wiki_url'] = train['wiki_url'].fillna('NOT_FOUND')
# train['wiki_url'] = train['wiki_url'].replace('--NME--', 'NOT_FOUND')

In [38]:
test[test.id == 65002]

Unnamed: 0,id,token,entity_tag,full_mention,wiki_url
65002,65002,Dejan,B,Dejan Koturovic,http://en.wikipedia.org/wiki/Dejan_Koturović


In [39]:
' '.join(test.token.fillna('', inplace=False).to_list()[:250])

"-DOCSTART- (947testa CRICKET) CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY .  LONDON 1996-08-30  West Indian all-rounder Phil Simmons took four for 38 on Friday as Leicestershire beat Somerset by an innings and 39 runs in two days to take over at the head of the county championship .  Their stay on top , though , may be short-lived as title rivals Essex , Derbyshire and Surrey all closed in on victory while Kent made up for lost time in their rain-affected match against Nottinghamshire .  After bowling Somerset out for 83 on the opening morning at Grace Road , Leicestershire extended their first innings by 94 runs before being bowled out for 296 with England discard Andy Caddick taking three for 83 .  Trailing by 213 , Somerset got a solid start to their second innings before Simmons stepped in to bundle them out for 174 .  Essex , however , look certain to regain their top spot after Nasser Hussain and Peter Such gave them a firm grip on their match against Yorkshi

In [40]:
# # One problem, Simmons does not retrieve the cricket player, but querying for Phil Simmons does.
# # TODO: Fix this, one idea is to group by per doc id, check if this token has a better previous full mention
# display(wiki_items[wiki_items.index.isin(test_dataset.anchor_to_candidate['simmons'])])

# print('____')

# display(wiki_items[wiki_items.index.isin(test_dataset.anchor_to_candidate['phil simmons'])])

In [41]:
# now create a .csv file from id, wiki_url
test[['id', 'wiki_url']].to_csv('submission_distilbert_epoch10.csv', index=False)
# train[['id', 'wiki_url']].to_csv('train_with_doc.csv', index=False)