-
Notifications
You must be signed in to change notification settings - Fork 0
/
websocket.go
151 lines (138 loc) · 4.21 KB
/
websocket.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
package tlsaudit
import (
"fmt"
"math/rand"
"net/http"
"sort"
"strings"
"time"
"github.com/adedayo/cidr"
"github.com/adedayo/tlsaudit/pkg/model"
"github.com/gorilla/websocket"
)
var (
allowedOrigins = []string{
"auditmate.local:12345",
}
upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
for _, origin := range allowedOrigins {
if origin == r.Host {
return true
}
}
return false
},
}
)
type mydata struct {
Name string
Data []byte
}
//RealtimeScan runs a scan asynchronously and streams result over a websocket
func RealtimeScan(w http.ResponseWriter, req *http.Request) {
if conn, err := upgrader.Upgrade(w, req, nil); err == nil {
go func() {
var request tlsmodel.ScanRequest
if err := conn.ReadJSON(&request); err == nil {
hosts := []string{}
psr := tlsmodel.PersistedScanRequest{}
if request.ScanID == "" { //start a fresh scan
request.ScanID = GetNextScanID()
for _, x := range request.CIDRs {
rng := "/32"
ports := ""
if strings.Contains(x, "/") {
rng = "/" + strings.Split(x, "/")[1]
}
if strings.Contains(x, ":") {
ports = strings.Split(strings.Split(x, "/")[0], ":")[1]
x = strings.Split(x, ":")[0] + rng
} else {
x = strings.Split(x, "/")[0] + rng
}
hs := cidr.Expand(x)
if ports != "" {
for i, h := range hs {
hh := strings.Split(h, "/")
hs[i] = fmt.Sprintf("%s:%s/%s", hh[0], ports, hh[1])
}
}
hosts = append(hosts, hs...)
}
//shuffle hosts randomly
rand.Shuffle(len(hosts), func(i, j int) {
hosts[i], hosts[j] = hosts[j], hosts[i]
})
psr.Hosts = hosts
psr.ScanStart = time.Now()
request.Day = psr.ScanStart.Format(dayFormat)
psr.Request = request
PersistScanRequest(psr)
} else {
//resume an
//LoadScanRequest retrieves persisted scan request from folder following a layout patternexisting scan
psr, err = LoadScanRequest(request.Day, request.ScanID)
if err != nil {
return
}
}
scanID := psr.Request.ScanID
//callback function to stream results over a websocket
callback := func(position int, results []tlsmodel.ScanResult, narrative string) {
// persistScans(fmt.Sprintf("%s;%s:%s", scanID, result.Server, result.Port), result)
res := []tlsmodel.HumanScanResult{}
for _, r := range results {
res = append(res, r.ToStringStruct())
}
out := tlsmodel.ScanProgress{
ScanID: scanID,
Progress: 100 * float32(position) / float32(len(psr.Hosts)),
ScanResults: res,
Narrative: narrative,
}
conn.WriteJSON(out)
}
streamExistingResult(psr, callback)
for index, host := range psr.Hosts {
if index < psr.Progress {
continue
}
position := index + 1
scan := make(map[string]tlsmodel.ScanResult)
results := []<-chan tlsmodel.ScanResult{}
results = append(results, ScanCIDRTLS(host, request.Config))
for result := range MergeResultChannels(results...) {
key := result.Server + result.Port
if _, present := scan[key]; !present {
scan[key] = result
narrative := fmt.Sprintf("Partial scan of %s. Progress %f%% %d hosts of a total of %d in %f seconds\n",
result.Server, 100*float32(position)/float32(len(psr.Hosts)), position, len(psr.Hosts), time.Since(psr.ScanStart).Seconds())
callback(position, []tlsmodel.ScanResult{result}, narrative)
}
}
psr.Progress = position
psr.ScanEnd = time.Now()
PersistScanRequest(psr)
var scanResults []tlsmodel.ScanResult
for k := range scan {
scanResults = append(scanResults, scan[k])
}
sort.Sort(tlsmodel.ScanResultSorter(scanResults))
PersistScans(psr, host, scanResults)
narrative := fmt.Sprintf("Finished scan of %s. Progress %f%% %d hosts of a total of %d in %f seconds\n",
host, 100*float32(position)/float32(len(psr.Hosts)), position, len(psr.Hosts), psr.ScanEnd.Sub(psr.ScanStart).Seconds())
callback(position, scanResults, narrative)
}
} else {
println(err.Error())
return
}
}()
} else {
println(err.Error())
return
}
}