Skip to content

Commit

Permalink
Merge pull request #6 from THEFASHIONGEEK/master
Browse files Browse the repository at this point in the history
Added Multi-GPU training for efficientdet. Tested and reviewed.
  • Loading branch information
abhi-kumar committed Jan 20, 2020
2 parents 2ebdfbb + a68daf8 commit f19f37e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 14 deletions.
3 changes: 1 addition & 2 deletions 4_efficientdet/installation/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@ torchvision
efficientnet_pytorch


tensorflow-gpu==1.12
tensorboardX
git+https://github.com/abhi-kumar/cocoapi.git#egg=pycocotools&subdirectory=PythonAPI

jupyter
notebook

dicttoxml
xmltodict
xmltodict
34 changes: 22 additions & 12 deletions 4_efficientdet/lib/train_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, verbose=1):
self.system_dict["params"]["batch_size"] = 8;
self.system_dict["params"]["num_workers"] = 3;
self.system_dict["params"]["use_gpu"] = True;
self.system_dict["params"]["gpu_devices"] = [0];
self.system_dict["params"]["lr"] = 0.0001;
self.system_dict["params"]["num_epochs"] = 10;
self.system_dict["params"]["val_interval"] = 1;
Expand Down Expand Up @@ -98,13 +99,22 @@ def Val_Dataset(self, root_dir, coco_dir, img_dir, set_dir):
**self.system_dict["local"]["val_params"])


def Model(self):
self.system_dict["local"]["model"] = EfficientDet(num_classes=self.system_dict["local"]["training_set"].num_classes())
if(self.system_dict["params"]["use_gpu"]):
if torch.cuda.is_available():
self.system_dict["local"]["model"] = self.system_dict["local"]["model"].cuda();
self.system_dict["local"]["model"] = nn.DataParallel(self.system_dict["local"]["model"]);
self.system_dict["local"]["model"].train();
def Model(self,gpu_devices=[0]):
num_classes = self.system_dict["local"]["training_set"].num_classes();
efficientdet = EfficientDet(num_classes=num_classes)

if self.system_dict["params"]["use_gpu"]:
self.system_dict["params"]["gpu_devices"] = gpu_devices
if len(self.system_dict["params"]["gpu_devices"])==1:
os.environ["CUDA_VISIBLE_DEVICES"] = str(self.system_dict["params"]["gpu_devices"][0])
else:
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(id) for id in self.system_dict["params"]["gpu_devices"]])
self.system_dict["local"]["device"] = 'cuda' if torch.cuda.is_available() else 'cpu'
efficientdet = efficientdet.to(self.system_dict["local"]["device"])
efficientdet= torch.nn.DataParallel(efficientdet).to(self.system_dict["local"]["device"])

self.system_dict["local"]["model"] = efficientdet;
self.system_dict["local"]["model"].train();


def Set_Hyperparams(self, lr=0.0001, val_interval=1, es_min_delta=0.0, es_patience=0):
Expand Down Expand Up @@ -149,7 +159,7 @@ def Train(self, num_epochs=2, model_output_dir="trained/"):
try:
self.system_dict["local"]["optimizer"].zero_grad()
if torch.cuda.is_available():
cls_loss, reg_loss = self.system_dict["local"]["model"]([data['img'].cuda().float(), data['annot'].cuda()])
cls_loss, reg_loss = self.system_dict["local"]["model"]([data['img'].to(self.system_dict["local"]["device"]).float(), data['annot'].to(self.system_dict["local"]["device"])])
else:
cls_loss, reg_loss = self.system_dict["local"]["model"]([data['img'].float(), data['annot']])

Expand Down Expand Up @@ -185,7 +195,7 @@ def Train(self, num_epochs=2, model_output_dir="trained/"):
for iter, data in enumerate(self.system_dict["local"]["test_generator"]):
with torch.no_grad():
if torch.cuda.is_available():
cls_loss, reg_loss = self.system_dict["local"]["model"]([data['img'].cuda().float(), data['annot'].cuda()])
cls_loss, reg_loss = self.system_dict["local"]["model"]([data['img'].to(self.system_dict["local"]["device"]).float(), data['annot'].to(self.system_dict["local"]["device"])])
else:
cls_loss, reg_loss = self.system_dict["local"]["model"]([data['img'].float(), data['annot']])

Expand Down Expand Up @@ -246,7 +256,7 @@ def Train(self, num_epochs=2, model_output_dir="trained/"):
try:
self.system_dict["local"]["optimizer"].zero_grad()
if torch.cuda.is_available():
cls_loss, reg_loss = self.system_dict["local"]["model"]([data['img'].cuda().float(), data['annot'].cuda()])
cls_loss, reg_loss = self.system_dict["local"]["model"]([data['img'].to(self.system_dict["local"]["device"]).float(), data['annot'].to(self.system_dict["local"]["device"])])
else:
cls_loss, reg_loss = self.system_dict["local"]["model"]([data['img'].float(), data['annot']])

Expand Down Expand Up @@ -280,7 +290,7 @@ def Train(self, num_epochs=2, model_output_dir="trained/"):

dummy_input = torch.rand(self.system_dict["params"]["batch_size"], 3, 512, 512)
if torch.cuda.is_available():
dummy_input = dummy_input.cuda()
dummy_input = dummy_input.to(self.system_dict["local"]["device"])
if isinstance(self.system_dict["local"]["model"], nn.DataParallel):
self.system_dict["local"]["model"].module.backbone_net.model.set_swish(memory_efficient=False)

Expand All @@ -297,4 +307,4 @@ def Train(self, num_epochs=2, model_output_dir="trained/"):
self.system_dict["local"]["model"].backbone_net.model.set_swish(memory_efficient=True)


writer.close()
writer.close()

0 comments on commit f19f37e

Please sign in to comment.