Skip to content

Commit

Permalink
Updated to work in OSSDC docker, instructions here https://github.com…
Browse files Browse the repository at this point in the history
…/ossdc/self-driving-cars-1/

Inside the docker in run make and then test with:

 darknet detector demo cfg/coco.data cfg/yolo.cfg yolo.weights /sharefolder/ossdc-ps4-the-crew-acc-train-20170202_2031-30fps.mp4

Download yolo.weights:

 wget http://pjreddie.com/media/files/yolo.weights

and then use YouTube Downloader to download the video from here, rename it as above:

 https://www.youtube.com/watch?v=uuQlMCMT71I

This will be part of OSSDC PS3/PS4 Simulatori, more details here:

 https://medium.com/@mslavescu/a-few-updates-on-ai-and-self-driving-cars-df48fdaa0733
  • Loading branch information
mslavescu committed Feb 6, 2017
1 parent 2710d63 commit dc76e56
Show file tree
Hide file tree
Showing 26 changed files with 311 additions and 171 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
*.csv
*.out
*.png
*.jpg
old/
mnist/
data/
caffe/
Expand Down
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
GPU=0
CUDNN=0
OPENCV=0
GPU=1
CUDNN=1
OPENCV=1
DEBUG=0

ARCH= -gencode arch=compute_20,code=[sm_20,sm_21] \
Expand Down
2 changes: 1 addition & 1 deletion cfg/voc.data
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ classes= 20
train = /home/pjreddie/data/voc/train.txt
valid = /home/pjreddie/data/voc/2007_test.txt
names = data/voc.names
backup = /home/pjreddie/backup/
backup = backup

1 change: 1 addition & 0 deletions src/art.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#ifdef OPENCV
#include "opencv2/highgui/highgui_c.h"
#include "opencv2/videoio/videoio_c.h"
image get_image_from_stream(CvCapture *cap);
#endif

Expand Down
28 changes: 28 additions & 0 deletions src/box.c
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,34 @@ int nms_comparator(const void *pa, const void *pb)
return 0;
}

void do_nms_obj(box *boxes, float **probs, int total, int classes, float thresh)
{
int i, j, k;
sortable_bbox *s = calloc(total, sizeof(sortable_bbox));

for(i = 0; i < total; ++i){
s[i].index = i;
s[i].class = classes;
s[i].probs = probs;
}

qsort(s, total, sizeof(sortable_bbox), nms_comparator);
for(i = 0; i < total; ++i){
if(probs[s[i].index][classes] == 0) continue;
box a = boxes[s[i].index];
for(j = i+1; j < total; ++j){
box b = boxes[s[j].index];
if (box_iou(a, b) > thresh){
for(k = 0; k < classes+1; ++k){
probs[s[j].index][k] = 0;
}
}
}
}
free(s);
}


void do_nms_sort(box *boxes, float **probs, int total, int classes, float thresh)
{
int i, j, k;
Expand Down
1 change: 1 addition & 0 deletions src/box.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ float box_rmse(box a, box b);
dbox diou(box a, box b);
void do_nms(box *boxes, float **probs, int total, int classes, float thresh);
void do_nms_sort(box *boxes, float **probs, int total, int classes, float thresh);
void do_nms_obj(box *boxes, float **probs, int total, int classes, float thresh);
box decode_box(box b, box anchor);
box encode_box(box b, box anchor);

Expand Down
2 changes: 2 additions & 0 deletions src/classifier.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#ifdef OPENCV
#include "opencv2/highgui/highgui_c.h"
#include "opencv2/imgproc/imgproc_c.h"
#include "opencv2/videoio/videoio_c.h"
image get_image_from_stream(CvCapture *cap);
#endif

Expand Down
2 changes: 1 addition & 1 deletion src/coco.c
Original file line number Diff line number Diff line change
Expand Up @@ -384,5 +384,5 @@ void run_coco(int argc, char **argv)
else if(0==strcmp(argv[2], "train")) train_coco(cfg, weights);
else if(0==strcmp(argv[2], "valid")) validate_coco(cfg, weights);
else if(0==strcmp(argv[2], "recall")) validate_coco_recall(cfg, weights);
else if(0==strcmp(argv[2], "demo")) demo(cfg, weights, thresh, cam_index, filename, coco_classes, 80, frame_skip, prefix);
else if(0==strcmp(argv[2], "demo")) demo(cfg, weights, thresh, cam_index, filename, coco_classes, 80, frame_skip, prefix, .5);
}
1 change: 1 addition & 0 deletions src/convolutional_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)

void backward_convolutional_layer_gpu(convolutional_layer l, network_state state)
{
//constrain_ongpu(l.outputs*l.batch, 1, l.delta_gpu, 1);
gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu);

backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h);
Expand Down
8 changes: 5 additions & 3 deletions src/darknet.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#endif

extern void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filename, int top);
extern void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh);
extern void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh, float hier_thresh);
extern void run_voxel(int argc, char **argv);
extern void run_yolo(int argc, char **argv);
extern void run_detector(int argc, char **argv);
Expand Down Expand Up @@ -129,7 +129,9 @@ void oneoff(char *cfgfile, char *weightfile, char *outfile)
network net = parse_network_cfg(cfgfile);
int oldn = net.layers[net.n - 2].n;
int c = net.layers[net.n - 2].c;
net.layers[net.n - 2].n = 9372;
scal_cpu(oldn*c, .1, net.layers[net.n - 2].weights, 1);
scal_cpu(oldn, 0, net.layers[net.n - 2].biases, 1);
net.layers[net.n - 2].n = 9418;
net.layers[net.n - 2].biases += 5;
net.layers[net.n - 2].weights += 5*c;
if(weightfile){
Expand Down Expand Up @@ -383,7 +385,7 @@ int main(int argc, char **argv)
} else if (0 == strcmp(argv[1], "detect")){
float thresh = find_float_arg(argc, argv, "-thresh", .24);
char *filename = (argc > 4) ? argv[4]: 0;
test_detector("cfg/coco.data", argv[2], argv[3], filename, thresh);
test_detector("cfg/coco.data", argv[2], argv[3], filename, thresh, .5);
} else if (0 == strcmp(argv[1], "cifar")){
run_cifar(argc, argv);
} else if (0 == strcmp(argv[1], "go")){
Expand Down
4 changes: 2 additions & 2 deletions src/data.c
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ void fill_truth_region(char *path, float *truth, int classes, int num_boxes, int
h = boxes[i].h;
id = boxes[i].id;

if (w < .01 || h < .01) continue;
if (w < .005 || h < .005) continue;

int col = (int)(x*num_boxes);
int row = (int)(y*num_boxes);
Expand Down Expand Up @@ -317,7 +317,7 @@ void fill_truth_detection(char *path, int num_boxes, float *truth, int classes,
h = boxes[i].h;
id = boxes[i].id;

if ((w < .01 || h < .01)) continue;
if ((w < .005 || h < .005)) continue;

truth[i*5+0] = x;
truth[i*5+1] = y;
Expand Down
22 changes: 17 additions & 5 deletions src/demo.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#ifdef OPENCV
#include "opencv2/highgui/highgui_c.h"
#include "opencv2/imgproc/imgproc_c.h"
#include "opencv2/videoio/videoio_c.h"
image get_image_from_stream(CvCapture *cap);

static char **demo_names;
Expand All @@ -31,6 +32,7 @@ static image disp = {0};
static CvCapture * cap;
static float fps = 0;
static float demo_thresh = 0;
static float demo_hier_thresh = .5;

static float *predictions[FRAMES];
static int demo_index = 0;
Expand Down Expand Up @@ -63,14 +65,15 @@ void *detect_in_thread(void *ptr)
if(l.type == DETECTION){
get_detection_boxes(l, 1, 1, demo_thresh, probs, boxes, 0);
} else if (l.type == REGION){
get_region_boxes(l, 1, 1, demo_thresh, probs, boxes, 0, 0);
get_region_boxes(l, 1, 1, demo_thresh, probs, boxes, 0, 0, demo_hier_thresh);
} else {
error("Last layer must produce detections\n");
}
if (nms > 0) do_nms(boxes, probs, l.w*l.h*l.n, l.classes, nms);
printf("\033[2J");
printf("\033[1;1H");
printf("\nFPS:%.1f\n",fps);
printf("net.w: %d net.h: %d\n", net.w,net.h);
printf("Objects:\n\n");

images[demo_index] = det;
Expand All @@ -91,7 +94,7 @@ double get_wall_time()
return (double)time.tv_sec + (double)time.tv_usec * .000001;
}

void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename, char **names, int classes, int frame_skip, char *prefix)
void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename, char **names, int classes, int frame_skip, char *prefix, float hier_thresh)
{
//skip = frame_skip;
image **alphabet = load_alphabet();
Expand All @@ -100,6 +103,7 @@ void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const ch
demo_alphabet = alphabet;
demo_classes = classes;
demo_thresh = thresh;
demo_hier_thresh = hier_thresh;
printf("Demo\n");
net = parse_network_cfg(cfgfile);
if(weightfile){
Expand All @@ -113,6 +117,7 @@ void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const ch
printf("video file: %s\n", filename);
cap = cvCaptureFromFile(filename);
}else{
printf("capture index: %d\n", cam_index);
cap = cvCaptureFromCAM(cam_index);
}

Expand All @@ -121,13 +126,17 @@ void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const ch
layer l = net.layers[net.n-1];
int j;

printf("net.w: %d net.h: %d\n", net.w,net.h);



avg = (float *) calloc(l.outputs, sizeof(float));
for(j = 0; j < FRAMES; ++j) predictions[j] = (float *) calloc(l.outputs, sizeof(float));
for(j = 0; j < FRAMES; ++j) images[j] = make_image(1,1,3);

boxes = (box *)calloc(l.w*l.h*l.n, sizeof(box));
probs = (float **)calloc(l.w*l.h*l.n, sizeof(float *));
for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = (float *)calloc(l.classes, sizeof(float *));
for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = (float *)calloc(l.classes, sizeof(float));

pthread_t fetch_thread;
pthread_t detect_thread;
Expand All @@ -154,9 +163,12 @@ void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const ch
if(!prefix){
cvNamedWindow("Demo", CV_WINDOW_NORMAL);
cvMoveWindow("Demo", 0, 0);
cvResizeWindow("Demo", 1352, 1013);
//cvResizeWindow("Demo", 1352, 1013);
//cvResizeWindow("Demo", 640, 480); //TODO MS must parametrize this
cvResizeWindow("Demo", 1280, 720); //TODO MS must parametrize this
}


double before = get_wall_time();

while(1){
Expand Down Expand Up @@ -213,7 +225,7 @@ void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const ch
}
}
#else
void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename, char **names, int classes, int frame_skip, char *prefix)
void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename, char **names, int classes, int frame_skip, char *prefix, float hier_thresh)
{
fprintf(stderr, "Demo needs OpenCV for webcam images.\n");
}
Expand Down
2 changes: 1 addition & 1 deletion src/demo.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
#define DEMO

#include "image.h"
void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename, char **names, int classes, int frame_skip, char *prefix);
void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename, char **names, int classes, int frame_skip, char *prefix, float hier_thresh);

#endif
35 changes: 20 additions & 15 deletions src/detector.c
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
if(l.random && count++%10 == 0){
printf("Resizing\n");
int dim = (rand() % 10 + 10) * 32;
if (get_current_batch(net)+100 > net.max_batches) dim = 544;
if (get_current_batch(net)+200 > net.max_batches) dim = 608;
//int dim = (rand() % 4 + 16) * 32;
printf("%d\n", dim);
args.w = dim;
Expand Down Expand Up @@ -231,7 +231,7 @@ void print_imagenet_detections(FILE *fp, int id, box *boxes, float **probs, int
}
}

void validate_detector(char *datacfg, char *cfgfile, char *weightfile)
void validate_detector(char *datacfg, char *cfgfile, char *weightfile, char *outfile)
{
int j;
list *options = read_data_cfg(datacfg);
Expand All @@ -251,7 +251,6 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile)
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
srand(time(0));

char *base = "comp4_det_test_";
list *plist = get_paths(valid_images);
char **paths = (char **)list_to_array(plist);

Expand All @@ -265,19 +264,22 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile)
int coco = 0;
int imagenet = 0;
if(0==strcmp(type, "coco")){
snprintf(buff, 1024, "%s/coco_results.json", prefix);
if(!outfile) outfile = "coco_results";
snprintf(buff, 1024, "%s/%s.json", prefix, outfile);
fp = fopen(buff, "w");
fprintf(fp, "[\n");
coco = 1;
} else if(0==strcmp(type, "imagenet")){
snprintf(buff, 1024, "%s/imagenet-detection.txt", prefix);
if(!outfile) outfile = "imagenet-detection";
snprintf(buff, 1024, "%s/%s.txt", prefix, outfile);
fp = fopen(buff, "w");
imagenet = 1;
classes = 200;
} else {
if(!outfile) outfile = "comp4_det_test_";
fps = calloc(classes, sizeof(FILE *));
for(j = 0; j < classes; ++j){
snprintf(buff, 1024, "%s/%s%s.txt", prefix, base, names[j]);
snprintf(buff, 1024, "%s/%s%s.txt", prefix, outfile, names[j]);
fps[j] = fopen(buff, "w");
}
}
Expand Down Expand Up @@ -333,7 +335,7 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile)
network_predict(net, X);
int w = val[t].w;
int h = val[t].h;
get_region_boxes(l, w, h, thresh, probs, boxes, 0, map);
get_region_boxes(l, w, h, thresh, probs, boxes, 0, map, .5);
if (nms) do_nms_sort(boxes, probs, l.w*l.h*l.n, classes, nms);
if (coco){
print_cocos(fp, path, boxes, probs, l.w*l.h*l.n, classes, w, h);
Expand Down Expand Up @@ -397,7 +399,7 @@ void validate_detector_recall(char *cfgfile, char *weightfile)
image sized = resize_image(orig, net.w, net.h);
char *id = basecfg(path);
network_predict(net, sized.data);
get_region_boxes(l, 1, 1, thresh, probs, boxes, 1, 0);
get_region_boxes(l, 1, 1, thresh, probs, boxes, 1, 0, .5);
if (nms) do_nms(boxes, probs, l.w*l.h*l.n, 1, nms);

char labelpath[4096];
Expand Down Expand Up @@ -436,7 +438,7 @@ void validate_detector_recall(char *cfgfile, char *weightfile)
}
}

void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh)
void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh, float hier_thresh)
{
list *options = read_data_cfg(datacfg);
char *name_list = option_find_str(options, "names", "data/names.list");
Expand Down Expand Up @@ -470,14 +472,15 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam

box *boxes = calloc(l.w*l.h*l.n, sizeof(box));
float **probs = calloc(l.w*l.h*l.n, sizeof(float *));
for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(l.classes, sizeof(float *));
for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(l.classes + 1, sizeof(float *));

float *X = sized.data;
time=clock();
network_predict(net, X);
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
get_region_boxes(l, 1, 1, thresh, probs, boxes, 0, 0);
if (nms) do_nms_sort(boxes, probs, l.w*l.h*l.n, l.classes, nms);
get_region_boxes(l, 1, 1, thresh, probs, boxes, 0, 0, hier_thresh);
if (l.softmax_tree && nms) do_nms_obj(boxes, probs, l.w*l.h*l.n, l.classes, nms);
else if (nms) do_nms_sort(boxes, probs, l.w*l.h*l.n, l.classes, nms);
draw_detections(im, l.w*l.h*l.n, thresh, boxes, probs, names, alphabet, l.classes);
save_image(im, "predictions");
show_image(im, "predictions");
Expand All @@ -498,13 +501,15 @@ void run_detector(int argc, char **argv)
{
char *prefix = find_char_arg(argc, argv, "-prefix", 0);
float thresh = find_float_arg(argc, argv, "-thresh", .24);
float hier_thresh = find_float_arg(argc, argv, "-hier", .5);
int cam_index = find_int_arg(argc, argv, "-c", 0);
int frame_skip = find_int_arg(argc, argv, "-s", 0);
if(argc < 4){
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
return;
}
char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);
char *outfile = find_char_arg(argc, argv, "-out", 0);
int *gpus = 0;
int gpu = 0;
int ngpus = 0;
Expand Down Expand Up @@ -533,15 +538,15 @@ void run_detector(int argc, char **argv)
char *cfg = argv[4];
char *weights = (argc > 5) ? argv[5] : 0;
char *filename = (argc > 6) ? argv[6]: 0;
if(0==strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh);
if(0==strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh, hier_thresh);
else if(0==strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear);
else if(0==strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights);
else if(0==strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights, outfile);
else if(0==strcmp(argv[2], "recall")) validate_detector_recall(cfg, weights);
else if(0==strcmp(argv[2], "demo")) {
list *options = read_data_cfg(datacfg);
int classes = option_find_int(options, "classes", 20);
char *name_list = option_find_str(options, "names", "data/names.list");
char **names = get_labels(name_list);
demo(cfg, weights, thresh, cam_index, filename, names, classes, frame_skip, prefix);
demo(cfg, weights, thresh, cam_index, filename, names, classes, frame_skip, prefix, hier_thresh);
}
}
Loading

0 comments on commit dc76e56

Please sign in to comment.