In [1]:
from PIL import Image
import numpy as np
import cv2
import os
import random
# import tensorwatch as tw

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt

from utils.dataset import *
import torch.utils.data as data

from vit_rollout import VITAttentionRollout

from utils.train import *
from utils.test  import *
from utils.CLS2IDX import CLS2IDX

In [2]:
# Training settings
epochs = 1
lr = 3e-6
gamma = 0.7
seed = 42
device = 'cpu'

file_Path = '/home/a611/Projects/Datasets/CSE_v1/images/'
train_name = ['/home/a611/Projects/Datasets/CSE_v1/labels/train.csv']
test_name = ['/home/a611/Projects/Datasets/CSE_v1/labels/test.csv']
num_classes = 11
num_input = 3
batch_size = 128
num_workers = 8
########################
os.chdir('examples')

In [3]:
def show_mask_on_image(img, mask):
    img = np.float32(img) / 255
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return np.uint8(255 * cam)

In [4]:
# from tensorboardX import SummaryWriter
# writer = SummaryWriter('log') #建立一个保存数据用的东西
model = torch.hub.load('facebookresearch/deit:main', 
    'deit_tiny_patch16_224', pretrained=True)
# model = torch.load('../models/CSE_train.model')
# model.head = nn.Linear(in_features = model.head.in_features, out_features = num_classes, bias = True)
model.to(device);

Using cache found in /home/a611/.cache/torch/hub/facebookresearch_deit_main


## Load Data

In [5]:
train_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)


test_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)

In [6]:
def draw(x, output):
    output_max = output.argmax(dim=1)
    output_numpy = output.cpu().detach().numpy()
    fig, ax = plt.subplots(figsize=(12, 8))
    y = output_numpy[0]
    ax.bar(x = x, height = y)
    ax.set_title('Item: %s.' % x[int(output_max)], fontsize=15);
    xticks = ax.get_xticks()
#     for i in range(len(y)):
#         xy = (xticks[i], y[i])
#         s = '%03f' % y[i]
#         ax.annotate(
#             text=s,  # 要添加的文本
#             xy=xy,  # 将文本添加到哪个位置
#             fontsize=8,  # 标签大小
#             color="red",  # 标签颜色
#             ha="center",  # 水平对齐
#             va="baseline"  # 垂直对齐
#         )
    return output_max

def result(x, output, input_type):
    output_max = output.argmax(dim=1)
    print('Input: %s; Item: %s.' % (input_type, x[int(output_max)]))
    return output_max

In [7]:
model.eval()
discard_ratio = 0.9
head_fusion = 'max'

#########################################
for input_file_name in os.listdir('inputs'):
#     input_file_name = 'iPod_1.jpg'
    image_name = os.path.join('inputs',input_file_name)
    image = Image.open(image_name)
    img = image.resize((224, 224))
    input = test_transforms(image).unsqueeze(0)
    input = input.to(device)
    output = model(input)
    # output_max = draw(list(CLS2IDX.values()), output)
    output_max = result(list(CLS2IDX.values()), output, 'origin')

    #########################################
    # How to fuse the attention heads for attention rollout. Can be mean/max/min.
    attention_rollout = VITAttentionRollout(model, head_fusion = head_fusion, discard_ratio = discard_ratio)
    mask = attention_rollout(input)
    # name = "outputs/{}_{}_A-R_{:.3f}_{}.png".format(input_file_name.split('.')[0], label_key[int(output_max)], discard_ratio, head_fusion)
    output_name = "outputs/Image-Net/{}:{}.png".format(input_file_name.split('.')[0], CLS2IDX[int(output_max)])
    heatmap_name = "heatmaps/Image-Net/{}:{}.png".format(input_file_name.split('.')[0], CLS2IDX[int(output_max)])
    weight_name = "weights/Image-Net/{}:{}.png".format(input_file_name.split('.')[0], CLS2IDX[int(output_max)])
    #########################################

    np_img = np.array(image)[:, :, ::-1]
    mask = cv2.resize(mask, (np_img.shape[1], np_img.shape[0]))
    mask_ = show_mask_on_image(np_img, mask)

    np_img_ = np.zeros(np_img.shape)
    for t in range(np_img.shape[2]):
        np_img_[:,:,t] = np_img[:,:,t] * mask
        
    input_ = Image.fromarray(np.uint8(np_img_))
    input_ = test_transforms(input_).unsqueeze(0)
    input_ = input_.to(device)
    output_ = model(input_)
    output_max_ = result(list(CLS2IDX.values()), output_, 'weight')

    cv2.imwrite('../input.png', np_img);
    cv2.imwrite('../heatmap.png', mask * 255);
    cv2.imwrite(heatmap_name, mask * 255);
    cv2.imwrite('../output.png', mask_);
    cv2.imwrite(output_name, mask_);
    cv2.imwrite('../weight.png', np_img_);
    cv2.imwrite(weight_name, np_img_);

Input: origin; Item: coffee mug.
Input: weight; Item: coffee mug.
Input: origin; Item: beagle.
Input: weight; Item: Brittany spaniel.
Input: origin; Item: toilet tissue, toilet paper, bathroom tissue.
Input: weight; Item: toilet tissue, toilet paper, bathroom tissue.
Input: origin; Item: coffeepot.
Input: weight; Item: coffee mug.
Input: origin; Item: iPod.
Input: weight; Item: pill bottle.
Input: origin; Item: pomegranate.
Input: weight; Item: Granny Smith.
Input: origin; Item: aircraft carrier, carrier, flattop, attack aircraft carrier.
Input: weight; Item: table lamp.
Input: origin; Item: traffic light, traffic signal, stoplight.
Input: weight; Item: cab, hack, taxi, taxicab.
Input: origin; Item: lemon.
Input: weight; Item: nematode, nematode worm, roundworm.
Input: origin; Item: palace.
Input: weight; Item: palace.
Input: origin; Item: iPod.
Input: weight; Item: iPod.
Input: origin; Item: banana.
Input: weight; Item: nematode, nematode worm, roundworm.
Input: origin; Item: trench c