In [1]:
from PIL import Image
import numpy as np
import pandas as pd
import cv2
import os
import random
import matplotlib.pyplot as plt
# import tensorwatch as tw
import time

import torch
import copy
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.optim.lr_scheduler import StepLR

from utils.dataset import *
import torch.utils.data as data

from utils.train import *
from utils.test  import *
from utils.model_select import model_select

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from tqdm import tqdm_notebook as tqdm

from collections import OrderedDict

import warnings
warnings.filterwarnings("ignore")

In [2]:
# Training settings
####### test作训练集 #######
# epochs = 80
# lr = 3e-6
# gamma = 0.7
# step_size = 5
###### train作训练集 #######
epochs = 5
lr = 2e-6
gamma = 0.8
step_size = 2

seed = 42
device = 'cpu'

file_Path = '/home/a611/Projects/Datasets/NICO/'
train_name = ['/home/a611/Projects/Datasets/NICO/labels/NICO_animal_train.csv']
test_name = ['/home/a611/Projects/gyc/Local_Features/csvs/NICO_animal_test_with_cage_with_dog_eating_without_monkey.csv']
num_classes = 10
num_input = 3
batch_size = 8
num_workers = 8
########################
os.chdir('examples')

In [3]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

## Load Data

In [4]:
train_transforms = transforms.Compose(
    [
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)


test_transforms = transforms.Compose([transforms.Resize(224),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

## Load Datasets

In [5]:
label_map = get_map(train_name)
label_key = list(label_map.keys())
train_set = MyDataset(file_Path, train_name, label_map,
                            train_transforms)
train_loader = data.DataLoader(
    dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

valid_set = MyDataset(file_Path, test_name, label_map,
                            val_transforms)
valid_loader = data.DataLoader(
    dataset=valid_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
print(label_key)

['bear', 'bird', 'cat', 'cow', 'dog', 'elephant', 'horse', 'monkey', 'rat', 'sheep']


In [6]:
print(len(train_loader))
print(len(valid_loader))

1291
338


## Effecient Attention

### Training

In [9]:
model = model_select('ResNet18', num_input, 90).to(device)
# print(model)
model = torch.nn.DataParallel(model.to(device))
checkpoint = torch.load('/home/a611/Projects/gyc/fewshot/model/pretrained/test_acc_0.766925_epoch_159', map_location = device)
model.load_state_dict(checkpoint)
model.module.linear = nn.Linear(512, num_classes).to(device)
######################## 下边这个！！！！ ########################
pth_file = '/home/a611/Projects/gyc/Local_Features/models/without_cage_without_dog_eating_without_monkey_ResNet18_random_1_0.01/test_acc_0.918826_epoch_59'
model_state_dict = torch.load(pth_file, map_location = device)
model.load_state_dict(model_state_dict)
model = model.module
model = model.to(device)

# print(model)

# model.load_state_dict(torch.load(pth_file, map_location = device))

# model.module.linear = nn.Linear(512 , num_classes)
# model = model.module
# model = model.to(device)
# print(model)

# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)

In [10]:


# print(model)
# print(model.layer4)

In [11]:
def draw(x, output, input_type):
    output_max = output.argmax(dim=1)
    output_numpy = output.cpu().detach().numpy()
    fig, ax = plt.subplots(figsize=(12, 8))
    y = output_numpy[0]
    ax.bar(x = x, height = y)
    ax.set_title('Item: %s.' % x[int(output_max)], fontsize=15);
    xticks = ax.get_xticks()
    for i in range(len(y)):
        xy = (xticks[i], y[i])
        s = '%03f' % y[i]
        ax.annotate(
            text=s,  # 要添加的文本
            xy=xy,  # 将文本添加到哪个位置
            fontsize=8,  # 标签大小
            color="red",  # 标签颜色
            ha="center",  # 水平对齐
            va="baseline"  # 垂直对齐
        )
        plt.savefig("../results/%s_%s.jpg" % (input_type, x[int(output_max)]))
    return output_max

def result(x, output, input_type):
    output_max = output.argmax(dim=1)
    print('Input: %s; Item: %s.' % (input_type, x[int(output_max)]))
    return int(output_max)

In [12]:
test_list = pd.read_csv(test_name[0], header=None)

# for i in range(len(test_list)):
#     print(test_list[0][i])

In [13]:
import re

model.eval()
target_layers = [model.layer4[-1]]
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
#########################################
# # for input_file_name in os.listdir('inputs'):
# input_file_name = 'doll_86.jpg'
# image_name = os.path.join('inputs',input_file_name)
# image = Image.open(image_name)
# img = image.resize((224, 224))
# input = test_transforms(image).unsqueeze(0)
# input = input.to(device)
# output = model(input)
# # output_max = draw(label_key, output)
# output_max = result(label_key, output, 'origin')
# output_max = draw(label_key, output, 'origin')
total = 0
acc = 0

dataset_type_path = '../NICO_OOD/%s/' % (test_name[0][61:-4] if len(test_name[0][61:-4]) else 'origin')
if not os.path.exists(dataset_type_path):
    os.makedirs(dataset_type_path)
    
with open('../logs/log_%s.csv' % (test_name[0][61:-4] if len(test_name[0][61:-4]) else 'origin'), 'w') as logfile:
    logfile.writelines('File Path,Real Label,Predict Label,Sub Label\n')
    for i in tqdm(range(len(test_list))):
        item = test_list[0][i]
        split = re.split('[/.]', item)
        class_name, subclass_name, pic_name = (split[1], split[2], split[3])

        if not os.path.exists('%s%s/%s/wrong/' % (dataset_type_path, class_name, subclass_name)):
            os.makedirs('%s%s/%s/wrong/' % (dataset_type_path, class_name, subclass_name))
        if not os.path.exists('%s%s/%s/right/' % (dataset_type_path, class_name, subclass_name)):
            os.makedirs('%s%s/%s/right/' % (dataset_type_path, class_name, subclass_name))
            
        real_label = test_list[1][i]
        image_name = os.path.join(file_Path, item)
        image = Image.open(image_name).convert('RGB')
        img = image.resize((224, 224))
        input = test_transforms(image).unsqueeze(0)
        input = input.to(device)
        output = model(input)
#         print(output)
        # output_max = draw(label_key, output)
        predict_index = int(output.argmax(dim=1))
#         print(predict_index)
        #         output_max = draw(label_key, output, '%s, %s.' %(catagory, item))
        real_index = label_key.index(real_label)
        predict_label = label_key[predict_index]
#         print(real_label, predict_label)
    #     print(predict_index)

        targets = [ClassifierOutputTarget(predict_index)]
        grayscale_cam = cam(input_tensor=input, targets = targets)
        grayscale_cam = grayscale_cam[0, :]
        image = np.array(image, dtype=np.float32) / 255
        grayscale_cam = cv2.resize(grayscale_cam, (image.shape[1], image.shape[0]))
        visualization = show_cam_on_image(image, grayscale_cam, use_rgb=True)

        im = Image.fromarray(visualization)
        
        if predict_label != real_label:

            im.save('%s%s/%s/wrong/%s_%s.jpg' % (dataset_type_path, class_name, subclass_name, pic_name, predict_label))

        if predict_label == real_label:
            
            im.save('%s%s/%s/right/%s_%s.jpg' % (dataset_type_path, class_name, subclass_name, pic_name, predict_label))
            acc += 1
            
        total += 1
        logfile.write('%s,%s,%s,%s\n' % (item, real_label, predict_label, subclass_name))
    if total:
        print('---------------------- Acc: {}%. ----------------------'.format(format(acc/total * 100, '.2f')))

  0%|          | 0/2701 [00:00<?, ?it/s]

---------------------- Acc: 73.97%. ----------------------


In [None]:
# error_file_name = "/home/a611/Projects/Datasets/NICO/Animal/bear/on snow/15.png"
# error_image = Image.open(error_file_name)
# error_image = np.array(error_image)
# print(error_image.shape)