# A Simple Example of Continual Finetune

This notebook's purpose is to demonstrate the implementation of the soft-masking concept (refer to the [TSS](https://arxiv.org/abs/2310.09436)). It is not designed to yield effective results in real-world scenarios. Its simplicity lies in the fact that:

*   We avoid using advanced packages, including huggingface.
*   We employ a basic fully connected network instead of any pre-trained language models or LSTM.
*   The data is synthetic, and we do not implement a real tokenizer or task-specific loss


Import the necessary packages

In [136]:
from collections import defaultdict
import random, os
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.autograd as autograd
import math



Construct a basic tokenizer. This tokenizer's vocabulary is created from the provided corpus. It is not suitable for real-world applications, as this simplistic approach cannot manage any words that are not already in the corpus.

In [137]:
def tokenizer(corpus):
  # Build vocabulary

  vocab = defaultdict(int)
  idx = 1 # 0 is used as padding token id
  for text in corpus:
      for word in text.split():
        if word not in vocab:
          vocab[word] = idx
          idx += 1

  # Use vocabulary
  tokenizerd_corpus = []
  for text in corpus:
      tokenized_text = []
      for word in text.split():
          tokenized_text.append(vocab[word])
      tokenizerd_corpus.append(tokenized_text)

  return {'idx': tokenizerd_corpus}




Next, we implement a helper function to assist in tokenizing each instance in the dataset.

In [138]:
def truncate_pad(examples, max_length):

    result = {}
    new_example = []
    for example in examples['idx']:
      if max_length < len(example): # trancate
        new_example.append(example[:max_length])
      else:
        difference = max_length - len(example)
        new_example.append(example + [0] * difference)

    result['idx'] = new_example

    #Lets also give some synthetic label here for pre-training task
    label_ids = [0,1]
    result['labels'] = []
    for idx in result['idx']:
      result['labels'].append(random.sample(label_ids, 1))

    return result

We also need to create a custom PyTorch dataset, since our data is formatted as a dictionary.

In [139]:

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data['idx'])

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        data_tensor = {}
        for key, value in self.data.items():
          data_item = self.data[key][idx]
          data_tensor[key] = torch.tensor(data_item, dtype=torch.float)

        return data_tensor


The following code is inspired by [SupSup](https://github.com/RAIVNLab/supsup/blob/master/mnist.ipynb). We overwrite the ``nn.linear`` function so that the network training is transformed into training for popup scores (see [TSS](https://arxiv.org/abs/2310.09436)).



In [140]:
def set_compute_mask_impt(model, compute_impt):
    for n, m in model.named_modules():
        if isinstance(m, NNSubnetworkSoftmask):
            m.compute_mask_impt = compute_impt

def set_ft_task(model, ft_task):
    for n, m in model.named_modules():
        if isinstance(m, NNSubnetworkSoftmask):
            m.ft_task = ft_task

# Subnetwork forward from hidden networks
class GetSubnet(autograd.Function):
    @staticmethod
    def forward(ctx, scores):
        return (scores >= 0).float() # Use 0 as threshold. this is related to the signed_constant initialization

    @staticmethod
    def backward(ctx, g):
        # Send the gradient g straight-through on the backward pass. so that it is trainable
        return g

class NNSubnetworkSoftmask(nn.Linear):
    def __init__(self, *args, num_tasks=1, **kwargs):
        super().__init__(*args, **kwargs)
        self.num_tasks = num_tasks
        self.scores = nn.ParameterList(
            [
                nn.Parameter(self.mask_init())
                for _ in range(num_tasks)
            ]
        )
        self.impt_mask =nn.ParameterList(
            [
                nn.Parameter(torch.zeros(self.weight.size())).requires_grad_(False)
                for _ in range(self.num_tasks)
            ]
        )

        # Alphas are used later when we compute the importance of the scores.
        self.alphas =nn.Parameter(torch.ones(self.weight.size()))

        # Keep weights untrained
        self.weight.requires_grad = False
        self.signed_constant()


    def copy_score(self, ft_task):
        with torch.no_grad():
            self.scores[ft_task+1].copy_(self.scores[ft_task].clone())

    def mask_init(self):
        scores = torch.Tensor(self.weight.size())
        nn.init.kaiming_uniform_(scores, a=math.sqrt(5))
        return scores

    def signed_constant(self):
        fan = nn.init._calculate_correct_fan(self.weight, 'fan_in')
        gain = nn.init.calculate_gain('relu')
        std = gain / math.sqrt(fan)
        self.weight.data = self.weight.data.sign() * std


    def forward(self, x):
        if self.compute_mask_impt:  # Whether it is to compute the importance
            selected_mask = self.scores[self.ft_task]

            subnet = GetSubnet.apply(selected_mask)
            w = self.weight * subnet * self.alphas
            x = F.linear(x, w, self.bias)

        else:
            selected_mask = self.scores[self.ft_task]
            subnet = GetSubnet.apply(selected_mask)
            w = self.weight * subnet
            x = F.linear(x, w, self.bias)

        return x


    def __repr__(self):
        return f"NNSubnetworkSoftmask({self.weight.size(0)}, {self.weight.size(1)})"

class NNSoftmask(nn.Module):
    def __init__(self):
        super(NNSoftmask, self).__init__()
        self.word_embeddings = nn.Embedding(300, 50)
        self.fc1 = NNSubnetworkSoftmask(50,30, num_tasks=2, bias=False)
        self.fc2 = NNSubnetworkSoftmask(30,10, num_tasks=2, bias=False)
        self.head = nn.Linear(10,1)
        self.dropout = nn.Dropout(0.2)
        self.sigmoid = nn.Sigmoid()
        self.word_embeddings.weight.requires_grad = False

    def forward(self, x):

        x = self.word_embeddings(x)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.dropout(F.relu(self.fc2(x)))
        x = self.sigmoid(self.head(x).mean(1))
        return x



Now we can initialize our synthetic data and the model.

In [141]:
corpus = [
        '''
        Apparently Prides Osteria had a rough summer as evidenced by the almost empty dining room at 6:30 on a Friday night. However new blood in the kitchen seems to have revitalized the food from other customers recent visits. Waitstaff was warm but unobtrusive. By 8 pm or so when we left the bar was full and the dining room was much more lively than it had been. Perhaps Beverly residents prefer a later seating. After reading the mixed reviews of late I was a little tentative over our choice but luckily there was nothing to worry about in the food department. We started with the fried dough, burrata and prosciutto which were all lovely. Then although they don't offer half portions of pasta we each ordered the entree size and split them. We chose the tagliatelle bolognese and a four cheese filled pasta in a creamy sauce with bacon, asparagus and grana frita. Both were very good. We split a secondi which was the special Berkshire pork secreto, which was described as a pork skirt steak with garlic potato purée and romanesco broccoli (incorrectly described as a romanesco sauce). Some tables received bread before the meal but for some reason we did not. Management also seems capable for when the tenants in the apartment above began playing basketball she intervened and also comped the tables a dessert. We ordered the apple dumpling with gelato and it was also quite tasty. Portions are not huge which I particularly like because I prefer to order courses. If you are someone who orders just a meal you may leave hungry depending on you appetite. Dining room was mostly younger crowd while the bar was definitely the over 40 set. Would recommend that the naysayers return to see the improvement although I personally don't know the former glory to be able to compare. Easy access to downtown Salem without the crowds on this month of October.
        ''',
        '''
        The food is always great here. The service from both the manager as well as the staff is super. Only draw back of this restaurant is it's super loud. If you can, snag a patio table!
        ''',
        '''
        This place used to be a cool, chill place. Now its a bunch of neanderthal bouncers hopped up on steroids acting like the can do whatever they want. There are so many better places in davis square where they are glad you are visiting their business. Sad that the burren is now the worst place in davis.
        '''
        ]


tokenizerd_text = tokenizer(corpus)
max_length = 30
truncate_pad_tokenized_text = truncate_pad(tokenizerd_text,max_length)

my_dataset = CustomDataset(truncate_pad_tokenized_text)
batch_size = 2
data_loader = DataLoader(my_dataset, batch_size=batch_size, shuffle=True)

subnetworrk_softmask = NNSoftmask()


For the first task, we do not need to apply any masking. We train the network (i.e., the scores) in a conventional manner.

In [142]:
ft_task = 0
criterion = nn.BCELoss()
optimizer = optim.Adam(subnetworrk_softmask.parameters(), lr=0.003)
set_compute_mask_impt(subnetworrk_softmask,False)
set_ft_task(subnetworrk_softmask,ft_task)

epochs = 10
for e in range(epochs):
  running_loss = 0
  i = 0
  for step, batch in enumerate(data_loader):
    i += 1
    if i % 100 == 0:
        print(f'Training loss at step {i}: {running_loss/(i*batch_size)}')
    input_ids = batch['idx'].long()
    labels = batch['labels']

    outputs = subnetworrk_softmask(input_ids)

    loss = criterion(outputs, labels)

    loss.backward()
    if e < 1 and step < 1:
      for n, p in subnetworrk_softmask.named_parameters():
        if p.grad is not None:
            print(f'Gradient of param "{n}" with size {tuple(p.size())} detected')


    optimizer.step()
    optimizer.zero_grad()

    running_loss += loss.item()


    print(f'Training loss: {running_loss / (len(data_loader) * batch_size)}')



Gradient of param "fc1.scores.0" with size (30, 50) detected
Gradient of param "fc2.scores.0" with size (10, 30) detected
Gradient of param "head.weight" with size (1, 10) detected
Gradient of param "head.bias" with size (1,) detected
Training loss: 0.12747827172279358
Training loss: 0.25286777317523956
Training loss: 0.11871679127216339
Training loss: 0.2254461646080017
Training loss: 0.10435687005519867
Training loss: 0.21180221438407898
Training loss: 0.09677727520465851
Training loss: 0.19502855837345123
Training loss: 0.09435197710990906
Training loss: 0.1859283596277237
Training loss: 0.0890825018286705
Training loss: 0.17168926447629929
Training loss: 0.08337169885635376
Training loss: 0.1547119840979576
Training loss: 0.08222860842943192
Training loss: 0.1536419317126274
Training loss: 0.0745910257101059
Training loss: 0.1509472206234932
Training loss: 0.06745259463787079
Training loss: 0.1325201839208603


After fine-tuning the first task (``ft_task=0``), we need to calculate the importance of the scores in each layer. This calculation is based on cross-entropy. Once determined using the gradient, we then normalize the importance. Additionally, we copy the trained scores to the next task as an initialization step, allowing knowledge transfer to the subsequent task.

In [146]:
ft_task = 0
set_compute_mask_impt(subnetworrk_softmask, True)
set_ft_task(subnetworrk_softmask,ft_task)

tss_impt_dict = {}

for step, batch in enumerate(data_loader):
  input_ids = batch['idx'].long()
  labels = batch['labels']

  outputs = subnetworrk_softmask(input_ids)
  loss = criterion(outputs, labels)
  loss.backward()

  for n, m in subnetworrk_softmask.named_modules():
      if isinstance(m, NNSubnetworkSoftmask):
          if n in tss_impt_dict:
              tss_impt_dict[n] += m.alphas.grad.clone().detach()
          else:
              tss_impt_dict[n] = m.alphas.grad.clone().detach()


subnetworrk_softmask.zero_grad() # Remove gradients

# Normalize the importance
def impt_norm(impt):
    tanh = torch.nn.Tanh()
    for layer in range(impt.size(0)):
        impt[layer] = (impt[layer] - impt[layer].mean()) / impt[
            layer].std()  # 2D, we need to deal with this for each layer
    impt = tanh(impt).abs()

    return impt

for n, m in subnetworrk_softmask.named_modules():
    if isinstance(m, NNSubnetworkSoftmask):
        with torch.no_grad():
            m.impt_mask[ft_task][impt_norm(tss_impt_dict[n]) >= 0.5] = 1
            print(f'Name and usage: {n}, {(m.impt_mask[ft_task].sum() / m.impt_mask[ft_task].numel()).item()}') # importance mask

# Copy the scores
for n, m in subnetworrk_softmask.named_modules():
    if isinstance(m, NNSubnetworkSoftmask):
        m.copy_score(ft_task)



Name and usage: fc1, 0.6053333282470703
Name and usage: fc2, 0.5266666412353516


Now we can begin training the second task (``ft_task=2``). For simplicity, we use the same data as in the first task. During training, we apply soft-masking to the gradients of the scores to preserve previous knowledge. After training the second task, we need to compute the importance as in the previous steps (not shown in the code).

In [144]:
ft_task = 1
criterion = nn.BCELoss()
optimizer = optim.Adam(subnetworrk_softmask.parameters(), lr=0.003)
set_compute_mask_impt(subnetworrk_softmask,False)
set_ft_task(subnetworrk_softmask,ft_task)

epochs = 10
for e in range(epochs):
  running_loss = 0
  i = 0
  for step, batch in enumerate(data_loader):
    i += 1
    if i % 100 == 0:
        print(f'Training loss at step {i}: {running_loss/(i*batch_size)}')
    input_ids = batch['idx'].long()
    labels = batch['labels']

    outputs = subnetworrk_softmask(input_ids)

    loss = criterion(outputs, labels)

    loss.backward()
    if e < 1 and step < 1:
      for n, p in subnetworrk_softmask.named_parameters():
        if p.grad is not None:
            print(f'Gradient of param "{n}" with size {tuple(p.size())} detected')

    for n, m in subnetworrk_softmask.named_modules():
        if isinstance(m, NNSubnetworkSoftmask):
            m.scores[ft_task].grad *= (1-tss_impt_dict[n])

    optimizer.step()
    optimizer.zero_grad()

    running_loss += loss.item()

    print(f'Training loss: {running_loss / (len(data_loader) * batch_size)}')



Gradient of param "fc1.scores.1" with size (30, 50) detected
Gradient of param "fc2.scores.1" with size (10, 30) detected
Gradient of param "head.weight" with size (1, 10) detected
Gradient of param "head.bias" with size (1,) detected
Training loss: 0.06010271608829498
Training loss: 0.12154484912753105
Training loss: 0.051853425800800323
Training loss: 0.10218348354101181
Training loss: 0.045971132814884186
Training loss: 0.09850922599434853
Training loss: 0.04364640638232231
Training loss: 0.08559411019086838
Training loss: 0.03648237884044647
Training loss: 0.08484786003828049
Training loss: 0.04133094474673271
Training loss: 0.07143541425466537
Training loss: 0.031948160380125046
Training loss: 0.06391788274049759
Training loss: 0.02935045398771763
Training loss: 0.06814335472881794
Training loss: 0.028552938252687454
Training loss: 0.058869631960988045
Training loss: 0.027650360018014908
Training loss: 0.049211161211133
