In [None]:
# ! pip install datasets evaluate transformers[sentencepiece]
# ! pip install accelerate 

In [140]:
from transformers import AutoModelForMaskedLM, AutoTokenizer

model_checkpoint = "distilbert-base-uncased"
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [141]:
text = "This is a great [MASK]."

In [142]:
import torch

inputs = tokenizer(text, return_tensors="pt")
token_logits = model(**inputs).logits

mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
    print(f"'>>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}'")

'>>> This is a great deal.'
'>>> This is a great success.'
'>>> This is a great adventure.'
'>>> This is a great idea.'
'>>> This is a great feat.'


In [143]:
with open("dataset_gl10_lt15.txt", 'r', encoding='utf-8') as f:
    sent = [line.strip() for line in f]

print(type(sent))
print(sent[:3])

<class 'list'>
['Diagnosis You may not know you have atrial fibrillation AFib', 'The condition may be found when a health checkup is done for another reason', 'This quick and painless test measures the electrical activity of the heart']


In [85]:
# from datasets import Dataset
# from datasets import DatasetDict
# from sklearn.model_selection import train_test_split

# train, val, _, _ = train_test_split(sent, sent, train_size=0.8, random_state=1)

# train = [{"text": sentence} for sentence in train]
# train = Dataset.from_list(train)

# val = [{"text": sentence} for sentence in val]
# val = Dataset.from_list(val)

# dataset = DatasetDict({
#     "train": train,
#     "validation": val
# })

# print(dataset)

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 35158
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 8790
    })
})


In [88]:
from datasets import Dataset
from datasets import DatasetDict

train = train = [{"text": sentence} for sentence in sent]
train = Dataset.from_list(train)

dataset = DatasetDict({
    "train": train,
    # "validation": val
})

print(dataset)

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 43948
    })
})


In [89]:
sample = dataset["train"].shuffle(seed=42).select(range(3))

for row in sample:
    print(f">>> {row['text']}'")

>>> Does any type of activity ease the pain or worsen it'
>>> If pericardial effusion signs and symptoms do occur, they might include'
>>> If you miss a dose of levothyroxine, take two pills the next day'


In [90]:
def tokenize_function(examples):
    result = tokenizer(examples["text"])
    if tokenizer.is_fast:
        result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
    return result


# Use batched=True to activate fast multithreading!
tokenized_datasets = dataset.map(
    tokenize_function, batched=True, remove_columns=["text"]
)
tokenized_datasets

Map:   0%|          | 0/43948 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids'],
        num_rows: 43948
    })
})

In [91]:
tokenizer.model_max_length

512

In [92]:
chunk_size = 128

In [93]:
tokenized_samples = tokenized_datasets["train"][:3]

for idx, sample in enumerate(tokenized_samples["input_ids"]):
    print(f"'>>> sentence {idx} length: {len(sample)}'")

'>>> sentence 0 length: 16'
'>>> sentence 1 length: 17'
'>>> sentence 2 length: 15'


In [94]:
concatenated_examples = {
    k: sum(tokenized_samples[k], []) for k in tokenized_samples.keys()
}
total_length = len(concatenated_examples["input_ids"])
print(f"'>>> Concatenated sentences length: {total_length}'")

'>>> Concatenated sentences length: 48'


In [95]:
chunks = {
    k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
    for k, t in concatenated_examples.items()
}

for chunk in chunks["input_ids"]:
    print(f"'>>> Chunk length: {len(chunk)}'")

'>>> Chunk length: 48'


In [96]:
def group_texts(examples):
    # Concatenate all texts
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    # Compute length of concatenated texts
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the last chunk if it's smaller than chunk_size
    total_length = (total_length // chunk_size) * chunk_size
    # Split by chunks of max_len
    result = {
        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concatenated_examples.items()
    }
    # Create a new labels column
    result["labels"] = result["input_ids"].copy()
    return result

In [97]:
lm_datasets = tokenized_datasets.map(group_texts, batched=True)
lm_datasets

Map:   0%|          | 0/43948 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 5772
    })
})

In [98]:
tokenizer.decode(lm_datasets["train"][1]["input_ids"])

'device shows how the heart is beating while you do your daily activities [SEP] [CLS] it may be used to see how often you have an afib episode [SEP] [CLS] for example, you may need one if you ve had an unexplained stroke echocardiogram [SEP] [CLS] sound waves are used to create images of the beating heart [SEP] [CLS] this test can show how blood flows through the heart and heart valves [SEP] [CLS] a chest x ray shows the condition of the lungs and heart [SEP] [CLS] more information atrial fibrillation care at echocardiogramelectrocardiogram ecg or ekg ep studyholter monitorx rays'

In [99]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

In [100]:
samples = [lm_datasets["train"][i] for i in range(2)]
for sample in samples:
    _ = sample.pop("word_ids")

for chunk in data_collator(samples)["input_ids"]:
    print(f">>> {tokenizer.decode(chunk)}'")

>>> [CLS] diagnosis you may not know you have atrial [MASK]brillation afib [SEP] [CLS] [MASK] [MASK] [MASK] [MASK] found when a health checkup is done for another reason [SEP] [CLS] this quick and painless test [MASK] the [MASK] activity of the heart [SEP] [CLS] sticky patches called electrodes are placed on the chestologist sometimes the [MASK] and legs [SEP] [CLS] wires connect the electrodes to a computer, which [MASK] or displays [MASK] test results [SEP] [CLS] it s worn [MASK] a day [MASK] two while you do your regular activities [SEP] [CLS] some [MASK] automatically record when an irregular [MASK] rhythm is detected [SEP] [CLS] [MASK] device records [MASK] heartbeat continuously for up to three years [SEP] [CLS] the'
>>> device shows how the heart [MASK] beating while you do your daily activities [SEP] [CLS] [MASK] may [MASK] used [MASK] see how often you [MASK] an afib episode [SEP] [CLS] for example, you may need one if you ve had an unexplained stroke echocard [MASK]gram [SEP]

In [101]:
import collections
import numpy as np

from transformers import default_data_collator

wwm_probability = 0.2


def whole_word_masking_data_collator(features):
    for feature in features:
        word_ids = feature.pop("word_ids")

        # Create a map between words and corresponding token indices
        mapping = collections.defaultdict(list)
        current_word_index = -1
        current_word = None
        for idx, word_id in enumerate(word_ids):
            if word_id is not None:
                if word_id != current_word:
                    current_word = word_id
                    current_word_index += 1
                mapping[current_word_index].append(idx)

        # Randomly mask words
        mask = np.random.binomial(1, wwm_probability, (len(mapping),))
        input_ids = feature["input_ids"]
        labels = feature["labels"]
        new_labels = [-100] * len(labels)
        for word_id in np.where(mask)[0]:
            word_id = word_id.item()
            for idx in mapping[word_id]:
                new_labels[idx] = labels[idx]
                input_ids[idx] = tokenizer.mask_token_id
        feature["labels"] = new_labels

    return default_data_collator(features)

In [102]:
samples = [lm_datasets["train"][i] for i in range(2)]
batch = whole_word_masking_data_collator(samples)

for chunk in batch["input_ids"]:
    print(f">>> {tokenizer.decode(chunk)}")

>>> [CLS] diagnosis [MASK] may not know you have [MASK] [MASK] fibrillation afib [SEP] [CLS] the [MASK] [MASK] be found when a health checkup is done for another reason [SEP] [CLS] this quick and painless test measures the electrical activity [MASK] [MASK] [MASK] [SEP] [CLS] sticky patches called electrodes are placed on [MASK] chest and sometimes the arms and [MASK] [SEP] [CLS] wires [MASK] [MASK] electrodes [MASK] a computer, which prints or [MASK] the test results [SEP] [CLS] it [MASK] worn [MASK] a day or [MASK] while you do your regular activities [SEP] [CLS] [MASK] devices automatically record when an irregular heart [MASK] is detected [SEP] [CLS] this device records the heartbeat continuously for up to three years [SEP] [CLS] the
>>> device [MASK] how the heart is beating while you [MASK] [MASK] daily [MASK] [SEP] [CLS] [MASK] may [MASK] used [MASK] [MASK] how often you have [MASK] [MASK] [MASK] episode [SEP] [CLS] for example, you may [MASK] one if you ve [MASK] [MASK] unexplai

In [103]:
train_size = 0.9
test_size = 0.1

downsampled_dataset = lm_datasets["train"].train_test_split(
    train_size=train_size, test_size=test_size, seed=42
)
downsampled_dataset

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 5194
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 578
    })
})

In [104]:
from transformers import TrainingArguments

batch_size = 64
# Show the training loss with every epoch
logging_steps = len(downsampled_dataset["train"]) // batch_size
model_name = model_checkpoint.split("/")[-1]

training_args = TrainingArguments(
    output_dir=f"{model_name}-finetuned",
    # overwrite_output_dir=True,
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
)

In [105]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=downsampled_dataset["train"],
    eval_dataset=downsampled_dataset["test"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [107]:
trainer.train()

  0%|          | 0/246 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 2.0009605884552, 'eval_runtime': 2.3043, 'eval_samples_per_second': 250.835, 'eval_steps_per_second': 4.34, 'epoch': 1.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.9535224437713623, 'eval_runtime': 1.9858, 'eval_samples_per_second': 291.072, 'eval_steps_per_second': 5.036, 'epoch': 2.0}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 1.937943935394287, 'eval_runtime': 2.0143, 'eval_samples_per_second': 286.952, 'eval_steps_per_second': 4.965, 'epoch': 3.0}
{'train_runtime': 293.115, 'train_samples_per_second': 53.16, 'train_steps_per_second': 0.839, 'train_loss': 2.0671412770341084, 'epoch': 3.0}


TrainOutput(global_step=246, training_loss=2.0671412770341084, metrics={'train_runtime': 293.115, 'train_samples_per_second': 53.16, 'train_steps_per_second': 0.839, 'train_loss': 2.0671412770341084, 'epoch': 3.0})

In [108]:
from transformers import pipeline

mask_filler = pipeline(
    "fill-mask", model=model, tokenizer=tokenizer
)

In [123]:
# Diagnosis You may not know you have atrial fibrillation AFib
test_sentences =["Diagnosis [MASK] may not know you have atrial fibrillation AFib",
                 "Diagnosis You [MASK] not know you have atrial fibrillation AFib",
                 "Diagnosis You may not [MASK] you have atrial fibrillation AFib",
                 "Diagnosis You may not know you [MASK] atrial fibrillation AFib",
                 "Diagnosis You may not know you have [MASK] fibrillation AFib",
                 "Diagnosis You may not know you have atrial [MASK] AFib"]

for text in test_sentences:
    preds = mask_filler(text)
    print("-----")
    # for pred in preds:
    #     print(f"> {pred['sequence']}")
    print(preds[0]['sequence'])


-----
diagnosis you may not know you have atrial fibrillation afib
-----
diagnosis you may not know you have atrial fibrillation afib
-----
diagnosis you may not know you have atrial fibrillation afib
-----
diagnosis you may not know you have atrial fibrillation afib
-----
diagnosis you may not know you have a fibrillation afib
-----
diagnosis you may not know you have atrial valve afib


In [139]:
text = "Cardioversion therapyIf atrial fibrillation symptoms are bothersome or if this is the first AFib episode, a doctor may try to reset the heart rhythm using a procedure called cardioversion."
masked = "Cardioversion therapyIf atrial fibrillation symptoms are bothersome or if this is the first AFib episode, a [MASK] may try to reset the [MASK] rhythm using a procedure called cardioversion."

preds = mask_filler(masked)


print(f"Original text: \t {text}\n")

for pred in preds:
    # print(f"\t >{pred['sequence']}")
    print(pred)


Original text: 	 Cardioversion therapyIf atrial fibrillation symptoms are bothersome or if this is the first AFib episode, a doctor may try to reset the heart rhythm using a procedure called cardioversion.

[{'score': 0.6329159140586853, 'token': 3460, 'token_str': 'doctor', 'sequence': '[CLS] cardioversion therapyif atrial fibrillation symptoms are bothersome or if this is the first afib episode, a doctor may try to reset the [MASK] rhythm using a procedure called cardioversion. [SEP]'}, {'score': 0.09803592413663864, 'token': 9431, 'token_str': 'surgeon', 'sequence': '[CLS] cardioversion therapyif atrial fibrillation symptoms are bothersome or if this is the first afib episode, a surgeon may try to reset the [MASK] rhythm using a procedure called cardioversion. [SEP]'}, {'score': 0.06754632294178009, 'token': 2711, 'token_str': 'person', 'sequence': '[CLS] cardioversion therapyif atrial fibrillation symptoms are bothersome or if this is the first afib episode, a person may try to res