In [1]:
!pip install memorywrap
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from memorywrap import MemoryWrapLayer, BaselineMemory
seed = 0




This notebook shows how to add a Memory Wrap layer to a given architecture.

# Model

Let's start from the definition of the starting model. For example here we have the implementation of MobileNet-v2:

In [2]:
class Block(nn.Module):
    '''expand + depthwise + pointwise'''
    def __init__(self, in_planes, out_planes, expansion, stride):
        super(Block, self).__init__()
        self.stride = stride

        planes = expansion * in_planes
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm2d(out_planes)

        self.shortcut = nn.Sequential()
        if stride == 1 and in_planes != out_planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(out_planes),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out = out + self.shortcut(x) if self.stride==1 else out
        return out
    
class MobileNetV2(nn.Module):
    # (expansion, out_planes, num_blocks, stride)
    cfg = [(1,  16, 1, 1),
           (6,  24, 2, 1), 
           (6,  32, 3, 2),
           (6,  64, 4, 2),
           (6,  96, 3, 1),
           (6, 160, 3, 2),
           (6, 320, 1, 1)]

    def __init__(self, num_classes=10):
        super(MobileNetV2, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.layers = self._make_layers(in_planes=32)
        self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(1280)
        self.linear = nn.Linear(1280, num_classes)

    def _make_layers(self, in_planes):
        layers = []
        for expansion, out_planes, num_blocks, stride in self.cfg:
            strides = [stride] + [1]*(num_blocks-1)
            for stride in strides:
                layers.append(Block(in_planes, out_planes, expansion, stride))
                in_planes = out_planes
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layers(out)
        out = F.relu(self.bn2(self.conv2(out)))
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In order to add the Memory Wrap to this architecture we have to:
* replace the last layer with a Memory Wrap layer;
* modify the forward function.

#### Replace the last layer
The first edit is straightforward. In this case we have to replace the line of code <br>
```self.linear = nn.Linear(1280, num_classes)``` <br>
with <br>
```self.mw = MemoryWrapLayer(1280,num_classes)``` <br>
where the parameters are the same: the first one is the number of dimension in input, and the second one is the output's dimension.
If you want to use the baseline that uses only the memory to compute the output, please replace MemoryWrapLayer with BaselineMemory.

#### Modify the forward function 

In the second step, first of all, we have to remove the call to self.linear in the forward function, being removed from the code, and rename the forward function. For clearness we rename it in forward_encoder, highlighting the role of this function for our architecture. <br> Therefore, the forward function becomes:
```python
def forward_encoder(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layers(out)
        out = F.relu(self.bn2(self.conv2(out)))
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        
        #here we removed the call to self.linear
        
        return out
```

And then we can make our new forward function that includes our new Memory Wrap layer.
```python
    def forward(self, x, ss, return_weights=False):

        # inputs
        out = self.forward_encoder(x)
        out_ss = self.forward_encoder(ss)

        # prediction
        out_mw = self.mw(out,out_ss,return_weights)
        return out_mw
```

What the forward function does is encoding both the input and the memory set and pass them to the Memory Wrap layer. The last argument of the Memory Wrap's call function is a boolean flag controlling the number of outputs returned. If the flag is True, then the layer returns both the output and the sparse attention weight associated to each memory sample; if the flag is False, then the layer return only the output.

Merging all the modification together we can create a new class called MemoryWrapMobileNet.

In [3]:
class MemoryWrapMobileNetV2(nn.Module):
    # (expansion, out_planes, num_blocks, stride)
    cfg = [(1,  16, 1, 1),
           (6,  24, 2, 1),  # NOTE: change stride 2 -> 1 for CIFAR10
           (6,  32, 3, 2),
           (6,  64, 4, 2),
           (6,  96, 3, 1),
           (6, 160, 3, 2),
           (6, 320, 1, 1)]

    def __init__(self, num_classes=10):
        super(MemoryWrapMobileNetV2, self).__init__()
        # NOTE: change conv1 stride 2 -> 1 for CIFAR10
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.layers = self._make_layers(in_planes=32)
        self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(1280)
        
        #replaced last layer
        #self.linear = nn.Linear(1280, num_classes)
        self.mw = MemoryWrapLayer(1280,num_classes)

    def _make_layers(self, in_planes):
        layers = []
        for expansion, out_planes, num_blocks, stride in self.cfg:
            strides = [stride] + [1]*(num_blocks-1)
            for stride in strides:
                layers.append(Block(in_planes, out_planes, expansion, stride))
                in_planes = out_planes
        return nn.Sequential(*layers)

    def forward_encoder(self, x):

        #input
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layers(out)
        out = F.relu(self.bn2(self.conv2(out)))
        # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)

        return out

    def forward(self, x, memory_set, return_weights=False):

        #input
        out = self.forward_encoder(x)
        out_ms = self.forward_encoder(memory_set)

        # prediction
        out_mw = self.mw(out,out_ms,return_weights)
        return out_mw

# Dataset

Download the SVHN dataset and randomly extract 2000 samples that will be our training dataset

In [4]:
# CONFIG dataset
len_dataset = 2000 # size training dataset
data_dir = 'datasets/' #directory where dataset is stored

In [5]:
normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        normalize,
])
train_data = torchvision.datasets.SVHN(data_dir, split='train', download=True, transform=transforms)
test_data =  torchvision.datasets.SVHN(data_dir, split='test', download=True, transform=transforms)
train_dataset, _ = torch.utils.data.random_split(train_data,[len_dataset,len(train_data)-len_dataset], generator=torch.Generator().manual_seed(seed))
print("In the training dataset there are {} samples.".format(len(train_dataset)))

Downloading http://ufldl.stanford.edu/housenumbers/train_32x32.mat to datasets/train_32x32.mat


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=182040794.0), HTML(value='')))


Downloading http://ufldl.stanford.edu/housenumbers/test_32x32.mat to datasets/test_32x32.mat


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64275384.0), HTML(value='')))


In the training dataset there are 2000 samples.


Prepare dataloaders for training and testing. Note that both the loaders for the training set and for the memory set share the same dataset.

In [6]:
# Training config
batch_size_train = 128
samples_in_memory = 100
batch_size_test = 128

In [7]:
train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size_train, shuffle=True, drop_last=True)
mem_loader = torch.utils.data.DataLoader(train_dataset, batch_size=samples_in_memory, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size_test, shuffle=False,drop_last=True)


## Training

The following code prepare two models for comparison: the standard MobileNet and the variant with Memory Wrap.

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [9]:
# training parameters
num_epochs = 40
dict_optim = {'lr' :1e-1, 'momentum':0.9, 'weight_decay':5e-4, 'nesterov':True}


std_model = MobileNetV2(10)
std_model = std_model.to(device)
std_optimizer = torch.optim.SGD(std_model.parameters(),**dict_optim)
std_scheduler = torch.optim.lr_scheduler.MultiStepLR(std_optimizer,  milestones=[20,30])

mw_model = MemoryWrapMobileNetV2(10)
mw_model = mw_model.to(device)
mw_optimizer = torch.optim.SGD(mw_model.parameters(),**dict_optim)
mw_scheduler = torch.optim.lr_scheduler.MultiStepLR(mw_optimizer,  milestones=[20,30])


loss_criterion = torch.nn.CrossEntropyLoss()




In [10]:
mw_model.train()
std_model.train()
for epoch in range(1, num_epochs + 1):
    for batch_idx, (data, y) in enumerate(train_loader):

        std_optimizer.zero_grad()
        mw_optimizer.zero_grad()
        
        # input
        data = data.to(device)
        y = y.to(device)
        
        # here we randomly extract a new memory set for the given batch
        memory_input, _ = next(iter(mem_loader))
        memory_input = memory_input.to(device)
        
        # note that the memory model takes both the input and the memory set
        mw_outputs  = mw_model(data,memory_input)
        mw_loss = loss_criterion(mw_outputs, y)
        
        std_outputs  = std_model(data)
        std_loss = loss_criterion(std_outputs, y)
        
        std_loss.backward()
        mw_loss.backward()
        
        
        std_optimizer.step()
        mw_optimizer.step()
        
        #log stuff
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [({:.0f}%({})]\t'.format(
            epoch,
            100. * batch_idx / len(train_loader), len(train_loader.dataset)),end='\r')

    std_scheduler.step()# increase scheduler step for each epoch
    mw_scheduler.step()

Train Epoch: 40 [(0%(2000)]	

Now we are ready to evaluate our two models, comparing their accuracy:

In [11]:
std_model.eval()
mw_model.eval()
std_correct = 0
mw_correct = 0
with torch.no_grad():
    for _, (data, target) in enumerate(test_loader):
        data = data.to(device)
        target = target.to(device)
        memory, _ = next(iter(mem_loader))
        memory = memory.to(device)

        mw_outputs  = mw_model(data,memory_input)
        mw_loss = loss_criterion(mw_outputs, y)
        mw_pred = mw_outputs.data.max(1, keepdim=True)[1]
        
        std_outputs  = std_model(data)
        std_loss = loss_criterion(std_outputs, y)
        std_pred = std_outputs.data.max(1, keepdim=True)[1]

        std_correct += std_pred.eq(target.data.view_as(std_pred)).sum().item()
        mw_correct += mw_pred.eq(target.data.view_as(mw_pred)).sum().item()

std_accuracy = 100.*(torch.true_divide(std_correct,len(test_loader.dataset))).item()
mw_accuracy = 100.*(torch.true_divide(mw_correct,len(test_loader.dataset))).item()


In [14]:
print("The standard model correctly classify {:.2f}% of the images".format(std_accuracy))
print("The model with Memory Wrap correctly classify {:.2f}% of the images".format(mw_accuracy))

The standard model correctly classify 70.24% of the images
The model with Memory Wrap correctly classify 77.95% of the images
