In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from lib.actor.LeeNet import LeeNetActor
from lib.models.backbone.plscore_RMT_sigma import PLScoreRMT
from lib.models.LeeNet.score_RMT_sigma_center import ScorePureRMTCENTER
from lib.models.layer.RMT import PatchMerging
from lib.models.head.center_predictor import CenterPredictor
from lib.trainer.LeeNet_trainer import LeeNetTrainer
from lib.utils.base_funtion import build_dataloaders, get_optimizer_scheduler
from lib.config.cfg_loader import env_setting
from torch.nn.functional import l1_loss
from torch.nn import BCEWithLogitsLoss
from lib.utils.box_ops import giou_loss
from lib.utils.focal_loss import FocalLoss
import torch


def build_model(cfg):
    backbone = PLScoreRMT(down_sample=PatchMerging,cfg=cfg)
    
    stride = cfg.model.backbone.stride
    feat_sz = int(cfg.data.search.size / stride)
    head = CenterPredictor(inplanes=cfg.model.pureRMT.embed_dim[-1], channel=cfg.model.head.num_channels, feat_sz=feat_sz, stride=stride)
    model = ScorePureRMTCENTER(backbone, head, cfg)
    return model


cfg = env_setting(cfg_name='plscore_pureRMT_sigma_center.yaml')
loader_train, loader_val = build_dataloaders(cfg)
data = None
# for i, _data in enumerate(loader_train, 1):
#     data = _data
#     break

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# net = build_model(cfg)
# 
# focal_loss = FocalLoss()
# objective = {'giou': giou_loss, 'l1': l1_loss, 'focal': focal_loss, 'cls': BCEWithLogitsLoss()}
# loss_weight = {'giou': cfg.train.GIOU_weight, 'l1': cfg.train.L1_weight, 'focal': 1., 'cls': 1.0}
# actor = LeeNetActor(net=net, objective=objective, loss_weight=loss_weight, cfg=cfg)

In [4]:
# data = data.to(cfg.train.device)
# actor(data)

In [5]:
checkpoint_path = "/media/star/data/Leezed/workspace/LeeNet/checkpoints/LeeNet_plScore_RMT_CENTER00001/ScorePureRMTCENTER_ep2600.pth.tar"
# checkpoint_dict = torch.load(checkpoint_path,map_location={"cuda:1":"cuda:1"})['net']
checkpoint_dict = torch.load(checkpoint_path)['net']
print(checkpoint_dict.keys())

net = build_model(cfg)

model_dict = net.state_dict()
print(model_dict.keys())
state_dict = {k: v for k, v in checkpoint_dict.items() if k in model_dict.keys()}
print(state_dict.keys())
del state_dict['mlp.layers.0.weight']
model_dict.update(state_dict)
net.load_state_dict(model_dict,strict=False)

odict_keys(['backbone.patch_embed.proj.weight', 'backbone.patch_embed.proj.bias', 'backbone.score.conv1.conv.weight', 'backbone.score.conv1.bn.weight', 'backbone.score.conv1.bn.bias', 'backbone.score.conv1.bn.running_mean', 'backbone.score.conv1.bn.running_var', 'backbone.score.conv1.bn.num_batches_tracked', 'backbone.score.conv2.conv.weight', 'backbone.score.conv2.bn.weight', 'backbone.score.conv2.bn.bias', 'backbone.score.conv2.bn.running_mean', 'backbone.score.conv2.bn.running_var', 'backbone.score.conv2.bn.num_batches_tracked', 'backbone.score.confident_conv1.conv.weight', 'backbone.score.confident_conv1.bn.weight', 'backbone.score.confident_conv1.bn.bias', 'backbone.score.confident_conv1.bn.running_mean', 'backbone.score.confident_conv1.bn.running_var', 'backbone.score.confident_conv1.bn.num_batches_tracked', 'backbone.score.confident_conv2.conv.weight', 'backbone.score.confident_conv2.bn.weight', 'backbone.score.confident_conv2.bn.bias', 'backbone.score.confident_conv2.bn.running

<All keys matched successfully>

In [6]:


focal_loss = FocalLoss()
objective = {'giou': giou_loss, 'l1': l1_loss, 'focal': focal_loss, 'cls': BCEWithLogitsLoss()}
loss_weight = {'giou': cfg.train.GIOU_weight, 'l1': cfg.train.L1_weight, 'focal': 1., 'cls': 1.0}
actor = LeeNetActor(net=net, objective=objective, loss_weight=loss_weight, cfg=cfg)

optimizer, lr_scheduler = get_optimizer_scheduler(net, cfg)

trainer = LeeNetTrainer(actor=actor, loaders=[loader_train, loader_val], optimizer=optimizer, lr_scheduler=lr_scheduler, cfg=cfg)

trainer.train(cfg.train.epoch)

Epoch Time: 0:01:41.351480
Avg Data Time: 97.82503
Avg GPU Trans Time: 0.28029
Avg Forward Time: 3.24616
Epoch Time: 0:01:48.794009
Avg Data Time: 51.51986
Avg GPU Trans Time: 0.26557
Avg Forward Time: 2.61157
Epoch Time: 0:01:51.387578
Avg Data Time: 34.41742
Avg GPU Trans Time: 0.25159
Avg Forward Time: 2.46019
Epoch Time: 0:01:53.789607
Avg Data Time: 25.85925
Avg GPU Trans Time: 0.24639
Avg Forward Time: 2.34177
Epoch Time: 0:01:56.163470
Avg Data Time: 20.71485
Avg GPU Trans Time: 0.24574
Avg Forward Time: 2.27211
Epoch Time: 0:01:58.551765
Avg Data Time: 17.28722
Avg GPU Trans Time: 0.24376
Avg Forward Time: 2.22765
Epoch Time: 0:02:01.061109
Avg Data Time: 14.85253
Avg GPU Trans Time: 0.24254
Avg Forward Time: 2.19937
Epoch Time: 0:02:03.524526
Avg Data Time: 13.02025
Avg GPU Trans Time: 0.24244
Avg Forward Time: 2.17788
Epoch Time: 0:02:06.045780
Avg Data Time: 11.59794
Avg GPU Trans Time: 0.24401
Avg Forward Time: 2.16314
[train: 1, 10 / 234] FPS: 19.9 (102.8)  ,  DataTime: 10

KeyboardInterrupt: 