In [1]:
import torch
import torch.random as random
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import IterableDataset, DataLoader
from tqdm.auto import tqdm

In [2]:
import numpy as np


class PermutedMNIST(IterableDataset):
    def __init__(self, num_permutations=10, transition_steps=1000, root="./data", download=True):
        self.root = root
        self.transform = transforms.ToTensor()
        self.download = download
        self.dataset = datasets.MNIST(root=self.root, train=True, download=self.download)
        self.permutation_id = 0
        self.transition = 0
        self.transition_steps = transition_steps
        self.num_permutations = num_permutations

    def _get_permutation(self, permutation_id: int):
        # this resets the random state when the context is closed, so that the next call to random will not be affected
        with random.fork_rng():
            # seed the permutation
            torch.manual_seed(permutation_id + 0xdeadbeef)
            perm = torch.randperm(28 * 28)
        return perm
    
    def __iter__(self):
        while True:
            next_id = (self.permutation_id + 1) % self.num_permutations
            # simulate a gradual transition
            permutation_id = np.random.choice([self.permutation_id, next_id], p=[1-self.transition, self.transition])
            self.transition += 1.0 / self.transition_steps
            if self.transition >= 1.0:
                self.transition = 0
                self.permutation_id = next_id
            perm = self._get_permutation(permutation_id)
            img_index = np.random.randint(0, len(self.dataset))
            img, target = self.dataset[img_index]
            img = self.transform(img)
            img = img.view(-1)
            img = img[perm]
            img = img.view(28, 28)
            target = torch.tensor(target, dtype=torch.long)
            permutation_id = torch.tensor(permutation_id, dtype=torch.long)
            yield img, target, permutation_id

In [3]:
class TaskAwareMLP(nn.Module):
    def __init__(self, input_dim, output_dim, task_embedding_dim, depth=3, hidden_dim=128):
        super(TaskAwareMLP, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.task_embedding_dim = task_embedding_dim
        self.depth = depth
        self.hidden_dim = hidden_dim
        self.task_embedding = nn.Parameter(torch.randn((1, task_embedding_dim)))
        self.flatten = nn.Flatten()
        self.layers = nn.Sequential()
        for i in range(depth):
            in_dim = input_dim + task_embedding_dim if i == 0 else hidden_dim
            out_dim = hidden_dim if i < depth - 1 else output_dim
            self.layers.append(nn.LayerNorm(in_dim))
            self.layers.append(nn.Linear(in_dim, out_dim))
            if i < depth - 1:
                self.layers.append(nn.ReLU())
            else:
                self.layers.append(nn.Softmax(dim=1))
        self.optimizer: torch.optim.Optimizer
        self.reset_optimizer()

    def reset_optimizer(self):
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.0001)

    def add_new_task(self):
        new_embedding = torch.randn((1, self.task_embedding_dim))
        self.task_embedding = nn.Parameter(torch.cat((self.task_embedding.data, new_embedding), dim=0))
        self.reset_optimizer()

    def forward(self, x: torch.Tensor, task_ids: torch.Tensor):
        assert len(task_ids.shape) == len(x.shape) - 2, f"task_ids must be 2 dimensions less than the input, instead got shape {len(task_ids.shape)} and {len(x.shape)}"
        needs_unbatching = False
        if len(x.shape) == 2:
            task_ids = task_ids.unsqueeze(0)
            x = x.unsqueeze(0)
            needs_unbatching = True
        assert x.shape[0] == task_ids.shape[0], f"Batch size of x and task_ids must match, instead got {x.shape[0]} and {task_ids.shape[0]}"
        
        while any(task_id >= self.task_embedding.shape[0] for task_id in task_ids):
            self.add_new_task()
        task_embedding = self.task_embedding[task_ids].expand(x.shape[0], -1)
        x = self.flatten(x)
        x = torch.cat((x, task_embedding), dim=1)
        x = self.layers(x)
        if needs_unbatching:
            x = x.squeeze(0)
        return x

In [4]:
def fit(model: TaskAwareMLP, dataset: PermutedMNIST, criterion: nn.Module):
    model.train()
    loading_bar = tqdm(dataset, desc="Training")
    avg_loss_per_task = [0 for _ in range(dataset.num_permutations)]
    l = 0.1
    for sample_idx, (data, target, task_id) in enumerate(loading_bar):
        model.optimizer.zero_grad()
        output: torch.Tensor = model(data, task_id)
        loss: torch.Tensor = criterion(output, target)
        loss.backward()
        if avg_loss_per_task[task_id] == 0:
            avg_loss_per_task[task_id] = loss.item()
        else:
            avg_loss_per_task[task_id] = avg_loss_per_task[task_id] * (1-l) + loss.item() * l
        model.optimizer.step()
        loading_bar.set_postfix(losses=np.array(avg_loss_per_task), task=dataset.permutation_id)

In [5]:
model = TaskAwareMLP(input_dim=28*28, output_dim=10, task_embedding_dim=10, depth=4, hidden_dim=256)
train_dataset = PermutedMNIST(num_permutations=10, download=True)
criterion = nn.CrossEntropyLoss()

In [6]:
fit(model, train_dataset, criterion)

Training: 0it [00:00, ?it/s]

KeyboardInterrupt: 