-
Notifications
You must be signed in to change notification settings - Fork 0
/
lewd_detector.py
154 lines (123 loc) · 4.08 KB
/
lewd_detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from nudenet import NudeDetector
from nudenet import NudeClassifier
import cv2
import argparse
import os
import pprint
def proc_opts():
parser = argparse.ArgumentParser('Lewd detector (NudeNet based)')
parser.add_argument('dir', default='Folder to read')
parser.add_argument('--noviz', default=False, action='store_true')
parser.add_argument('--outdir',default=None)
parser.add_argument('--mode',default='base')
parser.add_argument('--probability', type=float, default=0.1, help='Min probability to consider')
return parser.parse_args()
class nudeWrapper():
def __init__(self):
self.detector = NudeDetector()
self.classifier = NudeClassifier()
return
def detect_sketch(self,filename, prob=0.1,mode='base'):
""" We do the detection here"""
results = self.detector.detect(filename,mode=mode,min_prob=prob)
nms_results = self.nms(results)
return nms_results
def classify_sketch(self,filename):
return self.classifier.classify(filename)
def nms(self,boxes,l=0.4):
"""Basic non maximun supression algorithm"""
n_boxes = len(boxes)
keep = []
# O(n2)
while boxes:
box = boxes.pop()
add_box = True
for another_box in boxes:
#if box['label'] == another_box['label']:
iou_val = self.iou(another_box['box'],box['box'])
if iou_val > l:
add_box = False
break
if add_box:
keep.append(box)
print("Pruned {} boxes".format(n_boxes-len(keep)))
return keep
def iou(self,box1,box2):
"""intersection over union"""
# box = [x_1, y_1 , x_2, y_2 ]
x1 = max(box1[0],box2[0])
y1 = max(box1[1],box2[1])
x2 = min(box1[2],box2[2])
y2 = min(box1[3],box2[3])
interArea = max(0,x2-x1+1) * max(0,y2-y1+1)
box1Area = (box1[2] - box1[0])*(box1[3] - box1[1])
box2Area = (box2[2] - box2[0])*(box2[3] - box2[1])
iou = interArea / (box1Area + box2Area - interArea)
return iou
def put_bounding_box(self,image_file,boxes,unsafe_prob,threshold,
viz=True,outfile=None):
"""We need other lib for visualization and processing, using cv2"""
img = cv2.imread(image_file)
for box in boxes:
print("Detection {} with probability {}".format(box['label'],box['score']))
caption = box['label'].capitalize()
score_str = "{:.2f}".format(box['score'])
start_corner = (box['box'][0:2])
end_corner = (box['box'][2:4])
bottom_left = (box['box'][0],box['box'][3])
cv2.putText(img,caption,
start_corner,
cv2.FONT_HERSHEY_COMPLEX,
1,
(247,153,29 ),
2)
cv2.putText(img,score_str,bottom_left,
cv2.FONT_HERSHEY_COMPLEX,
1,
(247,153,29 ),
2)
cv2.rectangle(img,start_corner,end_corner,(114,205,238),2)
title="NSFW score: {:0.4}".format(unsafe_prob)
cv2.putText(img,title,
(0,30),
cv2.FONT_HERSHEY_COMPLEX,
1,
(200,153,29 ),
2)
txt = "Threshold: {:0.4}".format(threshold)
cv2.putText(img,txt,
(0,65),
cv2.FONT_HERSHEY_COMPLEX,
1,
(247,153,29 ),
2)
if viz:
source_window = 'Lewdness detector'
cv2.namedWindow(source_window)
cv2.imshow(source_window, img)
cv2.waitKey()
if outfile:
print("Writing to {}".format(outfile))
cv2.imwrite(outfile,img)
if __name__=="__main__":
args = proc_opts()
imgs = []
for path in os.scandir(args.dir):
if path.is_file():
# Todo check if image
imgs.append(path.name)
print("Will process the following files")
for p in imgs:
print(p)
lwd = nudeWrapper()
for path in imgs:
img = args.dir + path
print("Detecting: {}".format(img))
boxes = lwd.detect_sketch(img,args.probability,args.mode)
print("Classifying: {}".format(img))
classification = lwd.classify_sketch(img)
print(classification)
pckg = classification[img]
lwd.put_bounding_box(img,boxes,pckg['unsafe'],
args.probability,
viz=False,outfile=args.outdir+"processed_"+path)