# Implementation of Stand Alone Self Attention in Vision Models

![Stand Alone SA module](images/sasa.png)

In [1]:
from fastai.vision.all import *
from sklearn.metrics import auc, roc_curve, precision_recall_curve, classification_report
from sklearn.model_selection import StratifiedKFold
import gc

## Relative Self Attention Module

In [2]:
class RelativeSelfAttention(Module):
    def __init__(self, d_in, d_out, ks, groups, stride=1):
        self.n_c, self.ks, self.groups, self.stride = d_out, ks, groups, stride
        # linear transformation for queries, values and keys
        self.qx, self.kx, self.vx = [ConvLayer(d_in, d_out, ks=1, norm_type=None,
                                               act_cls=None) for _ in range(3)]
        # positional embeddings
        self.row_embeddings = nn.Parameter(torch.randn(d_out//2, ks))
        self.col_embeddings = nn.Parameter(torch.randn(d_out//2, ks))
        
    def calc_out_shape(self, inp_shape, pad):
        out_shape = [(sz + 2*pad - self.ks) // self.stride + 1 for sz in inp_shape]
        return out_shape
    
    def forward(self, x):
        query, keys, values = self.qx(x), self.kx(x), self.vx(x)
        
        pad = (self.ks -1) // 2
        
        # use unfold to extract the memory blocks and their associated queries
        query = F.unfold(query, kernel_size=1, stride=self.stride)
        keys = F.unfold(keys, kernel_size=self.ks, padding=pad, stride=self.stride)
        values = F.unfold(values, kernel_size=self.ks, padding=pad, stride=self.stride)
        
        
        # reshape and permute the dimensions into the appropriate format for matrix multiplication
        query = query.view(query.shape[0], self.groups, self.n_c//self.groups, -1, query.shape[-1]) # bs*G*C//G*1*N
        query = query.permute(0, 4, 1, 2, 3) # bs * N * G * C//G * 1
        keys = keys.view(keys.shape[0], self.groups, self.n_c//self.groups, -1, keys.shape[-1]) # bs*G*C//G*ks^2*N
        keys = keys.permute(0, 4, 1, 2, 3) # bs * N * G * C//G * ks^2
        values = values.view(values.shape[0], self.groups, self.n_c//self.groups, -1, values.shape[-1]) # bs*G*C//G*ks^2*N
        values = values.permute(0, 4, 1, 2, 3) # bs * N * G * C//G * ks^2
        
        # get positional embeddings
        row_embeddings = self.row_embeddings.unsqueeze(-1).expand(-1, -1, self.ks)
        col_embeddings = self.col_embeddings.unsqueeze(-2).expand(-1, self.ks, -1)
        
        embeddings = torch.cat((row_embeddings, col_embeddings)).view(self.groups,
                                self.n_c//self.groups, -1) # G * C//G * ks^2
        # add empty dimensions to match the shape of keys
        embeddings = embeddings[None, None, -1] # 1 * 1 * G * C//G * ks^2
        
        # compute attention map
        att_map = F.softmax(torch.matmul(query.transpose(-2,-1), keys+embeddings).contiguous(), dim=-1)
        # compute final output
        out = torch.matmul(att_map, values.transpose(-2,-1)).contiguous().permute(0, 2, 3, 4, 1)
        
        return out.view(out.shape[0], self.n_c, *self.calc_out_shape(x.shape[-2:], pad)).contiguous()

## Define helper functions and modules

In [3]:
def resnet_stem(*sizes):
    return [
        ConvLayer(sizes[i], sizes[i+1], stride=2 if i==0 else 1)
         for i in range(len(sizes) - 1)
    ] + [nn.MaxPool2d(kernel_size=3, stride=2, padding=1)]

In [4]:
def bottleneck(ni, nf, stride):
    if stride==1:
        layers = [ConvLayer(ni, nf//4, ks=1),
              RelativeSelfAttention(nf//4, nf//4, ks=7, groups=8),
              ConvLayer(nf//4, nf, ks=1, act_cls=None, norm_type=NormType.BatchZero)]
    else:
        layers = [ConvLayer(ni, nf//4, ks=1),
              RelativeSelfAttention(nf//4, nf//4, ks=7, groups=8),
              nn.AvgPool2d(2, ceil_mode=True),
              ConvLayer(nf//4, nf, ks=1, act_cls=None, norm_type=NormType.BatchZero)]
    
    return nn.Sequential(*layers)

In [5]:
class ResNetBlock(Module):
    def __init__(self, ni, nf, stride, sa, expansion=1):
        self.botl = bottleneck(ni, nf, stride)
        self.idconv = noop if ni==nf else ConvLayer(ni, nf, 1, act_cls=None)
        self.pool = noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)
        
    def forward(self, x):
        return F.relu(self.botl(x) + self.idconv(self.pool(x)))

## Define Randomization Layer to Protect Against Adversarial Attacks

In [6]:
class RandHead(Module):
    def __init__(self):
        pass
        
    def forward(self, x):
        rz_shape = torch.randint(210, 224, (1,))
        pad = (224 - rz_shape).item()
        h = torch.randint(0, pad+1, (1,))
        w = torch.randint(0, pad+1, (1,))
        
        # step 1 random resize
        out = F.interpolate(x, [rz_shape]*2)
        # step 2 pad
        return F.pad(out, (h, pad-h, w, pad-w))
        

## Define ResNet module

In [7]:
class xResNet(nn.Sequential):
    def __init__(self, channels, n_out, blocks, sa=True, expansion=1):
        stem = resnet_stem(channels, 32, 32, 64)
        self.group_sizes = [64, 64, 128, 256, 512]
        for i in range(1, len(self.group_sizes)): 
            self.group_sizes[i] *= expansion
        groups = [self._make_group(idx, n_blocks, sa=sa if idx==0 else False) 
                      for idx, n_blocks in enumerate(blocks)]
        
        super().__init__(RandHead(), *stem, *groups,
                         nn.AdaptiveAvgPool2d(1), Flatten(),
                         nn.Linear(self.group_sizes[-1], n_out))
        
    def _make_group(self, idx, n_blocks, sa):
        stride = 1 if idx==1 else 2
        ni, nf = self.group_sizes[idx], self.group_sizes[idx+1]
        return nn.Sequential(*[
            ResNetBlock(ni if i==0 else nf, nf, stride=stride if i==0 else 1,
                        sa=sa if i==n_blocks-1 else False)
             for i in range(n_blocks)
        ])

## Get the data

In [8]:
path = Path('data/')
Path.BASE_PATH = path
path.ls()

(#2) [Path('.ipynb_checkpoints'),Path('cell_images')]

In [9]:
imgs = get_image_files(path/'cell_images/cell_images')

In [10]:
def get_y(o):
    return [o.parent.name]

In [11]:
idxs = []
lbls = []

for img in imgs:
    idxs.append(img.name)
    lbls.append(get_y(img)[0])

In [12]:
def get_dls(size, bs, valid_idx):
    dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),
       get_items=get_image_files,
       splitter=IndexSplitter(valid_idx),
       get_y=get_y,
       item_tfms=Resize(size),
       batch_tfms=[*aug_transforms(flip_vert=True, max_zoom=1.2, max_warp=0), Normalize()])
    
    return dblock.dataloaders(path/'cell_images/cell_images', bs=bs)

# Learner

In [13]:
preds = []
y_true = []

In [14]:
skf = StratifiedKFold(n_splits=5, shuffle=True)

for _, valid_idx in skf.split(idxs, lbls):
    dls = get_dls(200, 64, valid_idx)
    model = xResNet(3, dls.c, [2,2,2,2])
    learn= Learner(dls, model, metrics=partial(accuracy_multi, thresh=0.5))
    
    learn.fit_one_cycle(20, 1e-3)
    
    probs, y = learn.get_preds()
    preds.append(probs)
    y_true.append(y)
    
    del learn
    torch.cuda.empty_cache()
    gc.collect()

epoch,train_loss,valid_loss,accuracy_multi,time
0,0.531832,0.464954,0.840893,02:35
1,0.206593,0.169722,0.952286,02:26
2,0.182522,0.158924,0.947388,02:26
3,0.180504,0.133254,0.955733,02:26
4,0.163738,0.130562,0.953284,02:26
5,0.168682,0.170936,0.950835,02:26
6,0.143835,0.136048,0.95791,02:26
7,0.150238,0.141189,0.94902,02:26
8,0.142057,0.126447,0.952014,02:26
9,0.140169,0.116269,0.960268,02:26


epoch,train_loss,valid_loss,accuracy_multi,time
0,0.563511,0.512608,0.775943,02:26
1,0.202331,0.140095,0.953465,02:26
2,0.187205,0.132805,0.955098,02:26
3,0.180078,0.135368,0.956096,02:26
4,0.176826,0.1324,0.958364,02:26
5,0.163576,0.129507,0.959543,02:26
6,0.163538,0.118817,0.956731,02:26
7,0.156214,0.121968,0.955189,02:26
8,0.14601,0.118234,0.960994,02:26
9,0.133884,0.108878,0.962264,02:26


epoch,train_loss,valid_loss,accuracy_multi,time
0,0.502862,0.419112,0.86602,02:26
1,0.196545,0.15568,0.952921,02:26
2,0.175569,0.143203,0.953647,02:26
3,0.162562,0.139553,0.957366,02:26
4,0.175195,0.13335,0.957003,02:26
5,0.153793,0.135653,0.956005,02:26
6,0.14941,0.12398,0.958545,02:26
7,0.143443,0.110695,0.962718,02:26
8,0.141468,0.142956,0.950653,02:26
9,0.144461,0.103184,0.965348,02:26


epoch,train_loss,valid_loss,accuracy_multi,time
0,0.509333,0.441438,0.840773,02:26
1,0.196394,0.152362,0.947559,02:26
2,0.186037,0.190033,0.936763,02:26
3,0.172928,0.148783,0.950372,02:26
4,0.164753,0.157867,0.944656,02:26
5,0.142089,0.136086,0.954727,02:26
6,0.151851,0.126226,0.955362,02:26
7,0.156823,0.132097,0.956904,02:27
8,0.128396,0.116626,0.957721,02:26
9,0.142917,0.145687,0.949465,02:26


epoch,train_loss,valid_loss,accuracy_multi,time
0,0.503224,0.435323,0.849664,02:26
1,0.209914,0.16685,0.945836,02:26
2,0.187666,0.154499,0.946199,02:26
3,0.172909,0.141712,0.950553,02:26
4,0.166059,0.155283,0.947197,02:26
5,0.159721,0.126397,0.956451,02:26
6,0.1502,0.122845,0.956451,02:26
7,0.162586,0.117667,0.9589,02:26
8,0.148136,0.115064,0.961531,02:26
9,0.136763,0.115524,0.961894,02:26


In [15]:
dls = get_dls(200, 64, valid_idx)

In [16]:
import pickle

In [17]:
with open('record.pkl', 'wb') as f:
    pickle.dump([preds, y_true], f)

## Get metrics for each fold

Fold 1

In [18]:
ps = np.argmax(preds[0].numpy(), axis=-1)
y = np.argmax(y_true[0].numpy(), axis=-1)
y.shape, ps.shape

((5512,), (5512,))

In [20]:
classes = list(dls.vocab)
report = classification_report(y, ps, target_names=classes)
print(report)

              precision    recall  f1-score   support

 Parasitized       0.98      0.96      0.97      2756
  Uninfected       0.96      0.98      0.97      2756

    accuracy                           0.97      5512
   macro avg       0.97      0.97      0.97      5512
weighted avg       0.97      0.97      0.97      5512



Fold 2

In [21]:
ps = np.argmax(preds[1].numpy(), axis=-1)
y = np.argmax(y_true[1].numpy(), axis=-1)
y.shape, ps.shape

((5512,), (5512,))

In [22]:
classes = list(dls.vocab)
report = classification_report(y, ps, target_names=classes)
print(report)

              precision    recall  f1-score   support

 Parasitized       0.98      0.96      0.97      2756
  Uninfected       0.96      0.98      0.97      2756

    accuracy                           0.97      5512
   macro avg       0.97      0.97      0.97      5512
weighted avg       0.97      0.97      0.97      5512



**Fold 3**

In [23]:
ps = np.argmax(preds[2].numpy(), axis=-1)
y = np.argmax(y_true[2].numpy(), axis=-1)
y.shape, ps.shape

((5512,), (5512,))

In [24]:
classes = list(dls.vocab)
report = classification_report(y, ps, target_names=classes)
print(report)

              precision    recall  f1-score   support

 Parasitized       0.98      0.95      0.97      2756
  Uninfected       0.95      0.98      0.97      2756

    accuracy                           0.97      5512
   macro avg       0.97      0.97      0.97      5512
weighted avg       0.97      0.97      0.97      5512



**Fold 4**

In [25]:
ps = np.argmax(preds[3].numpy(), axis=-1)
y = np.argmax(y_true[3].numpy(), axis=-1)
y.shape, ps.shape

((5511,), (5511,))

In [26]:
classes = list(dls.vocab)
report = classification_report(y, ps, target_names=classes)
print(report)

              precision    recall  f1-score   support

 Parasitized       0.98      0.95      0.96      2755
  Uninfected       0.95      0.98      0.96      2756

    accuracy                           0.96      5511
   macro avg       0.96      0.96      0.96      5511
weighted avg       0.96      0.96      0.96      5511



**Fold 5**

In [28]:
ps = np.argmax(preds[4].numpy(), axis=-1)
y = np.argmax(y_true[4].numpy(), axis=-1)
y.shape, ps.shape

((5511,), (5511,))

In [29]:
classes = list(dls.vocab)
report = classification_report(y, ps, target_names=classes)
print(report)

              precision    recall  f1-score   support

 Parasitized       0.98      0.96      0.97      2756
  Uninfected       0.96      0.98      0.97      2755

    accuracy                           0.97      5511
   macro avg       0.97      0.97      0.97      5511
weighted avg       0.97      0.97      0.97      5511

