In [0]:
%load_ext autoreload
%autoreload 2

In [0]:
import os

if "DATABRICKS_RUNTIME_VERSION" in os.environ and not 'installed_libs' in globals():
  #CUDA = 'cu121' 
  installed_libs = True
  
  
  !pip install torch==2.1.0  torchvision==0.16.0 torchtext==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
  import torch
  #os.environ['TORCH'] = torch.__version__
  #print(torch.__version__)
  #torch_version = '2.0.0+cu118'
  
  #!pip install pyg_lib torch_scatter torch_sparse torch_cluster -f https://data.pyg.org/whl/torch-2.1.0+${CUDA}.html # torch_spline_conv
  !pip install torch_geometric
  !pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
  #!pip install torch_sparse -f https://data.pyg.org/whl/torch-2.1.0+${CUDA}.html
  #!pip install torch_scatter -f https://data.pyg.org/whl/torch-2.1.0+${CUDA}.html
  #!pip install pyg_lib -f https://data.pyg.org/whl/torch-2.1.0+${CUDA}.html
  !pip install sentence-transformers
  !pip install torcheval
  !pip install matplotlib
  !pip install pandas
  !pip install tensorboard
  
if "DATABRICKS_RUNTIME_VERSION" in os.environ:
  ROOT_FOLDER = '/dbfs/FileStore/GraphNeuralNetworks/'
else:
  ROOT_FOLDER = ''

In [0]:
from learnings_sampler_v1 import get_datasets, uniform_hgt_sampler, get_minibatch_count, add_reverse_edge_original_attributes_and_label_inplace, get_hgt_linkloader, get_single_minibatch_count

train_data, val_data, test_data = get_datasets(get_edge_attr=False, filename=ROOT_FOLDER+'HeteroData_Learnings_normalized_triangles_withadditionaldata_v1.pt', filter_top_k=True, top_k=50)


In [0]:
import torch
from models.TransE import TransE
from models.DistMult import DistMult
from models.HGT import HGT
import torch_geometric
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Model(torch.nn.Module):
    def __init__(self, gnn : torch.nn.Module, head :  torch.nn.Module, node_types, edge_types, ggn_output_dim, pnorm=1):
        super().__init__()
        # edge_type onehot lookup table with keys
        # node_type onehot lookup table with keys
        self.node_type_embedding = torch.nn.Embedding(len(node_types), ggn_output_dim) # hidden channels should be the output dim of gnn
        
        self.edge_types = edge_types
        for edge_type in edge_types:
            if edge_type[1].startswith('rev_'):
                self.edge_types.remove(edge_type)
        
        # create edge to int mapping
        self.edgeindex_lookup = {edge_type:torch.tensor(i)  for i, edge_type in enumerate(edge_types)}
            
        # hidden channels should be the output dim of gnn
        if head=='TransE': 
            self.head = TransE(len(node_types), len(edge_types) , ggn_output_dim, p_norm= pnorm)  # KGE head with loss function
        elif head=='DistMult':
            self.head = DistMult(len(node_types), len(edge_types) , ggn_output_dim, p_norm= pnorm)  # KGE head with loss function
        else:
            raise NotImplementedError
        
        self.gnn = gnn
        
    

    def forward(self, hetero_data1, target_edge_type, edge_label_index, edge_label, hetero_data2=None):
        
        if hetero_data2 is not None:
            assert target_edge_type[0] != target_edge_type[2], 'when passing two data objects, the edge type has to contain two different node types'
            head_embeddings = self.gnn(hetero_data1.x_dict, hetero_data1.edge_index_dict)[target_edge_type[0]][edge_label_index[0,:]]
            tail_embeddings = self.gnn(hetero_data2.x_dict, hetero_data2.edge_index_dict)[target_edge_type[2]][edge_label_index[1,:]]
        else:
            assert target_edge_type[0] == target_edge_type[2], 'when passing one data object, the edge type has to contain the same node types'


            embeddings = self.gnn(hetero_data1.x_dict, hetero_data1.edge_index_dict)
            head_embeddings = embeddings[target_edge_type[0]][edge_label_index[0,:]]
            tail_embeddings = embeddings[target_edge_type[2]][edge_label_index[1,:]]

        
        edgeindex = self.edgeindex_lookup[target_edge_type]
        loss = self.head.loss(head_embeddings, edgeindex.to(device), tail_embeddings, edge_label)
        return loss
    
        
metadata = train_data.metadata()
# add selfloops
for node_type in train_data.node_types:
    metadata[1].append((node_type, 'self_loop', node_type))    
    
out_channels = 256
hidden_channels = 256
num_heads = 8
num_layers = 2
pnorm = 2
head = 'TransE'
gnn = HGT(hidden_channels=out_channels, out_channels=out_channels, num_heads=num_heads, num_layers=num_layers, node_types=train_data.node_types, data_metadata=metadata)

model = Model(gnn, head=head, node_types=metadata[0], edge_types=metadata[1], ggn_output_dim=out_channels, pnorm=pnorm)
#torch_geometric.compile(model, dynamic=True)
model.to(device)


In [0]:
from tqdm.auto import tqdm
from datetime import datetime
batch_size = 32
num_node_types = len(train_data.node_types)
print('num_node_types', num_node_types)
one_hop_neighbors = (25 * batch_size)//num_node_types # per relationship type
two_hop_neighbors = (25 * 10 * batch_size)//num_node_types # per relationship type
three_hop_neighbors = (25 * 10 * 4 * batch_size)//num_node_types # per relationship type
num_neighbors = [one_hop_neighbors, two_hop_neighbors] # three_hop_neighbors
# num_neighbors [36, 363, 1454]

print('num_neighbors', num_neighbors)
print('avg_num_neighbors', [num_neighbors[0]/batch_size,num_neighbors[1]/batch_size,  num_neighbors[2]/batch_size if len(num_neighbors)==3 else 0 ])

input_edgetype = ('people', 'rev_course_and_programs_student', 'courses_and_programs')
add_reverse_edge_original_attributes_and_label_inplace(train_data['courses_and_programs', 'course_and_programs_student', 'people'], reverse_edge=train_data[input_edgetype] )
add_reverse_edge_original_attributes_and_label_inplace(val_data['courses_and_programs', 'course_and_programs_student', 'people'], reverse_edge=val_data[input_edgetype] )

train_sampler = get_hgt_linkloader(train_data, input_edgetype, batch_size, True, 'triplet', 1, num_neighbors, num_workers=0, prefetch_factor=None, pin_memory=True)
val_sampler = get_hgt_linkloader(val_data, input_edgetype, batch_size, False, 'triplet', 1, num_neighbors, num_workers=0, prefetch_factor=None, pin_memory=True)

learning_rate = 2e-4
# torch get optimizer by string name
optimizer = 'Adam'
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) #2e-15


# create a tensorboard writer
from torch.utils.tensorboard import SummaryWriter
neighbors = '_'.join([str(n) for n in num_neighbors])


timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

writer = SummaryWriter(ROOT_FOLDER+f'runs/learningpeople_hgt_{timestamp}_pnrom{pnorm}_lr{learning_rate}_bs{batch_size}_neighbors_{neighbors}_head_{head}_hiddenchannels_{hidden_channels}_outchannels_{out_channels}_numheads_{num_heads}_numlayers_{num_layers}')
print('writer',ROOT_FOLDER+f'runs/learningpeople_hgt_{timestamp}_pnrom{pnorm}_llr{learning_rate}_bs{batch_size}_neighbors_{neighbors}_head_{head}_hiddenchannels_{hidden_channels}_outchannels_{out_channels}_numheads_{num_heads}_numlayers_{num_layers}')

model.train()
total_minibatches = get_single_minibatch_count(train_data, batch_size, input_edgetype)
start_epoch = 1
for epoch in range(start_epoch, start_epoch+1000):
    for i,minibatch in tqdm(enumerate(train_sampler), total=total_minibatches):
        
        optimizer.zero_grad() 
        # batching is different depending on if node types in edge are same or different
        
        minibatchpart1, minibatchpart2, edge_label_index, edge_label, input_edge_id, _, _ = minibatch
        #print(minibatchpart1['jobs'].device, minibatchpart2['jobs'].device, edge_label_index.device, edge_label.device)
        loss = model(minibatchpart1.to(device), input_edgetype, edge_label_index.to(device), edge_label.to(device), minibatchpart2.to(device))
        #loss, pos, neg = model(minibatchpart1, target_edge_type, edge_label_index, edge_label, minibatchpart2)
        
        loss.backward()
        optimizer.step()
        
        total_samples_seen = i * batch_size
        writer.add_scalar('Loss/train', loss.item(), total_samples_seen)
        
        if i == total_minibatches-1:
            print(f'{i} loss: {loss.item():.4f}')
            writer.add_scalar('Epoch Loss/train', loss.item(), total_samples_seen)
        
        # print loss and minibatch in the same line
        print(f'{i} loss: {loss.item():.4f}', end='\r')
        
        if i % 300 == 0 or i == total_minibatches-1:
            model.eval()
            with torch.no_grad():
                val_loss = 0
                for _ in range(3):
                    try:
                        minibatch = next(val_sampler)
                    except StopIteration:
                        val_sampler = iter(val_sampler)
                        minibatch = next(val_sampler)

                    minibatchpart1, minibatchpart2, edge_label_index, edge_label, input_edge_id, _, _ = minibatch
                    val_loss += model(minibatchpart1.to(device), input_edgetype, edge_label_index.to(device), edge_label.to(device), minibatchpart2.to(device)).item()
            
            val_loss /= 3
            if i == 0:
                writer.add_scalar('Epoch Loss/val', val_loss, total_samples_seen)
                writer.add_scalar('Loss/val', val_loss, total_samples_seen)
            elif i == total_minibatches-1:
                writer.add_scalar('Epoch Loss/val', val_loss, total_samples_seen)
            else:
                writer.add_scalar('Loss/val', val_loss, total_samples_seen)
            

            print(f'val_loss: {val_loss:.4f}', end='\r')
            model.train()

        writer.flush()
        
        if i % 1000 == 0 or i == total_minibatches-1:
            folder = 'models'
            if not os.path.exists(folder):
                os.makedirs(folder)
            
            run_folder = ROOT_FOLDER+f'{folder}/learningpeople_hgt_{timestamp}_pnrom{pnorm}_llr{learning_rate}_bs{batch_size}_neighbors_{neighbors}_head_{head}_hiddenchannels_{hidden_channels}_outchannels_{out_channels}_numheads_{num_heads}_numlayers_{num_layers}'
            if not os.path.exists(run_folder):
                os.makedirs(run_folder)
                
            print('saving model to', run_folder)
            # save model and optimizer
            is_epoch = f'Ep{epoch}_' if i == total_minibatches-1 else ''
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                }, run_folder+f'/{is_epoch}model_samplesseen{total_samples_seen}.pt')
            
writer.close()

In [0]:
len(minibatch)