In [2]:
import random
import re
from datasets import load_dataset

# Function to split text into sentences
def split_into_sentences(text):
    return re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)

# Load the dataset
dataset = load_dataset("zhengyun21/PMC-Patients")

# Initialize lists
brain_mri_indices = []
other_indices = []
# max_others = 6

# Filter the dataset based on conditions
for i, entry in enumerate(dataset['train']):
    patient_text = entry['patient']
    if 'brain' in patient_text and 'MRI' in patient_text:
        brain_mri_indices.append(i)
    else:
        other_indices.append(i)


patients = 500

# Extract patient histories and assign labels
osl = [dataset['train'][i]['patient'] for i in brain_mri_indices[:int(patients/2)]]
pred_1 = [1] * len(osl)

osl += [dataset['train'][i]['patient'] for i in other_indices[:int(patients/2)]]
pred_1 += [0] * int(patients/2)

# Trim patient histories to only include text up to the first "MRI"
for i, text in enumerate(osl):
    sentences = split_into_sentences(text)
    trimmed_text = ""
    for sentence in sentences:
        trimmed_text += sentence
        if 'MRI' in sentence:
            break
    osl[i] = trimmed_text

# Shuffle the lists while maintaining the corresponding labels
combined = list(zip(osl, pred_1))
random.shuffle(combined)
osl_shuffled, pred_1_shuffled = zip(*combined)

# Convert back to lists if needed
osl_shuffled = list(osl_shuffled)
pred_1_shuffled = list(pred_1_shuffled)

# Output the lengths to verify the process
print(len(osl_shuffled))
print(len(pred_1_shuffled))

# osl_shuffled = osl
# pred_1_shuffled = pred_1


500
500


In [25]:
print(osl_shuffled)
print(pred_1_shuffled)

['A 95-year-old lady with a past medical history of heart failure with reduced ejection fraction (HFrEF) and biventricular implantable cardioverter-defibrillator (ICD), hypertension, and asthma presented to the emergency department for evaluation of nausea, vomiting, and a two-month history of intermittent diarrhea, which had been worsening for a few days prior to admission.Diarrhea was associated with severe, diffuse, waxing, and waning abdominal cramps, which were noted to improve after emesis.No correlation was noted with eating habits and no history of recent antibiotic use was reported.On examination, the abdomen was soft but tender on deep palpation, with audible bowel sounds.No organomegaly or costovertebral angle (CVA) tenderness was appreciated.\\nInitial laboratory results revealed hypokalemia (3.3 mEq/L), lipase within normal limits (32 U/L), and normal transaminases (aspartate aminotransferase [AST]: 19 U/L; alanine aminotransferase [ALT]: 10 U/L) and bilirubin (total bilir

In [4]:
import re

# Assuming lis contains the text data
dataset = osl_shuffled
final_list_of_sentences = []

# Compile the regex patterns once for efficiency
sentence_splitter = re.compile(r'\.\s*')
word_tokenizer = re.compile(r'\b[\w-]+\b')

# Process each text in the dataset
for text in dataset:
    # Split text into sentences and remove empty sentences
    sentences = [sentence.strip() + '.' for sentence in sentence_splitter.split(text) if sentence.strip()]
    
    # Tokenize, lowercase, and clean each sentence
    tokenized_sentences = [' '.join(word_tokenizer.findall(sentence.lower())) for sentence in sentences]
    
    # Append the cleaned sentences to the final list
    final_list_of_sentences.append(tokenized_sentences)

print(len(final_list_of_sentences))


500


In [23]:
final_list_of_sentences

[['a 95-year-old lady with a past medical history of heart failure with reduced ejection fraction hfref and biventricular implantable cardioverter-defibrillator icd hypertension and asthma presented to the emergency department for evaluation of nausea vomiting and a two-month history of intermittent diarrhea which had been worsening for a few days prior to admission',
  'diarrhea was associated with severe diffuse waxing and waning abdominal cramps which were noted to improve after emesis',
  'no correlation was noted with eating habits and no history of recent antibiotic use was reported',
  'on examination the abdomen was soft but tender on deep palpation with audible bowel sounds',
  'no organomegaly or costovertebral angle cva tenderness was appreciated',
  'ninitial laboratory results revealed hypokalemia 3',
  '3 meq l lipase within normal limits 32 u l and normal transaminases aspartate aminotransferase ast 19 u l alanine aminotransferase alt 10 u l and bilirubin total bilirubin

In [5]:
for i in range(len(final_list_of_sentences)):
    if not final_list_of_sentences[i]:
        print('Empty')
    final_list_of_sentences[i] = [sentence for sentence in final_list_of_sentences[i] if sentence]


In [6]:
import torch 
import torch.nn as nn
import re
from torchcrf import CRF
from transformers import AutoTokenizer, AutoModel

# Define BiLSTM and MOD classes
class BiLSTM(nn.Module):
    def __init__(self):
        super(BiLSTM, self).__init__()
        self.input_size = 768
        self.hidden_size = 256
        self.num_layers = 2
        self.dropout = 0.3
        self.bilstm = nn.LSTM(input_size=self.input_size, hidden_size=128, 
                             num_layers=self.num_layers, batch_first=True, 
                             dropout=self.dropout, bidirectional=True)

    def forward(self, embeddings):
        lstm_out, _ = self.bilstm(embeddings)
        return lstm_out

class MOD(nn.Module):
    def __init__(self, input_size, num_labels):
        super(MOD, self).__init__()
        self.bilstm = BiLSTM()
        self.linear = nn.Linear(256, num_labels)
        self.crf = CRF(num_labels)

    def forward(self, x):
        lstm_output = self.bilstm(x)
        embeddings = self.linear(lstm_output)
        return embeddings

# Load BioBERT tokenizer and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")
biobert_model = AutoModel.from_pretrained("dmis-lab/biobert-v1.1").to(device)

# Load the trained model
model = MOD(768, 6).to(device)  # Ensure model is on the same device
model.load_state_dict(torch.load(r"/home/rvce/Downloads/HPCC/best_model.pth", map_location=device))
model.eval()

# Initialize lists for storing results
all_symptoms_ALL = []
symptoms_wout_duration_ALL = []
symptom_with_organ_ALL = []
new_dict_ALL = []

# Assuming final_list_of_sentences is already defined
for i in range(len(final_list_of_sentences)):    
    sentences = final_list_of_sentences[i]
    tsl_tokenized = []
    
    for sentence in sentences:
        tokens = re.findall(r'\b[\w-]+\b', sentence.lower())
        tsl_tokenized.append(tokens)
    
    vocab_test = []
    for sentence in tsl_tokenized:
        for token in sentence:
            if token not in vocab_test:
                vocab_test.append(token)
    
    word_embeddings = []
    for word in vocab_test:
        tokens = tokenizer(word, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = biobert_model(**tokens)
        embedding = outputs.last_hidden_state[0][0]
        word_embeddings.append(embedding)

    word_embeddings_test = torch.stack(word_embeddings).to(device)
    sentence_embeddings_test = []
    each_sentence = []
    
    for k in range(len(tsl_tokenized)):
        each_word_embeddings_test = []
        for token in tsl_tokenized[k]:
            word_embedding_tensor_test = word_embeddings_test[vocab_test.index(token)]
            each_word_embeddings_test.append(word_embedding_tensor_test)
        sentence_embeddings_test.extend(each_word_embeddings_test)
        each_sentence.append(each_word_embeddings_test)

    pred = []
    for j, sentence in enumerate(each_sentence):
        if len(sentence) == 0:
            continue
        with torch.no_grad():
            print(i, j)
            # Stack the sentence tensors and move them to the device
            sentence_tensor = torch.stack(sentence).to(device).view(1, -1, 768)
            output = model(sentence_tensor)
            prediction = model.crf.decode(output)
            pred.append(prediction)

    # Create lists for different labels
    lis_1 = []
    lis_3 = []
    lis_4 = []
    lis_5 = []

    for i in range(len(pred)):
        l1 = {}
        l3 = {}
        l4 = {}
        l5 = {}
        prediction = pred[i]
        s = ''
        for k in range(len(prediction)):
            if prediction[k][0] != 0:
                s += tsl_tokenized[i][k] + ' '
            elif k != 0 and prediction[k][0] == 0:
                t = prediction[k-1][0]
                if t != 0:
                    if t == 1:
                        l1[s.strip()] = k
                    elif t == 2:
                        l1[s.strip()] = k
                    elif t == 3:
                        l3[s.strip()] = k
                    elif t == 4:
                        l4[s.strip()] = k
                    elif t == 5:
                        l5[s.strip()] = k
                    s = ''
        if prediction[-1][0] != 0:
            t = prediction[-1][0]
            if t == 1:
                l1[s.strip()] = len(prediction)
            elif t == 2:
                l1[s.strip()] = len(prediction)
            elif t == 3:
                l3[s.strip()] = len(prediction)
            elif t == 4:
                l4[s.strip()] = len(prediction)
            elif t == 5:
                l5[s.strip()] = len(prediction)

        lis_1.append(l1)
        lis_3.append(l3)
        lis_4.append(l4)
        lis_5.append(l5)

    symptoms_wout_duration = []
    for l1 in lis_1:
        if l1:
            for key in l1.keys():
                symptoms_wout_duration.append(key)
    
    new_dict = {}
    for i in range(len(lis_3)):
        if not lis_3[i]:
            #new_dict_ALL.append(new_dict)
            continue
        for key_3, value_3 in lis_3[i].items():
            closest_key = None
            minimum_difference = float('inf')
            for key_1, value_1 in lis_1[i].items():
                difference = abs(value_3 - value_1)
                if difference < minimum_difference:
                    minimum_difference = difference
                    closest_key = key_1
            new_dict[key_3] = closest_key

    all_symptoms_ALL.append(symptoms_wout_duration)
    symptoms_wout_duration_ALL.append(symptoms_wout_duration)
    new_dict_ALL.append(new_dict)

# print(all_symptoms_ALL)
# print(symptoms_wout_duration_ALL)
# print(symptom_with_organ_ALL)
# print(new_dict_ALL)


  model.load_state_dict(torch.load(r"/home/rvce/Downloads/HPCC/best_model.pth", map_location=device))


0 0
0 1
0 2
0 3
0 4
0 5
0 6
0 7
0 8
0 9
0 10
0 11
0 12
0 13
0 14
0 15
0 16
0 17
0 18
0 19
1 0
1 1
1 2
1 3
1 4
1 5
1 6
1 7
1 8
1 9
1 10
1 11
1 12
1 13
1 14
1 15
1 16
1 17
1 18
1 19
1 20
1 21
1 22
1 23
1 24
1 25
1 26
1 27
2 0
2 1
2 2
2 3
2 4
2 5
2 6
2 7
2 8
2 9
2 10
2 11
2 12
2 13
2 14
3 0
3 1
3 2
3 3
3 4
3 5
4 0
4 1
4 2
4 3
4 4
4 5
4 6
4 7
4 8
4 9
4 10
4 11
4 12
4 13
4 14
4 15
4 16
4 17
4 18
4 19
4 20
4 21
4 22
4 23
4 24
4 25
4 26
4 27
4 28
4 29
4 30
4 31
4 32
4 33
4 34
4 35
4 36
4 37
4 38
4 39
4 40
4 41
4 42
4 43
4 44
4 45
4 46
5 0
5 1
5 2
5 3
5 4
5 5
5 6
5 7
6 0
6 1
6 2
6 3
6 4
7 0
7 1
7 2
7 3
7 4
7 5
8 0
8 1
8 2
8 3
8 4
8 5
8 6
8 7
8 8
8 9
8 10
8 11
8 12
8 13
8 14
8 15
8 16
8 17
9 0
9 1
9 2
9 3
9 4
9 5
10 0
10 1
10 2
10 3
10 4
10 5
10 6
10 7
10 8
10 9
10 10
10 11
10 12
10 13
10 14
10 15
10 16
10 17
10 18
11 0
11 1
11 2
11 3
11 4
11 5
11 6
11 7
11 8
11 9
11 10
11 11
11 12
11 13
11 14
11 15
11 16
11 17
11 18
11 19
11 20
12 0
12 1
12 2
13 0
13 1
13 2
13 3
13 4
13 5
13 6
13 7
13 8
13 9
1

In [7]:
l = []
for i in range(len(new_dict_ALL) - 1):
    l.append(new_dict_ALL[i])
    
l.append(new_dict_ALL[len(new_dict_ALL) - 1])

symptom_with_duration_ALL = l


In [8]:
time_units = ['second', 'minute', 'hour', 'day', 'week', 'month', 'year', 'decade', 'seconds', 'minutes', 'hours', 'days', 'weeks', 'months', 'years', 'decades']

def replace_keys(data, time_units):
    new_data = []
    for item in data:
        new_item = {}
        for key, value in item.items():
            matched_word = next((unit for unit in time_units if unit in key.lower()), None)
            if matched_word and value is not None:
                # Singularize the matched word if necessary
                if matched_word.endswith('s'):
                    matched_word = matched_word[:-1]
                new_item[matched_word] = value
        new_data.append(new_item)
    return new_data


# symptom_with_duration_ALL = [{}, {'88 to 96': None}, {}, {}, {'distribution': None}, {'2 weeks': 'cough'}, {'the same year': None}, {}]

updated_data = replace_keys(symptom_with_duration_ALL, time_units)





symptom_with_duration_ALL = updated_data
print(symptom_with_duration_ALL)


[{'day': 'diarrhea'}, {'week': 'scotoma'}, {'month': 'nasal obstruction symptoms'}, {}, {'day': 'anaphylaxis'}, {}, {'week': 'fever'}, {'year': 'neurological deficits', 'day': 'symptoms'}, {'day': 'weaning off'}, {'day': 'symptoms'}, {}, {}, {}, {}, {'week': 'altered mental status nausea'}, {'month': 'unable to attend school', 'week': 'nmultiple'}, {}, {}, {'year': 'swallowing dysfunction', 'month': 'swallowing dysfunction'}, {'month': 'unsteady', 'year': 'sexually'}, {}, {'year': 'remission'}, {'day': 'headache dizziness nausea vomiting'}, {}, {}, {'month': 'recurrence', 'year': 'technical issues'}, {}, {}, {'week': 'taste disorders'}, {}, {'week': 'delirium'}, {'month': 'symptoms'}, {}, {}, {}, {'week': 'dizygotic', 'month': 'macrocephaly'}, {}, {}, {}, {}, {'day': 'swelling'}, {}, {'month': 'arthralgias'}, {'month': 'headache'}, {'month': 'dyspnea'}, {'day': 'hematochezia'}, {'year': 'remission', 'day': 'neutropenia', 'month': 'relapse'}, {}, {'year': 'hearing loss'}, {}, {'month': 

In [9]:
from groq import Groq
import ast
import time
import re
import httpx  # Ensure you have httpx installed

# Initialize the Groq client
client = Groq(api_key="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")

# Constants
REQUESTS_PER_MINUTE = 30
REQUESTS_PER_SECOND = REQUESTS_PER_MINUTE / 60
TOKENS_PER_MINUTE = 30000
TOKENS_PER_SECOND = TOKENS_PER_MINUTE / 60
EXPECTED_TOKENS_PER_REQUEST = 200  # Estimate based on average usage

def organ(medical_text, symptoms):
    prompt = (
        f"Use the following medical text: {medical_text} "
        f"Now make a list of dictionaries with key as symptom and value as organ (the given text should mention a specific organ, if the medical text doesn't have an organ for that symptom make the value 'unspecified'). "
        f"Answer should be ['symptom1':'organ1','symptom2':'organ2',...and so on]. "
        f"Remember to use the symptoms in the list only: {symptoms}. "
        f"The key and value should be strictly present in the above text. No generation of generic organs. "
        f"Give me the final list of dictionaries. I mean a Python list of dictionaries only, not a string. "
        f"Strictly, each value should be in the medical text given or else make it 'unspecified' and the organ should be in one word. "
        f"Respond with only the list of dictionaries. Remove any duplicate dictionaries. Do not include any other text in your response."
    )
    response = client.chat.completions.create(
        messages=[
            {
                "role": "user",
                "content": prompt,
            }
        ],
        model="llama3-8b-8192",  # Changed to the correct model
    )
    final_summary = response.choices[0].message.content
    return final_summary

def parse_response(response):
    # Clean the response string
    response = response.strip()
    
    # Remove the leading and trailing brackets if any
    if response.startswith('['):
        response = response[1:]
    if response.endswith(']'):
        response = response[:-1]
    
    # Replace single quotes with double quotes
    response = response.replace("'", '"')
    
    # Regex pattern to match key-value pairs
    pattern = re.compile(r'"([^"]+)":\s*"([^"]+)"')

    # Find all matches
    matches = pattern.findall(response)

    # Build dictionaries from the extracted key-value pairs
    parsed_data = [{key.strip(): value.strip()} for key, value in matches]

    return parsed_data

def write_list_to_file(data_list, file_path):
    """Writes a list to a file in append mode, with each item on a new line."""
    with open(file_path, 'a') as file:  # Open file in append mode
        for item in data_list:
            file.write(f"{item}\n")

def organ_with_retry(medical_text, symptoms, retries=3):
    """Attempts to call the organ function with retry logic."""
    for attempt in range(retries):
        try:
            return organ(medical_text, symptoms)
        except (httpx.ReadTimeout, httpx.ConnectError) as e:
            print(f"Attempt {attempt + 1} failed: {e}")
            time.sleep(2 ** attempt)  # Exponential backoff
    print("Request failed after multiple attempts.")
    return []  # Return an empty list or handle as needed

# Define the file path
file_path = r"/home/rvce/Downloads/HPCC/output4.txt"


# Initialize an empty list to store the results
symptom_with_organ_ALL = []
c = 0

for i in range(len(all_symptoms_ALL)):
    try:
        # Generate response for each dataset entry
        funResponse = organ_with_retry(dataset[i], all_symptoms_ALL[i])
        
        # Convert the response string to a Python list of dictionaries
        symptom_with_organ = parse_response(funResponse)
        print(symptom_with_organ)
        print(c)
        
        # Write the results to a file
        write_list_to_file(symptom_with_organ, file_path)
        c += 1
        
        # Rate limiting: Delay to maintain requests per minute limit
        time.sleep(1 / REQUESTS_PER_SECOND)  # Delay to respect requests per second
        time.sleep(EXPECTED_TOKENS_PER_REQUEST / TOKENS_PER_SECOND)  # Delay to respect token limits

        # Append the parsed data to the result list
        symptom_with_organ_ALL.append(symptom_with_organ)

    except (httpx.ReadTimeout, httpx.ConnectError) as e:
        # Handle Groq API related errors
        print(f"Error calling Groq API for index {i}: {e}")
        continue  # Skip to the next index
    
    except (SyntaxError, ValueError) as e:
        # Handle parsing errors
        print(f"Error parsing response for index {i}: {e}")
        continue  # Skip to the next index

# Initialize an empty list to store the key-value pairs
key_value_pairs = []

# Iterate through the list of dictionaries
for graph in symptom_with_organ_ALL:
    for item in graph:
        if isinstance(item, dict):  # Ensure the item is a dictionary
            for k, v in item.items():
                if v != 'unspecified':
                    key_value_pairs.append(f"{k}: {v}")

# Append the key-value pairs to the file
with open(file_path, 'a') as file:  # Use 'a' mode to append
    for pair in key_value_pairs:
        file.write(pair + '\n')


[{'nausea vomiting': 'gastrointestinal'}, {'diarrhea': 'gastrointestinal'}, {'abdominal cramps': 'abdomen'}, {'hypokalemia': 'kidney'}, {'intussusception': 'colon'}, {'lesion': 'colon'}, {'telescoping': 'colon'}, {'oliguria': 'kidney'}]
0
[{'blind spots': 'eye'}, {'vision loss': 'eye'}, {'fundus abnormalities': 'eye'}, {'scotomas': 'eye'}, {'lesions': 'eye'}, {'scotoma': 'eye'}, {'obstructions': 'eye'}, {'small opaque structures': 'unspecified'}, {'relief': 'unspecified'}]
1
[{'nasal obstruction': 'nasal cavity'}, {'hyposmia': 'nasal cavity'}, {'nasal obstruction symptoms': 'unspecified'}, {'asymptomatic': 'unspecified'}, {'relief': 'unspecified'}, {'trauma': 'unspecified'}, {'anomalies': 'unspecified'}, {'bump': 'nasal cavity'}, {'focus': 'nasal cavity'}, {'convexity to the left': 'nasal cavity'}, {'white': 'unspecified'}, {'supernumerary': 'unspecified'}, {'mucosal thickening': 'sinuses'}]
2
[{'right leg weakness': 'unspecified'}, {'longer able to walk': 'unspecified'}, {'right ankle

In [10]:
import networkx as nx
import matplotlib.pyplot as plt

# Initialize a list to store graphs
graphs = []

# Process each patient's symptom data to create a graph
for i, patient_data in enumerate(symptom_with_organ_ALL):
    # Create a new graph
    G = nx.Graph()
    
    # Add a central node for the patient
    patient_node = f'Patient {i+1}'
    G.add_node(patient_node, label='Patient')
    
    # Add nodes for symptoms and connect them to the patient node
    for item in patient_data:
        for symptom, organ in item.items():
            if organ != 'unspecified':
                G.add_node(symptom, label='Symptom')
                G.add_edge(patient_node, symptom)
                # Add nodes for organs and connect them to the symptoms
                if organ != 'unspecified':
                    G.add_node(organ, label='Organ')
                    G.add_edge(symptom, organ)
            else:
                G.add_node(symptom, label='Symptom')
                G.add_edge(patient_node, symptom)

    # Optional: Set edge weights based on symptom durations
    if i < len(symptom_with_duration_ALL):
        for symptom, duration in symptom_with_duration_ALL[i].items():
            if G.has_edge(patient_node, symptom):
                G[patient_node][symptom]['weight'] = time_units.index(duration)

    # Append the graph for the patient to the list
    graphs.append(G)
def plot_graph(G):
    pos = nx.spring_layout(G, k=0.5, scale=2)  # Adjust k and scale for better spacing
    labels = nx.get_edge_attributes(G, 'label')
    weights = nx.get_edge_attributes(G, 'weight')

    plt.figure(figsize=(12, 12))  # Adjust figure size for better display
    nx.draw(G, pos, with_labels=True, node_size=2000, node_color='skyblue', font_size=10, font_weight='bold', alpha=0.7, edge_color='gray')
    
    # Draw edge labels and weights
    if labels:
        nx.draw_networkx_edge_labels(G, pos, edge_labels=labels, font_color='red', font_size=8)
    if weights:
        nx.draw_networkx_edge_labels(G, pos, edge_labels=weights, font_color='green', font_size=8, label_pos=0.3)  # Adjust label_pos for better placement

    plt.title('Patient Symptom and Organ Network')
    plt.show()

# # Plot each graph in the list
# for G in graphs:
#     plot_graph(G)


In [11]:
from transformers import BertTokenizer, BertModel, BertConfig
import torch
import numpy as np

# Load the tokenizer and model
config_path = r"/home/rvce/Downloads/HPCC/config.json"
vocab_path = r"/home/rvce/Downloads/HPCC/vocab.txt"
model_path = r"/home/rvce/Downloads/HPCC/pytorch_model.bin"

config = BertConfig.from_pretrained(config_path)
tokenizer = BertTokenizer.from_pretrained(vocab_path)
model = BertModel.from_pretrained(model_path, config=config, local_files_only=True)

# Move model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

feature_matrices = []

for G in graphs:
    matrix = [None] * len(G.nodes())  # Initialize a list to store embeddings in the correct order
    
    edge_list = list(G.edges())
    unique_nodes = sorted(set(G.nodes()))  # Ensure all nodes are included and sorted
    
    node_mapping = {node: i for i, node in enumerate(unique_nodes)}  # Mapping nodes to indices

    for node in G.nodes():
        if node not in node_mapping:
            print(f"Node '{node}' not found in node_mapping.")
            continue
        
        print(node)
        text = node  # Text to be tokenized
        inputs = tokenizer(text, return_tensors='pt').to(device)

        with torch.no_grad():
            outputs = model(**inputs)
            embeddings = outputs.last_hidden_state

        # Averaging the embeddings
        averaged_matrix = embeddings[0].mean(dim=0).view(1, -1)

        # Place the embedding in the correct row based on the node_mapping
        matrix[node_mapping[node]] = averaged_matrix

    # Remove None entries (if any) and convert the list to a tensor
    matrix = torch.cat([m for m in matrix if m is not None], dim=0)  # (num_nodes, hidden_size)
    feature_matrices.append(matrix)

# Convert feature_matrices to numpy array if needed
#feature_matrices_np = [matrix.cpu().numpy() for matrix in feature_matrices]

# Print the numpy arrays if needed




Patient 1
nausea vomiting
gastrointestinal
diarrhea
abdominal cramps
abdomen
hypokalemia
kidney
intussusception
colon
lesion
telescoping
oliguria
Patient 2
blind spots
eye
vision loss
fundus abnormalities
scotomas
lesions
scotoma
obstructions
small opaque structures
relief
Patient 3
nasal obstruction
nasal cavity
hyposmia
nasal obstruction symptoms
asymptomatic
relief
trauma
anomalies
bump
focus
convexity to the left
white
supernumerary
mucosal thickening
sinuses
Patient 4
right leg weakness
longer able to walk
right ankle flexion extension eversion
inversion
Patient 5
skin rash
body surface
hypotension
bronchospasm
rash
neck
bradycardia
myocardial hypertrophy
left ventricle
underfilling
anaphylaxis
Patient 6
right hemiparesis
brain
right facial droop
headache
dysphasia
m1 occlusion
occlusion
M1
middle cerebral artery
plaques
ICA
plaque
stenosis
dissection
crescentic hyperintense signal
haemorrhage
Patient 7
poor
vomiting
gastrointestinal
fever
complications
polyuria
kidney
Patient 8
s

In [12]:
len(feature_matrices)

500

In [13]:
# Assuming `device` has been defined as before (CUDA or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

main_edge_list = []
lis = []

for G in graphs:
    matrix = []
    print(G.nodes)

    edge_list = list(G.edges())
    unique_nodes = sorted(set(node for edge in edge_list for node in edge))  # Sort the unique nodes

    # Create a mapping from node labels to integers in a sorted order
    node_mapping = {node: i for i, node in enumerate(unique_nodes)}

    # Convert edge list to numeric representation
    numeric_edge_list = [[node_mapping[u], node_mapping[v]] for u, v in edge_list]

    # Create tensor and move to the appropriate device
    edge_tensor = torch.tensor(numeric_edge_list, dtype=torch.long).to(device)
    lis.append(edge_tensor)

    # Store the edge list in main_edge_list
    main_edge_list.append(edge_list)

# Print the tensors
#for edge_tensor in lis:
    #print(edge_tensor)


['Patient 1', 'nausea vomiting', 'gastrointestinal', 'diarrhea', 'abdominal cramps', 'abdomen', 'hypokalemia', 'kidney', 'intussusception', 'colon', 'lesion', 'telescoping', 'oliguria']
['Patient 2', 'blind spots', 'eye', 'vision loss', 'fundus abnormalities', 'scotomas', 'lesions', 'scotoma', 'obstructions', 'small opaque structures', 'relief']
['Patient 3', 'nasal obstruction', 'nasal cavity', 'hyposmia', 'nasal obstruction symptoms', 'asymptomatic', 'relief', 'trauma', 'anomalies', 'bump', 'focus', 'convexity to the left', 'white', 'supernumerary', 'mucosal thickening', 'sinuses']
['Patient 4', 'right leg weakness', 'longer able to walk', 'right ankle flexion extension eversion', 'inversion']
['Patient 5', 'skin rash', 'body surface', 'hypotension', 'bronchospasm', 'rash', 'neck', 'bradycardia', 'myocardial hypertrophy', 'left ventricle', 'underfilling', 'anaphylaxis']
['Patient 6', 'right hemiparesis', 'brain', 'right facial droop', 'headache', 'dysphasia', 'm1 occlusion', 'occlusi

In [22]:
len(pred_1_shuffled)

500

In [28]:
import pickle

with open(r"/home/rvce/Downloads/HPCC/12k Data/500 test/feature_matrices_test.pkl", 'wb') as f:
    pickle.dump(feature_matrices, f)
with open(r"/home/rvce/Downloads/HPCC/12k Data/500 test/edge_list_all_test.pkl", 'wb') as f:
    pickle.dump(lis, f)
with open(r"/home/rvce/Downloads/HPCC/12k Data/500 test/all_symptoms_ALL_test.pkl", 'wb') as f:
    pickle.dump(all_symptoms_ALL, f)
with open(r"/home/rvce/Downloads/HPCC/12k Data/500 test/symptom_with_organ_ALL_test.pkl", 'wb') as f:
    pickle.dump(symptom_with_organ_ALL, f)
with open(r"/home/rvce/Downloads/HPCC/12k Data/500 test/symptom_with_duration_ALL_test.pkl", 'wb') as f:
    pickle.dump(symptom_with_duration_ALL, f)
with open(r"/home/rvce/Downloads/HPCC/12k Data/500 test/pred_1_shuffled_test.pkl", 'wb') as f:
    pickle.dump(pred_1_shuffled, f)
with open(r"/home/rvce/Downloads/HPCC/12k Data/500 test/osl_shuffled_test.pkl", 'wb') as f:
    pickle.dump(osl_shuffled, f)

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SortAggregation
from torch_geometric.data import Data
import torch.optim as optim

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GCN_SortPool_CNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, k):
        super(GCN_SortPool_CNN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        #self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.sort_pool = SortAggregation(k=k) #prioritizing nodes (dk what k is)
        self.cnn1d = nn.Conv1d(in_channels=hidden_channels, out_channels=32, kernel_size=2)
        
        # Calculate the correct input size for the fully connected layer
        cnn_output_size = 32 * (k - 1)  # Because kernel_size=2 reduces length by 1
        self.fc = nn.Linear(cnn_output_size, out_channels)
        
    def forward(self, x, edge_index):
        # GCN Layers
        x = F.relu(self.conv1(x, edge_index))
        #x = F.relu(self.conv2(x, edge_index))
        
        # SortPooling
        x = self.sort_pool(x, torch.zeros(x.size(0), dtype=torch.long, device=x.device))  # Create a dummy batch tensor
        
        # Reshape for 1D CNN
        x = x.view(x.size(0), -1, self.sort_pool.k)  # Reshape to (batch_size, hidden_channels, k)
        
        # 1D CNN
        x = F.relu(self.cnn1d(x))
        x = x.view(x.size(0), -1)  # Flatten
        
        # Fully Connected + Sigmoid
        x = self.fc(x)
        x = torch.sigmoid(x)  # Apply sigmoid activation
        return x

# Load the model
model = GCN_SortPool_CNN(in_channels=768, hidden_channels=32, out_channels=1, k=14).to(device)  # Move model to device
model.load_state_dict(torch.load('gcn_sortpool_cnn.pth'))
model.eval()

# Create a list of Data objects for the test data
graphs = [Data(x=feature_matrices[i].to(device), edge_index=lis[i].to(device)) for i in range(len(feature_matrices))]
print(len(pred_1))
# Perform inference
with torch.no_grad():
    predictions = []
    targets = []
    c=0
    for i, graph in enumerate(graphs):
        c+=1
        if c>499:
            break
        output = model(graph.x, graph.edge_index.view(2, -1))
        predicted = (output > 0.5).float()  # Convert probabilities to binary predictions
        predictions.append(predicted.cpu().numpy())  # Move output to CPU and convert to numpy array
        targets.append(pred_1_shuffled[i])  # Collect the targets

# Convert lists to numpy arrays for easy comparison
predictions = np.array(predictions).flatten()  # Flatten to match target shape
targets = np.array(targets).flatten()  # Flatten to match prediction shape

# Calculate number of correct predictions
correct_predictions = np.sum(predictions == targets)
total_predictions = len(targets)
accuracy = correct_predictions / total_predictions

# Output number of correct predictions, accuracy, and predictions
print(f"Number of correct predictions: {correct_predictions}")
print(f"Accuracy: {accuracy:.4f}")

# If you need to compare predictions with pred_1
for i, (pred, target) in enumerate(zip(predictions, targets)):
    print(f"Graph {i} - Prediction: {pred}, Target: {target}")


  model.load_state_dict(torch.load('gcn_sortpool_cnn.pth'))


500
Number of correct predictions: 381
Accuracy: 0.7635
Graph 0 - Prediction: 0.0, Target: 0
Graph 1 - Prediction: 0.0, Target: 0
Graph 2 - Prediction: 0.0, Target: 0
Graph 3 - Prediction: 1.0, Target: 1
Graph 4 - Prediction: 0.0, Target: 0
Graph 5 - Prediction: 1.0, Target: 1
Graph 6 - Prediction: 0.0, Target: 1
Graph 7 - Prediction: 1.0, Target: 1
Graph 8 - Prediction: 0.0, Target: 0
Graph 9 - Prediction: 0.0, Target: 0
Graph 10 - Prediction: 1.0, Target: 1
Graph 11 - Prediction: 0.0, Target: 0
Graph 12 - Prediction: 1.0, Target: 0
Graph 13 - Prediction: 0.0, Target: 0
Graph 14 - Prediction: 0.0, Target: 1
Graph 15 - Prediction: 1.0, Target: 0
Graph 16 - Prediction: 0.0, Target: 0
Graph 17 - Prediction: 0.0, Target: 0
Graph 18 - Prediction: 0.0, Target: 0
Graph 19 - Prediction: 1.0, Target: 0
Graph 20 - Prediction: 0.0, Target: 1
Graph 21 - Prediction: 0.0, Target: 0
Graph 22 - Prediction: 1.0, Target: 1
Graph 23 - Prediction: 1.0, Target: 1
Graph 24 - Prediction: 1.0, Target: 1
Grap