In [1]:
from PIL import Image
import numpy as np
import pandas as pd
import cv2
import os
import random
import matplotlib.pyplot as plt
# import tensorwatch as tw

import torch
import copy
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.optim.lr_scheduler import StepLR

from utils.dataset import *
import torch.utils.data as data

from utils.train import *
from utils.test  import *
from vit_rollout import VITAttentionRollout
import time

In [2]:
# Training settings
####### test作训练集 #######
# epochs = 80
# lr = 3e-6
# gamma = 0.7
# step_size = 5
###### train作训练集 #######
epochs = 300
lr = 1.25e-4
gamma = 0.5
step_size = 5

seed = 42
device = 'cuda:1'

file_Path = '/home/a611/Projects/Datasets/mini-imagenet/images/'
train_name = ['/home/a611/Projects/Datasets/mini-imagenet/pre_train.csv']
test_name = ['/home/a611/Projects/Datasets/mini-imagenet/pre_test.csv']
num_classes = 90
num_input = 3
batch_size = 128
num_workers = 8
########################
os.chdir('examples')

In [3]:
# from tensorboardX import SummaryWriter
# writer = SummaryWriter('log') #建立一个保存数据用的东西
model = torch.hub.load('facebookresearch/deit:main', 
    'deit_tiny_patch16_224', pretrained=True)
model.head = nn.Linear(in_features = model.head.in_features, out_features = num_classes, bias = True)
model.to(device);

Using cache found in /home/a611/.cache/torch/hub/facebookresearch_deit_main


In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

## Load Data

In [5]:
train_transforms = transforms.Compose(
    [
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)


test_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)

## Load Datasets

In [6]:
label_map = get_map(test_name)
label_key = list(label_map.keys())
train_set = MyDataset(file_Path, train_name, label_map,
                            train_transforms)
train_loader = data.DataLoader(
    dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

valid_set = MyDataset(file_Path, test_name, label_map,
                            val_transforms)
valid_loader = data.DataLoader(
    dataset=valid_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [7]:
print(len(train_loader))
print(len(valid_loader))

352
71


## Effecient Attention

### Training

In [8]:
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)

In [9]:
model_ = None
highest_test_acc = 0
for i in range(epochs):
    # print('EPOCH:', i + 1)
    train_iter = iter(train_loader)
    test_iter = iter(valid_loader)
    ########################################
    train_loss, train_acc = train(model, device, train_iter, optimizer, train_set, batch_size)
    test_loss, test_acc = test(model, device, test_iter, valid_set, batch_size)
    scheduler.step()
    print( 'EPOCH: %03d, train_loss: %3f, train_acc: %3f, test_loss: %3f, test_acc: %3f'
          % (i + 1, train_loss, train_acc, test_loss, test_acc))
    if test_acc > highest_test_acc:
        highest_test_acc = test_acc
        model_ = copy.deepcopy(model)
        print('Highest test accuracy: %3f' % highest_test_acc)
        torch.save(model_, '../models/ImageNet_Pretrained.model')
#     print( 'EPOCH: %03d, train_loss: %3f, train_acc: %3f' % (i + 1, train_loss, train_acc))

# test_iter = iter(valid_loader)
model = torch.load("/home/a611/Projects/gyc/Local_Features/models/ImageNet_Pretrained.model", map_location=device)

# test_loss, test_acc = test(model, device, test_iter, valid_set, batch_size)
# print( 'test_loss: %3f, test_acc: %3f' % (test_loss, test_acc))

  0%|          | 0/352 [00:00<?, ?it/s]

  0%|          | 0/71 [00:00<?, ?it/s]

EPOCH: 001, train_loss: 1.310960, train_acc: 0.686790, test_loss: 0.596098, test_acc: 0.842033
Highest test accuracy: 0.842033


  0%|          | 0/352 [00:00<?, ?it/s]

  0%|          | 0/71 [00:00<?, ?it/s]

EPOCH: 002, train_loss: 0.717768, train_acc: 0.808451, test_loss: 0.528536, test_acc: 0.854820
Highest test accuracy: 0.854820


  0%|          | 0/352 [00:00<?, ?it/s]

  0%|          | 0/71 [00:00<?, ?it/s]

EPOCH: 003, train_loss: 0.634529, train_acc: 0.829607, test_loss: 0.515335, test_acc: 0.856932
Highest test accuracy: 0.856932


  0%|          | 0/352 [00:00<?, ?it/s]

  0%|          | 0/71 [00:00<?, ?it/s]

KeyboardInterrupt: 