In [1]:
import torch
from torch import nn
from torchvision.models import resnet50
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import time
import numpy as np
import matplotlib.pyplot as plt
from torchvision.ops import FeaturePyramidNetwork
import torch.nn.functional as F
from torchvision.ops import FeaturePyramidNetwork

class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.softmax = nn.Softmax(dim=-2)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        proj_query = self.query_conv(x).view(x.shape[0], -1, x.shape[2]*x.shape[3]).permute(0,2,1)
        proj_key = self.key_conv(x).view(x.shape[0], -1, x.shape[2]*x.shape[3])
        attention = self.softmax(torch.bmm(proj_query, proj_key))
        proj_value = self.value_conv(x).view(x.shape[0], -1, x.shape[2]*x.shape[3])
        out = torch.bmm(proj_value, attention.permute(0,2,1)).view(x.shape)
        out = self.gamma*out + x
        return out, attention

class AttentionResNet(nn.Module):
    def __init__(self, num_classes):
        super(AttentionResNet, self).__init__()
        self.resnet = resnet50(pretrained=True)
        self.self_attention = SelfAttention(self.resnet.fc.in_features)
        self.fc = nn.Linear(self.resnet.fc.in_features, num_classes)

    def forward(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)

        x, attention_map = self.self_attention(x)
        
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x, attention_map

In [2]:
from torchviz import make_dot

model = AttentionResNet(num_classes=10)
x = torch.randn(1, 3, 224, 224)
y, _ = model(x)
make_dot(y).render("AttentionResNet", format="png")




'AttentionResNet.png'