In [1]:
!pip install wandb pytorch-lightning transformers[sentencepiece] onnx

Collecting wandb
  Downloading wandb-0.16.3-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m14.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pytorch-lightning
  Downloading pytorch_lightning-2.2.0-py3-none-any.whl (800 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m800.3/800.3 kB[0m [31m22.5 MB/s[0m eta [36m0:00:00[0m
Collecting onnx
  Downloading onnx-1.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.7/15.7 MB[0m [31m57.8 MB/s[0m eta [36m0:00:00[0m
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.41-py3-none-any.whl (196 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m196.4/196.4 kB[0m [31m23.7 MB/s[0m eta [36m0:00:00[0m
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-1.40.3-py2.py3-none-any.whl (257 kB)
[2K     [90m━━━━━━━━━

In [2]:
import wandb
import pytorch_lightning as pl
import torch
from transformers import BertModel, BertConfig
from torchmetrics import Metric
import os
import onnx
import onnx.numpy_helper as numpy_helper


In [3]:
run = wandb.init()
artifact = run.use_artifact('prince_/lit-wandb/model.ckpt:v34', type='model')
artifact_dir = artifact.download()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


[34m[1mwandb[0m: Downloading large artifact model.ckpt:v34, 417.81MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:5.0


Read the documentation on hot to save the model in torch

In [4]:
pl.LightningModule.to_torchscript??

In [5]:
num_ner_labels, num_intent_labels = 9, 5

In [6]:
class MyAccuracy(Metric):
    def __init__(self):
        super().__init__()
        self.add_state('total', default=torch.tensor(0), dist_reduce_fx='sum')
        self.add_state('correct', default=torch.tensor(0), dist_reduce_fx='sum')

    def update(self, logits, labels, num_labels):

        flattened_targets = labels.view(-1) # shape (batch_size, sequence_len)
        active_logits = logits.view(-1, num_labels) # shape (batch_size * sequence_len, num_labels)
        flattened_predictions = torch.argmax(active_logits, axis=1) # shape (batch_size * sequence_len)

        # compute accuracy only at active labels
        active_accuracy = labels.view(-1) != -100 # shape (batch_size, sequnce_len)
        ac_labels = torch.masked_select(flattened_targets, active_accuracy)
        predictions = torch.masked_select(flattened_predictions, active_accuracy)

        self.correct += torch.sum(ac_labels == predictions)
        self.total += torch.numel(ac_labels)

    def compute(self):
        return self.correct.float() / self.total.float()

In [7]:
class MultiTaskBertModel(pl.LightningModule):

    """
    Multi-task Bert model for Named Entity Recognition (NER) and Intent Classification

    Args:
        config (BertConfig): Bert model configuration.
        num_ner_labels (int): The number of labels for NER task.
        num_intent_labels (int): The number of labels for Intent Classification task.
    """

    def __init__(self, config, num_ner_labels, num_intent_labels):
        super().__init__()

        self.num_ner_labels = num_ner_labels
        self.num_intent_labels = num_intent_labels

        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)

        self.model = BertModel(config=config)

        self.ner_classifier = torch.nn.Linear(config.hidden_size, self.num_ner_labels)
        self.intent_classifier = torch.nn.Linear(config.hidden_size, self.num_intent_labels)

        # log hyperparameters
        self.save_hyperparameters()

        self.accuracy = MyAccuracy()

    def forward(self, input_ids=None, attention_mask=None):

        """
        Perform a forward pass through Multi-task Bert model.

        Args:
            input_ids (torch.Tensor): Input token IDs.
            attention_mask (torch.Tensor): Attention mask for input tokens.
            ner_labels (torch.Tensor): Labels for NER task.
            intent_labels (torch.Tensor): Labels for Intent Classification task.

        Returns:
            Tuple[torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor]: NER loss, NER logits, Intent loss, Intent logits.

        Raises:
            ValueError: If ner_labels or intent_labels were not provided.
        """

        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)

        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        ner_logits = self.ner_classifier(sequence_output)

        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        intent_logits = self.intent_classifier(pooled_output)

        return ner_logits, intent_logits

    def training_step(self: pl.LightningModule, batch, batch_idx: int):
        loss, ner_logits, intent_logits, ner_labels, intent_labels = self._common_step(batch, batch_idx)
        accuracy_ner = self.accuracy(ner_logits, ner_labels, num_ner_labels)
        accuracy_intent = self.accuracy(intent_logits, intent_labels, num_intent_labels)
        self.log_dict({'training_loss': loss, 'ner_accuracy': accuracy_ner, 'intent_accuracy': accuracy_intent},
                      on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def on_validation_epoch_start(self):
        self.validation_step_outputs_ner = []
        self.validation_step_outputs_intent = []

    def validation_step(self, batch, batch_idx):
        loss, ner_logits, intent_logits, ner_labels, intent_labels = self._common_step(batch, batch_idx)
        # self.log('val_loss', loss)
        accuracy_ner = self.accuracy(ner_logits, ner_labels, num_ner_labels)
        accuracy_intent = self.accuracy(intent_logits, intent_labels, num_intent_labels)
        self.log_dict({'validation_loss': loss, 'val_ner_accuracy': accuracy_ner, 'val_intent_accuracy': accuracy_intent},
                      on_step=False, on_epoch=True, prog_bar=True)

        self.validation_step_outputs_ner.append(ner_logits)
        self.validation_step_outputs_intent.append(intent_logits)
        return loss

    def on_validation_epoch_end(self):
        validation_step_outputs_ner = self.validation_step_outputs_ner
        validation_step_outputs_intent = self.validation_step_outputs_intent

        dummy_input = torch.zeros((1, 128), device=self.device, dtype=torch.long)
        model_filename = f"model_{str(self.global_step).zfill(5)}.onnx"
        torch.onnx.export(self, dummy_input, model_filename)
        artifact = wandb.Artifact(name="model.ckpt", type="model")
        artifact.add_file(model_filename)
        self.logger.experiment.log_artifact(artifact)

        flattened_logits_ner = torch.flatten(torch.cat(validation_step_outputs_ner))
        flattened_logits_intent = torch.flatten(torch.cat(validation_step_outputs_intent))
        self.logger.experiment.log(
            {"valid/ner_logits": wandb.Histogram(flattened_logits_ner.to('cpu')),
             "valid/intent_logits": wandb.Histogram(flattened_logits_intent.to('cpu')),
             "global_step": self.global_step}
        )

    def _common_step(self, batch, batch_idx):
        ids = batch['input_ids']
        mask = batch['attention_mask']
        ner_labels = batch['ner_labels']
        intent_labels = batch['intent_labels']

        ner_logits, intent_logits = self.forward(ids, mask)

        criterion = torch.nn.CrossEntropyLoss()

        ner_loss = criterion(ner_logits.view(-1, self.num_ner_labels), ner_labels.view(-1))
        intent_loss = criterion(intent_logits.view(-1, self.num_intent_labels), intent_labels.view(-1))

        loss = ner_loss + intent_loss
        return loss, ner_logits, intent_logits, ner_labels, intent_labels

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
        return optimizer

In [8]:
config = BertConfig()

In [9]:
model = MultiTaskBertModel(config, num_ner_labels, num_intent_labels)

In [10]:
onnx_model = onnx.load(os.path.join(artifact_dir, "model_00280.onnx"))

state_dict = {}

# Iterate over the model's intializers, which contain the weights and biases
for initializer in onnx_model.graph.initializer:
    # Convert the initializer's data into Pytorch tensor
    tensor = numpy_helper.to_array(initializer)

    state_dict[initializer.name] = torch.tensor(tensor)

In [11]:
model.load_state_dict(state_dict)

RuntimeError: Error(s) in loading state_dict for MultiTaskBertModel:
	Missing key(s) in state_dict: "model.encoder.layer.0.attention.self.query.weight", "model.encoder.layer.0.attention.self.key.weight", "model.encoder.layer.0.attention.self.value.weight", "model.encoder.layer.0.attention.output.dense.weight", "model.encoder.layer.0.intermediate.dense.weight", "model.encoder.layer.0.output.dense.weight", "model.encoder.layer.1.attention.self.query.weight", "model.encoder.layer.1.attention.self.key.weight", "model.encoder.layer.1.attention.self.value.weight", "model.encoder.layer.1.attention.output.dense.weight", "model.encoder.layer.1.intermediate.dense.weight", "model.encoder.layer.1.output.dense.weight", "model.encoder.layer.2.attention.self.query.weight", "model.encoder.layer.2.attention.self.key.weight", "model.encoder.layer.2.attention.self.value.weight", "model.encoder.layer.2.attention.output.dense.weight", "model.encoder.layer.2.intermediate.dense.weight", "model.encoder.layer.2.output.dense.weight", "model.encoder.layer.3.attention.self.query.weight", "model.encoder.layer.3.attention.self.key.weight", "model.encoder.layer.3.attention.self.value.weight", "model.encoder.layer.3.attention.output.dense.weight", "model.encoder.layer.3.intermediate.dense.weight", "model.encoder.layer.3.output.dense.weight", "model.encoder.layer.4.attention.self.query.weight", "model.encoder.layer.4.attention.self.key.weight", "model.encoder.layer.4.attention.self.value.weight", "model.encoder.layer.4.attention.output.dense.weight", "model.encoder.layer.4.intermediate.dense.weight", "model.encoder.layer.4.output.dense.weight", "model.encoder.layer.5.attention.self.query.weight", "model.encoder.layer.5.attention.self.key.weight", "model.encoder.layer.5.attention.self.value.weight", "model.encoder.layer.5.attention.output.dense.weight", "model.encoder.layer.5.intermediate.dense.weight", "model.encoder.layer.5.output.dense.weight", "model.encoder.layer.6.attention.self.query.weight", "model.encoder.layer.6.attention.self.key.weight", "model.encoder.layer.6.attention.self.value.weight", "model.encoder.layer.6.attention.output.dense.weight", "model.encoder.layer.6.intermediate.dense.weight", "model.encoder.layer.6.output.dense.weight", "model.encoder.layer.7.attention.self.query.weight", "model.encoder.layer.7.attention.self.key.weight", "model.encoder.layer.7.attention.self.value.weight", "model.encoder.layer.7.attention.output.dense.weight", "model.encoder.layer.7.intermediate.dense.weight", "model.encoder.layer.7.output.dense.weight", "model.encoder.layer.8.attention.self.query.weight", "model.encoder.layer.8.attention.self.key.weight", "model.encoder.layer.8.attention.self.value.weight", "model.encoder.layer.8.attention.output.dense.weight", "model.encoder.layer.8.intermediate.dense.weight", "model.encoder.layer.8.output.dense.weight", "model.encoder.layer.9.attention.self.query.weight", "model.encoder.layer.9.attention.self.key.weight", "model.encoder.layer.9.attention.self.value.weight", "model.encoder.layer.9.attention.output.dense.weight", "model.encoder.layer.9.intermediate.dense.weight", "model.encoder.layer.9.output.dense.weight", "model.encoder.layer.10.attention.self.query.weight", "model.encoder.layer.10.attention.self.key.weight", "model.encoder.layer.10.attention.self.value.weight", "model.encoder.layer.10.attention.output.dense.weight", "model.encoder.layer.10.intermediate.dense.weight", "model.encoder.layer.10.output.dense.weight", "model.encoder.layer.11.attention.self.query.weight", "model.encoder.layer.11.attention.self.key.weight", "model.encoder.layer.11.attention.self.value.weight", "model.encoder.layer.11.attention.output.dense.weight", "model.encoder.layer.11.intermediate.dense.weight", "model.encoder.layer.11.output.dense.weight", "ner_classifier.weight". 
	Unexpected key(s) in state_dict: "onnx::MatMul_1288", "onnx::MatMul_1289", "onnx::MatMul_1295", "onnx::MatMul_1310", "onnx::MatMul_1311", "onnx::MatMul_1312", "onnx::MatMul_1313", "onnx::MatMul_1314", "onnx::MatMul_1320", "onnx::MatMul_1335", "onnx::MatMul_1336", "onnx::MatMul_1337", "onnx::MatMul_1338", "onnx::MatMul_1339", "onnx::MatMul_1345", "onnx::MatMul_1360", "onnx::MatMul_1361", "onnx::MatMul_1362", "onnx::MatMul_1363", "onnx::MatMul_1364", "onnx::MatMul_1370", "onnx::MatMul_1385", "onnx::MatMul_1386", "onnx::MatMul_1387", "onnx::MatMul_1388", "onnx::MatMul_1389", "onnx::MatMul_1395", "onnx::MatMul_1410", "onnx::MatMul_1411", "onnx::MatMul_1412", "onnx::MatMul_1413", "onnx::MatMul_1414", "onnx::MatMul_1420", "onnx::MatMul_1435", "onnx::MatMul_1436", "onnx::MatMul_1437", "onnx::MatMul_1438", "onnx::MatMul_1439", "onnx::MatMul_1445", "onnx::MatMul_1460", "onnx::MatMul_1461", "onnx::MatMul_1462", "onnx::MatMul_1463", "onnx::MatMul_1464", "onnx::MatMul_1470", "onnx::MatMul_1485", "onnx::MatMul_1486", "onnx::MatMul_1487", "onnx::MatMul_1488", "onnx::MatMul_1489", "onnx::MatMul_1495", "onnx::MatMul_1510", "onnx::MatMul_1511", "onnx::MatMul_1512", "onnx::MatMul_1513", "onnx::MatMul_1514", "onnx::MatMul_1520", "onnx::MatMul_1535", "onnx::MatMul_1536", "onnx::MatMul_1537", "onnx::MatMul_1538", "onnx::MatMul_1539", "onnx::MatMul_1545", "onnx::MatMul_1560", "onnx::MatMul_1561", "onnx::MatMul_1562", "onnx::MatMul_1563", "onnx::MatMul_1564", "onnx::MatMul_1570", "onnx::MatMul_1585", "onnx::MatMul_1586", "onnx::MatMul_1587", "onnx::MatMul_1588". 

In [None]:
scripted_model = model.to_torchscript(file_path=artifact_dir)

In [None]:
# Step 1: Save the TorchScript model to a .pt file
torch.jit.save(model, "model.pt")

# Step 2: Log the .pt file as an artifact
run = wandb.init(project="prince", job_type="model_upload")
artifact = wandb.Artifact("model", type="model")
artifact.add_file("model.pt")
run.log_artifact(artifact)

# Step 3: Finish the run
run.finish()