In [None]:
from google.colab import drive
import os
import json
import nltk
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
from nltk.corpus import wordnet
from nltk.tokenize import RegexpTokenizer
from nltk import pos_tag

In [None]:
drive.mount('/content/drive')
datadir = "/content/drive/My Drive/CS546Data/"

Mounted at /content/drive


In [None]:
import re

def parse_entities_from_file(file_path):
    entities = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            # Skip lines that don't contain entity definitions
            if line.startswith("<") or line.strip() == '':
                continue
            # Extract entity, ignoring hierarchy indentation
            entity = line.strip().split('\t')[-1]
            entities.append(entity)
    return entities

paths_to_ontology_files = [
    # datadir + 'Ontology/hierarchy.txt',
    datadir + 'Ontology/listEntities.txt'
    # datadir + 'Ontology/listEvents.txt',
    # datadir + 'Ontology/listPredicates.txt',
    # datadir + 'Ontology/listFacts.txt'
]
all_entities = set()

for file_path in paths_to_ontology_files:
    entities_in_file = parse_entities_from_file(file_path)
    all_entities.update(entities_in_file)

# Convert the set to a list and sort it
entities_list = sorted(list(all_entities))

# Write the entities to a text file
with open('entities_list.txt', 'w', encoding='utf-8') as f:
    for entity in entities_list:
        f.write(f"{entity}\n")

In [None]:
def load_entities(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        entities = [line.strip() for line in file.readlines()]
    return entities

entities_list = load_entities('./entities_list.txt')

In [None]:
def tag_entities_in_text(text, entities):
    for entity in entities:
        # Check if the entity is already tagged
        tagged_entity = f'[ENTITY]{entity}[/ENTITY]'
        if tagged_entity in text:
            continue  # Skip already tagged entities

        # Create a pattern that matches the entity
        pattern = re.compile(r'\b' + re.escape(entity) + r'\b', re.IGNORECASE)

        # Replace with the tagged version
        text = pattern.sub(tagged_entity, text)

    # Correct the incorrectly applied tags
    text = text.replace('[[ENTITY]entity[/ENTITY]]', '[ENTITY]').replace('[/[ENTITY]entity[/ENTITY]]', '[/ENTITY]')

    return text

In [None]:
def read_files_to_json(folder_paths, entity_list):
    data_dict = {}

    if os.path.exists('/content/minecraft_data.json'):
        with open('/content/minecraft_data.json', 'r', encoding='utf-8') as json_file:
            data_dict = json.load(json_file)

    for folder_path in folder_paths:
        for filename in os.listdir(folder_path):
            if filename.endswith(".txt"):
                topic = filename.replace(".txt", "")
                with open(os.path.join(folder_path, filename), 'r', encoding='utf-8') as file:
                    content = file.read()
                    # Tag the entities in the content
                    content_with_entities = tag_entities_in_text(content, entity_list)

                data_dict[topic] = {
                    'description': content_with_entities  # Store the content with entities tagged
                }

    data_json = json.dumps(data_dict, indent=4)

    with open('/content/minecraft_data.json', 'w', encoding='utf-8') as json_file:
        json_file.write(data_json)

    return data_dict

# Now call the function with the entity list
folder_paths = [
    datadir + 'KnowledgeDatabase/GamepediaTxt/Blocks',
    datadir + 'KnowledgeDatabase/GamepediaTxt/Entity',
    datadir + 'KnowledgeDatabase/GamepediaTxt/Items',
    datadir + 'KnowledgeDatabase/GamepediaTxt/Other'
]
minecraft_data = read_files_to_json(folder_paths, entities_list)


In [None]:
import re
from nltk import pos_tag
from nltk.corpus import stopwords
from nltk.tokenize import RegexpTokenizer
from nltk.stem import WordNetLemmatizer
from nltk.corpus import wordnet

nltk.download('punkt')
nltk.download('wordnet')
nltk.download('stopwords')
nltk.download('averaged_perceptron_tagger')

# Function to extract entities from ontology files
def parse_entities_from_file(file_path):
    entities = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            # Skip lines that don't contain entity definitions
            if line.startswith("<") or line.strip() == '':
                continue
            # Extract entity, ignoring hierarchy indentation
            entity = line.strip().split('\t')[-1]
            entities.append(entity)
    return entities

# Function to get the wordnet POS tag
def get_wordnet_pos(treebank_tag):
    if treebank_tag.startswith('J'):
        return wordnet.ADJ
    elif treebank_tag.startswith('V'):
        return wordnet.VERB
    elif treebank_tag.startswith('N'):
        return wordnet.NOUN
    elif treebank_tag.startswith('R'):
        return wordnet.ADV
    else:
        return wordnet.NOUN  # Default to noun

# Function to preprocess text
def preprocess_text(text, entity_list):
    lemmatizer = WordNetLemmatizer()
    tokenizer = RegexpTokenizer(r'\w+|[ENTITY][^/]+[/ENTITY]')  # Adjust the tokenizer to capture entity tags as single tokens
    tokens = tokenizer.tokenize(text)

    processed_tokens = []
    for token in tokens:
        if token.startswith('[ENTITY]'):
            processed_tokens.append(token)  # Keep the entity as is
        else:
            # Lowercase non-entity tokens
            token = token.lower()
            if token not in stopwords.words('english'):
                # Get POS tags for lemmatization
                pos = pos_tag([token])[0][1]
                wordnet_pos = get_wordnet_pos(pos)  # Convert the tag to wordnet format
                # Lemmatize the token
                lemmatized_token = lemmatizer.lemmatize(token, wordnet_pos)
                processed_tokens.append(lemmatized_token)
    return ' '.join(processed_tokens)

# Example usage
# entities_list = parse_entities_from_file('ontology.txt')  # Assuming 'ontology.txt' contains your ontology entities
sample_text = "To create a [ENTITY]Nether Portal[/ENTITY], you need [ENTITY]Obsidian[/ENTITY] which can be mined with a [ENTITY]Diamond Pickaxe[/ENTITY]."
preprocessed_sample = preprocess_text(sample_text, entities_list)
print(preprocessed_sample)

def load_json_data(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        return json.load(file)

# Function to save data to a JSON file
def save_json_data(data, file_path):
    with open(file_path, 'w', encoding='utf-8') as file:
        json.dump(data, file, indent=4)

minecraft_data = load_json_data('./minecraft_data.json')

for key, value in minecraft_data.items():
    if isinstance(value, dict) and 'description' in value:
        value['description'] = preprocess_text(value['description'], entities_list)

# Save the preprocessed data to a new file
save_json_data(minecraft_data, './preprocessed_minecraft_data.json')

print("Preprocessing complete. Data saved to preprocessed_minecraft_data.json")


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


create entity nether portal entity need entity obsidian entity mine entity diamond pickaxe entity
Preprocessing complete. Data saved to preprocessed_minecraft_data.json


In [None]:
!pip install cohere
!pip install tiktoken
!pip install openai



In [None]:
from openai import OpenAI

client = OpenAI(
)

def generate_qa_pairs(context, max_questions=3):
    try:
        chat_completion = client.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    # "content": f"Generate {max_questions} question, answer, and answer start index pairs based on the following text:\n\n{context}\n\nFor each answer, provide the starting index of the answer in the context, and the answer should be derived from the text directly without any modification.",
                    "content": f"Generate {max_questions} question, answer, and answer start index pairs based on the following text:\n\n{context}\n\nFor each answer, provide the starting index of the answer in the context, and the answer should be derived from the text directly without any modification.\nBe advised, some words might be surrounded by 'entity', it might look like this: entity cauldron entity. When you encounter this in the text, treat them as is and put 'entity' in the answer, this is a tagging step I applied to the text that I needed later",
                }
            ],
            model="gpt-4-1106-preview"
        )
        return chat_completion.choices[0].message.content.strip()
    except Exception as e:
        print(f"An error occurred: {e}")
        return None


# Example context from your dataset
context = "entity cauldron entity entity block entity hold water cauldron mine use pickaxe mine without pickaxe drop nothing entity cauldron entity destroyed water inside lose entity cauldron entity craft iron ingot single empty entity cauldron entity generate entity witch entity hut fill entity cauldron entity water press use entity cauldron entity water bucket entity cauldron entity also chance fill water rain upon water entity cauldron entity use fill entity glass entity bottle turn water bottle wash dye entity leather entity entity armor entity remove top pattern layer banner use entity cauldron entity press use entity cauldron entity entity glass entity bottle entity leather entity entity armor entity banner entity cauldron entity extinguish mob entity fire entity include player fall use include extinguish mob cause water level entity cauldron entity decrease one third use three time empty must refill additional us endermen fill entity cauldron entity take damage water entity cauldron entity cannot use fill empty bucket water bottle cannot use refill entity cauldron entity entity cauldron entity fill water bucket nether entity cauldron entity act power source redstone comparator entity cauldron entity behind possibly separate unpowered solid entity block entity comparator output signal strength proportional full entity cauldron entity 0 empty 1 one third full 2 two third full 3 completely full entity cauldron entity fullness define entity block entity data future plan replace entity block entity data entity block entity state arrow entity stick entity water entity cauldron entity try sneak entity cauldron entity still fall inside entity cauldron entity 0 3125 5 16 entity block entity tall entity cauldron entity contains water broken water show crack entity block entity entity item entity thrown entity cauldron entity entity hopper entity underneath entity hopper entity receive entity item entity potion brew beta 1 9 pre release 2 enable mod cauldron look three different level water empty comparison naturally occur entity cauldron entity inside entity witch entity hut example entity cauldron entity use redstone circuit"
qa_pairs = generate_qa_pairs(context)
print(qa_pairs)

Question 1: How do you destroy an entity cauldron entity, and what happens to the water inside it?
Answer: mine use pickaxe mine without pickaxe drop nothing entity cauldron entity destroyed water inside lose
Answer Start Index: 61

Question 2: What happens when a cauldron is used to wash dyed leather entity armor entity or remove a pattern from a banner?
Answer: turn water bottle wash dye entity leather entity armor entity remove top pattern layer banner
Answer Start Index: 426

Question 3: Can you fill an entity cauldron entity with a water bucket in the Nether?
Answer: cannot use fill empty bucket water bottle cannot use refill entity cauldron entity entity cauldron entity fill water bucket nether
Answer Start Index: 925


In [None]:
qas = []
for qa in qa_pairs.split('\n\n'):
    parts = qa.split('\n')
    if len(parts) >= 3:
        # Clean up the format of question and answer
        # question = parts[0].replace('Question:', '').strip()
        # answer = parts[1].replace('Answer:', '').strip()
        question = parts[0].split(': ')[1]  # Splitting on ': ' and taking the second part
        answer = parts[1].split(': ')[1]    # Similarly for the answer
        print(question)
        print(answer)
        answer_start_index_line = parts[2]
        answer_start_index = int(answer_start_index_line.split(': ')[1])

        # Use original context to find answer start index
        # answer_start = find_answer_start(original_context, answer)
        answer_start = answer_start_index

        qas.append({
            "question": question,
            "id": f"q{len(qas)+1}",
            "answers": [
                {
                    "text": answer,
                    "answer_start": answer_start
                }
            ]
        })
print(qas)

How do you destroy an entity cauldron entity, and what happens to the water inside it?
mine use pickaxe mine without pickaxe drop nothing entity cauldron entity destroyed water inside lose
What happens when a cauldron is used to wash dyed leather entity armor entity or remove a pattern from a banner?
turn water bottle wash dye entity leather entity armor entity remove top pattern layer banner
Can you fill an entity cauldron entity with a water bucket in the Nether?
cannot use fill empty bucket water bottle cannot use refill entity cauldron entity entity cauldron entity fill water bucket nether
[{'question': 'How do you destroy an entity cauldron entity, and what happens to the water inside it?', 'id': 'q1', 'answers': [{'text': 'mine use pickaxe mine without pickaxe drop nothing entity cauldron entity destroyed water inside lose', 'answer_start': 61}]}, {'question': 'What happens when a cauldron is used to wash dyed leather entity armor entity or remove a pattern from a banner?', 'id':

In [None]:
import json
from openai import OpenAI

client = OpenAI(
    api_key="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",
)

def load_data(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        return json.load(file)

def find_answer_start(context, answer):
    return context.find(answer)

def generate_qa_pairs(context, max_questions=3):
    try:
        chat_completion = client.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    # "content": f"Generate {max_questions} question, answer, and answer start index pairs based on the following text:\n\n{context}\n\nFor each answer, provide the starting index of the answer in the context, and the answer should be derived from the text directly without any modification.",
                    "content": f"Generate {max_questions} question, answer, and answer start index pairs based on the following text:\n\n{context}\n\nFor each answer, provide the starting index of the answer in the context, and the answer should be strictly derived from the text directly without any modification, and should not be a numerical answer.\nBe advised, some words might be surrounded by 'entity', it might look like this: entity cauldron entity. When you encounter this in the text, treat them as is and put 'entity' in the answer, this is a tagging step I applied to the text that I needed later.\n The questions should look like this: 'Question 1: (Question here)', and answer should look like this: 'Answer 1: (Answer here)'",
                }
            ],
            model="gpt-4-1106-preview"
        )
        return chat_completion.choices[0].message.content.strip()
    except Exception as e:
        print(f"An error occurred: {e}")
        return None


preprocessed_data = load_data('./preprocessed_minecraft_data.json')
original_data = load_data('./original_minecraft_data.json')

final_data = {"data": []}

for title, content in preprocessed_data.items():
    # Use original context for generating Q&A pairs

    preprocessed_context = content['description']
    qa_pairs = generate_qa_pairs(preprocessed_context, max_questions=3)

    qas = []
    for qa in qa_pairs.split('\n\n'):
        parts = qa.split('\n')
        if len(parts) >= 3:

            # Clean up the format of question and answer
            question = parts[0].split(': ')[1]  # Splitting on ': ' and taking the second part
            answer = parts[1].split(': ')[1]    # Similarly for the answer
            print(question)
            # print(answer)
            answer_start_index_line = parts[2]
            answer_start_index = int(answer_start_index_line.split(': ')[1])

            # Use original context to find answer start index
            # answer_start = find_answer_start(original_context, answer)
            answer_start = answer_start_index

            qas.append({
                "question": question,
                "id": f"q{len(qas)+1}",
                "answers": [
                    {
                        "text": answer,
                        "answer_start": answer_start
                    }
                ]
            })

    # Use preprocessed context in the final dataset
    preprocessed_context = preprocessed_data[title]['description']
    final_data["data"].append({
        "title": title,
        "paragraphs": [
            {
                "context": preprocessed_context,
                "qas": qas
            }
        ]
    })
    # print(final_data)

# Save the final dataset
with open('final_minecraft_dataset.json', 'w', encoding='utf-8') as outfile:
    json.dump(final_data, outfile, indent=4)
print("finished")

How can an entity cauldron entity be destroyed, and what happens to the water inside?
What happens when you use an entity cauldron entity filled with water on an entity leather entity entity armor entity or a banner?
How does the entity cauldron entity act as a power source for a redstone comparator, and how is the signal strength determined?
How many basic colors can a blank flag have in Minecraft?
How many patterns can be overlaid on a Minecraft banner?
How is the visibility of the base texture for banners determined in Minecraft?
What tool is recommended for mining andesite?
What happens if andesite is mined without a pickaxe?
Where is andesite typically found?
What is block of coal used for besides storing coal in a compact fashion?
How long does one block of coal last when used as fuel in a furnace?
Why can't a block of coal be crafted into or used as charcoal?
What must be used to mine a block of iron to ensure it drops an iron ingot?
What structure can blocks of iron be used in 

In [None]:
!pip install transformers torch



In [None]:
!pip install pandas



In [None]:
import json
import pandas as pd

with open('./final_minecraft_dataset.json', 'r') as file:
    data = json.load(file)

qa_data = []
for entry in data['data']:
    title = entry['title']
    for paragraph in entry['paragraphs']:
        context = paragraph['context']
        for qa in paragraph['qas']:
            question = qa['question']
            id = qa['id']
            for answer in qa['answers']:
                answer_text = answer['text']
                answer_start = answer['answer_start']
                qa_data.append({'title': title, 'context': context, 'question': question, 'id': id, 'answer_text': answer_text, 'answer_start': answer_start})

df = pd.DataFrame(qa_data)

In [None]:
df

Unnamed: 0,title,context,question,id,answer_text,answer_start
0,Cauldron,entity cauldron entity entity block entity hol...,How can an entity cauldron entity be destroyed...,q1,mine use pickaxe mine without pickaxe drop not...,98
1,Cauldron,entity cauldron entity entity block entity hol...,What happens when you use an entity cauldron e...,q2,turn water bottle wash dye entity leather enti...,433
2,Cauldron,entity cauldron entity entity block entity hol...,How does the entity cauldron entity act as a p...,q3,act power source redstone comparator entity ca...,1132
3,Banner,banner flag tall decorative block feature fiel...,How many basic colors can a blank flag have in...,q1,16 color,235
4,Banner,banner flag tall decorative block feature fiel...,How many patterns can be overlaid on a Minecra...,q2,six pattern,675
...,...,...,...,...,...,...
965,Shears,shear entity tool entity shepherd villager sel...,"Besides wool, what can shears also be used to ...",q2,harvest leaf tall grass fern normal 2 high dea...,268
966,Shears,shear entity tool entity shepherd villager sel...,What happens when a mooshroom is sheared?,q3,drop 5 red mushroom turn normal cow,658
967,Tools,tool item use entity player entity held perfor...,What are some examples of tools that do not st...,q1,"tool include hoe, bow, fishing rod, carrot on ...",271
968,Tools,tool item use entity player entity held perfor...,"What is an exception to the tool use, for enti...",q2,"exception entity clock, entity compass",159


In [None]:
from transformers import BertTokenizer


tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Add the special token to the tokenizer
special_tokens_dict = {'additional_special_tokens': ['[ENTITY]']}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)

def sliding_window(context, question, max_len=512, stride=128):
    question_tokens = custom_tokenize(question, tokenizer)
    context_tokens = custom_tokenize(context, tokenizer)
    # context_tokens = tokenizer.encode(context, add_special_tokens=False)

    max_context_length = max_len - len(question_tokens) - 3  # Account for [CLS] and two [SEP] tokens

    chunks = []
    start = 0
    while start < len(context_tokens):
        end = start + max_context_length
        end = min(end, len(context_tokens))

        chunk = context_tokens[start:end]
        input_ids = [tokenizer.cls_token_id] + question_tokens + [tokenizer.sep_token_id] + chunk + [tokenizer.sep_token_id]

        if len(input_ids) > max_len:
            print(f"Warning: Chunk exceeded max length with {len(input_ids)} tokens.")

        chunks.append(input_ids)
        start += stride

    return chunks

# sample_context = df.iloc[0]['context']
# sample_question = df.iloc[0]['question']
# chunks = sliding_window(sample_context, sample_question)

tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [None]:
def adjust_answer_positions(context_tokens, chunks, answer_start, answer_end):
    adjusted_positions = []
    context_start_pos = 0

    for chunk in chunks:
        chunk_end_pos = context_start_pos + len(chunk)

        if context_start_pos <= answer_start < chunk_end_pos or context_start_pos < answer_end <= chunk_end_pos:
            new_start = max(0, answer_start - context_start_pos)
            new_end = min(len(chunk), answer_end - context_start_pos)
            adjusted_positions.append((new_start, new_end))
        else:
            adjusted_positions.append((None, None))

        context_start_pos += len(chunk)

    return adjusted_positions

In [None]:
def custom_tokenize(text, tokenizer):
    # Replace 'entity' with the special token
    text = text.replace('entity', '[ENTITY]')

    # Use the tokenizer to tokenize the text
    return tokenizer.encode(text, add_special_tokens=True, truncation=True)

In [None]:
def find_answer_positions_in_tokenized(tokenized_context, answer_text, tokenizer):
    # Tokenize the answer text
    answer_tokens = tokenizer.encode(answer_text, add_special_tokens=False)

    # Search for the first occurrence of the answer token sequence in the context
    for start_position in range(len(tokenized_context) - len(answer_tokens) + 1):
        if tokenized_context[start_position:start_position + len(answer_tokens)] == answer_tokens:
            end_position = start_position + len(answer_tokens) - 1
            return start_position, end_position

    # If the answer isn't found in the context, return None
    return None, None


In [None]:


def prepare_data_with_sliding_window(df, max_len=512, stride=256):
    prepared_data = []

    for index, row in df.iterrows():
        # context = row['context'].replace('entity', '[ENTITY]')
        # question = row['question'].replace('entity', '[E]')
        # answer_text = row['answer_text'].replace('entity', '[E]')
        context = row['context']
        question = row['question']
        answer_text = row['answer_text']
        answer_start = row['answer_start']
        answer_end = answer_start + len(answer_text)

        # Tokenize the context
        # context_tokens = tokenizer.encode(context, add_special_tokens=False)
        context_tokens = custom_tokenize(context, tokenizer)


        # Apply sliding window
        chunks = sliding_window(context, question, max_len, stride)

        # Adjust answer positions for each chunk
        adjusted_positions = adjust_answer_positions(context_tokens, chunks, answer_start, answer_end)

        for chunk, (new_start, new_end) in zip(chunks, adjusted_positions):
            if new_start is not None and new_end is not None:
                prepared_data.append({
                    'input_ids': chunk,
                    'answer_start': new_start,
                    'answer_end': new_end,
                })

    return prepared_data

In [None]:
prepared_data = prepare_data_with_sliding_window(df)

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

In [None]:
class QADataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        input_ids = torch.tensor(item['input_ids'], dtype=torch.long)
        length = len(input_ids)  # Length of the sequence
        start_positions = torch.tensor(item['answer_start'], dtype=torch.long)
        end_positions = torch.tensor(item['answer_end'], dtype=torch.long)

        # Debugging: print the structure of the returned data
        # print(f"__getitem__ returned: input_ids={input_ids.shape}, length={length}, start_positions={start_positions}, end_positions={end_positions}")

        return input_ids, length, start_positions, end_positions

In [None]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    input_ids, lengths, start_positions, end_positions = zip(*batch)

    # Pad the input sequences
    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)

    # Create attention masks
    attention_masks = (input_ids_padded != tokenizer.pad_token_id).long()

    # Convert start_positions and end_positions to tensors
    start_positions = torch.tensor(start_positions)
    end_positions = torch.tensor(end_positions)

    return input_ids_padded, attention_masks, start_positions, end_positions

In [None]:
qa_dataset = QADataset(prepared_data)

In [None]:
qa_dataset[0]

(tensor([  101,   101,  2129,  2064,  2019, 30522,  6187, 21285,  4948, 30522,
          2022,  3908,  1010,  1998,  2054,  6433,  2000,  1996,  2300,  2503,
          1029,   102,   102,   101, 30522,  6187, 21285,  4948, 30522, 30522,
          3796, 30522,  2907,  2300,  6187, 21285,  4948,  3067,  2224,  4060,
          8528,  2063,  3067,  2302,  4060,  8528,  2063,  4530,  2498, 30522,
          6187, 21285,  4948, 30522,  3908,  2300,  2503,  4558, 30522,  6187,
         21285,  4948, 30522,  7477,  3707, 13749,  4140,  2309,  4064, 30522,
          6187, 21285,  4948, 30522,  9699, 30522,  6965, 30522, 12570,  6039,
         30522,  6187, 21285,  4948, 30522,  2300,  2811,  2224, 30522,  6187,
         21285,  4948, 30522,  2300, 13610, 30522,  6187, 21285,  4948, 30522,
          2036,  3382,  6039,  2300,  4542,  2588,  2300, 30522,  6187, 21285,
          4948, 30522,  2224,  6039, 30522,  3221, 30522,  5835,  2735,  2300,
          5835,  9378, 18554, 30522,  5898, 30522, 3

In [None]:
dataloader = DataLoader(qa_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

In [None]:
import torch
from torch.optim import AdamW
from transformers import BertForQuestionAnswering

accumulation_steps = 2  # For example, to mimic a batch size twice as large as currently fits in memory
# Load your model
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')

model.resize_token_embeddings(len(tokenizer))

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define the optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Number of training epochs
num_epochs = 10

# Training loop
model.train()
optimizer.zero_grad()  # Zero out gradients initially

# for epoch in range(num_epochs):
#     total_loss = 0
#     for step, batch in enumerate(dataloader):
#         input_ids, _, start_positions, end_positions = [b.to(device) for b in batch]

#         # Forward pass
#         outputs = model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions)
#         loss = outputs.loss / accumulation_steps  # Scale loss
#         total_loss += loss.item()

#         # Backward pass
#         loss.backward()

#         # Perform an optimization step every 'accumulation_steps' mini-batches
#         if (step + 1) % accumulation_steps == 0:
#             optimizer.step()
#             optimizer.zero_grad()  # Zero gradients after optimization

#     print(f"Epoch {epoch + 1}: Loss = {total_loss / len(dataloader)}")
for epoch in range(num_epochs):
    total_loss = 0
    for step, batch in enumerate(dataloader):
        # Unpack the batch data and move to the correct device
        input_ids, attention_masks, start_positions, end_positions = [b.to(device) for b in batch]

        # Forward pass: pass attention_mask to the model as well
        outputs = model(input_ids=input_ids, attention_mask=attention_masks, start_positions=start_positions, end_positions=end_positions)

        # Scale loss to account for gradient accumulation
        loss = outputs.loss / accumulation_steps
        total_loss += loss.item()

        # Backward pass
        loss.backward()

        # Perform an optimization step every 'accumulation_steps' mini-batches
        if (step + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()  # Zero gradients after optimization

    print(f"Epoch {epoch + 1}: Loss = {total_loss / len(dataloader)}")


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1: Loss = 2.9047160064964963
Epoch 2: Loss = 2.599499451486688
Epoch 3: Loss = 2.4237980612537315
Epoch 4: Loss = 2.0326933170619763
Epoch 5: Loss = 1.454059386462496
Epoch 6: Loss = 0.9901170087488074
Epoch 7: Loss = 0.6028671160078886
Epoch 8: Loss = 0.4659499442368223
Epoch 9: Loss = 0.3861115012775388
Epoch 10: Loss = 0.2892358948133494


In [None]:
model_save_path = './saved_model_directory'
model.save_pretrained(model_save_path)

tokenizer_save_path = './saved_tokenizer_directory'
tokenizer.save_pretrained(tokenizer_save_path)

('./saved_tokenizer_directory/tokenizer_config.json',
 './saved_tokenizer_directory/special_tokens_map.json',
 './saved_tokenizer_directory/vocab.txt',
 './saved_tokenizer_directory/added_tokens.json')

In [None]:
from transformers import BertTokenizer

# Load the trained model and tokenizer
model = BertForQuestionAnswering.from_pretrained('./saved_model_directory')
tokenizer = BertTokenizer.from_pretrained('./saved_tokenizer_directory')
model.to(device)
model.eval()

# Function to ask a question to the model
def ask_question(context, question):
    # Replace 'entity' with the special token in context and question
    context = context.replace('entity', '[ENTITY]')
    question = question.replace('entity', '[ENTITY]')

    # Tokenize the question and context
    inputs = tokenizer.encode_plus(question, context, return_tensors='pt', max_length=512, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}  # Move to the correct device

    # Model prediction
    with torch.no_grad():
        outputs = model(**inputs)

    answer_start_scores, answer_end_scores = outputs.start_logits, outputs.end_logits

    # Get the most likely start and end of answer
    answer_start = torch.argmax(answer_start_scores)
    answer_end = torch.argmax(answer_end_scores) + 1  # Add 1 for inclusive end index

    # Convert token indices to actual tokens
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
    answer = tokenizer.convert_tokens_to_string(tokens[answer_start:answer_end])

    return answer

In [None]:
context = "entity cauldron entity entity block entity hold water cauldron mine use pickaxe mine without pickaxe drop nothing entity cauldron entity destroyed water inside lose entity cauldron entity craft iron ingot single empty entity cauldron entity generate entity witch entity hut fill entity cauldron entity water press use entity cauldron entity water bucket entity cauldron entity also chance fill water rain upon water entity cauldron entity use fill entity glass entity bottle turn water bottle wash dye entity leather entity entity armor entity remove top pattern layer banner use entity cauldron entity press use entity cauldron entity entity glass entity bottle entity leather entity entity armor entity banner entity cauldron entity extinguish mob entity fire entity include player fall use include extinguish mob cause water level entity cauldron entity decrease one third use three time empty must refill additional us endermen fill entity cauldron entity take damage water entity cauldron entity cannot use fill empty bucket water bottle cannot use refill entity cauldron entity entity cauldron entity fill water bucket nether entity cauldron entity act power source redstone comparator entity cauldron entity behind possibly separate unpowered solid entity block entity comparator output signal strength proportional full entity cauldron entity 0 empty 1 one third full 2 two third full 3 completely full entity cauldron entity fullness define entity block entity data future plan replace entity block entity data entity block entity state arrow entity stick entity water entity cauldron entity try sneak entity cauldron entity still fall inside entity cauldron entity 0 3125 5 16 entity block entity tall entity cauldron entity contains water broken water show crack entity block entity entity item entity thrown entity cauldron entity entity hopper entity underneath entity hopper entity receive entity item entity potion brew beta 1 9 pre release 2 enable mod cauldron look three different level water empty comparison naturally occur entity cauldron entity inside entity witch entity hut example entity cauldron entity use redstone circuit"
question = "How can an entity cauldron entity be destroyed, and what happens to the water inside?"
predicted_answer = ask_question(context, question)
print("Q:", question)
print("A:", predicted_answer)

Q: How can an entity cauldron entity be destroyed, and what happens to the water inside?
A: ##ron [ENTITY] also chance fill water rain upon water [ENTITY] cauldron [ENTITY] use fill [ENTITY] glass [ENTITY] bottle turn water bottle wash dye [ENTITY] leather [ENTITY] [ENTITY] armor [ENTITY] remove top pattern layer banner use [ENTITY] cauldron [ENTITY] press use [ENTITY] cauldron [ENTITY] [ENTITY] glass [ENTITY] bottle [ENTITY] leather [ENTITY] [ENTITY] armor [ENTITY] banner [ENTITY] cauldron [ENTITY] extinguish mob [ENTITY] fire [ENTITY] include player fall use include extinguish mob cause water level [ENTITY] cauldron [ENTITY] decrease one third use three time empty must refill additional
