In [1]:
gpu_id = '1'
resnet_type = '34'

t = 0.5
region_side_length = 4


import warnings
warnings.filterwarnings("ignore")
import os
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id

import torch
import time
from tqdm import tqdm

from torch.utils.data import DataLoader
from data_utils.Dataset import Dataset,TestDataset,ValDataset,Test_FPS_Dataset,TestDataset_real

from tra_val.evaluation import *
from tra_val.loss import *
from tra_val.tra_val import *
from model.FARN import FARN
data_dir= 'HRSC2016/'

In [6]:
train_batch_size = 8 # default:2
test_batch_size = 1 # fixed:1
val_batch_size = 1 # fixed:1
fps_bath_size = 8 # default:1

train_dataset = Dataset(region_side_length,data_dir) #train dataset with bbox
tra_Dataloader = DataLoader(train_dataset, train_batch_size, shuffle=True,
                                               num_workers=8, pin_memory=True)

test_dataset = TestDataset(region_side_length,data_dir) #test dataset with bbox
tes_Dataloader = DataLoader(test_dataset, test_batch_size, shuffle=False,
                                               num_workers=8, pin_memory=True)

val_dataset = ValDataset(region_side_length,data_dir) #val dataset with bbox
val_Dataloader = DataLoader(val_dataset, val_batch_size, shuffle=True,
                                               num_workers=8, pin_memory=True)

test_fps_dataset = Test_FPS_Dataset(region_side_length,data_dir) #FPS test only with image
test_fps_Dataloader = DataLoader(test_fps_dataset, fps_bath_size, shuffle=True,
                                               num_workers=8, pin_memory=True)

testDataset_real = TestDataset_real(region_side_length,data_dir) #test dateset without resize_bbox
tes_realDataloader = DataLoader(testDataset_real, test_batch_size, shuffle=False,
                                               num_workers=8, pin_memory=True)

#pixels coordinates
pixel_coordinates = torch.zeros(128,128,2).cuda()
print('pixel_coordinates constructing')
for i in tqdm(range(128)):
    for ii in range(128):
        pixel_coordinates[i,ii,:] = torch.tensor([ii,i])*4

#model construction
model = FARN(resnet_type='resnet'+resnet_type, boxes_dx_dy = 4 * region_side_length).cuda()
# model.load_state_dict(torch.load('./checkpoints_34/g1_FARN_region44_-set_t0.5_34_07map_0.8943_date_20201016_17_51_43'))

optimizer =torch.optim.Adam(model.parameters(),lr=0.0001)

#training logging
# Log_path = 'output/g'+gpu_id+'_FARN_'+resnet_type+str(time.strftime("_date_%Y%m%d"))+str(time.strftime("_%H_%M_%S.txt"))
Log_path = 'Experiments_log_'+resnet_type+'/g'+gpu_id+'_FARN_region'+str(region_side_length)+'_'+str(t)+'_'+resnet_type+str(time.strftime("_date_%Y%m%d"))+str(time.strftime("_%H_%M_%S.txt"))

 17%|█▋        | 22/128 [00:00<00:00, 214.31it/s]

pixel_coordinates constructing


100%|██████████| 128/128 [00:00<00:00, 226.33it/s]


In [None]:
boxes_dx_dy = 4 * region_side_length
bestmap_07 = 0
bestmap_12 = 0
losses = []

if 1-os.path.exists('Experiments_log_'+resnet_type):
     os.mkdir('Experiments_log_'+resnet_type)
for epoch in range(5000):
    if epoch < 10:
        validate_interval = 10
    else:
        validate_interval = 1
    model.train()
    loss = Train_FARN(model,
               epoch,
               tra_Dataloader,
               optimizer,
               train_batch_size,
               pixel_coordinates,
               bodies_theshold=t,
               coeff_dxdy=boxes_dx_dy,
               visloss_per_iter=tra_Dataloader.__len__()/1,
               if_iou = False,
               S_set=False,
                )
    losses.append(loss)
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.set(title='PR',
           ylabel='loss', xlabel='epoch')
    plot1 = plt.plot(np.arange(epoch+1), losses, 'r',label='original values')
    plt.show()
    
    if (epoch + 1) % validate_interval == 0:
        model.eval()
        MAP_07, MAP_12 = Validate_FARN(model,
                                       epoch,
                                       tes_realDataloader,
                                       pixel_coordinates,
                                       Log_path,
                                       bestmap_07,
                                       bestmap_12,
                                       fliter_theshold=t,
                                       scores_theshold=0,
                                       nms_theshold=0.5,
                                       nms_saved_images_num=20,
                                       plot_PR=False,
                                      )
        if MAP_07 >= bestmap_07:
            bestmap_07 = MAP_07
            if MAP_07 > 0.80:
                torch.save(model.state_dict(), 'checkpoints_'+resnet_type+'/g'+gpu_id+'_FARN_region'+str(region_side_length)+'_'+str(t)+'_'+resnet_type+'_07map_'+str(round(bestmap_07, 4))+str(time.strftime("_date_%Y%m%d"))+str(time.strftime("_%H_%M_%S")))
        if MAP_12 >= bestmap_12:
            bestmap_12 = MAP_12
#             if MAP_12 > 0.80:
#                 torch.save(model.state_dict(), 'checkpoints/FARN_'+'12map_'+str(round(bestmap_12, 4))+str(time.strftime("_date_%Y%m%d"))+str(time.strftime("_%H_%M_%S")))