In [1]:
from PIL import Image
import numpy as np
import pandas as pd
import cv2
import os
import random
import matplotlib.pyplot as plt
# import tensorwatch as tw

import torch
import copy
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.optim.lr_scheduler import StepLR

from utils.dataset import *
import torch.utils.data as data

from utils.train import *
from utils.test  import *
from vit_rollout import VITAttentionRollout

In [2]:
# Training settings
####### test作训练集 #######
# epochs = 80
# lr = 3e-6
# gamma = 0.7
# step_size = 5
###### train作训练集 #######
epochs = 20
lr = 3e-5
gamma = 0.8
step_size = 5

seed = 42
device = 'cuda:1'

file_Path = '/home/a611/Projects/Datasets/CSE_v1/images/'
train_name = ['/home/a611/Projects/Datasets/CSE_v1/labels/train_aver_75.csv']
test_name = ['/home/a611/Projects/Datasets/CSE_v1/labels/train_aver_25.csv']
num_classes = 11
num_input = 3
batch_size = 16
num_workers = 8
########################
os.chdir('examples')

In [3]:
def show_mask_on_image(img, mask):
    img = np.float32(img) / 255
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return np.uint8(255 * cam)

In [4]:
# from tensorboardX import SummaryWriter
# writer = SummaryWriter('log') #建立一个保存数据用的东西
model = torch.hub.load('facebookresearch/deit:main', 
    'deit_tiny_patch16_224', pretrained=False)
model.head = nn.Linear(in_features = model.head.in_features, out_features = num_classes, bias = True)
model.to(device);

Using cache found in /home/a611/.cache/torch/hub/facebookresearch_deit_main


In [5]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

## Load Data

In [6]:
train_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)


test_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)

## Load Datasets

In [7]:
label_map = get_map(test_name)
label_key = list(label_map.keys())
train_set = MyDataset(file_Path, train_name, label_map,
                            train_transforms)
train_loader = data.DataLoader(
    dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

valid_set = MyDataset(file_Path, test_name, label_map,
                            val_transforms)
valid_loader = data.DataLoader(
    dataset=valid_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [8]:
print(len(train_loader))
print(len(valid_loader))

557
186


## Effecient Attention

### Training

In [9]:
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)

In [10]:
model_ = None
highest_test_acc = 0
for i in range(epochs):
    # print('EPOCH:', i + 1)
    train_iter = iter(train_loader)
    test_iter = iter(valid_loader)
    ########################################
    train_loss, train_acc = train(model, device, train_iter, optimizer, train_set, batch_size)
    test_loss, test_acc = test(model, device, test_iter, valid_set, batch_size)
    scheduler.step()
    print( 'EPOCH: %03d, train_loss: %3f, train_acc: %3f, test_loss: %3f, test_acc: %3f'
          % (i + 1, train_loss, train_acc, test_loss, test_acc))
    if test_acc > highest_test_acc:
        highest_test_acc = test_acc
        model_ = copy.deepcopy(model)
        print('Highest test accuracy: %3f' % highest_test_acc)
        torch.save(model_, '../models/CSE_train.model')
#     print( 'EPOCH: %03d, train_loss: %3f, train_acc: %3f' % (i + 1, train_loss, train_acc))



  0%|          | 0/557 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

EPOCH: 001, train_loss: 1.078187, train_acc: 0.599304, test_loss: 0.359867, test_acc: 0.892137
Highest test accuracy: 0.892137


  0%|          | 0/557 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

EPOCH: 002, train_loss: 0.128822, train_acc: 0.960390, test_loss: 0.072282, test_acc: 0.978965
Highest test accuracy: 0.978965


  0%|          | 0/557 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

EPOCH: 003, train_loss: 0.052340, train_acc: 0.983505, test_loss: 0.037363, test_acc: 0.991532
Highest test accuracy: 0.991532


  0%|          | 0/557 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

EPOCH: 004, train_loss: 0.026039, train_acc: 0.993268, test_loss: 0.090712, test_acc: 0.979167


  0%|          | 0/557 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

EPOCH: 005, train_loss: 0.026961, train_acc: 0.993716, test_loss: 0.004210, test_acc: 1.000000
Highest test accuracy: 1.000000


  0%|          | 0/557 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

EPOCH: 006, train_loss: 0.008207, train_acc: 0.998205, test_loss: 0.003884, test_acc: 1.000000


  0%|          | 0/557 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
def draw(x, output, input_type):
    output_max = output.argmax(dim=1)
    output_numpy = output.cpu().detach().numpy()
    fig, ax = plt.subplots(figsize=(12, 8))
    y = output_numpy[0]
    ax.bar(x = x, height = y)
    ax.set_title('Item: %s.' % x[int(output_max)], fontsize=15);
    xticks = ax.get_xticks()
    for i in range(len(y)):
        xy = (xticks[i], y[i])
        s = '%03f' % y[i]
        ax.annotate(
            text=s,  # 要添加的文本
            xy=xy,  # 将文本添加到哪个位置
            fontsize=8,  # 标签大小
            color="red",  # 标签颜色
            ha="center",  # 水平对齐
            va="baseline"  # 垂直对齐
        )
        plt.savefig("../results_%s.jpg" % input_type)
    return output_max

def result(x, output, input_type):
    output_max = output.argmax(dim=1)
    print('Input: %s; Item: %s.' % (input_type, x[int(output_max)]))
    return output_max

In [None]:
model.eval()
discard_ratio = 0.9
head_fusion = 'max'

#########################################
# for input_file_name in os.listdir('inputs'):
input_file_name = 'building_1.jpg'
image_name = os.path.join('inputs',input_file_name)
image = Image.open(image_name)
img = image.resize((224, 224))
input = test_transforms(image).unsqueeze(0)
input = input.to(device)
output = model(input)
# output_max = draw(label_key, output)
output_max = result(label_key, output, 'origin')
output_max = draw(label_key, output, 'origin')

# How to fuse the attention heads for attention rollout. Can be mean/max/min.
attention_rollout = VITAttentionRollout(model, head_fusion = head_fusion, discard_ratio = discard_ratio)
mask = attention_rollout(input)
# name = "outputs/{}_{}_A-R_{:.3f}_{}.png".format(input_file_name.split('.')[0], label_key[int(output_max)], discard_ratio, head_fusion)
output_name = "outputs/CSE_test/{}-{}.png".format(input_file_name.split('.')[0], label_key[int(output_max)])
heatmap_name = "heatmaps/CSE_test/{}-{}.png".format(input_file_name.split('.')[0], label_key[int(output_max)])
weight_name = "weights/CSE_test/{}-{}.png".format(input_file_name.split('.')[0], label_key[int(output_max)])
#########################################

np_img = np.array(image)[:, :, ::-1]
mask = cv2.resize(mask, (np_img.shape[1], np_img.shape[0]))
mask_ = show_mask_on_image(np_img, mask)

np_img_ = np.zeros(np_img.shape)
for t in range(np_img.shape[2]):
    np_img_[:,:,t] = np_img[:,:,t] * mask

input_ = Image.fromarray(np.uint8(np_img_))
input_ = test_transforms(input_).unsqueeze(0)
input_ = input_.to(device)
output_ = model(input_)
output_max_ = result(label_key, output_, 'weight')
output_max_ = draw(label_key, output_, 'weight')

cv2.imwrite('../input.png', np_img);
cv2.imwrite('../heatmap.png', mask * 255);
cv2.imwrite(heatmap_name, mask * 255);
cv2.imwrite('../output.png', mask_);
cv2.imwrite(output_name, mask_);
cv2.imwrite('../weight.png', np_img_);
cv2.imwrite(weight_name, np_img_);