In [1]:
import torch
import pandas as pd
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering, AdamW
from torch.utils.data import DataLoader, TensorDataset, RandomSampler
from tqdm import tqdm

In [2]:
def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', do_lower_case=True)


In [4]:
df = pd.read_parquet('datasets/squad_train.parquet')


In [5]:
questions = []
contexts = []
answers = []

In [6]:
max_length = 512
input_ids = []
attention_masks = []
start_positions = []
end_positions = []


In [7]:
stride = 128 
max_chunk_length = 384 


In [8]:
for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing dataset"):
    question = row['question']
    context = row['context']
    answer = row['answers']['text'][0]  # Assuming we take the first answer

    # Tokenize the answer separately
    answer_tokens = tokenizer.tokenize(answer)

    # Split context into chunks with sliding window approach
    chunks = []
    for i in range(0, len(context), stride):
        chunk = context[i:i + max_chunk_length]
        chunks.append(chunk)

    for chunk in chunks:
        encoding = tokenizer.encode_plus(
            question,
            chunk,
            add_special_tokens=True,
            max_length=max_length,
            truncation='only_second',
            return_tensors='pt',
            padding='max_length',
            return_attention_mask=True
        )

        input_id = encoding['input_ids'].squeeze()
        attention_mask = encoding['attention_mask'].squeeze()

        # Tokenize the chunk
        chunk_tokens = tokenizer.tokenize(chunk)

        # Find the start and end positions of the answer in the chunk
        try:
            token_start_index = chunk_tokens.index(answer_tokens[0])
            token_end_index = token_start_index + len(answer_tokens) - 1
        except ValueError:
            continue

        # Adjust the positions to account for the question tokens
        token_start_index += len(tokenizer.tokenize(question)) + 2  # +2 for [CLS] and [SEP]
        token_end_index += len(tokenizer.tokenize(question)) + 2

        input_ids.append(input_id)
        attention_masks.append(attention_mask)
        start_positions.append(token_start_index)
        end_positions.append(token_end_index)


Processing dataset:  68%|██████████████████████████████████████                  | 59577/87599 [14:17<06:43, 69.49it/s]


KeyboardInterrupt: 

In [None]:
input_ids = torch.stack(input_ids)
attention_masks = torch.stack(attention_masks)
start_positions = torch.tensor(start_positions)
end_positions = torch.tensor(end_positions)

In [None]:
batch_size = 16
dataset = TensorDataset(input_ids, attention_masks, start_positions, end_positions)
train_dataloader = DataLoader(dataset, sampler=RandomSampler(dataset), batch_size=batch_size)

In [None]:
model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased')
device = get_device()
model.to(device)


In [None]:
optimizer = AdamW(model.parameters(), lr=5e-5)


In [None]:
epochs = 3
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    model.train()

    for batch in tqdm(train_dataloader, desc="Training"):
        b_input_ids = batch[0].to(device)
        b_attention_mask = batch[1].to(device)
        b_start_positions = batch[2].to(device)
        b_end_positions = batch[3].to(device)

        optimizer.zero_grad()

        try:
            outputs = model(
                input_ids=b_input_ids,
                attention_mask=b_attention_mask,
                start_positions=b_start_positions,
                end_positions=b_end_positions
            )
        except RuntimeError as e:
            if 'CUDA out of memory' in str(e):
                print("CUDA out of memory. Switching to CPU...")
                device = torch.device('cpu')
                model.to(device)
                outputs = model(
                    input_ids=b_input_ids,
                    attention_mask=b_attention_mask,
                    start_positions=b_start_positions,
                    end_positions=b_end_positions
                )

        loss = outputs.loss
        loss.backward()
        optimizer.step()

    # Save checkpoint after each epoch
    checkpoint_path = f"checkpoint-epoch-{epoch + 1}.pt"
    torch.save(model.state_dict(), checkpoint_path)
    print(f"Checkpoint saved at {checkpoint_path}")

print("Training complete!")


In [None]:
final_model_path = "distilbert_qa_finetuned.pt"
torch.save(model.state_dict(), final_model_path)
print(f"Model saved at {final_model_path}")