# AG-News

Topic classification task.

Demonstrates:
* Implementing a DataReader
* Embedding textual data in preparation for feeding to CCP with byte-pair embeddings
* Learning soft pseudo-labels as q-vectors via Contrastive Credibility Propagation
* Classifying test examples

Evaluates:
* How well the labeler recovers q-vectors
* How well the classifier performs classification
    * (a) supervised learning with all 120k training examples (best case),
    * (b) supervised learning with missing labels (worst case),
    * (c) supervised learning with a subset of the training labels forgotten and then recovered as soft labels with CCP (CCP case)

This implementation automatically runs on a single GPU device named `cuda:0` if CUDA is available. If unavailable, it defaults to using the CPU.

In [None]:
%load_ext autoreload
%autoreload 2

import logging

logging.basicConfig(level=logging.DEBUG,
                    format='%(asctime)s %(levelname)s %(name)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S')

## Dataset, embedding that dataset, and loading batches of that dataset

`AgNewsDataReader` is a `DataReader`. If you want to transform the data, you should pass a Callable as the `transform` parameter. If you want the raw strings, you should pass `None`.

In [None]:
from ccp.encoders.byte_pair_embed import BytePairEmbed
from agnews_ccp_datareader import AgNewsDataReader

VOCABULARY_SIZE = 50_000
EMBEDDING_DIMENSIONALITY = 100
FORCED_LENGTH = 512

PROJECTION_HIDDEN_DIM = 64
PROJECTION_OUTPUT_DIM = 32 # representation of PROJECTION_OUTPUT_DIM size is used to calculate the soft supervised contrastive loss (L_SSC) for CCP pseudo-labeling

# CCP becomes particularly with few examples (e.g., MISSING_DATA_RATE around 0.99 or 0.999)
MISSING_DATA_RATE = 0.2
# sample_indices_to_use is populated so this code runs quickly on a CPU; it should be set to `None` to learn 
SAMPLE_INDICES = [0, 1, 2, 5000, 5001, 5002, 10000, 10001, 15000, 15001, 20000, 20001, 25000, 25001, 30000, 30001, 100000, 100001, 110000, 110001]


embed = BytePairEmbed(output_size=FORCED_LENGTH,
                      vocab_size=VOCABULARY_SIZE, 
                      embedding_dimensionality=EMBEDDING_DIMENSIONALITY)

ag_data = AgNewsDataReader(split="train", 
                           rate_to_keep=1-MISSING_DATA_RATE,
                           transform=embed,
                           sample_indices_to_use = SAMPLE_INDICES)


## Training CCP

CCP consists of two models, which share some substructure:
* The softlabeler ContrastiveCredibilityLabeller, which produces q-vectors for each sample using an encoder network f_b(x) and a projection head f_z(f_b(x)).
* The classifier ContrastiveCredibilityClassifier, which is a classification model for the target space. The model uses the same encoder network f_b(x) and a separate projection head f_g(f_b(x)).

As a computational performance optimization, you can optionally prewarm the network state, and reuse that state to reinitialize between iterations of CCP.

### Configure the soft labeling model

In [None]:
import torch.nn as nn

from ccp.softlabeler import ContrastiveCredibilityLabeller
from ccp.softlabeler.transforms.generic_transforms import Identity, GaussianNoise
from ccp.softlabeler.transforms.text_transforms import ParagraphSwap, BPEmbRandomVectorSwap, BPEmbVectorHide
from ccp.encoders.ccp_text_encoder import TextEncoder

from ccp.device_decision import DEVICE


OUTPUT_DIRECTORY = "ccp_labels"

encoder = TextEncoder(dim_in=(FORCED_LENGTH, EMBEDDING_DIMENSIONALITY)).to(DEVICE)

projection_head = nn.Sequential(
    nn.Linear(encoder.output_dim, PROJECTION_HIDDEN_DIM),
    nn.ReLU(),
    nn.Linear(PROJECTION_HIDDEN_DIM, PROJECTION_OUTPUT_DIM),
).to(DEVICE)

ccp_labeler = ContrastiveCredibilityLabeller(
    data_reader = ag_data,
    output_dir = OUTPUT_DIRECTORY,
    transforms = [
        Identity(), 
        GaussianNoise(), 
        ParagraphSwap(), 
        BPEmbRandomVectorSwap(embed),
        BPEmbVectorHide(embed)
    ],
    encoder_network_f_b = encoder,
    projection_head_f_z = projection_head,
    batch_size = 64
)

### Learn the q-labels

In [None]:
from ccp.softlabeler.ccp_labeler import EMALossExitCriteria

# The parameterization here is designed to run quickly on a CPU while learning very little.
# To learn for real, use the defaults rather than this parameterization.
ccp_inner_loop_iteration_exit_criterion = EMALossExitCriteria(ema_weight = 0.01, max_units_since_overwrite = 10, max_total_units = 20)

# Prewarm the network before training
ccp_labeler.prewarm_network(exit_criteria=ccp_inner_loop_iteration_exit_criterion)

# Train by hand
NUM_ITERATIONS = 3
args = {"previous_metadata": None}
for total_iterations in range(NUM_ITERATIONS):
    previous_metadata = args["previous_metadata"]
    _, args = ccp_labeler.execute_ccp_single_iteration(exit_criteria = ccp_inner_loop_iteration_exit_criterion,
                                                       output_prefix=f"iteration_{total_iterations}",
                                                       previous_metadata=previous_metadata,
                                                       print_loss_every_k_batches=5)
    
# Alternatively, train with automatic stopping condition, overwriting q vectors each time (no tracking of q values).
# ccp_overall_exit_criterion = EMALossExitCriteria(ema_weight = 0.01, max_units_since_overwrite = 10, max_total_units = 20) # reset to defaults to learn for real
# ccp_labeler.ema_train(ccp_overall_exit_criterion, 
#                       ccp_inner_loop_iteration_exit_criterion,
#                       print_loss_every_k_batches=5
#                      )


### Learn the classifier

In [None]:
from ccp.classifier.ccp_classifier import ContrastiveCredibilityClassifier
from ccp.param_init import init_weights_ccp

CLASSIFICATION_HIDDEN_DIM = 64
NUM_TARGETS = ag_data.n_distinct_labels


encoder = ccp_labeler.prewarmed_encoder()

classifier_projection_head = nn.Sequential(
    nn.Linear(encoder.output_dim, CLASSIFICATION_HIDDEN_DIM),
    nn.ReLU(),
    nn.Linear(CLASSIFICATION_HIDDEN_DIM, NUM_TARGETS),
)

ccp_classifier = ContrastiveCredibilityClassifier(encoder_network_f_b=encoder, 
                                                  projection_head_f_g=classifier_projection_head, 
                                                  q_dataset=ccp_labeler.classification_dataset(),
                                                  batch_size=5,
                                                  network_init_func=init_weights_ccp)

# The parameterization here is designed to run quickly on a CPU while learning very little.
# To learn for real, remove the parameterization to use defaults.
exit_criterion = EMALossExitCriteria(ema_weight = 0.01, max_units_since_overwrite = 10, max_total_units = 20)

# Run the classifier
ccp_classifier.ema_train(exit_criterion)

# Evaluating CCP

## Q-vector evaluation

In [None]:
import torch
import numpy as np
from ccp.typing import TargetLabel

def build_comparison(q_vals, ag_data_all) -> np.typing.NDArray:
    def _get_predicted_class(q_vals, item_idx) -> TargetLabel:
        if all(q_vals[item_idx] == 0):
            return ag_data_all.UNLABELLED_TARGET
        else:
            return q_vals[item_idx].argmax().item()

    true_predicted = np.zeros((len(ag_data_all), 2))
    for item_idx in range(len(ag_data_all)):
        true_class = ag_data_all[item_idx][1]
        predicted_class = _get_predicted_class(q_vals, item_idx)

        true_predicted[item_idx, :] = (true_class, predicted_class)
    return true_predicted

def evaluate(q_vals, ag_data_all):
    true_predicted = build_comparison(q_vals, ag_data_all)

    for true_label in range(ag_data_all.n_distinct_labels):
        data_subset = true_predicted[true_predicted[:, 0] == true_label]
        num_correct = sum(data_subset[:, 1] == true_label)
        num_unlabeled = sum(data_subset[:, 1] == ag_data_all.UNLABELLED_TARGET)
        num_wrong = len(data_subset) - num_correct - num_unlabeled
        print(f"for class {true_label}, we have {num_correct / len(data_subset):0.2%} correct labels, {num_unlabeled / len(data_subset):0.2%} still unlabeled, and {num_wrong / len(data_subset):0.2%} wrong")
    

# Same dataset as we used in training, but with NO data forgotten, so that we can compare the two labelsets
ag_data_train_untransformed = AgNewsDataReader(split="train", 
                                               rate_to_keep=1,
                                               transform=embed,
                                               sample_indices_to_use = SAMPLE_INDICES)

for i in range(NUM_ITERATIONS):
    print(f"\n\n-------- Iteration {i} --------")
    q_vals = torch.load(f"{OUTPUT_DIRECTORY}/iteration_{i}_q.pt")
    evaluate(q_vals, ag_data_train_untransformed)
    

## Classifier evaluation & comparison

In [None]:
from typing import Tuple

import torch.nn as nn
from torch.utils.data.dataloader import DataLoader

from ccp.datareaders import DataReader

def evaluate_classifier(classifier: nn.Module, data_reader: DataReader, batch_size: int = 100) -> Tuple[int, int]:
    """
    Assesses accuracy, which aligns with the CCP paper.
    
    Output is a tuple of (correct samples, total samples).
    """
    correct_prediction_count = 0
    for batched_data, true_labels in DataLoader(data_reader, batch_size = batch_size):
        predictions = classifier(batched_data).argmax(dim=1)
        correct_prediction_count += (true_labels == predictions).sum().item()

    return correct_prediction_count, len(data_reader)

###########
# Train -- same as the data we learned on, but with labels
ag_data_train_embedded = AgNewsDataReader(split="train", 
                                          rate_to_keep=1,
                                          transform=embed,
                                          sample_indices_to_use = SAMPLE_INDICES)

###########
# Test
ag_data_test_embedded = AgNewsDataReader(split="test",
                                         rate_to_keep=1,
                                         transform=embed)

### Performance of simple model with full data

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader

# We will train on the same data, but without any "forgetting" of labels
ag_full = AgNewsDataReader(split="train", 
                           rate_to_keep=1,
                           transform=embed,
                           sample_indices_to_use = SAMPLE_INDICES)

# This is NOT the CCP model -- we encode then predict
nonccp_loader = DataLoader(ag_full, batch_size=256)
nonccp_encoder = TextEncoder(dim_in=(FORCED_LENGTH, EMBEDDING_DIMENSIONALITY))
nonccp_model_full = nn.Sequential(
    nonccp_encoder,
    nn.Linear(in_features=nonccp_encoder.output_dim, out_features=ag_full.n_distinct_labels)
)


NUM_EPOCHS = 10

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(nonccp_model_full.parameters(),
                            lr=0.1,
                            weight_decay=1e-2,
                            momentum=0.9)
for epoch in range(NUM_EPOCHS):
    print(f"epoch {epoch} batches:", end=" ")
    for which_batch_within_data, (X_batch, y_batch) in enumerate(nonccp_loader):
        print(which_batch_within_data, end=" ")
        y_batch_pred = nonccp_model_full(X_batch)
        loss = loss_fn(y_batch_pred, y_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print()
    print(f"final loss in epoch {epoch} = {loss.item()}")

print()
train_correct, train_total = evaluate_classifier(nonccp_model_full, ag_data_train_embedded)
print(f"Simple model with full data -- train accuracy: {train_correct / train_total:0.2%} ({train_correct}/{train_total})")

test_correct, test_total = evaluate_classifier(nonccp_model_full, ag_data_test_embedded)
print(f"Simple model with full data -- test accuracy: {test_correct / test_total:0.2%} ({test_correct}/{test_total})")


### Performance of simple model with missing data

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader

# We reuse literally the same dataset from CCP, dropping -1 examples
# But this is NOT the CCP model.
nonccp_loader = DataLoader(ag_data, batch_size=256)
nonccp_encoder = TextEncoder(dim_in=(FORCED_LENGTH, EMBEDDING_DIMENSIONALITY))
nonccp_model_partial = nn.Sequential(
    nonccp_encoder,
    nn.Linear(in_features=nonccp_encoder.output_dim, out_features=ag_full.n_distinct_labels)
)


NUM_EPOCHS = 10

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(nonccp_model_partial.parameters(),
                            lr=0.1,
                            weight_decay=1e-2,
                            momentum=0.9)
for epoch in range(NUM_EPOCHS):
    print(f"epoch {epoch} batches:", end=" ")
    for which_batch_within_data, (X_batch, y_batch) in enumerate(nonccp_loader):
        # drop the unlabeled data during training
        X_batch = X_batch[y_batch != DataReader.UNLABELLED_TARGET]
        y_batch = y_batch[y_batch != DataReader.UNLABELLED_TARGET]

        y_batch_pred = nonccp_model_partial(X_batch)
        loss = loss_fn(y_batch_pred, y_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print()
    print(f"final loss in epoch {epoch} = {loss.item()}")

print()
train_correct, train_total = evaluate_classifier(nonccp_model_partial, ag_data_train_embedded)
print(f"Simple model with {MISSING_DATA_RATE:0.0%} missing -- train accuracy: {train_correct / train_total:0.2%} ({train_correct}/{train_total})")

test_correct, test_total = evaluate_classifier(nonccp_model_partial, ag_data_test_embedded)
print(f"Simple model with {MISSING_DATA_RATE:0.0%} missing -- test accuracy: {test_correct / test_total:0.2%} ({test_correct}/{test_total})")


### Performance of CCP on missing data: classification after label recovery

In [None]:
train_correct, train_total = evaluate_classifier(ccp_classifier.classifier(), ag_data_train_embedded)
print(f"CCP with {MISSING_DATA_RATE:0.0%} missing -- train accuracy: {train_correct / train_total:0.2%} ({train_correct}/{train_total})")

test_correct, test_total = evaluate_classifier(ccp_classifier.classifier(), ag_data_test_embedded)
print(f"CCP with {MISSING_DATA_RATE:0.0%} missing -- test accuracy: {test_correct / test_total:0.2%} ({test_correct}/{test_total})")
