In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision.models import resnet34, resnet18
import torch.nn.functional as F
import torchvision.transforms as transforms
import numpy as np
import os
import cv2
from tqdm import tqdm
import matplotlib.pyplot as plt

%matplotlib inline

In [2]:
train_preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset_train = torchvision.datasets.ImageFolder(root='./data/train', transform=train_preprocess)
loader_train = torch.utils.data.DataLoader(dataset_train,
                                           batch_size=64,
                                           shuffle=True,
#                                            num_workers=4
                                          ) 

dataset_test = torchvision.datasets.ImageFolder(root='./data/test', transform=test_preprocess)
loader_test = torch.utils.data.DataLoader(dataset_test,
                                          batch_size=64,
                                          shuffle=True,
#                                           num_workers=4
                                         ) 

In [3]:
class MultiAttentionNetwork(nn.Module):
    def __init__(self, num_classes, num_masks=2):
        super().__init__()

        base_model = resnet18(pretrained=True)
        self.features = nn.Sequential(*[layer for layer in base_model.children()][:-2])
        self.attn_conv = nn.Conv2d(512, num_masks, 1, bias=False)
        nn.init.xavier_uniform_(self.attn_conv.weight)
        self.fc = nn.Sequential(
            nn.Linear(512 * num_masks, 256),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes),
        )
        self.mask_ = None
        self.num_masks = num_masks

    def forward(self, x):
        x = self.features(x)
        
        attn = torch.sigmoid(self.attn_conv(x))  # [B, M, H, W]
        B, _, H, W = attn.shape
        self.mask_ = attn
        
        x = x.reshape(B, 1, 512, H, W)
        attn = attn.reshape(B, self.num_masks, 1, H, W)
        
        x = x * attn  # [B, M, 512, H, W]
        x = x.reshape(B * self.num_masks, -1, H, W)  # [BM, 512, H, W]
        x = F.adaptive_avg_pool2d(x, (1, 1))  # [BM, 512, 1, 1]
        
        x = x.reshape(B, -1)
        
        return self.fc(x)
    
    def divergence_loss(self):
        mask = self.mask_  # [B, M, H, W]
        B, M, H, W = mask.shape
        device = mask.device
        
        flatten_mask = mask.reshape(B, M, -1)
        diag = 1 - torch.eye(M).unsqueeze(0).to(device)  # [1, M, M]
        
        max_val, _ = flatten_mask.max(dim=2, keepdim=True)
        flatten_mask = flatten_mask / (max_val + 1e-2)
        
        div_loss = torch.bmm(flatten_mask, flatten_mask.transpose(1, 2)) * diag  # [B, M, M] x [1, M, M]
        return (div_loss.view(-1) ** 2).mean()
    
    def make_cam(self, img, mask):
        cam = cv2.resize(mask, (224, 224))
        heatmap = (cam - np.min(cam))/(np.max(cam) - np.min(cam))    # 私の自作モデルではこちらを使用
        image = img.transpose(1, 2, 0)
        image -= np.min(image)
        image = np.uint8(255*image)
        image = np.minimum(image, 255)
        cam = cv2.applyColorMap(np.uint8(255*heatmap), cv2.COLORMAP_JET)
        cam = np.float32(cam) + np.float32(image)
        cam = 255 * cam / np.max(cam)
        return np.uint8(cam)[:,:,::-1]        
    
    def save_attention_mask(self, x, path, head=4):
        B = x.shape[0]
        self.forward(x)
        x = x.cpu() * torch.Tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1)
        x = x + torch.Tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1)
        fig, axs = plt.subplots(min(B, head), self.num_masks*2+1, figsize=(16, 2 * min(B, head)), squeeze=False)
        plt.axis('off')
        mask = self.mask_.detach().cpu()
        for i in range(min(B, head)):
            axs[i, 0].imshow(x[i].permute(1, 2, 0))
            for j in range(0, self.num_masks):
                axs[i, j*2+1].imshow(mask[i, j], vmin = 0, vmax = 1)
                cam = self.make_cam(x[i].numpy(), mask[i, j].numpy())
                axs[i, j*2+2].imshow(cam, vmin = 0, vmax = 1)
        plt.savefig(path)
        plt.close()
        return mask

In [4]:
clf_loss_func = torch.nn.CrossEntropyLoss()

In [None]:
gpu_flag = torch.cuda.is_available()
print(gpu_flag)
if gpu_flag:
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

True


In [None]:
def train(model, loader, optimizer, lambda_divergence):
    model.train()
    correct = 0
    total = 0
    losses = []
    for X, y in tqdm(loader):
        X, y = X.to(device), y.to(device)
        clf = model(X)
        loss = clf_loss_func(clf, y)
        loss += lambda_divergence * model.divergence_loss()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        
        predict = clf.argmax(dim=1)
        correct += (predict == y.data).sum()
        total += len(y)
    
    return np.mean(losses), float(correct) / total

In [None]:
def valid(model, loader):
    model.eval()
  
    losses = []
    correct = 0
    total = 0
    with torch.no_grad():
        for X, y in tqdm(loader):
            X, y = X.to(device), y.to(device)
            clf = model(X)
            loss = clf_loss_func(clf, y)

            losses.append(loss.item())

            predict = clf.argmax(dim=1)
            correct += (predict == y).sum().item()
            total += len(y)
            
    return np.mean(losses), float(correct) / total

In [None]:
multi_attention_model = MultiAttentionNetwork(2)
multi_attention_model = multi_attention_model.to(device)

In [None]:
optimizer = torch.optim.SGD(multi_attention_model.parameters(), lr=0.001, momentum=0.9)
best_loss = 1e+10
lambda_divergence = 5e-04
# best_state = None
earlystop_counter = 0
for epoch in range(5):
    train_loss, train_acc = train(multi_attention_model, loader_train, optimizer, lambda_divergence)
    val_loss, val_acc = valid(multi_attention_model, loader_test)

    if val_loss < best_loss:
        best_loss = val_loss
#         best_state = model.cpu().state_dict()
        
    print('Epoch: {}'.format(epoch))
    print("train loss: {:.2f}, train acc: {:.2f}%".format(train_loss, train_acc*100.))
    print("val loss: {:.2f}, val acc: {:.2f}%".format(val_loss, val_acc*100.))

100%|██████████| 79/79 [00:29<00:00,  2.64it/s]
100%|██████████| 16/16 [00:05<00:00,  3.13it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 0
train loss: 0.49, train acc: 78.40%
val loss: 0.14, val acc: 96.60%


100%|██████████| 79/79 [00:29<00:00,  2.65it/s]
100%|██████████| 16/16 [00:04<00:00,  3.40it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 1
train loss: 0.13, train acc: 96.92%
val loss: 0.07, val acc: 97.80%


100%|██████████| 79/79 [00:29<00:00,  2.69it/s]
100%|██████████| 16/16 [00:04<00:00,  3.31it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 2
train loss: 0.07, train acc: 98.20%
val loss: 0.05, val acc: 98.00%


100%|██████████| 79/79 [00:29<00:00,  2.72it/s]
100%|██████████| 16/16 [00:04<00:00,  3.21it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 3
train loss: 0.04, train acc: 99.00%
val loss: 0.05, val acc: 98.30%


 94%|█████████▎| 74/79 [00:27<00:01,  2.77it/s]

In [None]:
for img, target in loader_test:
    img = img.to(device)
    mask = multi_attention_model.save_attention_mask(img, os.path.join('.', 'out.png'))
    break