In [1]:
import torch
from transformers import AutoTokenizer
from reduced_encoders import (
    BertReducedConfig,                          # PASS
    MPNetReducedConfig,                         # PASS
    DimReshape,                                 # PASS
    DimReduce,                                  # PASS
    DimExpand,                                  # PASS
    BertReducedModel,                           # PASS
    BertReducedForPreTraining,                  # PASS
    BertReducedForSequenceClassification,       # PASS
    MPNetReducedModel,                          # PASS
    MPNetReducedForSequenceClassification,      # PASS
    MPNetCompressedModel,                       # PASS
    MPNetCompressedForPreTraining,              # PASS
    MPNetCompressedForSequenceClassification,   # PASS
)

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
text = ["This is a test sentence to pass through my models",
        "This is a second test sentence to pass through my models that is similar",
        "The quick brown fox jumped over the lazy dog",] 

In [4]:
bert_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
bert_inputs = bert_tokenizer(text, padding=True, truncation=True, return_tensors="pt")
bert_inputs = {key:value.to(device) for key, value in bert_inputs.items()}
bert_inputs['output_hidden_states'] = True
bert_inputs['output_attentions'] = True
bert_labels = torch.tensor([0, 1, 0]).to(device)



In [5]:
mpnet_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2")
mpnet_inputs = mpnet_tokenizer(text, padding=True, truncation=True, return_tensors="pt")
mpnet_inputs = {key:value.to(device) for key, value in mpnet_inputs.items()}
mpnet_inputs['output_hidden_states'] = True
mpnet_inputs['output_attentions'] = True
mpnet_labels = torch.tensor([0, 1, 0]).to(device)

## Test Configs

#### Initialization

In [6]:
# PASS(1/2): No errors
bert_config = BertReducedConfig()
bert_config

BertReducedConfig {
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert_reduced",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "reduced_size": 48,
  "reduction_sizes": [
    512,
    256,
    128,
    68,
    48
  ],
  "transformers_version": "4.40.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [7]:
# PASS(2/2): No errors
mpnet_config = MPNetReducedConfig()
mpnet_config

MPNetReducedConfig {
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "mpnet_reduced",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "pooling_mode": "mean",
  "reduced_size": 48,
  "reduction_sizes": [
    512,
    256,
    128,
    68,
    48
  ],
  "relative_attention_num_buckets": 32,
  "transformers_version": "4.40.2",
  "vocab_size": 30527
}

#### Custom Initialization

In [8]:
# PASS(1/2): No errors, and custom variables were correctly set
BertReducedConfig(reduction_sizes=[256, 64, 32], hidden_act="relu")

BertReducedConfig {
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "hidden_act": "relu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert_reduced",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "reduced_size": 32,
  "reduction_sizes": [
    256,
    64,
    32
  ],
  "transformers_version": "4.40.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [9]:
# PASS(2/2): No errors, and custom variables were correctly set
MPNetReducedConfig(reduction_sizes=[256, 64, 32], hidden_act="relu", pooling_mode="cls")

MPNetReducedConfig {
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "eos_token_id": 2,
  "hidden_act": "relu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "mpnet_reduced",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "pooling_mode": "cls",
  "reduced_size": 32,
  "reduction_sizes": [
    256,
    64,
    32
  ],
  "relative_attention_num_buckets": 32,
  "transformers_version": "4.40.2",
  "vocab_size": 30527
}

## Test DimReshape

#### Initialization

In [10]:
# PASS(1/2): Module initialized without error
reshape_down = DimReshape(100, 5, bert_config)
reshape_down

DimReshape(
  (layernorm): LayerNorm((100,), eps=1e-12, elementwise_affine=True)
  (activation): GELUActivation()
  (dropout): Dropout(p=0.1, inplace=False)
  (dense): Linear(in_features=100, out_features=5, bias=True)
)

In [11]:
# PASS(2/2): Module initialized without error
reshape_up = DimReshape(5, 100, mpnet_config)   # Try with different config
reshape_up

DimReshape(
  (layernorm): LayerNorm((5,), eps=1e-12, elementwise_affine=True)
  (activation): GELUActivation()
  (dropout): Dropout(p=0.1, inplace=False)
  (dense): Linear(in_features=5, out_features=100, bias=True)
)

#### Forward Pass

In [12]:
# PASS(1/2): No errors, and output shape is correct
reshape_down(torch.randn(3,100)).shape

torch.Size([3, 5])

In [13]:
# PASS(2/2): No errors, and output shape is correct
reshape_up(torch.randn(3,5)).shape

torch.Size([3, 100])

## Test DimReduce

#### Initialization

In [14]:
# PASS (1/2): Module initialized without error, and structure looks correct
reduce = DimReduce(config=bert_config)
reduce

DimReduce(
  (0): DimReshape(
    (layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (activation): GELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
    (dense): Linear(in_features=768, out_features=512, bias=True)
  )
  (1): DimReshape(
    (layernorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
    (activation): GELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
    (dense): Linear(in_features=512, out_features=256, bias=True)
  )
  (2): DimReshape(
    (layernorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
    (activation): GELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
    (dense): Linear(in_features=256, out_features=128, bias=True)
  )
  (3): DimReshape(
    (layernorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (activation): GELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
    (dense): Linear(in_features=128, out_features=68, bias=True)
  )
  (4): DimReshape(
    (layern

In [15]:
# PASS (2/2): Module initialized without error, and structure looks correct
reduce = DimReduce(config=mpnet_config)     # Try with different config
reduce

DimReduce(
  (0): DimReshape(
    (layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (activation): GELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
    (dense): Linear(in_features=768, out_features=512, bias=True)
  )
  (1): DimReshape(
    (layernorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
    (activation): GELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
    (dense): Linear(in_features=512, out_features=256, bias=True)
  )
  (2): DimReshape(
    (layernorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
    (activation): GELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
    (dense): Linear(in_features=256, out_features=128, bias=True)
  )
  (3): DimReshape(
    (layernorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (activation): GELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
    (dense): Linear(in_features=128, out_features=68, bias=True)
  )
  (4): DimReshape(
    (layern

#### Forward Pass

In [16]:
# PASS: The output keys are correct, and all of the reduction shapes are as expected
with torch.no_grad():
    output = reduce(torch.randn(3,768))
    print("Final Reduced Embedding shape:", output[-1].shape)
    print("All Embedding shapes:", [layer.shape for layer in output])
    print("Number of Reductions:", len(output))

Final Reduced Embedding shape: torch.Size([3, 48])
All Embedding shapes: [torch.Size([3, 512]), torch.Size([3, 256]), torch.Size([3, 128]), torch.Size([3, 68]), torch.Size([3, 48])]
Number of Reductions: 5


## Test DimExpand

#### Initialization

In [17]:
# PASS (1/2): Module initialized without error, and structure looks correct
expand = DimExpand(config=mpnet_config)
expand

DimExpand(
  (0): DimReshape(
    (layernorm): LayerNorm((48,), eps=1e-12, elementwise_affine=True)
    (activation): GELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
    (dense): Linear(in_features=48, out_features=68, bias=True)
  )
  (1): DimReshape(
    (layernorm): LayerNorm((68,), eps=1e-12, elementwise_affine=True)
    (activation): GELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
    (dense): Linear(in_features=68, out_features=128, bias=True)
  )
  (2): DimReshape(
    (layernorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (activation): GELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
    (dense): Linear(in_features=128, out_features=256, bias=True)
  )
  (3): DimReshape(
    (layernorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
    (activation): GELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
    (dense): Linear(in_features=256, out_features=512, bias=True)
  )
  (4): DimReshape(
    (layernorm)

In [18]:
# PASS (2/2): Module initialized without error, and structure looks correct
expand = DimExpand(config=bert_config)
expand

DimExpand(
  (0): DimReshape(
    (layernorm): LayerNorm((48,), eps=1e-12, elementwise_affine=True)
    (activation): GELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
    (dense): Linear(in_features=48, out_features=68, bias=True)
  )
  (1): DimReshape(
    (layernorm): LayerNorm((68,), eps=1e-12, elementwise_affine=True)
    (activation): GELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
    (dense): Linear(in_features=68, out_features=128, bias=True)
  )
  (2): DimReshape(
    (layernorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (activation): GELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
    (dense): Linear(in_features=128, out_features=256, bias=True)
  )
  (3): DimReshape(
    (layernorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
    (activation): GELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
    (dense): Linear(in_features=256, out_features=512, bias=True)
  )
  (4): DimReshape(
    (layernorm)

#### Forward Pass

In [19]:
# PASS: The output keys are correct, and all of the reduction shapes are as expected
with torch.no_grad():
    output = expand(torch.randn(3,48))
    print("Final Reduced Embedding shape:", output[-1].shape)
    print("All Embedding shapes:", [layer.shape for layer in output])
    print("Number of Expansion:", len(output))

Final Reduced Embedding shape: torch.Size([3, 768])
All Embedding shapes: [torch.Size([3, 68]), torch.Size([3, 128]), torch.Size([3, 256]), torch.Size([3, 512]), torch.Size([3, 768])]
Number of Expansion: 5


## Test BertReducedModel

#### Initialization

In [20]:
# PASS: Module initialized without error
config = BertReducedConfig()
model = BertReducedModel(config)
model.to(device)

BertReducedModel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_

#### Forward Pass

In [21]:
# PASS: Inputs were passed through the model, and the output shapes are as expected
with torch.no_grad():
    outputs = model(**bert_inputs)
    print("Output keys:", outputs.__dict__.keys())
    print("Reduced Embedding shape:", outputs.last_reduced_hidden_state.shape)
    print("Full Embedding shape:", outputs.last_full_hidden_state.shape)
    print("Reduced Pooled shape:", outputs.reduced_pooler_output.shape)
    print("Full Pooled shape:", outputs.full_pooler_output.shape)
    print("Intermediate Embedding shapes:", [layer.shape for layer in outputs.reduced_hidden_states[0]])
    print("Intermediate Pooled shapes:", [layer.shape for layer in outputs.reduced_hidden_states[1]])
    print("Included hidden states and attentions:", bool(outputs.full_hidden_states and outputs.attentions))

Output keys: dict_keys(['last_reduced_hidden_state', 'last_full_hidden_state', 'reduced_pooler_output', 'full_pooler_output', 'reduced_hidden_states', 'full_hidden_states', 'past_key_values', 'attentions', 'cross_attentions'])
Reduced Embedding shape: torch.Size([3, 16, 48])
Full Embedding shape: torch.Size([3, 16, 768])
Reduced Pooled shape: torch.Size([3, 48])
Full Pooled shape: torch.Size([3, 768])
Intermediate Embedding shapes: [torch.Size([3, 16, 512]), torch.Size([3, 16, 256]), torch.Size([3, 16, 128]), torch.Size([3, 16, 68]), torch.Size([3, 16, 48])]
Intermediate Pooled shapes: [torch.Size([3, 512]), torch.Size([3, 256]), torch.Size([3, 128]), torch.Size([3, 68]), torch.Size([3, 48])]
Included hidden states and attentions: True


## Test BertReducedForPreTraining

#### Initialization

In [22]:
# PASS: Module initialized without error and the structure looks correct
config = BertReducedConfig()
model = BertReducedForPreTraining(config)
model.to(device)

BertReducedForPreTraining(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, ele

#### Forward Pass

In [23]:
# PASS(1/2): Inputs were passed through the model, and the output shapes are as expected
with torch.no_grad():
    outputs = model(**bert_inputs)
    print("Output keys:", outputs.__dict__.keys())
    print("Loss (Should be None):", outputs.loss)
    print("MLM logits shape:", outputs.prediction_logits.shape)
    print("Sequence logits:", outputs.seq_relationship_logits.shape)
    print("Intermediate Embedding shapes:", [layer.shape for layer in outputs.reduced_hidden_states[0]])
    print("Intermediate Pooled shapes:", [layer.shape for layer in outputs.reduced_hidden_states[1]])
    print("Included hidden states and attentions:", bool(outputs.full_hidden_states and outputs.attentions))

Output keys: dict_keys(['loss', 'prediction_logits', 'seq_relationship_logits', 'reduced_hidden_states', 'full_hidden_states', 'attentions'])
Loss (Should be None): None
MLM logits shape: torch.Size([3, 16, 30522])
Sequence logits: torch.Size([3, 2])
Intermediate Embedding shapes: [torch.Size([3, 16, 512]), torch.Size([3, 16, 256]), torch.Size([3, 16, 128]), torch.Size([3, 16, 68]), torch.Size([3, 16, 48])]
Intermediate Pooled shapes: [torch.Size([3, 512]), torch.Size([3, 256]), torch.Size([3, 128]), torch.Size([3, 68]), torch.Size([3, 48])]
Included hidden states and attentions: True


In [24]:
# PASS(2/2): The full loss is computed without raising an error
with torch.no_grad():
    outputs = model(**bert_inputs, labels=bert_inputs['input_ids'], next_sentence_label=bert_labels)
    print("Output keys:", outputs.__dict__.keys())
    print("Loss:", outputs.loss)
    print("MLM logits shape:", outputs.prediction_logits.shape)
    print("Sequence logits:", outputs.seq_relationship_logits.shape)
    print("Intermediate Embedding shapes:", [layer.shape for layer in outputs.reduced_hidden_states[0]])
    print("Intermediate Pooled shapes:", [layer.shape for layer in outputs.reduced_hidden_states[1]])
    print("Included hidden states and attentions:", bool(outputs.full_hidden_states and outputs.attentions))

Output keys: dict_keys(['loss', 'prediction_logits', 'seq_relationship_logits', 'reduced_hidden_states', 'full_hidden_states', 'attentions'])
Loss: tensor(11.1383, device='cuda:0')
MLM logits shape: torch.Size([3, 16, 30522])
Sequence logits: torch.Size([3, 2])
Intermediate Embedding shapes: [torch.Size([3, 16, 512]), torch.Size([3, 16, 256]), torch.Size([3, 16, 128]), torch.Size([3, 16, 68]), torch.Size([3, 16, 48])]
Intermediate Pooled shapes: [torch.Size([3, 512]), torch.Size([3, 256]), torch.Size([3, 128]), torch.Size([3, 68]), torch.Size([3, 48])]
Included hidden states and attentions: True


## Test BertReducedForSequenceClassification

#### Initialization

In [25]:
# PASS: Module initialized without error
config = BertReducedConfig()
model = BertReducedForSequenceClassification(config)
model.to(device)

BertReducedForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps

#### Forward Pass

In [26]:
# PASS(1/2): Inputs were passed through the model, and the output shapes are as expected
with torch.no_grad():
    outputs = model(**bert_inputs)
    print("Output keys:", outputs.__dict__.keys())
    print("Loss (Should be None):", outputs.loss)
    print("Logits shape:", outputs.logits.shape)
    print("Intermediate Pooled shapes:", [layer.shape for layer in outputs.reduced_hidden_states])
    print("Included hidden states and attentions:", bool(outputs.full_hidden_states and outputs.attentions))

Output keys: dict_keys(['loss', 'logits', 'reduced_hidden_states', 'full_hidden_states', 'attentions'])
Loss (Should be None): None
Logits shape: torch.Size([3, 2])
Intermediate Pooled shapes: [torch.Size([3, 512]), torch.Size([3, 256]), torch.Size([3, 128]), torch.Size([3, 68]), torch.Size([3, 48])]
Included hidden states and attentions: True


In [27]:
# PASS(2/2): Loss is computed without raising an error
with torch.no_grad():
    outputs = model(**bert_inputs, labels=bert_labels)
    print("Output keys:", outputs.__dict__.keys())
    print("Loss:", outputs.loss)
    print("Logits shape:", outputs.logits.shape)
    print("Intermediate Pooled shapes:", [layer.shape for layer in outputs.reduced_hidden_states])
    print("Included hidden states and attentions:", bool(outputs.full_hidden_states and outputs.attentions))

Output keys: dict_keys(['loss', 'logits', 'reduced_hidden_states', 'full_hidden_states', 'attentions'])
Loss: tensor(0.7110, device='cuda:0')
Logits shape: torch.Size([3, 2])
Intermediate Pooled shapes: [torch.Size([3, 512]), torch.Size([3, 256]), torch.Size([3, 128]), torch.Size([3, 68]), torch.Size([3, 48])]
Included hidden states and attentions: True


## Test MPNetReducedModel

#### Initialization

In [28]:
# PASS: Module initialized without error and the structure looks correct
config = MPNetReducedConfig()
model = MPNetReducedModel(config)
model.to(device)

MPNetReducedModel(
  (mpnet): MPNetModel(
    (embeddings): MPNetEmbeddings(
      (word_embeddings): Embedding(30527, 768, padding_idx=1)
      (position_embeddings): Embedding(512, 768, padding_idx=1)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): MPNetEncoder(
      (layer): ModuleList(
        (0-11): 12 x MPNetLayer(
          (attention): MPNetAttention(
            (attn): MPNetSelfAttention(
              (q): Linear(in_features=768, out_features=768, bias=True)
              (k): Linear(in_features=768, out_features=768, bias=True)
              (v): Linear(in_features=768, out_features=768, bias=True)
              (o): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
  

#### Forward Pass

In [29]:
# PASS: Inputs were passed through the model, and the output shapes are as expected
with torch.no_grad():
    outputs = model(**mpnet_inputs)
    print("Output keys:", outputs.__dict__.keys())
    print("Reduced Embedding shape:", outputs.last_reduced_hidden_state.shape)
    print("Full Embedding shape:", outputs.last_full_hidden_state.shape)
    print("Reduced Pooled shape:", outputs.reduced_pooler_output.shape)
    print("Full Pooled shape:", outputs.full_pooler_output.shape)
    print("Intermediate Embedding shapes:", [layer.shape for layer in outputs.reduced_hidden_states[0]])
    print("Intermediate Pooled shapes:", [layer.shape for layer in outputs.reduced_hidden_states[1]])
    print("Included hidden states and attentions:", bool(outputs.full_hidden_states and outputs.attentions))

Output keys: dict_keys(['last_reduced_hidden_state', 'last_full_hidden_state', 'reduced_pooler_output', 'full_pooler_output', 'reduced_hidden_states', 'full_hidden_states', 'attentions'])
Reduced Embedding shape: torch.Size([3, 16, 48])
Full Embedding shape: torch.Size([3, 16, 768])
Reduced Pooled shape: torch.Size([3, 48])
Full Pooled shape: torch.Size([3, 768])
Intermediate Embedding shapes: [torch.Size([3, 16, 512]), torch.Size([3, 16, 256]), torch.Size([3, 16, 128]), torch.Size([3, 16, 68]), torch.Size([3, 16, 48])]
Intermediate Pooled shapes: [torch.Size([3, 512]), torch.Size([3, 256]), torch.Size([3, 128]), torch.Size([3, 68]), torch.Size([3, 48])]
Included hidden states and attentions: True


## Test MPNetReducedForSequenceClassification

#### Initialization

In [30]:
# PASS: Module initialized without error and the structure looks correct
config = MPNetReducedConfig()
model = MPNetReducedForSequenceClassification(config)
model.to(device)

MPNetReducedForSequenceClassification(
  (mpnet): MPNetModel(
    (embeddings): MPNetEmbeddings(
      (word_embeddings): Embedding(30527, 768, padding_idx=1)
      (position_embeddings): Embedding(512, 768, padding_idx=1)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): MPNetEncoder(
      (layer): ModuleList(
        (0-11): 12 x MPNetLayer(
          (attention): MPNetAttention(
            (attn): MPNetSelfAttention(
              (q): Linear(in_features=768, out_features=768, bias=True)
              (k): Linear(in_features=768, out_features=768, bias=True)
              (v): Linear(in_features=768, out_features=768, bias=True)
              (o): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=F

#### Forward Pass

In [31]:
# PASS(1/2): Inputs were passed through the model, and the output shapes are as expected
with torch.no_grad():
    outputs = model(**mpnet_inputs)
    print("Output keys:", outputs.__dict__.keys())
    print("Loss (Should be None):", outputs.loss)
    print("Logits shape:", outputs.logits.shape)
    print("Intermediate Pooled shapes:", [layer.shape for layer in outputs.reduced_hidden_states])
    print("Included hidden states and attentions:", bool(outputs.full_hidden_states and outputs.attentions))

Output keys: dict_keys(['loss', 'logits', 'reduced_hidden_states', 'full_hidden_states', 'attentions'])
Loss (Should be None): None
Logits shape: torch.Size([3, 2])
Intermediate Pooled shapes: [torch.Size([3, 512]), torch.Size([3, 256]), torch.Size([3, 128]), torch.Size([3, 68]), torch.Size([3, 48])]
Included hidden states and attentions: True


In [32]:
# PASS(2/2): Loss is computed without raising an error
with torch.no_grad():
    outputs = model(**mpnet_inputs, labels=mpnet_labels)
    print("Output keys:", outputs.__dict__.keys())
    print("Loss:", outputs.loss)
    print("Logits shape:", outputs.logits.shape)
    print("Intermediate Pooled shapes:", [layer.shape for layer in outputs.reduced_hidden_states])
    print("Included hidden states and attentions:", bool(outputs.full_hidden_states and outputs.attentions))

Output keys: dict_keys(['loss', 'logits', 'reduced_hidden_states', 'full_hidden_states', 'attentions'])
Loss: tensor(0.7776, device='cuda:0')
Logits shape: torch.Size([3, 2])
Intermediate Pooled shapes: [torch.Size([3, 512]), torch.Size([3, 256]), torch.Size([3, 128]), torch.Size([3, 68]), torch.Size([3, 48])]
Included hidden states and attentions: True


## Test MPNetCompressedModel

#### Initialization

In [33]:
# PASS: Module initialized without error and the structure looks correct
config = MPNetReducedConfig()
model = MPNetCompressedModel(config)
model.to(device)

MPNetCompressedModel(
  (mpnet): MPNetModel(
    (embeddings): MPNetEmbeddings(
      (word_embeddings): Embedding(30527, 768, padding_idx=1)
      (position_embeddings): Embedding(512, 768, padding_idx=1)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): MPNetEncoder(
      (layer): ModuleList(
        (0-11): 12 x MPNetLayer(
          (attention): MPNetAttention(
            (attn): MPNetSelfAttention(
              (q): Linear(in_features=768, out_features=768, bias=True)
              (k): Linear(in_features=768, out_features=768, bias=True)
              (v): Linear(in_features=768, out_features=768, bias=True)
              (o): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )

#### Forward Pass

In [34]:
# PASS: Inputs were passed through the model, and the output shapes are as expected
with torch.no_grad():
    outputs = model(**mpnet_inputs)
    print("Output keys:", outputs.__dict__.keys())
    print("Reduced Embedding shape:", outputs.last_reduced_hidden_state.shape)
    print("Full Embedding shape:", outputs.last_full_hidden_state.shape)
    print("Intermediate Reduced Embedding shapes:", [layer.shape for layer in outputs.reduced_hidden_states])
    print("Included hidden states and attentions:", bool(outputs.full_hidden_states and outputs.attentions))

Output keys: dict_keys(['last_reduced_hidden_state', 'last_full_hidden_state', 'reduced_hidden_states', 'full_hidden_states', 'attentions'])
Reduced Embedding shape: torch.Size([3, 48])
Full Embedding shape: torch.Size([3, 768])
Intermediate Reduced Embedding shapes: [torch.Size([3, 512]), torch.Size([3, 256]), torch.Size([3, 128]), torch.Size([3, 68]), torch.Size([3, 48])]
Included hidden states and attentions: True


## Test MPNetCompressedForPreTraining

#### Initialization

In [35]:
# PASS(1/2): Module initialized without error and the structure looks correct
config = MPNetReducedConfig()
model = MPNetCompressedForPreTraining(config)
model.to(device)

MPNetCompressedForPreTraining(
  (mpnet): MPNetModel(
    (embeddings): MPNetEmbeddings(
      (word_embeddings): Embedding(30527, 768, padding_idx=1)
      (position_embeddings): Embedding(512, 768, padding_idx=1)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): MPNetEncoder(
      (layer): ModuleList(
        (0-11): 12 x MPNetLayer(
          (attention): MPNetAttention(
            (attn): MPNetSelfAttention(
              (q): Linear(in_features=768, out_features=768, bias=True)
              (k): Linear(in_features=768, out_features=768, bias=True)
              (v): Linear(in_features=768, out_features=768, bias=True)
              (o): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [36]:
# PASS(2/2): Error is raised when initializing without both losses
config = MPNetReducedConfig()
try:
    MPNetCompressedForPreTraining(config, do_contrast=False, do_reconstruction=False)
except ValueError as e:
    print(e)

At least one of do_contrast and do_reconstruction must be True


#### Forward Pass

In [37]:
# PASS(1/4): The full loss is computed without raising an error 
with torch.no_grad():
    outputs = model(**mpnet_inputs)
    print("Output keys:", outputs.__dict__.keys())
    print("Total Loss:", outputs.loss)
    print("Contrastive Loss:", outputs.contrastive_loss)
    print("Reconstruction Loss:", outputs.reconstruction_loss)
    print("Intermediate Reduced Embedding shapes:", [layer.shape for layer in outputs.reduced_hidden_states])
    print("Intermediate Reconstruction shapes:", [layer.shape for layer in outputs.reconstructed_hidden_states])
    print("Included hidden states and attentions:", bool(outputs.full_hidden_states and outputs.attentions))

Output keys: dict_keys(['loss', 'contrastive_loss', 'reconstruction_loss', 'reduced_hidden_states', 'reconstructed_hidden_states', 'full_hidden_states', 'attentions'])
Total Loss: tensor(1.9267, device='cuda:0')
Contrastive Loss: tensor(0.1961, device='cuda:0')
Reconstruction Loss: tensor(1.7306, device='cuda:0')
Intermediate Reduced Embedding shapes: [torch.Size([3, 512]), torch.Size([3, 256]), torch.Size([3, 128]), torch.Size([3, 68]), torch.Size([3, 48])]
Intermediate Reconstruction shapes: [torch.Size([3, 68]), torch.Size([3, 128]), torch.Size([3, 256]), torch.Size([3, 512]), torch.Size([3, 768])]
Included hidden states and attentions: True


In [38]:
# PASS(2/4): The contrastive loss is computed alone without raising an error
#            Total loss matches the contrastive loss, reconstruction loss is None
with torch.no_grad():
    model.do_contrast = True
    model.do_reconstruction = False
    outputs = model(**mpnet_inputs)
    print("Output keys:", outputs.__dict__.keys())
    print("Total Loss:", outputs.loss)
    print("Contrastive Loss:", outputs.contrastive_loss)
    print("Reconstruction Loss (None):", outputs.reconstruction_loss)
    print("Intermediate Reduced Embedding shapes:", [layer.shape for layer in outputs.reduced_hidden_states])
    print("Intermediate Reconstruction shapes (None):", outputs.reconstructed_hidden_states)
    print("Included hidden states and attentions:", bool(outputs.full_hidden_states and outputs.attentions))

Output keys: dict_keys(['loss', 'contrastive_loss', 'reconstruction_loss', 'reduced_hidden_states', 'reconstructed_hidden_states', 'full_hidden_states', 'attentions'])
Total Loss: tensor(0.2221, device='cuda:0')
Contrastive Loss: tensor(0.2221, device='cuda:0')
Reconstruction Loss (None): None
Intermediate Reduced Embedding shapes: [torch.Size([3, 512]), torch.Size([3, 256]), torch.Size([3, 128]), torch.Size([3, 68]), torch.Size([3, 48])]
Intermediate Reconstruction shapes (None): None
Included hidden states and attentions: True


In [39]:
# PASS(3/4): The reconstruction loss is computed alone without raising an error
#            Total loss matches the reconstruction loss, contrastive loss is None
with torch.no_grad():
    model.do_contrast = False
    model.do_reconstruction = True
    outputs = model(**mpnet_inputs)
    print("Output keys:", outputs.__dict__.keys())
    print("Total Loss:", outputs.loss)
    print("Contrastive Loss (None):", outputs.contrastive_loss)
    print("Reconstruction Loss:", outputs.reconstruction_loss)
    print("Intermediate Reduced Embedding shapes:", [layer.shape for layer in outputs.reduced_hidden_states])
    print("Intermediate Reconstruction shapes:", [layer.shape for layer in outputs.reconstructed_hidden_states])
    print("Included hidden states and attentions:", bool(outputs.full_hidden_states and outputs.attentions))

Output keys: dict_keys(['loss', 'contrastive_loss', 'reconstruction_loss', 'reduced_hidden_states', 'reconstructed_hidden_states', 'full_hidden_states', 'attentions'])
Total Loss: tensor(1.7378, device='cuda:0')
Contrastive Loss (None): None
Reconstruction Loss: tensor(1.7378, device='cuda:0')
Intermediate Reduced Embedding shapes: [torch.Size([3, 512]), torch.Size([3, 256]), torch.Size([3, 128]), torch.Size([3, 68]), torch.Size([3, 48])]
Intermediate Reconstruction shapes: [torch.Size([3, 68]), torch.Size([3, 128]), torch.Size([3, 256]), torch.Size([3, 512]), torch.Size([3, 768])]
Included hidden states and attentions: True


In [40]:
# PASS(4/4): Compute the total, not full loss without raising an error
with torch.no_grad():
    model.do_contrast = True
    model.do_reconstruction = True
    outputs = model(**mpnet_inputs, compute_full_loss=False)
    print("Output keys:", outputs.__dict__.keys())
    print("Total Loss:", outputs.loss)
    print("Contrastive Loss (None):", outputs.contrastive_loss)
    print("Reconstruction Loss:", outputs.reconstruction_loss)
    print("Intermediate Reduced Embedding shapes:", [layer.shape for layer in outputs.reduced_hidden_states])
    print("Intermediate Reconstruction shapes:", [layer.shape for layer in outputs.reconstructed_hidden_states])
    print("Included hidden states and attentions:", bool(outputs.full_hidden_states and outputs.attentions))

Output keys: dict_keys(['loss', 'contrastive_loss', 'reconstruction_loss', 'reduced_hidden_states', 'reconstructed_hidden_states', 'full_hidden_states', 'attentions'])
Total Loss: tensor(0.4758, device='cuda:0')
Contrastive Loss (None): tensor(0.0037, device='cuda:0')
Reconstruction Loss: tensor(0.4721, device='cuda:0')
Intermediate Reduced Embedding shapes: [torch.Size([3, 512]), torch.Size([3, 256]), torch.Size([3, 128]), torch.Size([3, 68]), torch.Size([3, 48])]
Intermediate Reconstruction shapes: [torch.Size([3, 68]), torch.Size([3, 128]), torch.Size([3, 256]), torch.Size([3, 512]), torch.Size([3, 768])]
Included hidden states and attentions: True


## Test MPNetCompressedForSequenceClassification

#### Initialization

In [41]:
# PASS: Module initialized without error and the structure looks correct 
config = MPNetReducedConfig()
model = MPNetCompressedForSequenceClassification(config)
model.to(device)

MPNetCompressedForSequenceClassification(
  (mpnet): MPNetModel(
    (embeddings): MPNetEmbeddings(
      (word_embeddings): Embedding(30527, 768, padding_idx=1)
      (position_embeddings): Embedding(512, 768, padding_idx=1)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): MPNetEncoder(
      (layer): ModuleList(
        (0-11): 12 x MPNetLayer(
          (attention): MPNetAttention(
            (attn): MPNetSelfAttention(
              (q): Linear(in_features=768, out_features=768, bias=True)
              (k): Linear(in_features=768, out_features=768, bias=True)
              (v): Linear(in_features=768, out_features=768, bias=True)
              (o): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplac

#### Forward Pass

In [42]:
# PASS(1/2): Inputs were passed through the model, and the output shapes are as expected
with torch.no_grad():
    outputs = model(**mpnet_inputs)
    print("Output keys:", outputs.__dict__.keys())
    print("Loss (Should be None):", outputs.loss)
    print("Logits shape:", outputs.logits.shape)
    print("Intermediate Pooled shapes:", [layer.shape for layer in outputs.reduced_hidden_states])
    print("Included hidden states and attentions:", bool(outputs.full_hidden_states and outputs.attentions))

Output keys: dict_keys(['loss', 'logits', 'reduced_hidden_states', 'full_hidden_states', 'attentions'])
Loss (Should be None): None
Logits shape: torch.Size([3, 2])
Intermediate Pooled shapes: [torch.Size([3, 512]), torch.Size([3, 256]), torch.Size([3, 128]), torch.Size([3, 68]), torch.Size([3, 48])]
Included hidden states and attentions: True


In [43]:
# PASS(2/2): Loss is computed without raising an error
with torch.no_grad():
    outputs = model(**mpnet_inputs, labels=mpnet_labels)
    print("Output keys:", outputs.__dict__.keys())
    print("Loss:", outputs.loss)
    print("Logits shape:", outputs.logits.shape)
    print("Intermediate Pooled shapes:", [layer.shape for layer in outputs.reduced_hidden_states])
    print("Included hidden states and attentions:", bool(outputs.full_hidden_states and outputs.attentions))

Output keys: dict_keys(['loss', 'logits', 'reduced_hidden_states', 'full_hidden_states', 'attentions'])
Loss: tensor(0.8921, device='cuda:0')
Logits shape: torch.Size([3, 2])
Intermediate Pooled shapes: [torch.Size([3, 512]), torch.Size([3, 256]), torch.Size([3, 128]), torch.Size([3, 68]), torch.Size([3, 48])]
Included hidden states and attentions: True
