### **Neural Networks Project** 
Sequencer: Deep LSTM for Image Classification
###### Vincenzo Guarino - 1742728

### Paper Introduction

Sequencer: Deep LSTM for Image Classification proposes a novel architecture for image classification, called Sequencer, based on the idea of replacing the self-attention layer presented in the Vision Transformer (ViT) with Long Short-Term Memory (LSTM) networks. 

This change has the goal to improve the memory efficiency of the architecture while also keeping the ability to learn long-range dependencies.

The paper also introduces a two-dimensional version of Sequencer (the one implemented in this project) called Sequencer2D, which uses two bidirectional LSTM to process the vertical horizontal axis of the image patches in parallel to enhance performance and reduce the sequence length.

#### **Sequencer2D architecture**

| ![sequencer architecture](https://i.imgur.com/ZEbPTmD.png)|
|:--:| 
| *Sequencer architecture* | 

The overall structure of the Sequencer architecture is represented in the image above.

As we can see, it takes as input non-overlapping patches of an image which are then processed by the main (repeating) component: the Sequencer2D block, represented here:

| ![sequencerblock architecture](https://i.imgur.com/cCQfuz3.png)|
|:--:| 
| *Sequencer2D Block* |

The Sequencer2D Block has two sub-components, a MLP layer for channel-mixing and the BiLSTM layer.

The MLP layer is derived by the ViT architecture, it is composed by two linear transformations and a GELU activation function in between.

The BiLSTM2D Layer is composed as following:

| ![bilstm2d architecture](https://i.imgur.com/Zkj1G8g.png)|
|:--:| 
| *BiLSTM2D layer* |

The BiLSTM2D layer consists of two plain BiLSTMs: one vertical processeing each column of patches as a sequence and one horizontal doing the same but for the rows. 

The outputs of both BiLSTMs are then concatenated and processed point-wise by a linear layer to obtain the final output of the layer.

Compared to the multi-head attention layer in ViT, the BiLSTM2D layer scales better for high-resolution images, in fact the BiLSTM2D layer has a memory complexity of (WC + HC)/2, compared to h∗(HW)^2 of the multi-head attention, where h is a number of heads the multi-head attention layer, H is the number of tokens in the vertical direction, W is the number of sequences in the horizontal direction, and C is the channel dimension.

There is an advantage even on throughput, with the computational complexity of self-attention being O(W^4C), compared to O(WC^2) of the BiLSTM (with W = H for simplicity).

### Code implementation

### Part 1: Imports and initialization

In [1]:
import torch
import pytorch_lightning as pylight
import torch.nn.functional as F
from torch import nn, optim

torch.set_float32_matmul_precision("high")

### Part 2: Hyperparameters definition

In [2]:
# I decided to go with a class in order to conveniently organize all the hyperparameters 
class SequencerParams():

    def __init__(self, variant, dataset):
        # stage params as described in table 4 of the paper
        if variant.lower() == "xs":
            self.layers = [2, 2, 4, 3]
            self.drop_path = 0.1
        elif variant.lower() == "s":
            self.layers = [4, 3, 8, 3]
            self.drop_path = 0.1
        elif variant.lower() == "m":
            self.layers = [4, 3, 14, 3]
            self.drop_path = 0.2
        elif variant.lower() == "l":
            self.drop_path = 0.4
            self.layers = [8, 8, 16, 4]
        self.dropout = 0.0
        self.mlp_ratio = 3
        self.stage_num = 4
        self.mixup_alpha = 0.8
        self.cutmix_alpha = 1.0
        self.label_smoothing = 0.1
        self.weight_decay = 0.05
        self.cycle_decay = 0.5
        self.lr_min = 1e-6
        self.random_erasing = 0.25
        self.crop_pct = 0.875
        self.auto_augment = "rand-m9-mstd0.5-inc1"
        
        if dataset == "imagenet":
            self.input_size = [3, 224, 224]
            self.num_classes = 200
            self.dataset_name = "imagenet"
        elif dataset == "tiny-imagenet-200":
            self.input_size = [3, 64, 64]
            self.num_classes = 200
            self.dataset_name = "tiny-imagenet-200"
        elif dataset == "cifar10":
            self.input_size = [3, 32, 32]
            self.num_classes = 10
            self.dataset_name = "cifar10"
        
        # opt and training params
        self.epochs = 300
        self.cooldown_epochs = 10
        self.warmup_epochs = 20
        self.batch_size = 256
        self.img_size = self.input_size[1]
        self.embed_dims = [192, 384, 384, 384]
        self.hidden_dims = [48, 96, 96, 96]
        self.patch_sizes = [7, 2, 1, 1]
        self.base_lr = 2e-3

seq_params = SequencerParams("S", "cifar10")

### Part 3a: Prepare Tiny-Imagenet 

In [3]:
# code edited from https://github.com/pytorch/vision/issues/6127#issuecomment-1555049003
import os
import re
def create_dir(base_path, classname):
    path = os.path.join(base_path, classname)
    if not os.path.exists(path):
        os.mkdir(path)

def reorg(filename, base_path, wordmap):
    with open(filename) as vals:
        for line in vals:
            vals = line.split()
            imagename = vals[0]
            classname = wordmap[vals[1]]
            old_path = os.path.join(base_path, 'images', imagename)
            new_path = os.path.join(base_path, classname, imagename)
            if os.path.exists(old_path):
                os.rename(old_path, new_path)
    os.rmdir(os.path.join(base_path, 'images'))

def prepare_tiny_imagenet(imagenet_path):
    if not os.path.exists(os.path.join(imagenet_path, 'val', 'images')):
        return

    wordmap = {}
    words_file = os.path.join(imagenet_path, 'words.txt')
    wnids_file = os.path.join(imagenet_path, 'wnids.txt')
    with open(words_file) as words, open(wnids_file) as wnids:
        for line in wnids:
            vals = line.split()
            wordmap[vals[0]] = ""
        for line in words:
            vals = line.split()
            if vals[0] in wordmap:
                single_words = vals[1:]
                classname =  re.sub(",", "", single_words[0])
                if len(single_words) >= 2:
                    classname += '_' + re.sub(",", "", single_words[1])
                wordmap[vals[0]] = classname
                create_dir(os.path.join(imagenet_path, 'val'), classname)
                old_train_dir = os.path.join(imagenet_path, 'train', vals[0])
                new_train_dir = os.path.join(imagenet_path, 'train', classname)
                if os.path.exists(old_train_dir):
                    os.rename(old_train_dir, new_train_dir)

    val_annotations_file = os.path.join(imagenet_path, 'val', 'val_annotations.txt')
    val_dir = os.path.join(imagenet_path, 'val')
    reorg(val_annotations_file, val_dir, wordmap)
    
    train_dir = os.path.join(imagenet_path, "train")
    classes = os.listdir(train_dir)
    for classname in classes:
        source_dir = os.path.join(train_dir, classname, "images")
        dest_dir = os.path.join(train_dir, classname)
        for file in os.listdir(source_dir):
            source_file = os.path.join(source_dir, file)
            dest_file = os.path.join(dest_dir, file)
            os.rename(source_file, dest_file)
        os.rmdir(source_dir)

### Part 3b: Dataset preparation 

In [4]:
from timm.data import create_dataset, create_loader
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split

dataset_path = "./" + seq_params.dataset_name + "/"

if seq_params.dataset_name == "imagenet":
    traindir = os.path.join(dataset_path, 'train')
    valdir = os.path.join(dataset_path, 'val')

    os.makedirs(traindir, exist_ok=True)
    os.makedirs(valdir, exist_ok=True)

    # Create the train and val datasets
    train_dataset = create_dataset(name="",
        root=traindir, split="train", is_training=True,
        batch_size=seq_params.batch_size)

    val_dataset = create_dataset(name="",
        root=valdir, split="val", is_training=False,
        batch_size=seq_params.batch_size)
        
elif seq_params.dataset_name == "tiny-imagenet-200":
    traindir = os.path.join(dataset_path, 'train')
    valdir = os.path.join(dataset_path, 'val')

    os.makedirs(traindir, exist_ok=True)
    os.makedirs(valdir, exist_ok=True)

    prepare_tiny_imagenet(dataset_path)

    # Create the train and val datasets
    train_dataset = create_dataset(name="",
        root=traindir, split="train", is_training=True,
        batch_size=seq_params.batch_size)

    val_dataset = create_dataset(name="",
        root=valdir, split="val", is_training=False,
        batch_size=seq_params.batch_size)

elif seq_params.dataset_name == "cifar10":
    train_dataset = create_dataset(name="torch/cifar10",
        root=dataset_path, split="train", download=True, is_training=True,
        batch_size=seq_params.batch_size)

    # Split the train dataset into train and val subsets
    train_indices, test_indices, _, _ = train_test_split(
        range(len(train_dataset)),
        train_dataset.targets,
        stratify=train_dataset.targets,
        test_size=0.1,
    )

    train_dataset = Subset(train_dataset, train_indices)
    val_dataset = Subset(train_dataset, test_indices)
else:
    # Raise an error if the dataset name is not valid
    raise ValueError("Invalid dataset name: {}".format(seq_params.dataset_name))

Files already downloaded and verified


### Part 4: Model building

In [5]:
# This classes are derived from the Vision Transformer and didn't had any implementation details in the paper,
# so I adapted them from the official github implementation

class PatchEmbedding(pylight.LightningModule):
    def __init__(self, patch_size, embed_dim, flatten=False):
        super().__init__()
        self.flatten = flatten
        self.conv = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.conv(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2) 
        else:
            x = x.permute(0, 2, 3, 1)  # BCHW -> BHWC
        return x

class DownsamplePatch(pylight.LightningModule):
    def __init__(self, input_dim, output_dim, patch_size):
        super().__init__()
        self.conv = nn.Conv2d(input_dim, output_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)
        x = self.conv(x)
        x = x.permute(0, 2, 3, 1)
        return x

In [6]:
import torchmetrics
from timm.scheduler.cosine_lr import CosineLRScheduler
from timm.models.layers import DropPath
from torchmetrics import Accuracy


class BiLSTM2D(pylight.LightningModule):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.rnn_v = torch.nn.LSTM(
            input_size, hidden_size, num_layers=1, batch_first=True, bias=True, bidirectional=True)
        self.rnn_h = torch.nn.LSTM(
            input_size, hidden_size, num_layers=1, batch_first=True, bias=True, bidirectional=True)
        self.fc = torch.nn.Linear(4 * hidden_size, input_size)

    def forward(self, x):
        B, H, W, C = x.shape

        v, _ = self.rnn_v(x.permute(0, 2, 1, 3).reshape(-1, H, C))
        v = v.reshape(B, W, H, -1).permute(0, 2, 1, 3)
        h, _ = self.rnn_h(x.reshape(-1, W, C))
        h = h.reshape(B, H, W, -1)
        x = torch.cat([v, h], dim=-1)
        x = self.fc(x)
        return x


class Sequencer2DBlock(pylight.LightningModule):
    def __init__(self, embed_dim, hidden_dim):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.bilstm2d = BiLSTM2D(embed_dim, hidden_dim)
        self.drop_path = DropPath(seq_params.drop_path)
        self.norm2 = nn.LayerNorm(embed_dim)

        mlp_in = embed_dim
        mlp_out = embed_dim
        mlp_hidden = int(embed_dim * seq_params.mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(mlp_in, mlp_hidden),
            nn.GELU(),
            nn.Dropout(seq_params.dropout),
            nn.Linear(mlp_hidden, mlp_out),
            nn.Dropout(seq_params.dropout),
        )

    def forward(self, x):
        # x: (B, C, H, W)
        # Token mixing with BiLSTM2D
        x = x + self.drop_path(self.bilstm2d(self.norm1(x)))
        # Channel mixing with MLP
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class Sequencer2D(pylight.LightningModule):
    def __init__(self):
        super().__init__()
        self.valid_acc = torchmetrics.Accuracy(task="multiclass", num_classes=seq_params.num_classes)

        stage_list = []
        for stage in range(seq_params.stage_num):
            if stage == 0:
                stage_list.append(PatchEmbedding(patch_size=seq_params.patch_sizes[0],
                                                 embed_dim=seq_params.embed_dims[0]))
            for layer in range(seq_params.layers[stage]):
                stage_list.append(Sequencer2DBlock(embed_dim=seq_params.embed_dims[stage],
                                                   hidden_dim=seq_params.hidden_dims[stage]))
            if stage < len(seq_params.embed_dims) - 1:
                stage_list.append(DownsamplePatch(input_dim=seq_params.embed_dims[stage],
                                                  output_dim=seq_params.embed_dims[stage+1],
                                                  patch_size=seq_params.patch_sizes[stage+1]))
        self.stages = nn.Sequential(*stage_list)

        self.norm = nn.LayerNorm(seq_params.embed_dims[-1])
        self.fc = nn.Linear(seq_params.embed_dims[-1], seq_params.num_classes)
        

    def forward(self, x):
        x = self.stages(x)
        x = self.norm(x)
        # global average pooling
        x = x.mean(dim=(1, 2))
        x = self.fc(x)
        return x

    def configure_optimizers(self):
        opt = optim.AdamW(self.parameters(), lr=seq_params.base_lr, weight_decay=seq_params.weight_decay)
        scheduler = CosineLRScheduler(opt, t_initial=seq_params.epochs,
                                      lr_min=seq_params.lr_min, 
                                      cycle_decay=seq_params.cycle_decay,
                                      warmup_t=seq_params.warmup_epochs,
                                      warmup_lr_init=seq_params.lr_min)
        seq_params.epochs = scheduler.get_cycle_length() + seq_params.cooldown_epochs
        return [opt], [{
            "scheduler": scheduler,
            "interval": "epoch"
        }]

    def lr_scheduler_step(self, scheduler, metric):
        scheduler.step(epoch=self.current_epoch)

    # self.log logs to tensorboard
    def training_step(self, batch, _):
        x, y = batch
        preds = self(x)
        loss = F.cross_entropy(preds, y, label_smoothing = seq_params.label_smoothing)
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, _):
        x, y = batch
        logits = self(x)
        val_loss = F.cross_entropy(logits, y)
        self.log("val_loss", val_loss, prog_bar=True)
        self.valid_acc(logits, y)
        self.log('valid_acc', self.valid_acc, prog_bar=True, on_step=True, on_epoch=True)

### Part 5: Data loading

In [7]:
from timm.data import FastCollateMixup

collate = FastCollateMixup(mixup_alpha=seq_params.mixup_alpha, cutmix_alpha=seq_params.cutmix_alpha,
    label_smoothing=seq_params.label_smoothing, num_classes=seq_params.num_classes)

train_loader = create_loader(
        train_dataset.dataset if seq_params.dataset_name == "cifar10" else train_dataset,
        input_size=seq_params.input_size,
        batch_size=seq_params.batch_size,
        is_training=True,
        re_prob=seq_params.random_erasing,
        re_mode="pixel",
        scale=[0.08, 1.0],
        ratio=[0.75, 1.33],
        auto_augment=seq_params.auto_augment,
        interpolation="random",
        num_workers=8,
        collate_fn=collate,
        pin_memory=True
    )

val_loader = create_loader(
        val_dataset.dataset if seq_params.dataset_name == "cifar10" else val_dataset,
        input_size=seq_params.input_size,
        batch_size=seq_params.batch_size,
        is_training=False,
        interpolation="bicubic",
        crop_pct=seq_params.crop_pct,
        num_workers=8,
        pin_memory=True
    )

### Part 6: Training

In [None]:
seq_model = Sequencer2D()

# Used bf16-mixed precision since it speeds up the training time on my local machine (RTX 3070)
trainer = pylight.Trainer(precision="bf16-mixed", max_epochs=seq_params.epochs)
trainer.fit(seq_model, train_loader, val_loader)

### Part 7: Results

The authors of the paper used the ImageNet-1K dataset, which contains 1.281.167 training images with a resolution of 224x224 pixels. However, I could not train the model on such a large dataset with my local machine. \
I attempted to reduce the model size by creating a “XS” version with fewer layers, but it still wasn't enough to complete the ImageNet-1K training in less then 24h of training. Therefore, I decided to use smaller datasets for this project: CIFAR10 and Tiny-Imagenet. \
CIFAR10 has 60.000 images of 32x32 pixels in 10 classes, with 6000 images per class. Tiny-Imagenet has 100.000 images of 64x64 pixels in 200 classes, with 500 images per class. \
Unfortunately, training the Sequencer model with such low-res data probably doesn't fully exploit the LSTM’s ability to capture long-range dependencies, which could explain why after trying both datasets, I was only able to reach an accuracy of about 60% on the validation set. \
I'll report the tensorboard data for both trainings after 300 epochs:

CIFAR10:

| ![CIFAR10 validation accuracy](https://i.imgur.com/Ho90Shd.png)| ![CIFAR10 validation loss](https://i.imgur.com/LcDK7Fk.png) |
|:--:| :--: |
| *CIFAR10 validation accuracy* | *CIFAR10 validation loss* |

Tiny-Imagenet:

| ![Tiny-Imagenet validation accuracy](https://i.imgur.com/VxpnNct.png)| ![Tiny-Imagenet validation loss](https://i.imgur.com/EV7mYys.png) |
|:--:| :--: |
| *Tiny-Imagenet validation accuracy* | *Tiny-Imagenet validation loss* |

I could probably reach an higher accuracy by tweaking the hyperparameters more (I mostly used the values mentioned in the paper), however this was also proven difficult as each training would require hours to complete.