Skip to content

Commit

Permalink
cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
talrid committed Feb 4, 2021
1 parent 1932e1f commit eb52197
Showing 1 changed file with 12 additions and 17 deletions.
29 changes: 12 additions & 17 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,23 @@

def main():
args = parser.parse_args()
args.batch_size = args.batch_size
args.do_bottleneck_head = False

# setup model
# Setup model
print('creating model...')
model = create_model(args).cuda()
if args.model_path:
if args.model_path: # make sure to load pretrained ImageNet model
state = torch.load(args.model_path, map_location='cpu')
filtered_dict = {k: v for k, v in state['model'].items() if
(k in model.state_dict() and 'head.fc' not in k)}
model.load_state_dict(filtered_dict, strict=False)
print('done\n')

# Data loading code
# COCO Data loading
instances_path_val = os.path.join(args.data, 'annotations/instances_val2014.json')
instances_path_train = os.path.join(args.data, 'annotations/instances_train2014.json')
# data_path_val = os.path.join(args.data, 'val2014')
# data_path_train = os.path.join(args.data, 'train2014')
data_path_val = args.data
data_path_train = args.data

val_dataset = CocoDetection(data_path_val,
instances_path_val,
transforms.Compose([
Expand All @@ -70,29 +66,29 @@ def main():
transforms.ToTensor(),
# normalize,
]))

print("len(val_dataset)): ", len(val_dataset))
print("len(train_dataset)): ", len(train_dataset))

val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=False)

# Pytorch Data loader
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True)

val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=False)

train_multi_label_coco(model, train_loader, val_loader, args.lr, gamma_neg=4, gamma_pos=0, clip=0.05)
# Actuall Training
train_multi_label_coco(model, train_loader, val_loader, args.lr)


def train_multi_label_coco(model, train_loader, val_loader, lr=2e-4, gamma_neg=4, gamma_pos=0, clip=0.05):
def train_multi_label_coco(model, train_loader, val_loader, lr):
ema = ModelEma(model, 0.9997) # 0.9997^641=0.82

# set optimizer
Epochs = 40
weight_decay = 1e-4
criterion = AsymmetricLoss(gamma_neg=gamma_neg, gamma_pos=gamma_pos, clip=clip,
disable_torch_grad_focal_loss=True)
criterion = AsymmetricLoss(gamma_neg=4, gamma_pos=0, clip=0.05, disable_torch_grad_focal_loss=True)
parameters = add_weight_decay(model, weight_decay)
optimizer = torch.optim.Adam(params=parameters, lr=lr, weight_decay=0) # true wd, filter_bias_and_bn
steps_per_epoch = len(train_loader)
Expand All @@ -101,7 +97,6 @@ def train_multi_label_coco(model, train_loader, val_loader, lr=2e-4, gamma_neg=4

highest_mAP = 0
trainInfoList = []
Sig = torch.nn.Sigmoid()
scaler = GradScaler()
for epoch in range(Epochs):
for i, (inputData, target) in enumerate(train_loader):
Expand Down

0 comments on commit eb52197

Please sign in to comment.