In [1]:
# !pip install spacy
# !pip install wandb
# !python -m spacy download en_core_web_sm
# !pip install datasets
# !pip install transformers[torch]
# !pip install evaluate
# !pip install seqeval

In [2]:
import pandas as pd
import numpy as np
import re
import spacy
import random
from sklearn.metrics import precision_score, recall_score, f1_score, precision_recall_fscore_support
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import silhouette_score
import wandb
api_key = 'e39350c7003ab06462caad3dd36b7f6a14bf8670'
!wandb login e39350c7003ab06462caad3dd36b7f6a14bf8670
from transformers import EvalPrediction
from transformers import BertTokenizer, BertForTokenClassification, Trainer, TrainingArguments
from datasets import load_dataset
import evaluate
from tqdm import tqdm
import torch 
torch.manual_seed(42)
np.random.seed(42)
nlp = spacy.load("en_core_web_sm")

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/aryanrr/.netrc


# Transforming Data

In [3]:
def create_df(file_path, start_index):
    with open(file_path, 'r') as f:
        lines = f.readlines()

    # Initialize lists to store data
    topics = []
    abstracts = []
    texts = []
    entities_list = []
    classes_list = []

    # Iterate through the lines of the file
    i = start_index
    while i < len(lines):
        # Extract topic
    #     topic_match = re.match(r'^(\d+\|t\|.*)\n$', lines[i])
        topic_match = re.match(r'^\d+\|t\|(.*)\n$', lines[i])
        if topic_match:
            topic = topic_match.group(1)
            i += 1
        else:
            break  # Break if no valid topic found

        # Extract abstract
    #     abstract_match = re.match(r'^(\d+\|a\|.*)\n$', lines[i])
        abstract_match = re.match(r'^\d+\|a\|(.*)\n$', lines[i])
        if abstract_match:
            abstract = abstract_match.group(1)
            i += 1
        else:
            break  # Break if no valid abstract found

        text = topic + " " + abstract
        # Initialize lists to store entities and classes for this instance
        entities = []
        classes = []

        # Extract entities and classes
        while i < len(lines) and lines[i] != "\n":
            entity_match = re.match(r'^(\d+)\t(\d+)\t(\d+)\t([^\t]+)\t([^\t]+)\t([^\n]+)\n$', lines[i])
            if entity_match:
                entities.append(entity_match.group(4))
                classes.append(entity_match.group(5))
                i += 1
            else:
                break  # Break if no valid entity found

        # Store data for this instance
        topics.append(topic)
        abstracts.append(abstract)
        texts.append(text)
        entities_list.append(entities)
        classes_list.append(classes)

        # Skip the empty line
        i += 1

    # Create DataFrame
    df = pd.DataFrame({
        'Topic': topics,
        'Abstract': abstracts,
        'Text': texts,
        'Entities': entities_list,
        'Classes': classes_list
    })
    
    return df

In [4]:
train_df = create_df("NCBItrainset_corpus.txt", 1)
dev_df = create_df("NCBIdevelopset_corpus.txt", 1)
test_df = create_df("NCBItestset_corpus.txt", 0)

In [5]:
train_df.head()

Unnamed: 0,Topic,Abstract,Text,Entities,Classes
0,A common human skin tumour is caused by activa...,WNT signalling orchestrates a number of develo...,A common human skin tumour is caused by activa...,"[skin tumour, cancer, colon cancers, adenomato...","[DiseaseClass, DiseaseClass, DiseaseClass, Spe..."
1,HFE mutations analysis in 711 hemochromatosis ...,Hereditary hemochromatosis (HH) is a common au...,HFE mutations analysis in 711 hemochromatosis ...,"[hemochromatosis, hemochromatosis, Hereditary ...","[Modifier, SpecificDisease, SpecificDisease, S..."
2,Germline BRCA1 alterations in a population-bas...,The objective of this study was to provide mor...,Germline BRCA1 alterations in a population-bas...,"[ovarian cancer, breast cancer, ovarian cancer...","[Modifier, Modifier, Modifier, Modifier, Modif..."
3,"Identification of APC2, a homologue of the ade...",The adenomatous polyposis coli (APC) tumour-su...,"Identification of APC2, a homologue of the ade...","[adenomatous polyposis coli tumour, adenomatou...","[Modifier, Modifier, Modifier, Modifier, Speci..."
4,Familial deficiency of the seventh component o...,The serum of a 29-year old woman with a recent...,Familial deficiency of the seventh component o...,[Familial deficiency of the seventh component ...,"[SpecificDisease, DiseaseClass, SpecificDiseas..."


In [6]:
dev_df.head()

Unnamed: 0,Topic,Abstract,Text,Entities,Classes
0,Somatic-cell selection is a major determinant ...,X-chromosome inactivation in mammals is regard...,Somatic-cell selection is a major determinant ...,"[enzyme deficiency, glucose-6-phosphate dehydr...","[DiseaseClass, SpecificDisease, SpecificDiseas..."
1,"The ataxia-telangiectasia gene product, a cons...",The product of the ataxia-telangiectasia gene ...,"The ataxia-telangiectasia gene product, a cons...","[ataxia-telangiectasia, ataxia-telangiectasia,...","[Modifier, Modifier, Modifier]"
2,Molecular basis for Duarte and Los Angeles var...,Human orythrocytes that are homozygous for the...,Molecular basis for Duarte and Los Angeles var...,"[Duarte and Los Angeles variant galactosemia, ...","[CompositeMention, SpecificDisease, SpecificDi..."
3,An intronic mutation in a lariat branchpoint s...,The first step in the splicing of an intron fr...,An intronic mutation in a lariat branchpoint s...,"[inherited human disorder, fish-eye disease, f...","[DiseaseClass, SpecificDisease, SpecificDiseas..."
4,Genetic heterogeneity in hereditary breast can...,The common hereditary forms of breast cancer h...,Genetic heterogeneity in hereditary breast can...,"[hereditary breast cancer, breast cancer, here...","[SpecificDisease, SpecificDisease, SpecificDis..."


In [7]:
test_df.head()

Unnamed: 0,Topic,Abstract,Text,Entities,Classes
0,Genetic mapping of the copper toxicosis locus ...,Abnormal hepatic copper accumulation is recogn...,Genetic mapping of the copper toxicosis locus ...,"[copper toxicosis, hepatic copper accumulation...","[Modifier, SpecificDisease, DiseaseClass, Spec..."
1,Molecular analysis of the APC gene in 205 fami...,BACKGROUND/AIMS The development of colorectal...,Molecular analysis of the APC gene in 205 fami...,"[APC, FAP, APC, colorectal cancer, colorectal ...","[Modifier, SpecificDisease, Modifier, Modifier..."
2,A European multicenter study of phenylalanine ...,Phenylketonuria (PKU) and mild hyperphenylalan...,A European multicenter study of phenylalanine ...,"[phenylalanine hydroxylase deficiency, Phenylk...","[SpecificDisease, SpecificDisease, SpecificDis..."
3,Disruption of splicing regulated by a CUG-bind...,Myotonic dystrophy (DM) is caused by a CTG exp...,Disruption of splicing regulated by a CUG-bind...,"[myotonic dystrophy, Myotonic dystrophy, DM, D...","[SpecificDisease, SpecificDisease, SpecificDis..."
4,Maternal disomy and Prader-Willi syndrome cons...,Maternal uniparental disomy (UPD) for chromoso...,Maternal disomy and Prader-Willi syndrome cons...,"[Maternal disomy, Prader-Willi syndrome, Mater...","[DiseaseClass, SpecificDisease, SpecificDiseas..."


# Preprocess Text

In [8]:
def preprocess_text(text):
    # Tokenization and lemmatization using SpaCy, removing stop words
    doc = nlp(text)
    tokens = [token.text for token in doc if not token.is_stop and token.is_alpha]
    lemmas = [token.lemma_ for token in doc if not token.is_stop and token.is_alpha]
    return tokens, lemmas

In [9]:
# Apply preprocessing to Abstract column for train_df
train_df[['Text_Tokens', 'Text_Lemmas']] = train_df['Text'].apply(lambda x: pd.Series(preprocess_text(x)))

# Apply preprocessing to Abstract column for dev_df
dev_df[['Text_Tokens', 'Text_Lemmas']] = dev_df['Text'].apply(lambda x: pd.Series(preprocess_text(x)))

# Apply preprocessing to Abstract column for test_df
test_df[['Text_Tokens', 'Text_Lemmas']] = test_df['Text'].apply(lambda x: pd.Series(preprocess_text(x)))

In [10]:
dev_df.head()

Unnamed: 0,Topic,Abstract,Text,Entities,Classes,Text_Tokens,Text_Lemmas
0,Somatic-cell selection is a major determinant ...,X-chromosome inactivation in mammals is regard...,Somatic-cell selection is a major determinant ...,"[enzyme deficiency, glucose-6-phosphate dehydr...","[DiseaseClass, SpecificDisease, SpecificDiseas...","[Somatic, cell, selection, major, determinant,...","[somatic, cell, selection, major, determinant,..."
1,"The ataxia-telangiectasia gene product, a cons...",The product of the ataxia-telangiectasia gene ...,"The ataxia-telangiectasia gene product, a cons...","[ataxia-telangiectasia, ataxia-telangiectasia,...","[Modifier, Modifier, Modifier]","[ataxia, telangiectasia, gene, product, consti...","[ataxia, telangiectasia, gene, product, consti..."
2,Molecular basis for Duarte and Los Angeles var...,Human orythrocytes that are homozygous for the...,Molecular basis for Duarte and Los Angeles var...,"[Duarte and Los Angeles variant galactosemia, ...","[CompositeMention, SpecificDisease, SpecificDi...","[Molecular, basis, Duarte, Los, Angeles, varia...","[molecular, basis, Duarte, Los, Angeles, varia..."
3,An intronic mutation in a lariat branchpoint s...,The first step in the splicing of an intron fr...,An intronic mutation in a lariat branchpoint s...,"[inherited human disorder, fish-eye disease, f...","[DiseaseClass, SpecificDisease, SpecificDiseas...","[intronic, mutation, lariat, branchpoint, sequ...","[intronic, mutation, lariat, branchpoint, sequ..."
4,Genetic heterogeneity in hereditary breast can...,The common hereditary forms of breast cancer h...,Genetic heterogeneity in hereditary breast can...,"[hereditary breast cancer, breast cancer, here...","[SpecificDisease, SpecificDisease, SpecificDis...","[Genetic, heterogeneity, hereditary, breast, c...","[genetic, heterogeneity, hereditary, breast, c..."


## BIO Notation

In [11]:
def get_specific_disease_entities(row):
    return [entity for entity, class_ in zip(row['Entities'], row['Classes']) if class_ == "SpecificDisease"]

In [12]:
train_df['SpecificDiseaseEntities'] = train_df.apply(get_specific_disease_entities, axis=1)
dev_df['SpecificDiseaseEntities'] = dev_df.apply(get_specific_disease_entities, axis=1)
test_df['SpecificDiseaseEntities'] = test_df.apply(get_specific_disease_entities, axis=1)

In [13]:
# id2label = {
#     0: "O",
#     1: "B",
#     2: "I",
# }

def generate_bio_ner(row):
    ner_list = []
    specific_disease_entities = row['SpecificDiseaseEntities']
    for token in row['Text_Tokens']:
        matching_entities = [entity for entity in specific_disease_entities if token in entity.split()]
        if matching_entities:
            entity = matching_entities[0]
            if token == entity.split()[0]:
                ner_list.append(1)
            else:
                ner_list.append(2)
        else:
            ner_list.append(0)
    return ner_list

In [14]:
train_df['ner_tags'] = train_df.apply(generate_bio_ner, axis=1)
dev_df['ner_tags'] = dev_df.apply(generate_bio_ner, axis=1)
test_df['ner_tags'] = test_df.apply(generate_bio_ner, axis=1)

# Statistics

In [15]:
# Calculate statistics for train_df
train_stats = {
    'Dataset': 'Train',
    'Instances': len(train_df),
    'Average Text Length': int(train_df['Text'].apply(lambda x: len(x.split())).mean())
}

# Calculate statistics for dev_df
dev_stats = {
    'Dataset': 'Dev',
    'Instances': len(dev_df),
    'Average Text Length': int(dev_df['Text'].apply(lambda x: len(x.split())).mean())
}

# Calculate statistics for test_df
test_stats = {
    'Dataset': 'Test',
    'Instances': len(test_df),
    'Average Text Length': int(test_df['Text'].apply(lambda x: len(x.split())).mean())
}

# Combine statistics into a DataFrame
pd.DataFrame([train_stats, dev_stats, test_stats])


Unnamed: 0,Dataset,Instances,Average Text Length
0,Train,593,190
1,Dev,100,201
2,Test,100,204


In [16]:
# Value counts of each dataframe
value_counts_df = pd.DataFrame({
    'Train': train_df['Classes'].explode().value_counts(),
    'Dev': dev_df['Classes'].explode().value_counts(),
    'Test': test_df['Classes'].explode().value_counts()
})

value_counts_df

Unnamed: 0,Train,Dev,Test
SpecificDisease,2972,412,555
Modifier,1289,214,264
DiseaseClass,769,126,121
CompositeMention,115,35,20


# Baselines

## Entity Extraction

### Baseline 1 - Random Performance

In [17]:
def baseline_entity_extraction(text):
    # Randomly decide the length of entities (in words)
    entity_length = random.randint(1, 3)
    
    # Randomly decide how many entities to extract
    num_entities = random.randint(1, 5)
    
    # Extract entities based on random decisions
    entities = []
    for _ in range(num_entities):
        start_idx = random.randint(0, len(text.split()) - entity_length)
        end_idx = start_idx + entity_length
        entity = " ".join(text.split()[start_idx:end_idx])
        if len(entity) > 0:
            entities.append(entity)
    
    return entities

In [18]:
# Apply baseline entity extraction to test_df
test_df['Baseline_1_Entities'] = test_df['Text'].apply(baseline_entity_extraction)

In [19]:
# Function to compute precision, recall, and F1 score
def evaluate_baseline(true_entities, predicted_entities):
    true_entities = set(true_entities)
    predicted_entities = set(predicted_entities)
    
    # Calculate precision, recall, and F1 score
    precision = len(true_entities.intersection(predicted_entities)) / len(predicted_entities)
    recall = len(true_entities.intersection(predicted_entities)) / len(true_entities)
    f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
    
    return precision, recall, f1

In [20]:
# Compute precision, recall, and F1 score for baseline
precision_scores = []
recall_scores = []
f1_scores = []

for true_entities, predicted_entities in zip(test_df['Entities'], test_df['Baseline_1_Entities']):
    precision, recall, f1 = evaluate_baseline(true_entities, predicted_entities)
    precision_scores.append(precision)
    recall_scores.append(recall)
    f1_scores.append(f1)

# Calculate average scores
average_precision = sum(precision_scores) / len(precision_scores)
average_recall = sum(recall_scores) / len(recall_scores)
average_f1 = sum(f1_scores) / len(f1_scores)

print("Baseline 1 Evaluation Results:")
print(f"Average Precision: {average_precision:.3f}")
print(f"Average Recall: {average_recall:.3f}")
print(f"Average F1 Score: {average_f1:.3f}")

Baseline 1 Evaluation Results:
Average Precision: 0.002
Average Recall: 0.003
Average F1 Score: 0.002


### Baseline 2 - Most Frequent

In [21]:
from collections import Counter

# Calculate the most frequent length of entities in the dataset
all_entities = train_df['Entities'].explode()
entity_lengths = [len(entity.split()) for entity in all_entities]
most_frequent_length = Counter(entity_lengths).most_common(1)[0][0]

def baseline_entity_extraction_most_frequent(text, most_frequent_length):
    # Determine the number of entities to extract
    num_entities = random.randint(1, 3)  # You can adjust the range as needed
    
    # Extract entities with the most frequent length
    entities = []
    for _ in range(num_entities):
        start_idx = random.randint(0, len(text.split()) - most_frequent_length)
        end_idx = start_idx + most_frequent_length
        entity = " ".join(text.split()[start_idx:end_idx])
        if len(entity) > 0:
            entities.append(entity)
    
    return entities

In [22]:
# Apply baseline entity extraction to test_df using the most frequent length
test_df['Baseline_2_Entities'] = test_df['Text'].apply(lambda x: baseline_entity_extraction_most_frequent(x, most_frequent_length))

In [23]:
# Evaluate Baseline 2
precision_scores_most_frequent = []
recall_scores_most_frequent = []
f1_scores_most_frequent = []

for true_entities, predicted_entities in zip(test_df['Entities'], test_df['Baseline_2_Entities']):
    precision, recall, f1 = evaluate_baseline(true_entities, predicted_entities)
    precision_scores_most_frequent.append(precision)
    recall_scores_most_frequent.append(recall)
    f1_scores_most_frequent.append(f1)

# Calculate average scores for Baseline 2
average_precision_most_frequent = sum(precision_scores_most_frequent) / len(precision_scores_most_frequent)
average_recall_most_frequent = sum(recall_scores_most_frequent) / len(recall_scores_most_frequent)
average_f1_most_frequent = sum(f1_scores_most_frequent) / len(f1_scores_most_frequent)

print("Baseline 2 Evaluation Results:")
print(f"Average Precision: {average_precision_most_frequent:.3f}")
print(f"Average Recall: {average_recall_most_frequent:.3f}")
print(f"Average F1 Score: {average_f1_most_frequent:.3f}")

Baseline 2 Evaluation Results:
Average Precision: 0.020
Average Recall: 0.004
Average F1 Score: 0.006


## Clustering

In [24]:
def random_clustering(n_samples, n_clusters):
    np.random.seed(630)
    labels = np.random.randint(0, n_clusters, size=n_samples)
    return labels

def clustering_baseline(test_df, col1, col2):
    specific_disease_entities = test_df.apply(lambda row: [entity for entity, class_ in zip(row[col1], row[col2]) if class_ == 'SpecificDisease'], axis=1)
    unique_entities = set([entity for sublist in specific_disease_entities for entity in sublist])

    vectorizer = TfidfVectorizer()
    X = vectorizer.fit_transform(unique_entities)

    n_clusters = 5
    n_samples = X.shape[0]
    
    random_labels = random_clustering(n_samples, n_clusters)
    silhouette_avg = silhouette_score(X, random_labels)
    return silhouette_avg

### Baseline 1 - Random Performance

In [25]:
# Function to randomly assign classes to entities
def baseline_class_assignment(entities):
    # Randomly generate classes for each entity
    np.random.seed(630)
    random_classes = [random.choice(['DiseaseClass', 'SpecificDisease', 'Modifier', 'CompositeMention']) for _ in entities]
    return random_classes

# Apply baseline class assignment to test_df
test_df['Baseline_1_Classes'] = test_df['Baseline_1_Entities'].apply(baseline_class_assignment)

In [26]:
baseline1_silhouette = clustering_baseline(test_df, 'Baseline_1_Entities', 'Baseline_1_Classes')
print(f'Silhouette Score for Baseline 1: {baseline1_silhouette}')

Silhouette Score for Baseline 1: -0.0159664500410288


### Baseline 2 - Most Frequent

In [27]:
# Calculate most frequent class from the training data
most_frequent_class = train_df['Classes'].explode().value_counts().idxmax()

# Function to assign most frequent class to entities
def baseline_class_assignment_most_frequent(entities):
    # Assign most frequent classes to each entity
    return [most_frequent_class] * len(entities)

# Apply baseline class assignment to test_df
test_df['Baseline_2_Classes'] = test_df['Baseline_2_Entities'].apply(baseline_class_assignment_most_frequent)

In [28]:
baseline2_silhouette = clustering_baseline(test_df, 'Baseline_2_Entities', 'Baseline_2_Classes')
print(f'Silhouette Score for Baseline 2: {baseline2_silhouette}')

Silhouette Score for Baseline 2: -0.01225949881011199


In [29]:
# Create a DataFrame for evaluation results
evaluation_results = pd.DataFrame({
    'Baseline': ['Baseline 1', 'Baseline 2'],
    'Avg Precision': [round(average_precision, 3), round(average_precision_most_frequent, 3)],
    'Avg Recall': [round(average_recall, 3), round(average_recall_most_frequent, 3)],
    'Avg F1 Score': [round(average_f1, 3), round(average_f1_most_frequent, 3)],
    'Silhouette Score': [round(baseline1_silhouette, 3), round(baseline2_silhouette, 3)]
})
evaluation_results

Unnamed: 0,Baseline,Avg Precision,Avg Recall,Avg F1 Score,Silhouette Score
0,Baseline 1,0.002,0.003,0.002,-0.016
1,Baseline 2,0.02,0.004,0.006,-0.012


In [30]:
evaluation_results.T

Unnamed: 0,0,1
Baseline,Baseline 1,Baseline 2
Avg Precision,0.002,0.02
Avg Recall,0.003,0.004
Avg F1 Score,0.002,0.006
Silhouette Score,-0.016,-0.012


In [31]:
test_df.drop(columns = ['Baseline_1_Entities','Baseline_2_Entities'], inplace = True)
test_df.drop(columns = ['Baseline_1_Classes','Baseline_2_Classes'], inplace = True)

In [32]:
from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer, TrainingArguments
from sklearn.preprocessing import LabelEncoder

In [33]:
len(train_df['Entities'].explode().unique())

1691

In [34]:
train_df.head()

Unnamed: 0,Topic,Abstract,Text,Entities,Classes,Text_Tokens,Text_Lemmas,SpecificDiseaseEntities,ner_tags
0,A common human skin tumour is caused by activa...,WNT signalling orchestrates a number of develo...,A common human skin tumour is caused by activa...,"[skin tumour, cancer, colon cancers, adenomato...","[DiseaseClass, DiseaseClass, DiseaseClass, Spe...","[common, human, skin, tumour, caused, activati...","[common, human, skin, tumour, cause, activate,...","[adenomatous polyposis coli, APC, pilomatricom...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,HFE mutations analysis in 711 hemochromatosis ...,Hereditary hemochromatosis (HH) is a common au...,HFE mutations analysis in 711 hemochromatosis ...,"[hemochromatosis, hemochromatosis, Hereditary ...","[Modifier, SpecificDisease, SpecificDisease, S...","[HFE, mutations, analysis, hemochromatosis, pr...","[HFE, mutation, analysis, hemochromatosis, pro...","[hemochromatosis, Hereditary hemochromatosis, ...","[0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, ..."
2,Germline BRCA1 alterations in a population-bas...,The objective of this study was to provide mor...,Germline BRCA1 alterations in a population-bas...,"[ovarian cancer, breast cancer, ovarian cancer...","[Modifier, Modifier, Modifier, Modifier, Modif...","[Germline, alterations, population, based, ser...","[Germline, alteration, population, base, serie...",[ovarian cancer],"[0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,"Identification of APC2, a homologue of the ade...",The adenomatous polyposis coli (APC) tumour-su...,"Identification of APC2, a homologue of the ade...","[adenomatous polyposis coli tumour, adenomatou...","[Modifier, Modifier, Modifier, Modifier, Speci...","[Identification, homologue, adenomatous, polyp...","[identification, homologue, adenomatous, polyp...",[cancer],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,Familial deficiency of the seventh component o...,The serum of a 29-year old woman with a recent...,Familial deficiency of the seventh component o...,[Familial deficiency of the seventh component ...,"[SpecificDisease, DiseaseClass, SpecificDiseas...","[Familial, deficiency, seventh, component, com...","[familial, deficiency, seventh, component, com...",[Familial deficiency of the seventh component ...,"[1, 2, 2, 2, 2, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, ..."


In [35]:
from datasets import Dataset, DatasetDict

# For training data
# train_ner = train_df.rename(columns={'Text_Tokens': 'tokens', 'ner_tags': 'labels'})[['tokens', 'labels']]
train_ner = train_df.rename(columns={'Text_Tokens': 'tokens', 'ner_tags': 'labels'})
train_ner['id'] = train_ner.index.tolist()  # Adding 'id' column
train_ner = train_ner[['id', 'tokens', 'labels']]

# For development data
# dev_ner = dev_df.rename(columns={'Text_Tokens': 'tokens', 'ner_tags': 'labels'})[['tokens', 'labels']]
dev_ner = dev_df.rename(columns={'Text_Tokens': 'tokens', 'ner_tags': 'labels'})
dev_ner['id'] = dev_ner.index.tolist()  # Adding 'id' column
dev_ner = dev_ner[['id', 'tokens', 'labels']]

# For testing data
# test_ner = test_df.rename(columns={'Text_Tokens': 'tokens', 'ner_tags': 'labels'})[['tokens', 'labels']]
test_ner = test_df.rename(columns={'Text_Tokens': 'tokens', 'ner_tags': 'labels'})
test_ner['id'] = test_ner.index.tolist()  # Adding 'id' column
test_ner = test_ner[['id', 'tokens', 'labels']]

train_ner = Dataset.from_pandas(train_ner)
dev_ner = Dataset.from_pandas(dev_ner)
test_ner = Dataset.from_pandas(test_ner)

dataset = DatasetDict({"train": train_ner, "validation": dev_ner, "test": test_ner})
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'labels'],
        num_rows: 593
    })
    validation: Dataset({
        features: ['id', 'tokens', 'labels'],
        num_rows: 100
    })
    test: Dataset({
        features: ['id', 'tokens', 'labels'],
        num_rows: 100
    })
})

In [36]:
dataset['train']['labels'][0]

[0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 2,
 2,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0]

In [37]:
label_list=["O","B","I"]

id2label = {
    0: "O",
    1: "B",
    2: "I",
}
label2id = {
    "O": 0,
    "B": 1,
    "I": 2,
}

# NCBI Disease

In [38]:
# # Load the NCBI Disease Corpus dataset
# dataset = load_dataset("ncbi_disease")
# # Renaming 'ner_tags' to 'label' in the train dataset
# dataset['train'] = dataset['train'].rename_column("ner_tags", "labels")

# # Renaming 'ner_tags' to 'label' in the validation dataset
# dataset['validation'] = dataset['validation'].rename_column("ner_tags", "labels")

# # Renaming 'ner_tags' to 'label' in the test dataset
# dataset['test'] = dataset['test'].rename_column("ner_tags", "labels")



# dataset

In [39]:
dataset['train']['tokens'][0]

['common',
 'human',
 'skin',
 'tumour',
 'caused',
 'activating',
 'mutations',
 'beta',
 'catenin',
 'WNT',
 'signalling',
 'orchestrates',
 'number',
 'developmental',
 'programs',
 'response',
 'stimulus',
 'cytoplasmic',
 'beta',
 'catenin',
 'encoded',
 'stabilized',
 'enabling',
 'downstream',
 'transcriptional',
 'activation',
 'members',
 'LEF',
 'TCF',
 'family',
 'target',
 'genes',
 'beta',
 'catenin',
 'TCF',
 'encodes',
 'c',
 'MYC',
 'explaining',
 'constitutive',
 'activation',
 'WNT',
 'pathway',
 'lead',
 'cancer',
 'particularly',
 'colon',
 'colon',
 'cancers',
 'arise',
 'mutations',
 'gene',
 'encoding',
 'adenomatous',
 'polyposis',
 'coli',
 'APC',
 'protein',
 'required',
 'ubiquitin',
 'mediated',
 'degradation',
 'beta',
 'catenin',
 'small',
 'percentage',
 'colon',
 'cancers',
 'harbour',
 'beta',
 'catenin',
 'stabilizing',
 'mutations',
 'Recently',
 'discovered',
 'transgenic',
 'mice',
 'expressing',
 'activated',
 'beta',
 'catenin',
 'predisposed',
 '

# BERT Model

In [40]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("google/electra-large-discriminator")

# def tokenize_function(examples):
#     examples["tokens"] = [" ".join(tokens) for tokens in examples["tokens"]]
#     return tokenizer(examples["tokens"], padding="max_length", truncation=True)

# tokenized_datasets = dataset.map(tokenize_function, batched=True)

In [41]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples["labels"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:  # Set the special tokens to -100.
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:  # Only label the first token of a given word.
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

tokenized_datasets =dataset.map(tokenize_and_align_labels, batched=True)

Map:   0%|          | 0/593 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [42]:
example = tokenized_datasets['train'][0]
example

{'id': 0,
 'tokens': ['common',
  'human',
  'skin',
  'tumour',
  'caused',
  'activating',
  'mutations',
  'beta',
  'catenin',
  'WNT',
  'signalling',
  'orchestrates',
  'number',
  'developmental',
  'programs',
  'response',
  'stimulus',
  'cytoplasmic',
  'beta',
  'catenin',
  'encoded',
  'stabilized',
  'enabling',
  'downstream',
  'transcriptional',
  'activation',
  'members',
  'LEF',
  'TCF',
  'family',
  'target',
  'genes',
  'beta',
  'catenin',
  'TCF',
  'encodes',
  'c',
  'MYC',
  'explaining',
  'constitutive',
  'activation',
  'WNT',
  'pathway',
  'lead',
  'cancer',
  'particularly',
  'colon',
  'colon',
  'cancers',
  'arise',
  'mutations',
  'gene',
  'encoding',
  'adenomatous',
  'polyposis',
  'coli',
  'APC',
  'protein',
  'required',
  'ubiquitin',
  'mediated',
  'degradation',
  'beta',
  'catenin',
  'small',
  'percentage',
  'colon',
  'cancers',
  'harbour',
  'beta',
  'catenin',
  'stabilizing',
  'mutations',
  'Recently',
  'discovered

In [43]:
len(tokenized_datasets['train'][0]['labels'])

263

In [44]:
len(tokenized_datasets['train'][0]['input_ids'])

263

In [45]:
seqeval = evaluate.load("seqeval")

In [46]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = seqeval.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

id2label = {
    0: "O",
    1: "B",
    2: "I",

}
label2id = {
    "O": 0,
    "B": 1,
    "I": 2,

}

In [47]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer

model = AutoModelForTokenClassification.from_pretrained(
    "google/electra-large-discriminator", num_labels=3, id2label=id2label, label2id=label2id)

Some weights of ElectraForTokenClassification were not initialized from the model checkpoint at google/electra-large-discriminator and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [48]:
# check if gpu is available
device = 'cpu' 
if torch.backends.mps.is_available():
    device = 'mps'
if torch.cuda.is_available():
    device = 'cuda'
print(f"Using '{device}' device")

model.to(device)

Using 'cuda' device


ElectraForTokenClassification(
  (electra): ElectraModel(
    (embeddings): ElectraEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): ElectraEncoder(
      (layer): ModuleList(
        (0-23): 24 x ElectraLayer(
          (attention): ElectraAttention(
            (self): ElectraSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): ElectraSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (Laye

In [49]:
training_args = TrainingArguments(
    output_dir="./",
    overwrite_output_dir=True,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=1,
    num_train_epochs=10,
    do_eval=True,
    seed=12345,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_f1",
    greater_is_better=True,
    report_to="wandb", 
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [50]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

In [51]:
# Set up trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    data_collator=data_collator
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [52]:
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[34m[1mwandb[0m: Currently logged in as: [33mmichaelbrown[0m ([33msi630_hw[0m). Use [1m`wandb login --relogin`[0m to force relogin
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,No log,0.167544,0.465498,0.608883,0.527623,0.930882
2,No log,0.140671,0.538122,0.697708,0.607611,0.942047
3,No log,0.115909,0.630807,0.739255,0.680739,0.958175
4,No log,0.116181,0.635236,0.733524,0.680851,0.957111
5,No log,0.102418,0.691218,0.69914,0.695157,0.965352
6,No log,0.10266,0.725146,0.710602,0.7178,0.968011
7,No log,0.11627,0.683155,0.732092,0.706777,0.964112
8,No log,0.115089,0.74058,0.732092,0.736311,0.968631
9,No log,0.116212,0.741557,0.723496,0.732415,0.969074
10,No log,0.113088,0.744838,0.723496,0.734012,0.969694


TrainOutput(global_step=380, training_loss=0.1180232098228053, metrics={'train_runtime': 431.8838, 'train_samples_per_second': 13.731, 'train_steps_per_second': 0.88, 'total_flos': 2951756444206140.0, 'train_loss': 0.1180232098228053, 'epoch': 10.0})

In [53]:
trainer.evaluate()

{'eval_loss': 0.1150888130068779,
 'eval_precision': 0.7405797101449275,
 'eval_recall': 0.7320916905444126,
 'eval_f1': 0.7363112391930837,
 'eval_accuracy': 0.9686309260079752,
 'eval_runtime': 1.8288,
 'eval_samples_per_second': 54.682,
 'eval_steps_per_second': 54.682,
 'epoch': 10.0}

In [54]:
def get_numerical_predictions(predictions):
    preds = predictions.predictions
    preds = np.argmax(preds, axis=2)
    labels = predictions.label_ids

    true_predictions = [
        [label_list[p] for p, l in zip(prediction, label) if l != -100]
        for prediction, label in zip(preds, labels)
    ]

    numerical_predictions = [
        [label2id[label] for label in sentence_labels]
        for sentence_labels in true_predictions
    ]
    
    return numerical_predictions

In [55]:
train_predictions = trainer.predict(tokenized_datasets["train"])
train_numerical_predictions = get_numerical_predictions(train_predictions)

train_df = dataset['train'].to_pandas()

train_df['predicted_labels'] = train_numerical_predictions

train_df.head()

Unnamed: 0,id,tokens,labels,predicted_labels
0,0,"[common, human, skin, tumour, caused, activati...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,1,"[HFE, mutations, analysis, hemochromatosis, pr...","[0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, ...","[0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 2, 1, 0, 0, ..."
2,2,"[Germline, alterations, population, based, ser...","[0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,3,"[Identification, homologue, adenomatous, polyp...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, ..."
4,4,"[Familial, deficiency, seventh, component, com...","[1, 2, 2, 2, 2, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, ...","[1, 2, 2, 2, 2, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, ..."


In [56]:
def extract_entities(tokens, labels):
    entities = []
    current_entity = []
    for token, label in zip(tokens, labels):
        if label == 1:  # B-label
            if current_entity:
                entities.append(" ".join(current_entity))
                current_entity = []
            current_entity.append(token)
        elif label == 2:  # I-label
            if current_entity:
                current_entity.append(token)
            else:
                current_entity.append(token)
        else:  # O-label
            if current_entity:
                entities.append(" ".join(current_entity))
                current_entity = []
    if current_entity:
        entities.append(" ".join(current_entity))
    return entities

train_df['predicted_entities'] = train_df.apply(lambda row: extract_entities(row['tokens'], row['predicted_labels']), axis=1)
train_df.head()

Unnamed: 0,id,tokens,labels,predicted_labels,predicted_entities
0,0,"[common, human, skin, tumour, caused, activati...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[adenomatous polyposis coli, pilomatricomas, p..."
1,1,"[HFE, mutations, analysis, hemochromatosis, pr...","[0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, ...","[0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 2, 1, 0, 0, ...","[hemochromatosis, hemochromatosis, Hereditary ..."
2,2,"[Germline, alterations, population, based, ser...","[0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, ...","[ovarian cancer, cancer, ovarian cancer, ovari..."
3,3,"[Identification, homologue, adenomatous, polyp...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, ...","[polyposis, polyposis]"
4,4,"[Familial, deficiency, seventh, component, com...","[1, 2, 2, 2, 2, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, ...","[1, 2, 2, 2, 2, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, ...",[Familial deficiency seventh component complem...


In [57]:
val_predictions = trainer.predict(tokenized_datasets["validation"])
val_numerical_predictions = get_numerical_predictions(val_predictions)

val_df = dataset['validation'].to_pandas()

val_df['predicted_labels'] = val_numerical_predictions

val_df['predicted_entities'] = val_df.apply(lambda row: extract_entities(row['tokens'], row['predicted_labels']), axis=1)

val_df.head()

Unnamed: 0,id,tokens,labels,predicted_labels,predicted_entities
0,0,"[Somatic, cell, selection, major, determinant,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 2, ...","[blood, dehydrogenase, enzyme deficiency, bloo..."
1,1,"[ataxia, telangiectasia, gene, product, consti...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[]
2,2,"[Molecular, basis, Duarte, Los, Angeles, varia...","[0, 0, 1, 0, 0, 2, 2, 0, 0, 0, 1, 2, 2, 2, 0, ...","[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0, ...","[galactosemia, galactosemia, galactosemia]"
3,3,"[intronic, mutation, lariat, branchpoint, sequ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, ...","[disease, disease, FED, FED, FED]"
4,4,"[Genetic, heterogeneity, hereditary, breast, c...","[0, 0, 1, 2, 2, 0, 0, 1, 0, 2, 2, 0, 0, 0, 0, ...","[0, 0, 1, 2, 2, 0, 0, 1, 0, 2, 2, 0, 0, 0, 0, ...","[hereditary breast cancer, hereditary, breast ..."


In [58]:
test_predictions = trainer.predict(tokenized_datasets["test"])
test_numerical_predictions = get_numerical_predictions(test_predictions)

test_df = dataset['test'].to_pandas()

test_df['predicted_labels'] = test_numerical_predictions

test_df['predicted_entities'] = test_df.apply(lambda row: extract_entities(row['tokens'], row['predicted_labels']), axis=1)

test_df.head()

Unnamed: 0,id,tokens,labels,predicted_labels,predicted_entities
0,0,"[Genetic, mapping, copper, toxicosis, locus, B...","[0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[copper toxicosis, copper, copper, Wilson dise..."
1,1,"[Molecular, analysis, APC, gene, families, ext...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ...","[FAP, colorectal cancer, colorectal cancer, fa..."
2,2,"[European, multicenter, study, phenylalanine, ...","[0, 0, 0, 1, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 1, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[phenylalanine hydroxylase deficiency, Phenylk..."
3,3,"[Disruption, splicing, regulated, CUG, binding...","[0, 0, 0, 0, 0, 0, 1, 2, 1, 2, 1, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 1, 2, 1, 2, 1, 0, 0, 0, 0, ...","[myotonic dystrophy, Myotonic dystrophy, DM, D..."
4,4,"[Maternal, disomy, Prader, Willi, syndrome, co...","[1, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 1, 2, 2, 2, ...","[1, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 1, 2, 2, 1, ...","[Maternal disomy, syndrome, Maternal uniparent..."


In [59]:
# Extract metrics from test_predictions
precision = test_predictions[2]['test_precision']
recall = test_predictions[2]['test_recall']
f1 = test_predictions[2]['test_f1']
accuracy = test_predictions[2]['test_accuracy']

# Print rounded metrics
print("Test Precision:", round(precision, 2))
print("Test Recall:", round(recall, 2))
print("Test F1 Score:", round(f1, 2))
print("Test Accuracy:", round(accuracy, 2))


Test Precision: 0.68
Test Recall: 0.72
Test F1 Score: 0.7
Test Accuracy: 0.95


In [60]:
# Define the input text
input_text = "The patient was diagnosed with stage IV lung adenocarcinoma. Mutation analysis revealed a mutation in the EGFR gene."

# Tokenize the input text
inputs = tokenizer(input_text, return_tensors="pt")

# Move input tensors to the same device as the model
inputs = {key: tensor.to(model.device) for key, tensor in inputs.items()}

# Move the model to the same device as the input tensors
model.to(inputs["input_ids"].device)

# Make predictions
with torch.no_grad():
    outputs = model(**inputs)

# Get the predicted labels
predictions = torch.argmax(outputs.logits, dim=2)

# Map label IDs to labels
label_map = {0: "O", 1: "B", 2: "I"}
predicted_labels = [label_map[label_id] for label_id in predictions[0].tolist()]

# Print tokenized input text along with predicted labels
for token, label in zip(tokenizer.convert_ids_to_tokens(inputs["input_ids"].tolist()[0]), predicted_labels):
    print(token, label)

[CLS] O
the O
patient O
was O
diagnosed O
with O
stage O
iv O
lung O
aden I
##oca I
##rc I
##ino I
##ma I
. O
mutation O
analysis O
revealed O
a O
mutation O
in O
the O
e O
##gf O
##r O
gene O
. O
[SEP] O


In [61]:
train_df.to_csv("630_train_predictions.csv", index=False)
val_df.to_csv("630_val_predictions.csv", index=False)
test_df.to_csv("630_test_predictions.csv", index=False)