Skip to content

Commit

Permalink
Merge pull request #5 from THEFASHIONGEEK/master
Browse files Browse the repository at this point in the history
Added Multi-GPU Training For RetinaNet. Verified, tested and added to pipeline
  • Loading branch information
abhi-kumar committed Jan 19, 2020
2 parents 7f18b70 + 8504267 commit ebda650
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions 5_pytorch_retinanet/lib/train_detector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections
import os

import numpy as np

Expand Down Expand Up @@ -32,6 +33,7 @@ def __init__(self, verbose=1):
self.system_dict["params"]["num_workers"] = 3;
self.system_dict["params"]["use_gpu"] = True;
self.system_dict["params"]["lr"] = 0.0001;
self.system_dict["params"]["gpu_devices"] = [0];
self.system_dict["params"]["num_epochs"] = 10;
self.system_dict["params"]["val_interval"] = 1;
self.system_dict["params"]["print_interval"] = 20;
Expand Down Expand Up @@ -95,7 +97,7 @@ def Val_Dataset(self, root_dir, coco_dir, img_dir, set_dir):
print('Num validation images: {}'.format(len(self.system_dict["local"]["dataset_val"])))


def Model(self, model_name="resnet18"):
def Model(self, model_name="resnet18",gpu_devices=[0]):

num_classes = self.system_dict["local"]["dataset_train"].num_classes();
if model_name == "resnet18":
Expand All @@ -110,8 +112,14 @@ def Model(self, model_name="resnet18"):
retinanet = model.resnet152(num_classes=num_classes, pretrained=True)

if self.system_dict["params"]["use_gpu"]:
retinanet = retinanet.cuda()
retinanet = torch.nn.DataParallel(retinanet).cuda()
self.system_dict["params"]["gpu_devices"] = gpu_devices
if len(self.system_dict["params"]["gpu_devices"])==1:
os.environ["CUDA_VISIBLE_DEVICES"] = str(self.system_dict["params"]["gpu_devices"][0])
else:
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(id) for id in self.system_dict["params"]["gpu_devices"]])
self.system_dict["local"]["device"] = 'cuda' if torch.cuda.is_available() else 'cpu'
retinanet = retinanet.to(self.system_dict["local"]["device"])
retinanet = torch.nn.DataParallel(retinanet).to(self.system_dict["local"]["device"])

retinanet.training = True
retinanet.train()
Expand Down Expand Up @@ -150,7 +158,7 @@ def Train(self, num_epochs=2, output_model_name="final_model.pt"):
try:
self.system_dict["local"]["optimizer"].zero_grad()

classification_loss, regression_loss = self.system_dict["local"]["model"]([data['img'].cuda().float(), data['annot']])
classification_loss, regression_loss = self.system_dict["local"]["model"]([data['img'].to(self.system_dict["local"]["device"]).float(), data['annot'].to(self.system_dict["local"]["device"])])

classification_loss = classification_loss.mean()
regression_loss = regression_loss.mean()
Expand Down Expand Up @@ -192,4 +200,4 @@ def Train(self, num_epochs=2, output_model_name="final_model.pt"):

self.system_dict["local"]["model"].eval()

torch.save(self.system_dict["local"]["model"], output_model_name)
torch.save(self.system_dict["local"]["model"], output_model_name)

0 comments on commit ebda650

Please sign in to comment.