Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Crowd-Kit Learning #47

Merged
merged 13 commits into from
Nov 18, 2022
2 changes: 1 addition & 1 deletion Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ verify_ssl = true
name = "pypi"

[packages]
crowd-kit = {editable = true, path = "."}
crowd-kit = {editable = true, path = ".", extras = ["learning"]}

[dev-packages]
mypy = "*"
Expand Down
5 changes: 0 additions & 5 deletions crowdkit/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
'SegmentationRASA',
'TextHRRASA',
'TextRASA',
'TextSummarization',
'Wawa',
'ZeroBasedSkill',
'BinaryRelevance'
Expand All @@ -69,7 +68,3 @@ def is_arcadia() -> bool:
return cast(bool, __res == __res)
except ImportError:
return False


if not is_arcadia():
from .texts.text_summarization import TextSummarization
9 changes: 9 additions & 0 deletions crowdkit/learning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .conal import CoNAL
from .crowd_layer import CrowdLayer
from .text_summarization import TextSummarization

__all__ = [
'CoNAL',
'CrowdLayer',
'TextSummarization'
]
142 changes: 142 additions & 0 deletions crowdkit/learning/conal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Adapted from:
# https://github.com/zdchu/CoNAL/blob/main/conal.py
__all__ = [
'CoNAL',
]

from typing import Optional, Tuple, Union
from numpy.typing import NDArray

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from crowdkit.learning.utils import differentiable_ds


def _identity_init(shape: Union[Tuple[int, int], Tuple[int, int, int]]) -> torch.Tensor:
"""
Creates a tensor containing identity matrices.

Args:
shape (Tuple[int]): Tuple of ints representing the shape of the tensor.

Returns:
torch.Tensor: Tensor containing identity matrices.
"""
out = np.zeros(shape, dtype=np.float32)
if len(shape) == 3:
for r in range(shape[0]):
for i in range(shape[1]):
out[r, i, i] = 2.0
elif len(shape) == 2:
for i in range(shape[1]):
out[i, i] = 2.0
return torch.Tensor(out)


class CoNAL(nn.Module): # type: ignore
"""
Common Noise Adaptation Layers (CoNAL). This method introduces two types of confusions: worker-specific and
global. Each is parameterized by a confusion matrix. The ratio of the two confusions is determined by the
common noise adaptation layer. The common noise adaptation layer is a trainable function that takes the
instance embedding and the worker ID as input and outputs a scalar value between 0 and 1.

Zhendong Chu, Jing Ma, and Hongning Wang. Learning from Crowds by Modeling Common Confusions.
*Proceedings of the AAAI Conference on Artificial Intelligence*, 35(7), 5832-5840, 2021.
https://doi.org/10.1609/aaai.v35i7.16730

Examples:
>>> from crowdkit.learning import CoNAL
>>> import torch
>>> input = torch.randn(3, 5)
>>> workers = torch.tensor([0, 1, 0])
>>> embeddings = torch.randn(3, 5)
>>> conal = CoNAL(5, 2)
>>> conal(embeddings, input, workers)
"""

def __init__(
self,
num_labels: int,
n_workers: int,
com_emb_size: int = 20,
user_feature: Optional[NDArray[np.float32]] = None,
):
"""
Initializes the CoNAL module.

Args:
num_labels (int): Number of classes.
n_workers (int): Number of annotators.
com_emb_size (int): Embedding size of the common noise module.
user_feature (np.ndarray): User feature vector.
"""
super().__init__()
self.n_workers = n_workers
self.annotator_confusion_matrices = nn.Parameter(
_identity_init((n_workers, num_labels, num_labels)),
requires_grad=True,
)

self.common_confusion_matrix = nn.Parameter(
_identity_init((num_labels, num_labels)), requires_grad=True
)

user_feature = user_feature or np.eye(n_workers, dtype=np.float32)
self.user_feature_vec = nn.Parameter(
torch.from_numpy(user_feature).float(), requires_grad=False
)
self.diff_linear_1 = nn.LazyLinear(128)
self.diff_linear_2 = nn.Linear(128, com_emb_size)
self.user_feature_1 = nn.Linear(self.user_feature_vec.size(1), com_emb_size)

def simple_common_module(
self, input: torch.Tensor, workers: torch.Tensor
) -> torch.Tensor:
"""
Common noise adoptation module.

Args:
input (torch.Tensor): Tensor of shape (batch_size, embedding_size)
workers (torch.Tensor): Tensor of shape (batch_size,) containing the worker IDs.

Returns:
torch.Tensor: Tensor of shape (batch_size, 1) containing the common noise rate.
"""
instance_difficulty = self.diff_linear_1(input)
instance_difficulty = self.diff_linear_2(instance_difficulty)

instance_difficulty = F.normalize(instance_difficulty)
user_feature = self.user_feature_1(self.user_feature_vec[workers])
user_feature = F.normalize(user_feature)
common_rate = torch.sum(instance_difficulty * user_feature, dim=1)
common_rate = torch.nn.functional.sigmoid(common_rate).unsqueeze(1)
return common_rate

def forward(
self, embeddings: torch.Tensor, logits: torch.Tensor, workers: torch.Tensor
) -> torch.Tensor:
"""
Forward pass of the CoNAL module.

Args:
embeddings (torch.Tensor): Tensor of shape (batch_size, embedding_size)
logits (torch.Tensor): Tensor of shape (batch_size, num_classes)
workers (torch.Tensor): Tensor of shape (batch_size,) containing the worker IDs.

Returns:
torch.Tensor: Tensor of shape (batch_size, 1) containing the predicted output probabilities.
"""
x = embeddings.view(embeddings.size(0), -1)
common_rate = self.simple_common_module(x, workers)
common_prob = torch.einsum(
"ij,jk->ik", (F.softmax(logits, dim=-1), self.common_confusion_matrix)
)
batch_confusion_matrices = self.annotator_confusion_matrices[workers]
indivi_prob = differentiable_ds(logits, batch_confusion_matrices)
crowd_out: torch.Tensor = (
common_rate * common_prob + (1 - common_rate) * indivi_prob
) # single instance
return crowd_out
175 changes: 175 additions & 0 deletions crowdkit/learning/crowd_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
__all__ = [
'CrowdLayer',
]

from typing import Optional

import torch
from torch import nn

from crowdkit.learning.utils import batch_identity_matrices


def crowd_layer_mw(
outputs: torch.Tensor, workers: torch.Tensor, weight: torch.Tensor
) -> torch.Tensor:
"""
CrowdLayer MW transformation. Defined by multiplication on squared confusion matrix.
This complies with the Dawid-Skene model.

Args:
outputs (torch.Tensor): Tensor of shape (batch_size, input_dim)
workers (torch.Tensor): Tensor of shape (batch_size,) containing the worker IDs.
weight (torch.Tensor): Tensor of shape (batch_size, 1) containing the workers' confusion matrices.

Returns:
torch.Tensor: Tensor of shape (batch_size, input_dim)
"""
return torch.einsum(
"lij,ljk->lik", weight[workers], outputs.unsqueeze(-1)
).squeeze()


def crowd_layer_vw(
outputs: torch.Tensor, workers: torch.Tensor, weight: torch.Tensor
) -> torch.Tensor:
"""
CrowdLayer VW transformation. A linear transformation of the input without the bias.

Args:
outputs (torch.Tensor): Tensor of shape (batch_size, input_dim)
workers (torch.Tensor): Tensor of shape (batch_size,) containing the worker IDs.
weight (torch.Tensor): Tensor of shape (batch_size, 1) containing the worker-specific weights.

Returns:
torch.Tensor: Tensor of shape (batch_size, input_dim)
"""
return weight[workers] * outputs


def crowd_layer_vb(
outputs: torch.Tensor, workers: torch.Tensor, weight: torch.Tensor
) -> torch.Tensor:
"""
CrowdLayer Vb transformation. Adds a worker-specific bias to the input.

Args:
outputs (torch.Tensor): Tensor of shape (batch_size, input_dim)
workers (torch.Tensor): Tensor of shape (batch_size,) containing the worker IDs.
weight (torch.Tensor): Tensor of shape (batch_size, 1) containing the worker-specific biases.

Returns:
torch.Tensor: Tensor of shape (batch_size, input_dim)
"""
return outputs + weight[workers]


def crowd_layer_vw_b(
outputs: torch.Tensor,
workers: torch.Tensor,
scale: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
"""
CrowdLayer VW + b transformation. A linear transformation of the input with the bias.

Args:
outputs (torch.Tensor): Tensor of shape (batch_size, input_dim)
workers (torch.Tensor): Tensor of shape (batch_size,) containing the worker IDs.
scale (torch.Tensor): Tensor of shape (batch_size, 1) containing the worker-specific weights.
bias (torch.Tensor): Tensor of shape (batch_size, 1) containing the worker-specific biases.

Returns:
torch.Tensor: Tensor of shape (batch_size, input_dim)
"""
return scale[workers] * outputs + bias[workers]


class CrowdLayer(nn.Module): # type: ignore
"""
CrowdLayer module for classification tasks.

This method applies a worker-specific transformation of the logits. There are four types of transformations:
- MW: Multiplication on the worker's confusion matrix.
- VW: Element-wise multiplication with the worker's weight vector.
- VB: Element-wise addition with the worker's bias vector.
- VW + b: Combination of VW and VB: VW * logits + b.

Filipe Rodrigues and Francisco Pereira. Deep Learning from Crowds.
*Proceedings of the AAAI Conference on Artificial Intelligence, 32(1),* 2018.
https://doi.org/10.1609/aaai.v32i1.11506

Examples:
>>> from crowdkit.learning import CrowdLayer
>>> import torch
>>> input = torch.randn(3, 5)
>>> workers = torch.tensor([0, 1, 0])
>>> cl = CrowdLayer(5, 2, conn_type="mw")
>>> cl(input, workers)
"""

def __init__(
self,
num_labels: int,
n_workers: int,
conn_type: str = "mw",
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
"""
Args:
num_labels (int): Number of classes.
n_workers (int): Number of workers.
conn_type (str): Connection type. One of 'mw', 'vw', 'vb', 'vw+b'.
device (torch.DeviceObjType): Device to use.
dtype (torch.dtype): Data type to use.
Raises:
ValueError: If conn_type is not one of 'mw', 'vw', 'vb', 'vw+b'.
"""
super(CrowdLayer, self).__init__()
self.conn_type = conn_type

self.n_workers = n_workers
if conn_type == "mw":
self.weight = nn.Parameter(
batch_identity_matrices(n_workers, num_labels, dtype=dtype, device=device)
)
elif conn_type == "vw":
self.weight = nn.Parameter(
torch.ones(n_workers, num_labels, dtype=dtype, device=device)
)
elif conn_type == "vb":
self.weight = nn.Parameter(
torch.zeros(n_workers, num_labels, dtype=dtype, device=device)
)
elif conn_type == "vw+b":
self.scale = nn.Parameter(
torch.ones(n_workers, num_labels, dtype=dtype, device=device)
)
self.bias = nn.Parameter(
torch.zeros(n_workers, num_labels, dtype=dtype, device=device)
)
else:
raise ValueError("Unknown connection type for CrowdLayer.")

def forward(self, outputs: torch.Tensor, workers: torch.Tensor) -> torch.Tensor:
"""
Forward pass.

Args:
outputs (torch.Tensor): Tensor of shape (batch_size, input_dim)
workers (torch.Tensor): Tensor of shape (batch_size,) containing the worker IDs.

Returns:
torch.Tensor: Tensor of shape (batch_size, num_labels)
"""
if self.conn_type == "mw":
return crowd_layer_mw(outputs, workers, self.weight)
elif self.conn_type == "vw":
return crowd_layer_vw(outputs, workers, self.weight)
elif self.conn_type == "vb":
return crowd_layer_vb(outputs, workers, self.weight)
elif self.conn_type == "vw+b":
return crowd_layer_vw_b(outputs, workers, self.scale, self.bias)
else:
raise ValueError("Unknown connection type for CrowdLayer.")
1 change: 1 addition & 0 deletions crowdkit/learning/py.typed
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
inline
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__all__ = [
'TextSummarization'
'TextSummarization',
]

import itertools
Expand All @@ -10,7 +10,7 @@
import pandas as pd
from transformers import PreTrainedTokenizer, PreTrainedModel # type: ignore

from ..base import BaseTextsAggregator
from crowdkit.aggregation.base import BaseTextsAggregator


@attr.s
Expand Down Expand Up @@ -51,7 +51,7 @@ class TextSummarization(BaseTextsAggregator):
Example:
>>> import torch
>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig
>>> from crowdkit.aggregation import TextSummarization
>>> from crowdkit.learning import TextSummarization
>>> device = 'cuda' if torch.cuda.is_available() else 'cpu'
>>> mname = "toloka/t5-large-for-text-aggregation"
>>> tokenizer = AutoTokenizer.from_pretrained(mname)
Expand Down