In [1]:
import math
import random
import csv
import sys
import gc
from tqdm import tqdm
# other users followed by particular user. key:userID,value: set of users
user_follow = {}
# other users who follow this user. key: userID,value: set of users.
follow_user = {}
# how many people this user follow
user_follow_num = {}
# how many people follow this user
follow_user_num = {}

In [2]:
# input train set
with open('train.txt', "r") as f:
    for line in f:
        piece = line.strip().split("\t")
        this_user = piece[0]
        who_follow_this_user = set()
        for i in range(1, len(piece)):
            who_follow_this_user.add(piece[i])
            follow_user.setdefault(piece[i], set(this_user)).add(this_user)
            follow_user_num[piece[i]] = follow_user_num.get(piece[i], 0) + 1
        if len(who_follow_this_user)==0:
            print("user {} has an empty list".format(this_user))
        else:
            user_follow[this_user] = who_follow_this_user
            user_follow_num[this_user] = len(who_follow_this_user)
print("train.txt is inputed")

user 2891796 has an empty list
user 349769 has an empty list
user 4599839 has an empty list
user 862544 has an empty list
user 4613611 has an empty list
user 2841393 has an empty list
user 109695 has an empty list
user 3247090 has an empty list
user 3973524 has an empty list
user 1913301 has an empty list
user 1292187 has an empty list
user 939681 has an empty list
user 4177769 has an empty list
user 3128447 has an empty list
user 4659074 has an empty list
user 509343 has an empty list
user 1325954 has an empty list
user 1349278 has an empty list
user 1710589 has an empty list
user 2852811 has an empty list
user 166306 has an empty list
user 3927240 has an empty list
user 1556453 has an empty list
user 4056276 has an empty list
user 1091416 has an empty list
user 244655 has an empty list
user 4615180 has an empty list
user 2203977 has an empty list
user 2259763 has an empty list
user 3326584 has an empty list
user 2171778 has an empty list
user 3745536 has an empty list
user 1146958 ha

In [3]:
test_set = {}
test_result=[]
with open('test-public.txt', "r") as f:
    # read header and pass
    line = f.readline()
    # read data
    line = f.readline().strip()
    while line:
        piece = line.split("\t")
        test_set[piece[1]] = piece[2]
        test_result.append((piece[1],piece[2]))
        line = f.readline().strip()
print("test-public.txt is inputed.")

test-public.txt is inputed.


In [4]:
metrics = ['CN', 'JC', 'SI', 'SC', 'HP', 'HD', 'LHN', 'RA', 'PA']
"""
mode list:
feature explanation: reference from https://arxiv.org/pdf/1411.5118.pdf
follow means source user follow others
follow by means source user followed by others
intersection: source user follow and target user follow by
reverse_intersection:  source user follow by and target user follow
"""
# modes = ['follow', 'follow_by', 'intersection', 'reverse_intersection']
modes = ['follow_by', 'intersection']

In [5]:
features = ['source', 'target', 'exist',
            'user_follow_num_source', 'follow_user_num_source',
            'follow_user_num_target','v_community']
for mode in modes:
    for metric in metrics:
        features.append(mode + '_' + metric)
print("feature is created with dictionary: /n {}".format(features))

feature is created with dictionary: /n ['source', 'target', 'exist', 'user_follow_num_source', 'follow_user_num_source', 'follow_user_num_target', 'v_community', 'follow_by_CN', 'follow_by_JC', 'follow_by_SI', 'follow_by_SC', 'follow_by_HP', 'follow_by_HD', 'follow_by_LHN', 'follow_by_RA', 'follow_by_PA', 'intersection_CN', 'intersection_JC', 'intersection_SI', 'intersection_SC', 'intersection_HP', 'intersection_HD', 'intersection_LHN', 'intersection_RA', 'intersection_PA']


In [6]:
def calculate_metric(mode, source=None, target=None):
    # if mode is 'follow':
    #     first_list = user_follow.get(source, set()) # safe
    #     second_list = user_follow.get(target, set())
    # elif mode is 'follow_by':
    #     first_list = follow_user.get(source, set()) # safe
    #     second_list = follow_user.get(target, set()) # safe
    # elif mode is 'intersection':
    #     first_list = user_follow.get(source, set()) # safe
    #     second_list = follow_user.get(target, set()) # safe
    # else:
    #     first_list = follow_user.get(source, set()) # safe
    #     second_list = user_follow.get(target, set())

    if mode is 'follow_by':
        first_list = follow_user.get(source, set())  # safe
        second_list = follow_user.get(target, set())  # safe
    elif mode is 'intersection':
        first_list = user_follow.get(source, set())  # safe
        second_list = follow_user.get(target, set())  # safe
    else:
        print('Error, unknown mode {}'.format(mode))
        sys.exit(0)
    intersection = first_list & second_list
    union = first_list | second_list
    PA = len(first_list) * len(second_list)
    CN = float(len(intersection))
    JC = CN / len(union)
    SI = CN / (len(first_list) + len(second_list))
    SC = CN / math.sqrt(len(first_list) * len(second_list))
    HP = CN / min(len(first_list), len(second_list))
    HD = CN / max(len(first_list), len(second_list))
    LHN = CN / PA
    CN = int(CN)
    RA = 0
    # calculate RA
    if mode is "follow_by":
        for i in intersection:
            if user_follow_num.get(i):
                RA += float(1) / user_follow_num.get(i)
    elif mode is "intersection":
        # note that source follow more user, less important this intermediary is
        #           more user follow i, less important this intermediary is
        #           i follow more user , less important this intermediary is
        #           more user follow target, less important this intermediary is
        for i in intersection:
            res = user_follow_num.get(source, 0) * \
                  follow_user_num.get(i, 0) * \
                  user_follow_num.get(i, 0) * \
                  follow_user_num.get(target, 0)
            if res != 0:
                RA += float(1) / res
    else:
        for i in intersection:
            # note that target follow more user, less important this intermediary is
            #           more user follow i, less important this intermediary is
            #           i follow more user , less important this intermediary is
            #           more user follow source, less important this intermediary is
            res = user_follow_num.get(target, 0) * \
                  follow_user_num.get(i, 0) * \
                  user_follow_num.get(i, 0) * \
                  follow_user_num.get(source, 0)
            if res != 0:
                RA += float(1) / res
    return [CN, JC, SI, SC, HP, HD, LHN, RA, PA]

In [7]:
f = zip(follow_user_num.values(),follow_user_num.keys())
sorted(f,reverse=True)

[(4841, '20388'),
 (4637, '2740141'),
 (3654, '1054633'),
 (3432, '487639'),
 (3425, '2120801'),
 (3193, '2984603'),
 (3035, '3528096'),
 (2951, '4847918'),
 (2951, '3377456'),
 (2921, '1139433'),
 (2914, '2544271'),
 (2821, '2809458'),
 (2664, '2619379'),
 (2605, '3088232'),
 (2532, '4007276'),
 (2466, '1324992'),
 (2452, '1086605'),
 (2431, '2332893'),
 (2358, '247783'),
 (2294, '1413055'),
 (2239, '163021'),
 (2203, '4813076'),
 (2152, '1638122'),
 (2098, '1511168'),
 (2090, '448823'),
 (2084, '848620'),
 (2058, '1165817'),
 (2035, '3487024'),
 (2022, '4096612'),
 (2014, '1660431'),
 (2002, '2389556'),
 (1999, '2057013'),
 (1914, '3711578'),
 (1882, '2930986'),
 (1881, '3793690'),
 (1870, '2810556'),
 (1841, '487236'),
 (1838, '658596'),
 (1837, '4586866'),
 (1820, '1802489'),
 (1805, '3554524'),
 (1795, '4475906'),
 (1794, '201039'),
 (1789, '1228736'),
 (1783, '1328064'),
 (1769, '1242327'),
 (1765, '3394554'),
 (1752, '717002'),
 (1748, '282738'),
 (1743, '888836'),
 (1681, '4208

In [8]:
def generate_train_set():
    user_list = list(user_follow.keys())
    follow_list = list(follow_user.keys())
    with open("train_set.csv", "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=features)
        writer.writeheader()
        for i in tqdm(range(50000)):
            try:
                source = random.choice(user_list)
                target = random.choice(list(user_follow.get(source)))
                community = 0
                if '20388' in user_follow[source] and '20388' in follow_user[target]:
                    community = 1
                else:
                    community = 0
                feature_for_this = {'source': source,
                                    'target': target,
                                    'exist': 1,
                                    'user_follow_num_source': user_follow_num.get(source),
                                    'follow_user_num_source': follow_user_num.get(source),
                                    'follow_user_num_target': follow_user_num.get(target),
                                    'v_community':community}
                for mode in modes:
                    for one_metrics in list(zip(calculate_metric(mode=mode, source=source, target=target), metrics)):
                        feature_for_this[mode + '_' + one_metrics[1]] = one_metrics[0]
                writer.writerow(feature_for_this)
            except IndexError:
                pass
        print('positive sample generated')
        gc.collect()
        for i in tqdm(range(50000)):
            try:
                source = random.choice(user_list)
                while (True):
                    target = random.choice(follow_list)
                    if target in (user_follow.get(source)) or test_set.get(source) == target:
                        continue
                    else:
                        break
                community = 0
                if '20388' in user_follow[source] and '20388' in follow_user[target]:
                    community = 1
                else:
                    community = 0
                feature_for_this = {'source': source,
                                    'target': target,
                                    'exist': 0,
                                    'user_follow_num_source': user_follow_num.get(source),
                                    'follow_user_num_source': follow_user_num.get(source),
                                    'follow_user_num_target': follow_user_num.get(target),
                                    'v_community': community}
                for mode in modes:
                    for one_metrics in list(zip(calculate_metric(mode=mode, source=source, target=target), metrics)):
                        feature_for_this[mode + '_' + one_metrics[1]] = one_metrics[0]
                writer.writerow(feature_for_this)
            except IndexError:
                pass
        print('negative samples generated')
        gc.collect()
generate_train_set()

100%|██████████| 50000/50000 [00:20<00:00, 2463.50it/s]
  0%|          | 69/50000 [00:00<01:13, 677.38it/s]

positive sample generated


100%|██████████| 50000/50000 [00:13<00:00, 3633.46it/s]


negative samples generated


In [9]:
def generate_test_set():
    with open("test_set.csv", "w", newline="") as f:
        features.remove('exist')
        writer = csv.DictWriter(f, fieldnames=features)
        writer.writeheader()
        for test_sample in tqdm(test_result):
            community = 0
            if '20388' in user_follow[source] and '20388' in follow_user[target]:
                community = 1
            else:
                community = 0
            feature_for_this = {'source': test_sample[0],
                                'target': test_sample[1],
                                'user_follow_num_source': user_follow_num.get(test_sample[0]),
                                'follow_user_num_source': follow_user_num.get(test_sample[0]),
                                'follow_user_num_target': follow_user_num.get(test_sample[1]),
                                'v_community':community}
            for mode in modes:
                for one_metrics in list(zip(calculate_metric(mode=mode, source=test_sample[0], target=test_sample[1]), metrics)):
                    feature_for_this[mode + '_' + one_metrics[1]] = one_metrics[0]
            writer.writerow(feature_for_this)
        print('testcsv is generated')
generate_test_set()

  0%|          | 0/2000 [00:00<?, ?it/s]


NameError: name 'source' is not defined