In [1]:
import random
import itertools
import pandas as pd
import numpy as np
import math

from math import sqrt as msqrt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score, classification_report
import torch
import torch.functional as F
from torch import nn
from torch.optim import Adadelta
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset, TensorDataset
from tqdm import tqdm

# Data Preparation

## Training Set-Concept 

In [2]:
# Extract extents and intents from txt file
def get_intents_extents(filename):
    intents = []
    extents = []
    with open(filename, 'r', encoding='utf-8') as file:
        for line in file:
            # Split the line based on four blank spaces
            parts = line.split('    ')

            if len(parts) == 2:
                intent = parts[0].strip()
                intents.append(intent)
                

                extent = parts[1].strip()
                extents.append(extent)
                
        modified_intents = [' '.join(['a' + token for token in item.split()]) for item in intents]
        modified_extents = [' '.join(['o' + token for token in item.split()]) for item in extents]

    return modified_intents, modified_extents, intents, extents

# The function to process formal context file
def process_context(file_name):
    intents = []
    extents = []
    modified_intents = []
    modified_extents = []
    with open(file_name, 'r') as file:
        for line in file:
            parts = line.strip().split()
            if len(parts) == 2:
                extents.append(parts[0])
                intents.append(parts[1])
                modified_extents.append('o{}'.format(parts[0]))
                modified_intents.append('a{}'.format(parts[1]))
    return extents, intents, modified_extents, modified_intents

# Truncate concept, delete concepts that are longer than "max_length"
def truncate_concept(concept, max_len):
    
    max_length = 0
    longest_sequence = []

    for sequence in concept:
        num_elements = len(sequence.split())
        if num_elements >= max_length:
            max_length = num_elements
            longest_sequence = sequence

    print("Length of the longest sequence:", max_length)
    
    truncated_concept = [sequence for sequence in concept if len(sequence) <= max_len]
    print("The number of concepts with limited length is:", len(truncated_concept))
    
          
    return truncated_concept, max_length

# Use objects and attribtutes to crate index dictionary
def create_index_dic(filename):
    full_extents, full_intents, test_extents, test_intents = process_context(filename)
    object_list = list(set(" ".join(full_extents).split()))
    sorted_object_list = sorted(map(int, object_list))
    
    attribute_list = list(set(" ".join(full_intents).split()))
    sorted_attribute_list = sorted(map(int, attribute_list))
    
    special_tokens = {'[PAD]': 0, '[CLS]': 1 }
    object2idx = {'o' + str(obj): int(obj)+2  for  obj in sorted_object_list}
    attribute2idx = {'a' + str(att): int(att)+10003  for  att in sorted_attribute_list}
    
    index = {}
    index.update(special_tokens)
    index.update(object2idx)
    index.update(attribute2idx)
    
    return index, object2idx, attribute2idx


In [3]:
train_intents, train_extents, original_train_intents, original_train_extents = get_intents_extents('BMS-POS-with-missing-part_concepts.txt')

print('The number of training intents is', len(train_intents))
print('The number of training extents is', len(train_extents))

The number of training intents is 5247
The number of training extents is 5247


In [4]:
# The maximum length
max_len_int = 10240
max_len_ext = 10240
max_len = 1024
# Generate the truncated concept 
truncated_intent, original_max_len_int = truncate_concept(train_intents, max_len_int)
truncated_extent, original_max_len_ext = truncate_concept(train_extents, max_len_ext)

if max_len_int >= original_max_len_int:
    max_len_int = original_max_len_int
    
if max_len_ext >= original_max_len_ext:
    max_len_ext = original_max_len_ext

if max_len_ext >= max_len_int:
    max_len = max_len_ext
else:
        max_len = max_len_int
        
print('max_len_ext =',max_len_ext)
print('max_len_int =',max_len_int)
print('max_len =',max_len)

Length of the longest sequence: 6
The number of concepts with limited length is: 5247
Length of the longest sequence: 200
The number of concepts with limited length is: 5247
max_len_ext = 200
max_len_int = 6
max_len = 200


In [5]:
# Generate index dictionary
index_dic, obj_dic, att_dic = create_index_dic('BMS-POS-with-missing-part.txt')

# print(index_dic)
print('The number of indices is ', len(index_dic))

The number of indices is  10839


In [6]:
last_item_obj = list(obj_dic.items())[-1]
max_index_obj = last_item_obj[1] 
print("The largest number in the object index is ",max_index_obj)

last_item_att = list(att_dic.items())[-1]
max_index_att = last_item_att[1] 
print("The largest number in the attribute index is ",max_index_att)

The largest number in the object index is  10002
The largest number in the attribute index is  11007


In [7]:
org_intent_token_list = []

for sequence in truncated_intent:
    int_tokens = sequence.split()  # Split sequence into tokens
    indices = [index_dic[token] for token in int_tokens if token in index_dic]  # Convert tokens to indices
    org_intent_token_list.append(indices)  # Store indices in concept_token_list

print(len(org_intent_token_list))

5247


In [8]:
org_extent_token_list = []

for sequence in truncated_extent:
    ext_tokens = sequence.split()  # Split sequence into tokens
    indices = [index_dic[token] for token in ext_tokens if token in index_dic]  # Convert tokens to indices
    org_extent_token_list.append(indices)  # Store indices in concept_token_list

print(len(org_extent_token_list))

5247


In [9]:
org_ext_positive = org_extent_token_list
org_int_positive = org_intent_token_list
print(len(org_int_positive))

5247


In [10]:
def generate_negative_sequence(sequence, min_index, max_index):
    length = len(sequence)
    num_elements_to_replace = max(1, int(length * .15)) # Replace 15% elements
    
    # Copy the original sequence
    new_sequence = sequence[:]
    
    # Replace 15% of elements with random numbers
    for _ in range(num_elements_to_replace):
        index_to_replace = random.randint(0, length - 1)
        new_sequence[index_to_replace] = random.randint(min_index, max_index)
    
    return new_sequence

In [11]:
org_ext_negative = []
min_index_obj = 4
for sequence in org_ext_positive:
    if len(sequence) >= 2:
        for _ in range(1):  # Generate 3 negative samples for each positive sample
            new_sequence = generate_negative_sequence(sequence, min_index_obj,  max_index_obj)
            
            # Ensure the new sequence is different from any sequences in positive_samples and negative_samples
            while new_sequence in org_ext_positive or new_sequence in org_ext_negative:
                new_sequence = generate_negative_sequence(sequence, min_index_obj, max_index_obj)
            
            org_ext_negative.append(new_sequence)
        

In [12]:
org_int_negative = []
min_index_att = max_index_obj + 1

for sequence in org_int_positive:
    if len(sequence) >= 2:
        for _ in range(1):  # Generate 3 negative samples for each positive sample
            new_sequence = generate_negative_sequence(sequence, min_index_att, max_index_att)
            
            # Ensure the new sequence is different from any sequences in positive_samples and negative_samples
            while new_sequence in org_int_positive or new_sequence in org_int_negative:
                new_sequence = generate_negative_sequence(sequence, min_index_att, max_index_att)
            
            org_int_negative.append(new_sequence)
        
print(len(org_int_negative))

print(org_int_negative[:1])

5076
[[10096, 10351]]


In [13]:
while len(org_int_negative) < len(org_int_positive):
    org_int_negative.append([0])
print(len(org_int_negative))

5247


In [14]:
def pad_sequences_with_zeros(token_list, max_len):
    padded_sequences = []
    for sequence in token_list:
        pad_sequence_int = sequence + [0] * (max_len - len(sequence))
        padded_sequences.append(pad_sequence_int)
    return padded_sequences

In [15]:
# Pad sequences with zeros
int_pos = pad_sequences_with_zeros(org_int_positive, max_len)
int_neg = pad_sequences_with_zeros(org_int_negative, max_len)
ext_pos = pad_sequences_with_zeros(org_ext_positive, max_len)
ext_neg = pad_sequences_with_zeros(org_ext_negative, max_len)

print(len(int_pos))
print(len(int_neg))
print(len(ext_pos))
print(len(ext_neg))

5247
5247
5247
5247


In [16]:
def generate_training_test_data(positive_token_list, negative_sample_list):
    # Generate labels for positive samples (1) and negative samples (0)
    positive_labels = [1] * len(positive_token_list)
    negative_labels = [0] * len(negative_sample_list)

    # Combine sequences and labels
    samples = positive_token_list + negative_sample_list
    labels = positive_labels + negative_labels

    # Verify the lengths to ensure they match
    assert len(samples) == len(labels)

    return samples, labels

In [17]:
train_extent_set, train_extent_labels = generate_training_test_data(ext_pos, ext_neg)
train_intent_set, train_intent_labels = generate_training_test_data(int_pos, int_neg)
# print(train_intent_set[:1])
# print(len(train_intent_labels))

In [18]:
train_concept_set = [train_extent_set[i] + train_intent_set[i] for i in range(len(train_intent_set))]

In [19]:
train_concept_labels = train_extent_labels
print(len(train_concept_set))
print(len(train_concept_labels))

10494
10494


## Training Set-Context 

In [20]:
# The function to process formal context file
def process_context(file_name):
    incidence = []
    with open(file_name, 'r') as file:
        for line in file:
            parts = line.strip().split()
            if len(parts) == 2:
                incidence.append('o{} a{}'.format(parts[0], parts[1]))
    return incidence

# Function to generate negative samples
def generate_negative_samples(number_of_samples, existing_token_list, obj_dic, att_dic, index_dic):
    negative_sample_list = []
    
    # Create a set of existing sequences for faster lookup
    existing_sequences = set(tuple(seq) for seq in existing_token_list)
    
    # Iterate until the number of negative samples matches the number of positive samples
    while len(negative_sample_list) < number_of_samples :
        # Randomly select an object and attribute
        obj = random.choice(list(obj_dic.keys()))
        att = random.choice(list(att_dic.keys()))
        
        # Create a sequence from the randomly selected object and attribute
        sequence = [obj_dic[obj], att_dic[att]]
        
        # Check if the sequence is not in the existing sequences and not in the negative sample list
        if tuple(sequence) not in existing_sequences and tuple(sequence) not in negative_sample_list:
            negative_sample_list.append(sequence)
    
    return negative_sample_list

def generate_training_test_data(positive_token_list, negative_sample_list):
    # Generate labels for positive samples (1) and negative samples (0)
    positive_labels = [1] * len(positive_token_list)
    negative_labels = [0] * len(negative_sample_list)

    # Combine sequences and labels
    samples = positive_token_list + negative_sample_list
    labels = positive_labels + negative_labels

    # Verify the lengths to ensure they match
    assert len(samples) == len(labels)

    return samples, labels

In [21]:
# Load the file and process its contents
FT_incidence = process_context('BMS-POS-with-missing-part.txt')

print('The number of incidences in training set is :',len(FT_incidence))

The number of incidences in training set is : 63421


In [22]:
unique_elements = set()

for seq in FT_incidence:
    unique_elements.update(seq.split())

num_unique_elements = len(unique_elements)
print("Number of different elements:", num_unique_elements)

Number of different elements: 10837


In [23]:
# Convert training set sequences to index
FT_positive_token_list = []
for incidence in FT_incidence:
    tokens = incidence.split()
    if tokens[0] in index_dic and tokens[1] in index_dic:
        FT_positive_token_list.append([
            index_dic[tokens[0]],
            index_dic[tokens[1]]
        ])
print('The number of positive data is :', len(FT_positive_token_list))

The number of positive data is : 63421


In [24]:
FT_negative_sample_list = generate_negative_samples(len(FT_positive_token_list), FT_positive_token_list, obj_dic, att_dic, index_dic)

# Print the negative samples
# print(negative_sample_list)
print('The number of negative data is :', len(FT_negative_sample_list))

org_train_context_set, train_context_labels = generate_training_test_data(FT_positive_token_list, FT_negative_sample_list)
print("Length of train_context_set:", len(org_train_context_set))
print("Length of train_context_labels:", len(train_context_labels))

The number of negative data is : 63421
Length of train_context_set: 126842
Length of train_context_labels: 126842


In [25]:
def pad_sequences_with_zeros(context_token, max_len):
    padded_sequences = []
    for sequence in context_token:
        padded_sequence = sequence[:1] + [0] * (max_len - 1) + sequence[1:] + [0] * (max_len - 1)
        padded_sequences.append(padded_sequence)
    return padded_sequences

In [26]:
train_context_set = pad_sequences_with_zeros(org_train_context_set, max_len)
# train_set = train_concept_set + train_context_set
# train_labels = train_concept_labels + train_context_labels

train_set =  train_context_set
train_labels = train_context_labels

print("Length of train set:", len(train_set))
print("Length of train labels:", len(train_labels))

Length of train set: 126842
Length of train labels: 126842


## TEST Set-Context

In [27]:
# Load the file and process its contents
TEST_incidence = process_context('BMS-POS.txt')

# Convert training set sequences to index
test_token_list = []
for incidence in TEST_incidence:
    tokens_test = incidence.split()
    if tokens_test[0] in index_dic and tokens_test[1] in index_dic:
        test_token_list.append([
            index_dic[tokens_test[0]],
            index_dic[tokens_test[1]]
        ])

In [28]:
# Convert inner lists into tuples
FT_token_set = set(tuple(seq) for seq in FT_positive_token_list)
test_token_set = set(tuple(seq) for seq in test_token_list)

# Find the sequences that are in token_set but not in incidence_token_set
test_positive_samples = [list(seq) for seq in (test_token_set - FT_token_set)]
all_positive_samples = [list(seq) for seq in test_token_set]

# Check if the population size is smaller than 1000
# if len(positive_samples) > 500:

#     test_positive_samples = random.sample(positive_samples, 500)  # Sample 1000 sequences


test_negative_samples = generate_negative_samples(len(test_positive_samples), all_positive_samples, obj_dic, att_dic, index_dic)

org_test_set, test_labels = generate_training_test_data(test_positive_samples, test_negative_samples)

test_set = pad_sequences_with_zeros(org_test_set, max_len)
print("Length of test_set:", len(org_test_set))
print("Length of test_labels:", len(test_labels))

Length of test_set: 13832
Length of test_labels: 13832


## Predictive Model

## Transformer Encoder

In [29]:
# dimension of key, values. the dimension of query and key are the same 
d_k = d_v = 64
# dimension of embedding
d_model = 768  # n_heads * d_k
# dimension of hidden layers
d_ff = d_model * 4

# number of heads
n_heads = 12
# number of encoders
n_layers = 9
# # the number of input setences
n_segs = 1
p_dropout = .1

$$
\displaylines{
\operatorname{GELU}(x)=x P(X \leq x)= x \Phi(x)=x \cdot \frac{1}{2}[1+\operatorname{erf}(x / \sqrt{2})] \\
 or \\
0.5 x\left(1+\tanh \left[\sqrt{2 / \pi}\left( x+ 0.044715 x^{3}\right)\right]\right)
}
$$

In [30]:
def gelu(x):
    '''
    Two way to implements GELU:
    0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    or
    0.5 * x * (1. + torch.erf(torch.sqrt(x, 2))) 
    '''
    return .5 * x * (1. + torch.erf(x / msqrt(2.)))

#  create a mask tensor to identify the padding tokens in a batch of sequences
def get_pad_mask(tokens, pad_idx=0):
    '''
    suppose index of [PAD] is zero in index dictionary
    the size of input tokens is [batch, seq_len]
    '''
    batch, seq_len = tokens.size()
    pad_mask = tokens.data.eq(pad_idx).unsqueeze(1) #.unsqueeze(1) adds a dimension and turns it to column vectors
    pad_mask = pad_mask.expand(batch, seq_len, seq_len)
    
    return pad_mask

In [31]:
# process input tokens to dense vectors before passing them to encoder.
class Embeddings(nn.Module):
    def __init__(self,max_vocab, max_len):
        super(Embeddings, self).__init__()
        self.seg_emb = nn.Embedding(n_segs, d_model)
        '''
        convert indices into vector embeddings.
        max_vocab can be replaced by formal context object vectors or attribute vectors
        '''
        self.word_emb = nn.Embedding(max_vocab, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(p_dropout)

    def forward(self, x):
        '''
        x: [batch, seq_len]
        '''
        # print("Input to Embeddings.forward - x:", x.size())
        word_enc = self.word_emb(x)
        # print("Output from Embeddings.forward - word_enc:", word_enc.size())

        # seg_enc = self.seg_emb(seg)
        x = self.norm(word_enc)
        return self.dropout(x)
        # return: [batch, seq_len, d_model]

$$
\operatorname{Attention}(Q, K, V) = \operatorname{softmax}(\frac{QK^T}{\sqrt{d_k}})V
$$

$$
\begin{aligned}
\operatorname{MultiHead}(Q, K, V) &= \operatorname{Concat}(\text{head}_1, \text{head}_2, \dots, \text{head}_h)W^O \\
\text{where } \text{head}_i &= \operatorname{Attention}(QW^Q_i, KW^K_i, VW^V_i)
\end{aligned}
$$

In [32]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2) / msqrt(d_k))
        # scores: [batch, n_heads, seq_len, seq_len]
        # fill the positions in the scores tensor where the attn_mask is True with a very large negative value (-1e9). 
        scores.masked_fill_(attn_mask, -1e9)
        attn = nn.Softmax(dim=-1)(scores)
        # context: [batch, n_heads, seq_len, d_v]
        context = torch.matmul(attn, V)
        return context

class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)

    def forward(self, Q, K, V, attn_mask):
        '''
        Q, K, V: [batch, seq_len, d_model]
        attn_mask: [batch, seq_len, seq_len]
        '''
        batch = Q.size(0)
        '''
        split Q, K, V to per head formula: [batch, seq_len, n_heads, d_k]
        Convenient for matrix multiply opearation later
        q, k, v: [batch, n_heads, seq_len, d_k or d_v]
        '''
        per_Q = self.W_Q(Q).view(batch, -1, n_heads, d_k).transpose(1, 2)
        per_K = self.W_K(K).view(batch, -1, n_heads, d_k).transpose(1, 2)
        per_V = self.W_V(V).view(batch, -1, n_heads, d_v).transpose(1, 2)

        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
        # context: [batch, n_heads, seq_len, d_v]
        context = ScaledDotProductAttention()(per_Q, per_K, per_V, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch, -1, n_heads * d_v)

        # output: [batch, seq_len, d_model]
        output = self.fc(context)
        return output

$$\operatorname{FFN}(x)=\operatorname{GELU}(xW_1+b_1)W_2+b_2$$

In [33]:
class FeedForwardNetwork(nn.Module):
    def __init__(self):
        super(FeedForwardNetwork, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(p_dropout)
        self.gelu = gelu

    def forward(self, x):
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.gelu(x)
        x = self.fc2(x)
        return x

In [34]:
# Encoder

class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.enc_attn = MultiHeadAttention()
        self.ffn = FeedForwardNetwork()

    def forward(self, x, pad_mask):
        '''
        pre-norm
        x: [batch, seq_len, d_model]
        '''
        residual = x
        x = self.norm1(x)
        x = self.enc_attn(x, x, x, pad_mask) + residual
        residual = x
        x = self.norm2(x)
        x = self.ffn(x)
        return x + residual

In [35]:
# next sentence prediction
# pooled representation of the entire sequence as the [CLS] token representation.
'''
The full connected linear layer improve the result while making the model harder to train.
'''
class Pooler(nn.Module):
    def __init__(self):
        super(Pooler, self).__init__()
        self.fc = nn.Linear(d_model, d_model)
        self.tanh = nn.Tanh()

    def forward(self, x):
        '''
        x: [batch, d_model] (first place output)
        '''
        x = self.fc(x)
        x = self.tanh(x)
        return x

In [43]:
class Encoder(nn.Module):
    def __init__(self, n_layers, max_vocab, max_len):
        super(Encoder, self).__init__()
        self.embedding = Embeddings(max_vocab, max_len)
        self.encoders = nn.ModuleList([
            EncoderLayer() for _ in range(n_layers)
        ])

        self.pooler = Pooler()

    def forward(self, tokens):
        output = self.embedding(tokens)
        enc_self_pad_mask = get_pad_mask(tokens)
        for layer in self.encoders:
            output = layer(output, enc_self_pad_mask)
        # output: [batch, max_len, d_model]
        '''
        Extracting the [CLS] token representation, 
        passing it through the pooler, 
        and making predictions.
        '''
        # hidden_pool = self.pooler(output[:, 0]) # only the [CLS] token
        hidden_pool = self.pooler(torch.mean(output, 1))

        return hidden_pool

## Input Data Preparation

In [45]:
def make_data(concepts, index_dic, max_vocab, max_len):
    batch_data = []

    for concpet_tokens in concepts :       

        input_ids = [index_dic['[CLS]']] + concpet_tokens
        batch_data.append([input_ids])

    random.shuffle(batch_data)

    return batch_data


In [46]:
last_item = list(index_dic.items())[-1]
print("The largest number in the index is ",last_item[1])

max_vocab = last_item[1] + 1

max_len = max_len + 1

The largest number in the index is  11007


In [47]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)

##  MLP for classification task

In [48]:
# design a MLP for classification task
class MLP(nn.Module):
    def __init__(self, encoder_model1, encoder_model2, embedding_size, hidden_size, output_size, dropout_rate = .1):
        super(MLP, self).__init__()
        
        self.encoder1 = encoder_model1
        self.encoder2 = encoder_model2
        
        self.fc1 = nn.Linear(embedding_size * 2, hidden_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout_rate)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, objs, attrs):
        x1 = self.encoder1(objs)
        x2 = self.encoder2(attrs)
        x = self.fc1(torch.cat((x1, x2), dim=1))
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x

# Training & Test

In [49]:
# Set parameters
hidden_size = 512
output_size = 1
learning_rate = 9e-6
num_epochs = 120
batch_size = 32

# Load pre-trained model
model1 = Encoder(n_layers, max_vocab, max_len)
model2 = Encoder(n_layers, max_vocab, max_len)

model1.train().to(device)
model2.train().to(device)

# Instantiate the model, loss function, and optimizer
MLP_model = MLP(model1, model2, d_model, hidden_size, output_size, dropout_rate=0.1).to(device)
criterion = nn.BCELoss()
optimizer = Adam(MLP_model.parameters(), lr=learning_rate)


# Prepare the data
train_inputs = torch.tensor(train_set).to(device)
train_labels = torch.tensor(train_labels).to(device)

# Create DataLoader
train_dataset = TensorDataset(train_inputs, train_labels)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_inputs = torch.tensor(test_set).to(device)
test_labels = torch.tensor(test_labels).to(device)

# Create DataLoader
test_dataset = TensorDataset(test_inputs, test_labels)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [51]:
for epoch in range(num_epochs):
    
    MLP_model.train()
    MLP_model.encoder1.train()
    MLP_model.encoder2.train()
    
    # ======================== Training =====================================
   
    
    for inputs, labels in train_loader:
        optimizer.zero_grad()

        # objs, attrs = torch.tensor_split(inputs, [1], dim = 1)
        objs, attrs = torch.tensor_split(inputs, [max_len+1], dim = 1)
        
        outputs = MLP_model(objs, attrs)
        # print(outputs.size())
        loss = criterion(outputs, labels.unsqueeze(1).float())
        loss.backward()
        optimizer.step()

        
    # Print the training loss
    print(f'Epoch:{epoch + 1} \t loss: {loss:.3f}')
        
    # ======================== Running test case =====================================
    # Switch the model to evaluate mode
    MLP_model.eval()
    MLP_model.encoder1.eval()
    MLP_model.encoder2.eval()

    # Initialize lists to store predictions and labels
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for batch_inputs, batch_labels in test_loader:
            # objs, attrs = torch.tensor_split(batch_inputs, [1], dim = 1)
            objs, attrs = torch.tensor_split(inputs, [max_len+1], dim = 1)
                
            # Get predictions
            test_outputs = MLP_model(objs, attrs)
            predictions = (test_outputs > 0.5).float().cpu().numpy()
            all_predictions.extend(predictions)

            # Convert labels to numpy
            all_labels.extend(batch_labels.cpu().numpy())

    # Convert to numpy arrays
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)

    # Convert predictions to binary (0 or 1)
    predictions_binary = (all_predictions > 0.5).astype(int)

    # Compute metrics
    accuracy = accuracy_score(all_labels, predictions_binary)
    precision = precision_score(all_labels, predictions_binary)
    recall = recall_score(all_labels, predictions_binary)
    f1 = f1_score(all_labels, predictions_binary)
    auc = roc_auc_score(all_labels, all_predictions)
    aupr = average_precision_score(all_labels, all_predictions)

    # Print the results
    print(f'Test Accuracy: {accuracy:.3f}, Precision: {precision:.3f}, Recall: {recall:.3f}, F1: {f1:.3f}, AUC: {auc:.3f}, AUPR: {aupr:.3f}')