In [150]:
import torch
from transformers import TrainingArguments, Trainer, default_data_collator
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering
from datasets import Dataset, load_metric
import json
import pandas as pd
import numpy as np
import accelerate
from sklearn.model_selection import train_test_split

In [151]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("MPS is available. Using MPS device.")
else:
    device = torch.device("cpu")
    print("MPS device not found. Using CPU.")

MPS is available. Using MPS device.


In [152]:
# Tokenizer and model initialization for DistilBERT
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-distilled-squad")
model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased-distilled-squad")


In [153]:
# Read context data and questions
context_data_files = [
    "../NLP Processing/after_scraping/Context-Data/fine-tuning-traveltriangle-goa.json",
    "../NLP Processing/after_scraping/Context-Data/fine-tuning-traveltriangle-japan.json",
    "../NLP Processing/after_scraping/Context-Data/fine-tuning-traveltriangle-vietnam.json"
]
dataset_files = [
    "../NLP Processing/after_scraping/four_qns/fine-tuning-dataset-traveltriangle-goa.json",
    "../NLP Processing/after_scraping/four_qns/fine-tuning-dataset-traveltriangle-japan.json",
    "../NLP Processing/after_scraping/four_qns/fine-tuning-dataset-traveltriangle-vietnam.json"
]

contexts = []
questions_dataset = []
answers_text = []
answers_start = []

# Load context data
context_data = {}
for i, file_path in enumerate(context_data_files):
    with open(file_path, "r") as file:
        context_data[i] = json.load(file)

# Define questions
questions = [
    "What is the name of the attraction?",
    "What is the location of the attraction?",
    "Describe the attraction in detail.",
    "What type of attraction is it? (e.g. historical, natural, amusement, beach)"
]

# Read dataset files
for i, file_path in enumerate(dataset_files):
    with open(file_path, "r") as file:
        dataset = json.load(file)
        for entry in dataset:
            id = entry['context_index']
            for question in questions:
                if question == entry['question'] and str(id) in context_data[i].keys():
                    contexts.append(context_data[i][str(id)])
                    questions_dataset.append(entry["question"])
                    answers_text.append(entry["answer"])
                    answers_start.append(0)

# Create DataFrame
df = pd.DataFrame({
    'context': contexts,
    'question': questions_dataset,
    'answers_text': answers_text,
    'answers_start': answers_start
})

print(df.head())

def tokenize_function(examples):
    tokenized_input = tokenizer(
        examples['context'], 
        examples['question'],
        truncation=True, 
        padding='max_length',  # Ensure all examples are padded to max length
        max_length=256  # Adjust max_length as per your model's requirements
    )

    tokenized_input["start_positions"] = examples["answers_start"]
    tokenized_input["answers_text"] = examples["answers_text"]

    if len(tokenized_input.input_ids) == 0:
       pass
    return tokenized_input


# Map tokenization function to dataset
train, test = train_test_split(df, test_size=0.2)
print(train.shape)
print(test.shape)
train_dataset = Dataset.from_pandas(train).map(tokenize_function, batched=True)
test_dataset = Dataset.from_pandas(test).map(tokenize_function, batched=True)

print("Train dataset type:", type(train_dataset))
print("Test dataset type:", type(test_dataset))

print(train_dataset['input_ids'][0])
print(train_dataset['question'][0])
print(train_dataset['start_positions'][0])
print(train_dataset['answers_text'][0])

                                             context  \
0   Aguada Fort: Beautiful Ambiance  Image Source...   
1   Aguada Fort: Beautiful Ambiance  Image Source...   
2   Aguada Fort: Beautiful Ambiance  Image Source...   
3   Aguada Fort: Beautiful Ambiance  Image Source...   
4   Chapora Fort: For Selfie Lovers  Image Source...   

                                            question  \
0                What is the name of the attraction?   
1            What is the location of the attraction?   
2                 Describe the attraction in detail.   
3  What type of attraction is it? (e.g. historica...   
4                What is the name of the attraction?   

                                        answers_text  answers_start  
0                                        Aguada Fort              0  
1  Fort Aguada Rd, Aguada Fort Area, Candolim, Go...              0  
2  Sightseeing in Goa is incomplete without a vis...              0  
3                                         Hist


[ABe aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence

Train dataset type: <class 'datasets.arrow_dataset.Dataset'>
Test dataset type: <class 'datasets.arrow_dataset.Dataset'>
[101, 24113, 1516, 10889, 8702, 3746, 3120, 2023, 2103, 3658, 2090, 4057, 20996, 22426, 1998, 1996, 2712, 2348, 2045, 1521, 1055, 1037, 2843, 2000, 2156, 1998, 3325, 2182, 2021, 24113, 2003, 7687, 2124, 2005, 2049, 12090, 12486, 3568, 1010, 13063, 1037, 15890, 1999, 2028, 1997, 2049, 7884, 2030, 23812, 2003, 1037, 2442, 999, 1996, 2103, 3310, 1999, 1996, 2126, 2043, 2017, 1521, 2128, 8932, 2013, 13000, 2000, 20168, 3081, 7960, 3345, 2065, 2017, 2215, 2000, 7409, 1996, 11084, 1997, 2023, 2103, 1010, 2017, 2323, 2562, 2070, 4469, 2051, 1999, 2192, 2043, 8932, 2090, 1996, 2048, 3655, 2023, 2003, 2426, 1996, 2190, 3655, 2000, 3942, 1999, 2900, 999, 2327, 13051, 1024, 11333, 27052, 11149, 2226, 2380, 10488, 2863, 2006, 5054, 24113, 1051, 4478, 9201, 2190, 2477, 2000, 2079, 1024, 2175, 2005, 1037, 12257, 3328, 2006, 9875, 6182, 11928, 4801, 2080, 2958, 3046, 1996, 8040, 68

In [154]:
# Training arguments
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
)

# Load metric
metric = load_metric("squad")

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    data_collator=default_data_collator,
    compute_metrics=metric.compute
)




In [155]:
# Train the model
def compute_loss(model, inputs):
    outputs = model(**inputs)
    
    # Extract necessary outputs
    start_logits = outputs.start_logits
    
    # Get start positions from inputs
    start_positions = inputs.get("start_positions")
    
    # Compute the CrossEntropy loss for start positions
    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
    start_loss = loss_fct(start_logits.view(-1, 2), start_positions.view(-1))
    
    return start_loss
print("Train dataset batch size:", len(train_dataset))
print("Test dataset batch size:", len(test_dataset))

trainer.compute_loss = compute_loss
trainer.train()


# Save the fine-tuned model
trainer.save_model("fine-tuned-distilbert-model")


  0%|          | 0/81 [00:41<?, ?it/s]


Train dataset batch size: 428
Test dataset batch size: 108


  0%|          | 0/81 [00:00<?, ?it/s]

ValueError: Expected input batch_size (2048) to match target batch_size (16).