In [1]:
from datasets.dataset_factory import get_dataset
%matplotlib inline
import matplotlib.pyplot as plt
import torch
import torchvision
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import random
import time
from tqdm import tqdm
get_dataset=get_dataset('pascal','ctdet')

In [2]:
class Opts():
    data_dir='/home/lishiqi/obj/CenterNet/src/lib/../../data'
    keep_res=False
    input_h=512
    input_w=512
    down_ratio=4
    mse_loss=False
    dense_wh=False
    cat_spec_wh=False
    reg_offset=True
    debug=0
    draw_ma_gaussian=True
    not_rand_crop=True
    scale=0
    shift=0
    flip=0
    no_color_aug=True

In [3]:
opt=Opts()
dataset_pascal=get_dataset(opt,'train')
objs=torch.load('objs_2')

==> initializing pascal trainval0712 data.
loading annotations into memory...
Done (t=0.97s)
creating index...
index created!
Loaded train 16551 samples


In [4]:
class PascalDataset(Dataset):
    def __init__(self,dataset,obj_res,pic_res,objs):
        self.objs=objs['objs_'+str(obj_res)]['pics_'+str(pic_res)]
        self.dataset=dataset
        self.obj_res=obj_res
        self.pic_res=pic_res
        self.avgpool=torch.nn.AvgPool2d(4,4)
        self.maxpool=torch.nn.MaxPool2d(4,4)
        self.max_objs=50
    def __len__(self):
        return len(self.objs)
    def __getitem__(self, index):
        dataset=self.dataset
        objs=self.objs
        pic_index=objs[index]['pic_index']
        obj_index=objs[index]['obj_index']
        obj_pic=dataset[pic_index]
        
        wh=obj_pic['wh'][obj_index]*4
        ori_pic=obj_pic['input']
        bbox=(obj_pic['bboxs'][obj_index]*4).astype(np.int)
        hms=obj_pic['hm']
        bbox_crop=self.crop_pic(ori_pic,bbox,wh)
        batch_dict=self.gen_offset(obj_pic['bboxs'],bbox_crop)
        crop_pic=ori_pic[:,bbox_crop[1]:bbox_crop[3],bbox_crop[0]:bbox_crop[2]]
        crop_hm=hms[:,bbox_crop[1]:bbox_crop[3],bbox_crop[0]:bbox_crop[2]]
        crop_resize_hm=self.avgpool(torch.Tensor(crop_hm))
        batch_dict['hm']=crop_resize_hm
        batch_dict['input']=crop_pic
        return batch_dict
    def gen_offset(self,obj_bboxes,bbox):
        reg = np.zeros((self.max_objs, 2), dtype=np.float32)
        ind = np.zeros((self.max_objs), dtype=np.int64)
        reg_mask = np.zeros((self.max_objs), dtype=np.uint8)
        wh=np.zeros((self.max_objs,2),dtype=np.float32)
        
        bbox_output=bbox/4
        w=bbox_output[2]-bbox_output[0]
        h=bbox_output[3]-bbox_output[1]

        for index,obj_bbox in enumerate(obj_bboxes):
            if obj_bbox[0]>=bbox_output[0] and obj_bbox[1]>=bbox_output[1] and obj_bbox[2]<=bbox_output[2] and obj_bbox[3]<=bbox_output[3]:
                obj_bbox_offset=obj_bbox-[bbox_output[0],bbox_output[1],bbox_output[0],bbox_output[1]]

                ct = np.array([(obj_bbox_offset[0] + obj_bbox_offset[2]) / 2, (obj_bbox_offset[1] + obj_bbox_offset[3]) / 2], dtype=np.float32)
                ct_int = ct.astype(np.int32)
                reg[index]=ct-ct_int
                index_ct=ct_int[1]*w+ct_int[0]
                wh[index]=[obj_bbox_offset[2]-obj_bbox_offset[0],obj_bbox_offset[3]-obj_bbox_offset[1]]
                if index<0 or index>=w*h:
                    reg_mask[index]=0
                    ind[index]=0
                else:
                    reg_mask[index]=1
                    ind[index]=index_ct
        wh_and_offset=dict(reg=reg,ind=ind,reg_mask=reg_mask,wh=wh)
        return wh_and_offset
                
    def crop_pic(self,pic,bbox,wh):
        bbox_crop=np.zeros(4,np.int)
        ori_w=pic.shape[2]
        ori_h=pic.shape[1]
        cut_w,cut_h=self.pic_res,self.pic_res
        if bbox[1]<ori_h-bbox[3]:
            max_h,h_l=bbox[1],True
            min_h=(cut_h+bbox[1])-ori_h
        else:
            max_h,h_l=(ori_h-bbox[3]),False
            min_h=cut_h-bbox[3]
        if bbox[0]<ori_w-bbox[2]:
            max_w,w_l=bbox[0],True
            min_w=(cut_w+bbox[0])-ori_w
        else:
            max_w,w_l=(ori_w-bbox[2]),False
            min_w=cut_w-bbox[2]
        max_h=min(max_h,cut_h-wh[1])
        max_w=min(max_w,cut_w-wh[0])
        min_h=max(0,min_h)
        min_w=max(0,min_w)
        
        rand_h=np.random.randint(min_h,max_h+1)
        rand_w=np.random.randint(min_w,max_w+1)

        if h_l:
            bbox_crop[1]=bbox[1]-rand_h
            bbox_crop[3]=bbox_crop[1]+cut_h
        else:
            bbox_crop[3]=bbox[3]+rand_h
            bbox_crop[1]=bbox_crop[3]-cut_h
        if w_l:
            bbox_crop[0]=bbox[0]-rand_w
            bbox_crop[2]=bbox_crop[0]+cut_w
        else:
            bbox_crop[2]=bbox[2]+rand_w
            bbox_crop[0]=bbox_crop[2]-cut_w

        return bbox_crop



In [5]:
class DatasetObj(Dataset):
    def __init__(self,dataset_obj,dataset,objs,obj_res,pic_res,loader_bses):
        dataloaders=[]
        dataloader_size=[]
        for res,loader_bs in zip(pic_res,loader_bses):
            dataloaders.append(iter(DataLoader(dataset_obj(dataset,obj_res,res,objs),batch_size=loader_bs,shuffle=True)))
        sum_len=0
        for dataloader in dataloaders: 
            sum_len+=len(dataloader)
            dataloader_size.append(len(dataloader))
        self.dataloaders=dataloaders
        self.dataloader_size=dataloader_size
        self.sum_len=sum_len
        self.random_pool=self.dataloader_size.copy()
    def __getitem__(self,index):
        indexs =[i for i in range(len(self.random_pool)) if self.random_pool[i]>0]
        loader_index=random.choice(indexs)
    
        self.random_pool[loader_index]-=1
        batch=self.dataloaders[loader_index].next()
        if index==self.sum_len-1:
            self.random_pool=self.dataloader_size.copy()
        return batch
    def __len__(self):
        return self.sum_len

In [6]:
class DatasetObjMuiltRes(Dataset):
    def __init__(self,objs,dataset,obj_res=[],pic_res=dict(),loader_bses=dict()):
        self.obj_res=obj_res
        dataloaders=[]
        def default_collate(batch):
            return batch[0]
        for res in obj_res:
            dataset_obj=DatasetObj(PascalDataset,dataset,objs,res,pic_res['pic_'+str(res)],loader_bses['pic_'+str(res)])
            dataloaders.append(iter(DataLoader(dataset_obj,num_workers=0,collate_fn=default_collate)))
        dataloader_size=[]
        sum_len=0
        for dataloader in dataloaders: 
            sum_len+=len(dataloader)
            dataloader_size.append(len(dataloader))
        self.dataloaders=dataloaders
        self.dataloader_size=dataloader_size
        self.sum_len=sum_len
        self.random_pool=self.dataloader_size.copy()
    def __getitem__(self,index):
        indexs =[i for i in range(len(self.random_pool)) if self.random_pool[i]>0]
        loader_index=random.choice(indexs)
        self.random_pool[loader_index]-=1
        batch=self.dataloaders[loader_index].next()
        if index==self.sum_len-1:
            self.random_pool=self.dataloader_size.copy()
        batch['res']=self.obj_res[loader_index]
        return batch
    def __len__(self):
        return self.sum_len

In [14]:
data=DatasetObjMuiltRes(objs,
                        dataset_pascal,
                   obj_res=[32,64,128,256],
                   pic_res=dict(pic_32=[64,128,192,256],
                                pic_64=[128,192,256,384],
                                pic_128=[256,384,512],
                                pic_256=[384,512]),
                   loader_bses=dict(pic_32=[128,48,16,8],
                                    pic_64=[128,64,32,16],
                                    pic_128=[64,32,8],
                                    pic_256=[32,16]))
def default_collate(batch):
    return batch[0]
loader=DataLoader(data,num_workers=10,collate_fn=default_collate)

In [15]:
count=[]
for index,batch in enumerate(loader):
    count.append(dict(batch_size=len(batch['input']),res=batch['res'],pic_size=batch['input'].shape[2]))
    print(index,len(batch['wh']),batch['res'],batch['input'].shape[2])


0 8 32 256
1 8 128 512
2 64 64 192
3 32 64 256
4 128 32 64
5 128 32 64
6 64 64 192
7 32 256 384
8 32 256 384
9 64 128 256
10 64 128 256
11 32 128 384
12 16 64 384
13 16 256 512
14 32 128 384
15 16 256 512
16 32 256 384
17 8 32 256
18 128 64 128
19 64 64 192
20 128 64 128
21 32 64 256
22 16 64 384
23 16 64 384
24 16 256 512
25 48 32 128
26 32 256 384
27 16 256 512
28 32 256 384
29 128 32 64
30 128 64 128
31 8 128 512
32 16 256 512
33 16 256 512
34 16 256 512
35 8 32 256
36 16 256 512
37 32 256 384
38 16 256 512
39 32 256 384
40 16 64 384
41 8 32 256
42 128 32 64
43 128 32 64
44 48 32 128
45 16 64 384
46 16 32 192
47 16 32 192
48 32 64 256
49 16 64 384
50 8 128 512
51 128 32 64
52 16 32 192
53 16 256 512
54 8 128 512
55 16 256 512
56 16 256 512
57 16 64 384
58 128 64 128
59 64 128 256
60 128 64 128
61 16 256 512
62 64 64 192
63 8 128 512
64 8 32 256
65 64 128 256
66 32 128 384
67 128 32 64
68 32 64 256
69 64 128 256
70 64 128 256
71 16 32 192
72 64 128 256
73 32 256 384
74 16 256 512
75 

578 32 128 384
579 16 256 512
580 32 64 256
581 8 32 256
582 32 64 256
583 32 256 384
584 48 32 128
585 32 256 384
586 16 256 512
587 16 256 512
588 16 256 512
589 128 64 128
590 16 32 192
591 16 64 384
592 64 128 256
593 128 64 128
594 8 128 512
595 32 256 384
596 64 128 256
597 8 32 256
598 128 64 128
599 8 128 512
600 32 128 384
601 16 32 192
602 8 128 512
603 16 256 512
604 32 64 256
605 32 256 384
606 8 128 512
607 16 64 384
608 32 256 384
609 16 32 192
610 64 128 256
611 64 64 192
612 32 64 256
613 16 256 512
614 48 32 128
615 32 128 384
616 16 256 512
617 32 256 384
618 64 128 256
619 32 128 384
620 32 64 256
621 32 128 384
622 48 32 128
623 64 128 256
624 128 64 128
625 64 128 256
626 128 64 128
627 32 256 384
628 48 32 128
629 16 256 512
630 16 256 512
631 8 32 256
632 64 128 256
633 64 64 192
634 8 128 512
635 48 32 128
636 48 32 128
637 16 256 512
638 128 32 64
639 32 256 384
640 32 128 384
641 32 256 384
642 64 64 192
643 32 256 384
644 64 64 192
645 32 128 384
646 128 64 1

1143 16 64 384
1144 32 256 384
1145 32 256 384
1146 48 32 128
1147 32 256 384
1148 16 256 512
1149 64 64 192
1150 128 32 64
1151 64 128 256
1152 16 256 512
1153 16 256 512
1154 32 128 384
1155 16 32 192
1156 16 32 192
1157 16 32 192
1158 64 64 192
1159 8 32 256
1160 32 256 384
1161 64 128 256
1162 64 64 192
1163 16 32 192
1164 8 128 512
1165 16 256 512
1166 8 128 512
1167 32 256 384
1168 8 32 256
1169 16 256 512
1170 16 256 512
1171 16 64 384
1172 16 256 512
1173 64 128 256
1174 16 64 384
1175 16 256 512
1176 16 256 512
1177 64 128 256
1178 16 256 512
1179 16 256 512
1180 16 64 384
1181 48 32 128
1182 32 64 256
1183 32 64 256
1184 32 128 384
1185 16 256 512
1186 8 128 512
1187 32 128 384
1188 32 256 384
1189 16 32 192
1190 8 32 256
1191 32 64 256
1192 48 32 128


In [16]:
count_pic=dict(pic_64=0,pic_128=0,pic_192=0,pic_256=0,pic_384=0,pic_512=0)


In [17]:
for cou in count:
    count_pic['pic_'+str(cou['pic_size'])]+=cou['batch_size']

In [22]:
count_res=dict(res_64=0,res_128=0,res_32=0,res_256=0)
for cou in count:
    count_res['res_'+str(cou['res'])]+=cou['batch_size']

In [23]:
print(count_res)

{'res_64': 17888, 'res_128': 9984, 'res_32': 16489, 'res_256': 7056}


In [18]:
print(count_pic)

{'pic_64': 10752, 'pic_128': 13824, 'pic_192': 5904, 'pic_256': 9113, 'pic_384': 8320, 'pic_512': 3504}


In [None]:
class FocalLoss(nn.Module):
    def __init__(self):
        super(FocalLoss, self).__init__()
#         self.loss_func=torch.nn.MSELoss(reduction='none')
        self.neg_loss = self._neg_loss
    def forward(self,output,hm):   
        output=self.sigmoid(output)
        loss=self.neg_loss(output,hm)
        return loss
    def sigmoid(self,x):
        y = torch.clamp(x.sigmoid_(), min=1e-4, max=1-1e-4)
        return y
    def _neg_loss(self,pred, gt):
        pos_inds = gt.eq(1).float()
        neg_inds = gt.lt(1).float()

        neg_weights = torch.pow(1 - gt, 4)

        loss = 0

        pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
        neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds

        num_pos  = pos_inds.float().sum()
        pos_loss = pos_loss.sum()
        neg_loss = neg_loss.sum()

        if num_pos == 0:
            loss = loss - neg_loss
        else:
            loss = loss - (pos_loss + neg_loss) / num_pos
        return loss
    
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
acc_log=AverageMeter()
loss_log=AverageMeter()
import math
from tqdm import tqdm
cuda_device=1
model=resnet_mr(num_classes=20)
model.cuda(cuda_device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 40, 60, 80,100], gamma=0.5)
# params=torch.load('params.pth')
# model.load_state_dict(params)
end=time.time()
loss_func=FocalLoss()

mean = torch.tensor([0.485, 0.456, 0.406]).view(1,1,3)
std  = torch.tensor([0.229, 0.224, 0.225]).view(1,1,3)
def draw_pic_hm(data,batch_heatmap):
    classes=["aeroplane", "bicycle", "bird", "boat","bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", 
"horse", "motorbike", "person", "pottedplant", "sheep", "sofa", 
"train", "tvmonitor"]
    for pic,heatmap in zip(data,batch_heatmap):
        plt.imshow(pic.permute(1,2,0)*std+mean)
        plt.show()
        plt.figure(figsize=[20,20])
        for index,hm in enumerate(heatmap):
            plt.subplot(4,5,index+1)
            plt.title(classes[index])
            plt.imshow(hm)
        plt.show()
def adjust_learning_rate(optimizer, batch_size,epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = 0.01*(batch_size/256)*((0.5)**(epoch//20))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
for e in range(0,100):
    loss_log.reset()
    scheduler.step()
    loader=tqdm(loader)
    for index,((data,hm),l) in enumerate(loader):
#         print(data.shape,hm.shape)
        num=int(math.log(l//32,2))+1
        batch_size=len(data)
        data=data.cuda(cuda_device,non_blocking=True)
        hm=hm.cuda(cuda_device,non_blocking=True)
#         bg=bg.cuda(cuda_device)
#         clsid=clsid.cuda(cuda_device)
        output=model(data,num,True)
#         draw_pic_hm(data.cpu(),hm.cpu())
    #         print(output.shape,l)
        loss=loss_func(output,hm)
#         print(loss.item())
        optimizer.zero_grad()
        loss.backward()
        adjust_learning_rate(optimizer,batch_size,e)
        optimizer.step()
        loss_log.update(loss.item())
        loader.set_description('loss:%.3f(%.3f) time:%.3f epoch%d/%d'%(loss_log.val,loss_log.avg,time.time()-end,e,100))
        end=time.time()
    print('epoch:%d loss:%.3f'%(e,loss_log.avg))

loss:3.645(6.042) time:1.609 epoch0/100: 100%|██████████| 1193/1193 [26:32<00:00,  1.52s/it]     
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:0 loss:6.042


loss:3.072(3.492) time:1.576 epoch1/100: 100%|██████████| 1193/1193 [26:26<00:00,  1.40s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:1 loss:3.492


loss:2.957(3.314) time:1.211 epoch2/100: 100%|██████████| 1193/1193 [26:24<00:00,  1.20s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:2 loss:3.314


loss:3.529(3.140) time:1.339 epoch3/100: 100%|██████████| 1193/1193 [26:17<00:00,  1.43s/it] 
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:3 loss:3.140


loss:3.487(3.026) time:1.518 epoch4/100: 100%|██████████| 1193/1193 [26:26<00:00,  1.56s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:4 loss:3.026


loss:3.332(3.044) time:1.350 epoch5/100: 100%|██████████| 1193/1193 [26:53<00:00,  1.44s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:5 loss:3.044


loss:2.737(2.964) time:1.521 epoch6/100: 100%|██████████| 1193/1193 [27:04<00:00,  1.34s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:6 loss:2.964


loss:3.915(2.937) time:0.969 epoch7/100: 100%|██████████| 1193/1193 [26:59<00:00,  1.56s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:7 loss:2.937


loss:2.609(2.825) time:1.354 epoch8/100: 100%|██████████| 1193/1193 [27:06<00:00,  1.29s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:8 loss:2.825


loss:3.149(2.785) time:1.646 epoch9/100: 100%|██████████| 1193/1193 [27:02<00:00,  1.40s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:9 loss:2.785


loss:0.340(2.711) time:0.980 epoch10/100: 100%|██████████| 1193/1193 [27:28<00:00,  1.15s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:10 loss:2.711


loss:1.753(2.677) time:0.815 epoch11/100: 100%|██████████| 1193/1193 [26:45<00:00,  1.17s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:11 loss:2.677


loss:2.451(2.627) time:1.251 epoch12/100: 100%|██████████| 1193/1193 [27:32<00:00,  1.43s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:12 loss:2.627


loss:2.112(2.578) time:1.736 epoch13/100: 100%|██████████| 1193/1193 [26:42<00:00,  1.47s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:13 loss:2.578


loss:0.461(2.559) time:0.982 epoch14/100: 100%|██████████| 1193/1193 [27:12<00:00,  1.19s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:14 loss:2.559


loss:2.932(2.490) time:0.963 epoch15/100: 100%|██████████| 1193/1193 [27:12<00:00,  1.06s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:15 loss:2.490


loss:1.591(2.434) time:1.994 epoch16/100: 100%|██████████| 1193/1193 [27:20<00:00,  1.39s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:16 loss:2.434


loss:2.500(2.365) time:1.600 epoch17/100: 100%|██████████| 1193/1193 [27:19<00:00,  1.47s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:17 loss:2.365


loss:2.856(2.352) time:0.965 epoch18/100: 100%|██████████| 1193/1193 [27:22<00:00,  1.44s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:18 loss:2.352


loss:2.144(2.318) time:2.233 epoch19/100: 100%|██████████| 1193/1193 [27:04<00:00,  1.58s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:19 loss:2.318


loss:2.070(2.217) time:1.002 epoch20/100: 100%|██████████| 1193/1193 [27:17<00:00,  1.41s/it] 
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:20 loss:2.217


loss:1.433(2.152) time:1.049 epoch21/100: 100%|██████████| 1193/1193 [27:19<00:00,  1.18s/it] 
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:21 loss:2.152


loss:0.397(2.135) time:0.965 epoch22/100: 100%|██████████| 1193/1193 [27:04<00:00,  1.32s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:22 loss:2.135


loss:2.223(2.021) time:1.274 epoch23/100: 100%|██████████| 1193/1193 [27:15<00:00,  1.43s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:23 loss:2.021


loss:2.229(2.054) time:2.143 epoch24/100: 100%|██████████| 1193/1193 [27:29<00:00,  1.62s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:24 loss:2.054


loss:1.499(1.982) time:1.003 epoch25/100: 100%|██████████| 1193/1193 [27:11<00:00,  1.29s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:25 loss:1.982


loss:2.107(1.948) time:1.520 epoch26/100: 100%|██████████| 1193/1193 [27:14<00:00,  1.26s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:26 loss:1.948


loss:0.438(1.899) time:0.955 epoch27/100: 100%|██████████| 1193/1193 [27:04<00:00,  1.34s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:27 loss:1.899


loss:3.658(1.851) time:0.952 epoch28/100: 100%|██████████| 1193/1193 [27:10<00:00,  1.37s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:28 loss:1.851


loss:1.642(1.768) time:0.989 epoch29/100: 100%|██████████| 1193/1193 [26:48<00:00,  1.14s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:29 loss:1.768


loss:2.675(1.881) time:1.469 epoch30/100: 100%|██████████| 1193/1193 [26:28<00:00,  1.45s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:30 loss:1.881


loss:2.292(1.798) time:1.324 epoch31/100: 100%|██████████| 1193/1193 [26:12<00:00,  1.15s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:31 loss:1.798


loss:2.020(1.810) time:1.593 epoch32/100: 100%|██████████| 1193/1193 [26:20<00:00,  1.35s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:32 loss:1.810


loss:0.290(1.681) time:0.945 epoch33/100: 100%|██████████| 1193/1193 [26:35<00:00,  1.16s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:33 loss:1.681


loss:1.938(1.635) time:1.445 epoch34/100: 100%|██████████| 1193/1193 [26:29<00:00,  1.40s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:34 loss:1.635


loss:0.597(1.554) time:0.765 epoch35/100: 100%|██████████| 1193/1193 [26:35<00:00,  1.16s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:35 loss:1.554


loss:1.151(1.614) time:1.331 epoch36/100: 100%|██████████| 1193/1193 [26:20<00:00,  1.31s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:36 loss:1.614


loss:1.672(1.514) time:1.587 epoch37/100: 100%|██████████| 1193/1193 [26:30<00:00,  1.40s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:37 loss:1.514


loss:0.811(1.421) time:0.979 epoch38/100: 100%|██████████| 1193/1193 [26:25<00:00,  1.35s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:38 loss:1.421


loss:0.722(1.419) time:0.947 epoch39/100: 100%|██████████| 1193/1193 [26:29<00:00,  1.08s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:39 loss:1.419


loss:1.562(1.223) time:1.327 epoch40/100: 100%|██████████| 1193/1193 [26:37<00:00,  1.49s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:40 loss:1.223


loss:0.731(1.118) time:1.698 epoch41/100: 100%|██████████| 1193/1193 [26:12<00:00,  1.25s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:41 loss:1.118


loss:0.755(1.109) time:1.747 epoch42/100: 100%|██████████| 1193/1193 [26:39<00:00,  1.53s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:42 loss:1.109


loss:0.920(1.050) time:2.105 epoch43/100: 100%|██████████| 1193/1193 [26:21<00:00,  1.64s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:43 loss:1.050


loss:0.694(1.039) time:1.622 epoch44/100: 100%|██████████| 1193/1193 [26:41<00:00,  1.47s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:44 loss:1.039


loss:1.535(0.994) time:0.967 epoch45/100: 100%|██████████| 1193/1193 [26:30<00:00,  1.29s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:45 loss:0.994


loss:0.852(0.968) time:0.972 epoch46/100: 100%|██████████| 1193/1193 [26:29<00:00,  1.15s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:46 loss:0.968


loss:1.067(1.001) time:1.782 epoch47/100: 100%|██████████| 1193/1193 [26:23<00:00,  1.39s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:47 loss:1.001


loss:1.798(0.934) time:2.277 epoch48/100: 100%|██████████| 1193/1193 [26:33<00:00,  1.45s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:48 loss:0.934


loss:0.663(0.902) time:2.069 epoch49/100: 100%|██████████| 1193/1193 [26:32<00:00,  1.54s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:49 loss:0.902


loss:0.733(0.916) time:1.455 epoch50/100: 100%|██████████| 1193/1193 [26:41<00:00,  1.31s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:50 loss:0.916


loss:0.063(0.851) time:0.762 epoch51/100: 100%|██████████| 1193/1193 [26:28<00:00,  1.19s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:51 loss:0.851


loss:0.481(0.862) time:1.720 epoch52/100: 100%|██████████| 1193/1193 [26:12<00:00,  1.49s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:52 loss:0.862


loss:0.484(0.801) time:1.707 epoch53/100: 100%|██████████| 1193/1193 [26:30<00:00,  1.38s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:53 loss:0.801


loss:0.782(0.803) time:1.357 epoch54/100: 100%|██████████| 1193/1193 [26:31<00:00,  1.48s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:54 loss:0.803


loss:1.033(0.768) time:1.192 epoch55/100: 100%|██████████| 1193/1193 [26:32<00:00,  1.10s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:55 loss:0.768


loss:0.998(0.803) time:1.454 epoch56/100: 100%|██████████| 1193/1193 [26:58<00:00,  1.28s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:56 loss:0.803


loss:0.718(0.766) time:1.730 epoch57/100: 100%|██████████| 1193/1193 [26:37<00:00,  1.24s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:57 loss:0.766


loss:0.936(0.732) time:2.101 epoch58/100: 100%|██████████| 1193/1193 [26:58<00:00,  1.50s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:58 loss:0.732


loss:0.428(0.734) time:1.747 epoch59/100: 100%|██████████| 1193/1193 [26:48<00:00,  1.37s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:59 loss:0.734


loss:0.611(0.648) time:1.777 epoch60/100: 100%|██████████| 1193/1193 [26:42<00:00,  1.37s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:60 loss:0.648


loss:0.545(0.607) time:1.460 epoch61/100: 100%|██████████| 1193/1193 [27:16<00:00,  1.58s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:61 loss:0.607


loss:0.346(0.626) time:1.540 epoch62/100: 100%|██████████| 1193/1193 [26:57<00:00,  1.53s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:62 loss:0.626


loss:0.261(0.560) time:1.798 epoch63/100: 100%|██████████| 1193/1193 [26:57<00:00,  1.57s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:63 loss:0.560


loss:0.333(0.556) time:1.643 epoch64/100: 100%|██████████| 1193/1193 [27:07<00:00,  1.23s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:64 loss:0.556


loss:0.337(0.542) time:1.522 epoch65/100: 100%|██████████| 1193/1193 [27:07<00:00,  1.30s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:65 loss:0.542


loss:0.021(0.529) time:0.777 epoch66/100: 100%|██████████| 1193/1193 [27:01<00:00,  1.21s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:66 loss:0.529


loss:0.092(0.536) time:0.783 epoch67/100: 100%|██████████| 1193/1193 [27:03<00:00,  1.07s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:67 loss:0.536


loss:0.330(0.514) time:1.655 epoch68/100: 100%|██████████| 1193/1193 [27:22<00:00,  1.58s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:68 loss:0.514


loss:0.518(0.534) time:2.176 epoch69/100: 100%|██████████| 1193/1193 [26:51<00:00,  1.53s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:69 loss:0.534


loss:0.164(0.490) time:1.484 epoch70/100: 100%|██████████| 1193/1193 [27:08<00:00,  1.35s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:70 loss:0.490


loss:0.185(0.468) time:0.758 epoch71/100: 100%|██████████| 1193/1193 [26:37<00:00,  1.01it/s]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:71 loss:0.468


loss:0.431(0.461) time:1.741 epoch72/100: 100%|██████████| 1193/1193 [27:05<00:00,  1.34s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:72 loss:0.461


loss:3.350(0.465) time:0.263 epoch73/100: 100%|██████████| 1193/1193 [26:41<00:00,  1.00s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:73 loss:0.465


loss:0.503(0.454) time:0.980 epoch74/100: 100%|██████████| 1193/1193 [26:50<00:00,  1.22s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:74 loss:0.454


loss:0.039(0.464) time:0.982 epoch75/100: 100%|██████████| 1193/1193 [26:44<00:00,  1.20s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:75 loss:0.464


loss:0.615(0.458) time:2.106 epoch76/100: 100%|██████████| 1193/1193 [26:48<00:00,  1.51s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:76 loss:0.458


loss:0.678(0.457) time:0.959 epoch77/100: 100%|██████████| 1193/1193 [26:46<00:00,  1.30s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:77 loss:0.457


loss:0.702(0.437) time:0.957 epoch78/100: 100%|██████████| 1193/1193 [27:05<00:00,  1.13s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:78 loss:0.437


loss:0.345(0.426) time:1.651 epoch79/100: 100%|██████████| 1193/1193 [26:35<00:00,  1.58s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:79 loss:0.426


loss:0.242(0.400) time:1.340 epoch80/100: 100%|██████████| 1193/1193 [26:24<00:00,  1.32s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:80 loss:0.400


loss:0.199(0.409) time:0.966 epoch81/100: 100%|██████████| 1193/1193 [26:26<00:00,  1.06s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:81 loss:0.409


loss:0.179(0.392) time:1.339 epoch82/100: 100%|██████████| 1193/1193 [26:21<00:00,  1.55s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:82 loss:0.392


loss:0.305(0.360) time:1.726 epoch83/100: 100%|██████████| 1193/1193 [26:33<00:00,  1.58s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:83 loss:0.360


loss:0.279(0.355) time:1.630 epoch84/100: 100%|██████████| 1193/1193 [26:23<00:00,  1.61s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:84 loss:0.355


loss:0.048(0.382) time:0.949 epoch85/100: 100%|██████████| 1193/1193 [26:21<00:00,  1.18s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:85 loss:0.382


loss:0.173(0.392) time:1.477 epoch86/100: 100%|██████████| 1193/1193 [27:06<00:00,  1.43s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:86 loss:0.392


loss:0.009(0.374) time:0.756 epoch87/100: 100%|██████████| 1193/1193 [26:08<00:00,  1.12s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:87 loss:0.374


loss:0.003(0.351) time:0.962 epoch88/100: 100%|██████████| 1193/1193 [26:22<00:00,  1.08s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:88 loss:0.351


loss:0.291(0.371) time:0.990 epoch89/100: 100%|██████████| 1193/1193 [26:07<00:00,  1.24s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:89 loss:0.371


loss:0.121(0.371) time:0.982 epoch90/100: 100%|██████████| 1193/1193 [26:21<00:00,  1.12s/it] 
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:90 loss:0.371


loss:0.238(0.351) time:1.493 epoch91/100: 100%|██████████| 1193/1193 [26:43<00:00,  1.41s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:91 loss:0.351


loss:1.165(0.381) time:2.124 epoch92/100: 100%|██████████| 1193/1193 [26:27<00:00,  1.32s/it] 
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:92 loss:0.381


loss:0.128(0.344) time:0.965 epoch93/100: 100%|██████████| 1193/1193 [26:18<00:00,  1.21s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:93 loss:0.344


loss:0.047(0.324) time:0.954 epoch94/100: 100%|██████████| 1193/1193 [26:22<00:00,  1.24s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:94 loss:0.324


loss:0.414(0.336) time:1.460 epoch95/100: 100%|██████████| 1193/1193 [26:14<00:00,  1.40s/it]
  0%|          | 0/1193 [00:00<?, ?it/s]

epoch:95 loss:0.336


loss:0.288(0.338) time:1.468 epoch96/100:  60%|██████    | 717/1193 [15:56<10:44,  1.35s/it]

In [None]:
print('loss%d'%(4.5))
print(loss_log.avg)
input('s')
torch.save(model.state_dict(),'params_focal_lossss.pth')

In [None]:
import math
from tqdm import tqdm
cuda_device=1
end=time.time()
loss_func=FocalLoss()
loader=tqdm(loader)
model.eval()
classes=["aeroplane", "bicycle", "bird", "boat","bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", 
"horse", "motorbike", "person", "pottedplant", "sheep", "sofa", 
"train", "tvmonitor"]
mean = torch.tensor([0.485, 0.456, 0.406]).view(1,1,3)
std  = torch.tensor([0.229, 0.224, 0.225]).view(1,1,3)
def draw_pic_hm(data,batch_heatmap,hms):
    classes=["aeroplane", "bicycle", "bird", "boat","bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", 
"horse", "motorbike", "person", "pottedplant", "sheep", "sofa", 
"train", "tvmonitor"]
#     data=data.numpy()[...,::-1]
    for pic,heatmap,labels_hm in zip(data,batch_heatmap,hms):
        plt.imshow(((pic.permute((1,2,0))*std+mean)*255).int())
        plt.show()
#         plt.figure(figsize=[20,20])
        plt.subplot(1,2,1)
        for index,(out_hm) in enumerate(heatmap):
            plt.subplot(4,5,index+1)
            plt.title(classes[index])
            plt.imshow(out_hm)
        plt.show()
        plt.subplot(1,2,2)
        for index,(label_hm) in enumerate(labels_hm):
            plt.subplot(4,5,index+1)
            plt.title(classes[index])
            plt.imshow(label_hm)
        plt.show()
with torch.no_grad():
    loss_log.reset()
    for index,((data,hm),l) in enumerate(loader):
        num=int(math.log(l//32,2))+1
        data=data.cuda(cuda_device)
        hm=hm.cuda(cuda_device)
        output=model(data,num,False)
        draw_pic_hm(data.cpu(),loss_func.sigmoid(output.cpu()),hm.cpu())
        input('s')
        loss=loss_func(output,hm)
        loss_log.update(loss.item())        
        loader.set_description('loss:%.3f(%.3f) time:%.3f batch%d/%d'%(loss_log.val,loss_log.avg,time.time()-end,index,len(loader)))
        input('s')

In [None]:
torch.save(model.state_dict(),'params_9.pth')

In [None]:
print('epoch:%d loss:%.3f'%(e,loss_log.avg))

In [None]:
a=torch.tensor(range(20))
a=a.view(4,5)

In [None]:
a.float().sigmoid()

In [None]:
begin loss=0.5

In [None]:
for index,(i,j,k,z) in enumerate(loader):
    print(i.shape,j.shape,k.shape,z)
    a=input('s')
    pass
#     print(index,i.shape,j.shape)

In [None]:
a=[1,2,3,4]
b=a.copy()
b[1]=4
print(a,b)

In [None]:
count_8=0
count_16=0
count_32=0
count_64=0
count_128=0
count_256=0
count_513=0
count_40=0
count_80=0
count_160=0
count_pics=0
count_objs=0
count_40_64=0
count_40_128=0
count_40_160=0
count_40_256=0
count_40_max=0
count_80_128=0
count_80_160=0
count_80_256=0
count_80_max=0
count_160_256=0
count_160_384=0
count_160_512=0
count_160_max=0
for i in a:
    count_pics+=1
    whs=i['wh']
    reg=i['reg_mask']
    for x in reg:
        if x==1:
            count_objs+=1
    for w,h in whs:
        w,h=w*4,h*4
        if w==0 and h==0:
            break
        if min(w,h)<=40:
            count_40+=1
            if max(w,h)<=64:
                count_40_64+=1
            elif max(w,h)<=128:
                count_40_128+=1
            elif max(w,h)<=160:
                count_40_160+=1
            elif max(w,h)<=256:
                count_40_256+=1
            else:
                count_40_max+=1
        elif min(w,h)<=80:
            count_80+=1
            if max(w,h)<=128:
                count_80_128+=1
            elif max(w,h)<=160:
                count_80_160+=1
            elif max(w,h)<=256:
                count_80_256+=1
            else:
                count_80_max+=1
        elif min(w,h)<=160:
            count_160+=1
            if max(w,h)<=256:
                count_160_256+=1
            elif max(w,h)<=384:
                count_160_384+=1
            elif max(w,h)<=512:
                count_160_512+=1
            else:
                count_160_max+=1
        else:
            count_513+=1
# print(count_small,count_middle,count_large)

In [None]:
print(count_40,count_80,count_160,count_513)
print(count_40_64,count_40_128,count_40_160,count_40_256,count_40_max)
print(count_80_128,count_80_160,count_80_256,count_80_max)
print(count_160_256,count_160_384,count_160_512,count_160_max)

In [None]:
len(a)
plt.imshow((a[3]['input'].transpose(1,2,0)*a.std+a.mean)[...,::-1])
plt.show()

In [None]:
def show(i):
    pics=a[i]['hm']
    input_pic=a[i]['input']
    plt.imshow((a[i]['input'].transpose(1,2,0)*a.std+a.mean)[...,::-1])
    plt.show()
    for pic in pics:
        if pic.sum()>0:
            plt.imshow(pic)
            plt.show()

In [None]:
while True:
    b=input('s')
    show(int(b))

In [None]:
for i in range(1,10):
    plt.imshow((a[i]['input'].transpose(1,2,0)*a.std+a.mean)[...,::-1])
    plt.show()
    b=input('s')

In [None]:
a[0]['input'].shape

In [None]:
# for k,v in a[0].items():
#     print(k)
# print(a[0]['bboxs'])
# print(a[0]['bboxs'])
# print(a[0]['wh'])
# print(a[0]['hm'])
for pic in a[0]['hm']['hm_32']:
    print('hm_32')
    if pic.sum()>0:
        plt.imshow(pic)
        plt.show()
for pic in a[0]['hm']['hm_64']:
    print('hm_64')
    if pic.sum()>0:
        plt.imshow(pic)
        plt.show()
for pic in a[0]['hm']['hm_128']:
    print('hm_128')
    if pic.sum()>0:
        plt.imshow(pic)
        plt.show()
for pic in a[0]['hm']['hm_256']:
    print('hm_256')
    if pic.sum()>0:
        plt.imshow(pic)
        plt.show()
plt.imshow((a[0]['input'].transpose(1,2,0)*a.std+a.mean)[...,::-1])
plt.show()

In [None]:
objs_32=dict(pics_64=[],pics_128=[],pics_192=[],pics_256=[],pics_384=[],pics_512=[])
objs_64=dict(pics_128=[],pics_192=[],pics_256=[],pics_384=[],pics_512=[])
objs_128=dict(pics_256=[],pics_384=[],pics_512=[])
objs_256=dict(pics_384=[],pics_512=[])
count_40=0
count_80=0
count_160=0
count_256=0
for index,j in enumerate(a):
    for obj_index,(w,h) in enumerate(j['wh']):
        w,h=w*4,h*4
        if w==0 and h==0:
            break
        if min(w,h)<=40:
            count_40+=1
            if max(w,h)<=64:
                objs_32['pics_64'].append(dict(pic_index=index,obj_index=obj_index))
            elif max(w,h)<=128:
                objs_32['pics_128'].append(dict(pic_index=index,obj_index=obj_index))
            elif max(w,h)<=192:
                objs_32['pics_192'].append(dict(pic_index=index,obj_index=obj_index))
            elif max(w,h)<=256:
                objs_32['pics_256'].append(dict(pic_index=index,obj_index=obj_index))
            elif max(w,h)<=384:
                objs_32['pics_384'].append(dict(pic_index=index,obj_index=obj_index))
            else:
                objs_32['pics_512'].append(dict(pic_index=index,obj_index=obj_index))
        elif min(w,h)<=80:
            count_80+=1
            if max(w,h)<=128:
                objs_64['pics_128'].append(dict(pic_index=index,obj_index=obj_index))
            elif max(w,h)<=192:
                objs_64['pics_192'].append(dict(pic_index=index,obj_index=obj_index))
            elif max(w,h)<=256:
                objs_64['pics_256'].append(dict(pic_index=index,obj_index=obj_index))
            elif max(w,h)<=384:
                objs_64['pics_384'].append(dict(pic_index=index,obj_index=obj_index))
            else:
                objs_64['pics_512'].append(dict(pic_index=index,obj_index=obj_index))
        elif min(w,h)<=160:
            count_160+=1
            if max(w,h)<=256:
                objs_128['pics_256'].append(dict(pic_index=index,obj_index=obj_index))
            elif max(w,h)<=384:
                objs_128['pics_384'].append(dict(pic_index=index,obj_index=obj_index))
            elif max(w,h)<=512:
                objs_128['pics_512'].append(dict(pic_index=index,obj_index=obj_index))
        else:
            count_256+=1
            if max(w,h)<=384:
                objs_256['pics_384'].append(dict(pic_index=index,obj_index=obj_index))
            elif max(w,h)<=512:
                objs_256['pics_512'].append(dict(pic_index=index,obj_index=obj_index))

In [None]:
import torch
torch.save(objs,'./objs_2')

In [None]:
def get(num):
    wh=a[objs_32['pics_64'][num]['pic_index']]['wh'][objs_32['pics_64'][num]['obj_index']]*4
    return wh

In [None]:
i=0 
while True:
    i+=1 
    show(i)

In [None]:
def show(num):
    pic,wh,bbox,obj_pic,crop_pic,bbox_crop,hm,bg=dataset[num]
    print(obj_pic.shape,bbox_crop,bbox,wh)
    plt.subplot(1,5,1)
    plt.imshow((obj_pic.transpose(1,2,0)*a.std+a.mean)[...,::-1])
    plt.subplot(1,5,2)
    plt.imshow((pic.transpose(1,2,0)*a.std+a.mean)[...,::-1])
    plt.subplot(1,5,3)
    plt.imshow((crop_pic.transpose(1,2,0)*a.std+a.mean)[...,::-1])
    plt.subplot(1,5,4)
    plt.imshow(hm.squeeze(0))
    plt.subplot(1,5,5)
    plt.imshow(bg.squeeze(0))
    plt.show()

In [None]:
DataLoader(dataset, batch_size=args.batch_size_test, shuffle=False, num_workers=args.workers)

In [None]:
b=a[0]['meta']
bboxs=b['gt_det']
pics1=a[0]['hm'][0]
for bbox in bboxs:
    pics1[int(bbox[1]):int(bbox[3]),int(bbox[0]):int(bbox[2]),]=1
plt.imshow(pics1)
plt.show()

In [2]:
import os
import math
import torch
import torch.nn as nn
import torchvision.models

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=3,
                     stride=stride,
                     padding=1,
                     bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        if stride != 1:
            self.conv1 = nn.Conv2d(inplanes,
                                   planes,
                                   kernel_size=4,
                                   stride=stride,
                                   padding=1,
                                   bias=False)
        else:
            self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes,
                               planes,
                               kernel_size=3,
                               stride=stride,
                               padding=1,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)
        return out


class ResNetMR(nn.Module):
    def __init__(self, block, layers, num_classes=21):
        super(ResNetMR, self).__init__()
        self.inplanes = 32
        self.conv1 = nn.Conv2d(3,
                               32,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 32, layers[0], stride=1)
        self.layer1_4=self._make_layer(block,128,layers[0], stride=1)
        self.inplanes=32
        self.layer2 = self._make_layer(block, 64, layers[1], stride=2)
        self.layer2_4 = self._make_layer(block, 128, layers[1], stride=1)
        self.inplanes = 64
        self.layer3 = self._make_layer(block, 128, layers[2], stride=2)
        self.layer3_4 = self._make_layer(block, 128, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 256, layers[3], stride=2)
        self.layer5 = self._make_layer(block, 512, layers[4], stride=2)
        self.layer6 = self._make_layer(block, 1024, layers[5], stride=2)
        self.layer4_tr=nn.Sequential(nn.Conv2d(256,
                                               512,
                                               kernel_size=4,
                                               stride=2,
                                               padding=1,
                                               bias=False),
                                     nn.BatchNorm2d(512),
                                     nn.ReLU(inplace=True))
        self.layer6_tr=nn.Sequential(nn.ConvTranspose2d(1024,
                                                        512,
                                                        kernel_size=4,
                                                        stride=2,
                                                        padding=1,
                                                        bias=False),
                                     nn.BatchNorm2d(512),
                                     nn.ReLU(inplace=True))
        self.inplanes=1536
        self.layer_f=self._make_layer(block,768,2,stride=1)
        self.final_cls=nn.Conv2d(768,num_classes,1,1)
    def fusion(self,input_middle):
        output_4=self.layer4(input_middle)
        output_5=self.layer5(output_4)
        output_6=self.layer6(output_5)
        output_4=self.layer4_tr(output_4)
        output_6=self.layer6_tr(output_6)
        output=torch.cat([output_4,output_5,output_6],1)
        output=self.layer_f(output)
        output=self.final_cls(output)
        return output
    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            if stride != 1:
                downsample = nn.Sequential(
                    nn.Conv2d(self.inplanes,
                              planes * block.expansion,
                              kernel_size=4,
                              stride=stride,
                              padding=1,
                              bias=False),
                    nn.BatchNorm2d(planes * block.expansion),
                )
            else:
                downsample = nn.Sequential(
                    nn.Conv2d(self.inplanes,
                              planes * block.expansion,
                              kernel_size=1,
                              stride=stride,
                              bias=False),
                    nn.BatchNorm2d(planes * block.expansion),
                )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x,pattern,cal):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x1 = self.layer1(x)
        self.x1=x1
        if pattern==1:
            x4=self.layer1_4(x1)
            fm=self.fusion(x4)
            return fm
        
        self.x2=self.layer2(self.x1)
        if pattern==2:

            x4=self.layer2_4(self.x2)
            fm=self.fusion(x4)
            return fm
        
        self.x3=self.layer3(self.x2)
        if pattern==3:
            fm=self.fusion(self.x3)
            return fm
        if pattern==4:
            self.x4=self.layer3_4(self.x3)
            fm=self.fusion(self.x4)
            return fm

def resnet_mr(pretrained=False,**kwargs):
    model=ResNetMR(BasicBlock,[1,1,2,2,2,2,2],**kwargs)    
    return model