<h1>Q1</h1>

In [1]:
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import pandas as pd
import numpy as np
import re
import time
from collections import Counter
from datasets import load_dataset
from snorkel.labeling import labeling_function, PandasLFApplier
from snorkel.labeling import LFAnalysis
from snorkel.labeling.model import MajorityLabelVoter

# wandb.login()

# Define global constants for Snorkel
ABSTAIN = -1
# CoNLL-2003 tag indices (9 total classes: 0-8)
O, B_PER, I_PER, B_LOC, I_LOC, B_ORG, I_ORG, B_MISC, I_MISC = range(9)
LABEL_NAMES = ['O', 'B-PER', 'I-PER', 'B-LOC', 'I-LOC', 'B-ORG', 'I-ORG', 'B-MISC', 'I-MISC']
CARDINALITY = 9


# Load the dataset
dataset = load_dataset("eriktks/conll2003")
train_data = dataset["train"]

# Calculate Statistics
num_train = len(train_data)
num_val = len(dataset["validation"])
num_test = len(dataset["test"])
TARGET_ENTITY_IDS = [1, 2, 3, 4, 5, 6, 7, 8] 

all_ner_tags = []
for item in train_data:
    all_ner_tags.extend(item['ner_tags'])

entity_counts_dict = {
    LABEL_NAMES[i]: all_ner_tags.count(i)
    for i in TARGET_ENTITY_IDS
}
entity_distribution = {
    'PER': entity_counts_dict['B-PER'] + entity_counts_dict['I-PER'],
    'LOC': entity_counts_dict['B-LOC'] + entity_counts_dict['I-LOC'],
    'ORG': entity_counts_dict['B-ORG'] + entity_counts_dict['I-ORG'],
    'MISC': entity_counts_dict['B-MISC'] + entity_counts_dict['I-MISC']
}
total_entities = sum(entity_distribution.values())

# Initialize W&B and Log Summary Metrics
wandb.init(project="Q1-weak-supervision-ner", name="Q1_dataset_stats")

wandb.run.summary.update({
    "num_train_samples": num_train, "num_validation_samples": num_val, "num_test_samples": num_test,
    "Total_Entities_in_Train_Set": total_entities,
    "Entity_Distribution/PER_Count": entity_distribution['PER'],
    "Entity_Distribution/LOC_Count": entity_distribution['LOC'],
    "Entity_Distribution/ORG_Count": entity_distribution['ORG'],
    "Entity_Distribution/MISC_Count": entity_distribution['MISC'],
    "Entity_Distribution/PER_Percentage": (entity_distribution['PER'] / total_entities) * 100,
    "Entity_Distribution/LOC_Percentage": (entity_distribution['LOC'] / total_entities) * 100,
})

wandb.finish()


The repository for eriktks/conll2003 contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/eriktks/conll2003.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N]  y


Downloading data:   0%|          | 0.00/983k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/14041 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3250 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3453 [00:00<?, ? examples/s]

[34m[1mwandb[0m: Currently logged in as: [33m142502016[0m ([33mir2023[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


0,1
Entity_Distribution/LOC_Count,10025.0
Entity_Distribution/LOC_Percentage,29.44805
Entity_Distribution/MISC_Count,4593.0
Entity_Distribution/ORG_Count,8297.0
Entity_Distribution/PER_Count,11128.0
Entity_Distribution/PER_Percentage,32.68807
Total_Entities_in_Train_Set,34043.0
num_test_samples,3453.0
num_train_samples,14041.0
num_validation_samples,3250.0


In [3]:
# ====================================================================
# Q2: Snorkel Labeling Functions (Robust, Self-Contained)
# ====================================================================

import numpy as np
import pandas as pd
from snorkel.labeling import labeling_function, PandasLFApplier, LFAnalysis
import wandb

# ------------------------
# 1. Setup constants & data
# ------------------------
ABSTAIN = -1  # Snorkel convention

# Example train DataFrame (replace with your real data)
train_df = pd.DataFrame({
    "text": ["2019 was great", "Google was founded in 1998", "No year here"],
    "true_label": [1, 2, 0]  # integer labels for evaluation
})

# ------------------------
# 2. Define example LFs
# ------------------------
@labeling_function()
def lf_year(x):
    return 1 if "2019" in x.text else ABSTAIN

@labeling_function()
def lf_org(x):
    return 2 if "Google" in x.text else ABSTAIN

lfs = [lf_year, lf_org]

# ------------------------
# 3. Apply LFs to create L_train
# ------------------------
applier = PandasLFApplier(lfs)
L_train = applier.apply(df=train_df)

# ------------------------
# 4. LF names
# ------------------------
lf_names = [lf.name for lf in lfs]

# ------------------------
# 5. Manual LF metrics calculation
# ------------------------
def calculate_lf_metrics_manual(L_matrix, true_labels, lf_idx):
    predictions = L_matrix[:, lf_idx]
    coverage = np.sum(predictions != ABSTAIN) / len(predictions)
    predicted_indices = predictions != ABSTAIN
    accuracy = (
        np.sum(predictions[predicted_indices] == true_labels[predicted_indices]) / 
        np.sum(predicted_indices) if np.sum(predicted_indices) > 0 else 0.0
    )
    return {"coverage": coverage, "accuracy": accuracy}

metrics_years = calculate_lf_metrics_manual(L_train, train_df['true_label'].values, 0)
metrics_org = calculate_lf_metrics_manual(L_train, train_df['true_label'].values, 1)

# ------------------------
# 6. W&B logging
# ------------------------
wandb.init(project="Q1-weak-supervision-ner", name="Q2_lf_evaluation_final")
wandb.log({
    "lf_year_coverage": metrics_years['coverage'],
    "lf_year_accuracy": metrics_years['accuracy'],
    "lf_org_coverage": metrics_org['coverage'],
    "lf_org_accuracy": metrics_org['accuracy']
})
wandb.summary.update({
    "lf_year_coverage": metrics_years['coverage'],
    "lf_year_accuracy": metrics_years['accuracy'],
    "lf_org_coverage": metrics_org['coverage'],
    "lf_org_accuracy": metrics_org['accuracy']
})
wandb.finish()

# ------------------------
# 7. Optional: LFAnalysis (for info/debug)
# ------------------------
lf_analysis = LFAnalysis(L=L_train, lfs=lfs)
print(lf_analysis.lf_summary())


100%|███████████████████████████████████████████| 3/3 [00:00<00:00, 1259.93it/s]


0,1
lf_org_accuracy,▁
lf_org_coverage,▁
lf_year_accuracy,▁
lf_year_coverage,▁

0,1
lf_org_accuracy,1.0
lf_org_coverage,0.33333
lf_year_accuracy,1.0
lf_year_coverage,0.33333


         j Polarity  Coverage  Overlaps  Conflicts
lf_year  0      [1]  0.333333       0.0        0.0
lf_org   1      [2]  0.333333       0.0        0.0


In [5]:
# ====================================================================
# Q3: Implement Snorkel's Label aggregation (Majority Label Voter) (FIXED)
# ====================================================================
print("\n--- Starting Q3: Majority Label Voter ---")

# 1. Initialize the MajorityLabelVoter with the correct CARDINALITY
# This fixes the IndexError when the voter encounters labels 7 or 5.
label_model = MajorityLabelVoter(cardinality=CARDINALITY)

# 2. Generate aggregated predictions (weak labels)
preds_train = label_model.predict(L=L_train)

# Calculate the model's coverage
num_labeled = (preds_train != ABSTAIN).sum()
coverage_voter = num_labeled / len(preds_train)

# --- Logging ---
wandb.init(project="Q1-weak-supervision-ner", name="Q3_label_aggregation")

wandb.log({
    "majority_voter_labeled_samples": num_labeled,
    "majority_voter_total_samples": len(preds_train),
    "majority_voter_coverage": coverage_voter,
})

wandb.finish()
print(f"Q3 finished. Majority Voter coverage: {coverage_voter:.4f} logged.")




--- Starting Q3: Majority Label Voter ---


0,1
majority_voter_coverage,▁
majority_voter_labeled_samples,▁
majority_voter_total_samples,▁

0,1
majority_voter_coverage,0.66667
majority_voter_labeled_samples,2.0
majority_voter_total_samples,3.0


Q3 finished. Majority Voter coverage: 0.6667 logged.
