In [1]:
# ! pip install wandb

In [1]:

import time
import torch
import math
import argparse
import wandb
import os

import json
import torch.nn as nn
import torch.nn.functional as F


In [2]:

parser = argparse.ArgumentParser()
parser.add_argument('--no_cuda', action='store_true', default=False, help='disables CUDA training (eg. no nvidia GPU)')
parser.add_argument('--epochs', type=int, default=1000, help='number of epochs to train')
# model parameters
parser.add_argument('--model', type=str, default='softbox', help='model type: choose from softbox, gumbel')
parser.add_argument('--box_embedding_dim', type=int, default=40, help='box embedding dimension')
parser.add_argument('--softplus_temp', type=float, default=1.0, help='beta of softplus function')
# gumbel box parameters
parser.add_argument('--gumbel_beta', type=float, default=1.0, help='beta value for gumbel distribution')
parser.add_argument('--scale', type=float, default=1.0, help='scale value for gumbel distribution')

parser.add_argument('--dataset', type=str, default='GALEN', help='dataset')
parser.add_argument('--using_rbox', type=int, default=1, help='using_rbox')
parser.add_argument('--gpu', type=int, default=0, help='gpu')

parser.add_argument('--dimension', type=int, default=50, help='dimension')
parser.add_argument('--learning_rate', type=int, default=0.001, help='learning_rate')
parser.add_argument('--batch_size', type=int, default=256, help='batch_size')
parser.add_argument('--seed', type=int, default=1111, help='seed')

args = parser.parse_args(args=['--no_cuda'] )
args.save_to = "./checkpoints/" + args.model

gpu = args.gpu
dimension = args.dimension
learning_rate = args.learning_rate
batch_size = args.batch_size
seed = args.seed
dataset = args.dataset
using_rbox = args.using_rbox
# device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device = 'cpu'


In [90]:

import time
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import uniform

eps = 1e-8

def l2_side_regularizer(box, log_scale: bool = True):
    """Applies l2 regularization on all sides of all boxes and returns the sum.
    """
    min_x = box.min_embed 
    delta_x = box.delta_embed  

    if not log_scale:
        return torch.mean(delta_x ** 2)
    else:
        return torch.mean(F.relu(min_x + delta_x - 1 + eps )) +  torch.mean(F.relu(- min_x - eps)) #+ F.relu(torch.norm(min_x, p=2)-1)

# def l2_volume_regularizer(box, log_scale: bool = True):
#     """Applies l2 regularization on all sides of all boxes and returns the sum.
#     """
#     min_x = box.min_embed 
#     delta_x = box.delta_embed  

#     if not log_scale:
#         return torch.mean(delta_x ** 2)
#     else:
#         return torch.mean(F.relu(min_x + delta_x - 1 + eps )) +  torch.mean(F.relu(- min_x - eps)) #+ F.relu(torch.norm(min_x, p=2)-1)

    
class Box:
    def __init__(self, min_embed, max_embed):
        self.min_embed = min_embed
        self.max_embed = max_embed
        self.delta_embed = max_embed - min_embed
        
class BoxEL(nn.Module):
    def __init__(self, vocab_size, relation_size, embed_dim, min_init_value, delta_init_value, relation_init_value, scaling_init_value, args):
        super(BoxEL, self).__init__()
        min_embedding = self.init_concept_embedding(vocab_size, embed_dim, min_init_value)
        delta_embedding = self.init_concept_embedding(vocab_size, embed_dim, delta_init_value)
        relation_embedding = self.init_concept_embedding(relation_size, embed_dim, relation_init_value)
        scaling_embedding = self.init_concept_embedding(relation_size, embed_dim, scaling_init_value)
#         
        self.temperature = args.softplus_temp
        self.min_embedding = nn.Parameter(min_embedding)
        self.delta_embedding = nn.Parameter(delta_embedding)
        self.relation_embedding = nn.Parameter(relation_embedding)
        self.scaling_embedding = nn.Parameter(scaling_embedding)
        self.min_embedding = nn.Parameter(min_embedding)
        
        self.gumbel_beta = args.gumbel_beta
        self.scale = args.scale

    def forward(self, data):
        
        nf1_min = self.min_embedding[data[0][:,[0,1]].long()-1000]
        nf1_delta = self.delta_embedding[data[0][:,[0,1]].long()-1000]
        nf1_max = nf1_min+torch.exp(nf1_delta)
        lower_bound = data[0][:,[2]]
        upper_bound = data[0][:,[2]]
        
        boxes1 = Box(nf1_min[:, 0, :], nf1_max[:, 0, :])
        boxes2 = Box(nf1_min[:, 1, :], nf1_max[:, 1, :])

        nf1_loss, reg_loss = self.nf1_loss(boxes1, boxes2, lower_bound, upper_bound)
        
        nf2_min = self.min_embedding[data[1][:,[0,1,2]].long()-1000]
        nf2_delta = self.delta_embedding[data[1][:,[0,1,2]].long()-1000]
        nf2_max = nf2_min+torch.exp(nf2_delta)
        
        boxes1 = Box(nf2_min[:, 0, :], nf2_max[:, 0, :])
        boxes2 = Box(nf2_min[:, 1, :], nf2_max[:, 1, :])
        boxes3 = Box(nf2_min[:, 2, :], nf2_max[:, 2, :])
        
        lower_bound = data[1][:,[3]]
        upper_bound = data[1][:,[3]]
        
        nf2_loss, _ = self.nf2_loss(boxes1, boxes2, boxes2, lower_bound, upper_bound)
        
        
        return nf1_loss.sum(), nf2_loss.sum(), reg_loss

    def volumes(self, boxes):
        return boxes.delta_embed.prod(1, keepdim=True)

    def intersection(self, boxes1, boxes2):
        intersections_min = torch.max(boxes1.min_embed, boxes2.min_embed)
        intersections_max = torch.min(boxes1.max_embed, boxes2.max_embed)
        intersection_box = Box(intersections_min, intersections_max)
        return intersection_box
    
    def inclusion(self, boxes1, boxes2):
        log_intersection = torch.log(torch.clamp(self.volumes(self.intersection(boxes1, boxes2)), 1e-10, 1e4))
        log_box1 = torch.log(torch.clamp(self.volumes(boxes1), 1e-10, 1e4))
        return torch.exp(log_intersection-log_box1)
    
    def nf1_loss(self, boxes1, boxes2, lower_bound,upper_bound):
        inters_box = self.intersection(boxes1, boxes2)
        inters_logvol = torch.log(torch.clamp(self.volumes(inters_box), 1e-5, 1e4))
        box1_logvol = torch.log(torch.clamp(self.volumes(boxes1), 1e-5, 1e4))
        
        upper_loss = F.relu( torch.exp(inters_logvol - box1_logvol) - upper_bound ) 
        lower_loss = F.relu( lower_bound -  torch.exp(box1_logvol-inters_logvol) )
        return upper_loss + lower_loss, l2_side_regularizer(boxes1, log_scale=True) + l2_side_regularizer(boxes2, log_scale=True)
        
    def nf2_loss(self, boxes1, boxes2, boxes3, lower_bound,upper_bound):
        inter_boxes = self.intersection(boxes1, boxes2)
        return self.nf1_loss(inter_boxes, boxes3, lower_bound, upper_bound)
    
    def nf2_loss(self, boxes1, boxes2, boxes3, lower_bound,upper_bound):
        inter_boxes = self.intersection(boxes1, boxes2)
        return self.nf1_loss(inter_boxes, boxes3, lower_bound, upper_bound)

    def init_concept_embedding(self, vocab_size, embed_dim, init_value):
        distribution = uniform.Uniform(init_value[0], init_value[1])
        box_embed = distribution.sample((vocab_size, embed_dim))
        return box_embed




In [91]:

import pandas as pd

train_data = []
nf1 = torch.tensor(pd.read_csv('./data/nf1.csv',sep=',',header=None).values)[:,1:]
nf2 = torch.tensor(pd.read_csv('./data/nf2.csv',sep=',',header=0).values)[:,1:]
nf3 = torch.tensor(pd.read_csv('./data/nf3.csv',sep=',',header=0).values)[:,1:]
nf4 = torch.tensor(pd.read_csv('./data/nf4.csv',sep=',',header=0).values)[:,1:]

train_data.append(nf1)
train_data.append(nf2)
train_data.append(nf3)
train_data.append(nf4)

train_data 



[tensor([[1.0040e+03, 1.0270e+03, 1.8868e-02],
         [1.0040e+03, 1.0380e+03, 1.4151e-02],
         [1.0040e+03, 1.0500e+03, 9.4340e-03],
         ...,
         [1.1080e+03, 1.1170e+03, 0.0000e+00],
         [1.1090e+03, 1.1170e+03, 0.0000e+00],
         [1.0590e+03, 1.0600e+03, 0.0000e+00]], dtype=torch.float64),
 tensor([[1004., 1027., 1000.,    0.],
         [1004., 1027., 1001.,    0.],
         [1004., 1027., 1002.,    0.],
         ...,
         [1121., 1126., 1127.,    0.],
         [1121., 1126., 1128.,    0.],
         [1121., 1126., 1129.,    0.]], dtype=torch.float64),
 tensor([[1.0030e+03, 1.1600e+03, 1.0090e+03, 8.6392e-01],
         [1.0030e+03, 1.1600e+03, 1.0100e+03, 4.7468e-01],
         [1.0040e+03, 1.1420e+03, 1.0190e+03, 4.0094e-01],
         [1.0050e+03, 1.1460e+03, 1.0220e+03, 4.9383e-01],
         [1.0050e+03, 1.1460e+03, 1.0240e+03, 5.4321e-01],
         [1.0050e+03, 1.1460e+03, 1.0260e+03, 3.4568e-01],
         [1.0050e+03, 1.1460e+03, 1.0320e+03, 4.9383e-01


## Loss term for (C|D)[l,u]

The intuition is to let Vol(C∧D)/Vol(D) to be in [l,u]

upper conditional loss: max(0, Vol(C∧D) - u*Vol(D) )

lower conditional loss: max(0, l*Vol(D) - Vol(C∧D) )



## Example Ontology


(Male|Person)[0.45,0.55]

(Person|Male)[1]

(Female|Person)[0.45,0.55]

(Person|Female)[1]

(Male|Female)[0]

(Female|Male)[0]

(MentalHealthProblem|Depression)[1]

(MentalHealthProblem|OCD)[1]

(Depression|OCD)[0.4,0.6]

(OCD|Depression)[0.4,0.6]

(Person|MentalHealthProblem)[0.0]

(MentalHealthProblem|Person)[0.0]




In [85]:

# wandb.init(project="basic_box",  reinit=False, config=args)
# train_data = []
# # person (0), male (1),  female (2), MentalHealthProblem (3), Depression (4), OCD(5)
# train_data.append(torch.Tensor([[1,0,0,0.45,0.55],[2,0,0,0.45,0.55],[0,0,1,1.0,1.0],[0,0,2,1.0,1.0], [2,0,1,0.0,0.0],[1,0,2,0.0,0.0], ]).to(device)) #nf1 c in d #
# # [3,0,5,1.0,1.0],[3,0,4,1.0,1.0],[4,0,5,0.4,0.6],[5,0,4,0.4,0.6],[0,0,3,0.0,0.0],[3,0,0,0.0,0.0]

# train_data

In [92]:

torch.manual_seed(888)
model = BoxEL(1000, 1, 2, [1e-4,0.2], [-0.1, 0], [-0.1,0.1], [0.9, 1.1], args).to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)
  


In [93]:
model.train()

for epoch in range(2000):

    nf1_loss, nf2_loss, reg_loss = model(train_data)
    loss = nf1_loss + nf2_loss + reg_loss
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print(nf1_loss.item(),nf2_loss.item())
    
    if epoch % 10 ==0:
        min_emb, delta_emb, max_emb, rel_emb, scaling_emb = model.min_embedding, model.delta_embedding, model.min_embedding +torch.exp(model.delta_embedding), model.relation_embedding, model.scaling_embedding
        boxes = Box(min_emb, max_emb)
#         plot_box(boxes, rel_emb.detach().numpy(), scaling_emb.detach().numpy(), epoch, loss.item())




    

283.6028632302873 5632.0
229.0883129914396 5632.0
182.52684696330996 5632.0
142.61603044535957 5632.0
108.16741445850572 5632.0
78.50037663075588 5632.0
54.729888426122805 5632.0
43.65322057300213 5632.0
39.32057056948588 5632.0
36.904380812406764 5632.0
35.56452607216562 5632.0
34.71168449437355 5632.0
34.05440987688279 5632.0
33.746475188017364 5632.0
32.60049498329884 5632.0
30.281641728807905 5632.0
27.603959181702365 5632.0
24.784321281101555 5632.0
22.068727554926888 5632.0
19.72179910354618 5632.0
17.89862804754765 5632.0
16.429058028885642 5632.0
15.033705481870506 5632.0
13.79097970926447 5632.0
12.859035421740295 5632.0
12.014793899791812 5632.0
11.22170240134301 5632.0
10.535494046893199 5632.0
9.960349957313795 5632.0
9.460598795861902 5632.0
9.043375226842077 5632.0
8.667781848261153 5632.0
8.338307539576942 5632.0
7.9892132776162175 5632.0
7.601722028523142 5632.0
7.280355180467301 5632.0
6.94951269281738 5632.0
6.642279279584837 5632.0
6.311424429225436 5632.0
6.00763077

0.08837307969002589 5632.0
0.07305109176331825 5632.0
0.0882406571224692 5632.0
0.06383190102337721 5632.0
0.08953989924953021 5632.0
0.08081904690880037 5632.0
0.06277636727463687 5632.0
0.09013824032756526 5632.0
0.10532571193152762 5632.0
0.11296426147428354 5632.0
0.10263535323474571 5632.0
0.12227823349462597 5632.0
0.11178008273805062 5632.0
0.10036936552091902 5632.0
0.08999925515604446 5632.0
0.096529263063303 5632.0
0.09776026712961539 5632.0
0.10612405893812138 5632.0
0.08793545014827941 5632.0
0.12481630872684946 5632.0
0.11020228302368196 5632.0
0.13704798453886724 5632.0
0.16150766612531697 5632.0
0.12705097798243514 5632.0
0.12549128576256408 5632.0
0.11703751886670943 5632.0
0.12034093437523552 5632.0
0.10534484776368294 5632.0
0.0992301373664759 5632.0
0.098025320625311 5632.0
0.10499559453894798 5632.0
0.1006787948196255 5632.0
0.08786415997072035 5632.0
0.09957334546318415 5632.0
0.09608830225520423 5632.0
0.09111527000709686 5632.0
0.08119815117856888 5632.0
0.070009

0.08649740963392105 5632.0
0.08726104621564446 5632.0
0.08520499122778347 5632.0
0.09672721350170832 5632.0
0.1065382181179757 5632.0
0.06695117701247 5632.0
0.08279801261642206 5632.0
0.06541825308772786 5632.0
0.08185018318840775 5632.0
0.07172803284879592 5632.0
0.07325880755070102 5632.0
0.08657492294651092 5632.0
0.0957193154982256 5632.0
0.06095515665624421 5632.0
0.04407834950782905 5632.0
0.04200010558136689 5632.0
0.05453595116705401 5632.0
0.08052827156185724 5632.0
0.08542846859836573 5632.0
0.050324955250061976 5632.0
0.0883263440139217 5632.0
0.08091210187853903 5632.0
0.06409724096817293 5632.0
0.04966148366432836 5632.0
0.05690983058376909 5632.0
0.09395072850588804 5632.0
0.0645390462807427 5632.0
0.04805832411755143 5632.0
0.062029085335552736 5632.0
0.07211541325159487 5632.0
0.08470733169189437 5632.0
0.05170907193610219 5632.0
0.06540154125605113 5632.0
0.07006406814207367 5632.0
0.08482539442798043 5632.0
0.15499762156969155 5632.0
0.060858014978778165 5632.0
0.060

0.16659127195043766 5632.0
0.11259488128234807 5632.0
0.05506082065630835 5632.0
0.2854875213033665 5632.0
0.0798028458589215 5632.0
0.17165777272546853 5632.0
0.16723738572409275 5632.0
0.13993594890962413 5632.0
0.20832583865103516 5632.0
0.15073029075165323 5632.0
0.15047426483170057 5632.0
0.34699043139971764 5632.0
0.12124316570998417 5632.0
0.0884883659958661 5632.0
0.07087421866981458 5632.0
0.12888656583936609 5632.0
0.04190325602939993 5632.0
0.04617451618059931 5632.0
0.11367433895520662 5632.0
0.08457720875367158 5632.0
0.03863872064266616 5632.0
0.05931462628996087 5632.0
0.15140003642727606 5632.0
0.17595532159087643 5632.0
0.21542015996828923 5632.0
0.18200215013030174 5632.0
0.0943055813590945 5632.0
0.3266862354375917 5632.0
0.09803876175374171 5632.0
0.08875264175640041 5632.0
0.08338063496512405 5632.0
0.07162870502133956 5632.0
0.10977244062405589 5632.0
0.16543747235891715 5632.0
0.5290553976217858 5632.0
0.09293393358439062 5632.0
0.09741175616668732 5632.0
0.18050

0.1076840654766329 5632.0
0.07758907230709156 5632.0
0.07778311234415014 5632.0
0.1094209192483504 5632.0
0.07951949204208818 5632.0
0.09500842464376547 5632.0
0.08792776145037351 5632.0
0.12349901040761324 5632.0
0.20155495255858114 5632.0
0.1503297286030829 5632.0
0.13556592030818138 5632.0
0.17059902570463237 5632.0
0.08897906493131545 5632.0
0.07944556295251459 5632.0
0.17038541512692973 5632.0
0.2272672805588627 5632.0
0.10226922705396646 5632.0
0.08729152015425741 5632.0
0.08430829491680925 5632.0
0.26072579498941195 5632.0
0.12752288085812324 5632.0
0.06984487725776489 5632.0
0.12476916114474079 5632.0
0.13369552759650105 5632.0
0.12127770669803795 5632.0
0.1356702518273778 5632.0
0.1423756106376004 5632.0
0.17371725857060483 5632.0
0.13563791206206588 5632.0
0.1214589958808574 5632.0
0.14155943691069606 5632.0
0.08754243396151651 5632.0
0.16707495501545133 5632.0
0.11385706362927236 5632.0
0.10898374543245778 5632.0
0.10509130515160821 5632.0
0.10801712919405788 5632.0
0.102022

0.19530773838550886 5632.0
0.18293459076858198 5632.0
0.1806865222692977 5632.0
0.197037844965962 5632.0
0.1964989074931509 5632.0
0.1736534478645808 5632.0
0.17018246915904456 5632.0
0.2344854751136154 5632.0
0.1622095206726044 5632.0
0.15832214499914699 5632.0
0.15519062647126702 5632.0
0.18111020374681175 5632.0
0.20161663416729425 5632.0
0.15771470195386428 5632.0
0.1434682558001441 5632.0
0.1407286687940541 5632.0
0.13820224468372544 5632.0
0.13578358481299801 5632.0
0.13347389000409748 5632.0
0.13132133923545553 5632.0
0.1445836150787727 5632.0
0.12718307084287517 5632.0
0.1324227653429745 5632.0
0.12353323803836247 5632.0
0.12313536716083812 5632.0
0.1922280804055747 5632.0
0.20129148464274227 5632.0
0.17177417386482308 5632.0
0.17409695236982703 5632.0
0.16939280760172856 5632.0
0.1882003123464333 5632.0
0.18215777062850774 5632.0
0.17280156005863118 5632.0
0.16722806985398098 5632.0
0.1641759860972245 5632.0
0.16178550365610977 5632.0
0.21756538294220196 5632.0
0.1564854077655

0.15666464333116892 5632.0
0.1527414873539783 5632.0
0.13232008545082863 5632.0
0.13624365089162893 5632.0
0.13781866460340098 5632.0
0.15105647656127985 5632.0
0.1371419109091221 5632.0
0.12773722855126834 5632.0
0.12198899754048398 5632.0
0.14117562723777155 5632.0
0.11582823218441263 5632.0
0.13568938423395593 5632.0
0.11962713179127604 5632.0
0.11839875363511965 5632.0
0.11881913592833371 5632.0
0.18707506617556646 5632.0
0.12428470411896342 5632.0
0.12582804335352193 5632.0
0.1266247953567472 5632.0
0.25485305307256567 5632.0
0.11738296789098968 5632.0
1.4062266159835417 5632.0
0.11402062714478234 5632.0
0.11878397749114811 5632.0
0.12465691290617542 5632.0
0.11953637057013111 5632.0
0.1346261214939659 5632.0
0.12788069291036663 5632.0
0.4298586529575914 5632.0
0.21262949165929967 5632.0
0.18156446547163796 5632.0
0.14408866311987367 5632.0
0.1755594727483185 5632.0
0.2302913892840479 5632.0
0.18225684390336028 5632.0
0.14828462309378665 5632.0
0.2134007066056256 5632.0
0.13893989

In [39]:

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from IPython.display import clear_output

# from matplotlib.pyplot import fig

def plot_box(boxes, relation, scaling_emb, epoch, loss):
    clear_output(wait=True)
    plt.figure(figsize=(5,4),dpi=100)
    # plt.figure(figsize=(10, 8), dpi=80)
    # plt.subplot(121)
    
    labels = ['Person','Male','Female'] #,'MentalProblem','Depression','OCD']
    color = ['r','b','g'] # 'c','m','black']
    
    # Plot concept embedding
    for i in range(boxes.min_embed.shape[0]):
        x1,x2,w,h = boxes.min_embed[i][0].detach().numpy(),boxes.min_embed[i][1].detach().numpy(),boxes.max_embed[i][0].detach().numpy()-boxes.min_embed[i][0].detach().numpy(), boxes.max_embed[i][1].detach().numpy()-boxes.min_embed[i][1].detach().numpy()
        vol = (F.softplus(torch.tensor([w]))*F.softplus(torch.tensor([h])) ).item()
        rect=mpatches.Rectangle((x1,x2),w,h, fill = False, color = color[i],linewidth = 2, alpha=0.6, facecolor=color[i])
        plt.gca().add_patch(rect)
        plt.text(x1+w-0.14,x2+h-0.05, labels[i], color = color[i]) #+str(vol)
        
    plt.xlim(-0.5,1.5)
    plt.ylim(-0.5,1.5)

    plt.show()

