# Knowledge Graph

To build a knowledge graph, different relationships between products and stores are extracted from the CTC datasets. The current focus is to create a dense representation of the data with a high quality knowledge graph. The graph will later be converted into embeddings. 

## Initializing Notebook

In [2]:
import pandas as pd
import numpy as np
import pickle
from sklearn.model_selection import train_test_split
import torch
from torchkge.data_structures import KnowledgeGraph




import torch
from ignite.engine import Engine, Events
from ignite.handlers import EarlyStopping
from ignite.metrics import RunningAverage
from torch.optim import Adam

from torchkge.evaluation import LinkPredictionEvaluator
from torchkge.models import TransRModel
from torchkge.models import TransEModel
from torchkge.sampling import BernoulliNegativeSampler
from torchkge.utils import MarginLoss, DataLoader
from torchkge.utils.datasets import load_fb15k
from torchkge.data_structures import KnowledgeGraph
from tqdm.autonotebook import tqdm


  from .autonotebook import tqdm as notebook_tqdm


In [17]:
basket_path = "clean_data/cleaned_basket.csv"
df_basket = pd.read_csv(basket_path).drop(columns=['Unnamed: 0'])
df_basket["product_num"].values.astype(int, copy=False)
df_basket["store_num"].values.astype(int, copy=False)

array([491, 491, 491, ..., 133, 133, 441])

In [18]:
products_path = "clean_data/cleaned_products.csv"
df_products = pd.read_csv(products_path).drop(columns=['Unnamed: 0'])


## Sources and Relevant Links

https://www.kaggle.com/code/nageshsingh/build-knowledge-graph-using-python

https://aws-dglke.readthedocs.io/en/latest/kg.html 

https://arxiv.org/abs/2107.07842 

Idea:
- Nodes: products, stores, and hierarchical categories 
- Relationships (Edges): "bought together" (using basket data), "sold in" (store data), "part of"/"subcategory of" (product hierarchical data) 


Product Knowledge Graph Embedding for E-commerce: https://arxiv.org/pdf/1911.12481.pdf 



# 1. Extract Relationships for the Knowledge Graph

In [4]:
def bought_together_pairs(df,groupby,target_column):
    grouped_df = df.groupby(groupby)
    all_pairs = [] #source,target

    for key, item in grouped_df:
        groups = item[target_column].unique()
        if len(groups)>1:
            for index, source in enumerate(groups):
                for target in groups[index+1:]:
                    all_pairs.append([source,target, "bought together"])
    return all_pairs

### Extract relationships from the Basket Data 

The basket data allows us to see which products have been bought together. Thus, we can extract relationships of the form '[product A, product B, "bought together]'. The triplets are stored in the list called bought_together. Extracting these relationships could take some time, so the data has been pickled so the extraction does not need to occur every time the notebook is run.

In [24]:
#specify "product" in front of product nums and "store" in front of store nums for clarity
df_basket['product_num'] = df_basket['product_num'].apply(lambda x: "{}{}".format('product ', x))
df_basket['store_num'] = df_basket['store_num'].apply(lambda x: "{}{}".format('store ', x))


In [25]:
df_basket

Unnamed: 0,basket_id,transaction_date,store_num,product_num,sales_qty
0,0000004cae04f049fdc627711e8a598c,2021-09-03,store 491,product 1517937,1
1,0000004cae04f049fdc627711e8a598c,2021-09-03,store 491,product 1531230,1
2,0000004cae04f049fdc627711e8a598c,2021-09-03,store 491,product 591626,1
3,0000004cae04f049fdc627711e8a598c,2021-09-03,store 491,product 1531549,2
4,0000004cae04f049fdc627711e8a598c,2021-09-03,store 491,product 1510858,1
...,...,...,...,...,...
27928321,ffffe37cd7a10396dd7c1a6622339949,2022-08-25,store 133,product 390512,1
27928322,ffffe37cd7a10396dd7c1a6622339949,2022-08-25,store 133,product 467554,1
27928323,ffffe37cd7a10396dd7c1a6622339949,2022-08-25,store 133,product 813711,2
27928324,ffffe37cd7a10396dd7c1a6622339949,2022-08-25,store 133,product 8420114,2


Extracting the "bought together" relationship from the basket data takes a while to run. We created a pickle data file to load instead.

In [7]:

'''
bought_together = bought_together_pairs(df_basket,"basket_id","product_num") #source,target
len(bought_together)
'''


'\nbought_together = bought_together_pairs(df_basket,"basket_id","product_num") #source,target\nlen(bought_together)\n'

In [8]:

'''
#pickle the data
with open('embeddings/pickle_bought_together.data', 'wb') as f:
        pickle.dump(bought_together, f)

'''


"\n#pickle the data\nwith open('embeddings/pickle_bought_together.data', 'wb') as f:\n        pickle.dump(bought_together, f)\n\n"

In [16]:
#open the pickled data and save
infile = open('embeddings/pickle_bought_together.data','rb')
bought_together = pickle.load(infile)
infile.close()

### Extract relationships from Product Standard Data

The product standard dataset will be filtered down to the products which appear in the basket data. Three relationships will be extracted to capture products that are part of categories, and categories that are subcategories of larger ones.

[ctr_product_num, merch_bus_cat_nm, "part of"]

[merch_bus_cat_nm, merch_lob_nm, "subcategory of"] 

[merch_lob_nm, merch_division_nm, "subcategory of"] 

In [10]:
df_products

Unnamed: 0,ctr_product_num,ctr_style_name,short_desc,long_desc,merch_division_nm,merch_lob_nm,merch_bus_cat_nm,merch_subcat_nm,merch_fineline_nm,corporate_status_cd,...,ctr_product_profile_cd,ctr_consumer_role_cd,package_depth_qty,package_height_qty,package_width_qty,package_volume_qty,package_weight_qty,national_consumer_price_amt,cold_sensitive_ind,heat_sensitive_ind
0,81282,,LT285/65R20 127Q E B,Goodyear Wrangler MTR Kevlar LT285/65R20 127Q ...,AUTOMOTIVE,TIRES,ALL TERRAIN TIRES,Special Order All Terrain Truck & SUV Tires,SOP Goodyear All Terrain Tires,ACT,...,JOB_JOY,EMERG_DESTINATION,1.0,1.0,1.0,0.000579,1.000,0.00,N,N
1,82263,,LT265/70R17 C WR ATS,,AUTOMOTIVE,TIRES,ALL TERRAIN TIRES,Special Order All Terrain Truck & SUV Tires,SOP Goodyear All Terrain Tires,FD,...,JOB_JOY,EMERG_DESTINATION,31.6,10.4,31.6,6.009852,1.000,268.99,N,N
2,121236,,28941 DF CONVERTER,,AUTOMOTIVE,HEAVY AUTO PARTS,EXHAUST,Emission Control,Ultra (direct Fit) Converters,FD,...,JOB_JOY,DESTINATION,27.4,6.3,6.3,0.629344,7.300,165.35,N,N
3,126085,,349531 EXHAUST PIPE,,AUTOMOTIVE,HEAVY AUTO PARTS,EXHAUST,Exhaust Pipes,Exhaust Pipes,FD,...,JOB_JOY,DESTINATION,40.5,4.5,6.2,0.653906,4.850,77.99,N,N
4,126331,,369826 EXHAUST PIPE,,AUTOMOTIVE,HEAVY AUTO PARTS,EXHAUST,Exhaust Pipes,Exhaust Pipes,FD,...,JOB_JOY,DESTINATION,69.3,8.0,20.0,6.416667,7.400,51.99,N,N
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
824501,8517399,,BLK/BONE SHRUG,BLK/BONE SHRUG,SEASONAL & GARDENING,PARTY CITY SEASONAL,PARTY CITY HALLOWEEN & HARVEST,Party City Halloween Accessories,Party City Halloween Accessories Themed,FD,...,USABLE,EMERG_DESTINATION,13.2,3.0,9.5,0.072569,0.330,13.00,N,N
824502,8520520,,AD LG/XL SHIRT SPONG,SpongeBob Fitted T-Shirt,SEASONAL & GARDENING,PARTY CITY SEASONAL,PARTY CITY HALLOWEEN & HARVEST,Party City Halloween Accessories,Party City Halloween Accessories Licensed,FD,...,USABLE,EMERG_DESTINATION,11.5,1.0,17.0,0.037712,0.367,10.00,N,N
824503,8520752,,D563 AD 18-20 BLESSE,Adult Blessed Babe Nun Costume Plus Size,SEASONAL & GARDENING,PARTY CITY SEASONAL,PARTY CITY HALLOWEEN COSTUMES,Party City Costumes Adult,Party City Costumes Women,FD,...,,,12.0,1.5,9.5,0.049479,0.825,29.99,N,N
824504,8527987,,VAL GLTR FOAM STKRS,Valentines Glitter Foam Heart Stickers 285ct,SEASONAL & GARDENING,PARTY CITY SEASONAL,PARTY CITY MICRO SEASON DCOR,Party City Valentine's Day,Party City Valentine's Day Favours and Cards,TD,...,USABLE,EMERG_DESTINATION,12.6,5.4,9.0,0.059062,0.217,0.00,N,N


In [19]:
#add the word "product" before product num (differentiate from store nums later on)
df_products['ctr_product_num'] = df_products['ctr_product_num'].apply(lambda x: "{}{}".format('product ', x))

In [20]:
#filter products to those that appear in basket data
df_products_basket = df_products[df_products.ctr_product_num.isin(df_basket.product_num.unique())]
len(df_products_basket)

0

In [23]:
df_products_basket

Unnamed: 0,ctr_product_num,ctr_style_name,short_desc,long_desc,merch_division_nm,merch_lob_nm,merch_bus_cat_nm,merch_subcat_nm,merch_fineline_nm,corporate_status_cd,...,ctr_product_profile_cd,ctr_consumer_role_cd,package_depth_qty,package_height_qty,package_width_qty,package_volume_qty,package_weight_qty,national_consumer_price_amt,cold_sensitive_ind,heat_sensitive_ind


Modified the get_pairs function to include the relation that is being captured as an input parameter

In [21]:
def get_pairs_relation(df,groupby,target_column,relation):
    grouped_df = df.groupby(groupby)
    all_pairs = [] #source,target

    for source, item in grouped_df:
        groups = item[target_column].unique()
        for index, target in enumerate(groups):
            all_pairs.append([source, target, relation])
    return all_pairs

In [22]:
#[product num, business category, relation: "part of"]
prod_buscat = get_pairs_relation(df_products_basket, "ctr_product_num", "merch_bus_cat_nm", "part of")
prod_buscat

[]

In [16]:
#[merch bus cat, merch lob nm, relation: "subcategory of"]
buscat_lob = get_pairs_relation(df_products_basket, "merch_bus_cat_nm", "merch_lob_nm", "subcategory of")
buscat_lob

[['AIR CONDITIONING CHEMICALS', 'HEAVY AUTO PARTS', 'subcategory of'],
 ['AIR TOOLS & ACCESSORIES', 'TOOLS', 'subcategory of'],
 ['ALL SEASON TIRES', 'TIRES', 'subcategory of'],
 ['ALL TERRAIN TIRES', 'TIRES', 'subcategory of'],
 ['ALL WEATHER TIRES', 'TIRES', 'subcategory of'],
 ['AMMUNITION', 'HUNTING', 'subcategory of'],
 ['ASOTV', 'AS SEEN ON TV', 'subcategory of'],
 ['AUTO ACCESSORIES', 'CAR CARE & ACCESSORIES', 'subcategory of'],
 ['AUTO BATTERIES', 'LIGHT AUTO PARTS', 'subcategory of'],
 ['AUTO BATTERY ACCESSORIES', 'LIGHT AUTO PARTS', 'subcategory of'],
 ['AUTO BODY REPAIR', 'CAR CARE & ACCESSORIES', 'subcategory of'],
 ['AUTO CLEANING - CHEMICALS', 'CAR CARE & ACCESSORIES', 'subcategory of'],
 ['AUTO CLEANING - TOOLS', 'CAR CARE & ACCESSORIES', 'subcategory of'],
 ['AUTO COMFORT', 'CAR CARE & ACCESSORIES', 'subcategory of'],
 ['AUTO ELECTRONICS', 'CAR CARE & ACCESSORIES', 'subcategory of'],
 ['AUTO FLUIDS', 'AUTO MAINTENANCE', 'subcategory of'],
 ['AUTO HEATING & COOLING PARTS

In [17]:
lob_division = get_pairs_relation(df_products_basket, "merch_lob_nm", "merch_division_nm", "subcategory of")
lob_division

[['AS SEEN ON TV', 'LIVING', 'subcategory of'],
 ['AUTO MAINTENANCE', 'AUTOMOTIVE', 'subcategory of'],
 ['AUTOMOTIVE OUTDOOR ADVENTURE', 'AUTOMOTIVE', 'subcategory of'],
 ['BACKYARD FUN', 'SEASONAL & GARDENING', 'subcategory of'],
 ['BACKYARD LIVING', 'SEASONAL & GARDENING', 'subcategory of'],
 ['CAMPING', 'PLAYING', 'subcategory of'],
 ['CAR CARE & ACCESSORIES', 'AUTOMOTIVE', 'subcategory of'],
 ['CLEANING', 'LIVING', 'subcategory of'],
 ['CYCLING', 'PLAYING', 'subcategory of'],
 ['ELECTRICAL', 'FIXING', 'subcategory of'],
 ['ELECTRONICS', 'LIVING', 'subcategory of'],
 ['EXERCISE', 'PLAYING', 'subcategory of'],
 ['FISHING', 'PLAYING', 'subcategory of'],
 ['FOOD & DRINK', 'LIVING', 'subcategory of'],
 ['FOOTWEAR & APPAREL', 'PLAYING', 'subcategory of'],
 ['GARDENING', 'SEASONAL & GARDENING', 'subcategory of'],
 ['HARDWARE', 'FIXING', 'subcategory of'],
 ['HEAVY AUTO PARTS', 'AUTOMOTIVE', 'subcategory of'],
 ['HOCKEY', 'PLAYING', 'subcategory of'],
 ['HOME DCOR', 'LIVING', 'subcategory 

### Extract store semantics 

In [18]:
in_store = get_pairs_relation(df_basket,"product_num", "store_num", "sold at")

In [19]:
in_store

[['product 100011', 'store 145', 'sold at'],
 ['product 100011', 'store 365', 'sold at'],
 ['product 100011', 'store 110', 'sold at'],
 ['product 100011', 'store 290', 'sold at'],
 ['product 100011', 'store 478', 'sold at'],
 ['product 100011', 'store 339', 'sold at'],
 ['product 100011', 'store 304', 'sold at'],
 ['product 100011', 'store 58', 'sold at'],
 ['product 100011', 'store 351', 'sold at'],
 ['product 100011', 'store 302', 'sold at'],
 ['product 100011', 'store 83', 'sold at'],
 ['product 100011', 'store 333', 'sold at'],
 ['product 100011', 'store 344', 'sold at'],
 ['product 100011', 'store 481', 'sold at'],
 ['product 100011', 'store 406', 'sold at'],
 ['product 100011', 'store 305', 'sold at'],
 ['product 100011', 'store 651', 'sold at'],
 ['product 100011', 'store 111', 'sold at'],
 ['product 100011', 'store 453', 'sold at'],
 ['product 100011', 'store 398', 'sold at'],
 ['product 100011', 'store 288', 'sold at'],
 ['product 100011', 'store 916', 'sold at'],
 ['product 1

# 2. Construct the Knowledge Graph

Combine all extracted relationships and put them into a dataframe for the KG. It must have three columns: "from", "to", and "rel" to describe pairs of nodes and their relationships.

Note: currently, the product and business category relation (prod_buscat) causes a strange error, so it is excluded from this version of the knowledge graph.

In [20]:
#Combine the extracted relationships into one list
all_rels = bought_together  + buscat_lob + lob_division + in_store # prod_buscat
len(all_rels)

59218851

In [21]:
all_rels

[['product 1517937', 'product 1531230', 'bought together'],
 ['product 1517937', 'product 591626', 'bought together'],
 ['product 1517937', 'product 1531549', 'bought together'],
 ['product 1517937', 'product 1510858', 'bought together'],
 ['product 1531230', 'product 591626', 'bought together'],
 ['product 1531230', 'product 1531549', 'bought together'],
 ['product 1531230', 'product 1510858', 'bought together'],
 ['product 591626', 'product 1531549', 'bought together'],
 ['product 591626', 'product 1510858', 'bought together'],
 ['product 1531549', 'product 1510858', 'bought together'],
 ['product 422578', 'product 426805', 'bought together'],
 ['product 422578', 'product 1531669', 'bought together'],
 ['product 426805', 'product 1531669', 'bought together'],
 ['product 670830', 'product 538011', 'bought together'],
 ['product 6533508', 'product 852275', 'bought together'],
 ['product 6533508', 'product 1530397', 'bought together'],
 ['product 852275', 'product 1530397', 'bought toge

In [22]:
# extract subject
source = [i[0] for i in all_rels] 

# extract object
target = [i[1] for i in all_rels] 

# state relationship
relations = [i[2] for i in all_rels] 

In [31]:
#assign to dataframe
kg_df = pd.DataFrame({'from':source, 'to':target, 'rel':relations})
kg_df

Unnamed: 0,from,to,rel
0,product 1517937,product 1531230,bought together
1,product 1517937,product 591626,bought together
2,product 1517937,product 1531549,bought together
3,product 1517937,product 1510858,bought together
4,product 1531230,product 591626,bought together
...,...,...,...
59218846,product 997809,store 218,sold at
59218847,product 997809,store 175,sold at
59218848,product 999765,store 175,sold at
59218849,product 999999,store 224,sold at


In [32]:
#split the df into train val and test 
df_train, df_val, df_test = np.split(kg_df.sample(frac=1, random_state=42), [int(.6*len(kg_df)), int(.8*len(kg_df))])

Creating the knowledge graph takes about 32 minutes. The train, val and test kgs have been pickeld for easier future use. If you wish to redo the KG generation and pickling, uncomment the two cells below.

In [33]:

# Turn into knowledge graph
kg_train = KnowledgeGraph(df=df_train)
kg_val = KnowledgeGraph(df=df_val)
kg_test = KnowledgeGraph(df=df_test)


In [23]:
'''
#pickle the train, val, and test KGs
with open('embeddings/pickle_kg_train.data', 'wb') as f:
        pickle.dump(kg_train, f)

with open('embeddings/pickle_kg_val.data', 'wb') as f:
        pickle.dump(kg_val, f)

with open('embeddings/pickle_kg_test.data', 'wb') as f:
        pickle.dump(kg_test, f)

'''

"\n#pickle the train, val, and test KGs\nwith open('embeddings/pickle_kg_train.data', 'wb') as f:\n        pickle.dump(kg_train, f)\n\nwith open('embeddings/pickle_kg_val.data', 'wb') as f:\n        pickle.dump(kg_val, f)\n\nwith open('embeddings/pickle_kg_test.data', 'wb') as f:\n        pickle.dump(kg_test, f)\n\n"

The pickles could take a few minutes to open

In [6]:
#open the pickles
infile = open('embeddings/pickle_kg_train.data','rb')
kg_train = pickle.load(infile)
infile.close()

infile = open('embeddings/pickle_kg_val.data','rb')
kg_val = pickle.load(infile)
infile.close()

infile = open('embeddings/pickle_kg_test.data','rb')
kg_test = pickle.load(infile)
infile.close()

# 3. Create Knowledge Graph Embedding using torchkge

https://torchkge.readthedocs.io/en/latest/

Training with Ignite, following https://torchkge.readthedocs.io/en/latest/tutorials/training.html



https://kge-tutorial-ecai2020.github.io/ECAI-20_KGE_tutorial.pdf


The conversion of knowledge graphs into embeddings is still in progress. The intuition is to choose a translational model and attempt link prediction between the nodes.

In [31]:
'''
def process_batch(engine, batch):
    h, t, r = batch[0], batch[1], batch[2]
    n_h, n_t = sampler.corrupt_batch(h, t, r)

    optimizer.zero_grad()

    pos, neg = model(h, t, r, n_h, n_t)
    loss = criterion(pos, neg)
    loss.backward()
    optimizer.step()

    return loss.item()


def linkprediction_evaluation(engine):
    model.normalize_parameters()

    loss = engine.state.output

    # validation MRR measure
    if engine.state.epoch % eval_epoch == 0:
        evaluator = LinkPredictionEvaluator(model, kg_val)
        evaluator.evaluate(b_size=256, verbose=False)
        val_mrr = evaluator.mrr()[1]
    else:
        val_mrr = 0

    print('Epoch {} | Train loss: {}, Validation MRR: {}'.format(
        engine.state.epoch, loss, val_mrr))

    try:
        if engine.state.best_mrr < val_mrr:
            engine.state.best_mrr = val_mrr
        return val_mrr

    except AttributeError as e:
        if engine.state.epoch == 1:
            engine.state.best_mrr = val_mrr
            return val_mrr
        else:
            raise e
'''

"\ndef process_batch(engine, batch):\n    h, t, r = batch[0], batch[1], batch[2]\n    n_h, n_t = sampler.corrupt_batch(h, t, r)\n\n    optimizer.zero_grad()\n\n    pos, neg = model(h, t, r, n_h, n_t)\n    loss = criterion(pos, neg)\n    loss.backward()\n    optimizer.step()\n\n    return loss.item()\n\n\ndef linkprediction_evaluation(engine):\n    model.normalize_parameters()\n\n    loss = engine.state.output\n\n    # validation MRR measure\n    if engine.state.epoch % eval_epoch == 0:\n        evaluator = LinkPredictionEvaluator(model, kg_val)\n        evaluator.evaluate(b_size=256, verbose=False)\n        val_mrr = evaluator.mrr()[1]\n    else:\n        val_mrr = 0\n\n    print('Epoch {} | Train loss: {}, Validation MRR: {}'.format(\n        engine.state.epoch, loss, val_mrr))\n\n    try:\n        if engine.state.best_mrr < val_mrr:\n            engine.state.best_mrr = val_mrr\n        return val_mrr\n\n    except AttributeError as e:\n        if engine.state.epoch == 1:\n           

In [32]:
'''
infile = open('embeddings/pickle_bought_together.data','rb')
bought_together_pairs = pickle.load(infile)
infile.close()

# extract subject
source = [i[0] for i in bought_together_pairs]

# extract object
target = [i[1] for i in bought_together_pairs]

# state relationship
relations = ["bought together" for i in bought_together_pairs]

kg_df = pd.DataFrame({'from':source, 'to':target, 'rel':relations})

df_train, df_val, df_test = np.split(kg_df.sample(frac=1, random_state=42), [int(.6*len(kg_df)), int(.8*len(kg_df))])
'''

'\ninfile = open(\'embeddings/pickle_bought_together.data\',\'rb\')\nbought_together_pairs = pickle.load(infile)\ninfile.close()\n\n# extract subject\nsource = [i[0] for i in bought_together_pairs]\n\n# extract object\ntarget = [i[1] for i in bought_together_pairs]\n\n# state relationship\nrelations = ["bought together" for i in bought_together_pairs]\n\nkg_df = pd.DataFrame({\'from\':source, \'to\':target, \'rel\':relations})\n\ndf_train, df_val, df_test = np.split(kg_df.sample(frac=1, random_state=42), [int(.6*len(kg_df)), int(.8*len(kg_df))])\n'

### Iteration 0: Bought Together KG

In [33]:
# Turn into knowledge graph
#kg_train = KnowledgeGraph(df=df_train)


In [34]:
#kg_val = KnowledgeGraph(df=df_val)
#kg_test = KnowledgeGraph(df=df_test)

In [35]:
#with open('embeddings/pickle_merch_pairs_kg_train.data', 'wb') as f:
#        pickle.dump(kg_train, f)

In [36]:
#with open('embeddings/pickle_merch_pairs_kg_val.data', 'wb') as f:
#        pickle.dump(kg_val, f)

In [37]:
#with open('embeddings/pickle_merch_pairs_kg_test.data', 'wb') as f:
#        pickle.dump(kg_test, f)

In [38]:
'''
infile = open('embeddings/pickle_merch_pairs_kg_train.data','rb')
kg_train = pickle.load(infile)
infile.close()

infile = open('embeddings/pickle_merch_pairs_kg_val.data','rb')
kg_val = pickle.load(infile)
infile.close()

infile = open('embeddings/pickle_merch_pairs_kg_test.data','rb')
kg_test = pickle.load(infile)
infile.close()
'''

"\ninfile = open('embeddings/pickle_merch_pairs_kg_train.data','rb')\nkg_train = pickle.load(infile)\ninfile.close()\n\ninfile = open('embeddings/pickle_merch_pairs_kg_val.data','rb')\nkg_val = pickle.load(infile)\ninfile.close()\n\ninfile = open('embeddings/pickle_merch_pairs_kg_test.data','rb')\nkg_test = pickle.load(infile)\ninfile.close()\n"

In [38]:
'''
import os

# prevent memory issue
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = 'max_split_size_mb:512'
'''

'\nimport os\n\n# prevent memory issue\nos.environ["PYTORCH_CUDA_ALLOC_CONF"] = \'max_split_size_mb:512\'\n'

In [4]:

ent_emb_dim = 1000
rel_emb_dim = 3
lr = 0.0005
n_epochs = 100
b_size = 32768
margin = 0.5


In [5]:

# Define the model and criterion
model = TransRModel(ent_emb_dim,rel_emb_dim, kg_train.n_ent, kg_train.n_rel)
criterion = MarginLoss(margin)

# Move everything to CUDA if available
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    model.cuda()
    criterion.cuda()

# Define the torch optimizer to be used
optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-5)

sampler = BernoulliNegativeSampler(kg_train)
dataloader = DataLoader(kg_train, batch_size=b_size, use_cuda='all')



In [6]:

iterator = tqdm(range(n_epochs), unit='epoch')
for epoch in iterator:
    running_loss = 0.0
    for i, batch in enumerate(dataloader):
        h, t, r = batch[0], batch[1], batch[2]
        n_h, n_t = sampler.corrupt_batch(h, t, r)

        optimizer.zero_grad()

        # forward + backward + optimize
        pos, neg = model(h, t, r, n_h, n_t)
        loss = criterion(pos, neg)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    iterator.set_description(
        'Epoch {} | mean loss: {:.5f}'.format(epoch + 1,
                                              running_loss / len(dataloader)))

model.normalize_parameters()


Epoch 100 | mean loss: 2436.30601: 100%|██████████| 100/100 [4:38:47<00:00, 167.27s/epoch]


In [7]:
torch.save(model.state_dict(), "embeddings/transr_state_dict_version_2.pt")

In [8]:
model

TransRModel(
  (ent_emb): Embedding(143618, 1000)
  (rel_emb): Embedding(3, 3)
  (proj_mat): Embedding(3, 3000)
)

In [7]:
model = TransRModel(ent_emb_dim,rel_emb_dim, kg_train.n_ent, kg_train.n_rel)  
model.load_state_dict(torch.load("embeddings/transr_state_dict_version_2.pt"))
model.eval()

TransRModel(
  (ent_emb): Embedding(143618, 1000)
  (rel_emb): Embedding(3, 3)
  (proj_mat): Embedding(3, 3000)
)

In [13]:
entity_emb,rel_emb, proj_mat = model.get_embeddings()

In [14]:
entity_emb

tensor([[-3.1623e-02, -3.1623e-02, -3.1623e-02,  ..., -3.1623e-02,
          3.1623e-02,  3.1623e-02],
        [-8.9981e-09,  1.9842e-12,  1.0370e-41,  ..., -3.6937e-20,
         -2.0670e-13,  1.1155e-07],
        [-1.8619e-03, -9.3253e-03,  8.6880e-03,  ...,  2.9323e-02,
          1.8425e-02, -6.3541e-04],
        ...,
        [-5.6488e-03,  3.4005e-04,  7.5890e-03,  ..., -1.5659e-04,
         -7.7505e-03,  1.5967e-02],
        [-1.1414e-02,  2.2066e-02, -7.1478e-03,  ..., -1.2597e-02,
         -1.2898e-02, -1.9814e-03],
        [ 1.4412e-02, -3.9612e-03, -1.5794e-02,  ..., -2.1708e-02,
         -2.1330e-02, -5.6367e-03]])

In [None]:
kg_train.evaluate_dicts