<a href="https://colab.research.google.com/drive/your-notebook-name" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-Tune BioBERT on PubMed Abstracts

This notebook fine-tunes BioBERT (`dmis-lab/biobert-base-cased-v1.1`) on PubMed abstracts using Masked Language Modeling (MLM) for biomedical NLP tasks, such as improving spellchecking in medical texts (e.g., correcting 'arbitysratsddion' to 'arteries'). Abstracts are downloaded from NCBI's PubMed baseline FTP server, preprocessed, and used for training in Google Colab with GPU support. The model is saved to Google Drive for local use in a spellchecker project.

**Setup**: Google Colab (GPU), PyTorch, NCBI PubMed abstracts.
**Output**: Fine-tuned BioBERT model and tokenizer saved to `/content/drive/MyDrive/biobert_finetuned`.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip uninstall -y numpy torch transformers datasets sentencepiece lxml nltk torchaudio torchvision sentence-transformers accelerate

Found existing installation: numpy 1.23.5
Uninstalling numpy-1.23.5:
  Successfully uninstalled numpy-1.23.5
Found existing installation: torch 2.0.0
Uninstalling torch-2.0.0:
  Successfully uninstalled torch-2.0.0
Found existing installation: transformers 4.41.0
Uninstalling transformers-4.41.0:
  Successfully uninstalled transformers-4.41.0
Found existing installation: datasets 2.12.0
Uninstalling datasets-2.12.0:
  Successfully uninstalled datasets-2.12.0
Found existing installation: sentencepiece 0.1.97
Uninstalling sentencepiece-0.1.97:
  Successfully uninstalled sentencepiece-0.1.97
Found existing installation: lxml 4.9.2
Uninstalling lxml-4.9.2:
  Successfully uninstalled lxml-4.9.2
Found existing installation: nltk 3.9.1
Uninstalling nltk-3.9.1:
  Successfully uninstalled nltk-3.9.1
Found existing installation: torchaudio 2.6.0+cu124
Uninstalling torchaudio-2.6.0+cu124:
  Successfully uninstalled torchaudio-2.6.0+cu124
Found existing installation: torchvision 0.21.0+cu124
Unins

In [None]:
!pip cache purge

Files removed: 143


In [None]:
!pip install numpy==1.22.4

Collecting numpy==1.22.4
  Downloading numpy-1.22.4.zip (11.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.5/11.5 MB[0m [31m107.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: numpy
  Building wheel for numpy (pyproject.toml) ... [?25l[?25hdone
  Created wheel for numpy: filename=numpy-1.22.4-cp311-cp311-linux_x86_64.whl size=17329059 sha256=5ef7159f59a51c6fdfe99a4d5b966071fb3d4b47af8dcd6af0ca933b3d4e7908
  Stored in directory: /root/.cache/pip/wheels/8e/c0/7e/1583fa989ccf57e2059824c8783691f4927f2ce7b77cec9da2
Successfully built numpy
Installing collected packages: numpy
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
peft 0.15.2 

In [None]:
!pip install torch==2.0.0
!pip install transformers==4.41.0
!pip install datasets==2.12.0
!pip install sentencepiece==0.1.97
!pip install lxml==4.9.2

Collecting torch==2.0.0
  Downloading torch-2.0.0-cp311-cp311-manylinux1_x86_64.whl.metadata (24 kB)
Collecting nvidia-cuda-nvrtc-cu11==11.7.99 (from torch==2.0.0)
  Downloading nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu11==11.7.99 (from torch==2.0.0)
  Downloading nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cuda-cupti-cu11==11.7.101 (from torch==2.0.0)
  Downloading nvidia_cuda_cupti_cu11-11.7.101-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu11==8.5.0.96 (from torch==2.0.0)
  Downloading nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu11==11.10.3.66 (from torch==2.0.0)
  Downloading nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cufft-cu11==10.9.0.58 (from torch==2.0.0)
  Downloading nvidia_cufft_cu11-10.9.0.58-py3-none-man

In [None]:
!pip install numpy==1.23.5

Collecting numpy==1.23.5
  Downloading numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.3 kB)
Downloading numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.1/17.1 MB[0m [31m82.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.0.2
    Uninstalling numpy-2.0.2:
      Successfully uninstalled numpy-2.0.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
thinc 8.3.6 requires numpy<3.0.0,>=2.0.0, but you have numpy 1.23.5 which is incompatible.
torchvision 0.21.0+cu124 requires torch==2.6.0, but you have torch 2.0.0 which is incompatible.
jax 0.5.2 requires numpy>=1.25, but you have numpy 1.23.5 which is incompatible.
treescope 0.1.9 req

In [1]:
from google.colab import drive
drive.mount('/content/drive')

# Set output directory
output_dir = '/content/drive/MyDrive/biobert_finetuned'
import os
os.makedirs(output_dir, exist_ok=True)

Mounted at /content/drive


In [2]:
import urllib.request
import gzip
import xml.etree.ElementTree as ET
import re
from nltk.tokenize import TreebankWordTokenizer

# Download a single PubMed baseline file
url = 'https://ftp.ncbi.nlm.nih.gov/pubmed/baseline/pubmed25n0001.xml.gz'
local_file = '/content/pubmed25n0001.xml.gz'

try:
    print(f"Downloading {url}...")
    urllib.request.urlretrieve(url, local_file)
except Exception as e:
    print(f"Error downloading file: {e}")
    raise

# Extract and preprocess abstracts
def extract_abstracts(xml_gz_file):
    abstracts = []
    tokenizer = TreebankWordTokenizer()
    with gzip.open(xml_gz_file, 'rt', encoding='utf-8') as f:
        tree = ET.parse(f)
        root = tree.getroot()
        for article in root.findall('.//Article'):
            abstract = article.find('.//Abstract/AbstractText')
            if abstract is not None and abstract.text:
                text = abstract.text.lower()
                text = re.sub(r'[\r\n]+', ' ', text)
                text = re.sub(r'[^\x00-\x7F]+', ' ', text)
                tokenized = tokenizer.tokenize(text)
                text = ' '.join(tokenized)
                text = re.sub(r"\s's\b", "'s", text)
                abstracts.append(text)
    return abstracts

abstracts = extract_abstracts(local_file)
print(f"Extracted {len(abstracts)} abstracts")

# Save abstracts for inspection
with open('/content/pubmed_abstracts.txt', 'w', encoding='utf-8') as f:
    f.write('\n'.join(abstracts))

# Create a dataset
from datasets import Dataset
dataset = Dataset.from_dict({'text': abstracts})
print(dataset)
print(dataset[0]['text'] if len(dataset) > 0 else 'No abstracts')

Downloading https://ftp.ncbi.nlm.nih.gov/pubmed/baseline/pubmed25n0001.xml.gz...
Extracted 15377 abstracts
Dataset({
    features: ['text'],
    num_rows: 15377
})
( -- ) -alpha-bisabolol has a primary antipeptic action depending on dosage , which is not caused by an alteration of the ph-value. the proteolytic activity of pepsin is reduced by 50 percent through addition of bisabolol in the ratio of 1/0.5. the antipeptic action of bisabolol only occurs in case of direct contact. in case of a previous contact with the substrate , the inhibiting effect is lost .


In [3]:
import random
def add_misspellings(text):
    words = text.split()
    for i in range(len(words)):
        if random.random() < 0.1:
            word = words[i]
            if len(word) > 2:
                pos = random.randint(0, len(word)-1)
                words[i] = word[:pos] + word[pos]*2 + word[pos+1:]
    return ' '.join(words)
misspelled_abstracts = [add_misspellings(abstract) for abstract in abstracts]

In [4]:
from transformers import BertTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling
from datasets import Dataset
import nltk
nltk.download('punkt')

tokenizer = BertTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.1')
model = BertForMaskedLM.from_pretrained('dmis-lab/biobert-base-cased-v1.1')

# Create dataset
dataset = Dataset.from_dict({'text': misspelled_abstracts})

# Tokenize without masking (we'll use DataCollator for masking)
def tokenize_function(examples):
    return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=128, return_special_tokens_mask=True)

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=['text'])

# Data collator for dynamic masking
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=0.15  # Mask 15% of tokens
)

# Verify dataset
print(tokenized_dataset[0].keys())  # Should include input_ids, token_type_ids, attention_mask

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/313 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Map:   0%|          | 0/15377 [00:00<?, ? examples/s]

dict_keys(['input_ids', 'token_type_ids', 'special_tokens_mask', 'attention_mask'])


In [5]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir='./biobert-finetuned',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    save_steps=10_000,
    save_total_limit=2,
    logging_dir='./logs',
    logging_steps=500,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
)

trainer.train()

# Save model and tokenizer
model.save_pretrained('/content/drive/MyDrive/biobert-finetuned')
tokenizer.save_pretrained('/content/drive/MyDrive/biobert-finetuned')



<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33m2023ad05044[0m ([33m2023ad05044-bits-pilani[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
500,3.9026
1000,2.3597
1500,2.1077
2000,1.943
2500,1.8033
3000,1.7494
3500,1.697
4000,1.6648
4500,1.5822
5000,1.5436


('/content/drive/MyDrive/biobert-finetuned/tokenizer_config.json',
 '/content/drive/MyDrive/biobert-finetuned/special_tokens_map.json',
 '/content/drive/MyDrive/biobert-finetuned/vocab.txt',
 '/content/drive/MyDrive/biobert-finetuned/added_tokens.json')

In [None]:
"""from transformers import BertTokenizer, DataCollatorForLanguageModeling

# Load BioBERT tokenizer
tokenizer = BertTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.1')

# Preprocess function
def preprocess_function(examples):
    encodings = tokenizer(examples['text'], truncation=True, padding='max_length', max_length=128, return_tensors='pt')
    return encodings

# Apply preprocessing
encoded_dataset = dataset.map(preprocess_function, batched=True, remove_columns=['text'])

# Data collator for MLM
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)"""

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/313 [00:00<?, ?B/s]

Map:   0%|          | 0/15377 [00:00<?, ? examples/s]

In [None]:
"""from transformers import BertForMaskedLM, Trainer, TrainingArguments

# Load BioBERT model
model = BertForMaskedLM.from_pretrained('dmis-lab/biobert-base-cased-v1.1')

# Training arguments
training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir=f'{output_dir}/logs',
    logging_steps=100,
    save_steps=1000,
    save_total_limit=2
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset,
    data_collator=data_collator,
)

# Train
trainer.train()

# Save model and tokenizer
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f'Model saved to {output_dir}')"""



<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33m2023ad05044[0m ([33m2023ad05044-bits-pilani[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
100,9.4438
200,6.0931
300,4.6039
400,3.6118
500,2.9774
600,2.5045
700,2.2584
800,2.0629
900,1.9862
1000,1.876


Step,Training Loss
100,9.4438
200,6.0931
300,4.6039
400,3.6118
500,2.9774
600,2.5045
700,2.2584
800,2.0629
900,1.9862
1000,1.876


Model saved to /content/drive/MyDrive/biobert_finetuned


In [6]:
from transformers import pipeline

# Test the fine-tuned model
fill_mask = pipeline('fill-mask', model=output_dir, tokenizer=output_dir)
test_sentence = 'Hypertension is a [MASK] condition.'
results = fill_mask(test_sentence)
for result in results:
    print(f"Token: {result['token_str']}, Score: {result['score']:.4f}")

# Test spellchecker-relevant sentence
test_spell = 'The patient has [MASK] blood pressure.'
results_spell = fill_mask(test_spell)
for result in results_spell:
    print(f"Token: {result['token_str']}, Score: {result['score']:.4f}")

ValueError: Could not load model /content/drive/MyDrive/biobert_finetuned with any of the following classes: (<class 'transformers.models.auto.modeling_auto.AutoModelForMaskedLM'>, <class 'transformers.models.auto.modeling_tf_auto.TFAutoModelForMaskedLM'>). See the original errors:

while loading with AutoModelForMaskedLM, an error is thrown:
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/transformers/pipelines/base.py", line 292, in infer_framework_load_model
    model = model_class.from_pretrained(model, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/models/auto/auto_factory.py", line 571, in from_pretrained
    return model_class.from_pretrained(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py", line 309, in _wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py", line 4422, in from_pretrained
    checkpoint_files, sharded_metadata = _get_resolved_checkpoint_files(
                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py", line 976, in _get_resolved_checkpoint_files
    raise EnvironmentError(
OSError: Error no file named pytorch_model.bin, model.safetensors, tf_model.h5, model.ckpt.index or flax_model.msgpack found in directory /content/drive/MyDrive/biobert_finetuned.

while loading with TFAutoModelForMaskedLM, an error is thrown:
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/transformers/pipelines/base.py", line 292, in infer_framework_load_model
    model = model_class.from_pretrained(model, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/models/auto/auto_factory.py", line 571, in from_pretrained
    return model_class.from_pretrained(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/modeling_tf_utils.py", line 2777, in from_pretrained
    raise EnvironmentError(
OSError: Error no file named tf_model.h5, model.safetensors or pytorch_model.bin found in directory /content/drive/MyDrive/biobert_finetuned.




In [8]:
# Zip the model for download
!zip -r /content/biobert_finetuned.zip /content/drive/MyDrive/biobert-finetuned

from google.colab import files
files.download('/content/biobert_finetuned.zip')
print('Download the zip file to your local machine')

  adding: content/drive/MyDrive/biobert-finetuned/ (stored 0%)
  adding: content/drive/MyDrive/biobert-finetuned/config.json (deflated 48%)
  adding: content/drive/MyDrive/biobert-finetuned/model.safetensors (deflated 7%)
  adding: content/drive/MyDrive/biobert-finetuned/tokenizer_config.json (deflated 74%)
  adding: content/drive/MyDrive/biobert-finetuned/special_tokens_map.json (deflated 42%)
  adding: content/drive/MyDrive/biobert-finetuned/vocab.txt (deflated 49%)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Download the zip file to your local machine


In [6]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [5]:
import nltk
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [9]:
import json
import torch
from transformers import BertTokenizer, BertForMaskedLM, GPT2Tokenizer, GPT2LMHeadModel
from torch.nn.functional import softmax
import nltk
from nltk.corpus import words
from nltk.tokenize import word_tokenize
from nltk.metrics.distance import edit_distance
import re

# Step 1: Download NLTK data
nltk.download('words', quiet=True)
nltk.download('punkt')

# Step 2: Load domain configuration
def load_config(domain='medical'):
    # Now access your file
    config_file = '/content/drive/MyDrive/domain_config.json'
    try:
        with open(config_file, 'r') as f:
            config = json.load(f)
        return set(config[domain]['terms'])
    except Exception as e:
        print(f"Error loading config: {e}")
        raise

# Step 3: Initialize models
def initialize_models():
    try:
        bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
        bert_model.eval()

        gpt_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        gpt_model = GPT2LMHeadModel.from_pretrained('gpt2')
        gpt_model.eval()

        return bert_tokenizer, bert_model, gpt_tokenizer, gpt_model
    except Exception as e:
        print(f"Error initializing models: {e}")
        raise

# Step 4: Error detection
def detect_errors(text, domain_terms, standard_dict):
    tokens = word_tokenize(text)
    errors = []
    domain_terms_lower = {term.lower() for term in domain_terms}
    for i, token in enumerate(tokens):
        token_lower = token.lower()
        # Skip exact matches
        if token_lower in domain_terms_lower or token_lower in standard_dict or token.isdigit():
            continue
        # Flag near-domain terms or general misspellings
        for term in domain_terms:
            if len(term) > 2 and abs(len(token_lower) - len(term.lower())) <= 2 and edit_distance(token_lower, term.lower()) <= 2:
                errors.append((i, token, term))
                break
        else:
            errors.append((i, token))
    return errors

# Step 5: Context-aware correction
def correct_with_bert(text, error_pos, error_token, bert_tokenizer, bert_model, domain_terms, standard_dict):
    tokens = word_tokenize(text)
    tokens[error_pos] = '[MASK]'
    masked_text = ' '.join(tokens)

    inputs = bert_tokenizer(masked_text, return_tensors='pt', truncation=True, max_length=128)
    mask_indices = (inputs['input_ids'] == bert_tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
    if len(mask_indices) == 0:
        return error_token, 0.0
    mask_index = mask_indices[0].item()

    with torch.no_grad():
        outputs = bert_model(**inputs)
    predictions = softmax(outputs.logits[0, mask_index], dim=-1)
    top_k_probs, top_k_indices = torch.topk(predictions, k=5)
    candidates = [(bert_tokenizer.decode([idx.item()]).strip(), prob.item()) for idx, prob in zip(top_k_indices, top_k_probs)]

    # Match case
    for candidate, confidence in candidates:
        candidate_lower = candidate.lower()
        if candidate_lower in {term.lower() for term in domain_terms} or candidate_lower in standard_dict:
            if error_token[0].isupper():
                return candidate.capitalize(), confidence
            return candidate, confidence
    return candidates[0][0] if candidates else error_token, candidates[0][1] if candidates else 0.0

def correct_with_gpt(text, error_pos, error_token, gpt_tokenizer, gpt_model, domain_terms, standard_dict):
    tokens = word_tokenize(text)
    prefix = ' '.join(tokens[:error_pos])
    inputs = gpt_tokenizer(prefix, return_tensors='pt', truncation=True, max_length=128)

    with torch.no_grad():
        outputs = gpt_model(**inputs)

    next_token_logits = outputs.logits[:, -1, :]
    top_k_probs, top_k_indices = torch.topk(softmax(next_token_logits, dim=-1), k=5)
    candidates = [(gpt_tokenizer.decode(idx.item()).strip(), prob.item()) for idx, prob in zip(top_k_indices[0], top_k_probs[0])]

    # Match case
    for candidate, confidence in candidates:
        candidate_lower = candidate.lower()
        if candidate_lower in {term.lower() for term in domain_terms} or candidate_lower in standard_dict:
            if error_token[0].isupper():
                return candidate.capitalize(), confidence
            return candidate, confidence
    return candidates[0][0] if candidates else error_token, candidates[0][1] if candidates else 0.0

# Step 6: Spellcheck function
def spellcheck(text, domain='medical'):
    domain_terms = load_config(domain)
    standard_dict = set(words.words())
    bert_tokenizer, bert_model, gpt_tokenizer, gpt_model = initialize_models()

    errors = detect_errors(text, domain_terms, standard_dict)
    corrected_text = word_tokenize(text)
    corrections = []

    for error in errors:
        pos = error[0]
        token = error[1]
        suggestion = None
        confidence = 0.0

        bert_suggestion, bert_confidence = correct_with_bert(text, pos, token, bert_tokenizer, bert_model, domain_terms, standard_dict)
        if bert_suggestion.lower() in {term.lower() for term in domain_terms} or bert_suggestion.lower() in standard_dict:
            suggestion = bert_suggestion
            confidence = bert_confidence
        else:
            gpt_suggestion, gpt_confidence = correct_with_gpt(text, pos, token, gpt_tokenizer, gpt_model, domain_terms, standard_dict)
            if gpt_suggestion.lower() in {term.lower() for term in domain_terms} or gpt_suggestion.lower() in standard_dict:
                suggestion = gpt_suggestion
                confidence = gpt_confidence

        if suggestion and suggestion != token:
            corrected_text[pos] = suggestion
            corrections.append((token, suggestion, pos, confidence))

    return ' '.join(corrected_text), corrections

# Step 7 & 8: Test and output
if __name__ == "__main__":
    text = "PubMed is an openley acessible, free databse which includes primarilly the MEDLINE databse of referances and abstrats on life scienses and biomdical topics. The United States National Library of Medicin (NLM) at the National Instituts of Health mantains the databse as part of the Entrez systm of informtion retrival. Please incld and elablorate."
    corrected_text, corrections = spellcheck(text, domain='medical')
    print(f"Original: {text}")
    print(f"Corrected: {corrected_text}")
    print("Corrections:", [(orig, corr, pos, f"{conf:.2f}") for orig, corr, pos, conf in corrections])

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/313 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

RuntimeError: cannot reshape tensor of 0 elements into shape [-1, 0] because the unspecified dimension size -1 can be any value and is ambiguous