# Continual Learning for personalized pictogram recommendation using PictoBERT

A pictogram is a picture with a label that denotes an action, object, person, animal, or place. Predicting the next pictogram to be set in a sentence in construction is an essential feature for AAC boards to facilitate communication.

PictoBERT, an adaptation of BERT for the next pictogram prediction task, with changed input embeddings to allow word-sense usage instead of words, considering that a word-sense represents a pictogram.

Continual learning (CL) aims to adaptively learn across time by leveraging previously learned data to improve generalization for future data.

In [1]:
!pip install gdown

Collecting gdown
  Downloading gdown-5.2.0-py3-none-any.whl.metadata (5.8 kB)
Downloading gdown-5.2.0-py3-none-any.whl (18 kB)
Installing collected packages: gdown
Successfully installed gdown-5.2.0


In [2]:
# For data downloads
import gdown
import pickle
import json
import pandas as pd

# For ARASAAC API
import requests

# For model
import torch
# from transformers import IdeficsForVisionText2Text, AutoProcessor, BitsAndBytesConfig

# For text processing
import string

# For dataset creation
# from datasets import Dataset
import random

import gc
from IPython.display import Markdown, display

def display_markdown(markdown_text: str):
    display(Markdown(markdown_text))


## Data SemChilds

The data is the automatically annotated the North American English part of the Child Language Data Exchange System (CHILDES) corpus (MacWhinney, 2014) with word-senses to used as a training corpus for PictoBert, called the SemCHILDES.

In [3]:
train_data_url = 'https://drive.google.com/uc?id=1cGD5317jxdVFLk35KGkaTErYnSyGrJDr'
train_data_output = 'train_data.pt'
gdown.download(train_data_url, train_data_output, quiet=False)
with open(train_data_output, 'rb') as file:
    # Load the object from the file
    train_data = pickle.load(file)

eval_data_url = 'https://drive.google.com/uc?id=1a4crU1Vq6ujRXmcceVeK0qDXu-H31Ml2'
eval_data_output = 'eval_data.pt'
gdown.download(eval_data_url, eval_data_output, quiet=False)
with open(eval_data_output, 'rb') as file:
    # Load the object from the file
    eval_data = pickle.load(file)

data_url = 'https://drive.google.com/uc?id=1qM4ZeSs51QO85CAlSBgjqYAnyL8ySq6d'
data_output = 'test_data.pt'
gdown.download(data_url, data_output, quiet=False)
with open(data_output, 'rb') as file:
    # Load the object from the file
    test_data = pickle.load(file)


Downloading...
From (original): https://drive.google.com/uc?id=1cGD5317jxdVFLk35KGkaTErYnSyGrJDr
From (redirected): https://drive.google.com/uc?id=1cGD5317jxdVFLk35KGkaTErYnSyGrJDr&confirm=t&uuid=b6eef3c1-7250-476d-9728-c234981b97ac
To: /kaggle/working/train_data.pt
100%|██████████| 247M/247M [00:01<00:00, 162MB/s] 
Downloading...
From: https://drive.google.com/uc?id=1a4crU1Vq6ujRXmcceVeK0qDXu-H31Ml2
To: /kaggle/working/eval_data.pt
100%|██████████| 2.52M/2.52M [00:00<00:00, 56.3MB/s]
Downloading...
From: https://drive.google.com/uc?id=1qM4ZeSs51QO85CAlSBgjqYAnyL8ySq6d
To: /kaggle/working/test_data.pt
100%|██████████| 2.52M/2.52M [00:00<00:00, 158MB/s]


In [5]:
keys = train_data.keys()
keys

dict_keys(['input_ids', 'attention_mask', 'special_tokens_mask', 'ngrams'])

### Data Visualization

In [6]:
for j in ['input_ids', 'attention_mask', 'special_tokens_mask']:
    for i in train_data[j][:1]:
        print(f"{j}: ", i)

input_ids:  [13580, 190, 7, 11, 6, 1302, 205, 1, 13579, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581]
attention_mask:  [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
special_tokens_mask:  [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [7]:
len(train_data['input_ids'])

936379

### Using half the data for training to reduce the training time and due to memory bounds

In [4]:
set_size = len(train_data['input_ids']) // 2

# Reduce dataset size by half
reduced_dataset = {
    "input_ids": train_data["input_ids"][:set_size],
    "attention_mask": train_data["attention_mask"][:set_size],
    "special_tokens_mask": train_data["special_tokens_mask"][:set_size]
}

len(reduced_dataset['input_ids'])

468189

### Specifing the parameters

In [5]:
MAX_EPOCHS = 3
WARMUP_STEPS = int(MAX_EPOCHS * 0.15)
BATCH_SIZE = 32
NUM_WORKERS = 2
GPUS = 1
LEARNING_RATE = 1e-06
ACCUMULATE_GRAD_BATCHES = 4
LOGGER_VERSION = '1e06'
LOGGER_INFO = "first_run"
FREEZE_TO = None
MLM_PROBABILITY= 0.15

### CHILDES Tokenizer

In [7]:
from transformers import PreTrainedTokenizerFast
TOKENIZER_PATH = "/kaggle/input/tokenizer/childes_all_new.json"
f = open(TOKENIZER_PATH)
semchilds_tokenizer = json.load(f)
#TOKENIZER_PATH = "/content/drive/MyDrive/TESI/DOCS/tokenizer_arasaac.json"
#TOKENIZER_PATH = "tokenizer_arasaac.json"
loaded_tokenizer = PreTrainedTokenizerFast(tokenizer_file=TOKENIZER_PATH)
loaded_tokenizer.pad_token = "[PAD]"
loaded_tokenizer.sep_token = "[SEP]"
loaded_tokenizer.mask_token = "[MASK]"
loaded_tokenizer.cls_token = "[CLS]"
loaded_tokenizer.unk_token = "[UNK]"

Using the CHILDES dictionary to create a mapping to decode the tokenized sentences.

In [7]:
word_sense_to_token_dict = semchilds_tokenizer["model"]["vocab"]

token_to_word_sense_dict = {v: k for k, v in word_sense_to_token_dict.items()}

special_tokens_dict = {i["id"]:i["content"] for i in semchilds_tokenizer["added_tokens"]}
special_tokens_dict

{0: '[UNK]', 13579: '[SEP]', 13580: '[CLS]', 13581: '[PAD]', 13582: '[MASK]'}

In [16]:
sentences = []
for i in train_data["input_ids"][2:5]:
  print(i)
  sentences.append(i)

for i in sentences:
  for j in i:
    if j in token_to_word_sense_dict.keys():
      print(token_to_word_sense_dict[j])
    elif j in special_tokens_dict.keys() and j != 13581:
      print(special_tokens_dict[j])
  print("")

[13580, 49, 256, 1277, 8, 834, 110, 1, 13579, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581]
[13580, 6, 149, 484, 76, 327, 1, 13579, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581]
[13580, 5, 8, 26, 639, 7, 23, 132, 157, 1, 13579, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581, 13581]
[CLS]
okay%5:00:00:satisfactory:00
next%5:00:00:succeeding:00
day%1:28:01::
be%2:42:06::
sunday%1:28:00::
again%4:02:00::
.
[SEP]

[CLS]
a
horse%1:05:00::
inside%4:02:00::
or
outside%4:02:00::
.
[SEP]

[CLS]
i
be%2:42:06::
go_to%2:42:00::
try%2:37:00::
it
in%4:02:01::
chair%1:06:00::
dad%1:18:00::
.
[SEP]



In [17]:
loaded_tokenizer.mask_token_id

13582

### Loading the data and defining the dataloader

In [8]:
from torch.utils.data import Dataset, Subset
from torch import tensor
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
import pickle

class MyDataset(Dataset):
    def __init__(self, examples):

        self.input_ids = examples['input_ids']
        self.attention_mask = examples['attention_mask']
        self.special_tokens_mask = examples['special_tokens_mask']
        self.labels = None
        if 'labels' in examples:
            self.labels = examples['labels']
        self.pad_token_id = loaded_tokenizer.pad_token_id
        self.max_len = 32

    def __len__(self):
        return len(self.input_ids)

    def pad(self, sequence, pad_length, pad_value):
        """
        Pads a sequence with a specific value to a desired length.

        Args:
          sequence: The sequence to be padded.
          pad_length: The desired length after padding.
          pad_value: The value to use for padding.

        Returns:
          The padded sequence.
        """
        return sequence + [pad_value] * pad_length


    def __getitem__(self, idx):
        input_ids = tensor(self.input_ids[idx])
        attention_mask = tensor(self.attention_mask[idx])
        special_tokens_mask = tensor(self.special_tokens_mask[idx])

        out_dict = {
          "input_ids":input_ids,
          "attention_mask":attention_mask,
          "special_tokens_mask":special_tokens_mask
        }

        if self.labels is not None:
            out_dict['labels'] = self.labels[idx]

        return out_dict


train_dataset = MyDataset(reduced_dataset)

val_dataset = MyDataset(eval_data)

test_dataset = MyDataset(test_data)



In [9]:
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(tokenizer=loaded_tokenizer, mlm_probability=MLM_PROBABILITY)

def data_collator_new(examples):

    batch = {
      "input_ids" : torch.stack([example['input_ids'] for example in examples]),
      "attention_mask": torch.stack([example['attention_mask'] for example in examples]),
  }
  # Clone input_ids to labels
    labels = batch["input_ids"].clone()
  # Initially set all labels to -100 to ignore them in the loss computation
    labels.fill_(-100)

  # Iterate over each sequence in the batch
    for idx, sequence in enumerate(batch["input_ids"]):
      # Find the last occurrence of the specific token
        token_positions = (sequence == token_id).nonzero(as_tuple=True)[0]
        if len(token_positions) > 0:
            last_token_position = token_positions[-1]
          # Set labels for tokens after the specific token to their actual values
            if last_token_position + 1 < sequence.size(0):  # Check if there is a next token
                labels[idx, last_token_position + 1:] = batch["input_ids"][idx, last_token_position + 1:]

  # Update the inputs dictionary to include the adjusted labels
    batch['labels'] = labels

    return batch

In [10]:
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=data_collator,
    drop_last = True,
    shuffle=True
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=data_collator,
    drop_last = True
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    collate_fn=data_collator,
    pin_memory=True,
    drop_last = True
)

In [12]:
for batch in train_dataloader:
    print("Input ids: ", batch['input_ids'], batch['input_ids'].shape)
    print("labels: ", batch['labels'], batch['labels'].shape)
    break

Input ids:  tensor([[13580,    15,    22,  ..., 13581, 13581, 13581],
        [13580,     5, 13582,  ..., 13581, 13581, 13581],
        [13580, 13582,    19,  ..., 13581, 13581, 13581],
        ...,
        [13580,    40,    72,  ..., 13581, 13581, 13581],
        [13580,     5,    21,  ..., 13581, 13581, 13581],
        [13580,    17,    19,  ..., 13581, 13581, 13581]]) torch.Size([32, 32])
labels:  tensor([[-100, -100, -100,  ..., -100, -100, -100],
        [-100, -100,  442,  ..., -100, -100, -100],
        [-100,    5, -100,  ..., -100, -100, -100],
        ...,
        [-100, -100, -100,  ..., -100, -100, -100],
        [-100, -100, -100,  ..., -100, -100, -100],
        [-100, -100, -100,  ..., -100, -100, -100]]) torch.Size([32, 32])


## Loading the Pretrained Model PictoBERT

In [13]:
from transformers import BertForMaskedLM
# from torchsummary import summary

pictobert = BertForMaskedLM.from_pretrained("/kaggle/input/pictobert-model/SemChilds_model", output_hidden_states=True)
pictobert.config.vocab_size

13583

# Model 1 : Experience Replay

* Model Overview:
   The model, LitBertClassifier, is built using PyTorch Lightning and utilizes a pretrained BERT-based architecture (BertForMaskedLM: PictoBERT) for next word prediction tasks.
   It includes an external memory buffer mechanism to store and sample data for continual learning.
   - The `memory buffer` is designed to store past training samples to facilitate experience replay (ER), a technique commonly used in continual learning.
   - ER helps `mitigate catastrophic forgetting`, a common problem in neural networks when they are trained incrementally on new tasks. 
   - The buffer allows the model to remember and rehearse past data, effectively balancing learning between new and old data.


* Key Components:

  - <b>BERT Architecture</b>: Uses a pre-trained BERT model (pictobert) to encode input text, which is further fine-tuned for specific tasks.
  - <b>Memory Buffer</b>: Implements an experience replay (ER) mechanism with a fixed memory size (`mem_sz = 1000`) to store and retrieve batches during training.
  - <b>Custom Memory Management</b>: Contains methods to initialize, update, sample from, and combine memory batches (<i>update_memory, sample_from_memory, combine_batches</i>).

* Training and Optimization:

  - <b>Optimizer</b>: Uses the `AdamW optimizer` with a polynomial decay learning rate scheduler (get_polynomial_decay_schedule_with_warmup).
  - <b>Freezing Layers</b>: Option to freeze specific layers of the BERT model for more efficient fine-tuning.

* Evaluation:
  
  - The use of experience replay in the new model has resulted in a substantial `87.43% decrease in loss` and a `13.31% decrease in perplexity`, demonstrating significantly enhanced performance, indicating significantly better convergence and optimization.


In [14]:
import random
import torch
import torch.nn as nn
from torch import optim
from torch.nn import functional as F
from sklearn.metrics import f1_score
from transformers import BertForMaskedLM, AdamW, BertTokenizer
from torch.utils.data import DataLoader, ConcatDataset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from transformers import get_polynomial_decay_schedule_with_warmup
from scipy import stats
import numpy as np


class ER_PictoBERT(pl.LightningModule):
    def __init__(self, pretrained_model_name='bert-large-uncased', mem_sz=1000, alpha=1e-3, beta=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.batch_size = 8
        self.lr = LEARNING_RATE
        self.validation_step_outputs = []
        self.validation_step_targets = []
        self.mem_sz = mem_sz
        self.ptr = 0  # Pointer to the current memory index
        self.size = 0  # Current size of the memory buffer
        self.alpha = alpha
        self.beta = beta
        self.memory = {
            "input_ids": None,
            "attention_mask": None,
            "labels": None
        }  # Initialize the memory buffer for ER
        self.bert = pictobert


    def freeze_to(self, layers):
        for param in self.bert.bert.encoder.layer[:layers].parameters():
            param.requires_grad = False

    def forward(self, input_ids, attention_mask, labels=None):
        # Check for invalid values in input_ids
        if torch.any(input_ids < 0) or torch.any(input_ids >= self.bert.config.vocab_size):
            print(self.bert.config.vocab_size)
            print("Invalid input_ids detected!")
            print("Min value:", torch.min(input_ids))
            print("Max value:", torch.max(input_ids))
        if labels is None:
            output = self.bert(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
        output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        return output

    def combine_batches(self, Bn, BM):
            # Get the device of the current batch (Bn)
        device = Bn["input_ids"].device

        # Move BM (memory batch) to the same device as Bn if they are not already on the same device
        BM = {key: value.to(device) for key, value in BM.items()}
        
        return {"input_ids": torch.cat([Bn["input_ids"], BM["input_ids"]], dim=0),
                             "attention_mask": torch.cat([Bn["attention_mask"], BM["attention_mask"]], dim=0),
                             "labels": torch.cat([Bn["labels"], BM["labels"]], dim=0)}

    def sample_from_memory(self, sample_size):
        if self.is_empty():
            return None

        # Sample random indices from memory
        indices = torch.randperm(self.memory["input_ids"].size(0))[:sample_size]

        sampled_batch = {
            "input_ids": self.memory["input_ids"][indices],
            "attention_mask": self.memory["attention_mask"][indices],
            "labels": self.memory["labels"][indices]
        }
        return sampled_batch

    def training_step(self, batch, batch_idx):
        # Current batch from general training data

        if self.is_empty():
            print(self.is_empty())
            combined_batch = batch
        # Combine current batch and personalized batch (with potential memory data)
        else : 
            BM = self.sample_from_memory(min(len(self.memory), self.batch_size))
            combined_batch = self.combine_batches(batch, BM)
            
        # Train on the combined batches
        # for combined_batch in combined_dataloader:
        outputs = self._shared_step(combined_batch, batch_idx)
        loss = outputs[0]
        self.log("train_loss", loss, on_epoch=True, prog_bar=True)

        # Update memory with personalized data
        self.update_memory(batch)
        return loss


    def meta_train_step(self, X, Y, W, alpha):
        self.train()
        optimizer = AdamW(W.parameters(), lr=alpha)
        optimizer.zero_grad()
        outputs = self.forward(X, attention_mask=None, labels=Y)
        loss = outputs.loss
        print("Meta-training loss:", loss)
        loss.backward()
        optimizer.step()
        return W

    def is_empty(self):
        return self.memory["input_ids"] is None
    
    def update_memory(self, batch):
        batch_size = batch["input_ids"].size(0)

        if self.memory["input_ids"] is None:
            # Pre-allocate memory buffer with fixed size
            self.memory["input_ids"] = torch.zeros((self.mem_sz, *batch["input_ids"].shape[1:]), dtype=batch["input_ids"].dtype)
            self.memory["attention_mask"] = torch.zeros((self.mem_sz, *batch["attention_mask"].shape[1:]), dtype=batch["attention_mask"].dtype)
            self.memory["labels"] = torch.zeros((self.mem_sz, *batch["labels"].shape[1:]), dtype=batch["labels"].dtype)

        # Calculate the end index for insertion
        end_ptr = (self.ptr + batch_size) % self.mem_sz

        if end_ptr > self.ptr:
            # Case where we don't wrap around the buffer
            self.memory["input_ids"][self.ptr:end_ptr] = batch["input_ids"]
            self.memory["attention_mask"][self.ptr:end_ptr] = batch["attention_mask"]
            self.memory["labels"][self.ptr:end_ptr] = batch["labels"]
        else:
            # Case where we wrap around the buffer
            part1_len = self.mem_sz - self.ptr
            self.memory["input_ids"][self.ptr:] = batch["input_ids"][:part1_len]
            self.memory["attention_mask"][self.ptr:] = batch["attention_mask"][:part1_len]
            self.memory["labels"][self.ptr:] = batch["labels"][:part1_len]
            self.memory["input_ids"][:end_ptr] = batch["input_ids"][part1_len:]
            self.memory["attention_mask"][:end_ptr] = batch["attention_mask"][part1_len:]
            self.memory["labels"][:end_ptr] = batch["labels"][part1_len:]

        # Update pointer and size
        self.ptr = end_ptr
        self.size = min(self.size + batch_size, self.mem_sz)


    def _shared_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        outputs = self.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        # print("Outputs:",outputs)
        return outputs

    def train_dataloader(self, train_dataset):
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            num_workers=NUM_WORKERS,
            pin_memory=True,
            collate_fn=data_collator,
            drop_last=True,
            shuffle=True
        )
        return train_dataloader

    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            result = self._shared_step(batch, batch_idx)
            val_loss = result[0].detach()
            logits = result[1].detach()

            predictions = torch.argmax(logits, dim=-1)
            labels = batch['labels']

            self.log("val_loss", val_loss, on_epoch=True, prog_bar=True)

            perplexity = torch.exp(val_loss)
            self.log("val_ppl", perplexity, on_epoch=True, prog_bar=True)

            # Save predictions and targets for F1 score calculation
            self.validation_step_outputs.append(predictions.cpu())  # Move to CPU to avoid memory issues
            self.validation_step_targets.append(labels.cpu())  # Move to CPU to avoid memory issues

            # Return loss, labels and perplexity calculation
            return {
                "val_loss": val_loss,
                "predictions": predictions,
                "labels": labels
            }


    def on_validation_epoch_end(self):
        # Concatenate all predictions and targets
        all_preds = torch.cat(self.validation_step_outputs).numpy()
        all_targets = torch.cat(self.validation_step_targets).numpy()

         # Calculate F1 score
        f1 = f1_score(all_targets.flatten(), all_preds.flatten(), average='weighted')

        # Log F1 score
        self.log('val_f1', f1, on_epoch=True, prog_bar=True)

        # Clear stored predictions and targets
        self.validation_step_outputs.clear()
        self.validation_step_targets.clear()


    def test_step(self, batch, batch_idx):
        with torch.no_grad():
            result = self._shared_step(batch, batch_idx)
            loss = result[0].detach()
            logits = result[1].detach()

            predictions = torch.argmax(logits, dim=-1)
            labels = batch['labels']

            perplexity = torch.exp(loss)
            self.log("test_ppl", perplexity, on_epoch=True, prog_bar=True)
            self.log("test_loss", loss, on_epoch=True, prog_bar=True)

            return {
                "test_ppl": perplexity,
                "test_loss": loss
            }


    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
        scheduler = {
            'scheduler': get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS,
                                                                   num_training_steps=MAX_EPOCHS, lr_end=1e-09),
            'name': 'lr'
        }
        return [optimizer], [scheduler]

    def backward(self, loss):
        loss.backward()

# Callbacks and logger setup
LOGS_PATH = "./logs/ARASAAC-contextual-ft"
CHECKPOINTS_PATH = "./checkpoints/ARASAAC-contextual-ft"

tb_logger = TensorBoardLogger(LOGS_PATH, name='logger', version='version')
lr_monitor = LearningRateMonitor(logging_interval='epoch')

checkpoint_callback = ModelCheckpoint(
    dirpath=CHECKPOINTS_PATH,
    filename='bert-large-{epoch:02d}-{train_loss:.2f}-{val_loss:.2f}',
    mode='min',
    monitor="val_loss",
    save_top_k=5
)

trainer = pl.Trainer(
    max_epochs= MAX_EPOCHS,
    logger=tb_logger,
    callbacks=[checkpoint_callback, lr_monitor],
    precision="16-mixed",
    accelerator="gpu"
)


to_train = ER_PictoBERT()


In [15]:
trainer.fit(to_train, train_dataloader, val_dataloader)



Sanity Checking: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()
  self.pid = os.fork()


Training: |          | 0/? [00:00<?, ?it/s]

True


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

  self.pid = os.fork()


Validation: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()
  self.pid = os.fork()
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Validation: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()
  self.pid = os.fork()
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Validation: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


In [17]:
# After training
display_markdown("## ER Model: Validation Set Metrics")
# print(f"Final Training Loss: {trainer.logged_metrics["train_loss"]}")
print(f"Final val_ppl : {trainer.logged_metrics['val_ppl']}")
print(f"Final val_loss: {trainer.logged_metrics['val_loss']}")

## ER Model: Validation Set Metrics

Final val_ppl : 10.19471549987793
Final val_loss: 2.1363754272460938


In [18]:
trainer.test(to_train, dataloaders=test_dataloader)

  self.pid = os.fork()


Testing: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


[{'test_ppl': 10.837124824523926, 'test_loss': 2.2029049396514893}]

In [19]:
trainer.test(to_train, dataloaders=test_dataloader)

Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_ppl': 10.414925575256348, 'test_loss': 2.141904354095459}]

# Model 2: Meta PictoBERT

+ <b>Base Model</b>: Utilizes the pre-trained BERT model (pictoBERT) designed for masked language modeling, which serves as the foundation for downstream tasks such as understanding and generating natural language text.

+ <b>Meta-Learning Framework</b>: The model leverages meta-learning techniques to adapt quickly to new tasks by simulating a scenario where it learns from a smaller dataset (trajectory data) and then evaluates on another set (meta-test data). This approach is designed to improve generalization to new, unseen data.
  - using a meta-learning approach, the model can quickly adapt to new tasks, demonstrating strong generalization capabilities

* Custom Layers:

  - <b>Linear Layers</b>: Custom linear layers are added on top of BERT to learn specific task representations.
  - <b>LSTM Layer</b>: Provides optional recurrent neural network capability for capturing sequential dependencies in the data if needed.
  - <b>Custom Weights</b>: Implements a mechanism to apply custom weights dynamically, allowing for more fine-tuned control over the learning process.

* Meta-Learning Update Mechanism:

  1. <b>Inner Update</b>: 
     + Implements an inner-loop optimization step where gradients are calculated and weights are updated using a `meta-learning rate (alpha)`, simulating multiple learning episodes within a single training step.
     + the inner loop updates a temporary set of model parameters, called `fast_weights`, using gradient descent based on a task-specific loss (cross-entropy loss). This step mimics learning on a small dataset from scratch or with few examples.

  2. <b>Outer Loop (Meta-Learning)</b>:
     + The outer loop evaluates the performance of the inner-loop updates on new tasks (meta-test data). It aims to find an optimal initialization or model configuration that, when fine-tuned, can quickly adapt to various tasks.
     + The Meta_PictoBERT model computes a `meta-loss` using the parameters learned in the inner loop (fast_weights) and updates the main model parameters to improve the ability to generalize across tasks.

* Optimization Strategy:

  - <b>Optimizer</b>: Uses the `AdamW optimizer`, which is well-suited for training large models like BERT due to its adaptive learning rate and weight decay capabilities.
  - <b>Learning Rate Scheduler</b>: Utilizes a polynomial decay schedule with warmup to adjust the learning rate dynamically during training, which helps prevent the model from overshooting the optimal solution and ensures steady convergence.

* Evaluation:
  
  - The use of meta learning, resulted in a substantial `decrease in loss` and a `lower perplexity`, even with `training for only 3 epochs` (due to memory constraints).  
  - The use of meta-learning in the Meta_PictoBERT model enables it to rapidly adapt to new tasks with minimal data, improve generalization across diverse domains, and retain knowledge effectively over time. This approach provides significant advantages in environments with continuous learning requirements.

In [15]:
import random
import torch
import torch.nn as nn
from torch import optim
from torch.nn import functional as F
from sklearn.metrics import f1_score
from transformers import BertForMaskedLM, AdamW, BertTokenizer
from torch.utils.data import DataLoader, ConcatDataset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from transformers import get_polynomial_decay_schedule_with_warmup
from scipy import stats
import numpy as np
from pytorch_lightning.callbacks import RichProgressBar


class Meta_PictoBERT(pl.LightningModule):
    def __init__(self, pretrained_model_name='bert-large-uncased',num_lstm_layers: int = 4,
                 lstm_hidden_size: int = 128,
                 ffnn_hidden_size: int = 64, mem_sz=1000, alpha=3e-4, beta=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.batch_size = 32
        self.lr = LEARNING_RATE
        self.train_dataset = train_dataset  # Global training dataset
        self.mem_sz = mem_sz
        self.alpha = alpha
        self.meta_lr = beta
        self.validation_step_outputs = []
        self.validation_step_targets = []
        self.memory = []  # Initialize the memory buffer for ER
        self.bert = pictobert #BertForMaskedLM.from_pretrained(pretrained_model_name)
        
        #self.lstm = nn.LSTM(input_size=self.bert.config.hidden_size,
#                             hidden_size=lstm_hidden_size,
#                             num_layers=num_lstm_layers,
#                             batch_first=True,
#                             bidirectional=True)
        
        self.W = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size, 128),
            nn.ReLU(),
            nn.Linear(128, self.bert.config.vocab_size)
        )

    def freeze_to(self, layers):
        for param in self.bert.bert.encoder.layer[:layers].parameters():
            param.requires_grad = False


    def forward(self, input_ids, attention_mask, fast_weights = None, labels=None, train = False):
        # Check for invalid values in input_ids
        if torch.any(input_ids < 0) or torch.any(input_ids >= self.bert.config.vocab_size):
            print(self.bert.config.vocab_size)
            print("Invalid input_ids detected!")
            print("Min value:", torch.min(input_ids))
            print("Max value:", torch.max(input_ids))
        # Forward pass through BERT
        bert_output = self.bert(input_ids=input_ids,
            attention_mask=attention_mask,
            labels = labels).hidden_states[-1]
        
#         lstm_output,_ = self.lstm(bert_output)
        
        # If no fast_weights are provided, use the model's weights
        if train :
            logits = self.apply_custom_weights(bert_output, fast_weights)
            # return logits, bert_output[0]
        else :
            logits = self.W(bert_output)

        return logits


    def inner_update(self, x, att_mask, fast_weights, y):
        if fast_weights is None:
            fast_weights = list(self.W.parameters())

        # Forward pass through BERT and meta model using fast_weights
        logits = self.forward(x, att_mask, fast_weights, y, train =True)

        loss = F.cross_entropy(logits.view(-1, self.bert.config.vocab_size), y.view(-1), ignore_index=-100)
        grad = torch.autograd.grad(loss, fast_weights, create_graph=True)

        new_weights = []
        for param, g in zip(fast_weights, grad):
            new_weights.append(param - self.alpha * g)

        return new_weights

    def apply_custom_weights(self, x, weights):
        idx = 0
        for name, layer in self.W.named_modules():
            if isinstance(layer, nn.Linear):
                w, b = weights[idx], weights[idx + 1]
                x = F.linear(x, w, b)
                idx += 2
            elif isinstance(layer, nn.Conv2d):
                w, b = weights[idx], weights[idx + 1]
                x = F.conv2d(x, w, b, stride=layer.stride, padding=layer.padding)
                idx += 2
            elif isinstance(layer, nn.ReLU):
                x = F.relu(x)

        return x

    def meta_loss(self, x, att_mask, fast_weights, y):

        # Forward pass through BERT and meta model using fast_weights
        logits = self.forward(x, att_mask, fast_weights, y, train =True)

        loss = F.cross_entropy(logits.view(-1, self.bert.config.vocab_size), y.view(-1), ignore_index=-100)

        return loss, logits

    def eval_accuracy(self, logits, y):
        pred_q = F.softmax(logits, dim=1).argmax(dim=1)
        correct = torch.eq(pred_q, y).sum().item()
        return correct

   
    def training_step(self, batch, batch_idx):
        X, Y = batch['input_ids'], batch['labels']
        mask = batch['attention_mask']

        # Perform a single random sampling and split into trajectory and meta-test
        total_indices = list(range(len(X)))
        # random.shuffle(total_indices)

        # Split indices: 50% for trajectory, 50% for meta-test (adjust the split ratio if needed)
        split_idx = len(X) // 2
        traj_indices = total_indices[:split_idx]
        meta_test_indices = total_indices[split_idx:]

        # Get trajectory and meta-test data
        X_traj = X[traj_indices]
        mask_traj = mask[traj_indices]
        Y_traj = Y[traj_indices]

        X_meta = X[meta_test_indices]
        mask_meta = mask[meta_test_indices]
        Y_meta = Y[meta_test_indices]

        # Meta-learning loop
        fast_weights = None
        # for j in range(len(X_traj)):
        #     fast_weights = self.inner_update(X_traj[j].unsqueeze(0), mask_traj[j].unsqueeze(0), fast_weights, Y_traj[j].unsqueeze(0))
        fast_weights = self.inner_update(X_traj, mask_traj, fast_weights, Y_traj)

        # Compute meta-loss on meta-test set
        meta_loss, _ = self.meta_loss(X_meta, mask_meta, fast_weights, Y_meta)

        self.log("train_loss", meta_loss, on_epoch=True, prog_bar=True)
        return meta_loss

    def _shared_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        logits = self.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        loss = F.cross_entropy(logits.view(-1, self.bert.config.vocab_size), labels.view(-1), ignore_index=-100)
        return loss, logits

    def train_dataloader(self):
        train_dataloader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=NUM_WORKERS,
            pin_memory=True,
            collate_fn=data_collator,
            drop_last=True,
            shuffle=True
        )
        return train_dataloader

    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            result = self._shared_step(batch, batch_idx)
            val_loss = result[0].detach()
            logits = result[1].detach()

            predictions = torch.argmax(logits, dim=-1)
            labels = batch['labels']

            self.log("val_loss", val_loss, on_epoch=True, prog_bar=True)
            
            perplexity = torch.exp(val_loss)
            self.log("val_ppl", perplexity, on_epoch=True, prog_bar=True)

            # Save predictions and targets for F1 score calculation
            self.validation_step_outputs.append(predictions.cpu())  # Move to CPU to avoid memory issues
            self.validation_step_targets.append(labels.cpu())  # Move to CPU to avoid memory issues

            # Return loss and perplexity calculation
            return {
                "val_loss": val_loss,
                "val_ppl":  perplexity
            }

    def on_validation_epoch_end(self):

        # Concatenate all predictions and targets
        all_preds = torch.cat(self.validation_step_outputs).numpy()
        all_targets = torch.cat(self.validation_step_targets).numpy()

        # Calculate F1 score
        f1 = f1_score(all_targets.flatten(), all_preds.flatten(), average='weighted')

        # Log F1 score
        self.log('val_f1', f1, on_epoch=True, prog_bar=True)

        # Clear stored predictions and targets
        self.validation_step_outputs.clear()
        self.validation_step_targets.clear()

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
        scheduler = {
            'scheduler': get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS,
                                                                   num_training_steps=MAX_EPOCHS, lr_end=1e-09),
            'name': 'lr'
        }
        return [optimizer], [scheduler]

    def test_step(self, batch, batch_idx):
        with torch.no_grad():
            result = self._shared_step(batch, batch_idx)
            loss = result[0].detach()
            logits = result[1].detach()

            predictions = torch.argmax(logits, dim=-1)
            labels = batch['labels']

            perplexity = torch.exp(loss)
            self.log("test_ppl", perplexity, on_epoch=True, prog_bar=True)
            self.log("test_loss", loss, on_epoch=True, prog_bar=True)


            return {
                "test_ppl": perplexity,
                "test_loss": loss
            }

    def backward(self, loss):
        loss.backward()

# # Callbacks and logger setup
# LOGS_PATH = "./logs/ARASAAC-contextual-ft"
# CHECKPOINTS_PATH = "./checkpoints/ARASAAC-contextual-ft"

# tb_logger = TensorBoardLogger(LOGS_PATH, name='logger', version='version')
# lr_monitor = LearningRateMonitor(logging_interval='epoch')

# checkpoint_callback = ModelCheckpoint(
#     dirpath=CHECKPOINTS_PATH,
#     filename='bert-large-{epoch:02d}-{train_loss:.2f}-{val_loss:.2f}',
#     mode='min',
#     monitor="val_loss",
#     save_top_k=5
# )

trainer = pl.Trainer(
    max_epochs=10,
    logger=False,
    callbacks=False, #[checkpoint_callback, lr_monitor],
    precision="16-mixed",
    accelerator="gpu"
)


to_train = Meta_PictoBERT()



In [16]:
trainer.fit(to_train, train_dataloader, val_dataloader)



Sanity Checking: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()
  self.pid = os.fork()


Training: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


Validation: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()
  self.pid = os.fork()


Validation: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()
  self.pid = os.fork()


Validation: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()
  self.pid = os.fork()


Validation: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()
  self.pid = os.fork()


Validation: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()
  self.pid = os.fork()


Validation: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()
  self.pid = os.fork()


Validation: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()
  self.pid = os.fork()


Validation: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()
  self.pid = os.fork()


Validation: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()
  self.pid = os.fork()


Validation: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


In [17]:
trainer.test(to_train, dataloaders=test_dataloader)

  self.pid = os.fork()


Testing: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


[{'test_ppl': 43.75568389892578, 'test_loss': 3.559744358062744}]

# Model 3: Meta_ER PictoBERT

1. <b>Core Components of MetaER_pictobert</b>:

   - <b>Pre-trained BERT for Masked Language Modeling (pictobert)</b>: The model utilizes a pre-trained BERT (pictobert) as its base for encoding text. This provides a powerful foundation of language understanding, which can be fine-tuned for specific tasks.
   - <b>Experience Replay (ER) Memory Buffer</b>:
     + A memory buffer stores past examples from the personalized dataset to continuously retain and replay them during training. This helps the model retain knowledge over time.
     + <i>Memory Buffer Details</i>: The buffer uses a fixed size (`mem_sz = 1000`) and is updated regularly, replacing old samples with new ones as the buffer reaches capacity.

   - <b>Meta-Learning Loop (MAML-Inspired)</b>:
     + The meta-learning loop enables the model to learn how to adapt quickly to new tasks by simulating learning on small datasets (<i>support set</i>) and evaluating on new examples (<i>query set</i>).
     + <b>Inner Loop (Task-Specific Learning)</b>: Updates fast weights (task-specific parameters) using a small subset of data. This loop mimics how the model would learn a new task from limited examples.
     + <b>Outer Loop (Meta-Optimization)</b>: Evaluates the effectiveness of the inner loop’s updates and adjusts the base model parameters (MetaER_pictobert) to enhance adaptability across tasks.
   - <b>Custom Weight Application</b>: A mechanism to dynamically apply custom weights for each layer during the meta-learning process, allowing for flexible updates and fine-tuning.
   - <b>Dynamic Learning Rate and Optimizer</b>: Uses `AdamW optimizer` with `polynomial decay scheduling` to fine-tune the learning process, ensuring better convergence over time.

2. <b>Model Training Strategy</b>:

   <b>Experience Replay and Meta-Learning Combined</b>:
     + <b>Combining Mini-Batches</b>: The model dynamically `combines current mini-batches with samples drawn from the memory buffer`, ensuring that the model is trained on both new and historical data. This balances learning between recent and past knowledge.
     + <b>Dynamic Sampling</b>: During training, a `random sampling process` splits data into two halves: a trajectory set for inner-loop meta-updates and a meta-test set for evaluating outer-loop optimization. This dual-phase training optimizes both specific and generalized learning capabilities.
 
3. <b>Memory Buffer Management</b>:

   The memory buffer uses a `pointer mechanism (ptr)` to manage storage efficiently, allowing for the `cyclic replacement` of old examples. This mechanism ensures that the model maintains a diverse set of training examples without increasing memory usage indefinitely.

4. <b>Custom Forward Pass and Inner Update</b>:

   - The forward method allows flexibility by applying different layers and their weights dynamically, enabling more granular control over the learning process.
   - <b>Inner Update Mechanism</b>: Calculates the gradients using `torch.autograd.grad`, which creates graph-based gradients and allows the model to update itself in a task-specific manner, optimizing the fast weights for quick adaptation.
   - <b>Training Step</b>: Involves a combination of experience replay and meta-learning. After a batch is processed, the loss is calculated and used to update both the model weights and the memory buffer.


In [16]:
import random
import torch
import torch.nn as nn
from torch import optim
from torch.nn import functional as F
from sklearn.metrics import f1_score
from transformers import BertForMaskedLM, AdamW, BertTokenizer
from torch.utils.data import DataLoader, ConcatDataset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from transformers import get_polynomial_decay_schedule_with_warmup
from scipy import stats
import numpy as np
from pytorch_lightning.callbacks import RichProgressBar


class MetaER_pictobert(pl.LightningModule):
    def __init__(self, pretrained_model_name='bert-large-uncased',num_lstm_layers: int = 4,
                 lstm_hidden_size: int = 128,
                 ffnn_hidden_size: int = 64, mem_sz=1000, alpha=3e-4, beta=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.batch_size = 8
        self.lr = LEARNING_RATE
        self.train_dataset = train_dataset  # Global training dataset
        self.mem_sz = mem_sz
        self.ptr = 0  # Pointer to the current memory index
        self.size = 0  # Current size of the memory buffer
        self.alpha = alpha
        self.beta = beta
        self.memory = {
            "input_ids": None,
            "attention_mask": None,
            "labels": None
        } 
        self.validation_step_outputs = []
        self.validation_step_targets = []
        self.bert = pictobert #BertForMaskedLM.from_pretrained(pretrained_model_name)
        
        #self.lstm = nn.LSTM(input_size=self.bert.config.hidden_size,
#                             hidden_size=lstm_hidden_size,
#                             num_layers=num_lstm_layers,
#                             batch_first=True,
#                             bidirectional=True)
        
        self.W = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size, 128),
            nn.ReLU(),
            nn.Linear(128, self.bert.config.vocab_size)
        )

    def freeze_to(self, layers):
        for param in self.bert.bert.encoder.layer[:layers].parameters():
            param.requires_grad = False


    def forward(self, input_ids, attention_mask, fast_weights = None, labels=None, train = False):
        # Check for invalid values in input_ids
        if torch.any(input_ids < 0) or torch.any(input_ids >= self.bert.config.vocab_size):
            print(self.bert.config.vocab_size)
            print("Invalid input_ids detected!")
            print("Min value:", torch.min(input_ids))
            print("Max value:", torch.max(input_ids))
        # Forward pass through BERT
        bert_output = self.bert(input_ids=input_ids,
            attention_mask=attention_mask,
            labels = labels).hidden_states[-1]
        
#         lstm_output,_ = self.lstm(bert_output)
        
        # If no fast_weights are provided, use the model's weights
        if train :
            logits = self.apply_custom_weights(bert_output, fast_weights)
            # return logits, bert_output[0]
        else :
            logits = self.W(bert_output)

        return logits


    def inner_update(self, x, att_mask, fast_weights, y):
        if fast_weights is None:
            fast_weights = list(self.W.parameters())


        # Forward pass through BERT and meta model using fast_weights
        logits = self.forward(x, att_mask, fast_weights, y, train =True)

        loss = F.cross_entropy(logits.view(-1, self.bert.config.vocab_size), y.view(-1), ignore_index=-100)
        grad = torch.autograd.grad(loss, fast_weights, create_graph=True)

        new_weights = []
        for param, g in zip(fast_weights, grad):
            new_weights.append(param - self.alpha * g)

        return new_weights

    def apply_custom_weights(self, x, weights):
        idx = 0
        for name, layer in self.W.named_modules():
            if isinstance(layer, nn.Linear):
                w, b = weights[idx], weights[idx + 1]
                x = F.linear(x, w, b)
                idx += 2
            elif isinstance(layer, nn.Conv2d):
                w, b = weights[idx], weights[idx + 1]
                x = F.conv2d(x, w, b, stride=layer.stride, padding=layer.padding)
                idx += 2
            elif isinstance(layer, nn.ReLU):
                x = F.relu(x)
            # Add more conditions if you have other types of layers
        return x

    def meta_loss(self, x, att_mask, fast_weights, y):

        # Forward pass through BERT and meta model using fast_weights
        logits = self.forward(x, att_mask, fast_weights, y, train =True)

        loss = F.cross_entropy(logits.view(-1, self.bert.config.vocab_size), y.view(-1), ignore_index=-100)
        # logits = self.W(self.bert(x, params=fast_weights)[0])
        # loss_q = F.cross_entropy(logits, y)
        return loss, logits

    def eval_accuracy(self, logits, y):
        pred_q = F.softmax(logits, dim=1).argmax(dim=1)
        correct = torch.eq(pred_q, y).sum().item()
        return correct

    
    def combine_batches(self, Bn, BM):
            # Get the device of the current batch (Bn)
        device = Bn["input_ids"].device

        # Move BM (memory batch) to the same device as Bn if they are not already on the same device
        BM = {key: value.to(device) for key, value in BM.items()}
        
        return {"input_ids": torch.cat([Bn["input_ids"], BM["input_ids"]], dim=0),
                             "attention_mask": torch.cat([Bn["attention_mask"], BM["attention_mask"]], dim=0),
                             "labels": torch.cat([Bn["labels"], BM["labels"]], dim=0)}

    def sample_from_memory(self, sample_size):
        if self.is_empty():
            return None

        # Sample random indices from memory
        indices = torch.randperm(self.memory["input_ids"].size(0))[:sample_size]

        sampled_batch = {
            "input_ids": self.memory["input_ids"][indices],
            "attention_mask": self.memory["attention_mask"][indices],
            "labels": self.memory["labels"][indices]
        }
        return sampled_batch

    def training_step(self, batch, batch_idx):
        
        if self.is_empty():
            print(self.is_empty())
            combined_batch = batch
        # Experience Replay with personalized data
        else : 
            BM = self.sample_from_memory(min(len(self.memory), self.batch_size))
            combined_batch = self.combine_batches(batch, BM)
        
        X, Y = combined_batch['input_ids'], combined_batch['labels']
        mask = combined_batch['attention_mask']

        # Perform a single random sampling and split into trajectory and meta-test
        total_indices = list(range(len(X)))
        # random.shuffle(total_indices)

        # Split indices: 50% for trajectory, 50% for meta-test (adjust the split ratio if needed)
        split_idx = len(X) // 2
        traj_indices = total_indices[:split_idx]
        meta_test_indices = total_indices[split_idx:]

        # Get trajectory and meta-test data
        X_traj = X[traj_indices]
        mask_traj = mask[traj_indices]
        Y_traj = Y[traj_indices]

        X_meta = X[meta_test_indices]
        mask_meta = mask[meta_test_indices]
        Y_meta = Y[meta_test_indices]

        # Meta-learning loop
        fast_weights = None
        # for j in range(len(X_traj)):
        #     fast_weights = self.inner_update(X_traj[j].unsqueeze(0), mask_traj[j].unsqueeze(0), fast_weights, Y_traj[j].unsqueeze(0))
        fast_weights = self.inner_update(X_traj, mask_traj, fast_weights, Y_traj)

        # Compute meta-loss on meta-test set
        meta_loss, _ = self.meta_loss(X_meta, mask_meta, fast_weights, Y_meta)

        # Update memory buffer
        self.log("train_loss", meta_loss, on_epoch=True, prog_bar=True)
        
        self.update_memory(batch)
        
        return meta_loss

    def is_empty(self):
        return self.memory["input_ids"] is None
    
    def update_memory(self, batch):
        batch_size = batch["input_ids"].size(0)

        if self.memory["input_ids"] is None:
            # Pre-allocate memory buffer with fixed size
            self.memory["input_ids"] = torch.zeros((self.mem_sz, *batch["input_ids"].shape[1:]), dtype=batch["input_ids"].dtype)
            self.memory["attention_mask"] = torch.zeros((self.mem_sz, *batch["attention_mask"].shape[1:]), dtype=batch["attention_mask"].dtype)
            self.memory["labels"] = torch.zeros((self.mem_sz, *batch["labels"].shape[1:]), dtype=batch["labels"].dtype)

        # Calculate the end index for insertion
        end_ptr = (self.ptr + batch_size) % self.mem_sz

        if end_ptr > self.ptr:
            # Case where we don't wrap around the buffer
            self.memory["input_ids"][self.ptr:end_ptr] = batch["input_ids"]
            self.memory["attention_mask"][self.ptr:end_ptr] = batch["attention_mask"]
            self.memory["labels"][self.ptr:end_ptr] = batch["labels"]
        else:
            # Case where we wrap around the buffer
            part1_len = self.mem_sz - self.ptr
            self.memory["input_ids"][self.ptr:] = batch["input_ids"][:part1_len]
            self.memory["attention_mask"][self.ptr:] = batch["attention_mask"][:part1_len]
            self.memory["labels"][self.ptr:] = batch["labels"][:part1_len]
            self.memory["input_ids"][:end_ptr] = batch["input_ids"][part1_len:]
            self.memory["attention_mask"][:end_ptr] = batch["attention_mask"][part1_len:]
            self.memory["labels"][:end_ptr] = batch["labels"][part1_len:]

        # Update pointer and size
        self.ptr = end_ptr
        self.size = min(self.size + batch_size, self.mem_sz)

    def _shared_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        logits = self.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        loss = F.cross_entropy(logits.view(-1, self.bert.config.vocab_size), labels.view(-1), ignore_index=-100)
        return loss, logits

    def train_dataloader(self):
        train_dataloader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=NUM_WORKERS,
            pin_memory=True,
            collate_fn=data_collator,
            drop_last=True,
            shuffle=True
        )
        return train_dataloader

    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            result = self._shared_step(batch, batch_idx)
            val_loss = result[0].detach()
            logits = result[1].detach()

            predictions = torch.argmax(logits, dim=-1)
            labels = batch['labels']

            self.log("val_loss", val_loss, on_epoch=True, prog_bar=True)
            
            perplexity = torch.exp(val_loss)
            self.log("val_ppl", perplexity, on_epoch=True, prog_bar=True)

            # Save predictions and targets for F1 score calculation
            self.validation_step_outputs.append(predictions.cpu())  # Move to CPU to avoid memory issues
            self.validation_step_targets.append(labels.cpu())  # Move to CPU to avoid memory issues

            # Return loss and perplexity calculation
            return {
                "val_loss": val_loss,
                "val_ppl":  perplexity
            }

    def on_validation_epoch_end(self):

        # Concatenate all predictions and targets
        all_preds = torch.cat(self.validation_step_outputs).numpy()
        all_targets = torch.cat(self.validation_step_targets).numpy()

        # Calculate F1 score
        f1 = f1_score(all_targets.flatten(), all_preds.flatten(), average='weighted')


        # Log F1 score
        self.log('val_f1', f1, on_epoch=True, prog_bar=True)

        # Clear stored predictions and targets
        self.validation_step_outputs.clear()
        self.validation_step_targets.clear()

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
        scheduler = {
            'scheduler': get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS,
                                                                   num_training_steps=MAX_EPOCHS, lr_end=1e-09),
            'name': 'lr'
        }
        return [optimizer], [scheduler]

    def test_step(self, batch, batch_idx):
        with torch.no_grad():
            result = self._shared_step(batch, batch_idx)
            loss = result[0].detach()
            logits = result[1].detach()

            predictions = torch.argmax(logits, dim=-1)
            labels = batch['labels']

            perplexity = torch.exp(loss)
            self.log("test_ppl", perplexity, on_epoch=True, prog_bar=True)
            self.log("test_loss", loss, on_epoch=True, prog_bar=True)


            return {
                "test_ppl": perplexity,
                "test_loss": loss
            }

    def backward(self, loss):
        loss.backward()


trainer = pl.Trainer(
    max_epochs=MAX_EPOCHS,
    logger=False,
    callbacks=False, #[checkpoint_callback, lr_monitor],
    precision="16-mixed",
    accelerator="gpu"
)


to_train = MetaER_pictobert()



In [17]:
trainer.fit(to_train, train_dataloader, val_dataloader)



Sanity Checking: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


Training: |          | 0/? [00:00<?, ?it/s]

True


Validation: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()
  self.pid = os.fork()


Validation: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()
  self.pid = os.fork()


Validation: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


In [18]:
# After training
display_markdown("## MetaER Model: Validation Set Metrics")
# print(f"Final Training Loss: {trainer.logged_metrics["train_loss"]}")
print(f"Final val_ppl : {trainer.logged_metrics['val_ppl']}")
print(f"Final val_loss: {trainer.logged_metrics['val_loss']}")

## MetaER Model: Validation Set Metrics

Final val_ppl : 49.16434097290039
Final val_loss: 3.6612133979797363


In [20]:
display_markdown("## MetaER Model: Test Set Metrics")
trainer.test(to_train, dataloaders=test_dataloader)

## MetaER Model: Test Set Metrics

Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_ppl': 48.77863311767578, 'test_loss': 3.6606292724609375}]

# Inference

In [25]:
import torch.nn.functional as F
def get_top_k(sentence, k):
    text = " ".join(sentence+['[MASK]','.'])
  # return text
    tokenized = loaded_tokenizer(text, return_tensors="pt")

    input_ids, attention_mask = tokenized['input_ids'], tokenized['attention_mask']
  # input_ids = tensor([loaded_tokenizer.convert_tokens_to_ids(sentence+['[MASK]','.'])])
    outputs = to_train.forward(input_ids, attention_mask)
    predictions = F.softmax(outputs, dim=-1)

    mask_idx = input_ids.tolist()[0].index(loaded_tokenizer.mask_token_id)
    probs = predictions[0, mask_idx, :]
    return loaded_tokenizer.convert_ids_to_tokens(probs.topk(k)[1])

get_top_k(['mommy%1:18:00::', 'be%2:42:03::'],10)

['do%2:41:01::',
 'go_to%2:42:00::',
 'go%2:38:00::',
 'not%4:02:00::',
 'nice%3:00:00::',
 'be%2:42:06::',
 'will%2:32:00::',
 'like',
 'at',
 'know%2:31:01::']

# How Meta_ER PictoBERT Helps in the Continual Learning Task for Personalized Pictogram Recommendations:

A continual learning agent should be able to build on top of existing knowledge to learn on new data quickly while minimizing forgetting.

The combined approach of MAML and Experience Replay in MetaER_pictobert equips the model with the ability to rapidly adapt to new user preferences while retaining past knowledge.
 1. Personalization: It can quickly personalize pictogram recommendations based on minimal user interactions.
 
 2. Generalization: It retains a generalized understanding of various user preferences, making it robust across diverse user groups.

A major challenge with Continual learning is to mitigate `Catastrophic forgetting` which occurs when a model trained sequentially on multiple tasks forgets previous tasks as it learns new ones. MetaER_pictobert addresses this challenge in several ways:

 1. Experience Replay (ER):
    - <i>Memory Buffer for Old Examples</i>: The model maintains a memory buffer that stores examples from past tasks (or user preferences) and regularly replays them during training.
    - <i>Periodic Sampling from Memory</i>: During each training step, the model randomly samples from the memory buffer and combines these samples with new data. This ensures that the model retains knowledge of past tasks while learning from new data.
    - <i>How This Prevents Forgetting</i>: By continually revisiting old examples, the model prevents the parameters from drifting too far from what was optimal for past tasks, mitigating catastrophic forgetting.

 2. MAML-Inspired Meta-Learning:
    - <i>Learned Initialization</i>: The MAML approach learns an initialization of parameters that are effective for rapid adaptation to a wide range of tasks. It trains the model’s parameters such that a small number of gradient updates will lead to fast learning on a new task.
    - <i>Inner Loop Updates (Task-Specific Fine-Tuning)</i>: When learning new tasks, the model only fine-tunes a small number of gradient steps, which prevents excessive overwriting of the base parameters.
    - <i>Outer Loop Updates (Meta-Optimization)</i>: Meta-optimization ensures that the base parameters remain robust and general, preserving knowledge that is broadly useful across tasks.

 3. Balanced Training Strategy:

    - <i>Combining Mini-Batches</i>: The model combines current mini-batches with samples drawn from the memory buffer, ensuring that it is trained on both new and historical data. This reduces the risk of overfitting to recent data or losing valuable information learned earlier.
    - <i>Task-Specific Learning (Inner Update)</i>: The task-specific learning process ensures that each new task is adapted with minimal changes to the core parameters, which are meta-optimized to be adaptable yet stable.

 4. Custom Weight Application and Dynamic Updates:

    - <i>Dynamic Weight Application</i>: By dynamically applying weights during the meta-learning loop, the model achieves finer control over which parameters are updated, further reducing the risk of forgetting.
    - <i>Efficient Memory Management</i>: The cyclic replacement strategy in the memory buffer allows the model to maintain a diverse set of training examples over time, which is crucial for preventing forgetting in a continual learning scenario.