-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
135 lines (100 loc) · 3.66 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import asyncio
import logging
import os
import time
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack, RTCIceServer, RTCConfiguration
from aiortc.contrib.media import MediaPlayer, MediaRelay
from av.video.frame import VideoFrame
import numpy as np
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from starlette.requests import Request
from starlette.responses import HTMLResponse
from starlette.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
import cv2
from models.models import Offer, Settings
ROOT = os.path.dirname(__file__)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins='*',
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
relay = MediaRelay()
logger = logging.getLogger("pc")
CONTRAST = 1
BRIGHTNESS = 0
SATURATION = 0
class VideoTransformTrack(MediaStreamTrack):
kind = "video"
def __init__(self, track):
super().__init__()
self.track = track
async def recv(self):
global CONTRAST, BRIGHTNESS, SATURATION
frame = await self.track.recv()
img = frame.to_ndarray(format="bgr24")
brightness = BRIGHTNESS
contrast = CONTRAST
img = cv2.convertScaleAbs(img, alpha=contrast, beta=brightness)
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
saturation_range = SATURATION
hsv[:, :, 1] = np.clip(hsv[:, :, 1] + saturation_range, 0, 255)
img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
new_frame = VideoFrame.from_ndarray(img, format="bgr24")
new_frame.pts = frame.pts
new_frame.time_base = frame.time_base
return new_frame
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/settings")
async def set_settings(settings: Settings):
global CONTRAST, BRIGHTNESS, SATURATION
CONTRAST = settings.contrast
BRIGHTNESS = settings.brightness
SATURATION = settings.saturation
return {"message": "ok"}
@app.post("/offer")
async def offer(params: Offer):
offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
pc = RTCPeerConnection(configuration=RTCConfiguration(iceServers=[RTCIceServer(urls="stun:stun.l.google.com:19302", username="test", credential="test")]))
player = None
pcs.add(pc)
if params.video_id == 1:
if params.video_type == "common":
player = MediaPlayer("video1_com.mp4")
elif params.video_type == "subtitle":
player = MediaPlayer("video1_sub.mp4")
else:
if params.video_type == "common":
player = MediaPlayer("video2_com.MP4")
elif params.video_type == "epilepsy":
player = MediaPlayer("video2_epilepsy.avi")
@pc.on("connectionstatechange")
async def on_connectionstatechange():
print("Connection state is %s" % pc.connectionState)
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)
if player.audio:
pc.addTrack(player.audio)
if player.video:
pc.addTrack(VideoTransformTrack(player.video))
await pc.setRemoteDescription(offer)
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
return {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
pcs = set()
args = ''
@app.on_event("shutdown")
async def on_shutdown():
# close peer connections
coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros)
pcs.clear()