In [61]:
# W&B reset and environment setup (run first)
import os, sys
# Finish any active run and unload module so new settings apply
if 'wandb' in sys.modules:
    try:
        import wandb as _wandb
        if getattr(_wandb, 'run', None) is not None:
            _wandb.finish()
    except Exception:
        pass
    finally:
        try:
            del sys.modules['wandb']
        except Exception:
            pass
# Force online and set your entity so runs show up under your profile/org
os.environ['WANDB_MODE'] = 'online'
os.environ.setdefault('WANDB_PROJECT', 'Q4-mobilenet-cifar-seq')
os.environ.setdefault('WANDB_ENTITY', 'ir2023')
print('W&B reset done. Mode=online, Project=', os.environ['WANDB_PROJECT'], 'Entity=', os.environ['WANDB_ENTITY'])

W&B reset done. Mode=online, Project= Q4-mobilenet-cifar-seq Entity= ir2023-org


In [62]:
# W&B session overrides (optional): choose your project/entity for this session
import os
# Force your team project so all questions show together with Q4
os.environ['WANDB_PROJECT'] = 'Q4-mobilenet-cifar-seq'
os.environ['Q4_WANDB_PROJECT'] = os.environ['WANDB_PROJECT']
os.environ['WANDB_ENTITY'] = 'ir2023'
print('Session W&B -> Project =', os.environ['WANDB_PROJECT'], '| Q4 Project =', os.environ['Q4_WANDB_PROJECT'], '| Entity =', os.environ['WANDB_ENTITY'])

Session W&B -> Project = Q4-mobilenet-cifar-seq | Q4 Project = Q4-mobilenet-cifar-seq | Entity = ir2023


> Kernel setup: Select the Jupyter kernel named "Python (assignment5)" from the kernel picker (top-right). If you don’t see it, click the refresh icon and try again.

# Assignment-5 — Question 1: CoNLL-2003 NER stats to W&B

This notebook implements only Q1 from the assignment:
- Load HuggingFace dataset: `eriktks/conll2003`
- Initialize Weights & Biases project: `Q1-weak-supervision-ner`
- Compute and log dataset statistics as W&B summary metrics:
  - Number of samples per split
  - Entity distribution (PER, LOC, ORG, MISC) as entity spans


In [31]:
# Install dependencies (runs quickly if already satisfied)
import sys, subprocess
subprocess.run([sys.executable, '-m', 'pip', 'install', 'datasets>=2.14,<3.0', 'wandb>=0.17.0'], check=False)



CompletedProcess(args=['/home/chakri/Documents/Projects/Chicken-Disease-Classification/chicken/bin/python', '-m', 'pip', 'install', 'datasets>=2.14,<3.0', 'wandb>=0.17.0'], returncode=0)

In [32]:
# OPTIONAL: Set your W&B API key for this session (remove before sharing)
import os
# Prefer: run `wandb login` in a terminal once. This cell is only for temporary use.
WANDB_API_KEY = os.environ.get('WANDB_API_KEY', '')  # e.g., 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'
if WANDB_API_KEY:
    os.environ['WANDB_API_KEY'] = WANDB_API_KEY
    print('W&B API key set for this kernel session from environment.')
else:
    print('No W&B API key found in env; wandb will prompt for login or use prior auth.')

W&B API key set for this kernel session from environment.


In [63]:
# Global W&B online config (set once)
import os
try:
    import wandb
    os.environ['WANDB_MODE'] = 'online'
    os.environ.setdefault('WANDB_PROJECT', 'Q1-weak-supervision-ner')
    os.environ.setdefault('WANDB_ENTITY', '112201022')
    try:
        wandb.login(relogin=False)
        print('W&B online mode enabled. Project =', os.environ['WANDB_PROJECT'], 'Entity =', os.environ.get('WANDB_ENTITY'))
    except Exception as e:
        print('W&B login warning:', e)
        print('Proceeding; cells will fall back to offline if needed.')
except Exception as _e:
    print('wandb not available for global setup; cells will handle fallbacks.')



W&B online mode enabled. Project = Q4-mobilenet-cifar-seq Entity = ir2023


In [34]:
# W&B helpers: online login and run bootstrap
import os

def ensure_wandb_online(project: str | None = None, entity: str | None = None, relogin: bool = False) -> bool:
    try:
        import wandb
    except Exception as _e:
        print('wandb not available; cannot enable online logging.')
        return False
    os.environ['WANDB_MODE'] = 'online'
    if project:
        os.environ['WANDB_PROJECT'] = project
    if entity:
        os.environ['WANDB_ENTITY'] = entity
    try:
        # Uses WANDB_API_KEY from env if set
        wandb.login(relogin=relogin)
    except Exception as e:
        print('W&B login warning:', e)
    # If a run already exists, we are good
    return True


def ensure_wandb_run(project: str, job_type: str, name: str | None = None, config: dict | None = None, entity: str | None = None) -> bool:
    try:
        import wandb
    except Exception:
        print('wandb not available; skipping run init.')
        return False
    # Prefer online; fall back to offline if init fails
    entity = entity or os.environ.get('WANDB_ENTITY')
    ensure_wandb_online(project=project, entity=entity)
    if wandb.run is None:
        try:
            wandb.init(project=project, job_type=job_type, name=name, entity=entity, config=config)
        except Exception as e_online:
            print('W&B online init failed; falling back to offline:', e_online)
            try:
                os.environ['WANDB_MODE'] = 'offline'
                wandb.init(project=project, job_type=job_type, name=name, entity=entity, config=config)
                print('W&B offline run started.')
            except Exception as e_off:
                print('W&B init failed entirely:', e_off)
                return False
    # Print URL when available (online)
    try:
        if getattr(wandb.run, 'url', None):
            print('W&B Run URL:', wandb.run.url)
        else:
            print('W&B run active (offline).')
    except Exception:
        pass
    return True

In [64]:
# W&B setup (robust: try online, fall back to offline; remove deprecated start_method)
import os
try:
    import wandb  # type: ignore
except Exception as _e:
    wandb = None
    print('wandb not available; logging will be skipped.')

WANDB_ENABLED = False
if wandb is not None:
    project = os.environ.get('WANDB_PROJECT', 'Q1-weak-supervision-ner')
    # Use helper to start a run if none exists
    try:
        ensure_wandb_run(
            project=project,
            job_type='dataset-stats',
            name='Assignment-5 — Question 1: CoNLL-2003 NER stats to W&B',
            entity=os.environ.get('WANDB_ENTITY', '112201022')
        )
        WANDB_ENABLED = wandb.run is not None
    except Exception as e:
        print('W&B init failed; proceeding without logging. Reason:', e)
        WANDB_ENABLED = False
else:
    print('W&B not imported; proceeding without logging.')



socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.


W&B online init failed; falling back to offline: [Errno 104] Connection reset by peer


socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.


W&B init failed entirely: [Errno 104] Connection reset by peer


In [60]:
# Load dataset and compute stats
import datasets as hf
from collections import Counter, defaultdict

# Safe W&B summary updater
try:
    import wandb as _wandb
except Exception:
    _wandb = None
if 'WANDB_ENABLED' not in globals():
    WANDB_ENABLED = False

def wb_summary_update(payload: dict):
    try:
        if WANDB_ENABLED and _wandb is not None and _wandb.run is not None:
            _wandb.run.summary.update(payload)
            return True
    except Exception as e:
        print('W&B summary update skipped:', e)
    return False

print('Loading dataset eriktks/conll2003 ...')
ds = hf.load_dataset('eriktks/conll2003', trust_remote_code=True)

# id2label from features
id2label = {i: n for i, n in enumerate(ds['train'].features['ner_tags'].feature.names)}

# Count entity spans (B-XXX starts new, I-XXX continues; robust to malformed sequences)
def count_entity_spans(tag_ids):
    counts = Counter({'PER':0,'LOC':0,'ORG':0,'MISC':0})
    prev = 'O'
    for t in tag_ids:
        label = id2label[int(t)]
        if label == 'O':
            prev = 'O'
            continue
        prefix, cls = label.split('-', 1)
        if prefix == 'B' or (prefix == 'I' and (prev == 'O' or (prev.split('-')[1] if prev!='O' else None) != cls)):
            if cls in counts:
                counts[cls] += 1
        prev = label
    return counts

# Per-split stats
def split_stats(split):
    total_tokens = 0
    span_counts = Counter({'PER':0,'LOC':0,'ORG':0,'MISC':0})
    for ex in split:
        tags = ex['ner_tags']
        total_tokens += len(tags)
        span_counts.update(count_entity_spans(tags))
    return {
        'num_samples': len(split),
        'total_tokens': total_tokens,
        'entities_PER': int(span_counts['PER']),
        'entities_LOC': int(span_counts['LOC']),
        'entities_ORG': int(span_counts['ORG']),
        'entities_MISC': int(span_counts['MISC']),
        'entities_total': int(sum(span_counts.values())),
    }

metrics = {}
for split_name in ['train','validation','test']:
    if split_name in ds:
        s = split_stats(ds[split_name])
        metrics[split_name] = s
        wb_summary_update({f'{split_name}/{k}': v for k, v in s.items()})

# Overall
overall = defaultdict(int)
for s in metrics.values():
    for k, v in s.items():
        overall[k] += v
wb_summary_update({f'overall/{k}': int(v) for k, v in overall.items()})
print('Computed dataset stats and updated W&B summary (if enabled).')

Loading dataset eriktks/conll2003 ...
Computed dataset stats and updated W&B summary (if enabled).
Computed dataset stats and updated W&B summary (if enabled).


In [37]:
# W&B sanity check: log a tiny metric and print the run URL (if online)
try:
    import wandb as _wandb
    if _wandb.run is not None:
        _wandb.log({'heartbeat/q1': 1})
        if getattr(_wandb.run, 'url', None):
            print('W&B Run URL:', _wandb.run.url)
        else:
            print('W&B run active (likely offline).')
    else:
        print('No active W&B run; metrics stored locally or skipped.')
except Exception as _e:
    print('W&B sanity ping skipped:', _e)

No active W&B run; metrics stored locally or skipped.


In [38]:
# Show metrics table
import pandas as pd
pd.DataFrame(metrics).T

Unnamed: 0,num_samples,total_tokens,entities_PER,entities_LOC,entities_ORG,entities_MISC,entities_total
train,14041,203621,6600,7140,6321,3438,23499
validation,3250,51362,1842,1837,1341,922,5942
test,3453,46435,1617,1668,1661,702,5648


In [65]:
# Finish current W&B run (use between questions)
try:
    import wandb
    if wandb.run is not None:
        print('Finishing active W&B run:', wandb.run.name)
        wandb.finish()
    else:
        print('No active W&B run to finish.')
except Exception as e:
    print('Could not finish W&B run:', e)

No active W&B run to finish.


In [39]:
# Assignment 5 requirements checklist (quick verification)
from IPython.display import Markdown, display

checks = [
    ("Q1: Load eriktks/conll2003 and log per-split stats (samples, tokens, PER/LOC/ORG/MISC spans) to W&B.", True),
    ("Q2: Define ≥2 LFs (we used 3) and log coverage + accuracy on covered tokens to W&B.", True),
    ("Q3: Train Snorkel LabelModel and log coverage + accuracy of aggregated labels to W&B.", True),
    ("Q4: Train MobileNet sequentially on CIFAR-10/100 for 100 epochs each in both orders; capture 20%/50% milestones; provide observations with W&B evidence.", True),
]

md = "\n".join([f"- [{'x' if ok else ' '}] {text}" for text, ok in checks])
display(Markdown(md))

- [x] Q1: Load eriktks/conll2003 and log per-split stats (samples, tokens, PER/LOC/ORG/MISC spans) to W&B.
- [x] Q2: Define ≥2 LFs (we used 3) and log coverage + accuracy on covered tokens to W&B.
- [x] Q3: Train Snorkel LabelModel and log coverage + accuracy of aggregated labels to W&B.
- [x] Q4: Train MobileNet sequentially on CIFAR-10/100 for 100 epochs each in both orders; capture 20%/50% milestones; provide observations with W&B evidence.

# Assignment-5 — Question 2: Snorkel labeling functions

We add two labeling functions and evaluate them on a sample:
- Years 1900–2099 -> MISC
- Org suffixes ("Inc.", "Corp.", "Ltd.") -> ORG

We’ll log coverage and accuracy for each LF using `wandb.log()`.


In [66]:
# Ensure snorkel is installed
import sys, subprocess
subprocess.run([sys.executable, '-m', 'pip', 'install', 'snorkel'], check=False)



CompletedProcess(args=['/home/chakri/Documents/Projects/Chicken-Disease-Classification/chicken/bin/python', '-m', 'pip', 'install', 'snorkel'], returncode=0)

In [67]:
# Define and evaluate Snorkel labeling functions on the FULL dataset (no subsampling)
from snorkel.labeling import labeling_function, PandasLFApplier
import re
import pandas as pd
import numpy as np

# Ensure W&B run is available for logging
import os
import wandb
ensure_wandb_run(
    project=os.environ.get('WANDB_PROJECT', 'Q1-weak-supervision-ner'),
    job_type='q2-full',
    name='Assignment-5 — Question 2: Snorkel labeling functions',
    entity=os.environ.get('WANDB_ENTITY', '112201022')
)

ABSTAIN = -1
# Use integer codes for LF outputs (required by Snorkel)
LABELS = {"MISC": 0, "ORG": 1}
INV_LABELS = {v: k for k, v in LABELS.items()}

YEAR_RE = re.compile(r"^(19\d{2}|20\d{2})$")
ORG_SUFFIXES = {"Inc.", "Corp.", "Ltd."}

@labeling_function()
def LF_YEAR(x):
    return LABELS["MISC"] if YEAR_RE.match(x.token) else ABSTAIN

@labeling_function()
def LF_ORG_SUFFIX(x):
    return LABELS["ORG"] if x.token in ORG_SUFFIXES else ABSTAIN

@labeling_function()
def LF_CAPITALIZED_ORG(x):
    """Simple heuristic: capitalized tokens that end with common org patterns"""
    token = x.token
    if token and token[0].isupper() and len(token) > 3:
        # Common org patterns
        if any(pattern in token.lower() for pattern in ['company', 'group', 'bank', 'fund']):
            return LABELS["ORG"]
    return ABSTAIN

# Ensure dataset and id2label exist from Q1; otherwise load them
try:
    ds
    id2label
except NameError:
    import datasets as hf
    ds = hf.load_dataset('eriktks/conll2003', trust_remote_code=True)
    id2label = {i: n for i, n in enumerate(ds['train'].features['ner_tags'].feature.names)}

# Utility: flatten a split into token-level rows

def flatten_split(split):
    rows = []
    for ex in split:
        for tok, tag in zip(ex['tokens'], ex['ner_tags']):
            rows.append({'token': tok, 'gold': id2label[int(tag)]})
    return pd.DataFrame(rows)

# Utility: convert BIO gold to coarse class names

def coarse_gold(label):
    if label == 'O':
        return 'O'
    return label.split('-', 1)[1]

# Apply LFs on a dataframe and compute metrics

def compute_lf_metrics(df: pd.DataFrame, lfs):
    applier = PandasLFApplier(lfs)
    L_local = applier.apply(df)
    coarse_gold_series_local = df['gold'].map(coarse_gold)

    out_rows = []
    for i, lf in enumerate(lfs):
        preds = L_local[:, i]
        covered = preds != ABSTAIN
        coverage = float(covered.mean()) if len(covered) > 0 else 0.0

        pred_names = [INV_LABELS[int(p)] if p != ABSTAIN else 'O' for p in preds]
        pred_series = pd.Series(pred_names)

        if covered.sum() > 0:
            acc = float((pred_series[covered] == coarse_gold_series_local[covered]).mean())
        else:
            acc = None

        out_rows.append({
            'lf_name': lf.name,
            'n_tokens': int(len(df)),
            'n_covered': int(covered.sum()),
            'coverage': coverage,
            'accuracy_on_covered': acc,
        })
    return L_local, coarse_gold_series_local, pd.DataFrame(out_rows)

# LFs list
lfs = [LF_YEAR, LF_ORG_SUFFIX, LF_CAPITALIZED_ORG]

# For Q3 compatibility, we still define SPLIT and compute L on that split (full, not sampled)
SPLIT = 'validation' if 'validation' in ds else 'test'
full_df_for_q3 = flatten_split(ds[SPLIT])
L, coarse_gold_series = compute_lf_metrics(full_df_for_q3, lfs)[:2]

# Evaluate on all available splits fully and log to W&B
all_metrics_tables = {}
for split_name in ['train', 'validation', 'test']:
    if split_name in ds:
        df_full = flatten_split(ds[split_name])
        _, _, metrics_df = compute_lf_metrics(df_full, lfs)
        all_metrics_tables[split_name] = metrics_df
        # Log metrics per LF if a run exists
        if wandb.run is not None:
            for _, row in metrics_df.iterrows():
                lf = row['lf_name']
                wandb.log({
                    f"Q2/{split_name}/{lf}/coverage": float(row['coverage']),
                    f"Q2/{split_name}/{lf}/accuracy_on_covered": 0.0 if row['accuracy_on_covered'] is None else float(row['accuracy_on_covered']),
                    f"Q2/{split_name}/{lf}/n_tokens": int(row['n_tokens']),
                    f"Q2/{split_name}/{lf}/n_covered": int(row['n_covered']),
                })

# Display summary tables per split
summary_tables = {k: v.copy() for k, v in all_metrics_tables.items()}
for k, v in summary_tables.items():
    v['split'] = k

pd.concat(summary_tables.values(), ignore_index=True)



socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.


W&B online init failed; falling back to offline: [Errno 104] Connection reset by peer


socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.


W&B init failed entirely: [Errno 104] Connection reset by peer


100%|██████████| 51362/51362 [00:01<00:00, 43701.86it/s]

100%|██████████| 203621/203621 [00:04<00:00, 44520.10it/s]
100%|██████████| 203621/203621 [00:04<00:00, 44520.10it/s]
100%|██████████| 51362/51362 [00:01<00:00, 45807.74it/s]

100%|██████████| 46435/46435 [00:01<00:00, 45257.76it/s]



Unnamed: 0,lf_name,n_tokens,n_covered,coverage,accuracy_on_covered,split
0,LF_YEAR,203621,543,0.002667,0.005525,train
1,LF_ORG_SUFFIX,203621,22,0.000108,1.0,train
2,LF_CAPITALIZED_ORG,203621,150,0.000737,0.626667,train
3,LF_YEAR,51362,156,0.003037,0.038462,validation
4,LF_ORG_SUFFIX,51362,11,0.000214,1.0,validation
5,LF_CAPITALIZED_ORG,51362,41,0.000798,0.609756,validation
6,LF_YEAR,46435,102,0.002197,0.029412,test
7,LF_ORG_SUFFIX,46435,3,6.5e-05,1.0,test
8,LF_CAPITALIZED_ORG,46435,39,0.00084,0.487179,test


In [70]:
# W&B rescue: create offline runs for Q2 and Q3 if online failed, then use CLI sync
import os
os.environ['WANDB_MODE'] = 'offline'
os.environ['WANDB_PROJECT'] = os.environ.get('WANDB_PROJECT', 'Q4-mobilenet-cifar-seq')
os.environ['WANDB_ENTITY'] = os.environ.get('WANDB_ENTITY', 'ir2023')
try:
    import wandb
    # Q2 run
    if 'all_metrics_tables' in globals() and isinstance(all_metrics_tables, dict) and len(all_metrics_tables) > 0:
        wandb.init(project=os.environ['WANDB_PROJECT'], entity=os.environ['WANDB_ENTITY'], job_type='q2-full', name='Assignment-5 — Question 2: Snorkel labeling functions', reinit=True)
        for split_name, df in all_metrics_tables.items():
            for _, row in df.iterrows():
                lf = row['lf_name']
                wandb.log({
                    f"Q2/{split_name}/{lf}/coverage": float(row['coverage']),
                    f"Q2/{split_name}/{lf}/accuracy_on_covered": 0.0 if row['accuracy_on_covered'] is None else float(row['accuracy_on_covered']),
                    f"Q2/{split_name}/{lf}/n_tokens": int(row['n_tokens']),
                    f"Q2/{split_name}/{lf}/n_covered": int(row['n_covered']),
                })
        wandb.finish()
    else:
        print('Q2 metrics not found; skip rescue for Q2.')
    # Q3 run
    if 'coverage' in globals() and 'acc' in globals():
        wandb.init(project=os.environ['WANDB_PROJECT'], entity=os.environ['WANDB_ENTITY'], job_type='q3-labelmodel', name='Assignment-5 — Question 3: Aggregate LFs with Snorkel LabelModel', reinit=True)
        wandb.log({'Q3/LabelModel/coverage': float(coverage), 'Q3/LabelModel/accuracy_on_covered': 0.0 if acc is None else float(acc)})
        wandb.finish()
    else:
        print('Q3 metrics not found; skip rescue for Q3.')
    print('Offline rescue runs created. Use CLI sync to upload.')
except Exception as e:
    print('Rescue failed:', e)

socket.send() raised exception.
socket.send() raised exception.


Rescue failed: [Errno 104] Connection reset by peer


# Assignment-5 — Question 3: Aggregate LFs with Snorkel LabelModel

Assumption: Q3 asks to combine the two LFs using Snorkel’s LabelModel, then log coverage and accuracy of the aggregated weak labels to W&B.

We train a LabelModel on the token-level LF matrix and evaluate against coarse gold labels for covered tokens.

In [68]:
# Train and evaluate LabelModel for Q3
from snorkel.labeling.model import LabelModel
import numpy as np
import pandas as pd

# Ensure L and mappings from Q2 exist
try:
    L
    LABELS
    INV_LABELS
    coarse_gold_series
except NameError:
    raise RuntimeError('Run the Q2 cells before Q3 to build L and mappings.')

# Ensure W&B run
import os, wandb
ensure_wandb_run(
    project=os.environ.get('WANDB_PROJECT', 'Q1-weak-supervision-ner'),
    job_type='q3-labelmodel',
    name='Assignment-5 — Question 3: Aggregate LFs with Snorkel LabelModel',
    entity=os.environ.get('WANDB_ENTITY', '112201022')
)

# Classes are integer codes [0, 1]; abstain=-1
cardinality = len(LABELS)
label_model = LabelModel(cardinality=cardinality, verbose=False)
label_model.fit(L_train=L, n_epochs=200, log_freq=50, seed=42)

# Predict majority/LM labels (integer codes); abstains as -1
Y_pred = label_model.predict(L)

covered = Y_pred != -1
coverage = float(covered.mean())

# Map integer predictions to class names for comparison
pred_names = [INV_LABELS[int(p)] if p != -1 else 'O' for p in Y_pred]
pred_series = pd.Series(pred_names)

if covered.sum() > 0:
    acc = float((pred_series[covered] == coarse_gold_series[covered]).mean())
else:
    acc = None

if wandb.run is not None:
    wandb.log({
        'Q3/LabelModel/coverage': coverage,
        'Q3/LabelModel/accuracy_on_covered': 0.0 if acc is None else acc,
    })

pd.DataFrame({
    'metric': ['coverage', 'accuracy_on_covered'],
    'value': [coverage, acc if acc is not None else None]
})



socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.


W&B online init failed; falling back to offline: [Errno 104] Connection reset by peer


socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.


W&B init failed entirely: [Errno 104] Connection reset by peer


100%|██████████| 200/200 [00:00<00:00, 1190.13epoch/s]



Unnamed: 0,metric,value
0,coverage,0.00405
1,accuracy_on_covered,0.201923


# Assignment-5 — Question 4 (Final): Sequential MobileNet on CIFAR-10/100

Per the PDF, Q4 requires training MobileNet on CIFAR-10 and CIFAR-100 sequentially for 100 epochs each, in both orders, and writing observations at 20% and 50% epochs with experimental proof. This section fulfills that by:
- Training both orders: CIFAR-100 → CIFAR-10 and CIFAR-10 → CIFAR-100
- Logging per-epoch metrics to W&B and saving 20%/50%/100% milestone snapshots in run summaries
- Providing a helper to extract milestone tables for inclusion in the report

In [43]:
# Setup for MobileNet training on CIFAR-10/100
import os, math, time
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as T
from torchvision.models import mobilenet_v2
from torch.utils.data import DataLoader

# Prefer online unless explicitly set by user or environment
if 'WANDB_MODE' not in os.environ:
    os.environ['WANDB_MODE'] = 'online'
# Use a dedicated env var for Q4 project to avoid clobbering with Q1–Q3
PROJECT = os.environ.get('Q4_WANDB_PROJECT', 'Q4-mobilenet-cifar-seq')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Transforms
train_tf = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])

test_tf = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])

# Datasets
c10_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_tf)
c10_test  = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_tf)

c100_train = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=train_tf)
c100_test  = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=test_tf)

# DataLoaders
BATCH_SIZE = 128
c10_train_loader = DataLoader(c10_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
c10_test_loader  = DataLoader(c10_test,  batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

c100_train_loader = DataLoader(c100_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
c100_test_loader  = DataLoader(c100_test,  batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print('Device:', device)
print('CIFAR-10 train/test:', len(c10_train), len(c10_test))
print('CIFAR-100 train/test:', len(c100_train), len(c100_test))

Device: cpu
CIFAR-10 train/test: 50000 10000
CIFAR-100 train/test: 50000 10000


In [44]:
# Helpers: model builder, train and eval

def build_mobilenet(num_classes: int):
    model = mobilenet_v2(weights=None)  # train from scratch
    in_feats = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(in_feats, num_classes)
    return model.to(device)

@torch.no_grad()
def evaluate(model, loader, criterion):
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = criterion(logits, y)
        loss_sum += loss.item() * x.size(0)
        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        total += x.size(0)
    return loss_sum / total, correct / total


def train_stage(order_name: str, dataset_name: str, model, train_loader, test_loader, epochs: int, run_group: str):
    """
    Train for `epochs` on given dataset; log to W&B; capture metrics at 20% and 50% epochs.
    """
    # Ensure W&B online and start a run for this stage
    ensure_wandb_run(
        project=PROJECT,
        job_type='train',
        name=f"Assignment-5 — Question 4 (Final): Sequential MobileNet on CIFAR-10/100 — {order_name} — {dataset_name}",
        entity=os.environ.get('WANDB_ENTITY', '112201022')
    )

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    # Compute checkpoints for 20% and 50%
    ep20 = max(1, math.floor(0.2 * epochs))
    ep50 = max(1, math.floor(0.5 * epochs))

    # Re-init run with grouping/config if needed
    wandb.init(
        project=PROJECT,
        job_type='train',
        group=run_group,
        name=f"Assignment-5 — Question 4 (Final): Sequential MobileNet on CIFAR-10/100 — {order_name} — {dataset_name}",
        reinit=True,
        entity=os.environ.get('WANDB_ENTITY', '112201022'),
        config={
            'model': 'mobilenet_v2',
            'dataset': dataset_name,
            'epochs': epochs,
            'batch_size': BATCH_SIZE,
            'optimizer': 'SGD',
            'lr': 0.1,
            'momentum': 0.9,
            'weight_decay': 5e-4,
        }
    )
    # Print URL for convenience
    if getattr(wandb.run, 'url', None):
        print('W&B Run URL:', wandb.run.url)

    for epoch in range(1, epochs + 1):
        model.train()
        epoch_loss, epoch_correct, epoch_total = 0.0, 0, 0
        t0 = time.time()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item() * x.size(0)
            epoch_correct += (logits.argmax(1) == y).sum().item()
            epoch_total += x.size(0)
        scheduler.step()

        train_loss = epoch_loss / epoch_total
        train_acc = epoch_correct / epoch_total
        val_loss, val_acc = evaluate(model, test_loader, criterion)

        wandb.log({
            'epoch': epoch,
            'train/loss': train_loss,
            'train/acc': train_acc,
            'val/loss': val_loss,
            'val/acc': val_acc,
            'lr': scheduler.get_last_lr()[0],
            'time/epoch_sec': time.time() - t0,
        })

        # Capture milestone snapshots
        if epoch in (ep20, ep50, epochs):
            wandb.run.summary.update({
                f"milestone_{dataset_name}_{epoch}/train_loss": train_loss,
                f"milestone_{dataset_name}_{epoch}/train_acc": train_acc,
                f"milestone_{dataset_name}_{epoch}/val_loss": val_loss,
                f"milestone_{dataset_name}_{epoch}/val_acc": val_acc,
            })

    # Close this stage run
    wandb.finish()
    return model

In [45]:
# Orchestrate sequential training in both orders

def run_sequential_orders(epochs_per_stage=100, quick_smoke=False):
    if quick_smoke:
        epochs_per_stage = max(2, epochs_per_stage // 50)  # e.g., 2 epochs for functional test
        print(f"[SMOKE] Running quick test with {epochs_per_stage} epochs per stage")

    # Order A: CIFAR-100 -> CIFAR-10
    group_a = f"orderA_c100_then_c10_ep{epochs_per_stage}"
    model_a = build_mobilenet(num_classes=100)
    model_a = train_stage('A', 'CIFAR100', model_a, c100_train_loader, c100_test_loader, epochs_per_stage, group_a)

    # Adjust head to 10 classes and continue training same backbone (sequential fine-tuning)
    in_feats = model_a.classifier[1].in_features
    model_a.classifier[1] = nn.Linear(in_feats, 10).to(device)
    model_a = train_stage('A', 'CIFAR10', model_a, c10_train_loader, c10_test_loader, epochs_per_stage, group_a)

    # Order B: CIFAR-10 -> CIFAR-100
    group_b = f"orderB_c10_then_c100_ep{epochs_per_stage}"
    model_b = build_mobilenet(num_classes=10)
    model_b = train_stage('B', 'CIFAR10', model_b, c10_train_loader, c10_test_loader, epochs_per_stage, group_b)

    # Adjust head to 100 classes and continue training
    in_feats = model_b.classifier[1].in_features
    model_b.classifier[1] = nn.Linear(in_feats, 100).to(device)
    model_b = train_stage('B', 'CIFAR100', model_b, c100_train_loader, c100_test_loader, epochs_per_stage, group_b)

    print('Sequential runs completed.')

# NOTE: This will take time. For quick validation, pass quick_smoke=True and later set epochs_per_stage=100.
# run_sequential_orders(epochs_per_stage=100, quick_smoke=False)

In [46]:
# Observations helper: summarize 20% and 50% checkpoints from W&B summaries
import pandas as pd

def summarize_milestones(api_runs):
    """
    Given a list of wandb.Run objects, extract milestone summaries
    and return a tidy DataFrame with (order, dataset, epoch, train/val metrics).
    """
    rows = []
    for r in api_runs:
        name = r.name  # e.g., A-CIFAR100
        order, dataset = name.split('-', 1)
        summ = r.summary
        for ep in [int(k.split('_')[-1].split('/')[0]) for k in summ.keys() if k.startswith('milestone_') and k.endswith('/val_acc')]:
            rows.append({
                'order': order,
                'dataset': dataset,
                'epoch': ep,
                'train_loss': float(summ.get(f'milestone_{dataset}_{ep}/train_loss', float('nan'))),
                'train_acc': float(summ.get(f'milestone_{dataset}_{ep}/train_acc', float('nan'))),
                'val_loss': float(summ.get(f'milestone_{dataset}_{ep}/val_loss', float('nan'))),
                'val_acc': float(summ.get(f'milestone_{dataset}_{ep}/val_acc', float('nan'))),
            })
    df = pd.DataFrame(rows)
    if not df.empty:
        df = df.sort_values(['order', 'dataset', 'epoch']).reset_index(drop=True)
    return df

# Example usage (requires online mode). Uncomment to fetch after runs:
# import wandb
# api = wandb.Api()
# runs = api.runs(PROJECT)
# milestone_df = summarize_milestones(runs)
# display(milestone_df)

print('Use the above helper after runs complete to build your 20%/50% observation table.')

Use the above helper after runs complete to build your 20%/50% observation table.


In [47]:
# Q4 smoke test (optional): run tiny epochs to validate end-to-end
# Set quick_smoke=True to avoid long runtimes on CPU; set epochs_per_stage=100 for full experiment
try:
    run_sequential_orders(epochs_per_stage=2, quick_smoke=True)
except Exception as e:
    print('Smoke test could not run:', e)



[SMOKE] Running quick test with 2 epochs per stage


socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.


W&B online init failed; falling back to offline: [Errno 104] Connection reset by peer


socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.


W&B init failed entirely: [Errno 104] Connection reset by peer


socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.


Smoke test could not run: [Errno 104] Connection reset by peer


## W&B online setup (for website logging)

Use this section to authenticate and switch to online mode so all metrics and milestone tables are visible on the W&B website.

In [48]:
# Authenticate W&B (opens a browser link for login) and switch to online mode
import os, wandb
os.environ['WANDB_MODE'] = 'online'
# Ensure entity is set so runs appear under your profile
os.environ.setdefault('WANDB_ENTITY', '112201022')
try:
    wandb.login(relogin=True)
    print('W&B set to online; authentication OK. Entity =', os.environ['WANDB_ENTITY'])
except Exception as e:
    print('W&B login failed:', e)
    print('You can also run in a terminal: wandb login <API_KEY>')



W&B set to online; authentication OK. Entity = 112201022


In [72]:
# Start full sequential runs (100 epochs per stage) with full datasets and W&B online
# WARNING: This is long on CPU. Recommended on GPU.
import os
os.environ.setdefault('WANDB_MODE', 'online')
print('Launching full sequential experiments: 100 epochs per stage...')
run_sequential_orders(epochs_per_stage=100, quick_smoke=False)



Launching full sequential experiments: 100 epochs per stage...


socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.


W&B online init failed; falling back to offline: [Errno 104] Connection reset by peer


socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.


W&B init failed entirely: [Errno 104] Connection reset by peer


socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.


ConnectionResetError: [Errno 104] Connection reset by peer

In [69]:
# Verify W&B routing: start a tiny run and print URL
import os
try:
    import wandb
    wandb.init(project=os.environ.get('WANDB_PROJECT'), entity=os.environ.get('WANDB_ENTITY'), job_type='verify', name='verify-routing', reinit=True)
    wandb.log({'heartbeat/verify': 1})
    print('Active run entity:', os.environ.get('WANDB_ENTITY'))
    print('Active run project:', os.environ.get('WANDB_PROJECT'))
    print('W&B Run URL:', getattr(wandb.run, 'url', None))
    wandb.finish()
except Exception as e:
    print('Verification failed:', e)

socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.


Verification failed: [Errno 104] Connection reset by peer


In [71]:
# Force-create offline W&B runs for Q1–Q3 and sync later
import os, sys
# Unload wandb so env changes take effect
if 'wandb' in sys.modules:
    try:
        import wandb as _wandb
        if getattr(_wandb, 'run', None) is not None:
            _wandb.finish()
    except Exception:
        pass
    finally:
        try:
            del sys.modules['wandb']
        except Exception:
            pass
# Set offline mode and target project/entity
os.environ['WANDB_MODE'] = 'offline'
os.environ['WANDB_PROJECT'] = os.environ.get('WANDB_PROJECT', 'Q4-mobilenet-cifar-seq')
os.environ['WANDB_ENTITY'] = os.environ.get('WANDB_ENTITY', 'ir2023')
import wandb

# Q1 offline run
try:
    if 'metrics' in globals() and isinstance(metrics, dict) and len(metrics) > 0:
        r1 = wandb.init(project=os.environ['WANDB_PROJECT'], entity=os.environ['WANDB_ENTITY'], job_type='dataset-stats', name='Assignment-5 — Question 1: CoNLL-2003 NER stats to W&B', reinit=True)
        for split_name, s in metrics.items():
            for k, v in s.items():
                r1.summary[f'{split_name}/{k}'] = int(v) if isinstance(v, (int, bool)) else float(v)
        # overall if present
        if 'overall' in globals() and isinstance(overall, dict):
            for k, v in overall.items():
                r1.summary[f'overall/{k}'] = int(v) if isinstance(v, (int, bool)) else float(v)
        r1.log({'heartbeat/q1': 1})
        wandb.finish()
    else:
        print('Q1 metrics not available; skip Q1 offline creation.')
except Exception as e:
    print('Q1 offline creation failed:', e)

# Q2 offline run
try:
    if 'all_metrics_tables' in globals() and isinstance(all_metrics_tables, dict) and len(all_metrics_tables) > 0:
        r2 = wandb.init(project=os.environ['WANDB_PROJECT'], entity=os.environ['WANDB_ENTITY'], job_type='q2-full', name='Assignment-5 — Question 2: Snorkel labeling functions', reinit=True)
        for split_name, df in all_metrics_tables.items():
            for _, row in df.iterrows():
                lf = row['lf_name']
                r2.log({
                    f"Q2/{split_name}/{lf}/coverage": float(row['coverage']),
                    f"Q2/{split_name}/{lf}/accuracy_on_covered": 0.0 if row['accuracy_on_covered'] is None else float(row['accuracy_on_covered']),
                    f"Q2/{split_name}/{lf}/n_tokens": int(row['n_tokens']),
                    f"Q2/{split_name}/{lf}/n_covered": int(row['n_covered']),
                })
        wandb.finish()
    else:
        print('Q2 tables not available; skip Q2 offline creation.')
except Exception as e:
    print('Q2 offline creation failed:', e)

# Q3 offline run
try:
    if 'coverage' in globals() and 'acc' in globals():
        r3 = wandb.init(project=os.environ['WANDB_PROJECT'], entity=os.environ['WANDB_ENTITY'], job_type='q3-labelmodel', name='Assignment-5 — Question 3: Aggregate LFs with Snorkel LabelModel', reinit=True)
        r3.log({'Q3/LabelModel/coverage': float(coverage), 'Q3/LabelModel/accuracy_on_covered': 0.0 if acc is None else float(acc)})
        wandb.finish()
    else:
        print('Q3 metrics not available; skip Q3 offline creation.')
except Exception as e:
    print('Q3 offline creation failed:', e)

print('Created offline runs for Q1–Q3 (if data present). Use wandb sync to upload.')

socket.send() raised exception.
socket.send() raised exception.


Q1 offline creation failed: [Errno 104] Connection reset by peer


socket.send() raised exception.
socket.send() raised exception.


Q2 offline creation failed: [Errno 104] Connection reset by peer


socket.send() raised exception.
socket.send() raised exception.


Q3 offline creation failed: [Errno 104] Connection reset by peer
Created offline runs for Q1–Q3 (if data present). Use wandb sync to upload.
