This repository has been archived by the owner on Mar 29, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
/
main.go
190 lines (170 loc) · 4.85 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
package main
import (
"bytes"
"errors"
"fmt"
"image"
"image/png"
"log"
"os"
"strings"
"github.com/CollActionteam/collaction_backend/utils"
"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-lambda-go/lambda"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/rekognition"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/nfnt/resize"
)
const (
moderationConfidenceThreshold = 0.9
maxFileSize = 5 * 1024 * 1024
minWidth = 250
maxWidth = 1024
preferedSize = 300
)
func checkContent(clientRekognition *rekognition.Rekognition, bucketName *string, key *string) error {
contains := func(slice []string, item string) bool {
for _, candidate := range slice {
if candidate == item {
return true
}
}
return false
}
mlRes, err := clientRekognition.DetectModerationLabels(&rekognition.DetectModerationLabelsInput{
Image: &rekognition.Image{
S3Object: &rekognition.S3Object{
Bucket: bucketName,
Name: key,
},
},
})
if err != nil {
return err
}
// Refer to https://docs.aws.amazon.com/rekognition/latest/dg/moderation.html#moderation-api
unacceptableLabels := []string{"Explicit Nudity", "Violence", "Visually Disturbing", "Hate Symbols"}
for _, label := range mlRes.ModerationLabels {
if *label.Confidence >= moderationConfidenceThreshold {
if contains(unacceptableLabels, *label.Name) || contains(unacceptableLabels, *label.ParentName) {
reason := fmt.Sprintf("Rejected file %s because is might contain %s (%f%% confidence)!\n", *key, *label.Name, *label.Confidence)
return errors.New(reason)
}
}
}
return nil
}
func processImage(imgBytes []byte) ([]byte, error) {
imgCfg, _, err := image.DecodeConfig(bytes.NewReader(imgBytes))
if imgCfg.Width != imgCfg.Height {
return nil, errors.New("image does not have an aspect ratio of 1.00")
}
if imgCfg.Width < minWidth || imgCfg.Width > maxWidth {
return nil, fmt.Errorf("image is not between %d and %d pixels wide", minWidth, maxWidth)
}
img, _, err := image.Decode(bytes.NewReader(imgBytes))
img = resize.Thumbnail(preferedSize, preferedSize, img, resize.Lanczos3)
if err != nil {
return nil, err
}
var f bytes.Buffer
err = png.Encode(&f, img)
if err != nil {
return nil, err
}
return f.Bytes(), nil
}
func handler(e events.S3Event) {
outputBucketName := os.Getenv("OUTPUT_BUCKET_NAME")
sess := session.Must(session.NewSession())
clientS3 := s3.New(sess)
clientRekognition := rekognition.New(sess)
process_object := func(inputBucketName string, key string) {
var err error
defer func() {
// Delete user uploaded image
_, err = clientS3.DeleteObject(&s3.DeleteObjectInput{
Bucket: aws.String(inputBucketName),
Key: aws.String(key),
})
if err != nil {
log.Println(err.Error())
}
}()
// Analyze content
err = checkContent(clientRekognition, &inputBucketName, &key)
if err != nil {
log.Println(err.Error())
return
}
// Download image
dlRes, err := clientS3.GetObject(&s3.GetObjectInput{
Bucket: aws.String(inputBucketName),
Key: aws.String(key),
})
if err != nil {
log.Println(err.Error())
return
}
if *dlRes.ContentLength > int64(maxFileSize) {
log.Printf("Size of file %s exceedes %d bytes!\n", key, maxFileSize)
return
}
var b bytes.Buffer
_, err = b.ReadFrom(dlRes.Body)
defer dlRes.Body.Close()
if err != nil {
log.Println(err.Error())
return
}
// Process image
imgBytes, err := processImage(b.Bytes())
if err != nil {
log.Println(err.Error())
return
}
// Upload image
keyPrefix := os.Getenv("KEY_PREIFX")
path := fmt.Sprintf("%s%s", keyPrefix, key)
_, err = clientS3.PutObject(&s3.PutObjectInput{
Body: bytes.NewReader(imgBytes),
Bucket: aws.String(outputBucketName),
Key: aws.String(path),
ContentType: aws.String("image/png"),
ACL: aws.String("public-read"),
})
if err != nil {
log.Println(err.Error())
}
// Invalidate CDN cache
cloudfrontDist := os.Getenv("CLOUDFRONT_DISTRIBUTION")
if len(cloudfrontDist) > 0 {
if !strings.HasPrefix(path, "/") {
// Path must start with "/"
path = fmt.Sprintf("/%s", path)
}
err = utils.InvalidateCache(cloudfrontDist, path)
if err != nil {
log.Println(err.Error())
}
} else {
log.Println("No CloudFront distribution to invalidate")
}
}
for _, r := range e.Records {
bucketName := r.S3.Bucket.Name
// TODO Maybe check if the bucket is the correct one
// TODO Should it check that the target and source bucket are not the same?
key := r.S3.Object.Key
if strings.HasSuffix(key, ".png") {
lastPrefixSeparatorIdx := strings.LastIndex(key, "/")
keyWithoutPrefix := key[lastPrefixSeparatorIdx+1:]
process_object(bucketName, keyWithoutPrefix)
}
}
}
func main() {
lambda.Start(handler)
}