In [1]:
!pip install wandb pytorch-lightning transformers[sentencepiece]

Collecting wandb
  Downloading wandb-0.16.3-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pytorch-lightning
  Downloading pytorch_lightning-2.2.0-py3-none-any.whl (800 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m800.3/800.3 kB[0m [31m27.8 MB/s[0m eta [36m0:00:00[0m
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.41-py3-none-any.whl (196 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m196.4/196.4 kB[0m [31m24.6 MB/s[0m eta [36m0:00:00[0m
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-1.40.2-py2.py3-none-any.whl (257 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m257.7/257.7 kB[0m [31m21.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting setp

In [2]:
import wandb
import pytorch_lightning as pl
import torch
from transformers import BertModel, BertConfig
from torchmetrics import Metric

In [3]:
run = wandb.init()
artifact = run.use_artifact('prince_/lit-wandb/model.ckpt:v34', type='model')
artifact_dir = artifact.download()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


[34m[1mwandb[0m: Downloading large artifact model.ckpt:v34, 417.81MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:20.7


In [10]:
pl.LightningModule.to_torchscript??

In [None]:
num_ner_labels, num_intent_labels = 9, 5

In [None]:
class MyAccuracy(Metric):
    def __init__(self):
        super().__init__()
        self.add_state('total', default=torch.tensor(0), dist_reduce_fx='sum')
        self.add_state('correct', default=torch.tensor(0), dist_reduce_fx='sum')

    def update(self, logits, labels, num_labels):

        flattened_targets = labels.view(-1) # shape (batch_size, sequence_len)
        active_logits = logits.view(-1, num_labels) # shape (batch_size * sequence_len, num_labels)
        flattened_predictions = torch.argmax(active_logits, axis=1) # shape (batch_size * sequence_len)

        # compute accuracy only at active labels
        active_accuracy = labels.view(-1) != -100 # shape (batch_size, sequnce_len)
        ac_labels = torch.masked_select(flattened_targets, active_accuracy)
        predictions = torch.masked_select(flattened_predictions, active_accuracy)

        self.correct += torch.sum(ac_labels == predictions)
        self.total += torch.numel(ac_labels)

    def compute(self):
        return self.correct.float() / self.total.float()

In [None]:
class MultiTaskBertModel(pl.LightningModule):

    """
    Multi-task Bert model for Named Entity Recognition (NER) and Intent Classification

    Args:
        config (BertConfig): Bert model configuration.
        num_ner_labels (int): The number of labels for NER task.
        num_intent_labels (int): The number of labels for Intent Classification task.
    """

    def __init__(self, config, num_ner_labels, num_intent_labels):
        super().__init__()

        self.num_ner_labels = num_ner_labels
        self.num_intent_labels = num_intent_labels

        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)

        self.model = BertModel(config=config)

        self.ner_classifier = torch.nn.Linear(config.hidden_size, self.num_ner_labels)
        self.intent_classifier = torch.nn.Linear(config.hidden_size, self.num_intent_labels)

        # log hyperparameters
        self.save_hyperparameters()

        self.accuracy = MyAccuracy()

    def forward(self, input_ids=None, attention_mask=None):

        """
        Perform a forward pass through Multi-task Bert model.

        Args:
            input_ids (torch.Tensor): Input token IDs.
            attention_mask (torch.Tensor): Attention mask for input tokens.
            ner_labels (torch.Tensor): Labels for NER task.
            intent_labels (torch.Tensor): Labels for Intent Classification task.

        Returns:
            Tuple[torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor]: NER loss, NER logits, Intent loss, Intent logits.

        Raises:
            ValueError: If ner_labels or intent_labels were not provided.
        """

        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)

        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        ner_logits = self.ner_classifier(sequence_output)

        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        intent_logits = self.intent_classifier(pooled_output)

        return ner_logits, intent_logits

    def training_step(self: pl.LightningModule, batch, batch_idx: int):
        loss, ner_logits, intent_logits, ner_labels, intent_labels = self._common_step(batch, batch_idx)
        accuracy_ner = self.accuracy(ner_logits, ner_labels, num_ner_labels)
        accuracy_intent = self.accuracy(intent_logits, intent_labels, num_intent_labels)
        self.log_dict({'training_loss': loss, 'ner_accuracy': accuracy_ner, 'intent_accuracy': accuracy_intent},
                      on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def on_validation_epoch_start(self):
        self.validation_step_outputs_ner = []
        self.validation_step_outputs_intent = []

    def validation_step(self, batch, batch_idx):
        loss, ner_logits, intent_logits, ner_labels, intent_labels = self._common_step(batch, batch_idx)
        # self.log('val_loss', loss)
        accuracy_ner = self.accuracy(ner_logits, ner_labels, num_ner_labels)
        accuracy_intent = self.accuracy(intent_logits, intent_labels, num_intent_labels)
        self.log_dict({'validation_loss': loss, 'val_ner_accuracy': accuracy_ner, 'val_intent_accuracy': accuracy_intent},
                      on_step=False, on_epoch=True, prog_bar=True)

        self.validation_step_outputs_ner.append(ner_logits)
        self.validation_step_outputs_intent.append(intent_logits)
        return loss

    def on_validation_epoch_end(self):
        validation_step_outputs_ner = self.validation_step_outputs_ner
        validation_step_outputs_intent = self.validation_step_outputs_intent

        dummy_input = torch.zeros((1, 128), device=self.device, dtype=torch.long)
        model_filename = f"model_{str(self.global_step).zfill(5)}.onnx"
        torch.onnx.export(self, dummy_input, model_filename)
        artifact = wandb.Artifact(name="model.ckpt", type="model")
        artifact.add_file(model_filename)
        self.logger.experiment.log_artifact(artifact)

        flattened_logits_ner = torch.flatten(torch.cat(validation_step_outputs_ner))
        flattened_logits_intent = torch.flatten(torch.cat(validation_step_outputs_intent))
        self.logger.experiment.log(
            {"valid/ner_logits": wandb.Histogram(flattened_logits_ner.to('cpu')),
             "valid/intent_logits": wandb.Histogram(flattened_logits_intent.to('cpu')),
             "global_step": self.global_step}
        )

    def _common_step(self, batch, batch_idx):
        ids = batch['input_ids']
        mask = batch['attention_mask']
        ner_labels = batch['ner_labels']
        intent_labels = batch['intent_labels']

        ner_logits, intent_logits = self.forward(ids, mask)

        criterion = torch.nn.CrossEntropyLoss()

        ner_loss = criterion(ner_logits.view(-1, self.num_ner_labels), ner_labels.view(-1))
        intent_loss = criterion(intent_logits.view(-1, self.num_intent_labels), intent_labels.view(-1))

        loss = ner_loss + intent_loss
        return loss, ner_logits, intent_logits, ner_labels, intent_labels

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
        return optimizer

In [None]:
config = BertConfig()

In [None]:
model = MultiTaskBertModel(config, num_ner_labels, num_intent_labels)

In [None]:
model.to_torchscript(file_path=artifact_dir)

In [None]:
# Assuming `script_model` is your TorchScript model
script_model = torch.jit.script(model)

# Step 1: Save the TorchScript model to a .pt file
torch.jit.save(script_model, "model.pt")

# Step 2: Log the .pt file as an artifact
run = wandb.init(project="your_project_name", job_type="model_upload")
artifact = wandb.Artifact("model", type="model")
artifact.add_file("model.pt")
run.log_artifact(artifact)

# Step 3: Finish the run
run.finish()