In [1]:
from PIL import Image
import numpy as np
import pandas as pd
import cv2
import os
import random
import matplotlib.pyplot as plt
# import tensorwatch as tw

import torch
import copy
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.optim.lr_scheduler import StepLR

from utils.dataset import *
import torch.utils.data as data

from utils.train import *
from utils.test  import *
from vit_rollout import VITAttentionRollout
import time

In [2]:
# Training settings
####### test作训练集 #######
# epochs = 80
# lr = 3e-6
# gamma = 0.7
# step_size = 5
###### train作训练集 #######
epochs = 20
lr = 6e-5
gamma = 0.9
step_size = 5

seed = 42
device = 'cuda:1'

file_Path = '/home/a611/Projects/Datasets/CSE_v1/images/'
train_name = ['/home/a611/Projects/Datasets/CSE_v1/labels/train_aver_75.csv']
test_name = ['/home/a611/Projects/Datasets/CSE_v1/labels/train_aver_25.csv']
num_classes = 11
num_input = 3
batch_size = 128
num_workers = 8
########################
os.chdir('examples')

In [3]:
def show_mask_on_image(img, mask):
    img = np.float32(img) / 255
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return np.uint8(255 * cam)

In [4]:
# from tensorboardX import SummaryWriter
# writer = SummaryWriter('log') #建立一个保存数据用的东西
model = torch.hub.load('facebookresearch/deit:main', 
    'deit_tiny_patch16_224', pretrained=False)
model.head = nn.Linear(in_features = model.head.in_features, out_features = num_classes, bias = True)
model.to(device);

Using cache found in /home/a611/.cache/torch/hub/facebookresearch_deit_main


In [5]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

## Load Data

In [6]:
train_transforms = transforms.Compose(
    [
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)


test_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)

## Load Datasets

In [7]:
label_map = get_map(test_name)
label_key = list(label_map.keys())
train_set = MyDataset(file_Path, train_name, label_map,
                            train_transforms)
train_loader = data.DataLoader(
    dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

valid_set = MyDataset(file_Path, test_name, label_map,
                            val_transforms)
valid_loader = data.DataLoader(
    dataset=valid_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [8]:
print(len(train_loader))
print(len(valid_loader))

70
24


## Effecient Attention

### Training

In [9]:
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)

In [10]:
model_ = None
highest_test_acc = 0
for i in range(epochs):
    # print('EPOCH:', i + 1)
    train_iter = iter(train_loader)
    test_iter = iter(valid_loader)
    ########################################
    train_loss, train_acc = train(model, device, train_iter, optimizer, train_set, batch_size)
    test_loss, test_acc = test(model, device, test_iter, valid_set, batch_size)
    scheduler.step()
    print( 'EPOCH: %03d, train_loss: %3f, train_acc: %3f, test_loss: %3f, test_acc: %3f'
          % (i + 1, train_loss, train_acc, test_loss, test_acc))
    if test_acc > highest_test_acc:
        highest_test_acc = test_acc
        model_ = copy.deepcopy(model)
        print('Highest test accuracy: %3f' % highest_test_acc)
        torch.save(model_, '../models/CSE_train.model')
#     print( 'EPOCH: %03d, train_loss: %3f, train_acc: %3f' % (i + 1, train_loss, train_acc))

# test_iter = iter(valid_loader)
model = torch.load("/home/a611/Projects/gyc/Local_Features/models/CSE_train.model", map_location=device)

# test_loss, test_acc = test(model, device, test_iter, valid_set, batch_size)
# print( 'test_loss: %3f, test_acc: %3f' % (test_loss, test_acc))

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

EPOCH: 001, train_loss: 2.371268, train_acc: 0.146743, test_loss: 2.439225, test_acc: 0.100511
Highest test accuracy: 0.100511


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

EPOCH: 002, train_loss: 1.909111, train_acc: 0.325280, test_loss: 1.462533, test_acc: 0.446665
Highest test accuracy: 0.446665


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

EPOCH: 003, train_loss: 1.371872, train_acc: 0.522568, test_loss: 1.044037, test_acc: 0.638622
Highest test accuracy: 0.638622


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

EPOCH: 004, train_loss: 0.990020, train_acc: 0.658757, test_loss: 0.765165, test_acc: 0.735527
Highest test accuracy: 0.735527


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

EPOCH: 005, train_loss: 0.763795, train_acc: 0.746806, test_loss: 0.576579, test_acc: 0.793394
Highest test accuracy: 0.793394


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

EPOCH: 006, train_loss: 0.577748, train_acc: 0.813324, test_loss: 0.523372, test_acc: 0.814353
Highest test accuracy: 0.814353


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

EPOCH: 007, train_loss: 0.475132, train_acc: 0.851437, test_loss: 0.265314, test_acc: 0.916416
Highest test accuracy: 0.916416


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

EPOCH: 008, train_loss: 0.385744, train_acc: 0.883013, test_loss: 0.239196, test_acc: 0.923603
Highest test accuracy: 0.923603


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

EPOCH: 009, train_loss: 0.309006, train_acc: 0.905406, test_loss: 0.121582, test_acc: 0.970703
Highest test accuracy: 0.970703


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

EPOCH: 010, train_loss: 0.340245, train_acc: 0.891423, test_loss: 0.207838, test_acc: 0.938176


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

EPOCH: 011, train_loss: 0.278007, train_acc: 0.911481, test_loss: 0.106725, test_acc: 0.974634
Highest test accuracy: 0.974634


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

EPOCH: 012, train_loss: 0.250029, train_acc: 0.918472, test_loss: 0.077222, test_acc: 0.981771
Highest test accuracy: 0.981771


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

EPOCH: 013, train_loss: 0.223556, train_acc: 0.927807, test_loss: 0.097006, test_acc: 0.968750


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

EPOCH: 014, train_loss: 0.223225, train_acc: 0.929370, test_loss: 0.065145, test_acc: 0.985702
Highest test accuracy: 0.985702


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

EPOCH: 015, train_loss: 0.206712, train_acc: 0.932647, test_loss: 0.075130, test_acc: 0.983073


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

EPOCH: 016, train_loss: 0.194484, train_acc: 0.935620, test_loss: 0.051056, test_acc: 0.988281
Highest test accuracy: 0.988281


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

EPOCH: 017, train_loss: 0.184541, train_acc: 0.938593, test_loss: 0.063383, test_acc: 0.983098


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

EPOCH: 018, train_loss: 0.189590, train_acc: 0.938897, test_loss: 0.063067, test_acc: 0.980794


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

EPOCH: 019, train_loss: 0.178787, train_acc: 0.941375, test_loss: 0.039754, test_acc: 0.992513
Highest test accuracy: 0.992513


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

EPOCH: 020, train_loss: 0.165653, train_acc: 0.943615, test_loss: 0.029571, test_acc: 0.996419
Highest test accuracy: 0.996419


In [11]:
def draw(x, output, input_type):
    output_max = output.argmax(dim=1)
    output_numpy = output.cpu().detach().numpy()
    fig, ax = plt.subplots(figsize=(12, 8))
    y = output_numpy[0]
    ax.bar(x = x, height = y)
    ax.set_title('Item: %s.' % x[int(output_max)], fontsize=15);
    xticks = ax.get_xticks()
    for i in range(len(y)):
        xy = (xticks[i], y[i])
        s = '%03f' % y[i]
        ax.annotate(
            text=s,  # 要添加的文本
            xy=xy,  # 将文本添加到哪个位置
            fontsize=8,  # 标签大小
            color="red",  # 标签颜色
            ha="center",  # 水平对齐
            va="baseline"  # 垂直对齐
        )
        plt.savefig("../results_%s.jpg" % input_type)
    return output_max

def result(x, output, input_type):
    output_max = output.argmax(dim=1)
    print('Input: %s; Item: %s.' % (input_type, x[int(output_max)]))
    return output_max

In [13]:
model.eval()
discard_ratio = 0.9
head_fusion = 'max'

# lists = pd.read_csv(train_name[0], header = None)
# lists = lists[0].sample(frac = 1).reset_index(drop = True)
# print(lists)

file_Path = '../sample_pairs/same_rotation_diff_background/'
for input_file_name in os.listdir(file_Path):

    image_name = os.path.join(file_Path,input_file_name)
    image = Image.open(image_name)
    img = image.resize((224, 224))
    input = test_transforms(image).unsqueeze(0)
    input = input.to(device)
    output = model(input)
    # output_max = draw(label_key, output)
    output_max = result(label_key, output, input_file_name)
#     output_max = draw(label_key, output, 'after')

    attention_rollout = VITAttentionRollout(model, head_fusion = head_fusion, discard_ratio = discard_ratio)
    mask = attention_rollout(input)
#     output_name = "outputs/CSE_test/{}-{}.png".format(input_file_name.split('.')[0], label_key[int(output_max)])
#     heatmap_name = "heatmaps/CSE_test/{}-{}.png".format(input_file_name.split('.')[0], label_key[int(output_max)])
#     weight_name = "weights/CSE_test/{}-{}.png".format(input_file_name.split('.')[0], label_key[int(output_max)])
#     #########################################

    np_img = np.array(image)[:, :, ::-1]
    mask = cv2.resize(mask, (np_img.shape[1], np_img.shape[0]))
    mask_ = show_mask_on_image(np_img, mask)

    np_img_ = np.zeros(np_img.shape)
    for t in range(np_img.shape[2]):
        np_img_[:,:,t] = np_img[:,:,t] * mask


    cv2.imwrite('../input.png', np_img);
    cv2.imwrite('../heatmap.png', mask * 255);
#     cv2.imwrite(heatmap_name, mask * 255);
    cv2.imwrite('../output.png', mask_);
#     cv2.imwrite(output_name, mask_);
    cv2.imwrite('../weight.png', np_img_);
#     cv2.imwrite(weight_name, np_img_);
    
    time.sleep(0.7)

Input: plant_20.jpg; Item: headset.
Input: bottle_89.jpg; Item: fleet.
Input: car_80.jpg; Item: headset.
Input: headset_62.jpg; Item: fleet.
Input: pepper_81.jpg; Item: plant.
Input: milk_41.jpg; Item: doll.
Input: doll_86.jpg; Item: doll.
Input: fleet_24.jpg; Item: fleet.
Input: cup_61.jpg; Item: doll.
Input: apple_78.jpg; Item: apple.
Input: container_49.jpg; Item: container.
