In [1]:
from sentence_transformers import SentenceTransformer
import torch
from torch_geometric.nn import GCNConv  # or GraphSAGE, GAT, etc.
from torch_geometric.data import Data
import numpy as np
import pandas as pd

import pickle

import json
import os

from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda:1"

In [38]:
from sentence_transformers import SentenceTransformer


# 1. Load a pretrained Sentence Transformer model
sentTransformer = SentenceTransformer('pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb').to("cuda:1")

  return self.fget.__get__(instance, owner)()


In [3]:
df = pd.read_pickle("./publications_clustered_kmeans.pkl")

In [None]:
embeddings = []

for index, row in df.iterrows():
    # pubIdToCommunity[int(row["node_id"])] = int(row["community_id"])
    sentence = row["title"] + ". " + str(row["abstract"])
    embeddings.append(sentTransformer.encode([sentence]))


In [6]:
embeddings = np.array(embeddings).squeeze(axis=1)
np.save("./encodedTrainFeatures.npy", embeddings)

In [4]:
embeddings = np.load("./encodedTrainFeatures.npy")

In [5]:
allIds = []


for index, row in df.iterrows():
    allIds.append(int(row["publication_ID"]))

sortedIds = sorted(allIds)
pubIdToId = {id: idx for idx, id in enumerate(sortedIds)}

In [6]:
allCitations = []

for index, row in df.iterrows():
    # print(doc["title"])
    # print(doc["abstract"])
    # print(idx)

    # print(doc["Citations"])
    # print(idx)
    if type(row["Citations"]) == int:
        continue
    citations = row["Citations"].split(";")
    docId = int(row["publication_ID"])
    # print(citations)
    for citation in citations:
        if citation == "nan" or int(citation) not in pubIdToId.keys():
            continue
        tup = (pubIdToId[docId], pubIdToId[int(citation)])
        # print(pubIdToId[docId])
        # print(pubIdToId[citation])
        # print(tup)
        allCitations.append((pubIdToId[docId], pubIdToId[int(citation)]))

In [7]:
allCommunities = df["cluster"].to_list()
    

In [8]:
y = torch.tensor(allCommunities, dtype=torch.long)

In [9]:
data = Data(
    x=torch.tensor(embeddings),         # [num_nodes, embedding_dim]
    edge_index=torch.tensor(allCitations).T,  # [2, num_edges]
    y=y,                    # [num_nodes]
).to(device)

In [None]:
class GNNModel(torch.nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.gcn1 = GCNConv(768, 256)
        self.gcn2 = GCNConv(256, 128)
        self.classifier = torch.nn.Linear(128, num_classes)
        self.relu = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(0.3)

    def forward(self, x, edge_index):
        x = self.relu(self.gcn1(x, edge_index))
        x = self.dropout(x)
        x = self.gcn2(x, edge_index)
        return self.classifier(x)

In [11]:
model = GNNModel(200).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Assuming `data` is your PyG Data object (with .x, .edge_index, .y, .train_mask)
model.train()
for epoch in range(100):
    optimizer.zero_grad()

    out = model(data.x, data.edge_index)             # forward pass
    loss = torch.nn.functional.cross_entropy(out, data.y)  # compute loss
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")


Epoch 0, Loss: 5.3493
Epoch 1, Loss: 5.2673
Epoch 2, Loss: 5.1990
Epoch 3, Loss: 5.1323
Epoch 4, Loss: 5.0644
Epoch 5, Loss: 4.9896
Epoch 6, Loss: 4.9067
Epoch 7, Loss: 4.8168
Epoch 8, Loss: 4.7192
Epoch 9, Loss: 4.6135
Epoch 10, Loss: 4.5044
Epoch 11, Loss: 4.3912
Epoch 12, Loss: 4.2680
Epoch 13, Loss: 4.1537
Epoch 14, Loss: 4.0248
Epoch 15, Loss: 3.9033
Epoch 16, Loss: 3.7802
Epoch 17, Loss: 3.6558
Epoch 18, Loss: 3.5314
Epoch 19, Loss: 3.4143
Epoch 20, Loss: 3.3005
Epoch 21, Loss: 3.1897
Epoch 22, Loss: 3.0818
Epoch 23, Loss: 2.9803
Epoch 24, Loss: 2.8826
Epoch 25, Loss: 2.7881
Epoch 26, Loss: 2.7016
Epoch 27, Loss: 2.6131
Epoch 28, Loss: 2.5364
Epoch 29, Loss: 2.4617
Epoch 30, Loss: 2.4013
Epoch 31, Loss: 2.3368
Epoch 32, Loss: 2.2829
Epoch 33, Loss: 2.2210
Epoch 34, Loss: 2.1692
Epoch 35, Loss: 2.1314
Epoch 36, Loss: 2.0780
Epoch 37, Loss: 2.0416
Epoch 38, Loss: 2.0011
Epoch 39, Loss: 1.9699
Epoch 40, Loss: 1.9395
Epoch 41, Loss: 1.9134
Epoch 42, Loss: 1.8781
Epoch 43, Loss: 1.853

In [21]:
out.shape

torch.Size([42000, 128])

In [13]:
with torch.no_grad():
    x = model.relu(model.gcn1(data.x, data.edge_index))
    newEmbeddings = model.gcn2(x, data.edge_index)

In [14]:
newEmbeddings.shape

torch.Size([42000, 128])

In [None]:
import torch.nn.functional as F


In [40]:
query_embedding = newEmbeddings[0].unsqueeze(0)  # shape: [1, 128]

cosine_sim = F.cosine_similarity(query_embedding, newEmbeddings, dim=1)  # shape: [N]

topk = torch.topk(cosine_sim, k=5)  # top 5 most similar articles
topk_indices = topk.indices
topk_scores = topk.values


In [47]:
print(topk_scores)
topk_indices

tensor([1.0000, 0.9317, 0.9231, 0.9077, 0.8981])


tensor([    0, 27424,  4370, 19191, 29419])

In [42]:
df.iloc[0]

publication_ID                                             17396995
Citations         17957262;21818356;24164861;21818356;24164861;2...
pubDate                                                  2007 May 1
language                                                        eng
title             Herpes simplex virus type 2 infection does not...
journal                          The Journal of infectious diseases
abstract          We sought to compare baseline and longitudinal...
keywords          Adult;California;epidemiology;Cohort Studies;H...
authors           Edward R Cachay; Simon D W Frost; Douglas D Ri...
venue                {'name': 'The Journal of infectious diseases'}
doi                                                  10.1086/513568
combined_text     Herpes simplex virus type 2 infection does not...
embedding         [0.025286998599767685, -0.01651591807603836, 0...
cluster                                                           6
Name: 0, dtype: object

In [46]:
idx = 0
print(df.iloc[idx]["title"])
print(df.iloc[idx]["abstract"])
print(df.iloc[idx]["cluster"])



Herpes simplex virus type 2 infection does not influence viral dynamics during early HIV 1 infection
We sought to compare baseline and longitudinal plasma HIV-1 loads between herpes simplex virus type 2 (HSV-2)-seropositive and -seronegative individuals who are enrolled in a primary HIV-1 infection cohort in San Diego, California.
6


In [50]:
idx = 4370

print(df.iloc[idx]["title"])
print(df.iloc[idx]["abstract"])
print(df.iloc[idx]["cluster"])


Herpes simplex virus type 2 acquisition during recent HIV infection does not influence plasma HIV levels
We assessed the effect of herpes simplex virus type 2 (HSV-2) acquisition on the plasma HIV RNA and CD4 cell levels among individuals with primary HIV infection using a retrospective cohort analysis. We studied 119 adult, antiretroviral-naive, recently HIV-infected men with a negative HSV-2-specific enzyme immunoassay (EIA) result at enrollment. HSV-2 acquisition was determined by seroconversion on HSV-2 EIA, confirmed by Western blot analysis. Ten men acquired HSV-2 infection a median of 1.3 years after HIV infection (HSV-2 incidence rate of 7.4 per 100 person-years of follow-up). The median time of follow-up after acquiring HSV-2 infection was 303 days. All men except 1 were asymptomatic during HSV-2 acquisition, and only 1 HSV-2 seroconverter, who was asymptomatic, had a transient increase in blood HIV load (0.5 log10 copies/mL over 11 days). The HSV-2 incidence rate was high in 

In [34]:
def info_nce_loss(proj, gcn, temperature=0.1):
    proj = torch.nn.functional.normalize(proj, dim=1)
    gcn = torch.nn.functional.normalize(gcn, dim=1)
    logits = torch.matmul(proj, gcn.T) / temperature
    labels = torch.arange(len(proj)).to(proj.device)
    return torch.nn.functional.cross_entropy(logits, labels)

def mse_alignment_loss(gcn_features, projected_features):
    return torch.nn.functional.mse_loss(projected_features, gcn_features)

def cosine_alignment_loss(gcn_features, projected_features):
    gcn_features = torch.nn.functional.normalize(gcn_features, dim=1)
    projected_features = torch.nn.functional.normalize(projected_features, dim=1)
    return 1 - (gcn_features * projected_features).sum(dim=1).mean()

In [35]:
projector = torch.nn.Sequential(
    torch.nn.Linear(768, 512),
    torch.nn.ReLU(),
    torch.nn.Linear(512, 256),
    torch.nn.ReLU(),
    torch.nn.Linear(256, 128),
).to(device)

In [30]:
embeddingDataset = TensorDataset(torch.tensor(embeddings), newEmbeddings)
embeddingsLoader = DataLoader(embeddingDataset, 128)

In [31]:
embeddings.shape

(42000, 768)

In [36]:
projectorOptim = torch.optim.Adam(projector.parameters(), lr=1e-3)


In [37]:
# for epoch in tqdm(range(100)):
#     for sentenceFeat, gnnFeat in embeddingsLoader:
#         sentenceFeat, gnnFeat = sentenceFeat.to(device), gnnFeat.to(device)

#         projected_features = projector(sentenceFeat)

#         # Choose your loss
#         loss = cosine_alignment_loss(gnnFeat, projected_features)

#         projectorOptim.zero_grad()
#         loss.backward()
#         optimizer.step()

#     print(f"Epoch {epoch}: {loss.item()}")

import torch.nn.functional as F

for epoch in range(100):
    epoch_loss = 0.0
    epoch_mse = 0.0
    epoch_cos_sim = 0.0
    total_batches = 0

    projector.train()

    loop = tqdm(embeddingsLoader, desc=f"Epoch {epoch}", leave=False)
    for sentenceFeat, gnnFeat in loop:
        sentenceFeat, gnnFeat = sentenceFeat.to(device), gnnFeat.to(device)

        projected_features = projector(sentenceFeat)

        # Alignment loss (e.g., cosine loss or InfoNCE)
        loss = cosine_alignment_loss(gnnFeat, projected_features)

        # Metric 1: Cosine similarity
        cos_sim = F.cosine_similarity(
            F.normalize(projected_features, dim=1),
            F.normalize(gnnFeat, dim=1),
            dim=1
        ).mean().item()

        # Metric 2: MSE
        mse = F.mse_loss(projected_features, gnnFeat).item()

        projectorOptim.zero_grad()
        loss.backward()
        projectorOptim.step()

        # Track stats
        epoch_loss += loss.item()
        epoch_cos_sim += cos_sim
        epoch_mse += mse
        total_batches += 1

        loop.set_postfix(loss=loss.item(), cosine=cos_sim, mse=mse)

    # Averages for the epoch
    avg_loss = epoch_loss / total_batches
    avg_cos = epoch_cos_sim / total_batches
    avg_mse = epoch_mse / total_batches

    print(f"Epoch {epoch:03d} | Loss: {avg_loss:.4f} | Cosine Sim: {avg_cos:.4f} | MSE: {avg_mse:.4f}")

                                                                                                

Epoch 000 | Loss: 0.0816 | Cosine Sim: 0.9184 | MSE: 4.1743


                                                                                                

Epoch 001 | Loss: 0.0392 | Cosine Sim: 0.9608 | MSE: 3.7519


                                                                                                

Epoch 002 | Loss: 0.0369 | Cosine Sim: 0.9631 | MSE: 3.4944


                                                                                                

Epoch 003 | Loss: 0.0355 | Cosine Sim: 0.9645 | MSE: 3.2530


                                                                                                

Epoch 004 | Loss: 0.0345 | Cosine Sim: 0.9655 | MSE: 3.0208


                                                                                                 

Epoch 005 | Loss: 0.0337 | Cosine Sim: 0.9663 | MSE: 2.7965


                                                                                                 

Epoch 006 | Loss: 0.0330 | Cosine Sim: 0.9670 | MSE: 2.5764


                                                                                                 

Epoch 007 | Loss: 0.0325 | Cosine Sim: 0.9675 | MSE: 2.3644


                                                                                                 

Epoch 008 | Loss: 0.0319 | Cosine Sim: 0.9681 | MSE: 2.1602


                                                                                                 

Epoch 009 | Loss: 0.0314 | Cosine Sim: 0.9686 | MSE: 1.9662


                                                                                                  

Epoch 010 | Loss: 0.0309 | Cosine Sim: 0.9691 | MSE: 1.7834


                                                                                                  

Epoch 011 | Loss: 0.0304 | Cosine Sim: 0.9696 | MSE: 1.6136


                                                                                                  

Epoch 012 | Loss: 0.0300 | Cosine Sim: 0.9700 | MSE: 1.4542


                                                                                                   

Epoch 013 | Loss: 0.0296 | Cosine Sim: 0.9704 | MSE: 1.3059


                                                                                                   

Epoch 014 | Loss: 0.0292 | Cosine Sim: 0.9708 | MSE: 1.1677


                                                                                                   

Epoch 015 | Loss: 0.0288 | Cosine Sim: 0.9712 | MSE: 1.0428


                                                                                                   

Epoch 016 | Loss: 0.0285 | Cosine Sim: 0.9715 | MSE: 0.9299


                                                                                                   

Epoch 017 | Loss: 0.0281 | Cosine Sim: 0.9719 | MSE: 0.8292


                                                                                                   

Epoch 018 | Loss: 0.0278 | Cosine Sim: 0.9722 | MSE: 0.7380


                                                                                                   

Epoch 019 | Loss: 0.0275 | Cosine Sim: 0.9725 | MSE: 0.6575


                                                                                                   

Epoch 020 | Loss: 0.0272 | Cosine Sim: 0.9728 | MSE: 0.5870


                                                                                                   

Epoch 021 | Loss: 0.0270 | Cosine Sim: 0.9730 | MSE: 0.5263


                                                                                                   

Epoch 022 | Loss: 0.0267 | Cosine Sim: 0.9733 | MSE: 0.4760


                                                                                                    

Epoch 023 | Loss: 0.0263 | Cosine Sim: 0.9737 | MSE: 0.4394


                                                                                                    

Epoch 024 | Loss: 0.0259 | Cosine Sim: 0.9741 | MSE: 0.4126


                                                                                                    

KeyboardInterrupt: 

In [53]:
query = sentTransformer.encode("Does HSV-2 infection affect HIV-1 viral load during early infection?", convert_to_tensor=True).unsqueeze(0)

In [46]:
projector(query)

tensor([[ 0.9475,  0.1267, -0.7355,  2.3193, -1.0245,  5.4224, -0.2830,  5.3337,
          1.1157, -1.3054,  0.9092,  3.5090,  0.9092, -1.9904,  5.9561,  1.8709,
         -0.8519,  4.1124, -5.8601, -3.2842,  0.8364, -1.7003, -2.5755,  2.5388,
          0.1860, -4.8581, -0.0871, -0.0883, -4.0005,  0.4310,  3.0253,  0.5420,
          1.1632,  0.6783, -1.9242, -0.5274, -2.2086, -2.8591, -1.2566,  3.7541,
          1.6410, -5.4824, -4.4069, -2.0614, -1.3832, -0.8065, -2.8015, -4.5868,
          0.4555,  2.3108,  0.7342, -0.4800,  4.1041,  0.5223,  4.1886, -1.6070,
         -1.0340, -0.6285, -3.6189, -2.9228, -2.0556,  0.5181,  3.5922, -3.6828,
         -2.5354, -0.7375, -2.1428,  1.0271,  2.0640, -0.4872,  0.1799,  2.5157,
         -0.1506, -3.8132, -3.9102,  3.4707, -1.8599,  0.8882,  0.1442, -1.5927,
         -1.5980,  0.3055,  1.0181, -4.4591, -0.9713, -0.6892, -0.2785, -1.7720,
         -0.9992,  1.6119, -1.3113,  2.3926, -1.3130,  1.7119, -0.3259,  1.3891,
          0.5427, -1.0244, -

In [57]:
cosine_sim = F.cosine_similarity(projector(query), newEmbeddings, dim=1)  # shape: [N]

topk = torch.topk(cosine_sim, k=5)  # top 5 most similar articles
topk_indices = topk.indices
topk_scores = topk.values


In [55]:
print(topk_scores)
topk_indices

tensor([0.9356, 0.9348, 0.9341, 0.9319, 0.9285], device='cuda:1',
       grad_fn=<TopkBackward0>)


tensor([32237, 16548,     0,  4370,  3334], device='cuda:1')

In [58]:
print(topk_scores)
topk_indices

tensor([0.9356, 0.9348, 0.9341, 0.9319, 0.9285], device='cuda:1',
       grad_fn=<TopkBackward0>)


tensor([32237, 16548,     0,  4370,  3334], device='cuda:1')

In [56]:
idx = 32237
print(df.iloc[idx]["title"])
print(df.iloc[idx]["abstract"])
print(df.iloc[idx]["cluster"])



HIV 1 HSV 2 co infected adults in early HIV 1 infection have elevated CD4 T cell counts
Introduction HIV-1 is often acquired in the presence of pre-existing co-infections, such as Herpes Simplex Virus 2 (HSV-2). We examined the impact of HSV-2 status at the time of HIV-1 acquisition for its impact on subsequent clinical course, and total CD4+ T cell phenotypes. Methods We assessed the relationship of HSV-1/HSV-2 co-infection status on CD4+ T cell counts and HIV-1 RNA levels over time prior in a cohort of 186 treatment naÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂ¯ve adults identified during early HIV-1 infection. We assessed the activation and differentiation state of total CD4+ T cells at study entry by HSV-2 status. Results Of 186 recently HIV-1 infected persons, 101 (54 %) were sero-positive for HSV-2. There was no difference in initial CD8+ T cell count, or differences between the groups for age, gender, or race based on HSV-2 status. Persons with HIV-1/HSV-2 co

In [70]:
pubIdToId[9738515]

1478

In [None]:
for val in topk_indices:
    pubID = df.iloc[val.cpu().item()]["publication_ID"]
    title = df.iloc[val.cpu().item()]["title"]
    
    print(f"{pubID}: {title}")


17957262: HIV 1 HSV 2 co infected adults in early HIV 1 infection have elevated CD4 T cell counts
16973556: Selection on the human immunodeficiency virus type 1 proteome following primary infection
17396995: Herpes simplex virus type 2 infection does not influence viral dynamics during early HIV 1 infection
18197122: Herpes simplex virus type 2 acquisition during recent HIV infection does not influence plasma HIV levels
18936487: HIV rebounds from latently infected cells rather than from continuing low level replication


In [None]:
1478

In [74]:
df[df["publication_ID"] == 9738515]

Unnamed: 0,publication_ID,Citations,pubDate,language,title,journal,abstract,keywords,authors,venue,doi,combined_text,embedding,cluster
888,9738515,22408488;21943363;20878774;30075569;31996227;3...,1998 Sep,eng,Early changes in biochemical markers of bone t...,Journal of bone and mineral research : the off...,Although the antiresorptive agent alendronate ...,Aged;Alendronate;administration & dosage;thera...,S L Greenspan; R A Parker; L Ferguson; H N Ros...,{'name': 'Journal of bone and mineral research...,10.1359/jbmr.1998.13.9.1431,Early changes in biochemical markers of bone t...,"[0.0006918551516719162, -0.038948796689510345,...",67


In [75]:
df.iloc[888]["title"]

'Early changes in biochemical markers of bone turnover predict the long term response to alendronate therapy in representative elderly women a randomized clinical trial'