In [1]:
import os
parent_path = os.path.dirname(os.getcwd())
os.chdir(parent_path)

In [2]:
from utils.dataloader import GraphTextDataset, GraphDataset, TextDataset
from torch_geometric.data import DataLoader
# from models.Model import BaseModel
from models.model2 import GAT_MLP
# from models.model3_transfert_learning import GAT_MLP_TL
import numpy as np
from transformers import AutoTokenizer
import torch
from torch import optim
import time
import pandas as pd
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils.variables import ROOT_DIR
from sklearn.decomposition import PCA
from sklearn.cluster import DBSCAN
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors

torch.cuda.empty_cache()

In [3]:
train = pd.read_csv(ROOT_DIR + '/data/train.tsv', sep='\t', header=None)
sample_train = train.sample(6000, random_state=42)
sample_train.to_csv(ROOT_DIR + '/data/sample_train.tsv', sep='\t', header=False, index=False)

model_name = 'allenai/scibert_scivocab_uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
gt = np.load("./data/token_embedding_dict.npy", allow_pickle=True)[()]
val_dataset = GraphTextDataset(root='./data/', gt=gt, split='val', tokenizer=tokenizer)
train_dataset = GraphTextDataset(root='./data/', gt=gt, split='sample_train', tokenizer=tokenizer)

In [4]:
# loading the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(ROOT_DIR + '/logs/model_20241127_171613.pt')
model = GAT_MLP(model_name=model_name, num_node_features=300, nout=768, nhid=300, graph_hidden_channels=300) 
model.load_state_dict(checkpoint)
model.to(device)

  checkpoint = torch.load(ROOT_DIR + '/logs/model_20241127_171613.pt')
2024-11-30 20:25:47.798755: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1732994747.809716 1295178 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732994747.813018 1295178 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-30 20:25:47.823932: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


GAT_MLP(
  (graph_encoder): GraphEncoder(
    (relu): ReLU()
    (ln): LayerNorm((300,), eps=1e-05, elementwise_affine=True)
    (conv1): GATConv(300, 300, heads=1)
    (conv2): GATConv(300, 300, heads=1)
    (conv3): GATConv(300, 300, heads=1)
    (mol_hidden1): Linear(in_features=300, out_features=300, bias=True)
    (mol_hidden2): Linear(in_features=300, out_features=768, bias=True)
  )
  (text_encoder): TextEncoder(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(31090, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSdpaSelfAttention(
                (query): Linear(in_featur

In [5]:
data_loader = DataLoader(val_dataset, batch_size=128, shuffle=True)

model.eval()
count = 0 

graphs_embedding = []
texts_embedding = []
for batch in data_loader:
    count +=1
    input_ids = batch.input_ids
    batch.pop('input_ids')
    attention_mask = batch.attention_mask
    batch.pop('attention_mask')
    graph_batch = batch
    
    x_graph, x_text = model(graph_batch.to(device), 
                            input_ids.to(device), 
                            attention_mask.to(device))
    
    x_graph = x_graph.detach().cpu()
    x_text = x_text.detach().cpu()

    graphs_embedding.append(x_graph)
    texts_embedding.append(x_text)

graphs_embedding = torch.cat(graphs_embedding, dim=0)
texts_embedding = torch.cat(texts_embedding, dim=0)

texts = pd.DataFrame(val_dataset.description)
graph_df = pd.DataFrame(graphs_embedding.numpy(), columns=[f'graph_emb_{i}' for i in range(graphs_embedding.shape[1])], index=texts.index)
text_df = pd.DataFrame(texts_embedding.numpy(), columns=[f'text_emb_{i}' for i in range(texts_embedding.shape[1])], index=texts.index)

  data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(self.idx_to_cid[idx])))


In [6]:
graphs_embedding_normalized = (graph_df - graph_df.mean()) / graph_df.std()
texts_embedding_normalized = (text_df - text_df.mean()) / text_df.std()

In [7]:
pca = PCA(n_components=30, random_state=42)
pca_graphs = pca.fit_transform(graphs_embedding_normalized)

min_samples = 50
eps = 13
dbscan = DBSCAN(eps=eps, min_samples=min_samples)
dbscan.fit(pca_graphs)

In [10]:
labels = dbscan.labels_

In [36]:
# checking silhouette score
from sklearn.metrics import silhouette_score
silhouette_score(pca_graphs, labels)

0.006943491

## LDA

In [34]:
from gensim.corpora.dictionary import Dictionary
from gensim.models.ldamodel import LdaModel
from nltk.corpus import stopwords

documents = texts.loc[labels==1,1]

# delete the punctuation
documents = documents.str.replace('[^\w\s]','')

# delete the stopwords
stop_words = set(stopwords.words('english'))
documents = documents.apply(lambda x: ' '.join([word for word in x.split() if word not in stop_words]))

# delete fixed words
to_remove = ['It']
documents = documents.apply(lambda x: ' '.join([word for word in x.split() if word not in to_remove]))

tokenized_docs = [doc.split() for doc in documents]
dictionary = Dictionary(tokenized_docs)
corpus = [dictionary.doc2bow(doc) for doc in tokenized_docs]

# Train LDA model
lda_model = LdaModel(corpus=corpus, id2word=dictionary, num_topics=2, random_state=42)

  documents = documents.str.replace('[^\w\s]','')


In [35]:
import pyLDAvis
import pyLDAvis.gensim_models as gensimvis
from pyLDAvis import enable_notebook

enable_notebook()

# Visualize LDA model
vis = gensimvis.prepare(lda_model, corpus, dictionary)
pyLDAvis.save_html(vis, 'lda_vis_cluster1.html')