In [1]:
import os, sys
project_dir = os.path.join(os.getcwd(),'..')
if project_dir not in sys.path:
    sys.path.append(project_dir)


sparse_dir = os.path.join(project_dir, 'modules/Sparse')
if sparse_dir not in sys.path:
    sys.path.append(sparse_dir)

import torch
from torch import nn
import numpy as np

In [2]:
from AttentionMap.AttentionMap import LinearAttentionMap, GridAttentionBlock
from AttentionMap.Projector import ProjectorBlock

class Model(nn.Module):
    def __init__(self, input_size, in_channels, n_classes, attention=True, normalize_attn=True) -> None:
        super(Model, self).__init__()
        self.l1 = nn.Sequential(*[
            nn.Conv2d(in_channels, 16, 3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.SELU(),
            nn.MaxPool2d(2),
        ])

        self.l2 = nn.Sequential(*[
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.SELU(),
            nn.MaxPool2d(2),
        ])

        self.l3 = nn.Sequential(*[
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.SELU(),
            nn.MaxPool2d(2),
        ])


        self.g_descriptor = nn.Sequential(*[
            nn.Conv2d(256, 256, kernel_size=int(input_size/2**3), padding=0)
        ]) 

        self.attention = attention
        if attention:
            # self.proj_l1 = ProjectorBlock(64, 56)
            self.attn_l1 = LinearAttentionMap(64, normalize=normalize_attn)
            
            self.proj_l2 = ProjectorBlock(128, 64)
            self.attn_l2 = LinearAttentionMap(64, normalize=normalize_attn)
            
            self.proj_l3 = ProjectorBlock(256, 64)
            self.attn_l3 = LinearAttentionMap(64, normalize=normalize_attn)
            
            self.proj_g = ProjectorBlock(256, 64)

            self.classify = nn.Linear(64*3, n_classes, bias=True)
        else:
            self.classify = nn.Linear(64, n_classes, bias=True)

        for module in self.modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                torch.nn.init.kaiming_uniform_(module.weight)

    def forward(self, x) -> torch.Tensor:

        l1_out = self.l1(x)
        l2_out = self.l2(l1_out)
        l3_out = self.l3(l2_out)
        g = self.g_descriptor(l3_out)

        if self.attention:
            g_proj = self.proj_g(g)
            c1, g1 = self.attn_l1(l1_out, g_proj)
            c2, g2 = self.attn_l2(self.proj_l2(l2_out), g_proj)
            c3, g3 = self.attn_l3(self.proj_l3(l3_out), g_proj)

            g = torch.cat((g1,g2,g3), dim=1) # batch_sizexC
            # classification layer
            y_hat = self.classify(g) # batch_sizexnum_classes
        else:
            c1, c2, c3 = None, None, None
            y_hat = self.classify(g)

        return (y_hat, c1, c2, c3)



# Dataset

In [3]:
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST, CIFAR10
from torchvision.transforms import Compose, ToTensor

transform = Compose([ToTensor()])
# dataset = MNIST('dataset/', train=True, transform=transform, download=True)
train_dataset = CIFAR10('dataset/', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = CIFAR10('dataset/', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
from torch.utils.tensorboard import SummaryWriter
import torchvision.utils as utils
from AttentionMap.utils import visualize_attention

model = Model(32, 3, 10, normalize_attn=True)
model.train()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

tb_writer = SummaryWriter('log/LinearAttention')
running_avg_accuracy = 0
step = 0

# TMP
log_images = True

for epoch in range(30):
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        # log scalars
        images_disp = []
        for i, data in enumerate(test_loader, 0):
                images_test, labels_test = data
                images_test, labels_test = images_test.to(device), labels_test.to(device)
                pred_test, c1, c2, c3 = model(images_test)
                predict = torch.argmax(pred_test, 1)
                total += labels_test.size(0)
                correct += torch.eq(predict, labels_test).sum().double().item()
                
        tb_writer.add_scalar('test/accuracy', correct/total, epoch)

        n_rows=2
        activation = 'softmax'
        # C1
        scale_factor = 2**1
        vis = visualize_attention(n_rows, images_test[:4], c1[:4], scale_factor, activation=activation)
        tb_writer.add_image('Attention/C1', vis, epoch)

        # C2
        scale_factor = 2**2
        vis = visualize_attention(n_rows, images_test[:4], c2[:4], scale_factor, activation=activation)
        tb_writer.add_image('Attention/C2', vis, epoch)

        # C3
        scale_factor = 2**3
        vis = visualize_attention(n_rows, images_test[:4], c3[:4], scale_factor, activation=activation)
        tb_writer.add_image('Attention/C3', vis, epoch)

    model.train()
    for idx, (inputs, targets) in enumerate(train_loader):
        optimizer.zero_grad()

        inputs = inputs.to(device)
        targets = targets.to(device)
        pred, __, __, __ = model(inputs)

        loss = criterion(pred, targets)
        loss.backward()
        optimizer.step()

        if idx % 10:
            model.eval()
            pred, __, __, __ = model(inputs)
            predict = torch.argmax(pred, 1)
            total = targets.size(0)
            correct = torch.eq(predict, targets).sum().double().item()
            accuracy = correct / total
            running_avg_accuracy = 0.9*running_avg_accuracy + 0.1*accuracy
            tb_writer.add_scalar('train/loss', loss.item(), step)
            tb_writer.add_scalar('train/accuracy', accuracy, step)
            tb_writer.add_scalar('train/running_avg_accuracy', running_avg_accuracy, step)
            step += 1