In [1]:
import re
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer
from datasets import Dataset
from sklearn.metrics import precision_recall_fscore_support

In [2]:
# Load the datasets

job_postings_df = pd.read_csv('../datasets/job_postings/computing_desc_job_posting.csv')  # Adjust filename as needed



print("\nJob postings dataset structure:")
print(job_postings_df.head())
print(job_postings_df.columns)


Job postings dataset structure:
      job_id        company_name                        title  \
0   11009123   PGAV Destinations            project architect   
1   69333422     Staffing Theory    product marketing manager   
2  133130219                 NaN            software engineer   
3  175485704                GOYT            software engineer   
4  266825034  Recruitment Design  software support specialist   

                                         description  max_salary pay_period  \
0  PGAV Destinations is seeking a self-motivated ...         NaN        NaN   
1  A leading pharmaceutical company committed to ...         NaN        NaN   
2  Education Bachelor's degree in software, math,...         NaN        NaN   
3  Job Description:GOYT is seeking a skilled and ...         NaN        NaN   
4  Are you driven by the thrill of solving proble...     65000.0     YEARLY   

                        location  company_id  views  med_salary  ...  \
0                   St Louis,

In [3]:
# Load the datasets

resume_df = pd.read_csv('combined_resume_data.csv')  # Adjust filename as needed



print("\nResume dataset structure:")
print(resume_df.head())
print(resume_df.columns)


Resume dataset structure:
                                            raw_text                    role  \
0  C:\Workspace\java\scrape_indeed\dba_part_1\1.h...  Database Administrator   
1  C:\Workspace\java\scrape_indeed\dba_part_1\10....  Database Administrator   
2  C:\Workspace\java\scrape_indeed\dba_part_1\100...  Database Administrator   
3  C:\Workspace\java\scrape_indeed\dba_part_1\100...  Database Administrator   
4  C:\Workspace\java\scrape_indeed\dba_part_1\100...  Database Administrator   

                                          clean_text  \
0  database administrator database administrator ...   
1  database administrator sql microsoft powerpoin...   
2  oracle database administrator oracle database ...   
3  amazon redshift administrator and etl develope...   
4  scrum master oracle database administrator scr...   

                  source_file  
0  Database_Administrator.csv  
1  Database_Administrator.csv  
2  Database_Administrator.csv  
3  Database_Administrator.c

In [4]:
# Add resume_id as primary key
resume_df.reset_index(inplace=True)
resume_df.rename(columns={'index': 'resume_id'}, inplace=True)

In [5]:
ENTITY_TYPES = [
    "PL",      # Programming Language
    "FW",      # Framework
    "DB",      # Database
    "CP",      # Cloud Platform
    "DO",      # DevOps
    "NS",      # Network & Security
    "DAS",     # Data Analysis & Science
    "SWE",     # Software Engineering
    "PM",      # Project Management
    "EC",      # Education Certification
    "SS"       # Soft Skills
]

In [6]:
keyword_mappings = {
        "PL": ["python", "java", "javascript", "c++", "c#", "go", "rust", "php", "ruby", "swift", 
               "typescript", "kotlin", "scala", "perl", "bash", "powershell", "r", "matlab", "dart", "lua"],
        
        "FW": ["react", "angular", "vue", "django", "flask", "spring", "laravel", "express", "symfony", 
               "bootstrap", "jquery", "rails", "asp.net", "node.js", "next.js", "flutter", "xamarin", "qt"],
        
        "DB": ["sql", "mysql", "postgresql", "mongodb", "oracle", "cassandra", "redis", "dynamodb", 
               "sqlite", "mariadb", "couchdb", "neo4j", "firebase", "elasticsearch", "snowflake", "bigtable"],
        
        "CP": ["aws", "azure", "gcp", "google cloud", "heroku", "digitalocean", "alibaba", "ibm cloud", 
               "oracle cloud", "salesforce", "s3", "ec2", "lambda", "kubernetes", "eks", "gke", "aks"],
        
        "DO": ["docker", "kubernetes", "jenkins", "gitlab", "github actions", "circleci", "terraform", 
               "ansible", "puppet", "chef", "prometheus", "grafana", "elk", "ci/cd", "devops", "sre"],
        
        "NS": ["firewall", "vpn", "encryption", "ssl", "tls", "authentication", "authorization", 
               "oauth", "jwt", "kerberos", "ldap", "penetration testing", "security", "cybersecurity"],
        
        "DAS": ["machine learning", "data mining", "pandas", "numpy", "tensorflow", "pytorch", "big data", 
                "hadoop", "spark", "tableau", "power bi", "data science", "statistics", "ai", "nlp"],
        
        "SWE": ["agile", "scrum", "object-oriented", "oop", "design patterns", "architecture", 
                "microservices", "rest api", "soap", "uml", "solid", "tdd", "bdd", "clean code"],
        
        "PM": ["jira", "confluence", "trello", "asana", "scrum master", "product owner", "kanban", 
               "project management", "pmp", "agile", "waterfall", "lean", "sprint planning"],
        
        "EC": ["bachelor", "master", "phd", "certificate", "certification", "aws certified", "cisco", 
               "comptia", "microsoft certified", "google certified", "pmp", "itil", "cisa", "cissp"],
        
        "SS": ["teamwork", "communication", "leadership", "problem-solving", "creativity", "collaboration", 
               "time management", "adaptability", "critical thinking", "interpersonal", "presentation"]
    }

In [7]:

def preprocess_text(text):
    """Clean and prepare text for NER"""
    if not isinstance(text, str):
        return ""
    
    # Remove extra whitespace and normalize text
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def extract_entities(text, ner_pipeline):
    """Extract entities from text using BERT NER and keyword matching"""
    # Define keyword lists for each entity type
    keyword_mappings = {
        "PL": ["python", "java", "javascript", "c++", "c#", "go", "rust", "php", "ruby", "swift", 
               "typescript", "kotlin", "scala", "perl", "bash", "powershell", "r", "matlab", "dart", "lua"],
        
        "FW": ["react", "angular", "vue", "django", "flask", "spring", "laravel", "express", "symfony", 
               "bootstrap", "jquery", "rails", "asp.net", "node.js", "next.js", "flutter", "xamarin", "qt"],
        
        "DB": ["sql", "mysql", "postgresql", "mongodb", "oracle", "cassandra", "redis", "dynamodb", 
               "sqlite", "mariadb", "couchdb", "neo4j", "firebase", "elasticsearch", "snowflake", "bigtable"],
        
        "CP": ["aws", "azure", "gcp", "google cloud", "heroku", "digitalocean", "alibaba", "ibm cloud", 
               "oracle cloud", "salesforce", "s3", "ec2", "lambda", "kubernetes", "eks", "gke", "aks"],
        
        "DO": ["docker", "kubernetes", "jenkins", "gitlab", "github actions", "circleci", "terraform", 
               "ansible", "puppet", "chef", "prometheus", "grafana", "elk", "ci/cd", "devops", "sre"],
        
        "NS": ["firewall", "vpn", "encryption", "ssl", "tls", "authentication", "authorization", 
               "oauth", "jwt", "kerberos", "ldap", "penetration testing", "security", "cybersecurity"],
        
        "DAS": ["machine learning", "data mining", "pandas", "numpy", "tensorflow", "pytorch", "big data", 
                "hadoop", "spark", "tableau", "power bi", "data science", "statistics", "ai", "nlp"],
        
        "SWE": ["agile", "scrum", "object-oriented", "oop", "design patterns", "architecture", 
                "microservices", "rest api", "soap", "uml", "solid", "tdd", "bdd", "clean code"],
        
        "PM": ["jira", "confluence", "trello", "asana", "scrum master", "product owner", "kanban", 
               "project management", "pmp", "agile", "waterfall", "lean", "sprint planning"],
        
        "EC": ["bachelor", "master", "phd", "certificate", "certification", "aws certified", "cisco", 
               "comptia", "microsoft certified", "google certified", "pmp", "itil", "cisa", "cissp"],
        
        "SS": ["teamwork", "communication", "leadership", "problem-solving", "creativity", "collaboration", 
               "time management", "adaptability", "critical thinking", "interpersonal", "presentation"]
    }
    
    # Initialize results dictionary with empty lists for all entity types
    organized_entities = {entity_type: [] for entity_type in ENTITY_TYPES}
    
    # Handle empty or invalid input
    if not isinstance(text, str) or not text.strip():
        return organized_entities
    
    # Preprocess text
    processed_text = preprocess_text(text)
    if not processed_text:
        return organized_entities
    
    try:
        # Apply keyword-based extraction (simpler and more reliable for technical terms)
        lower_text = processed_text.lower()
        for category, keywords in keyword_mappings.items():
            for keyword in keywords:
                # Use word boundary check to avoid partial matches
                pattern = r'\b' + re.escape(keyword) + r'\b'
                if re.search(pattern, lower_text):
                    organized_entities[category].append(keyword)
        
        # Apply NER model for entities that might be missed by keywords
        max_length = 512  # BERT token limit
        words = processed_text.split()
        
        # Process text in chunks if needed
        if len(words) > max_length - 10:
            chunks = []
            for i in range(0, len(words), max_length - 10):
                chunks.append(" ".join(words[i:i+max_length-10]))
        else:
            chunks = [processed_text]
        
        # Process each chunk
        for chunk in chunks:
            try:
                # Get NER results
                ner_results = ner_pipeline(chunk)
                
                # Process each entity
                for entity in ner_results:
                    # Extract the entity word and group
                    word = entity.get("word", "").strip()
                    group = entity.get("entity_group", "")
                    
                    # Skip empty words
                    if not word:
                        continue
                    
                    # Map entity groups to our categories
                    if group == "MISC":
                        # For MISC entities, check which technical category they might belong to
                        word_lower = word.lower()
                        for cat, kw_list in keyword_mappings.items():
                            if any(kw in word_lower for kw in kw_list):
                                organized_entities[cat].append(word)
                                break
                        else:
                            # If no match found, default to Programming Language
                            organized_entities["PL"].append(word)
                    
                    elif group == "ORG":
                        # Organizations could be companies or cloud platforms
                        word_lower = word.lower()
                        if any(kw in word_lower for kw in keyword_mappings["CP"]):
                            organized_entities["CP"].append(word)
                    
                    elif group == "PER" and "years" in chunk.lower()[max(0, entity.get("start", 0)-20):entity.get("end", 0)+20]:
                        # People names near "years" might indicate experience
                        organized_entities["SS"].append(word)
            
            except Exception as e:
                print(f"Error in NER processing: {e}")
                # Continue with other chunks
        
        # Remove duplicates (case-insensitive)
        for entity_type in ENTITY_TYPES:
            seen = set()
            unique_list = []
            for item in organized_entities[entity_type]:
                item_lower = item.lower()
                if item_lower not in seen:
                    seen.add(item_lower)
                    unique_list.append(item)
            organized_entities[entity_type] = unique_list
            
    except Exception as e:
        print(f"Error in extract_entities: {e}")
        import traceback
        traceback.print_exc()
    
    return organized_entities

In [8]:
gt =  pd.read_csv("./final_merged.csv")

In [9]:
# Filter to only use rows marked as ground truth (GT = 1)
ground_truth_df = gt[gt['GT'] == 1]


In [10]:
ground_truth_df.size

6643

In [11]:
def determine_entity_type(skill):
    """
    Maps a given skill to one of the predefined entity types.
    Returns "MISC" if no matching category is found.
    
    Parameters:
    skill (str): The skill to categorize
    
    Returns:
    str: Entity type code (PL, FW, DB, etc.) or "MISC" for unmatched skills
    """
    if not skill:
        return "MISC"
        
    skill_lower = skill.lower()
    
    keyword_mappings = {
        "PL": ["python", "java", "javascript", "c++", "c#", "go", "rust", "php", "ruby", "swift",
               "typescript", "kotlin", "scala", "perl", "bash", "powershell", "r", "matlab", "dart", "lua"],
        
        "FW": ["react", "angular", "vue", "django", "flask", "spring", "laravel", "express", "symfony",
               "bootstrap", "jquery", "rails", "asp.net", "node.js", "next.js", "flutter", "xamarin", "qt"],
        
        "DB": ["sql", "mysql", "postgresql", "mongodb", "oracle", "cassandra", "redis", "dynamodb",
               "sqlite", "mariadb", "couchdb", "neo4j", "firebase", "elasticsearch", "snowflake", "bigtable"],
        
        "CP": ["aws", "azure", "gcp", "google cloud", "heroku", "digitalocean", "alibaba", "ibm cloud",
               "oracle cloud", "salesforce", "s3", "ec2", "lambda", "kubernetes", "eks", "gke", "aks"],
        
        "DO": ["docker", "kubernetes", "jenkins", "gitlab", "github actions", "circleci", "terraform",
               "ansible", "puppet", "chef", "prometheus", "grafana", "elk", "ci/cd", "devops", "sre"],
        
        "NS": ["firewall", "vpn", "encryption", "ssl", "tls", "authentication", "authorization",
               "oauth", "jwt", "kerberos", "ldap", "penetration testing", "security", "cybersecurity"],
        
        "DAS": ["machine learning", "data mining", "pandas", "numpy", "tensorflow", "pytorch", "big data",
                "hadoop", "spark", "tableau", "power bi", "data science", "statistics", "ai", "nlp"],
        
        "SWE": ["agile", "scrum", "object-oriented", "oop", "design patterns", "architecture",
                "microservices", "rest api", "soap", "uml", "solid", "tdd", "bdd", "clean code"],
        
        "PM": ["jira", "confluence", "trello", "asana", "scrum master", "product owner", "kanban",
               "project management", "pmp", "agile", "waterfall", "lean", "sprint planning"],
        
        "EC": ["bachelor", "master", "phd", "certificate", "certification", "aws certified", "cisco",
               "comptia", "microsoft certified", "google certified", "pmp", "itil", "cisa", "cissp"],
        
        "SS": ["teamwork", "communication", "leadership", "problem-solving", "creativity", "collaboration",
               "time management", "adaptability", "critical thinking", "interpersonal", "presentation"]
    }
    
    # Check each category's keywords for matches
    for category, keywords in keyword_mappings.items():
        for keyword in keywords:
            # Check if the keyword is in the skill text (as a whole word or part of a compound word)
            if keyword in skill_lower:
                return category
    
    # If no match is found
    return "MISC"




In [12]:
def prepare_ner_data(df):
    data_list = []
    error_count = 0
    
    for idx, row in df.iterrows():
        try:
            # Get text, with fallback
            if 'raw_text' not in row or pd.isna(row['raw_text']):
                print(f"Warning: Missing text at index {idx}. Skipping this row.")
                continue
            
            text = str(row['raw_text'])  # Ensure text is string
            
            # Get skills list, with validation
            if 'skills' not in row:
                print(f"Warning: No 'skills' column at index {idx}. Skipping this row.")
                continue
                
            skills_list = row['skills']
            
            # Handle non-iterable skills_list (string instead of list)
            if isinstance(skills_list, str):
                # Split the comma-separated string into a list
                skills_list = [s.strip() for s in skills_list.split(',')]
                print(f"Converted string skills to list at index {idx}. Found {len(skills_list)} skills.")
            elif not isinstance(skills_list, (list, tuple, set)) or pd.isna(skills_list):
                print(f"Warning: skills_list is not iterable at index {idx}. Value: {skills_list}. Using empty list.")
                skills_list = []
            
            # Initialize annotations list
            annotations = []
            
            # For each skill, find it in the text and determine its entity type
            for skill in skills_list:
                # Skip None or NaN values
                if pd.isna(skill):
                    continue
                
                # Ensure skill is a string
                skill = str(skill)
                
                # Skip if skill is empty
                if not skill.strip():
                    continue
                
                # Escape special regex characters if present in the skill
                skill_pattern = re.escape(skill)
                
                # Find all occurrences of the skill in the text
                for match in re.finditer(r'\b' + skill_pattern + r'\b', text, re.IGNORECASE):
                    start, end = match.span()
                    skill_type = determine_entity_type(skill)
                    
                    annotations.append({
                        "start": start,
                        "end": end,
                        "label": skill_type
                    })
            
            data_list.append({
                "text": text,
                "annotations": annotations
            })
            
        except Exception as e:
            error_count += 1
            print(f"Error processing row at index {idx}: {e}")
            # Continue with next row instead of stopping
    
    print(f"Processed {len(df)} rows with {error_count} errors. Generated {len(data_list)} valid examples.")
    return data_list

In [13]:
train_df, val_df = train_test_split(ground_truth_df, test_size=0.2, random_state=42)

In [14]:
# Create data with annotations
train_data = prepare_ner_data(train_df)
val_data = prepare_ner_data(val_df)

# Verify the data structure
print(f"Training data: {len(train_data)} examples")
print(f"Validation data: {len(val_data)} examples")
if len(train_data) > 0:
    print("Sample annotation:", train_data[0]['annotations'][:2])

Converted string skills to list at index 34431. Found 48 skills.
Converted string skills to list at index 11254. Found 97 skills.
Converted string skills to list at index 22164. Found 92 skills.
Converted string skills to list at index 10538. Found 192 skills.
Converted string skills to list at index 4305. Found 66 skills.
Converted string skills to list at index 3502. Found 47 skills.
Converted string skills to list at index 28271. Found 18 skills.
Converted string skills to list at index 30884. Found 30 skills.
Converted string skills to list at index 3669. Found 36 skills.
Converted string skills to list at index 34436. Found 57 skills.
Converted string skills to list at index 29321. Found 37 skills.
Converted string skills to list at index 34406. Found 49 skills.
Converted string skills to list at index 13365. Found 64 skills.
Converted string skills to list at index 188. Found 37 skills.
Converted string skills to list at index 4206. Found 46 skills.
Converted string skills to lis

In [15]:
# Create label list from keyword_mappings
labels = ["O"]  # Start with "O" for non-entity tokens
for skill_type in keyword_mappings.keys():
    labels.append(f"B-{skill_type}")
    labels.append(f"I-{skill_type}")

# Add MISC for anything that doesn't match
labels.append("B-MISC")
labels.append("I-MISC")

# Create label mappings
label_to_id = {label: i for i, label in enumerate(labels)}
id_to_label = {i: label for i, label in enumerate(labels)}

print(f"Total labels: {len(labels)}")
print("Label examples:", labels[:10])

Total labels: 25
Label examples: ['O', 'B-PL', 'I-PL', 'B-FW', 'I-FW', 'B-DB', 'I-DB', 'B-CP', 'I-CP', 'B-DO']


In [16]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")

# Convert to HF datasets
def convert_to_hf_dataset(data_list):
    texts = [item["text"] for item in data_list]
    dataset_dict = {"text": texts, "annotations": [item["annotations"] for item in data_list]}
    return Dataset.from_dict(dataset_dict)

train_dataset = convert_to_hf_dataset(train_data)
val_dataset = convert_to_hf_dataset(val_data)

print(f"Train dataset features: {train_dataset.features}")
print(f"Train dataset size: {len(train_dataset)}")

Train dataset features: {'text': Value(dtype='string', id=None), 'annotations': [{'end': Value(dtype='int64', id=None), 'label': Value(dtype='string', id=None), 'start': Value(dtype='int64', id=None)}]}
Train dataset size: 759


In [17]:
# Tokenize and align labels function
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["text"], 
        truncation=True, 
        padding="max_length",
        max_length=256,
        return_offsets_mapping=True
    )
    
    labels = []
    for i, annotations in enumerate(examples["annotations"]):
        label_ids = [label_to_id["O"]] * len(tokenized_inputs["input_ids"][i])
        offset_mapping = tokenized_inputs["offset_mapping"][i]
        
        for annotation in annotations:
            start_char = annotation["start"]
            end_char = annotation["end"]
            label = annotation["label"]
            
            # Find tokens corresponding to this entity
            token_start_index = None
            token_end_index = None
            
            for j, (start, end) in enumerate(offset_mapping):
                if start == end:  # Skip special tokens
                    continue
                    
                if token_start_index is None and start <= start_char < end:
                    token_start_index = j
                    
                if end > 0 and start < end_char <= end:
                    token_end_index = j
            
            if token_start_index is not None and token_end_index is not None:
                # Assign B- label to first token
                label_ids[token_start_index] = label_to_id[f"B-{label}"]
                
                # Assign I- label to remaining tokens
                for j in range(token_start_index + 1, token_end_index + 1):
                    label_ids[j] = label_to_id[f"I-{label}"]
        
        # Set special tokens to -100
        for j, input_id in enumerate(tokenized_inputs["input_ids"][i]):
            if j == 0 or j == len(label_ids) - 1 or input_id == tokenizer.pad_token_id:
                label_ids[j] = -100
        
        labels.append(label_ids)
    
    # Remove offset_mapping from the features
    tokenized_inputs.pop("offset_mapping")
    tokenized_inputs["labels"] = labels
    
    return tokenized_inputs

# Apply tokenization to datasets
tokenized_train_dataset = train_dataset.map(tokenize_and_align_labels, batched=True)
tokenized_val_dataset = val_dataset.map(tokenize_and_align_labels, batched=True)

print(f"Tokenized train features: {tokenized_train_dataset.features}")

Map:   0%|          | 0/759 [00:00<?, ? examples/s]

Map:   0%|          | 0/190 [00:00<?, ? examples/s]

Tokenized train features: {'text': Value(dtype='string', id=None), 'annotations': [{'end': Value(dtype='int64', id=None), 'label': Value(dtype='string', id=None), 'start': Value(dtype='int64', id=None)}], 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'labels': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None)}


In [18]:
model = AutoModelForTokenClassification.from_pretrained(
    "dslim/bert-base-NER", 
    num_labels=len(labels),
    id2label=id_to_label,
    label2id=label_to_id,
    ignore_mismatched_sizes=True,
    hidden_dropout_prob=0.2,    # Slightly increase dropout
    attention_probs_dropout_prob=0.2
)
# Define metrics computation
def compute_metrics(pred):
    predictions, labels = pred
    predictions = predictions.argmax(axis=2)
    
    # Only use valid predictions (not -100)
    true_predictions = [
        [id_to_label[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [id_to_label[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    
    # Flatten lists
    flat_predictions = [p for sublist in true_predictions for p in sublist]
    flat_labels = [l for sublist in true_labels for l in sublist]
    
    # Calculate metrics
    results = precision_recall_fscore_support(
        flat_labels,
        flat_predictions,
        average='weighted'
    )
    
    return {
        "precision": results[0],
        "recall": results[1],
        "f1": results[2]
    }

Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at dslim/bert-base-NER and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([9]) in the checkpoint and torch.Size([25]) in the model instantiated
- classifier.weight: found shape torch.Size([9, 768])

In [19]:
# Set up training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    report_to="none"
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [20]:
# Train the model
trainer.train()

# Save the model
model.save_pretrained("./ner_skills_model")
tokenizer.save_pretrained("./ner_skills_model")

# Evaluate the model
eval_results = trainer.evaluate()
print(f"Evaluation results: {eval_results}")

  0%|          | 0/570 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.3299342393875122, 'eval_precision': 0.898746660627385, 'eval_recall': 0.8889237134081364, 'eval_f1': 0.8922536439055734, 'eval_runtime': 36.9364, 'eval_samples_per_second': 5.144, 'eval_steps_per_second': 1.3, 'epoch': 1.0}


  0%|          | 0/48 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.28387945890426636, 'eval_precision': 0.9063697368999505, 'eval_recall': 0.9027560124532481, 'eval_f1': 0.9032618555675007, 'eval_runtime': 38.4388, 'eval_samples_per_second': 4.943, 'eval_steps_per_second': 1.249, 'epoch': 2.0}
{'loss': 0.3769, 'grad_norm': 1.1018067598342896, 'learning_rate': 2.456140350877193e-06, 'epoch': 2.63}


  0%|          | 0/48 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.28758832812309265, 'eval_precision': 0.9094298787654872, 'eval_recall': 0.9024843812031175, 'eval_f1': 0.9041541609906925, 'eval_runtime': 36.9727, 'eval_samples_per_second': 5.139, 'eval_steps_per_second': 1.298, 'epoch': 3.0}
{'train_runtime': 1875.7932, 'train_samples_per_second': 1.214, 'train_steps_per_second': 0.304, 'train_loss': 0.36332988069768535, 'epoch': 3.0}


  0%|          | 0/48 [00:00<?, ?it/s]

Evaluation results: {'eval_loss': 0.28758832812309265, 'eval_precision': 0.9094298787654872, 'eval_recall': 0.9024843812031175, 'eval_f1': 0.9041541609906925, 'eval_runtime': 37.441, 'eval_samples_per_second': 5.075, 'eval_steps_per_second': 1.282, 'epoch': 3.0}


  _warn_prf(average, modifier, msg_start, len(result))
