In [21]:
import torch
from pytorch_lightning import LightningModule, LightningDataModule
from torch import nn

In [2]:
from transformers import AutoModel

In [13]:
class MLP(LightningModule):
    def __init__(self, n_hidden_layers, hidden_layer_size=256):
        super().__init__()
        self.save_hyperparameters(ignore=["hidden_layer_size"])

        # Build the network
        hidden_layers = []
        for _ in range(self.hparams["n_hidden_layers"]):
            hidden_layers.extend(
                [nn.Linear(in_features=512, out_features=512), nn.ReLU()]
            )

        self.net = nn.Sequential(
            *hidden_layers, nn.Linear(in_features=512, out_features=1)
        )

    def forward(self, inputs):
        self.n_hidden_layers = 4
        return self.net(inputs)

    def train_step(self, batch):
        inputs, targets = batch
        outputs = self(inputs).reshape(-1)
        loss_fn = nn.BCEWithLogitsLoss()
        loss = loss_fn(outputs, targets)
        return loss

In [14]:
bert = AutoModel.from_pretrained("bert-base-cased")

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [16]:
model = MLP(5)

In [17]:
inputs = torch.rand(16, 512)
targets = torch.randint(0, 2, (16,)).float()  # binary targets
batch = (inputs, targets)

In [18]:
model(inputs)

tensor([[-0.0435],
        [-0.0429],
        [-0.0422],
        [-0.0429],
        [-0.0425],
        [-0.0440],
        [-0.0442],
        [-0.0424],
        [-0.0423],
        [-0.0438],
        [-0.0423],
        [-0.0424],
        [-0.0452],
        [-0.0438],
        [-0.0405],
        [-0.0428]], grad_fn=<AddmmBackward>)

In [19]:
model.train_step(batch)

tensor(0.6986, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)

In [38]:
from pl_modules import PlLanguageModelForSequenceOrdering

In [48]:
from argparse import Namespace

args = Namespace()

setattr(args, "model_name_or_path", "bert-base-cased")
setattr(args, "model_name_or_path", "bert-base-cased")

m = PlLanguageModelForSequenceOrdering(args)

['__class__', '__contains__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_get_args', '_get_kwargs', 'model_name_or_path']


Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cas