## Introduction

In [1]:
# Imports
import torch
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast, AutoModelForCausalLM, AutoTokenizer
import json
from sklearn.model_selection import train_test_split # scikit-learn
from torch.nn import functional as F

# Constants
RANDOM_SEED = 69
DATASET_FILENAME = '../data/clean/customer_support_twitter_full.json'
CLASSIFIER_MODEL_NAME = 'distilbert-base-uncased'
GENERATOR_MODEL_NAME = 'microsoft/DialoGPT-small'
SEQUENCE_LENGTH = 2

# Setup
torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x1d2dcd6de70>

In [2]:
class MultiTaskDataModule(pl.LightningDataModule):

    def __init__(self, filename: str, batch_size=4):
        super().__init__()
        self.filename = filename
        self.batch_size = batch_size
        self.conversations = self._load_conversations(filename)
        self.labels = self._determine_labels(self.conversations)
        self.classifier_tokenizer = DistilBertTokenizerFast.from_pretrained(CLASSIFIER_MODEL_NAME)
        self.generator_tokenizer = self._init_generator_tokenizer()

        data = []
        for conversation in self.conversations[:3]:
            for j, (input_message, target_message) in enumerate(zip(conversation, conversation[1:])):
                if target_message.get('authored'):
                    data.append({
                        'classifier': self._get_classifier_data(input_message),
                        'predictor': self._get_predictor_data(conversation[:j], target_message),
                        'generator': self._get_generator_data(input_message, target_message),
                    })
        print(data[:1])

        # Split the data into 80% train, 10% validation, and 10% test
        self.train_data, temp_data = train_test_split(data, test_size=0.2, random_state=RANDOM_SEED)
        self.val_data, self.test_data = train_test_split(temp_data, test_size=0.5, random_state=RANDOM_SEED)

    @staticmethod
    def _load_conversations(filename):
        with open(filename) as file:
            conversations = json.load(file)
        return conversations

    @staticmethod
    def _determine_labels(conversations):
        labels = set()
        for conversation in conversations:
            for message in conversation:
                for intent in message.get('intents'):
                    labels.add(intent)
        return sorted(list(labels))

    @staticmethod
    def _init_generator_tokenizer():
        tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL_NAME)
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        return tokenizer

    def _get_label(self, intents: list[str]):
        label = torch.zeros(len(self.labels))
        for intent in intents:
            label[self.labels.index(intent)] = 1
        return label

    def _get_classifier_data(self, message):
        inputs = self.classifier_tokenizer(
            message.get('text'), 
            padding='max_length',
            max_length=50,
            truncation=True,
            return_tensors='pt',
        )
        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'labels': self._get_label(message.get('intents')),
        }

    def _get_predictor_data(self, sub_conversation, target_message):
        labels = torch.stack(([self._get_label([])]*SEQUENCE_LENGTH)+[self._get_label(message.get('intents')) for message in sub_conversation])
        latest_labels = labels[-SEQUENCE_LENGTH:]
        return latest_labels, self._get_label(target_message.get('intents'))

    def _get_generator_data(self, input_message, target_message):
        text = f"{input_message.get('text')}{self.generator_tokenizer.eos_token}{target_message.get('text')}"
        tokens = self.generator_tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=50,
            return_tensors='pt',
        )
        input_ids = tokens['input_ids'].squeeze()
        start_idx = (input_ids == self.generator_tokenizer.eos_token_id).nonzero(as_tuple=True)[0][0]
        labels = input_ids.clone()
        labels[:start_idx+1] = -100
        return {
            'input_ids': input_ids,
            'attention_mask': tokens['attention_mask'].squeeze(),
            'labels': labels,
        }

    @property
    def stats(self):
        return '\n'.join([
            f'Conversation Count: {len(self.conversations)}',
            f'Label Counts: {self.labels}',
        ])

    @property
    def n_labels(self):
        return len(self.labels)

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)

dm = MultiTaskDataModule(DATASET_FILENAME)
print(dm.stats)

[{'classifier': {'input_ids': tensor([  101,  1030, 18108,  6279,  6442, 24471,  2140,   102,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0]), 'labels': tensor([0., 0., 0., 0., 0., 0., 0., 0., 1.])}, 'predictor': (tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.]]), tensor([1., 0., 1., 0., 1., 1., 0., 0., 0.])), 'generator': {'input_ids': tensor([   31, 16108, 15514,   220, 10289, 50256, 29904, 20608,   775,   821,
          994,   329,   345,    13,  9022,  2196,   286,   262,  89

In [7]:
class MultiTaskModel(pl.LightningModule):

    def __init__(self, n_labels):
        super().__init__()
        self.n_labels = n_labels

        # DistilBERT Classifier
        self.classifier = DistilBertForSequenceClassification.from_pretrained(
            CLASSIFIER_MODEL_NAME,
            num_labels=n_labels,
            problem_type='multi_label_classification',
            return_dict=True,
            output_hidden_states=True,
        )

        # Next Intent Predictor
        self.gru = torch.nn.GRU(
            input_size=n_labels, 
            hidden_size=self.classifier.config.dim,
            batch_first=True,
        )
        self.predictor_fc = torch.nn.Linear(self.classifier.config.dim, n_labels)

        # Text Generator
        self.generator = AutoModelForCausalLM.from_pretrained(GENERATOR_MODEL_NAME)
        self.generator.resize_token_embeddings(len(dm.generator_tokenizer))

    def forward(self, inputs):
        # Classification
        classifier_output = self.classifier(**inputs['classifier'])

        # Predictor
        classifier_hidden_state = classifier_output.hidden_states[-1][:,0,:]
        h_0 = classifier_hidden_state.unsqueeze(dim=0).contiguous()
        gru_output, _ = self.gru(inputs['predictor'][0], h_0)
        # We're interested in the last output for prediction,
        # which is the contextually richest. If x has shape (batch_size, seq_len, input_dim),
        # out will have shape (batch_size, seq_len, hidden_dim).
        # Thus, we select out[:,-1,:] to get a shape of (batch_size, hidden_dim)
        gru_output = gru_output[:, -1, :]
        fc_output = self.predictor_fc(gru_output)

        # Generator
        generator_output = self.generator(**inputs['generator'])
        
        return classifier_output, fc_output, generator_output

    # def _common_step(self, batch, batch_idx):
    #     classifier_output, predictor_output = self(batch)
    #     return classifier_output

    def _common_log(self, classifier_output, predictor_loss, generator_output):
        self.log_dict({
            'train_class_loss': classifier_output.loss,
            'train_pred_loss': predictor_loss,
            'train_gen_loss': generator_output.loss,
        }, prog_bar=True)

    def training_step(self, batch, batch_idx):
        classifier_output, predictor_output, generator_output = self(batch)
        pred_loss = F.binary_cross_entropy_with_logits(predictor_output, batch['predictor'][1])
        self._common_log(classifier_output, pred_loss, generator_output)

        total_loss = classifier_output.loss + pred_loss + generator_output.loss
        return total_loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-5)
        

model = MultiTaskModel(dm.n_labels)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias', 'pre_classifier

In [8]:
trainer = pl.Trainer(
    max_epochs=2,
    callbacks=[pl.callbacks.RichProgressBar(leave=True)],
)
trainer.fit(model, datamodule=dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

  rank_zero_warn(
  rank_zero_warn(


Output()

`Trainer.fit` stopped: `max_epochs=2` reached.


In [65]:
# print(model.classifier_tokenizer(d['input_text']))
# X = [{'input_ids': model.classifier_tokenizer(d['input_text'])} for d in data]
# data = [{'input_ids': [1,2,3]}, {'input_ids': [2,3,4]}]
# data = [{'class': {'input_ids': [1,2,3]}}, {'class': {'input_ids': [2,3,4]}}]
data = dm.train_data.copy()
loader = DataLoader(data, batch_size=2)
for batch in loader:
    print(batch)

{'classifier': {'input_ids': tensor([[  101,  1030, 18108,  6279,  6442,  2023,  2003,  2054,  2009,  3504,
          2066, 24471,  2140,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  1030, 18108,  6279,  6442,  2023,  2003,  2054,  2003,  6230,
          1012,  1012,  1012, 24471,  2140,   102,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [18]:
DistilBertForSequenceClassification.from_pretrained(CLASSIFIER_MODEL_NAME, output_hidden_states=True)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias', 'pre_classifier

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
       

In [74]:
model.classifier.config.dim

768