In [1]:
#!pip install snorkel

In [2]:
import pandas as pd

In [3]:
from conll2003 import Conll2003

dataset_builder = Conll2003()
dataset_builder.download_and_prepare()
dataset = dataset_builder.as_dataset()

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import re
import wandb
import numpy as np
import pandas as pd
from datasets import load_dataset
from snorkel.labeling import labeling_function, PandasLFApplier, LFAnalysis

In [5]:
# --------------------------------------------
# Step 1: Load dataset (subset for demo)
# --------------------------------------------
# dataset = load_dataset("conll2003", trust_remote_code=True)
df = dataset["train"].to_pandas()#.head(500)  # use only first 500 samples

# Flatten into token-level dataset for simplicity
tokens, ner_tags = [], []
for row in df.itertuples(index=False):
    # print(row)
    tokens.extend(row.tokens)
    ner_tags.extend(row.ner_tags)

token_df = pd.DataFrame({"token": tokens, "label": ner_tags})

# Label mapping from the dataset
label_feature = dataset["train"].features["ner_tags"].feature
token_df["label_name"] = token_df["label"].map(label_feature.int2str)
print(f"Loaded {len(token_df)} token examples from Hugging Face CoNLL-2003")

Loaded 203621 token examples from Hugging Face CoNLL-2003


In [6]:
# --------------------------------------------
# Step 2: Define constants for labeling
# --------------------------------------------
ABSTAIN = -1
DATE = 7  # using MISC index as per CoNLL-2003
ORG = 3   # B-ORG index (depends on dataset schema)

# --------------------------------------------
# Step 3: Define labeling functions
# --------------------------------------------

@labeling_function()
def lf_detect_year(x):
    """Heuristic: detect years between 1900 and 2099."""
    return DATE if re.fullmatch(r"(19|20)\d{2}", x.token) else ABSTAIN


@labeling_function()
def lf_detect_org_suffix(x):
    """Pattern: detect organization suffixes like 'Inc.', 'Corp.', 'Ltd.', 'LLC'."""
    return ORG if re.search(r"(Inc\.|Corp\.|Ltd\.|LLC)", x.token) else ABSTAIN


# List of labeling functions
lfs = [lf_detect_year, lf_detect_org_suffix]

# --------------------------------------------
# Step 4: Apply labeling functions
# --------------------------------------------
applier = PandasLFApplier(lfs=lfs)
L_train = applier.apply(df=token_df)

# --------------------------------------------
# Step 5: Analyze results
# --------------------------------------------
analysis = LFAnalysis(L=L_train, lfs=lfs)
results = analysis.lf_summary()

print("\n=== Labeling Function Summary ===")
print(results)


100%|██████████| 203621/203621 [00:01<00:00, 158247.57it/s]



=== Labeling Function Summary ===
                      j Polarity  Coverage  Overlaps  Conflicts
lf_detect_year        0      [7]  0.002667       0.0        0.0
lf_detect_org_suffix  1      [3]  0.000108       0.0        0.0


In [7]:
# --------------------------------------------
# Step 6: Compute coverage and accuracy
# --------------------------------------------
labels = token_df["label"].to_numpy()
lf_metrics = []
for i, lf in enumerate(lfs):
    mask = L_train[:, i] != ABSTAIN
    coverage = float(mask.mean())
    accuracy = float((L_train[mask, i] == labels[mask]).mean()) if mask.any() else float("nan")
    lf_metrics.append({
        "lf_index": i,
        "lf_name": lf.name,
        "coverage": coverage,
        "accuracy": accuracy,
        "triggers": int(mask.sum())
    })

metrics_df = pd.DataFrame(lf_metrics)
print("\n=== Labeling Function Metrics ===")
print(metrics_df)


=== Labeling Function Metrics ===
   lf_index               lf_name  coverage  accuracy  triggers
0         0        lf_detect_year  0.002667  0.005525       543
1         1  lf_detect_org_suffix  0.000108  0.000000        22


In [8]:
# # --------------------------------------------
# # Step 7: Log metrics to Weights & Biases
# # --------------------------------------------
# wandb.init(project="Q1-weak-supervision-ner", name="snorkel_labeling_functions", reinit=True)

# for _, row in metrics_df.iterrows():
#     wandb.log({
#         "lf_name": row["lf_name"],
#         "coverage": row["coverage"],
#         "accuracy": row["accuracy"],
#         "triggers": row["triggers"]
#     })

# wandb.finish()

In [9]:
import matplotlib.pyplot as plt
import wandb

wandb.init(project="Q1-weak-supervision-ner", name="cov_acc")

# Example: coverage bar chart
# plt.figure(figsize=(6,4))
# plt.bar(metrics_df["lf_name"], metrics_df["coverage"])
# plt.ylabel("Coverage (%)")
# plt.title("LF Coverage")
# plt.tight_layout()
# plt.savefig("plots/q2_coverage.png")  # save locally
wandb.log({"coverage_chart": wandb.Image("plots/q2_coverage.png"),
            "accuracy_chart": wandb.Image("plots/q2_accuracy.png")})

# Similarly for accuracy
# plt.figure(figsize=(6,4))
# plt.bar(metrics_df["lf_name"], metrics_df["accuracy"])
# plt.ylabel("Accuracy (%)")
# plt.title("LF Accuracy")
# plt.tight_layout()
# plt.savefig("plots/q2_accuracy.png
wandb.finish()


[34m[1mwandb[0m: Currently logged in as: [33m142201020[0m ([33m142201020-indian-institute-of-technology-palakkad[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
#!pip install scikit-learn

In [12]:
from snorkel.labeling.model.baselines import MajorityLabelVoter
from snorkel.labeling import PandasLFApplier, LabelingFunction, LFAnalysis
import numpy as np
import wandb

# Apply LFs
applier = PandasLFApplier(lfs=lfs)
L = applier.apply(token_df)

# Cardinality = number of classes
cardinality = len(token_df['label'].unique())
majority_model = MajorityLabelVoter(cardinality=cardinality)

# Predict aggregated labels
Y_majority = majority_model.predict(L)

# Optional: compare to ground truth
from sklearn.metrics import classification_report
print(classification_report(token_df["label"], Y_majority))

# Log to W&B
wandb.init(project="Q1-weak-supervision-ner", name="majority_label_aggregation")
coverage = np.mean(Y_majority != -1)
accuracy = np.mean(Y_majority == token_df["label"])
wandb.log({"majority_label_coverage": coverage, "majority_label_accuracy": accuracy})
wandb.finish()


100%|██████████| 203621/203621 [00:01<00:00, 133134.83it/s]


              precision    recall  f1-score   support

          -1       0.00      0.00      0.00         0
           0       0.00      0.00      0.00    169578
           1       0.00      0.00      0.00      6600
           2       0.00      0.00      0.00      4528
           3       0.00      0.00      0.00      6321
           4       0.00      0.00      0.00      3704
           5       0.00      0.00      0.00      7140
           6       0.00      0.00      0.00      1157
           7       0.01      0.00      0.00      3438
           8       0.00      0.00      0.00      1155

    accuracy                           0.00    203621
   macro avg       0.00      0.00      0.00    203621
weighted avg       0.00      0.00      0.00    203621



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


0,1
majority_label_accuracy,▁
majority_label_coverage,▁

0,1
majority_label_accuracy,1e-05
majority_label_coverage,0.00277
