# Residual Attention Network for Image Classification 


You must first load two files after starting the runtime from the repository. You can upload them directly from your local device or from GDrive.
The files are stored in layers folder in [our repo](https://github.com/AlKun25/ResAttNet) :
 - [attention_module.py](https://github.com/AlKun25/ResAttNet/blob/main/layers/attention_module.py)
 - [basic_layers.py](https://github.com/AlKun25/ResAttNet/blob/main/layers/basic_layers.py)

 Note: There is a possibility of _import error_ being genrated, due to the way these two files are interlinked. Just make the required corrections.

In [None]:
!pip install pytorch-lightning 

Collecting pytorch-lightning
[?25l  Downloading https://files.pythonhosted.org/packages/c3/3d/fffaf4f83633552249a40d3d366f460f44539bce0592c568d8ee20d782fa/pytorch_lightning-1.1.6-py3-none-any.whl (687kB)
[K     |████████████████████████████████| 696kB 17.0MB/s 
Collecting fsspec[http]>=0.8.1
[?25l  Downloading https://files.pythonhosted.org/packages/ec/80/72ac0982cc833945fada4b76c52f0f65435ba4d53bc9317d1c70b5f7e7d5/fsspec-0.8.5-py3-none-any.whl (98kB)
[K     |████████████████████████████████| 102kB 10.1MB/s 
[?25hCollecting future>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)
[K     |████████████████████████████████| 829kB 57.8MB/s 
[?25hCollecting PyYAML!=5.4.*,>=5.1
[?25l  Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)
[K     |████████████████████████████████| 

In [None]:
from pytorch_lightning.accelerators import accelerator
from pytorch_lightning.core.datamodule import LightningDataModule
import torch 
import torch.nn as nn 
from torch.nn import functional as F
import pytorch_lightning as pl
import time
from torchvision.datasets import CIFAR10 # you can change this with CIFAR100, but remember to make all the corresponding changes in DataModule
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from attention_module import *

In [None]:
# Defining DataModule for smooth use of the dataset in pytorch_lightning framework

class CIFARDataModule(pl.LightningDataModule):
  '''
  You can use 2 different datasets in this DataModule: CIFAR10 and CIFAR100
  To use a different dataset, change the import as welll as make corresponding changes in 
  If you change the dataset, make sure to change the value of n_classes accordingly
  '''
  def prepare_data(self, image_size):
      # prepare transforms standard to CIFAR-10
      CIFAR10(root='./data/', train=True, download=True)
      CIFAR10(root='./data/', train=False, download=True)
      # this depends on the model you use. It can be 224 or 32
      self.image_size = image_size

  def train_dataloader(self):
      train_transform = transforms.Compose([  transforms.RandomHorizontalFlip(),
                                              transforms.RandomCrop((32, 32), padding=4),   #left, top, right, bottom, 
                                              transforms.Resize(self.image_size),
                                              transforms.ToTensor()
                                          ])
      cifar_train = CIFAR10(root='./data/', train=True, download=False, transform=train_transform)
      cifar_train = DataLoader( dataset=cifar10_train, 
                                  batch_size=32, 
                                  shuffle=True, 
                                  num_workers=4
                              )
      return cifar_train

  def test_dataloader(self):
      test_transform = transforms.Compose([transforms.Resize(self.image_size),
                                           transforms.ToTensor()])
      cifar_test = CIFAR10(root='./data/', train=False,download=False,transform=test_transform)
      cifar_test = DataLoader( dataset=cifar_test, 
                                  batch_size=20, 
                                  shuffle=False
                              )
      return cifar_test

### Model architecture:
There are different models available to be used in the repository, in the [models](https://github.com/AlKun25/ResAttNet/tree/main/model) folder. The models available to be used are the following:
  - [Attention-56, image_size = 224](https://github.com/AlKun25/ResAttNet/blob/main/model/RAN_56_224.py)
  - [Attention-92, image_size = 224](https://github.com/AlKun25/ResAttNet/blob/main/model/RAN_92_224.py)
  - [Attention-92, image_size = 32](https://github.com/AlKun25/ResAttNet/blob/main/model/RAN_92_32.py)

In [None]:
# Defining model and its architecture

class ResidualAttentionModel(pl.LightningModule):
    '''
    You have to place the model here
    Replace code in this block to use different models
    When you change the model, remember to change the value of image_size accordingly
    '''
    def __init__(self, n_classes):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias = False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.residual_block1 = ResidualBlock(64, 256)
        self.attention_module1 = AttentionModule_stage1(256, 256)
        self.residual_block2 = ResidualBlock(256, 512, 2)
        self.attention_module2 = AttentionModule_stage2(512, 512)
        self.attention_module2_2 = AttentionModule_stage2(512, 512)  # tbq add
        self.residual_block3 = ResidualBlock(512, 1024, 2)
        self.attention_module3 = AttentionModule_stage3(1024, 1024)
        self.attention_module3_2 = AttentionModule_stage3(1024, 1024)  # tbq add
        self.attention_module3_3 = AttentionModule_stage3(1024, 1024)  # tbq add
        self.residual_block4 = ResidualBlock(1024, 2048, 2)
        self.residual_block5 = ResidualBlock(2048, 2048)
        self.residual_block6 = ResidualBlock(2048, 2048)
        self.mpool2 = nn.Sequential(
            nn.BatchNorm2d(2048),
            nn.ReLU(inplace=True),
            nn.AvgPool2d(kernel_size=7, stride=1)
        )
        self.fc = nn.Linear(2048,n_classes) # n_classes =10(CIFAR10)/100(CIFAR100)

    def forward(self, x):
        out = self.conv1(x)
        out = self.mpool1(out)
        # print(out.data)
        out = self.residual_block1(out)
        out = self.attention_module1(out)
        out = self.residual_block2(out)
        out = self.attention_module2(out)
        out = self.attention_module2_2(out)
        out = self.residual_block3(out)
        # print(out.data)
        out = self.attention_module3(out)
        out = self.attention_module3_2(out)
        out = self.attention_module3_3(out)
        out = self.residual_block4(out)
        out = self.residual_block5(out)
        out = self.residual_block6(out)
        out = self.mpool2(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)

        return out

    def cross_entropy_loss(self, logits, labels):
        return F.cross_entropy(logits, labels)

    def training_step(self, train_batch, batch_idx):
        images, labels = train_batch
        images = images.cuda()
        labels = labels.cuda()
        logits = self.forward(images)
        loss = self.cross_entropy_loss(logits, labels)
        self.log('train_loss', loss)
        return loss

    def test_step(self,test_batch, batch_idx):
        x, y = test_batch
        y_hat = self.forward(x)
        loss = self.cross_entropy_loss(y_hat, y)
        self.log('test_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=0.0001)
        return optimizer

In [None]:
class LitProgressBar(pl.callbacks.ProgressBar): # Progress bar for seeing the progress during train and test phase

    def init_train_tqdm(self):
        bar = super().init_train_tqdm()
        bar.set_description('running training ...')
        bar.leave = True
        return bar

    def init_test_tqdm(self):
        bar = super().init_test_tqdm()
        bar.set_description('running testing ...')
        bar.leave = True
        return bar


In [None]:
bar = LitProgressBar(refresh_rate=50) # keep refresh rate more than 20

data_module = CIFARDataModule(image_size=224) # image_size needs to change when you change model
data_module.prepare_data() # loads data onto Colab's runtime
test_data = data_module.test_dataloader()
train_data = data_module.train_dataloader()

model = ResidualAttentionModel(n_classes=10) # n_classes need to change when you change dataset

trainer = pl.Trainer(max_epochs=10, gpus=-1, callbacks=[bar], accelerator='ddp', progress_bar_refresh_rate=50)

trainer.fit(model, train_data)

Files already downloaded and verified
Files already downloaded and verified


GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/1

   | Name                | Type                   | Params
----------------------------------------------------------------
0  | conv1               | Sequential             | 9.5 K 
1  | mpool1              | MaxPool2d              | 0     
2  | residual_block1     | ResidualBlock          | 74.1 K
3  | attention_module1   | AttentionModule_stage1 | 1.8 M 
4  | residual_block2     | ResidualBlock          | 377 K 
5  | attention_module2   | AttentionModule_stage2 | 5.4 M 
6  | attention_module2_2 | AttentionModule_stage2 | 5.4 M 
7  | residual_block3     | ResidualBlock          | 1.5 M 
8  | attention_module3   | AttentionModule_stage3 | 15.1 M
9  | attention_module3_2 | AttentionModule_stage3 | 15.1 M
10 | attention_module3_3 | AttentionModule_stage3 | 15.1 M
11 | residual_block4     | ResidualBlock          | 6.0 M 
12 | re

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

In [None]:
# Testing the trained model
trainer.test(test_dataloaders=test_data, verbose=True)

In [None]:
# Saving the trained model as checkpoint
trainer.save_checkpoint('CIFAR10_92.ckpt')

In [None]:
# Loading tensorboard
%load_ext tensorboard
%tensorboard --logdir lightning_logs/