# FlyModel demo 🍄

In this notebook we'll show how to use our implementation of FlyModel, introduced in [Algorithmic insights on continual learning from fruit flies](https://arxiv.org/abs/2107.07617) to classify MNIST digits.

Useful links 👀:
* [paper](https://arxiv.org/pdf/2107.07617.pdf)
* [our repo](https://github.com/Ramos-Ramos/FlyModel/)

**Disclaimers 🚨:**
* we're not the original authors, we just took a crack at implementing it
* we haven't been able to perfectly reproduce their results

## How does it work? 🤔



FlyModel is composed of two layers. The first is an untrainable layer of sparse binary weights. The second is a set of weights initally randomized between 0 and 1.

The first layer projects $m$-dimensional input $x$ to a $d$-dimensional hidden representation $ψ(x)$. The top $l$ activations of $ψ(x)$ are taken while the rest are suppressed to 0 to form sparse hidden representation $Φ(x)$, which is min-max normalized between 0 and 1.

The second layer projects $Φ(x)$ to $k$ classes. When training, a class index $j$ is provided which is used to update the weights between the active neurons in $Φ(x)$ and the $j$-th output neuron.

## Installations and imports 🔧

In [1]:
pip install einops gradio pytorch_lightning torchmetrics git+https://github.com/Ramos-Ramos/flymodel

Collecting git+https://github.com/Ramos-Ramos/flymodel
  Cloning https://github.com/Ramos-Ramos/flymodel to /tmp/pip-req-build-o3nfgr99
  Running command git clone -q https://github.com/Ramos-Ramos/flymodel /tmp/pip-req-build-o3nfgr99
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting einops
  Downloading einops-0.3.2-py3-none-any.whl (25 kB)
Collecting gradio
  Downloading gradio-2.2.15-py3-none-any.whl (3.4 MB)
[K     |████████████████████████████████| 3.4 MB 23.9 MB/s 
[?25hCollecting pytorch_lightning
  Downloading pytorch_lightning-1.4.5-py3-none-any.whl (919 kB)
[K     |████████████████████████████████| 919 kB 41.4 MB/s 
[?25hCollecting torchmetrics
  Downloading torchmetrics-0.5.1-py3-none-any.whl (282 kB)
[K     |████████████████████████████████| 282 kB 52.2 MB/s 
[?25hCollecting cupy-cuda101
  Downloading cupy_cuda101-9.4.0-cp37-cp37m-manylinux1_x86_6

In [2]:
from einops import reduce, rearrange
from einops.layers.torch import Rearrange
import gradio as gr
import numpy as np
import pytorch_lightning as pl
from scipy.special import softmax
import torch
from torch import nn
import torch.nn.functional as F
import torchmetrics
from torchvision.datasets import KMNIST, MNIST
import torchvision.transforms as T
from tqdm.notebook import tqdm

from itertools import chain
import pickle

from flymodel import FlyModel

## Training an encoder 👁️


Since FlyModel is more concerned with associating inputs and outputs rather than learning representations, it doesn't actually take in raw input but rather hidden representations already encoded from input.

Let's train an encoder. For MNIST and Fashion-MNIST, the authors trained a modified LeNet-5 on KMNIST. We can do that in PyTorch Lightning.

Since this isn't really the FlyModel, you can just run these cells without reading the contents.

In [3]:
# define LeNet5 model
class LeNet5(pl.LightningModule):
  
  def __init__(self, classes=10):
    super().__init__()
    self.model = nn.Sequential(
        nn.Conv2d(1, 6, 5, padding=2),
        nn.BatchNorm2d(6),
        nn.Sigmoid(),
        nn.MaxPool2d(2, stride=2),
        nn.Conv2d(6, 16, 5),
        nn.BatchNorm2d(16),
        nn.Sigmoid(),
        nn.MaxPool2d(2, stride=2),
        Rearrange('b c h w -> b (c h w)'),
        nn.LazyLinear(120),
        nn.Sigmoid(),
        nn.Linear(120, 84),
        nn.Sigmoid(),
        nn.Linear(84, classes),
        nn.Sigmoid()
    )
    self.criterion = nn.CrossEntropyLoss()
    self.test_acc = torchmetrics.Accuracy()
    self.val_acc = torchmetrics.Accuracy()

  def forward(self, x):
    return self.model(x)
  
  def forward_features(self, x):
    return self.model[:-2](x)

  def shared_step(self, batch, stage):
    inputs, labels = batch
    outputs = self(inputs)
    loss = F.cross_entropy(outputs, labels)
    self.log(f'{stage}_loss', loss, prog_bar=True)
    if stage != 'train':
      acc = getattr(self, f'{stage}_acc')
      acc(outputs, labels)
      self.log(f'{stage}_acc', acc, prog_bar=True)
    return loss
  
  def training_step(self, batch, batch_idx):
    return self.shared_step(batch, 'train')

  def test_step(self, batch, batch_idx):
    return self.shared_step(batch, 'test')
  
  def validation_step(self, batch, batch_idx):
    return self.shared_step(batch, 'val')

  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), 1e-3)

In [4]:
# define KMNIST datamodule
class LitKMNIST(pl.LightningDataModule):

  def __init__(self, root='./', batch_size=64):
    super().__init__()
    self.root = root
    self.batch_size = batch_size

  def prepare_data(self):
    KMNIST(self.root, train=True, download=True)
    KMNIST(self.root, train=False, download=True)

  def setup(self, stage):
    if stage == 'fit' or stage is None:
      self.trainset = KMNIST(self.root, train=True, transform=T.ToTensor())
    
    if stage in ('test', 'validate') or stage is None:
      self.testset = KMNIST(self.root, train=False, transform=T.ToTensor())

  def train_dataloader(self):
    return torch.utils.data.DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True, num_workers=2)
  
  def test_dataloader(self):
    return torch.utils.data.DataLoader(self.testset, batch_size=self.batch_size, num_workers=2)

  def val_dataloader(self):
    if not hasattr(self, 'testset'):
      self.setup('validate')
    return torch.utils.data.DataLoader(self.testset, batch_size=self.batch_size, num_workers=2)

In [5]:
# train and save checkpoint
encoder = LeNet5()
dm = LitKMNIST()
trainer = pl.Trainer(max_epochs=25)
trainer.fit(encoder, dm)
trainer.save_checkpoint('kmnist_lenet5.pth')

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Downloading http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz
Downloading http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz to ./KMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/18165135 [00:00<?, ?it/s]

Extracting ./KMNIST/raw/train-images-idx3-ubyte.gz to ./KMNIST/raw

Downloading http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz
Downloading http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz to ./KMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29497 [00:00<?, ?it/s]

Extracting ./KMNIST/raw/train-labels-idx1-ubyte.gz to ./KMNIST/raw

Downloading http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz
Downloading http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz to ./KMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/3041136 [00:00<?, ?it/s]

Extracting ./KMNIST/raw/t10k-images-idx3-ubyte.gz to ./KMNIST/raw

Downloading http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz
Downloading http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz to ./KMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5120 [00:00<?, ?it/s]

Extracting ./KMNIST/raw/t10k-labels-idx1-ubyte.gz to ./KMNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
  "A layer with UninitializedParameter was found. "

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 13.6 K
1 | criterion | CrossEntropyLoss | 0     
2 | test_acc  | Accuracy         | 0     
3 | val_acc   | Accuracy         | 0     
-----------------------------------------------
13.6 K    Trainable params
0         Non-trainable params
13.6 K    Total params
0.055     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

## Training and Evaluating FlyModel 💪

Now that we have an encoder, we can train a FlyModel. FlyModel is trained and evlauted on the task of continual learning. Given classes 0 to i, the authors trained the model on a sequence of classification tasks where in each task it must classify between non-overlapping consecutive pairs of classes (0 vs 1, 2 vs 3, etc.). For each task, the model goes through the classes one at a time for a single epoch (ex. all 0s then all 1s).

FlyModel can be evaluated by measuring the accuracy on all tasks trained on before the task just trained on. The authors also evaluate using a memory loss metric, but we focus on accuracy for this demo.

Since we're using MNIST, let's make a continual learning version of the dataset in PyTorch.

In [6]:
# take only the classes we need and sort them
class MNISTCL(MNIST):
  
  def __init__(self, *, classes, **kwargs):
    super().__init__(**kwargs)
    self.classes = classes
    filter = reduce(
        torch.stack([self.targets.eq(c) for c in classes]), 'c i -> i', 'sum'
    ).bool()
    self.targets, indices = self.targets[filter].sort()
    self.data = self.data[filter][indices]

A train loop consists of going through each element in the dataset, encoding the input through the encoder, and feeding the encoded representation and the class to the model. Note that we need to call `model.train()` to make sure the weights update.

In [7]:
def train(model, encoder, trainloader):
  model.train()
  encoder.eval()
  for (input, label) in tqdm(iter(trainloader)):
    with torch.no_grad():
      input = encoder.forward_features(input).numpy()
    label = label.numpy()
    model(input, label)

The test loop is similar to the train loop except we must call `model.eval()` instead to prevent the weights from updating (they can't anyway without the classes, which we don't feed because we're evaluating). We return the accuracy on the test set.

In [8]:
def test(model, encoder, testloader):
  model.eval()
  encoder.eval()
  total_correct = 0
  for (input, label) in tqdm(iter(testloader)):
    with torch.no_grad():
      input = encoder.forward_features(input).numpy()
    label = int(label)
    outputs = model(input)
    pred = outputs.argmax()
    if label == pred:
      total_correct += 1
  accuracy = total_correct/len(testloader)
  return accuracy

Since we're almost always going to instantiate a train or testloader before training or testing, let's create some helper functions that create the dataloaders and perform training or testing.

In [9]:
def create_trainloader_and_train(model, encoder, classes, root='./'):
  trainset = MNISTCL(classes=classes, root=root, train=True, download=True, transform=T.ToTensor())
  trainloader = torch.utils.data.DataLoader(trainset, num_workers=2)
  train(model, encoder, trainloader)

def create_testloader_and_test(model, encoder, classes, root='./'):
  testset = MNISTCL(classes=classes, root=root, train=False, download=True, transform=T.ToTensor())
  testloader = torch.utils.data.DataLoader(testset, num_workers=2)
  accuracy = test(model, encoder, testloader)
  return accuracy

Now we can create the FlyModel and the LeNet5 encoder. We can use the encoder weights of the encoder we trained previously.

In [10]:
model = FlyModel(
  input_size=84,        # input dimension size (no. of projection neurons)
  hidden_size=3200,     # hidden dimension size (no. of Kenyon cells)
  output_size=10,       # output dimension size (no. of mushroom body output neurons)
  top_activations=320,  # no. of top cells to be left active in hidden layer
  lr=1e-2,              # learning rate (learning is performed internally)
  decay=0,              # forgetting term
  input_connections=10  # number of inputs to connect to for each hidden neuron; alternativey, `input_density` can be specified
)
encoder = LeNet5.load_from_checkpoint('kmnist_lenet5.pth')



In [11]:
tasks = [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9)]
for i, classes in enumerate(tasks):
  create_trainloader_and_train(model, encoder, classes)
  print(f'trained on task {i}')

  accuracy = create_testloader_and_test(model, encoder, classes)
  print(f'\taccuracy on task {i}: {accuracy}')

  if i > 0:
    accuracy = create_testloader_and_test(model, encoder, list(chain.from_iterable(tasks[:i+1])))
    print(f'\taccuracy on all tasks so far after training up to task {i}: {accuracy}')

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNISTCL/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNISTCL/raw/train-images-idx3-ubyte.gz to ./MNISTCL/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNISTCL/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNISTCL/raw/train-labels-idx1-ubyte.gz to ./MNISTCL/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNISTCL/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNISTCL/raw/t10k-images-idx3-ubyte.gz to ./MNISTCL/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNISTCL/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNISTCL/raw/t10k-labels-idx1-ubyte.gz to ./MNISTCL/raw



  0%|          | 0/12665 [00:00<?, ?it/s]

trained on task 0


  0%|          | 0/2115 [00:00<?, ?it/s]

	accuracy on task 0: 0.9895981087470449


  0%|          | 0/12089 [00:00<?, ?it/s]

trained on task 1


  0%|          | 0/2042 [00:00<?, ?it/s]

	accuracy on task 1: 0.8227228207639569


  0%|          | 0/4157 [00:00<?, ?it/s]

	accuracy on all tasks so far after training up to task 1: 0.8874188116430118


  0%|          | 0/11263 [00:00<?, ?it/s]

trained on task 2


  0%|          | 0/1874 [00:00<?, ?it/s]

	accuracy on task 2: 0.8014941302027748


  0%|          | 0/6031 [00:00<?, ?it/s]

	accuracy on all tasks so far after training up to task 2: 0.8214226496435085


  0%|          | 0/12183 [00:00<?, ?it/s]

trained on task 3


  0%|          | 0/1986 [00:00<?, ?it/s]

	accuracy on task 3: 0.7880161127895267


  0%|          | 0/8017 [00:00<?, ?it/s]

	accuracy on all tasks so far after training up to task 3: 0.7906947736060871


  0%|          | 0/11800 [00:00<?, ?it/s]

trained on task 4


  0%|          | 0/1983 [00:00<?, ?it/s]

	accuracy on task 4: 0.5229450327786183


  0%|          | 0/10000 [00:00<?, ?it/s]

	accuracy on all tasks so far after training up to task 4: 0.7227


Let's save our model weights for future use.

In [12]:
with open('weights.pkl', 'wb') as file:
  pickle.dump(model.state_dict(), file)

## Interactive demo ✏️

Here we use Gradio to let you draw your own digits and have them classified by FlyModel. Make sure you run all the previous cells since we'll need model checkpoints.

Have fun!

In [14]:
model = FlyModel(
    input_size=84,
    hidden_size=3200,
    output_size=10,
    top_activations=320,
    lr=1e-2,
    decay=0,
    input_connections=10
)
with open('weights.pkl', 'rb') as file:
  state_dict = pickle.load(file)
model.load_state_dict(state_dict)
encoder = LeNet5.load_from_checkpoint('kmnist_lenet5.pth')
model.eval()
encoder.eval()

transforms = T.Compose([
                        Rearrange('h w -> h w ()'),
                        T.ToTensor(),
                        Rearrange('h w c -> () h w c'),
                        T.Lambda(lambda x : x.float()),
                        T.Lambda(lambda x : x/x.max())
])

def recognize_digit(img):
  img = transforms(img)

  with torch.no_grad():
    input = encoder.forward_features(img).numpy()
  outputs = softmax(rearrange(model(input), '() s -> s').tolist())
  
  labels_confs_dict = dict(zip(range(10), outputs))
  return labels_confs_dict

gr.Interface(fn=recognize_digit, inputs="sketchpad", outputs="label").launch()



Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`
This share link will expire in 24 hours. If you need a permanent link, visit: https://gradio.app/introducing-hosted (NEW!)
Running on External URL: https://16906.gradio.app
Interface loading below...


(<Flask 'gradio.networking'>,
 'http://127.0.0.1:7861/',
 'https://16906.gradio.app')