In [1]:
import argparse
import logging
import math
import os
import random

import datasets
from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm

import transformers
from accelerate import Accelerator
from transformers import (
    AdamW,
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    PretrainedConfig,
    SchedulerType,
    default_data_collator,
    get_scheduler,
    set_seed,
)
from transformers.utils.versions import require_version

# My custom model
from models import BertForSequenceClassification
from models import BertConfig
import torch
import matplotlib.pyplot as plt

In [2]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# change model pretrained path here
model = BertForSequenceClassification.from_pretrained("bert-base-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [3]:
inputs = tokenizer(["Hello, my dog is cute and I am the biggest person in the world", "Yo yo"], 
                   max_length = 128, padding='max_length', truncation=True, return_tensors="pt")
labels = torch.tensor([1, 0]).unsqueeze(0)  # Batch size 1

In [4]:
# first phrase
outputs = model(**inputs, labels=labels)

In [9]:
inputs = tokenizer(["Hello, my dog is cute and I am the biggest person in the world"], 
                   max_length = 128, padding='max_length', truncation=True, return_tensors="pt")
labels = torch.tensor([1]).unsqueeze(0)  # Batch size 1

In [10]:
# second phrase
outputs = model.exit_forward(**inputs, labels=labels)

In [None]:
model.eval()

In [11]:
outputs = model.exit_inference(**inputs)

In [12]:
outputs

SequenceClassifierOutput(loss=None, logits=tensor([[-0.5716, -0.3170]], grad_fn=<AddmmBackward>), hidden_states=(tensor([[[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
           3.8253e-02,  1.6400e-01],
         [ 3.7386e-01, -1.5575e-02, -2.4561e-01,  ..., -3.1657e-02,
           5.5144e-01, -5.2406e-01],
         [ 4.6709e-04,  1.6225e-01, -6.4443e-02,  ...,  4.9443e-01,
           6.9413e-01,  3.6286e-01],
         ...,
         [ 2.6543e-02, -1.7981e-01,  4.8942e-01,  ..., -5.3213e-01,
          -2.4651e-01,  4.6587e-01],
         [-1.1945e-01, -2.3242e-01,  2.4277e-01,  ..., -4.7397e-01,
          -1.3933e-01,  3.1804e-01],
         [ 1.1076e-01, -7.5039e-02,  3.2258e-01,  ..., -7.3081e-02,
          -5.1369e-01,  2.2897e-01]]], grad_fn=<NativeLayerNormBackward>), tensor([[[ 0.0756,  0.0418, -0.2009,  ...,  0.1857, -0.0269,  0.0443],
         [ 0.4123,  0.1266,  0.3189,  ...,  0.4447,  0.9412, -0.3549],
         [-0.1503,  0.3128, -0.0155,  ..., -0.0342,  0.7499,  0

In [8]:
outputs.loss

tensor(7.9250, grad_fn=<AddBackward0>)

In [13]:
for name, param in model.named_parameters():
    if "exit_port" not in name:
        print(name)

bert.embeddings.word_embeddings.weight
bert.embeddings.position_embeddings.weight
bert.embeddings.token_type_embeddings.weight
bert.embeddings.LayerNorm.weight
bert.embeddings.LayerNorm.bias
bert.encoder.layer.0.attention.self.query.weight
bert.encoder.layer.0.attention.self.query.bias
bert.encoder.layer.0.attention.self.key.weight
bert.encoder.layer.0.attention.self.key.bias
bert.encoder.layer.0.attention.self.value.weight
bert.encoder.layer.0.attention.self.value.bias
bert.encoder.layer.0.attention.output.dense.weight
bert.encoder.layer.0.attention.output.dense.bias
bert.encoder.layer.0.attention.output.LayerNorm.weight
bert.encoder.layer.0.attention.output.LayerNorm.bias
bert.encoder.layer.0.intermediate.dense.weight
bert.encoder.layer.0.intermediate.dense.bias
bert.encoder.layer.0.output.dense.weight
bert.encoder.layer.0.output.dense.bias
bert.encoder.layer.0.output.LayerNorm.weight
bert.encoder.layer.0.output.LayerNorm.bias
bert.encoder.layer.1.attention.self.query.weight
bert.enc