In [None]:
import os
import random
import json
import pickle
from tqdm import tqdm

from collections import defaultdict

## read file

In [None]:
# read train file
train_set = []

with open('raw/train.tsv') as f:
    readin = f.readlines()
    for line in tqdm(readin):
        train_set.append(json.loads(line))

In [None]:
# read val file
val_set = []

with open('raw/val.tsv') as f:
    readin = f.readlines()
    for line in tqdm(readin):
        val_set.append(json.loads(line))

In [None]:
# read test file
test_set = []

with open('raw/test.tsv') as f:
    readin = f.readlines()
    for line in tqdm(readin):
        test_set.append(json.loads(line))

In [None]:
train_set

In [None]:
# generate author_id2idx_dict, tag_id2idx_dict, mention_id2idx_dict based on train set

author_id2idx_dict = {}
tag_id2idx_dict = {}
mention_id2idx_dict = {}
poi_neighbour_dict = defaultdict(list)
tweet_id2text_dict = {}

mention_neighbour = defaultdict(list)
tag_neighbour = defaultdict(list)
author_neighbour = defaultdict(list)

for t in tqdm(train_set):
    poi_neighbour_dict[t['poi_id']].append(t['tweet_id'])
    tweet_id2text_dict[t['tweet_id']] = t['tweet_text']
    
    author_neighbour[t['author_id']].append(t['tweet_text'])
    if t['author_id'] not in author_id2idx_dict:
        author_id2idx_dict[t['author_id']] = len(author_id2idx_dict)
    
    for tag in t['tags']:
        tag_neighbour[tag].append(t['tweet_text'])
        if tag not in tag_id2idx_dict:
            tag_id2idx_dict[tag] = len(tag_id2idx_dict)
            
    for mention in t['mentions']:
        mention_neighbour[mention].append(t['tweet_text'])
        if mention not in mention_id2idx_dict:
            mention_id2idx_dict[mention] = len(mention_id2idx_dict)
    
print(f'Author:{len(author_id2idx_dict)}, Tags:{len(tag_id2idx_dict)}, Mentions:{len(mention_id2idx_dict)}, POI:{len(poi_neighbour_dict)}.')

## Generate Train/Val/Test with no filtering on metadata

In [None]:
# generate official train file
# tweet_text \t ... blank_tweet_neighbour ... \t mention_id1 \t mention_id2 \t tag_id1 \t tag_id2 \t tag_id3 \t author_id $$ poi_text \t tweet_id_1 \t ... \t tweet_id_6 \t ... blank_notext_neighbour ... \n
# remember that we add blank position to make tweet_text node and poi node exactly the same
# this is used to generate train_text.tsv

mention_neighbours = 2
tag_neighbours = 3
tweet_neighbours = 6

blank_for_tweet = [''] * tweet_neighbours
blank_for_poi = ['-1'] * (mention_neighbours + tag_neighbours + 1)

with open('Tweet_text/train_text.tsv','w') as fout:
    for t in tqdm(train_set):
        # generate tweet center node
        ## sample mentions
        m = [mention_id2idx_dict[mm] for mm in t['mentions']]
        random.shuffle(m)
        
        if len(m) >= mention_neighbours:
            sampled_mention_n = m[:mention_neighbours]
        else:
            sampled_mention_n = m + [-1] * (mention_neighbours - len(m))
        sampled_mention_n = [str(mm) for mm in sampled_mention_n]
        
        ## sample tags
        tags = [tag_id2idx_dict[tt] for tt in t['tags']]
        random.shuffle(tags)
        
        if len(tags) >= tag_neighbours:
            sampled_tag_n = tags[:tag_neighbours]
        else:
            sampled_tag_n = tags + [-1] * (tag_neighbours - len(tags))
        sampled_tag_n = [str(tt) for tt in sampled_tag_n]
        
        ## concate for tweet
        tw = [t['tweet_text']] + blank_for_tweet + sampled_mention_n + sampled_tag_n + [str(author_id2idx_dict[t['author_id']])]

        # generate poi center node
        tw_n = list(poi_neighbour_dict[t['poi_id']])
        tw_n.pop(tw_n.index(t['tweet_id']))
        tw_n = [tweet_id2text_dict[tid] for tid in tw_n]
        random.shuffle(tw_n)
        if len(tw_n) >= tweet_neighbours:
            sampled_twitter_n = tw_n[:tweet_neighbours]
        else:
            sampled_twitter_n = tw_n + [''] * (tweet_neighbours - len(tw_n))
                
        ## concate for poi
        poi = [t['poi_text']] + sampled_twitter_n + blank_for_poi
        
        fout.write('\t'.join(tw) + '\$\$'+'\t'.join(poi)+'\n')

In [None]:
# generate official validation set
# tweet_text \t ... blank_tweet_neighbour ... \t mention_id1 \t mention_id2 \t tag_id1 \t tag_id2 \t tag_id3 \t author_id $$ poi_text \t tweet_id_1 \t ... \t tweet_id_6 \t ... blank_notext_neighbour ... \n
# remember that we add blank position to make tweet_text node and poi node exactly the same
# this is used to generate val_text.tsv

mention_neighbours = 2
tag_neighbours = 3
tweet_neighbours = 6

blank_for_tweet = [''] * tweet_neighbours
blank_for_poi = ['-1'] * (mention_neighbours + tag_neighbours + 1)

with open('Tweet_text/val_text.tsv','w') as fout:
    for t in tqdm(val_set):
        # generate tweet center node
        ## sample mentions
        m = [mention_id2idx_dict[mm] for mm in t['mentions'] if mm in mention_id2idx_dict]
        random.shuffle(m)
        
        if len(m) >= mention_neighbours:
            sampled_mention_n = m[:mention_neighbours]
        else:
            sampled_mention_n = m + [-1] * (mention_neighbours - len(m))
        sampled_mention_n = [str(mm) for mm in sampled_mention_n]
        
        ## sample tags
        tags = [tag_id2idx_dict[tt] for tt in t['tags'] if tt in tag_id2idx_dict]
        random.shuffle(tags)
        
        if len(tags) >= tag_neighbours:
            sampled_tag_n = tags[:tag_neighbours]
        else:
            sampled_tag_n = tags + [-1] * (tag_neighbours - len(tags))
        sampled_tag_n = [str(tt) for tt in sampled_tag_n]
        
        ## concate for tweet
        if t['author_id'] in author_id2idx_dict:
            author = [str(author_id2idx_dict[t['author_id']])]
        else:
            author = ['-1']
        tw = [t['tweet_text']] + blank_for_tweet + sampled_mention_n + sampled_tag_n + author

        # generate poi center node
        if t['poi_id'] in poi_neighbour_dict:
            tw_n = list(poi_neighbour_dict[t['poi_id']])
            tw_n = [tweet_id2text_dict[tid] for tid in tw_n]
            random.shuffle(tw_n)
        else:
            tw_n = []
        if len(tw_n) >= tweet_neighbours:
            sampled_twitter_n = tw_n[:tweet_neighbours]
        else:
            sampled_twitter_n = tw_n + [''] * (tweet_neighbours - len(tw_n))
        
        ## concate for poi
        poi = [t['poi_text']] + sampled_twitter_n + blank_for_poi
        
        fout.write('\t'.join(tw) + '\$\$'+'\t'.join(poi)+'\n')

In [None]:
# generate official test set
# tweet_text \t ... blank_tweet_neighbour ... \t mention_id1 \t mention_id2 \t tag_id1 \t tag_id2 \t tag_id3 \t author_id $$ poi_text \t tweet_id_1 \t ... \t tweet_id_6 \t ... blank_notext_neighbour ... \n
# remember that we add blank position to make tweet_text node and poi node exactly the same
# this is used to generate test_text.tsv

mention_neighbours = 2
tag_neighbours = 3
tweet_neighbours = 6

blank_for_tweet = [''] * tweet_neighbours
blank_for_poi = ['-1'] * (mention_neighbours + tag_neighbours + 1)

with open('Tweet_text/test_text.tsv','w') as fout:
    for t in tqdm(test_set):
        # generate tweet center node
        ## sample mentions
        m = [mention_id2idx_dict[mm] for mm in t['mentions'] if mm in mention_id2idx_dict]
        random.shuffle(m)
        
        if len(m) >= mention_neighbours:
            sampled_mention_n = m[:mention_neighbours]
        else:
            sampled_mention_n = m + [-1] * (mention_neighbours - len(m))
        sampled_mention_n = [str(mm) for mm in sampled_mention_n]
        
        ## sample tags
        tags = [tag_id2idx_dict[tt] for tt in t['tags'] if tt in tag_id2idx_dict]
        random.shuffle(tags)
        
        if len(tags) >= tag_neighbours:
            sampled_tag_n = tags[:tag_neighbours]
        else:
            sampled_tag_n = tags + [-1] * (tag_neighbours - len(tags))
        sampled_tag_n = [str(tt) for tt in sampled_tag_n]
        
        ## concate for tweet
        if t['author_id'] in author_id2idx_dict:
            author = [str(author_id2idx_dict[t['author_id']])]
        else:
            author = ['-1']
        tw = [t['tweet_text']] + blank_for_tweet + sampled_mention_n + sampled_tag_n + author

        # generate poi center node
        if t['poi_id'] in poi_neighbour_dict:
            tw_n = list(poi_neighbour_dict[t['poi_id']])
            tw_n = [tweet_id2text_dict[tid] for tid in tw_n]
            random.shuffle(tw_n)
        else:
            tw_n = []
        if len(tw_n) >= tweet_neighbours:
            sampled_twitter_n = tw_n[:tweet_neighbours]
        else:
            sampled_twitter_n = tw_n + [''] * (tweet_neighbours - len(tw_n))
        
        ## concate for poi
        poi = [t['poi_text']] + sampled_twitter_n + blank_for_poi
        
        fout.write('\t'.join(tw) + '\$\$'+'\t'.join(poi)+'\n')

In [None]:
# save anthor_num, tags_num, mentions_num
pickle.dump([len(author_id2idx_dict), len(tag_id2idx_dict), len(mention_id2idx_dict)], open('Tweet_text/mta_num.pkl', 'wb'))
pickle.dump(author_id2idx_dict, open('Tweet_text/author_id2idx_dict.pkl', 'wb'))
pickle.dump(tag_id2idx_dict, open('Tweet_text/tag_id2idx_dict.pkl', 'wb'))
pickle.dump(mention_id2idx_dict, open('Tweet_text/mention_id2idx_dict.pkl', 'wb'))

In [None]:
pickle.dump(mention_neighbour, open('Tweet_text/mention_neighbour.pkl', 'wb'))
pickle.dump(tag_neighbour, open('Tweet_text/tag_neighbour.pkl', 'wb'))
pickle.dump(author_neighbour, open('Tweet_text/author_neighbour.pkl', 'wb'))

## Generate filtered train/val/test

In [None]:
filter_num = 2

In [None]:
# generate official train file
# tweet_text \t ... blank_tweet_neighbour ... \t mention_id1 \t mention_id2 \t tag_id1 \t tag_id2 \t tag_id3 \t author_id $$ poi_text \t tweet_id_1 \t ... \t tweet_id_6 \t ... blank_notext_neighbour ... \n
# remember that we add blank position to make tweet_text node and poi node exactly the same
# this is used to generate train_textf.tsv

mention_neighbours = 2
tag_neighbours = 3
tweet_neighbours = 6

blank_for_tweet = [''] * tweet_neighbours
blank_for_poi = ['-1'] * (mention_neighbours + tag_neighbours + 1)

with open(f'Tweet_text/train_textf{filter_num}.tsv','w') as fout:
    for t in tqdm(train_set):
        # generate tweet center node
        ## sample mentions
        m = [mention_id2idx_dict[mm] for mm in t['mentions'] if len(mention_neighbour[mm]) >= filter_num]
        random.shuffle(m)
        
        if len(m) >= mention_neighbours:
            sampled_mention_n = m[:mention_neighbours]
        else:
            sampled_mention_n = m + [-1] * (mention_neighbours - len(m))
        sampled_mention_n = [str(mm) for mm in sampled_mention_n]
        
        ## sample tags
        tags = [tag_id2idx_dict[tt] for tt in t['tags'] if len(tag_neighbour[tt]) >= filter_num]
        random.shuffle(tags)
        
        if len(tags) >= tag_neighbours:
            sampled_tag_n = tags[:tag_neighbours]
        else:
            sampled_tag_n = tags + [-1] * (tag_neighbours - len(tags))
        sampled_tag_n = [str(tt) for tt in sampled_tag_n]
        
        ## sample author
        if len(tag_neighbour[t['author_id']]) >= filter_num:
            author = [str(author_id2idx_dict[t['author_id']])]
        else:
            author = [str(-1)]
        
        ## concate for tweet
        tw = [t['tweet_text']] + blank_for_tweet + sampled_mention_n + sampled_tag_n + author

        # generate poi center node
        tw_n = list(poi_neighbour_dict[t['poi_id']])
        tw_n.pop(tw_n.index(t['tweet_id']))
        tw_n = [tweet_id2text_dict[tid] for tid in tw_n]
        random.shuffle(tw_n)
        if len(tw_n) >= tweet_neighbours:
            sampled_twitter_n = tw_n[:tweet_neighbours]
        else:
            sampled_twitter_n = tw_n + [''] * (tweet_neighbours - len(tw_n))
                
        ## concate for poi
        poi = [t['poi_text']] + sampled_twitter_n + blank_for_poi
        
        fout.write('\t'.join(tw) + '\$\$'+'\t'.join(poi)+'\n')

In [None]:
# generate official validation set
# tweet_text \t ... blank_tweet_neighbour ... \t mention_id1 \t mention_id2 \t tag_id1 \t tag_id2 \t tag_id3 \t author_id $$ poi_text \t tweet_id_1 \t ... \t tweet_id_6 \t ... blank_notext_neighbour ... \n
# remember that we add blank position to make tweet_text node and poi node exactly the same
# this is used to generate val_textf.tsv

mention_neighbours = 2
tag_neighbours = 3
tweet_neighbours = 6

blank_for_tweet = [''] * tweet_neighbours
blank_for_poi = ['-1'] * (mention_neighbours + tag_neighbours + 1)

with open(f'Tweet_text/val_textf{filter_num}.tsv','w') as fout:
    for t in tqdm(val_set):
        # generate tweet center node
        ## sample mentions
        m = [mention_id2idx_dict[mm] for mm in t['mentions'] if mm in mention_id2idx_dict]
        random.shuffle(m)
        
        if len(m) >= mention_neighbours:
            sampled_mention_n = m[:mention_neighbours]
        else:
            sampled_mention_n = m + [-1] * (mention_neighbours - len(m))
        sampled_mention_n = [str(mm) for mm in sampled_mention_n]
        
        ## sample tags
        tags = [tag_id2idx_dict[tt] for tt in t['tags'] if tt in tag_id2idx_dict]
        random.shuffle(tags)
        
        if len(tags) >= tag_neighbours:
            sampled_tag_n = tags[:tag_neighbours]
        else:
            sampled_tag_n = tags + [-1] * (tag_neighbours - len(tags))
        sampled_tag_n = [str(tt) for tt in sampled_tag_n]
        
        ## concate for tweet
        if t['author_id'] in author_id2idx_dict:
            author = [str(author_id2idx_dict[t['author_id']])]
        else:
            author = ['-1']
        tw = [t['tweet_text']] + blank_for_tweet + sampled_mention_n + sampled_tag_n + author

        # generate poi center node
        if t['poi_id'] in poi_neighbour_dict:
            tw_n = list(poi_neighbour_dict[t['poi_id']])
            tw_n = [tweet_id2text_dict[tid] for tid in tw_n]
            random.shuffle(tw_n)
        else:
            tw_n = []
        if len(tw_n) >= tweet_neighbours:
            sampled_twitter_n = tw_n[:tweet_neighbours]
        else:
            sampled_twitter_n = tw_n + [''] * (tweet_neighbours - len(tw_n))
        
        ## concate for poi
        poi = [t['poi_text']] + sampled_twitter_n + blank_for_poi
        
        fout.write('\t'.join(tw) + '\$\$'+'\t'.join(poi)+'\n')

In [None]:
# generate official test set
# tweet_text \t ... blank_tweet_neighbour ... \t mention_id1 \t mention_id2 \t tag_id1 \t tag_id2 \t tag_id3 \t author_id $$ poi_text \t tweet_id_1 \t ... \t tweet_id_6 \t ... blank_notext_neighbour ... \n
# remember that we add blank position to make tweet_text node and poi node exactly the same
# this is used to generate test_text.tsv

mention_neighbours = 2
tag_neighbours = 3
tweet_neighbours = 6

blank_for_tweet = [''] * tweet_neighbours
blank_for_poi = ['-1'] * (mention_neighbours + tag_neighbours + 1)

with open(f'Tweet_text/test_textf{filter_num}.tsv','w') as fout:
    for t in tqdm(test_set):
        # generate tweet center node
        ## sample mentions
        m = [mention_id2idx_dict[mm] for mm in t['mentions'] if mm in mention_id2idx_dict]
        random.shuffle(m)
        
        if len(m) >= mention_neighbours:
            sampled_mention_n = m[:mention_neighbours]
        else:
            sampled_mention_n = m + [-1] * (mention_neighbours - len(m))
        sampled_mention_n = [str(mm) for mm in sampled_mention_n]
        
        ## sample tags
        tags = [tag_id2idx_dict[tt] for tt in t['tags'] if tt in tag_id2idx_dict]
        random.shuffle(tags)
        
        if len(tags) >= tag_neighbours:
            sampled_tag_n = tags[:tag_neighbours]
        else:
            sampled_tag_n = tags + [-1] * (tag_neighbours - len(tags))
        sampled_tag_n = [str(tt) for tt in sampled_tag_n]
        
        ## concate for tweet
        if t['author_id'] in author_id2idx_dict:
            author = [str(author_id2idx_dict[t['author_id']])]
        else:
            author = ['-1']
        tw = [t['tweet_text']] + blank_for_tweet + sampled_mention_n + sampled_tag_n + author

        # generate poi center node
        if t['poi_id'] in poi_neighbour_dict:
            tw_n = list(poi_neighbour_dict[t['poi_id']])
            tw_n = [tweet_id2text_dict[tid] for tid in tw_n]
            random.shuffle(tw_n)
        else:
            tw_n = []
        if len(tw_n) >= tweet_neighbours:
            sampled_twitter_n = tw_n[:tweet_neighbours]
        else:
            sampled_twitter_n = tw_n + [''] * (tweet_neighbours - len(tw_n))
        
        ## concate for poi
        poi = [t['poi_text']] + sampled_twitter_n + blank_for_poi
        
        fout.write('\t'.join(tw) + '\$\$'+'\t'.join(poi)+'\n')