In [None]:
import torch
import torch.nn.functional as F
from models.TransE import TransE
from models.DistMult import DistMult
from models.FactorizationMachineModel import FactorizationMachineModel
import torch_geometric
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Model(torch.nn.Module):
    def __init__(self, fm : torch.nn.Module, head, node_types, edge_types, ggn_output_dim, pnorm=1, num_supervisors=0, num_organizations=0):
        super().__init__()
        # edge_type onehot lookup table with keys
        # node_type onehot lookup table with keys
        self.node_type_embedding = torch.nn.Embedding(len(node_types), ggn_output_dim) # hidden channels should be the output dim of gnn
        self.num_supervisors = num_supervisors
        self.num_organizations = num_organizations
        self.edge_types = edge_types
        for edge_type in edge_types:
            if edge_type[1].startswith('rev_'):
                self.edge_types.remove(edge_type)
        
        # create edge to int mapping
        self.edgeindex_lookup = {edge_type:torch.tensor(i)  for i, edge_type in enumerate(edge_types)}
            
        if head=='TransE': 
            self.head = TransE(len(node_types), len(edge_types) , ggn_output_dim, p_norm= pnorm)  # KGE head with loss function
        elif head=='DistMult':
            self.head = DistMult(len(node_types), len(edge_types) , ggn_output_dim, p_norm= pnorm)  # KGE head with loss function
        else:
            raise NotImplementedError
        
        self.fm = fm
       
     
        self.layer1 = torch.nn.Linear(3167,512)
        self.layer2 = torch.nn.Linear(512,256)
        self.layer3 = torch.nn.Linear(257,256)
        self.fc_output = torch.nn.Linear(256, 256)

    def forward(self, hetero_data, hetero_data_embeddings, edge_label_index, edge_label):
        
    
        people = hetero_data['people'].x[edge_label_index[0,:],:]
        # last two columns in people are the indices of onehot, so change them to full onehot supervisor and organization
        #supervisors = torch.nn.functional.one_hot(people[:,-2].to(torch.int64), num_classes=self.num_supervisors).to(torch.float32)
        #organizations = torch.nn.functional.one_hot(people[:,-1].to(torch.int64), num_classes=self.num_organizations).to(torch.float32)
        #people = torch.cat((people[:,:-2], supervisors, organizations), dim=1)
        people_embeddings = hetero_data_embeddings['people'].x[edge_label_index[0,:],:]
                        
        learnings = hetero_data['courses_and_programs'].x[edge_label_index[1,:],:]
        learning_embeddings = hetero_data_embeddings['courses_and_programs'].x[edge_label_index[1,:],:]
        
        #x1 = torch.cat((people_embeddings, learning_embeddings),dim=1)
        #x1 = self.layer1(x1).relu()
        #x1 = self.layer2(x1).relu()
        x2 = torch.cat((people,learnings),dim=1)
        x2 = self.fm(x2)
        scores = x2
        #x3 = self.layer3(torch.cat((x1,x2.unsqueeze(1)),dim=1)).relu()
        #scores = self.fc_output(x3).relu()
     
        pos_scores = scores[edge_label==1]
        neg_scores = scores[edge_label==0]
            

        return F.margin_ranking_loss(
            pos_scores,
            neg_scores,
            target=torch.ones_like(pos_scores), # 1 for similarity, -1 for dissimilarity
            margin=0.2
        )
        
    
out_channels = 1
hidden_channels = 16
num_heads = 0
num_layers = 0
pnorm = 2
head = 'TransE'
#gnn = HGT(hidden_channels=out_channels, out_channels=out_channels, num_heads=num_heads, num_layers=num_layers, node_types=train_data.node_types, data_metadata=metadata)
filename = 'HeteroData_Learnings_normalized_triangles_withadditionaldata_v1.pt'
data_forlookup = HeteroData.from_dict(torch.load(ROOT_FOLDER+filename))
num_supervisors = data_forlookup['people'].num_nodes
num_organizations = data_forlookup['organizations'].num_nodes
metadata = data_forlookup.metadata()
# add selfloops
for node_type in data_forlookup.node_types:
    metadata[1].append((node_type, 'self_loop', node_type))  
    
    

del data_forlookup
print(train_data['people'].labelencoding.shape, train_data['courses_and_programs'].labelencoding.shape)  
field_dims = torch.cat((train_data['people'].labelencoding,train_data['courses_and_programs'].labelencoding), dim=0)
print(field_dims)
# convert the field dims to integer
field_dims = field_dims.to(torch.int64)
fm = FactorizationMachineModel(
    field_dims=field_dims,
        embed_dim=hidden_channels)







model = Model(fm, head=head, node_types=metadata[0], edge_types=metadata[1], ggn_output_dim=out_channels, pnorm=pnorm, num_supervisors=num_supervisors, num_organizations=num_organizations)
#torch_geometric.compile(model, dynamic=True)
model.to(device)

