In [1]:
!pip install -q datasets pytorch_lightning transformers wandb undecorated sentencepiece

In [2]:
!nvidia-smi

Sat Dec 10 05:23:58 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   70C    P8    16W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import pytorch_lightning as pl
from pytorch_lightning import LightningDataModule
from pytorch_lightning import Trainer
from transformers import T5ForConditionalGeneration, T5Tokenizer
from undecorated import undecorated
from types import MethodType
import wandb
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import random

In [4]:
dataset = load_dataset("qanastek/MASSIVE", "en-US", data_dir='MASSIVE')

_INTENTS = ['audio_volume_other', 'play_music', 'iot_hue_lighton', 'general_greet', 'calendar_set', 'audio_volume_down', 'social_query', 'audio_volume_mute', 'iot_wemo_on', 'iot_hue_lightup', 'audio_volume_up', 'iot_coffee', 'takeaway_query', 'qa_maths', 'play_game', 'cooking_query', 'iot_hue_lightdim', 'iot_wemo_off', 'music_settings', 'weather_query', 'news_query', 'alarm_remove', 'social_post', 'recommendation_events', 'transport_taxi', 'takeaway_order', 'music_query', 'calendar_query', 'lists_query', 'qa_currency', 'recommendation_movies',
            'general_joke', 'recommendation_locations', 'email_querycontact', 'lists_remove', 'play_audiobook', 'email_addcontact', 'lists_createoradd', 'play_radio', 'qa_stock', 'alarm_query', 'email_sendemail', 'general_quirky', 'music_likeness', 'cooking_recipe', 'email_query', 'datetime_query', 'transport_traffic', 'play_podcasts', 'iot_hue_lightchange', 'calendar_remove', 'transport_query', 'transport_ticket', 'qa_factoid', 'iot_cleaning', 'alarm_set', 'datetime_convert', 'iot_hue_lightoff', 'qa_definition', 'music_dislikeness']
_TAGS = ['O', 'B-food_type', 'B-movie_type', 'B-person', 'B-change_amount', 'I-relation', 'I-game_name', 'B-date', 'B-movie_name', 'I-person', 'I-place_name', 'I-podcast_descriptor', 'I-audiobook_name', 'B-email_folder', 'B-coffee_type', 'B-app_name', 'I-time', 'I-coffee_type', 'B-transport_agency', 'B-podcast_descriptor', 'I-playlist_name', 'B-media_type', 'B-song_name', 'I-music_descriptor', 'I-song_name', 'B-event_name', 'I-timeofday', 'B-alarm_type', 'B-cooking_type', 'I-business_name', 'I-color_type', 'B-podcast_name', 'I-personal_info', 'B-weather_descriptor', 'I-list_name', 'B-transport_descriptor', 'I-game_type', 'I-date', 'B-place_name', 'B-color_type', 'B-game_name', 'I-artist_name', 'I-drink_type', 'B-business_name', 'B-timeofday', 'B-sport_type', 'I-player_setting', 'I-transport_agency', 'B-game_type', 'B-player_setting', 'I-music_album', 'I-event_name', 'I-general_frequency', 'I-podcast_name', 'I-cooking_type', 'I-radio_name', 'I-joke_type',
         'I-meal_type', 'I-transport_type', 'B-joke_type', 'B-time', 'B-order_type', 'B-business_type', 'B-general_frequency', 'I-food_type', 'I-time_zone', 'B-currency_name', 'B-time_zone', 'B-ingredient', 'B-house_place', 'B-audiobook_name', 'I-ingredient', 'I-media_type', 'I-news_topic', 'B-music_genre', 'I-definition_word', 'B-list_name', 'B-playlist_name', 'B-email_address', 'I-currency_name', 'I-movie_name', 'I-device_type', 'I-weather_descriptor', 'B-audiobook_author', 'I-audiobook_author', 'I-app_name', 'I-order_type', 'I-transport_name', 'B-radio_name', 'I-business_type', 'B-definition_word', 'B-artist_name', 'I-movie_type', 'B-transport_name', 'I-email_folder', 'B-music_album', 'I-house_place', 'I-music_genre', 'B-drink_type', 'I-alarm_type', 'B-music_descriptor', 'B-news_topic', 'B-meal_type', 'I-transport_descriptor', 'I-email_address', 'I-change_amount', 'B-device_type', 'B-transport_type', 'B-relation', 'I-sport_type', 'B-personal_info']


def index_to_intent(index):
    return _INTENTS[index]


def index_to_tag(index):
    return _TAGS[index]



  0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
def get_intent_slots():
    intent_slots = {}
    for data in dataset['train']:
        intent = data['intent']
        if intent not in intent_slots:
            intent_slots[intent] = set(data['ner_tags'])
        else:
            intent_slots[intent] = intent_slots[intent].union(
                set(data['ner_tags']))

    named_data = {}
    for intent, slots in intent_slots.items():
        named_data[intent] = 'slots: ' + ", ".join(set(
            [index_to_tag(slot)[2:] if index_to_tag(slot) != 'O' else 'O' for slot in slots]))
    return named_data


def get_data_per_intent():
    data_per_intent = {}
    for data in dataset['train']:
        intent = data['intent']
        if intent not in data_per_intent:
            data_per_intent[intent] = [data]
        else:
            data_per_intent[intent].append(data)
    return data_per_intent

def get_min_num_of_examples_per_intent():
    data_per_intent = get_data_per_intent()
    min_num_of_examples_per_intent = min(
        [len(data_per_intent[intent]) for intent in data_per_intent])
    return min_num_of_examples_per_intent    

In [6]:
class T5GenerationFineTune(Dataset):
    def __init__(self, dataset, intent_slots, tokenizer, data_per_intent, min_examples, max_length=50):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.intent_slots = intent_slots
        self.data_per_intent = data_per_intent
        self.min_examples = min_examples

    def __len__(self):
        return len(self.dataset)

    # format of data (intent, slots_for_intent, example: text), all_examples_except_current
    def __getitem__(self, index):
        data = self.dataset[index]
        intent = data['intent']
        slots_for_intent = self.intent_slots[intent]
        text = data['utt']
        # all_examples_except_current = self.data_per_intent[intent][:index] + \
        #     self.data_per_intent[intent][index+1:]
        # all_examples_except_current = [example['utt']
        #                                for i,example in enumerate(all_examples_except_current) if i < self.min_examples]
        # print(all_examples_except_current)

        input_text = f"intent: {index_to_intent(intent)}\n{slots_for_intent}\nexample: {text}"
        tokenized_text = self.tokenizer(
            input_text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt')
        
        label_text = self.dataset[random.randint(0, len(self.dataset)-1)]['utt']
        # print(label_text)
        
        desired_out = self.tokenizer(self.dataset[random.randint(0, len(self.dataset)-1)]['utt'], max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt').input_ids
        # print(desired_out)
        # tokenized_examples = self.tokenizer(all_examples_except_current, max_length=self.max_length, padding='max_length',
        #                                      truncation=True, return_tensors='pt').input_ids


        return tokenized_text.input_ids.squeeze(0), tokenized_text.attention_mask.squeeze(0), desired_out.squeeze(0)

In [7]:
class T5GenerationDataModule(LightningDataModule):
    def __init__(self, dataset, tokenizer, num_workers, batch_size=8, max_length=64):
        super().__init__()
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.max_length = max_length
        self.num_workers = num_workers

    def setup(self, stage=None):
        self.intent_slots = get_intent_slots()
        self.data_per_intent = get_data_per_intent()
        self.min_examples = get_min_num_of_examples_per_intent()
        self.train_dataset = T5GenerationFineTune(
            self.dataset['train'], self.intent_slots, self.tokenizer, self.data_per_intent, self.min_examples, self.max_length)
        self.val_dataset = T5GenerationFineTune(
            self.dataset['validation'], self.intent_slots, self.tokenizer, self.data_per_intent, self.min_examples, self.max_length)
        self.test_dataset = T5GenerationFineTune(
            self.dataset['test'], self.intent_slots, self.tokenizer, self.data_per_intent, self.min_examples, self.max_length)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)

In [8]:
class T5DataGenerator(pl.LightningModule):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.model = T5ForConditionalGeneration.from_pretrained(
            self.args['model_name'])

        self.loss = torch.nn.CrossEntropyLoss()

    # have the model generate multiple unique candidate outputs for a given input

    def forward(self, input_ids, attention_mask, labels):
        # print(input_ids.shape)
        # print
        return self.model(
            input_ids=input_ids,
            labels=labels
            )
            # max_length=64,
            # top_k=self.args['top_k'],
            # penalty_alpha=0.6

    # for a given training step, generate a batch of candidate outputs for a given input
    # input_ids are the tokenized examples of the data we want to generate
    # the labels are all of the potential outputs the model could generate for a given intent
    # model out shape (bs * num_return_sequences, max_seq_len)
    def training_step(self, batch, batch_idx):
        input_ids, attention_masks, examples = batch

        outputs = self(input_ids=input_ids, attention_mask=attention_masks, labels=examples)

        loss = outputs.loss

        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids, attention_masks, examples = batch

        outputs = self(input_ids=input_ids, attention_mask=attention_masks, labels=examples)

        loss = outputs.loss

        self.log('validation_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-4)

In [10]:
model_name = 'google/flan-t5-small'

args = {
    'model_name': model_name,
    'top_k': 5,
    'num_return_sequences': 5,
}

model = T5DataGenerator(args)
tokenizer = T5Tokenizer.from_pretrained(model_name)
data_module = T5GenerationDataModule(dataset, tokenizer, 2, batch_size=128)

logger = WandbLogger(project='t5-data-generator', name=model_name, log_model="all")
checkpoint_callback = ModelCheckpoint(monitor='validation_loss', mode='min')

Downloading:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mandrewmead[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=5, logger=logger, log_every_n_steps=25, callbacks=[checkpoint_callback])
trainer.fit(model, data_module)
wandb.finish()


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 77.0 M
1 | loss  | CrossEntropyLoss           | 0     
-----------------------------------------------------
77.0 M    Trainable params
0         Non-trainable params
77.0 M    Total params
307.845   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


0,1
epoch,▁▁▁▁▃▃▃▃▃▅▅▅▅▆▆▆▆▆█████
train_loss,█▄▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▂▂▂▃▃▃▄▄▄▅▅▅▆▆▆▇▇▇███
validation_loss,█▂▁▁▁

0,1
epoch,4.0
train_loss,0.61081
trainer/global_step,449.0
validation_loss,0.50816


In [32]:
loader = data_module.train_dataloader()

data = next(iter(loader))
print(data[0].shape)
print(tokenizer.decode(data[0][0]))

dummy_labels = torch.zeros((data[0][0].shape), dtype=torch.long).unsqueeze(0)
out = model(data[0][0].unsqueeze(0), None, dummy_labels).logits
print(out.shape)
out = out.argmax(dim=2).squeeze(0)
print(out.shape)
print(out)
print(tokenizer.decode(out))


torch.Size([128, 64])
intent: news_query slots: slots: O, transport_type, general_frequency, time, news_topic, media_type, person, date, place_name, timeofday, device_type example: tell me the latest trending news of this week</s><pad><pad><pad><pad><pad><pad>
torch.Size([1, 64, 32128])
torch.Size([64])
tensor([125, 125, 125,  54,  54, 817, 125, 125, 817, 754,   3, 125,  54, 125,
         54, 817, 149,   3, 149, 125, 149, 817,   3, 149, 817, 817, 125, 125,
        817, 125, 103, 125, 125, 817, 125, 125, 125, 817,   3, 125, 125, 754,
        125,  54, 817, 125, 125, 125,  54, 125,   3, 817, 125, 149, 125, 125,
        125, 125, 149,   3,  54, 817, 125, 125])
what what what can can tell what what tell please  what can what can tell how  how what how tell  how tell tell what what tell what do what what tell what what what tell  what what please what can tell what what what can what  tell what how what what what what how  can tell what what


In [None]:
# 11B-XXL
# 60M
# 20x larger

# 2 mins/ epoch

# 5x memory
# 8x faster