In [1]:
from PIL import Image
import numpy as np
import pandas as pd
import cv2
import os
import random
import matplotlib.pyplot as plt
# import tensorwatch as tw

import torch
import copy
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler

from utils.dataset import *
import torch.utils.data as data

from utils.train import *
from utils.test  import *
from utils.model_select import model_select

import warnings
warnings.filterwarnings("ignore")

In [2]:
# Training settings
####### test作训练集 #######
# epochs = 80
# lr = 3e-6
# gamma = 0.7
# step_size = 5
###### train作训练集 #######
epochs = 100
lr = 0.008
gamma = 0.1
step_size = 30

seed = 42
device = 'cuda:1'

file_Path = '/home/a611/Projects/Datasets/NICO/'
train_name = ['/home/a611/Projects/gyc/Local_Features/csvs/NICO_animal_train_without_cage_without_dog_eating_IID_train.csv']
test_name = ['/home/a611/Projects/gyc/Local_Features/csvs/NICO_animal_train_without_cage_without_dog_eating_IID_test.csv']
num_classes = 10
num_input = 3
batch_size = 32
num_workers = 8
########################
os.chdir('examples')

In [3]:
# # from tensorboardX import SummaryWriter
# # writer = SummaryWriter('log') #建立一个保存数据用的东西
# model = torch.hub.load('facebookresearch/deit:main', 
#     'deit_tiny_patch16_224', pretrained=False)
# model.head = nn.Linear(in_features = model.head.in_features, out_features = num_classes, bias = True)
# model.to(device);

In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

## Load Data

In [5]:
train_transforms = transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

val_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)


test_transforms = transforms.Compose([transforms.Resize(224),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

## Load Datasets

In [6]:
label_map = get_map(test_name)
label_key = list(label_map.keys())
train_set = MyDataset(file_Path, train_name, label_map,
                            train_transforms)
train_loader = data.DataLoader(
    dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

valid_set = MyDataset(file_Path, test_name, label_map,
                            val_transforms)
valid_loader = data.DataLoader(
    dataset=valid_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [7]:
print(len(train_loader))
print(len(valid_loader))

291
32


## Effecient Attention

### Training

In [8]:
model = model_select('ResNet18', num_input, num_classes).to('cpu')
# print(model)
model = torch.nn.DataParallel(model)
model_state_disk = torch.load('/home/a611/Projects/gyc/Local_Features/models/test_acc_0.743311_epoch_33', map_location = 'cpu')
model.load_state_dict(model_state_disk)
model = model.module
model.linear = torch.nn.Linear(512, num_classes)

model = model.to(device)
# print(model)

# loss function
criterion = nn.CrossEntropyLoss()
# # optimizer
# optimizer = optim.Adam(model.parameters(), lr=lr)
# # scheduler
# scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)

optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9,weight_decay=5e-4)  # 优化器
decay_epoch = [30, 60, 80, 100] 
#decay_epoch = [10, 20, 30, 40]
scheduler = lr_scheduler.MultiStepLR(optimizer,
                                     milestones=decay_epoch, gamma=0.1)

In [9]:
model_ = None
highest_test_acc = 0
for i in range(epochs):
    # print('EPOCH:', i + 1)
    train_iter = iter(train_loader)
    test_iter = iter(valid_loader)
    ########################################
    train_loss, train_acc = train(model, device, train_iter, optimizer, train_set, batch_size)
    test_loss, test_acc = test(model, device, test_iter, valid_set, batch_size)
    scheduler.step()
    print( 'EPOCH: %03d, train_loss: %3f, train_acc: %3f, test_loss: %3f, test_acc: %3f'
          % (i + 1, train_loss, train_acc, test_loss, test_acc))
    if test_acc > highest_test_acc:
        highest_test_acc = test_acc
        model_ = copy.deepcopy(model)
        print('Highest test accuracy: %3f' % highest_test_acc)
        torch.save(model_, '../models/NICO_train_without_cage_without_dog_eating.model')
#     print( 'EPOCH: %03d, train_loss: %3f, train_acc: %3f' % (i + 1, train_loss, train_acc))




  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 001, train_loss: 0.799953, train_acc: 0.788061, test_loss: 1.942819, test_acc: 0.337695
Highest test accuracy: 0.337695


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 002, train_loss: 0.540650, train_acc: 0.832735, test_loss: 1.584399, test_acc: 0.451953
Highest test accuracy: 0.451953


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 003, train_loss: 0.534370, train_acc: 0.831078, test_loss: 1.873239, test_acc: 0.341211


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 004, train_loss: 0.530047, train_acc: 0.833594, test_loss: 1.780200, test_acc: 0.417383


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 005, train_loss: 0.548264, train_acc: 0.823883, test_loss: 1.407111, test_acc: 0.520117
Highest test accuracy: 0.520117


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 006, train_loss: 0.520687, train_acc: 0.829191, test_loss: 2.303482, test_acc: 0.277344


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 007, train_loss: 0.537876, train_acc: 0.825908, test_loss: 1.650673, test_acc: 0.441797


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 008, train_loss: 0.545971, train_acc: 0.823822, test_loss: 1.820421, test_acc: 0.365430


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 009, train_loss: 0.524919, train_acc: 0.829989, test_loss: 1.978316, test_acc: 0.391602


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 010, train_loss: 0.494124, train_acc: 0.837077, test_loss: 1.365616, test_acc: 0.565625
Highest test accuracy: 0.565625


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 011, train_loss: 0.507614, train_acc: 0.835788, test_loss: 1.678012, test_acc: 0.414844


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 012, train_loss: 0.493181, train_acc: 0.837307, test_loss: 1.699867, test_acc: 0.435742


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 013, train_loss: 0.492353, train_acc: 0.839347, test_loss: 1.900395, test_acc: 0.394922


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 014, train_loss: 0.467378, train_acc: 0.848245, test_loss: 1.494960, test_acc: 0.499219


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 015, train_loss: 0.493407, train_acc: 0.840360, test_loss: 1.652470, test_acc: 0.428516


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 016, train_loss: 0.475858, train_acc: 0.845990, test_loss: 1.778831, test_acc: 0.415820


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 017, train_loss: 0.474750, train_acc: 0.845284, test_loss: 1.622630, test_acc: 0.441797


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 018, train_loss: 0.464262, train_acc: 0.849764, test_loss: 1.778514, test_acc: 0.437891


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 019, train_loss: 0.457328, train_acc: 0.852372, test_loss: 1.634171, test_acc: 0.460742


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 020, train_loss: 0.461731, train_acc: 0.850025, test_loss: 1.869103, test_acc: 0.374414


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 021, train_loss: 0.464918, train_acc: 0.847877, test_loss: 1.782679, test_acc: 0.405469


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 022, train_loss: 0.438486, train_acc: 0.857204, test_loss: 1.923220, test_acc: 0.373242


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 023, train_loss: 0.448246, train_acc: 0.852448, test_loss: 1.606714, test_acc: 0.437891


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 024, train_loss: 0.427327, train_acc: 0.863233, test_loss: 1.486638, test_acc: 0.501563


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 025, train_loss: 0.423513, train_acc: 0.864200, test_loss: 2.118377, test_acc: 0.323047


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 026, train_loss: 0.420247, train_acc: 0.858662, test_loss: 1.945549, test_acc: 0.387500


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 027, train_loss: 0.422384, train_acc: 0.860318, test_loss: 2.154645, test_acc: 0.354297


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 028, train_loss: 0.434296, train_acc: 0.859306, test_loss: 1.761313, test_acc: 0.396289


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 029, train_loss: 0.401710, train_acc: 0.873220, test_loss: 1.825788, test_acc: 0.462109


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 030, train_loss: 0.424385, train_acc: 0.863019, test_loss: 1.753773, test_acc: 0.434375


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 031, train_loss: 0.330404, train_acc: 0.898779, test_loss: 1.494504, test_acc: 0.494922


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 032, train_loss: 0.265736, train_acc: 0.916989, test_loss: 1.573546, test_acc: 0.484961


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 033, train_loss: 0.260315, train_acc: 0.921975, test_loss: 1.685773, test_acc: 0.464648


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 034, train_loss: 0.257873, train_acc: 0.920686, test_loss: 1.584178, test_acc: 0.487500


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 035, train_loss: 0.246867, train_acc: 0.923217, test_loss: 1.569760, test_acc: 0.490430


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 036, train_loss: 0.249008, train_acc: 0.923432, test_loss: 1.556685, test_acc: 0.485156


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 037, train_loss: 0.232003, train_acc: 0.926761, test_loss: 1.538324, test_acc: 0.513867


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 038, train_loss: 0.224127, train_acc: 0.930888, test_loss: 1.421499, test_acc: 0.531445


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 039, train_loss: 0.229993, train_acc: 0.933036, test_loss: 1.470702, test_acc: 0.525977


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 040, train_loss: 0.232380, train_acc: 0.927774, test_loss: 1.493377, test_acc: 0.508984


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 041, train_loss: 0.227333, train_acc: 0.931747, test_loss: 1.372178, test_acc: 0.543750


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 042, train_loss: 0.225935, train_acc: 0.932177, test_loss: 1.413282, test_acc: 0.534766


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 043, train_loss: 0.211338, train_acc: 0.935460, test_loss: 1.411176, test_acc: 0.538281


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 044, train_loss: 0.205310, train_acc: 0.936257, test_loss: 1.593140, test_acc: 0.500195


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 045, train_loss: 0.206800, train_acc: 0.935997, test_loss: 1.451998, test_acc: 0.536328


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 046, train_loss: 0.211066, train_acc: 0.934585, test_loss: 1.653322, test_acc: 0.464648


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 047, train_loss: 0.199195, train_acc: 0.938574, test_loss: 1.618964, test_acc: 0.483594


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 048, train_loss: 0.201856, train_acc: 0.938252, test_loss: 1.514698, test_acc: 0.509961


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 049, train_loss: 0.202682, train_acc: 0.938344, test_loss: 1.532901, test_acc: 0.496484


  0%|          | 0/291 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

EPOCH: 050, train_loss: 0.208081, train_acc: 0.935613, test_loss: 1.545610, test_acc: 0.494531


  0%|          | 0/291 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# model = torch.load('../models/CSE_train.model')
# print(model.layer4)

In [None]:
# def draw(x, output, input_type):
#     output_max = output.argmax(dim=1)
#     output_numpy = output.cpu().detach().numpy()
#     fig, ax = plt.subplots(figsize=(12, 8))
#     y = output_numpy[0]
#     ax.bar(x = x, height = y)
#     ax.set_title('Item: %s.' % x[int(output_max)], fontsize=15);
#     xticks = ax.get_xticks()
#     for i in range(len(y)):
#         xy = (xticks[i], y[i])
#         s = '%03f' % y[i]
#         ax.annotate(
#             text=s,  # 要添加的文本
#             xy=xy,  # 将文本添加到哪个位置
#             fontsize=8,  # 标签大小
#             color="red",  # 标签颜色
#             ha="center",  # 水平对齐
#             va="baseline"  # 垂直对齐
#         )
#         plt.savefig("../results/%s_%s.jpg" % (input_type, x[int(output_max)]))
#     return output_max

# def result(x, output, input_type):
#     output_max = output.argmax(dim=1)
#     print('Input: %s; Item: %s.' % (input_type, x[int(output_max)]))
#     return output_max

In [None]:
# model.eval()

# #########################################
# # # for input_file_name in os.listdir('inputs'):
# # input_file_name = 'doll_86.jpg'
# # image_name = os.path.join('inputs',input_file_name)
# # image = Image.open(image_name)
# # img = image.resize((224, 224))
# # input = test_transforms(image).unsqueeze(0)
# # input = input.to(device)
# # output = model(input)
# # # output_max = draw(label_key, output)
# # output_max = result(label_key, output, 'origin')
# # output_max = draw(label_key, output, 'origin')

# root_dir = '../sample_pairs'

# catagories = os.listdir(root_dir)
# catagories.sort()
# for catagory in catagories:
#     acc = 0
#     total = 0
#     print('###################### %s ######################' % catagory)
#     items = os.listdir(os.path.join(root_dir, catagory))
#     items.sort()
#     for item in items:
#         image_name = os.path.join(root_dir, catagory, item)
#         image = Image.open(image_name)
#         img = image.resize((224, 224))
#         input = test_transforms(image).unsqueeze(0)
#         input = input.to(device)
#         output, embedding = model(input)
#         # output_max = draw(label_key, output)
#         output_max = result(label_key, output, '%s, %s.' %(catagory, item))
# #         output_max = draw(label_key, output, '%s, %s.' %(catagory, item))
#         if item.split('_')[0] == label_key[int(output_max)]:
#             acc += 1
#         total += 1
#     if total:
#         print('---------------------- Acc: {}%. ----------------------'.format(format(acc/total * 100, '.2f')))