In [1]:
%reload_ext autoreload
%autoreload 2

In [25]:
from pprint import pprint

In [3]:
from transformers import AutoModel

In [11]:
from sentform.modeling import SentenceTransformer
from sentform.pooling import MeanPooling
from sentform.utils import pairwise_cosine_similarity

# Sanity-check SentenceTransformer

In [12]:
backbone = AutoModel.from_pretrained("bert-base-uncased")

In [13]:
sentformer = SentenceTransformer(
    backbone=backbone,
    pooling_layer=MeanPooling()
)



In [14]:
sentformer

SentenceTransformer(
  (backbone): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, e

In [15]:
sentformer.embedding_dim

768

In [16]:
sentences = [
    "I love cat.",
    "I prefer dogs.",
    "I didn't like that movie."
]

In [17]:
embeddings = sentformer.encode(sentences)
embeddings.shape

torch.Size([3, 768])

In [18]:
pairwise_cosine_similarity(embeddings)

tensor([[1.0000, 0.7155, 0.6561],
        [0.7155, 1.0000, 0.6478],
        [0.6561, 0.6478, 1.0000]])

# Sanity-Check Multi-Task learner

In [20]:
from sentform.modeling import MultiTaskFormer
from sentform.heads import ClassificationHead, NERHead

In [21]:
multi_tasker = MultiTaskFormer(
    heads=[
        ClassificationHead(
            backbone.config.hidden_size,
            num_classes=3,
            labels=["A", "B", "C"],
            multi_label=False
        ),
        NERHead(
            backbone.config.hidden_size,
            num_tags=3,
            ner_tags=["Entity1", "Entity2", "Entity2"],
            multi_label=False
        )
    ],
    backbone=backbone,
)



In [22]:
outputs = multi_tasker(sentences)

In [23]:
outputs

{'head_0': {'logits': tensor([[-0.1184, -0.1660,  0.2809],
          [-0.0499, -0.0170,  0.3292],
          [ 0.1075, -0.0941,  0.2124]]),
  'predicted_labels': ['C', 'C', 'C']},
 'head_1': {'logits': tensor([[[ 0.0434,  0.2204,  0.5106],
           [ 0.1505,  0.5894,  0.2920],
           [-0.0443,  0.2464,  0.2912],
           [ 0.1045,  0.5564,  0.5694],
           [ 0.0764,  0.6265,  0.1294],
           [ 0.2616, -0.1412, -0.1685],
           [ 0.0261,  0.5208,  0.1622],
           [ 0.1476,  0.6539,  0.2159],
           [ 0.0511,  0.5468,  0.2664],
           [ 0.2583,  0.7474,  0.2863]],
  
          [[ 0.0248, -0.0257,  0.4707],
           [-0.0033,  0.3054,  0.1180],
           [-0.2258, -0.1988,  0.1635],
           [ 0.1792,  0.0022,  0.8497],
           [ 0.1765,  0.3426, -0.2418],
           [ 0.1304, -0.1063, -0.3706],
           [ 0.0858,  0.3002,  0.1983],
           [ 0.1181,  0.3982,  0.2706],
           [ 0.0992,  0.2153,  0.2534],
           [ 0.2846,  0.6430,  0.3579

In [26]:
for idx, out in outputs.items():
    print(idx, out["logits"].shape)
    pprint(out["predicted_labels"])
    print("-"*10)

head_0 torch.Size([3, 3])
['C', 'C', 'C']
----------
head_1 torch.Size([3, 10, 3])
[['Entity2',
  'Entity2',
  'Entity2',
  'Entity2',
  'Entity2',
  'Entity1',
  'Entity2',
  'Entity2',
  'Entity2',
  'Entity2'],
 ['Entity2',
  'Entity2',
  'Entity2',
  'Entity2',
  'Entity2',
  'Entity1',
  'Entity2',
  'Entity2',
  'Entity2',
  'Entity2'],
 ['Entity2',
  'Entity1',
  'Entity1',
  'Entity1',
  'Entity1',
  'Entity2',
  'Entity2',
  'Entity2',
  'Entity1',
  'Entity1']]
----------
