In [None]:
#!pip install pandas
#!pip install gensim
#!pip install pandas
import pandas as pd
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel, get_scheduler
#from transformers import AdamW
from torch.utils.data import Dataset, DataLoader
import os
from google.colab import drive
drive.mount('/content/drive')
import ast
from collections import defaultdict
#from gensim.models import Word2Vec
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from tqdm.auto import tqdm
import random
import numpy as np
import datetime
!pip install transformers datasets accelerate

# Define parameters
EMBEDDING_DIM = 32  # Dimension for Word2Vec embeddings
MAX_SEQ_LENGTH = 256  # Max length for tokenized text
BATCH_SIZE = 16 # Batch size for training
#CSV_FILE_PATH = "MeDaL_with_semantic_types_and_Relations.csv"  # Path to your CSV file

CSV_FILE_PATH = "./data/data.csv"

#Replace with generated file names from preprocessing
SEMANTIC_EMBEDDINGS_PATH = "data_semantic_type_embeddings.txt" # Path to save semantic type embeddings
SEMANTIC_RELATIONS_PATH = "data_semantic_relation_embeddings_last.txt" # Path to save semantic relation embeddings
VOCAB_SEMANTIC_TYPES_PATH = "data_semantic_type_vocab_last.txt" #Path to save vocabulary for semantic types
VOCAB_RELATIONS_PATH = "data_semantic_relation_vocab_last.txt" #Path to save vocabulary for relations

OPTIMIZER_TYPE = "AdamW"  # Choose "Adam", "SGD", or "AdamW"
LEARNING_RATE = 2.2e-5
EPOCHS = 7  # Number of training epochs
embeddings_2 = {}

# Define UMLS API details
UMLS_API_KEY = "" #replace with your key
UMLS_API_URL = "https://uts-ws.nlm.nih.gov/rest"

#CSV_FILE_PATH = "MSH_All_dataset_with_GoldRelations_withcontext_last.csv"  # Path to your CSV file


# Change Model name here
MODEL_NAME = "GanjinZero/UMLSBert_ENG"
import nltk
from nltk.corpus import wordnet
nltk.download('wordnet')

df = pd.read_csv(CSV_FILE_PATH)
print("Data loaded!")

# **Full Model**

In [None]:

BATCH_SIZE = 16  # Reduce the batch size

MAX_SEQ_LENGTH = 256  # Reduce the sequence length if possible

EMBEDDING_DIM = 32  # BioBERT's hidden dimension

# Load Data
df = pd.read_csv(CSV_FILE_PATH)

print("Data Loaded")
print("length of data before augmentation is:", len(df))


print("----Run Info.----")
num_stypes =df['GOLD_SEMANTIC_ENCODING'].nunique() +1 # Get the number of unique labels from your data.
num_labels = num_stypes
print(f"No. of unique labels from data: {num_stypes}")
print(f"No. of unique labels assigned to classifier: {num_labels}")

import datetime
now =datetime.datetime.now()
print("Current execusion time:", now.strftime("%Y-%m-%d %H:%M:%S"))
print("---------")

# Load semantic type and relation embeddings (from txt files)
def load_embeddings(file_path, embedding_dim, vocabulary):
    """Loads word embeddings from a text file, handling errors for missing values or wrongly formatted rows."""
    embeddings = {}
    with open(file_path, 'r', encoding='utf-8') as f:
        for line_number, line in enumerate(f,1): # start index at 1
            values = line.strip().split()
            if len(values) > 1: #Check to make sure it has more than one element, as first element is the semantic type
              word = values[0]
              if word not in vocabulary: # only creates a vector if it is not present on vocabulary
                try:
                    vector = torch.tensor([float(val) for val in values[1:]], dtype=torch.float32)
                    embeddings[word] = vector
                except ValueError as e:
                  #print(f"Warning: Could not convert line to embeddings, ValueError {e} in line number {line_number} and content '{line}'")
                  embeddings[word] = torch.zeros(embedding_dim)  # Use a zero vector for unknown embeddings.
              else:
                  print(f"Warning: word in vocab:Skipping line: {line}, as it's a semantic type in the vocab file. line number: {line_number}")
            else: #If not, then just skip the line
                 print(f"Warning: len(values)<1 in embedding file Skipping line: {line}. Line number: {line_number}")

    return embeddings


#load vocabularies
def load_vocabularies(file_path):
  """loads vocabularies from file"""
  vocabulary = set()
  with open(file_path, 'r', encoding='utf-8') as f:
    for line_number, line in enumerate(f,1):
       item = line.strip()
       if item:
          vocabulary.add(item)
       else:
          print(f"Warning: Skipping empty line in vocabulary file '{file_path}'. Line Number is: {line_number}")
  return vocabulary

#load embedding from file
def load_embeddings_from_file(file_path, vocabulary):
    """Loads word embeddings from a text file, handling errors for missing values or wrongly formatted rows."""
    embeddings = {} # Initialize embeddings dictionary here
    with open(file_path, 'r', encoding='utf-8') as f:
        for line_number, line in enumerate(f,1):
            if(line_number ==1):
              continue
            values = line.strip().split()
            if len(values) > 1:
                i = 1
                while i < len(values) and values[i][0] not in {"-", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9"}:
                    values[i]
                    i += 1
                word = " ".join(values[:i])
                try:
                    vector = torch.tensor([float(val) for val in values[i:]], dtype=torch.float32)
                    embeddings[word] = vector
                except ValueError as e:
                    embeddings[word] = torch.zeros(EMBEDDING_DIM)


            else:
                print(f"Warning: len(values)<1 in embedding file Skipping line: {line}. Line number: {line_number}")

    return embeddings


# Load vocabularies
semantic_types_vocabulary = load_vocabularies(VOCAB_SEMANTIC_TYPES_PATH)
semantic_relations_vocabulary = load_vocabularies(VOCAB_RELATIONS_PATH)

# Load embeddings
semantic_type_embeddings = load_embeddings_from_file(SEMANTIC_EMBEDDINGS_PATH,semantic_types_vocabulary)
semantic_relation_embeddings = load_embeddings_from_file(SEMANTIC_RELATIONS_PATH, semantic_relations_vocabulary)

print("Semantic and Relation embeddings and vocabularies loaded!")


def set_seed(seed_value):
    """Set seed for reproducibility."""
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    # Essential for ensuring deterministic behavior
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Load BERT tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

class BioWSDDataset(Dataset):
        def __init__(self, df, tokenizer, max_length, semantic_type_embeddings, semantic_relation_embeddings, semantic_types_vocabulary, semantic_relations_vocabulary, augment_p = DATA_AUGMENTATION_PROB ):
            self.df = df
            self.tokenizer = tokenizer
            self.max_length = max_length
            self.semantic_type_embeddings = semantic_type_embeddings
            self.semantic_relation_embeddings = semantic_relation_embeddings
            self.semantic_types_vocabulary = semantic_types_vocabulary
            self.semantic_relations_vocabulary = semantic_relations_vocabulary
            self.augment_p = augment_p


        def __len__(self):
            return len(self.df)

        def __getitem__(self, idx):
             row = self.df.iloc[idx]

            #Text Context embedding
             text = row['TEXT']
             #location = row['LOCATION']
             #text = augment_text(text, location, p = self.augment_p)
             concepts = row["concept_names"]  # Get the abbreviation
             encoded_text = tokenizer(text, return_tensors="pt", padding="max_length", truncation =True, max_length = self.max_length) # Generate the input ids and the attention mask, using the current text
             text_embeddings = encoded_text.input_ids.squeeze(0) # remove batch dimension
             attention_mask = encoded_text.attention_mask.squeeze(0)
             abbreviation = row['ABBREV'] # Get the abbreviation, needed for embedding creation

             #Abbreviation Semantic Embeddings
             abbrev_cui_semantic_types = row['abbrev_cui_semantic_types'] if pd.notna(row['abbrev_cui_semantic_types']) else "[]"
             context_semantic_types = row['context_semantic_types'] if pd.notna(row['context_semantic_types']) else "[]"
             abbrev_relations = row['abbrev_relations'] if pd.notna(row['abbrev_relations']) else "[]"
             context_relations = row['context_relations'] if pd.notna(row['context_relations']) else "[]"

             try:
                  abbrev_cui_semantic_types = ast.literal_eval(abbrev_cui_semantic_types)
                  context_semantic_types = ast.literal_eval(context_semantic_types)
                  #print(abbrev_cui_semantic_types)
                  abbrev_relations = ast.literal_eval(abbrev_relations)
                  context_relations = ast.literal_eval(context_relations)
             except (ValueError, TypeError) as e:
                  print(f"Skipping row {idx}. Invalid semantic/relation lists format or missing values. Error is: {e}")
                  return None

             semantic_embeddings = self.create_semantic_embeddings(abbrev_cui_semantic_types,context_semantic_types, self.semantic_type_embeddings)
             relation_embeddings = self.create_relation_embeddings(abbrev_relations,context_relations, self.semantic_relation_embeddings)


             #Gold CUI Semantic Embeddings
             gold_cui_semantic_types = row['GOLD_CUI_semantic_types'] if pd.notna(row['GOLD_CUI_semantic_types']) else "[]"
             gold_cui_relations = row['gold_relations'] if pd.notna(row['gold_relations']) else "[]"
             try:
                   gold_cui_semantic_types = ast.literal_eval(gold_cui_semantic_types)
                   gold_cui_relations = ast.literal_eval(gold_cui_relations)
             except (ValueError, TypeError) as e:
                    print(f"Skipping row {idx}. Invalid semantic/relation lists format or missing values. Error is: {e}")
                    return None


             gold_semantic_embeddings = self.create_semantic_embeddings(gold_cui_semantic_types, [], self.semantic_type_embeddings) # pass an empty list for the second CUI
             gold_relation_embeddings = self.create_relation_embeddings(gold_cui_relations, [], self.semantic_relation_embeddings) # pass an empty list for the second CUI

             # Label
             label = torch.tensor(row['GOLD_SEMANTIC_ENCODING'], dtype=torch.long)
             if label == -1: # Skip samples with incorrect label
                return None
             else:
                 return text_embeddings, semantic_embeddings,relation_embeddings, gold_semantic_embeddings, gold_relation_embeddings, label,concepts, attention_mask, abbreviation #include the abbreviation

        def create_semantic_embeddings(self, semantic_types1, semantic_types2, embeddings):
            """Creates a combined embedding for a given list of semantic types."""
            all_semantic_embeddings = []
            if semantic_types1:
                for stype in semantic_types1:
                    if stype in embeddings:
                        #print(f"Semantic type is a string and found in embeddings: {stype}")
                        all_semantic_embeddings.append(embeddings[stype])
                    else:
                        print(f"Warning: Semantic type for abbrev not found in embeddings: {stype}")
                        all_semantic_embeddings.append(torch.zeros(EMBEDDING_DIM)) # Use zero vector for unknown embeddings
            else:
                print("No semantic types for abbrev provided.")
                all_semantic_embeddings.append(torch.zeros(EMBEDDING_DIM)) # Use zero vector for unknown embeddings
            if semantic_types2:
              if isinstance(semantic_types2, str): # Check if it's a string, indicating possible errors in previous parsing
                try:
                    context_semantic_types = ast.literal_eval(semantic_types2)
                    # Ensure it's a list of strings
                    if isinstance(context_semantic_types, list) and all(isinstance(item, str) for item in context_semantic_types):
                        for stype in context_semantic_types:
                            if stype in embeddings:
                                all_semantic_embeddings.append(embeddings[stype])
                            else:
                                print(f"Warning: Semantic type for context not found in embeddings: {stype}")
                                all_semantic_embeddings.append(torch.zeros(EMBEDDING_DIM))
                    else:
                        # If it's not a list of strings, handle the unexpected format. Log the issue for debugging.
                        print(f"Unexpected format for context_semantic_types: {context_semantic_types}")
                        all_semantic_embeddings.append(torch.zeros(EMBEDDING_DIM))
                except (SyntaxError, ValueError):
                    # Handle parsing errors, log the issue for debugging.
                    print(f"Error parsing context_semantic_types: {semantic_types2}")
                    all_semantic_embeddings.append(torch.zeros(EMBEDDING_DIM))
              elif isinstance(semantic_types2, list):
                #print(f"Semantic types2: {semantic_types2}")
                # If it's already a list, iterate assuming items are strings
                for stype in semantic_types2:
                  if isinstance(stype, str) and stype in embeddings:  # Added isinstance check
                     print(f"Semantic type is a string and found in embeddings: {stype}")
                     all_semantic_embeddings.append(embeddings[stype])
                  elif isinstance(stype, list):
                     for item in stype:
                        if item in embeddings:
                           #print(f"Semantic type is a list and found in embeddings: {item}")
                           all_semantic_embeddings.append(embeddings[item])
                  else:
                     print(f"Warning: Semantic type for context not found in embeddings: {stype}")
                     all_semantic_embeddings.append(torch.zeros(EMBEDDING_DIM)) # Use zero vector for unknown embeddings

            else:
                #print("No semantic types for context provided.")
                all_semantic_embeddings.append(torch.zeros(EMBEDDING_DIM)) # Use zero vector for unknown embeddings

            if all_semantic_embeddings:
                return torch.mean(torch.stack(all_semantic_embeddings), dim=0) #Mean pooling to reduce to a single vector
            else:
                print("No valid semantic embeddings found.")
                return torch.zeros(EMBEDDING_DIM)  # zero vector if no embeddings

        def create_relation_embeddings(self, relations1, relations2, embeddings):
            """Creates a combined embedding for a given list of relations."""
            all_relation_embeddings = []
            if relations1:
                for relation in relations1:  # Iterate directly through relations2
                  # Ensure it's a list of strings
                    if isinstance(relation, list):
                        for rel in relation:
                           try:
                                r_label, r_concept = rel  # Attempt to unpack
                                if (r_label, r_concept) in embeddings:
                                    all_relation_embeddings.append(embeddings[(r_label, r_concept)])
                                else:
                                    all_relation_embeddings.append(torch.zeros(EMBEDDING_DIM))  # Use zero vector for unknown embeddings
                           except ValueError:
                                print(f"Warning: Unexpected format in relations2: {rel}")
                    else:
                          try:
                              r_label, r_concept = relation  # Attempt to unpack
                              if (r_label, r_concept) in embeddings:
                                  all_relation_embeddings.append(embeddings[(r_label, r_concept)])
                              else:
                                  all_relation_embeddings.append(torch.zeros(EMBEDDING_DIM))  # Use zero vector for unknown embeddings
                          except ValueError:
                              print(f"Warning: Unexpected format in relations1: {relation}")  # Log unexpected formats

            if relations2:
                for relation in relations2:  # Iterate directly through relations2
                  # Ensure it's a list of strings
                    if isinstance(relation, list):
                        for rel in relation:
                           try:
                                r_label, r_concept = rel  # Attempt to unpack
                                if (r_label, r_concept) in embeddings:
                                    all_relation_embeddings.append(embeddings[(r_label, r_concept)])
                                else:
                                    all_relation_embeddings.append(torch.zeros(EMBEDDING_DIM))  # Use zero vector for unknown embeddings
                           except ValueError:
                                print(f"Warning: Unexpected format in relations2: {rel}")
                    else:
                          try:
                              r_label, r_concept = relation  # Attempt to unpack
                              if (r_label, r_concept) in embeddings:
                                  all_relation_embeddings.append(embeddings[(r_label, r_concept)])
                              else:
                                  all_relation_embeddings.append(torch.zeros(EMBEDDING_DIM))  # Use zero vector for unknown embeddings
                          except ValueError:
                              print(f"Warning: Unexpected format in relations2: {relation}")  # Log unexpected formats

            if all_relation_embeddings:
                return torch.mean(torch.stack(all_relation_embeddings), dim=0)  # Mean pooling to reduce to a single vector
            else:
                return torch.zeros(EMBEDDING_DIM)  # zero vector if no embeddings

class BioWSDClassifier(nn.Module):
    def __init__(self, embedding_dim, num_labels, model_name = MODEL_NAME): # num_labels is now the number of unique CUI encodings
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_labels = num_labels
        self.bert = AutoModel.from_pretrained(model_name)
        self.fc1 = nn.Linear(768*2 + 4 * embedding_dim + 768 , 256) #Combine text and both types of semantic embeddings
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, num_labels)
        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        self.to(self.device)


    def forward(self, text_embeddings, semantic_embeddings, relation_embeddings, gold_semantic_embeddings, gold_relation_embeddings,concepts, label, attention_mask = None, abbreviation = None):
        # Concatenate embeddings
        text_embeddings = text_embeddings.to(self.device)
        semantic_embeddings = semantic_embeddings.to(self.device)
        relation_embeddings = relation_embeddings.to(self.device)
        gold_semantic_embeddings = gold_semantic_embeddings.to(self.device)
        gold_relation_embeddings = gold_relation_embeddings.to(self.device)

        attention_mask = attention_mask.to(self.device)

        bert_output = self.bert(input_ids = text_embeddings, attention_mask = attention_mask)
        text_embeddings = bert_output.last_hidden_state[:, 0, :].to(self.device)  # Get embeddings of [CLS] token and move to device



        # Move abbreviation tokenization and embedding generation to the correct device
        abbreviation_input_ids = tokenizer(abbreviation, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.device)
        abbreviation_embeddings = self.bert(input_ids=abbreviation_input_ids).last_hidden_state[:, 0, :].to(self.device)

       if isinstance(abbreviation, tuple):
          for abbr in abbreviation:
            if abbr is None:
              new_abbrev.append(f"CUI: no concept")
            elif isinstance(abbr, str):
              cui_specific =""
              for abbr_item in ast.literal_eval(abbr):
                if abbr_item is None:
                  new_abbrev.append(f"CUI: no concept")
                  break
                else:
                  cui_specific += f"CUI: {abbr_item} "
              if cui_specific != "":
                new_abbrev.append(cui_specific) # Assuming you want the first element of the tuple
                #print("cui specific", cui_specific)
            else:
                   new_abbrev.append(f"CUI: {abbr} ") # Assuming you want the first element of the tuple
                   print(new_abbrev)

        abbreviation =   "[CLS] " + abbreviation + " [SEP]",
        concepts_input_ids = tokenizer(new_abbrev, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.device)
        concepts_embeddings = self.bert(input_ids=abbreviation_input_ids).last_hidden_state[:, 0, :].to(self.device)

        # Ensure concepts_embeddings and abbreviation_embeddings have the correct batch size
        batch_size = text_embeddings.shape[0]  # Get the batch size
        concepts_embeddings = concepts_embeddings[:batch_size]  # Truncate or pad if necessary
        abbreviation = abbreviation[:batch_size]  # Truncate or pad if necessary


        #print("concepts_embeddings shape:", concepts_embeddings.shape)

        combined_embeddings = torch.cat(
            (text_embeddings, semantic_embeddings, relation_embeddings,gold_semantic_embeddings,gold_relation_embeddings, concepts_embeddings,abbreviation_embeddings), dim=1
        ).to(self.device)

        x = self.fc1(combined_embeddings)
        x = self.relu1(x)
        x = self.dropout1(x)
        logits = self.fc2(x)
        return logits


def measure_efficiency(model, dataloader, device):
    model.eval()

    # 1. Measure Memory (Reset before starting)
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()

    start_time = time.time()
    total_samples = 0

    with torch.no_grad():
        for batch in dataloader:
            if batch is None:
                continue

            # Unpack the batch components based on BioWSDDataset.__getitem__
            text_embeddings, semantic_embeddings, relation_embeddings, \
            gold_semantic_embeddings, gold_relation_embeddings, labels, \
            concepts, attention_mask, abbreviation = batch

            # Move relevant inputs to device
            text_embeddings = text_embeddings.to(device)
            semantic_embeddings = semantic_embeddings.to(device)
            relation_embeddings = relation_embeddings.to(device)
            gold_semantic_embeddings = gold_semantic_embeddings.to(device)
            gold_relation_embeddings = gold_relation_embeddings.to(device)
            labels = labels.to(device)
            attention_mask = attention_mask.to(device)

            # Pass all required arguments to the model's forward method
            outputs = model(
                text_embeddings=text_embeddings,
                semantic_embeddings=semantic_embeddings,
                relation_embeddings=relation_embeddings,
                gold_semantic_embeddings=gold_semantic_embeddings,
                gold_relation_embeddings=gold_relation_embeddings,
                concepts=concepts,
                label=labels, # Pass label here even if not directly used for inference output
                attention_mask=attention_mask,
                abbreviation=abbreviation
            )

            total_samples += text_embeddings.size(0)

    end_time = time.time()
    total_time = end_time - start_time
    # CALCULATIONS
    latency_ms = (total_time * 1000) / total_samples
    throughput = total_samples / total_time
    # Get peak memory in GB
    peak_memory = torch.cuda.max_memory_allocated() / (1024 ** 3)

    print(f"--- Efficiency Metrics ---")
    print(f"Total Samples: {total_samples}")
    print(f"Total Time: {total_time:.4f}s")
    print(f"Inference Latency: {latency_ms:.4f} ms/sample")
    print(f"Throughput: {throughput:.2f} samples/sec")
    print(f"Peak GPU Memory: {peak_memory:.4f} GB")

    return latency_ms, throughput, peak_memory

#Set Seed
seeds = [42, 43, 44] # try three different seeds
all_results = []
for seed in seeds:
    print(f"\n=== Starting Experiment with Seed: {seed} ===")
    set_seed(seed)
    SEED = seed

    #1. Initialize Model & Dataloaders
    model = BioWSDClassifier(EMBEDDING_DIM, num_labels)

    # 1.1 Create dataset & DataLoader

    # Split data to training and test
    train_df, test_df = train_test_split(df, test_size=0.2, random_state=SEED)

    # Create the datasets and data loaders
    train_dataset = BioWSDDataset(train_df, tokenizer, MAX_SEQ_LENGTH, semantic_type_embeddings, semantic_relation_embeddings, semantic_types_vocabulary, semantic_relations_vocabulary)
    test_dataset = BioWSDDataset(test_df,tokenizer, MAX_SEQ_LENGTH, semantic_type_embeddings, semantic_relation_embeddings, semantic_types_vocabulary, semantic_relations_vocabulary)

    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # 3. Optimizer and Scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    # Assuming you have 'train_dataloader' defined from your previous preprocessing steps:
    num_training_steps = EPOCHS * len(train_dataloader) #You should use this in case you are running a train loop.
    lr_scheduler = get_scheduler(
        "linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
    )


    print ("Model Successfully Created")
    print ("Data successfully prepared!")


    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)


    # Training loop
    loss_fct = nn.CrossEntropyLoss() # Define loss function globally
    best_val_loss = float('inf')
    patience = 2          # Stop if no improvement after 2 epochs
    patience_counter = 0

    print(f"Starting training with Early Stopping (Patience={patience})...")

    for epoch in range(EPOCHS):
        print(f"\nEpoch {epoch + 1}/{EPOCHS}")

        # --- TRAINING PHASE ---
        model.train()
        train_loss = 0

        # (Optional) Tqdm for training progress
        progress_bar = tqdm(train_dataloader, desc="Training")

        for batch in progress_bar:
            if batch is None: continue

            # Unpack batch
            text_embeddings, semantic_embeddings, relation_embeddings, gold_semantic_embeddings, gold_relation_embeddings, labels, concepts_embeddings, attention_mask, abbreviation = batch

            # Move to device
            text_embeddings = text_embeddings.to(device)
            semantic_embeddings = semantic_embeddings.to(device)
            relation_embeddings = relation_embeddings.to(device)
            gold_semantic_embeddings = gold_semantic_embeddings.to(device)
            gold_relation_embeddings = gold_relation_embeddings.to(device)
            concepts_embeddings = concepts_embeddings.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)
            abbreviation = abbreviation # Strings don't need .to(device) usually, but your model handles it

            optimizer.zero_grad()

            logits = model(text_embeddings, semantic_embeddings, relation_embeddings, gold_semantic_embeddings, gold_relation_embeddings, concepts_embeddings, labels, attention_mask=attention_mask, abbreviation=abbreviation)

            loss = loss_fct(logits, labels)
            loss.backward()

            optimizer.step()
            lr_scheduler.step()

            train_loss += loss.item()
            progress_bar.set_postfix({'loss': loss.item()})

        avg_train_loss = train_loss / len(train_dataloader)
        print(f"Average Training Loss: {avg_train_loss:.4f}")

        # --- VALIDATION PHASE (For Early Stopping) ---
        model.eval()
        val_loss = 0

        with torch.no_grad():
            for batch in test_dataloader: # Using test set as validation for efficiency
                if batch is None: continue

                text_embeddings, semantic_embeddings, relation_embeddings, gold_semantic_embeddings, gold_relation_embeddings, labels, concepts_embeddings, attention_mask, abbreviation = batch

                # Move to device
                text_embeddings = text_embeddings.to(device)
                semantic_embeddings = semantic_embeddings.to(device)
                relation_embeddings = relation_embeddings.to(device)
                gold_semantic_embeddings = gold_semantic_embeddings.to(device)
                gold_relation_embeddings = gold_relation_embeddings.to(device)
                concepts_embeddings = concepts_embeddings.to(device)
                attention_mask = attention_mask.to(device)
                labels = labels.to(device)

                logits = model(text_embeddings, semantic_embeddings, relation_embeddings, gold_semantic_embeddings, gold_relation_embeddings, concepts_embeddings, labels, attention_mask=attention_mask, abbreviation=abbreviation)

                loss = loss_fct(logits, labels)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(test_dataloader)
        print(f"Validation Loss: {avg_val_loss:.4f}")

        # --- EARLY STOPPING CHECK ---
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            # Save the best model state
            torch.save(model.state_dict(), "best_model_state.pt")
            print("Validation loss decreased. Saving model...")
        else:
            patience_counter += 1
            print(f"EarlyStopping counter: {patience_counter} out of {patience}")
            if patience_counter >= patience:
                print("Early stopping triggered! Stopping training.")
                break

    # Load the best model before final evaluation
    print("Loading best model for final evaluation...")
    model.load_state_dict(torch.load("best_model_state.pt"))

    print ("Training Complete")



    # FINAL EVALUATION
    model.eval()
    true_labels = []
    predicted_labels = []

    with torch.no_grad():
        for batch in test_dataloader:
            if batch is None: continue
            text_embeddings, semantic_embeddings, relation_embeddings, gold_semantic_embeddings, gold_relation_embeddings, labels, concepts_embeddings, attention_mask, abbreviation = batch

            text_embeddings = text_embeddings.to(device)
            semantic_embeddings = semantic_embeddings.to(device)
            relation_embeddings = relation_embeddings.to(device)
            gold_semantic_embeddings = gold_semantic_embeddings.to(device)
            gold_relation_embeddings = gold_relation_embeddings.to(device)
            concepts_embeddings = concepts_embeddings.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)

            logits = model(text_embeddings, semantic_embeddings, relation_embeddings, gold_semantic_embeddings, gold_relation_embeddings, concepts_embeddings, labels, attention_mask=attention_mask, abbreviation=abbreviation)

            predictions = torch.argmax(logits, dim=-1)
            true_labels.extend(labels.tolist())
            predicted_labels.extend(predictions.tolist())

    # --- Metrics Calculation ---
    accuracy = accuracy_score(true_labels, predicted_labels)

    macro_f1 = f1_score(true_labels, predicted_labels, average='macro')
    weighted_f1 = f1_score(true_labels, predicted_labels, average='weighted')

    weighted_precision = precision_score(true_labels, predicted_labels, average='weighted')
    weighted_recall = recall_score(true_labels, predicted_labels, average='weighted')

    print("--------------RESULTS-------------")
    print(f"*****Model Name: {MODEL_NAME} | Seed: {SEED}*****")
    print(f"  Accuracy: {accuracy:.4f}")
    print(f"  F1-score (Macro):  {macro_f1:.4f}")
    print(f"  F1-score (Weighted):  {weighted_f1:.4f}")

    # Store results
    results = {'seed': SEED, 'accuracy': accuracy, 'f1_macro': macro_f1, 'f1_weighted': weighted_f1}
    all_results.append(results)
    print("---------")

    # --- SPECIAL CHECKS (Efficiency & Error Analysis) ---
    # Only run these on specific seeds to save time
    if(SEED == 42): # Or whatever seed you prefer
         measure_efficiency(model, test_dataloader, device)
         print("---------")
    if(SEED == 44):
         evaluate_and_save(model, test_dataloader, device, "scibert_full_model_results.csv")
         print("---------")

# --- FINAL SUMMARY ---
import statistics
f1_weighted_scores = [res['f1_weighted'] for res in all_results]
mean_score = statistics.mean(f1_weighted_scores)
if len(f1_weighted_scores) > 1:
    std_dev = statistics.stdev(f1_weighted_scores)
else:
    std_dev = 0.0

print(f"\nFinal Result (Weighted F1): {mean_score:.4f} \u00b1 {std_dev:.4f}")
