In [1]:
pip install datasets wandb snorkel


Collecting wandb
  Downloading wandb-0.22.2-py3-none-win_amd64.whl (19.1 MB)
Collecting snorkel
  Downloading snorkel-0.9.9-py3-none-any.whl (103 kB)
Collecting eval-type-backport
  Downloading eval_type_backport-0.2.2-py3-none-any.whl (5.8 kB)
Collecting sentry-sdk>=2.0.0
  Downloading sentry_sdk-2.40.0-py2.py3-none-any.whl (374 kB)
Collecting urllib3<3,>=1.21.1
  Using cached urllib3-2.5.0-py3-none-any.whl (129 kB)
Installing collected packages: urllib3, sentry-sdk, eval-type-backport, wandb, snorkel
  Attempting uninstall: urllib3
    Found existing installation: urllib3 1.26.9
    Uninstalling urllib3-1.26.9:
      Successfully uninstalled urllib3-1.26.9
Successfully installed eval-type-backport-0.2.2 sentry-sdk-2.40.0 snorkel-0.9.9 urllib3-2.5.0 wandb-0.22.2
Note: you may need to restart the kernel to use updated packages.


ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
conda-repo-cli 1.0.4 requires pathlib, which is not installed.
anaconda-project 0.10.2 requires ruamel-yaml, which is not installed.
google-api-core 1.25.1 requires google-auth<2.0dev,>=1.21.1, but you have google-auth 2.38.0 which is incompatible.
botocore 1.24.32 requires urllib3<1.27,>=1.25.4, but you have urllib3 2.5.0 which is incompatible.


In [3]:

#  Import libraries
from datasets import load_dataset
import wandb
import pandas as pd
from collections import Counter


# Initialize W&B project

wandb.init(project="Q1-weak-supervision-ner", name="Dataset_Stats")


# Load the CoNLL-2003 dataset

dataset = load_dataset("eriktks/conll2003")

# Print available splits
print(dataset)


# Check data sample

print(dataset["train"][0])


# Compute dataset statistics

# Total samples in each split
train_size = len(dataset["train"])
valid_size = len(dataset["validation"])
test_size = len(dataset["test"])

# Entity label names
label_list = dataset["train"].features["ner_tags"].feature.names
print("Entity labels:", label_list)

# Function to count entity occurrences
def count_entities(split):
    counts = Counter()
    for sample in dataset[split]:
        for tag in sample["ner_tags"]:
            label = label_list[tag]
            if label != "O":  # ignore non-entity tokens
                counts[label] += 1
    return counts

train_entities = count_entities("train")
valid_entities = count_entities("validation")
test_entities = count_entities("test")

# Combine into DataFrame for visualization
df_stats = pd.DataFrame({
    "Split": ["Train", "Validation", "Test"],
    "Samples": [train_size, valid_size, test_size]
})
print(df_stats)

print("\nEntity distribution (train):", train_entities)


# Log metrics to W&B

wandb.log({
    "train_samples": train_size,
    "valid_samples": valid_size,
    "test_samples": test_size,
    "entity_distribution_train": dict(train_entities)
})

# Add dataset statistics as summary metrics
wandb.summary["train_samples"] = train_size
wandb.summary["valid_samples"] = valid_size
wandb.summary["test_samples"] = test_size
wandb.summary["entity_distribution_train"] = dict(train_entities)

wandb.finish()


DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3453
    })
})
{'id': '0', 'tokens': ['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.'], 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7], 'chunk_tags': [11, 21, 11, 12, 21, 22, 11, 12, 0], 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0]}
Entity labels: ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']
        Split  Samples
0       Train    14041
1  Validation     3250
2        Test     3453

Entity distribution (train): Counter({'B-LOC': 7140, 'B-PER': 6600, 'B-ORG': 6321, 'I-PER': 4528, 'I-ORG': 3704, 'B-MISC': 3438, 'I-LOC': 1157, 'I-MISC': 1155})


0,1
test_samples,▁
train_samples,▁
valid_samples,▁

0,1
test_samples,3453
train_samples,14041
valid_samples,3250


In [24]:

# Step 2:Labeling Functions with Snorkel
import re
import pandas as pd
import wandb
from snorkel.labeling import labeling_function, PandasLFApplier, LFAnalysis
from datasets import load_dataset


# Initialize W&B

wandb.init(project="Q1-weak-supervision-ner", name="Labeling_Functions_Eval")


# Load dataset

dataset = load_dataset("eriktks/conll2003")
df = pd.DataFrame(dataset['train'])
df['text'] = [' '.join(tokens) for tokens in df['tokens']]


# Define Labeling Functions


# LF 1: Detect years (DATE/MISC)
@labeling_function()
def lf_year(x):
    if re.search(r'\b(19|20)\d{2}\b', x.text):
        return 7  # B-MISC label
    return -1  # Abstain

# LF 2: Detect organizations by common suffixes
@labeling_function()
def lf_org_suffix(x):
    if re.search(r'\b(Inc\.|Corp\.|Ltd\.|LLC)\b', x.text):
        return 3  
    return -1  

# LF 3: Detect locations by common suffixes (optional, improves aggregation)
@labeling_function()
def lf_loc_suffix(x):
    if re.search(r"\b(city|town|village|river|mountain)\b", x.text, re.IGNORECASE):
        return 5  
    return -1  

# List of LFs
lfs = [lf_year, lf_org_suffix, lf_loc_suffix]


# Apply LFs to dataset

applier = PandasLFApplier(lfs=lfs)
L_train = applier.apply(df)  # shape: (num_sentences, num_LFs)


# Flatten labels for token-level analysis

# Flatten L_train and gold labels to token-level
L_tokens = []
gold_labels = []

for i, row in df.iterrows():
    for j, token in enumerate(row['tokens']):
        # Each LF applied on sentence text → propagate to all tokens
        L_tokens.append(L_train[i])
        gold_labels.append(row['ner_tags'][j])

import numpy as np
L_tokens = np.array(L_tokens)
gold_labels = np.array(gold_labels)


# Analyze LF coverage and accuracy

analysis = LFAnalysis(L=L_tokens, lfs=lfs).lf_summary(Y=gold_labels)
print(analysis)


# Log coverage, overlaps, conflicts, accuracy to W&B

for lf_name, row in analysis.iterrows():
    wandb.log({
        f"{lf_name}_coverage": row["Coverage"],
        f"{lf_name}_overlaps": row["Overlaps"],
        f"{lf_name}_conflicts": row["Conflicts"],
        f"{lf_name}_accuracy": row["Emp. Acc."]
    })

wandb.finish()


100%|██████████████████████████████████████████████████████████████████████████| 14041/14041 [00:01<00:00, 7295.16it/s]


               j Polarity  Coverage  Overlaps  Conflicts  Correct  Incorrect  \
lf_year        0      [7]  0.073322   0.00195    0.00195      279      14651   
lf_org_suffix  1       []  0.000000   0.00000    0.00000        0          0   
lf_loc_suffix  2      [5]  0.027846   0.00195    0.00195      266       5404   

               Emp. Acc.  
lf_year         0.018687  
lf_org_suffix   0.000000  
lf_loc_suffix   0.046914  


0,1
lf_loc_suffix_accuracy,▁
lf_loc_suffix_conflicts,▁
lf_loc_suffix_coverage,▁
lf_loc_suffix_overlaps,▁
lf_org_suffix_accuracy,▁
lf_org_suffix_conflicts,▁
lf_org_suffix_coverage,▁
lf_org_suffix_overlaps,▁
lf_year_accuracy,▁
lf_year_conflicts,▁

0,1
lf_loc_suffix_accuracy,0.04691
lf_loc_suffix_conflicts,0.00195
lf_loc_suffix_coverage,0.02785
lf_loc_suffix_overlaps,0.00195
lf_org_suffix_accuracy,0
lf_org_suffix_conflicts,0
lf_org_suffix_coverage,0
lf_org_suffix_overlaps,0
lf_year_accuracy,0.01869
lf_year_conflicts,0.00195


In [25]:

# Step 3: Label Aggregation using LabelModel

import numpy as np
import wandb
from snorkel.labeling.model import LabelModel


# Initialize W&B

wandb.init(project="Q1-weak-supervision-ner", name="LabelModel_Aggregation")


# Fit LabelModel on token-level LF matrix

label_model = LabelModel(cardinality=8,  # number of NER classes including 'O'
                         verbose=True)

label_model.fit(
    L_train=L_tokens,      # token-level LF matrix
    n_epochs=500,
    log_freq=100,
    seed=42
)


# Predict aggregated labels

y_agg = label_model.predict(L=L_tokens)

# Compute aggregated coverage

coverage = np.mean(y_agg != -1)
print(f"Aggregated label coverage: {coverage:.4f}")


# Log coverage to W&B

wandb.log({"labelmodel_coverage": coverage})

wandb.finish()


INFO:root:Computing O...
INFO:root:Estimating \mu...
  0%|                                                                                       | 0/500 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.003]
 16%|████████████                                                                 | 78/500 [00:00<00:01, 362.54epoch/s]INFO:root:[100 epochs]: TRAIN:[loss=0.000]
 37%|████████████████████████████                                                | 185/500 [00:00<00:00, 341.74epoch/s]INFO:root:[200 epochs]: TRAIN:[loss=0.000]
 58%|███████████████████████████████████████████▉                                | 289/500 [00:00<00:00, 328.78epoch/s]INFO:root:[300 epochs]: TRAIN:[loss=0.000]
 79%|███████████████████████████████████████████████████████████▉                | 394/500 [00:01<00:00, 341.41epoch/s]INFO:root:[400 epochs]: TRAIN:[loss=0.000]
100%|████████████████████████████████████████████████████████████████████████████| 500/500 [00:01<00:00, 332.38epoch/s]
INFO:root:Finished 

Aggregated label coverage: 0.0992


0,1
labelmodel_coverage,▁

0,1
labelmodel_coverage,0.09922


In [26]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import wandb
from torch.utils.data import Subset
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
subset_size = 5000  # number of samples for each dataset

# Data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# CIFAR-10 dataset subset
trainset_10 = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
subset_indices_10 = random.sample(range(len(trainset_10)), subset_size)
trainloader_10 = torch.utils.data.DataLoader(Subset(trainset_10, subset_indices_10), batch_size=128, shuffle=True)

# CIFAR-100 dataset subset
trainset_100 = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
subset_indices_100 = random.sample(range(len(trainset_100)), subset_size)
trainloader_100 = torch.utils.data.DataLoader(Subset(trainset_100, subset_indices_100), batch_size=128, shuffle=True)

# Define a simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.classifier = nn.Linear(64*8*8, num_classes)
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

# Training function
def train_model(model, trainloader, epochs=10, dataset_name="Dataset"):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    model.train()
    for epoch in range(epochs):
        running_loss, correct, total = 0.0, 0, 0
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
        
        epoch_loss = running_loss / total
        epoch_acc = correct / total
        print(f"{dataset_name} Epoch {epoch+1}/{epochs} - Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.4f}")
        wandb.log({f"{dataset_name}_loss": epoch_loss, f"{dataset_name}_accuracy": epoch_acc})

# Experiment A: CIFAR-100 then CIFAR-10
wandb.init(project="CIFAR_subset_sequential", name="CIFAR100_then_CIFAR10")
model = SimpleCNN(num_classes=100).to(device)
train_model(model, trainloader_100, epochs=10, dataset_name="CIFAR100")
model.classifier = nn.Linear(64*8*8, 10).to(device)
train_model(model, trainloader_10, epochs=10, dataset_name="CIFAR10")
wandb.finish()

# Experiment B: CIFAR-10 then CIFAR-100
wandb.init(project="CIFAR_subset_sequential", name="CIFAR10_then_CIFAR100")
model = SimpleCNN(num_classes=10).to(device)
train_model(model, trainloader_10, epochs=10, dataset_name="CIFAR10")
model.classifier = nn.Linear(64*8*8, 100).to(device)
train_model(model, trainloader_100, epochs=10, dataset_name="CIFAR100")
wandb.finish()


CIFAR100 Epoch 1/10 - Loss: 4.3377, Acc: 0.0536
CIFAR100 Epoch 2/10 - Loss: 3.6759, Acc: 0.1470
CIFAR100 Epoch 3/10 - Loss: 3.2743, Acc: 0.2254
CIFAR100 Epoch 4/10 - Loss: 2.9061, Acc: 0.3036
CIFAR100 Epoch 5/10 - Loss: 2.5334, Acc: 0.3796
CIFAR100 Epoch 6/10 - Loss: 2.1894, Acc: 0.4526
CIFAR100 Epoch 7/10 - Loss: 1.8500, Acc: 0.5258
CIFAR100 Epoch 8/10 - Loss: 1.5103, Acc: 0.6158
CIFAR100 Epoch 9/10 - Loss: 1.2310, Acc: 0.6896
CIFAR100 Epoch 10/10 - Loss: 0.9467, Acc: 0.7684
CIFAR10 Epoch 1/10 - Loss: 1.7339, Acc: 0.3922
CIFAR10 Epoch 2/10 - Loss: 1.3210, Acc: 0.5432
CIFAR10 Epoch 3/10 - Loss: 1.1664, Acc: 0.5894
CIFAR10 Epoch 4/10 - Loss: 1.0526, Acc: 0.6368
CIFAR10 Epoch 5/10 - Loss: 0.9457, Acc: 0.6788
CIFAR10 Epoch 6/10 - Loss: 0.9033, Acc: 0.6940
CIFAR10 Epoch 7/10 - Loss: 0.8014, Acc: 0.7290
CIFAR10 Epoch 8/10 - Loss: 0.7433, Acc: 0.7450
CIFAR10 Epoch 9/10 - Loss: 0.6702, Acc: 0.7826
CIFAR10 Epoch 10/10 - Loss: 0.6043, Acc: 0.7996


0,1
CIFAR100_accuracy,▁▂▃▃▄▅▆▇▇█
CIFAR100_loss,█▇▆▅▄▄▃▂▂▁
CIFAR10_accuracy,▁▄▄▅▆▆▇▇██
CIFAR10_loss,█▅▄▄▃▃▂▂▁▁

0,1
CIFAR100_accuracy,0.7684
CIFAR100_loss,0.94665
CIFAR10_accuracy,0.7996
CIFAR10_loss,0.60434


CIFAR10 Epoch 1/10 - Loss: 1.9817, Acc: 0.2890
CIFAR10 Epoch 2/10 - Loss: 1.6411, Acc: 0.4238
CIFAR10 Epoch 3/10 - Loss: 1.4929, Acc: 0.4834
CIFAR10 Epoch 4/10 - Loss: 1.3551, Acc: 0.5234
CIFAR10 Epoch 5/10 - Loss: 1.2629, Acc: 0.5558
CIFAR10 Epoch 6/10 - Loss: 1.1777, Acc: 0.5942
CIFAR10 Epoch 7/10 - Loss: 1.1276, Acc: 0.6160
CIFAR10 Epoch 8/10 - Loss: 1.0664, Acc: 0.6340
CIFAR10 Epoch 9/10 - Loss: 1.0003, Acc: 0.6588
CIFAR10 Epoch 10/10 - Loss: 0.9353, Acc: 0.6856
CIFAR100 Epoch 1/10 - Loss: 4.0653, Acc: 0.1008
CIFAR100 Epoch 2/10 - Loss: 2.8901, Acc: 0.3138
CIFAR100 Epoch 3/10 - Loss: 2.1914, Acc: 0.4564
CIFAR100 Epoch 4/10 - Loss: 1.6443, Acc: 0.5810
CIFAR100 Epoch 5/10 - Loss: 1.1663, Acc: 0.7082
CIFAR100 Epoch 6/10 - Loss: 0.7881, Acc: 0.8078
CIFAR100 Epoch 7/10 - Loss: 0.5066, Acc: 0.8822
CIFAR100 Epoch 8/10 - Loss: 0.3464, Acc: 0.9306
CIFAR100 Epoch 9/10 - Loss: 0.1840, Acc: 0.9700
CIFAR100 Epoch 10/10 - Loss: 0.1197, Acc: 0.9830


0,1
CIFAR100_accuracy,▁▃▄▅▆▇▇███
CIFAR100_loss,█▆▅▄▃▂▂▁▁▁
CIFAR10_accuracy,▁▃▄▅▆▆▇▇██
CIFAR10_loss,█▆▅▄▃▃▂▂▁▁

0,1
CIFAR100_accuracy,0.983
CIFAR100_loss,0.11973
CIFAR10_accuracy,0.6856
CIFAR10_loss,0.93534
