In [None]:
!pip install transformers



In [5]:
import os
import json
import torch
import numpy as np
import pandas as pd
from torch import nn
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from transformers import (
    AutoTokenizer,
    AutoModelForQuestionAnswering,
    AdamW,
    get_linear_schedule_with_warmup
)

In [None]:
import requests
def download_squad_v2():
    url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json"

    print("Downloading SQuAD v2.0 dataset...")
    response = requests.get(url)
    os.makedirs('data', exist_ok=True)
    with open('data/train-v2.0.json', 'wb') as f:
        f.write(response.content)

    print("Download complete. File saved to data/train-v2.0.json")
    return 'data/train-v2.0.json'
squad_train_path = download_squad_v2()

In [None]:
def load_squad_v2(path_to_file):
    """Load SQuAD v2 dataset from file"""
    with open(path_to_file, 'r', encoding='utf-8') as f:
        squad_dict = json.load(f)
    contexts = []
    questions = []
    answers = []
    is_impossible = []

    for group in squad_dict['data']:
        for passage in group['paragraphs']:
            context = passage['context']

            for qa in passage['qas']:
                question = qa['question']
                if not qa['is_impossible']:
                    for answer in qa['answers']:
                        contexts.append(context)
                        questions.append(question)
                        answers.append({
                            'text': answer['text'],
                            'answer_start': answer['answer_start']
                        })
                        is_impossible.append(False)
                else:
                    contexts.append(context)
                    questions.append(question)
                    answers.append({
                        'text': '',
                        'answer_start': -1
                    })
                    is_impossible.append(True)

    return contexts, questions, answers, is_impossible
squad_train_path="/kaggle/input/squad-v2/train-v2.0.json"
contexts, questions, answers, is_impossible = load_squad_v2(squad_train_path)

print(f"Total examples: {len(contexts)}")
print(f"Answerable questions: {sum(not imp for imp in is_impossible)}")
print(f"Unanswerable questions: {sum(is_impossible)}")
max_samples = 15000
if len(contexts) > max_samples:
    indices = list(range(len(contexts)))
    train_indices, _ = train_test_split(
        indices,
        test_size=len(indices) - max_samples,
        stratify=is_impossible,
        random_state=42
    )

    contexts = [contexts[i] for i in train_indices]
    questions = [questions[i] for i in train_indices]
    answers = [answers[i] for i in train_indices]
    is_impossible = [is_impossible[i] for i in train_indices]
train_contexts, val_contexts, train_questions, val_questions, train_answers, val_answers = train_test_split(
    contexts, questions, answers, test_size=0.1, random_state=42
)

print(f"Train samples: {len(train_contexts)}")
print(f"Validation samples: {len(val_contexts)}")

Total examples: 130319
Answerable questions: 86821
Unanswerable questions: 43498
Train samples: 13500
Validation samples: 1500


In [None]:
class SQuADDataset(Dataset):
    def __init__(self, contexts, questions, answers, tokenizer, max_length=384):
        self.contexts = contexts
        self.questions = questions
        self.answers = answers
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.contexts)

    def __getitem__(self, idx):
        context = self.contexts[idx]
        question = self.questions[idx]
        answer = self.answers[idx]
        inputs = self.tokenizer(
            question,
            context,
            max_length=self.max_length,
            truncation="only_second",
            stride=128,
            return_overflowing_tokens=False,
            return_offsets_mapping=True,
            padding="max_length",
            return_tensors="pt"
        )
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        offset_mapping = inputs.pop("offset_mapping").tolist()
        start_positions = 0
        end_positions = 0

        if answer['text'] and answer['answer_start'] >= 0:
            answer_text = answer['text']
            start_char = answer['answer_start']
            end_char = start_char + len(answer_text)
            token_start_index = 0
            token_end_index = 0

            for i, offset in enumerate(offset_mapping):
                if offset[0] <= start_char < offset[1]:
                    token_start_index = i
                if offset[0] < end_char <= offset[1]:
                    token_end_index = i
                    break

            start_positions = token_start_index
            end_positions = token_end_index

        return {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
            "start_positions": torch.tensor(start_positions, dtype=torch.long),
            "end_positions": torch.tensor(end_positions, dtype=torch.long)
        }

In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm
import matplotlib.pyplot as plt

def exact_match_score(predictions, references):
    assert len(predictions) == len(references), "Lists must have the same length"
    matches = sum(p == r for p, r in zip(predictions, references))
    return matches / len(references) * 100

def train_epoch(model, dataloader, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    
    for batch in tqdm(dataloader, desc="Training"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        start_positions = batch["start_positions"].to(device)
        end_positions = batch["end_positions"].to(device)
        optimizer.zero_grad()

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            start_positions=start_positions,
            end_positions=end_positions
        )
        if isinstance(outputs, dict) and "loss" in outputs:
            loss = outputs["loss"]
        else:
            loss = outputs.loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()
        total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    return avg_loss

def evaluate(model, dataloader, tokenizer, device):
    model.eval()
    predictions = []
    references = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            
            if isinstance(outputs, dict) and "start_positions" in outputs:
                start_indices = outputs["start_positions"]
                end_indices = outputs["end_positions"]
            else:
                start_logits = outputs.start_logits
                end_logits = outputs.end_logits
                start_indices = torch.argmax(start_logits, dim=1)
                end_indices = torch.argmax(end_logits, dim=1)
            
            for i in range(input_ids.size(0)):
                input_id = input_ids[i].tolist()
                start_idx = start_indices[i].item()
                end_idx = end_indices[i].item()
                
                if end_idx < start_idx:
                    end_idx = start_idx
                
                if start_idx > 0 and end_idx > 0:
                    answer = tokenizer.decode(input_id[start_idx:end_idx+1], skip_special_tokens=True)
                else:
                    answer = ""
                predictions.append(answer)
                if "start_positions" in batch and "end_positions" in batch:
                    start_ref = batch["start_positions"][i].item()
                    end_ref = batch["end_positions"][i].item()
                    
                    if start_ref > 0 and end_ref > 0:
                        reference = tokenizer.decode(input_id[start_ref:end_ref+1], skip_special_tokens=True)
                    else:
                        reference = ""  # No answer (unanswerable question)!!!
                    
                    references.append(reference)
    if references:
        em_score = exact_match_score(predictions, references)
        return em_score, predictions
    else:
        return None, predictions

def train_model(model_type, train_dataset, val_dataset, tokenizer, num_epochs=6, batch_size=8, learning_rate=1e-5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    if model_type == "spanbert":
        model = SpanBERTModel()
    elif model_type == "spanbert-crf":
        model = SpanBERTCRFModel()
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    
    model.to(device)
    
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    total_steps = len(train_loader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=0,
        num_training_steps=total_steps
    )
    
    train_losses = []
    val_scores = []
    best_score = 0
    
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        
        train_loss = train_epoch(model, train_loader, optimizer, scheduler, device)
        train_losses.append(train_loss)
        print(f"Training loss: {train_loss:.4f}")
        
        val_score, _ = evaluate(model, val_loader, tokenizer, device)
        val_scores.append(val_score)
        print(f"Validation EM score: {val_score:.2f}%")

        if val_score > best_score:
            best_score = val_score
            torch.save(model.state_dict(), f"best_{model_type}_model.pt")
            print(f"New best model saved with EM score: {best_score:.2f}%")
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses)
    plt.title(f"{model_type} Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    
    plt.subplot(1, 2, 2)
    plt.plot(val_scores)
    plt.title(f"{model_type} Validation EM Score")
    plt.xlabel("Epoch")
    plt.ylabel("EM Score (%)")
    
    plt.tight_layout()
    plt.savefig(f"{model_type}_training_plots.png")
    plt.show()
    
    return model, train_losses, val_scores, best_scoreA

In [None]:
# Train SpanBERT-CRF
spanbert_crf_model, spanbert_crf_train_losses, spanbert_crf_val_scores, spanbert_crf_best_score = train_model(
    "spanbert-crf", train_dataset, val_dataset, tokenizer, num_epochs=10
)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoModelForQuestionAnswering

class CRF(nn.Module):
    def __init__(self, tagset_size, batch_first=True):
        super(CRF, self).__init__()
        self.tagset_size = tagset_size
        self.batch_first = batch_first
        self.transitions = nn.Parameter(torch.randn(tagset_size, tagset_size))

        self.transitions.data[:, 0] = -10000.0
        self.transitions.data[1, :] = -10000.0
        if tagset_size >= 5:
            self.transitions.data[2, 4] = -5000.0
    
    def forward(self, emissions, tags, mask=None):
        if mask is None:
            mask = torch.ones_like(tags, dtype=torch.bool)
            
        gold_score = self._score_sentence(emissions, tags, mask)
        forward_score = self._forward_alg(emissions, mask)
        
        # Negative log-likelihood loss: sum over batch!!!!!!
        return torch.mean(forward_score - gold_score)
    
    def _score_sentence(self, emissions, tags, mask):
        batch_size, seq_length = tags.size()
        score = torch.zeros(batch_size, device=emissions.device)
        if not self.batch_first:
            emissions = emissions.transpose(0, 1)
            tags = tags.transpose(0, 1)
            mask = mask.transpose(0, 1)

        for i in range(seq_length):
            valid_mask = mask[:, i]
            emit_scores = emissions[:, i].gather(1, tags[:, i].unsqueeze(1)).squeeze(1)
            score += emit_scores * valid_mask.float()
            if i > 0:
                trans_scores = self.transitions[tags[:, i-1], tags[:, i]]
                score += trans_scores * valid_mask.float()
                
        return score
    
    def _forward_alg(self, emissions, mask):
        batch_size, seq_length, tagset_size = emissions.size()
        if not self.batch_first:
            emissions = emissions.transpose(0, 1)
            mask = mask.transpose(0, 1)
        alpha = torch.full((batch_size, tagset_size), -10000.0, device=emissions.device)
        alpha[:, 0] = 0
        
        for i in range(seq_length):
            valid_mask = mask[:, i].unsqueeze(1).float()
            emit_scores = emissions[:, i]
            broadcast_alpha = alpha.unsqueeze(2)
            broadcast_transitions = self.transitions.unsqueeze(0)
            next_tag_var = broadcast_alpha + broadcast_transitions
            next_tag_var = torch.logsumexp(next_tag_var, dim=1)
            updated_alpha = next_tag_var + emit_scores
            alpha = alpha * (1 - valid_mask) + updated_alpha * valid_mask
            
        terminal_var = alpha + self.transitions[:, 1].unsqueeze(0)
        terminal_var = torch.logsumexp(terminal_var, dim=1)
        
        return terminal_var
    
    def decode(self, emissions, mask=None):
        if mask is None:
            mask = torch.ones(emissions.size(0), emissions.size(1), dtype=torch.bool, device=emissions.device)
            
        batch_size, seq_length, tagset_size = emissions.size()
        viterbi_scores = torch.full((batch_size, tagset_size), -10000.0, device=emissions.device)
        viterbi_scores[:, 0] = 0 
        
        backpointers = torch.zeros(batch_size, seq_length, tagset_size, dtype=torch.long, device=emissions.device)
        for i in range(seq_length):
            valid_mask = mask[:, i].unsqueeze(1).float()
            next_viterbi_scores = torch.full((batch_size, tagset_size), -10000.0, device=emissions.device)
            
            for j in range(tagset_size):
                next_tag_scores = viterbi_scores + self.transitions[:, j]
                best_prev_tag = torch.argmax(next_tag_scores, dim=1)
                next_viterbi_scores[:, j] = next_tag_scores.gather(1, best_prev_tag.unsqueeze(1)).squeeze(1)
                backpointers[:, i, j] = best_prev_tag
                
            # Add emission scores and apply mask!!!!
            next_viterbi_scores += emissions[:, i]
            viterbi_scores = viterbi_scores * (1 - valid_mask) + next_viterbi_scores * valid_mask
            
        # Transition to END_TAG!!!
        final_scores = viterbi_scores + self.transitions[:, 1].unsqueeze(0)
        best_final_tag = torch.argmax(final_scores, dim=1)
        
        # Backtrack to get the best path!!!
        best_path_scores = final_scores.gather(1, best_final_tag.unsqueeze(1)).squeeze(1)
        best_paths = torch.zeros(batch_size, seq_length, dtype=torch.long, device=emissions.device)
        best_paths[:, -1] = best_final_tag
        
        for i in range(seq_length - 1, 0, -1):
            best_tags_at_i = best_paths[:, i]
            for batch_idx in range(batch_size):
                best_tag = best_tags_at_i[batch_idx].item()
                best_paths[batch_idx, i-1] = backpointers[batch_idx, i, best_tag]
            
        return best_paths, best_path_scores


class ImprovedSpanBERTCRF(nn.Module):
    def __init__(self, model_name="SpanBERT/spanbert-base-cased", freeze_bert_layers=6):
        super(ImprovedSpanBERTCRF, self).__init__()
        self.spanbert = AutoModel.from_pretrained(model_name)
        self.hidden_dim = self.spanbert.config.hidden_size
        if freeze_bert_layers > 0:
            modules = [self.spanbert.embeddings, *self.spanbert.encoder.layer[:freeze_bert_layers]]
            for module in modules:
                for param in module.parameters():
                    param.requires_grad = False

        self.tagset_size = 5
        self.start_classifier = nn.Linear(self.hidden_dim, self.tagset_size)
        self.end_classifier = nn.Linear(self.hidden_dim, self.tagset_size)
        
        self.answerable_classifier = nn.Sequential(
            nn.Linear(self.hidden_dim, self.hidden_dim//2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(self.hidden_dim//2, 1)
        )
        
        self.output_dropout = nn.Dropout(0.2)
        self.start_crf = CRF(self.tagset_size)
        self.end_crf = CRF(self.tagset_size)
        
        self.max_span_length = 30
    
    def forward(self, input_ids, attention_mask, token_type_ids=None, start_positions=None, end_positions=None):
        outputs = self.spanbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            return_dict=True
        )
        sequence_output = outputs.last_hidden_state 
        pooled_output = sequence_output[:, 0]
        sequence_output = self.output_dropout(sequence_output)
        start_logits = self.start_classifier(sequence_output)  
        end_logits = self.end_classifier(sequence_output)
        
        answerable_logits = self.answerable_classifier(pooled_output).squeeze(-1)
        if start_positions is not None and end_positions is not None:
            batch_size, seq_len = input_ids.size()
            start_tags = torch.full((batch_size, seq_len), 2, dtype=torch.long, device=input_ids.device)
            end_tags = torch.full((batch_size, seq_len), 2, dtype=torch.long, device=input_ids.device)
            
            has_answer = ((start_positions > 0) & (end_positions > 0)).float()
        
            for i in range(batch_size):
                if start_positions[i] > 0:
                    start_tags[i, start_positions[i]] = 3
                    for j in range(start_positions[i] + 1, end_positions[i] + 1):
                        if j < seq_len:
                            start_tags[i, j] = 4
                
                if end_positions[i] > 0:
                    end_tags[i, end_positions[i]] = 3
                    for j in range(start_positions[i], end_positions[i]):
                        if j < seq_len:
                            end_tags[i, j] = 4
            
            start_crf_loss = self.start_crf(start_logits, start_tags, attention_mask)
            end_crf_loss = self.end_crf(end_logits, end_tags, attention_mask)
            answerable_loss = F.binary_cross_entropy_with_logits(answerable_logits, has_answer)
            total_loss = start_crf_loss + end_crf_loss + 2.0 * answerable_loss
            
            return {"loss": total_loss, "answerable_loss": answerable_loss}
        else:
            answerable_probs = torch.sigmoid(answerable_logits)
            is_answerable = answerable_probs > 0.5
        
            start_paths, _ = self.start_crf.decode(start_logits, attention_mask)
            end_paths, _ = self.end_crf.decode(end_logits, attention_mask)
            batch_size = input_ids.size(0)
            start_indices = torch.zeros(batch_size, dtype=torch.long, device=input_ids.device)
            end_indices = torch.zeros(batch_size, dtype=torch.long, device=input_ids.device)
            
            for i in range(batch_size):
                if not is_answerable[i]:
                    continue
                start_b_positions = (start_paths[i] == 3).nonzero(as_tuple=True)[0]
                start_i_positions = (start_paths[i] == 4).nonzero(as_tuple=True)[0]
                
                end_b_positions = (end_paths[i] == 3).nonzero(as_tuple=True)[0]
                end_i_positions = (end_paths[i] == 4).nonzero(as_tuple=True)[0]
                
                if len(start_b_positions) > 0:
                    start_idx = start_b_positions[0].item()
                elif len(start_i_positions) > 0:
                    start_idx = start_i_positions[0].item()
                else:
                    continue
                if len(end_b_positions) > 0:
                    end_candidates = end_b_positions[end_b_positions >= start_idx]
                    if len(end_candidates) > 0:
                        end_idx = end_candidates[-1].item()
                    else:
                        continue
                elif len(end_i_positions) > 0:
                    end_candidates = end_i_positions[end_i_positions >= start_idx]
                    if len(end_candidates) > 0:
                        end_idx = end_candidates[-1].item()
                    else:
                        continue
                else:
        
                    continue
                
                if end_idx - start_idx > self.max_span_length:
                    end_idx = start_idx + self.max_span_length
    
                start_indices[i] = start_idx
                end_indices[i] = end_idx
            return {
                "start_logits": start_logits,
                "end_logits": end_logits,
                "start_positions": start_indices,
                "end_positions": end_indices,
                "answerable_probs": answerable_probs
            }


def train_improved_model(train_dataset, val_dataset, tokenizer, num_epochs=10, batch_size=8, learning_rate=2e-5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset, 
        batch_size=batch_size
    )
    model = ImprovedSpanBERTCRF()
    model.to(device)
    bert_params = list(model.spanbert.parameters())
    crf_params = list(model.start_crf.parameters()) + list(model.end_crf.parameters())
    classifier_params = list(model.start_classifier.parameters()) + \
                        list(model.end_classifier.parameters()) + \
                        list(model.answerable_classifier.parameters())
    
    optimizer = torch.optim.AdamW([
        {'params': bert_params, 'lr': learning_rate},
        {'params': crf_params, 'lr': learning_rate * 5},  # Higher LR for CRF!!!!
        {'params': classifier_params, 'lr': learning_rate * 2}  # Higher LR for classifiers!!!!!
    ], weight_decay=0.01)
    
    # Learning rate scheduler with warmup!!!!!!
    from transformers import get_cosine_schedule_with_warmup
    
    num_warmup_steps = len(train_loader) 
    num_training_steps = len(train_loader) * num_epochs
    
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps
    )
    train_losses = []
    val_scores = []
    best_score = 0
    
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        model.train()
        epoch_loss = 0
        
        for batch in tqdm(train_loader, desc="Training"):
            batch = {k: v.to(device) for k, v in batch.items()}
            optimizer.zero_grad()
            outputs = model(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                token_type_ids=batch.get("token_type_ids"),
                start_positions=batch["start_positions"],
                end_positions=batch["end_positions"]
            )
            loss = outputs["loss"]
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(train_loader)
        train_losses.append(avg_loss)
        print(f"Training loss: {avg_loss:.4f}")
        model.eval()
        all_predictions = []
        all_references = []
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Evaluating"):
                batch = {k: v.to(device) for k, v in batch.items()}
                
                outputs = model(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    token_type_ids=batch.get("token_type_ids")
                )
                
                start_indices = outputs["start_positions"]
                end_indices = outputs["end_positions"]
                answerable_probs = outputs["answerable_probs"]
                for i in range(batch["input_ids"].size(0)):
                    input_id = batch["input_ids"][i].tolist()
                    start_idx = start_indices[i].item()
                    end_idx = end_indices[i].item()
                    is_answerable = answerable_probs[i].item() > 0.5
                    if not is_answerable or start_idx == 0 or end_idx == 0:
                        answer = ""
                    else:
                        if end_idx < start_idx:
                            end_idx = start_idx
                        answer = tokenizer.decode(input_id[start_idx:end_idx+1], skip_special_tokens=True)
                    
                    all_predictions.append(answer)
                    start_ref = batch["start_positions"][i].item()
                    end_ref = batch["end_positions"][i].item()
                    
                    if start_ref > 0 and end_ref > 0: 
                        reference = tokenizer.decode(input_id[start_ref:end_ref+1], skip_special_tokens=True)
                    else:
                        reference = ""  
                    
                    all_references.append(reference)
            for i in range(min(5, len(all_predictions))):
                print(f"Example {i+1}:")
                print(f"  Prediction: '{all_predictions[i]}'")
                print(f"  Reference:  '{all_references[i]}'")
                print(f"  Match:      {all_predictions[i] == all_references[i]}")
                print()
            em_score = exact_match_score(all_predictions, all_references)
            val_scores.append(em_score)
            print(f"Validation EM score: {em_score:.2f}%")
            if em_score > best_score:
                best_score = em_score
                torch.save(model.state_dict(), "best_improved_spanbert_crf_model.pt")
                print(f"New best model saved with EM score: {best_score:.2f}%")
    import matplotlib.pyplot as plt
    
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses)
    plt.title("Improved SpanBERT-CRF Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    
    plt.subplot(1, 2, 2)
    plt.plot(val_scores)
    plt.title("Improved SpanBERT-CRF Validation EM Score")
    plt.xlabel("Epoch")
    plt.ylabel("EM Score (%)")
    
    plt.tight_layout()
    plt.savefig("improved_spanbert_crf_training_plots.png")
    plt.show()
    
    return model, train_losses, val_scores, best_score

In [None]:

model, losses, scores, best_score = train_improved_model(
    train_dataset, 
    val_dataset, 
    tokenizer, 
    num_epochs=10,
    batch_size=8, 
    learning_rate=2e-4
)

Some weights of BertModel were not initialized from the model checkpoint at SpanBERT/spanbert-base-cased and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/10


Training: 100%|██████████| 1688/1688 [30:53<00:00,  1.10s/it]


Training loss: 2849.8808


Evaluating: 100%|██████████| 188/188 [02:32<00:00,  1.24it/s]


Example 1:
  Prediction: ''
  Reference:  'raze'
  Match:      False

Example 2:
  Prediction: 'mediterranean theater'
  Reference:  'mediterranean'
  Match:      False

Example 3:
  Prediction: ''
  Reference:  'warmth, companionship, and even protection'
  Match:      False

Example 4:
  Prediction: ''
  Reference:  'estonian volunteers in finland'
  Match:      False

Example 5:
  Prediction: ''
  Reference:  'david shackleton'
  Match:      False

Validation EM score: 35.40%
New best model saved with EM score: 35.40%
Epoch 2/10


Training: 100%|██████████| 1688/1688 [30:49<00:00,  1.10s/it]


Training loss: 2779.1069


Evaluating: 100%|██████████| 188/188 [02:32<00:00,  1.24it/s]


Example 1:
  Prediction: 'raze'
  Reference:  'raze'
  Match:      True

Example 2:
  Prediction: ''
  Reference:  'mediterranean'
  Match:      False

Example 3:
  Prediction: ''
  Reference:  'warmth, companionship, and even protection'
  Match:      False

Example 4:
  Prediction: ''
  Reference:  'estonian volunteers in finland'
  Match:      False

Example 5:
  Prediction: 'david shackleton'
  Reference:  'david shackleton'
  Match:      True

Validation EM score: 39.53%
New best model saved with EM score: 39.53%
Epoch 3/10


Training: 100%|██████████| 1688/1688 [30:47<00:00,  1.09s/it]


Training loss: 2699.9993


Evaluating: 100%|██████████| 188/188 [02:32<00:00,  1.24it/s]


Example 1:
  Prediction: 'raze'
  Reference:  'raze'
  Match:      True

Example 2:
  Prediction: 'mediterranean theater'
  Reference:  'mediterranean'
  Match:      False

Example 3:
  Prediction: ''
  Reference:  'warmth, companionship, and even protection'
  Match:      False

Example 4:
  Prediction: 'soomepoisid ) was formed out of estonian volunteers in finland'
  Reference:  'estonian volunteers in finland'
  Match:      False

Example 5:
  Prediction: 'david shackleton'
  Reference:  'david shackleton'
  Match:      True

Validation EM score: 29.20%
Epoch 4/10


Training: 100%|██████████| 1688/1688 [30:51<00:00,  1.10s/it]


Training loss: 2602.0010


Evaluating: 100%|██████████| 188/188 [02:33<00:00,  1.23it/s]


Example 1:
  Prediction: 'raze'
  Reference:  'raze'
  Match:      True

Example 2:
  Prediction: 'mediterranean theater'
  Reference:  'mediterranean'
  Match:      False

Example 3:
  Prediction: ''
  Reference:  'warmth, companionship, and even protection'
  Match:      False

Example 4:
  Prediction: ''
  Reference:  'estonian volunteers in finland'
  Match:      False

Example 5:
  Prediction: 'david shackleton'
  Reference:  'david shackleton'
  Match:      True

Validation EM score: 39.53%
Epoch 5/10


Training: 100%|██████████| 1688/1688 [30:52<00:00,  1.10s/it]


Training loss: 2484.5961


Evaluating: 100%|██████████| 188/188 [02:31<00:00,  1.24it/s]


Example 1:
  Prediction: ''
  Reference:  'raze'
  Match:      False

Example 2:
  Prediction: 'mediterranean theater'
  Reference:  'mediterranean'
  Match:      False

Example 3:
  Prediction: ''
  Reference:  'warmth, companionship, and even protection'
  Match:      False

Example 4:
  Prediction: 'soomepoisid'
  Reference:  'estonian volunteers in finland'
  Match:      False

Example 5:
  Prediction: 'david shackleton'
  Reference:  'david shackleton'
  Match:      True

Validation EM score: 42.47%
New best model saved with EM score: 42.47%
Epoch 6/10


Training:   7%|▋         | 118/1688 [02:09<28:48,  1.10s/it]