diff --git a/admin-web/package-lock.json b/admin-web/package-lock.json index 9a3cb11..7ef2f6f 100644 --- a/admin-web/package-lock.json +++ b/admin-web/package-lock.json @@ -1,12 +1,12 @@ { "name": "admin-web", - "version": "0.0.1", + "version": "0.0.3", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "admin-web", - "version": "0.0.1", + "version": "0.0.3", "dependencies": { "@reduxjs/toolkit": "^1.8.6", "@testing-library/jest-dom": "^5.16.5", diff --git a/admin-web/src/api/authAPI.ts b/admin-web/src/api/authAPI.ts index 2e4e45b..24b6700 100644 --- a/admin-web/src/api/authAPI.ts +++ b/admin-web/src/api/authAPI.ts @@ -1,4 +1,4 @@ -import http from "./http"; +import http, {ID} from "./http"; export function checkSampleTokenCorrect(token: string) { return http.post('/auth/st/check', { @@ -7,7 +7,7 @@ export function checkSampleTokenCorrect(token: string) { } -export function getSSOInviteCode(networkId:number):Promise { +export function getSSOInviteCode(networkId:ID):Promise { // @ts-ignore return http.get(`/auth/oauth/${networkId}/device_code`, { // @ts-ignore diff --git a/admin-web/src/api/http.ts b/admin-web/src/api/http.ts index ff3f090..29b408c 100644 --- a/admin-web/src/api/http.ts +++ b/admin-web/src/api/http.ts @@ -7,10 +7,13 @@ const instance = axios.create({ timeout: 5 * 1000, }) + +export type ID = string; export interface CreatedSuccess { - id: number + id: ID } + export interface Page { total: number list: T[] diff --git a/admin-web/src/api/networkAPI.ts b/admin-web/src/api/networkAPI.ts index dd3413a..d52c029 100644 --- a/admin-web/src/api/networkAPI.ts +++ b/admin-web/src/api/networkAPI.ts @@ -1,15 +1,19 @@ -import http, {CreatedSuccess, Page} from "./http"; +import http, {CreatedSuccess, ID, Page} from "./http"; +export enum NetworkProtocol { + TCP, UDP +} export interface NetworkSetting { mtu: number, keepAlive: number, dns?: string, port: number, + protocol: NetworkProtocol, } export interface Network { - id: number, + id: ID, name: string, setting: NetworkSetting, createdAt: string, @@ -31,25 +35,26 @@ export function networkList(name: string | null, page: number = 1, pageSize: num export interface CreateNetwork { name: string, addressRange: string, + protocol: NetworkProtocol, } export function createNetwork(data: CreateNetwork) { return http.post('/network', data) } -export function updateNetwork(id: number, data: Network) { +export function updateNetwork(id: ID, data: Network) { return http.put(`/network/${id}`, data) } -export function getNetwork(id: number) { +export function getNetwork(id: ID) { return http.get(`/network/${id}`).then(r => r.data) } -export function getNetworkInviteCode(networkId: number) { +export function getNetworkInviteCode(networkId: ID) { return http.get(`/network/${networkId}/invite_code`).then(r => r.data) } -export function deleteNetwork(id:number) { +export function deleteNetwork(id:ID) { return http.delete(`/network/${id}`) } diff --git a/admin-web/src/api/nodeAPI.ts b/admin-web/src/api/nodeAPI.ts index 493a103..a9b358c 100644 --- a/admin-web/src/api/nodeAPI.ts +++ b/admin-web/src/api/nodeAPI.ts @@ -1,4 +1,4 @@ -import http, {CreatedSuccess} from "./http"; +import http, {CreatedSuccess, ID} from "./http"; export enum NodeStatus { @@ -10,11 +10,11 @@ export enum NodeType { } export interface Node { - id: number, + id: ID, nodeType: NodeType, status: NodeStatus, setting: NodeSetting, - networkId: number, + networkId: ID, name: string, } @@ -34,7 +34,7 @@ export interface UpdateNode { setting: NodeSetting, } -export function getNodeList(networkId: number, page: number = 1, pageSize: number = 10) { +export function getNodeList(networkId: ID, page: number = 1, pageSize: number = 10) { return http.get(`/node/${networkId}`, { params: { page, @@ -50,22 +50,22 @@ export interface CreateNode { setting: NodeSetting, } -export function createNode(networkId: number, data: CreateNode) { +export function createNode(networkId: ID, data: CreateNode) { return http.post(`/node/${networkId}`, data) } -export function getNode(networkId: number, nodeId: number) { +export function getNode(networkId: ID, nodeId: ID) { return http.get(`/node/${networkId}/${nodeId}`).then(r => r.data) } -export function updateNode(networkId: number, nodeId: number, data: UpdateNode) { +export function updateNode(networkId: ID, nodeId: ID, data: UpdateNode) { return http.put(`/node/${networkId}/${nodeId}`, data) } -export function getNodeActiveCode(networkId: number, nodeId: number) { +export function getNodeActiveCode(networkId: ID, nodeId: ID) { return http.get(`/node/${networkId}/${nodeId}/active_code`).then(r => r.data) } -export function updateNodeStatus(networkId: number, nodeId: number, status: NodeStatus.Forbid | NodeStatus.Normal) { +export function updateNodeStatus(networkId: ID, nodeId: ID, status: NodeStatus.Forbid | NodeStatus.Normal) { return http.put(`/node/${networkId}/${nodeId}/status`, {status}) } diff --git a/admin-web/src/view/network/CreateNetworkPage.tsx b/admin-web/src/view/network/CreateNetworkPage.tsx index 1e79ef4..6994ef3 100644 --- a/admin-web/src/view/network/CreateNetworkPage.tsx +++ b/admin-web/src/view/network/CreateNetworkPage.tsx @@ -1,9 +1,9 @@ -import {Button, Form, Input, message} from "antd"; +import {Button, Form, Input, message, Select} from "antd"; import {useIntl} from "react-intl"; -import {CreateNetwork, createNetwork} from "../../api/networkAPI"; +import {CreateNetwork, createNetwork, NetworkProtocol} from "../../api/networkAPI"; import {useNavigate} from "react-router-dom"; - +const {Option} = Select; export function CreateNetworkPage() { const [form] = Form.useForm() const navi = useNavigate() @@ -23,6 +23,9 @@ export function CreateNetworkPage() { form={form} labelCol={{span: 8}} wrapperCol={{span: 16}} + initialValues={{ + protocol: NetworkProtocol.UDP + }} > + + + diff --git a/admin-web/src/view/network/NetworkDetailPage.tsx b/admin-web/src/view/network/NetworkDetailPage.tsx index c043c9e..bd099cc 100644 --- a/admin-web/src/view/network/NetworkDetailPage.tsx +++ b/admin-web/src/view/network/NetworkDetailPage.tsx @@ -1,10 +1,11 @@ import {useEffect} from "react"; -import {getNetwork, Network, updateNetwork} from "../../api/networkAPI"; +import {getNetwork, Network, NetworkProtocol, updateNetwork} from "../../api/networkAPI"; import {useParams} from "react-router-dom"; -import {Button, Col, Form, Input, InputNumber, Row} from "antd"; +import {Button, Col, Form, Input, InputNumber, message, Row, Select} from "antd"; import {useIntl} from "react-intl"; import {useForm} from "antd/es/form/Form"; +const {Option} = Select; export default function NetworkDetailPage() { const {networkId} = useParams<{ networkId: string }>() @@ -12,14 +13,15 @@ export default function NetworkDetailPage() { const [form] = useForm() useEffect(() => { - getNetwork(parseInt(networkId!)).then((d) => { + getNetwork(networkId!).then((d) => { form.setFieldsValue(d) }) }, [networkId, form]) const submit = async () => { const data = await form.validateFields() - await updateNetwork(parseInt(networkId!), data) + await updateNetwork(networkId!, data) + message.info(intl.formatMessage({id: 'result.updateSuccess'}, {'0': intl.formatMessage({id: 'nav.network'})})) } return ( <> @@ -62,6 +64,15 @@ export default function NetworkDetailPage() { max={600}/> + + + + +
@@ -70,6 +81,7 @@ export default function NetworkDetailPage() { ) } +// add Nodes Navigator, Invite Code // // // \ No newline at end of file diff --git a/admin-web/src/view/network/NetworkListPage.tsx b/admin-web/src/view/network/NetworkListPage.tsx index 3e07008..f9cd9df 100644 --- a/admin-web/src/view/network/NetworkListPage.tsx +++ b/admin-web/src/view/network/NetworkListPage.tsx @@ -5,7 +5,7 @@ import {deleteNetwork, getNetworkInviteCode, Network, networkList} from "../../a import {useEffect, useState} from "react"; import DayjsFormat from "../../component/DayjsFormat"; import {Link, useSearchParams} from "react-router-dom"; -import {defaultPage, Page} from "../../api/http"; +import {defaultPage, ID, Page} from "../../api/http"; import {QRCodeCanvas} from "qrcode.react"; import copy from "copy-to-clipboard"; import {getPersistenceToken, getSSOInviteCode} from "../../api/authAPI"; @@ -23,13 +23,13 @@ export default function NetworkListPage() { networkList(name, page).then((d) => setData(d)) }, [name, page]) - const showInviteModel = (id: number) => { + const showInviteModel = (id: ID) => { if((getPersistenceToken()??'').startsWith('Bearer')) { getSSOInviteCode(id).then(r => setShowSSO(r)) } getNetworkInviteCode(id).then(setShowModal) } - const deleteNetworkAction = async (id: number) => { + const deleteNetworkAction = async (id: ID) => { await deleteNetwork(id) message.info(intl.formatMessage({id: 'result.deleteSuccess'}, {'0': intl.formatMessage({id: 'nav.network'})})) diff --git a/admin-web/src/view/node/CreateNodePage.tsx b/admin-web/src/view/node/CreateNodePage.tsx index 67452b1..600a5bc 100644 --- a/admin-web/src/view/node/CreateNodePage.tsx +++ b/admin-web/src/view/node/CreateNodePage.tsx @@ -16,7 +16,7 @@ export function CreateNodePage() { const intl = useIntl() useEffect(() => { if (networkId) { - getNetwork(parseInt(networkId)).then(r => setNetwork(r)) + getNetwork(networkId).then(r => setNetwork(r)) } }, [networkId]) @@ -36,7 +36,7 @@ export function CreateNodePage() { delete data.ip } - const {id} = (await createNode(parseInt(networkId as string), data)).data + const {id} = (await createNode(networkId!, data)).data message.info(intl.formatMessage({id: 'result.createSuccess'}, {'0': intl.formatMessage({id: 'nav.node'})})) navi(`/network/${networkId}/node/${id}`, {replace: true}) } diff --git a/admin-web/src/view/node/NodeDetailPage.tsx b/admin-web/src/view/node/NodeDetailPage.tsx index ee09a63..9e5ea93 100644 --- a/admin-web/src/view/node/NodeDetailPage.tsx +++ b/admin-web/src/view/node/NodeDetailPage.tsx @@ -71,8 +71,7 @@ export default function NodeDetailPage() { useEffect(() => { if (networkId && nodeId) { - const networkIdNum = parseInt(networkId) - Promise.all([getNetwork(networkIdNum), getNode(networkIdNum, parseInt(nodeId))]).then(r => { + Promise.all([getNetwork(networkId), getNode(networkId, nodeId)]).then(r => { setNetwork(r[0]) form.setFieldsValue(r[1]) }) @@ -91,7 +90,7 @@ export default function NodeDetailPage() { name: data.name, setting: data.setting, } - await updateNode(parseInt(networkId as string), parseInt(nodeId as string), updateData) + await updateNode(networkId!, nodeId!, updateData) message.info(intl.formatMessage({id: 'result.updateSuccess'}, {'0': intl.formatMessage({id: 'nav.node'})})) } diff --git a/admin-web/src/view/node/NodeListPage.tsx b/admin-web/src/view/node/NodeListPage.tsx index a13cad5..da5c5c4 100644 --- a/admin-web/src/view/node/NodeListPage.tsx +++ b/admin-web/src/view/node/NodeListPage.tsx @@ -7,6 +7,7 @@ import {useIntl} from "react-intl"; import {enumToDesc} from "../../local/intl"; import {QRCodeCanvas} from "qrcode.react"; import copy from "copy-to-clipboard"; +import {ID} from "../../api/http"; export default function NodeListPage() { const {networkId} = useParams<{ networkId: string }>() @@ -14,14 +15,14 @@ export default function NodeListPage() { const [data, setData] = useState([]) const intl = useIntl() useEffect(() => { - getNodeList(parseInt(networkId!)).then((d) => setData(d.data)) + getNodeList(networkId!).then((d) => setData(d.data)) }, [networkId]) - const showActiveModalAction = async (networkId: number, nodeId: number) => { + const showActiveModalAction = async (networkId: ID, nodeId: ID) => { const activeCode = await getNodeActiveCode(networkId, nodeId) setShowActiveCode(activeCode) } - const updateNodeStatusAction = async (networkId: number, nodeId: number, status: NodeStatus.Forbid | NodeStatus.Normal) => { + const updateNodeStatusAction = async (networkId: ID, nodeId: ID, status: NodeStatus.Forbid | NodeStatus.Normal) => { await updateNodeStatus(networkId, nodeId, status) message.info(intl.formatMessage({id: 'result.updateSuccess'}, {'0': intl.formatMessage({id: 'status'})})) getNodeList(networkId).then((d) => setData(d.data)) diff --git a/backend/README.md b/backend/README.md index 7274703..7923252 100644 --- a/backend/README.md +++ b/backend/README.md @@ -2,4 +2,5 @@ If Scala 3 is a good choice? ### TODO -- [ ] create IpArrange util to handle all, should be careful use. \ No newline at end of file +- [ ] create IpArrange util to handle all, should be careful use. +- [ ] needs async runtime(thread pool) to handle push. \ No newline at end of file diff --git a/backend/build.sbt b/backend/build.sbt index 916fbf6..ab13789 100644 --- a/backend/build.sbt +++ b/backend/build.sbt @@ -13,6 +13,9 @@ Compile / PB.targets := Seq( ) Compile / PB.protoSources += file("../protobuf") +// zio-json default value needs this +ThisBuild / scalacOptions ++= Seq("-Yretain-trees") + import Dependencies._ lazy val app = project diff --git a/backend/src/main/resources/application.conf b/backend/src/main/resources/application.conf index 8525e6d..e40a8b8 100644 --- a/backend/src/main/resources/application.conf +++ b/backend/src/main/resources/application.conf @@ -31,6 +31,7 @@ mqtt { # should set keycloak or simple. +# please set this in private.conf wehn develop. auth { # ref from keycloak config, you can download it from keycloak/realm/client #keycloak { @@ -47,5 +48,5 @@ auth { # userId: "admin" #} } - +# you can set your private config in private.conf file include "private.conf" \ No newline at end of file diff --git a/backend/src/main/scala/com/timzaak/fornet/controller/AuthController.scala b/backend/src/main/scala/com/timzaak/fornet/controller/AuthController.scala index b018164..978e345 100644 --- a/backend/src/main/scala/com/timzaak/fornet/controller/AuthController.scala +++ b/backend/src/main/scala/com/timzaak/fornet/controller/AuthController.scala @@ -3,7 +3,7 @@ package com.timzaak.fornet.controller import com.google.common.base.Charsets import com.timzaak.fornet.config.AppConfig import com.timzaak.fornet.controller.auth.AppAuthSupport -import com.timzaak.fornet.dao.{ NetworkDao, NetworkStatus } +import com.timzaak.fornet.dao.{NetworkDao, NetworkStatus} import com.timzaak.fornet.di.DI.hashId import com.typesafe.config.Config import org.hashids.Hashids diff --git a/backend/src/main/scala/com/timzaak/fornet/controller/NetworkController.scala b/backend/src/main/scala/com/timzaak/fornet/controller/NetworkController.scala index d449cc1..ad83167 100644 --- a/backend/src/main/scala/com/timzaak/fornet/controller/NetworkController.scala +++ b/backend/src/main/scala/com/timzaak/fornet/controller/NetworkController.scala @@ -3,7 +3,8 @@ package com.timzaak.fornet.controller import com.google.common.net.InetAddresses import com.timzaak.fornet.config.AppConfig import com.timzaak.fornet.controller.auth.AppAuthSupport -import com.timzaak.fornet.dao.{ DB, Network, NetworkDao, NetworkSetting } +import com.timzaak.fornet.dao.{DB, Network, NetworkDao, NetworkProtocol, NetworkSetting} +import com.timzaak.fornet.pubsub.NodeChangeNotifyService import com.typesafe.config.Config import org.hashids.Hashids import very.util.security.IntID @@ -22,7 +23,7 @@ import zio.json.{ DeriveJsonDecoder, JsonDecoder } import java.time.OffsetDateTime -case class CreateNetworkReq(name: String, addressRange: String) +case class CreateNetworkReq(name: String, addressRange: String, protocol:NetworkProtocol) given JsonDecoder[CreateNetworkReq] = DeriveJsonDecoder.gen case class UpdateNetworkReq( name: String, @@ -32,6 +33,7 @@ case class UpdateNetworkReq( given JsonDecoder[UpdateNetworkReq] = DeriveJsonDecoder.gen trait NetworkController( networkDao: NetworkDao, + nodeChangeNotifyService: NodeChangeNotifyService, appConfig: AppConfig, )(using quill: DB, config: Config, hashId: Hashids) extends Controller @@ -53,9 +55,12 @@ trait NetworkController( .filter(_.groupId == lift(groupId)) )(_.sortBy(_.id)(Ord.desc)) case _ => - pageWithCount( + val r= pageWithCount( query[Network].filter(_.status == lift(NetworkStatus.Normal)) )(_.sortBy(_.id)(Ord.desc)) + import zio.json.* + val j = r.toJson + r } } @@ -75,7 +80,7 @@ trait NetworkController( .insert( _.name -> lift(req.name), _.addressRange -> lift(req.addressRange), - _.setting -> lift(NetworkSetting()), + _.setting -> lift(NetworkSetting(protocol = req.protocol)), _.groupId -> lift(groupId), ) .returning(_.id) @@ -113,17 +118,22 @@ trait NetworkController( for { _ <- ipV4Range(data.addressRange) } yield { - quill.run { - quote { - query[Network] - .filter(n => n.id == lift(id) && n.groupId == lift(groupId)) - .update( - _.name -> lift(data.name), - _.addressRange -> lift(data.addressRange), - _.setting -> lift(data.setting), - _.updatedAt -> lift(OffsetDateTime.now()), - ) - } + networkDao.findById(id) match { + case Some(oldNetwork) => + quill.run { + quote { + query[Network] + .filter(n => n.id == lift(id) && n.groupId == lift(groupId)) + .update( + _.name -> lift(data.name), + _.addressRange -> lift(data.addressRange), + _.setting -> lift(data.setting), + _.updatedAt -> lift(OffsetDateTime.now()), + ) + } + } + nodeChangeNotifyService.networkSettingChange(oldNetwork, networkDao.findById(id).get) + case _ => } Accepted() } @@ -142,9 +152,7 @@ trait NetworkController( } } if (changeCount > 0) { - // TODO: - // kickoff all nodes in the network - // change all node status to Deleted + nodeChangeNotifyService.networkDeleteNotify(networkId) } Accepted() } diff --git a/backend/src/main/scala/com/timzaak/fornet/controller/NodeController.scala b/backend/src/main/scala/com/timzaak/fornet/controller/NodeController.scala index 6f117f8..e91a675 100644 --- a/backend/src/main/scala/com/timzaak/fornet/controller/NodeController.scala +++ b/backend/src/main/scala/com/timzaak/fornet/controller/NodeController.scala @@ -1,7 +1,7 @@ package com.timzaak.fornet.controller import com.timzaak.fornet.config.AppConfig -import com.timzaak.fornet.controller.auth.{ AppAuthSupport, User } +import com.timzaak.fornet.controller.auth.{AppAuthSupport, User} import com.timzaak.fornet.dao.* import com.timzaak.fornet.grpc.convert.EntityConvert import com.timzaak.fornet.pubsub.NodeChangeNotifyService @@ -110,9 +110,9 @@ trait NodeController( } } - if (oldNode.setting != req.setting && oldNode.status == NodeStatus.Normal) { - // notify self change - nodeChangeNotifyService.nodeInfoChangeNotify(oldNode, req.setting) + val network = networkDao.findById(oldNode.networkId).get + if (oldNode.setting != req.setting && oldNode.realStatus(network.status) == NodeStatus.Normal) { + nodeChangeNotifyService.nodeInfoChangeNotify(oldNode, req.setting, network) } Accepted() } @@ -140,26 +140,29 @@ trait NodeController( } } if (changeNumber > 0) { - nodeChangeNotifyService.nodeStatusChangeNotify( - oldNode, - oldNode.status, - req.status - ) + val network = networkDao.findById(networkId).get + if (network.status == NetworkStatus.Normal) { + nodeChangeNotifyService.nodeStatusChangeNotify( + oldNode, + oldNode.status, + req.status + ) + } } Accepted() } } get("/:networkId/:nodeId/active_code") { + val (_, networkId) = checkAuth + val nodeId = _nodeId nodeDao - .findById(_networkId, _nodeId) + .findById(networkId, nodeId) .filter(_.status == NodeStatus.Waiting) .map { _ => String( Base64.getEncoder.encode( - s"1|${config.getString("server.grpc.endpoint")}|${hashId.encode( - params("networkId").toLong - )}|${hashId.encode(params("nodeId").toLong)}".getBytes + s"1|${config.getString("server.grpc.endpoint")}|${networkId.secretId}|${nodeId.secretId}".getBytes ) ) } diff --git a/backend/src/main/scala/com/timzaak/fornet/dao/DB.scala b/backend/src/main/scala/com/timzaak/fornet/dao/DB.scala index 510e3c6..dce1cc8 100644 --- a/backend/src/main/scala/com/timzaak/fornet/dao/DB.scala +++ b/backend/src/main/scala/com/timzaak/fornet/dao/DB.scala @@ -1,20 +1,14 @@ package com.timzaak.fornet.dao -import com.timzaak.fornet.dao.Network import io.getquill.* -import io.getquill.context.jdbc.{Decoders, Encoders} -import very.util.persistence.quill.IDSupport -//import org.json4s.Extraction -//import org.json4s.JsonAST.JValue -import very.util.persistence.quill.ZIOJsonSupport -import very.util.web.Pagination +import io.getquill.context.jdbc.{ Decoders, Encoders } +import very.util.persistence.quill.{ IDSupport, PageSupport, ZIOJsonSupport } +import very.util.entity.Pagination -import java.time.{LocalDateTime, OffsetDateTime} +import java.time.{ LocalDateTime, OffsetDateTime } import java.util.Calendar -class DB - extends PostgresJdbcContext(SnakeCase, "database") - with ZIOJsonSupport with IDSupport { +class DB extends PostgresJdbcContext(SnakeCase, "database") with ZIOJsonSupport with IDSupport with PageSupport[SnakeCase] { given encodeOffsetDateTime: Encoder[OffsetDateTime] = encoder( @@ -25,8 +19,6 @@ class DB given decodeOffsetDateTime: Decoder[OffsetDateTime] = decoder((index, row, _) => row.getObject(index, classOf[OffsetDateTime])) - // import org.json4s.jvalue2extractable - // private inline def encodeJValueEntity[T]:MappedEncoding[T,JValue] = MappedEncoding[T,JValue](v => Extraction.decompose(v)(formats)) // private inline def decodeJValueEntity[T](implicit mf:scala.reflect.Manifest[T]):MappedEncoding[JValue, T] = MappedEncoding[JValue,T](_.extract[T]) // @@ -66,28 +58,7 @@ class DB MappedEncoding(_.ordinal) given decodeNetworkStatus: MappedEncoding[Int, NetworkStatus] = MappedEncoding(NetworkStatus.fromOrdinal) - - extension [T](inline q: Query[T]) { - inline def page(using pagination: Pagination) = { - q.drop(lift(pagination.offset)).take(lift(pagination.pageSize)) - } - - // warning: sortBy should be split, because PG would report error for count(*) - inline def pageWithCount(using pagination: Pagination) = { - (this.run(quote(q.page)), this.run(quote(q.size))) - } - - inline def pageWithCount( - sort: Query[T] => Query[T] - )(using pagination: Pagination) = { - (this.run(quote(sort(q).page)), this.run(quote(q.size))) - } - - // inline def pageWithPram(param:T => Boolean)(using pagination:Pagination) = { -// q.filter(param).page -// } - inline def single = q.take(1) - } + } //@main def testQuill = { diff --git a/backend/src/main/scala/com/timzaak/fornet/dao/Network.scala b/backend/src/main/scala/com/timzaak/fornet/dao/Network.scala index 150dc36..6079509 100644 --- a/backend/src/main/scala/com/timzaak/fornet/dao/Network.scala +++ b/backend/src/main/scala/com/timzaak/fornet/dao/Network.scala @@ -2,6 +2,7 @@ package com.timzaak.fornet.dao // import io.getquill.{UpdateMeta, updateMeta} +import com.timzaak.fornet.dao.NetworkProtocol.TCP import org.hashids.Hashids import very.util.persistence.quill.DBSerializer import very.util.security.IntID @@ -23,6 +24,38 @@ object NetworkStatus { } } } + +enum NetworkProtocol { + case TCP, UDP + + import com.timzaak.fornet.protobuf.config.Protocol as PProtocol + def gRPCProtocol:PProtocol = { + this match { + case TCP => PProtocol.Protocol_TCP + case UDP => PProtocol.Protocol_UDP + } + } + + given JsonEncoder[NetworkProtocol] = JsonEncoder[Int].contramap(_.ordinal) + + given JsonDecoder[NetworkProtocol] = JsonDecoder[Int].mapOrFail { e => + Try(NetworkProtocol.fromOrdinal(e)) match { + case Success(v) => Right(v) + case Failure(_) => Left("no matching NodeType enum value") + } + } +} +object NetworkProtocol { + given JsonEncoder[NetworkProtocol] = JsonEncoder[Int].contramap(_.ordinal) + + given JsonDecoder[NetworkProtocol] = JsonDecoder[Int].mapOrFail { e => + Try(NetworkProtocol.fromOrdinal(e)) match { + case Success(v) => Right(v) + case Failure(_) => Left("no matching NetworkProtocol enum value") + } + } +} + case class Network( id: IntID, name: String, @@ -31,7 +64,7 @@ case class Network( setting: NetworkSetting, status: NetworkStatus, createdAt: OffsetDateTime, - updatedAt: OffsetDateTime + updatedAt: OffsetDateTime, ) //object Network { // given networkUpdateMeta:UpdateMeta[Network] = updateMeta[Network](_.id) @@ -40,6 +73,7 @@ case class NetworkSetting( port: Int = 51820, keepAlive: Int = 30, mtu: Int = 1420, + protocol:NetworkProtocol = NetworkProtocol.UDP, dns: Option[Seq[String]] = None, ) extends DBSerializer @@ -51,9 +85,12 @@ object NetworkSetting { given JsonCodec[NetworkSetting] = DeriveJsonCodec.gen } + import io.getquill.* -class NetworkDao(using quill: DB) { - import quill.{ *, given } +import org.hashids.Hashids + +class NetworkDao(using quill: DB, hashIds:Hashids) { + import quill.{*, given} def findById(id: IntID): Option[Network] = { quill.run(quote(query[Network]).filter(_.id == lift(id)).single).headOption diff --git a/backend/src/main/scala/com/timzaak/fornet/dao/Node.scala b/backend/src/main/scala/com/timzaak/fornet/dao/Node.scala index 9d62ea8..379f618 100644 --- a/backend/src/main/scala/com/timzaak/fornet/dao/Node.scala +++ b/backend/src/main/scala/com/timzaak/fornet/dao/Node.scala @@ -13,6 +13,14 @@ enum NodeType { // Normal: Fornet Client case Client, Relay + + import com.timzaak.fornet.protobuf.config.NodeType as PNodeType + def gRPCNodeType: PNodeType = { + this match { + case NodeType.Client => PNodeType.NODE_CLIENT + case NodeType.Relay => PNodeType.NODE_RELAY + } + } } object NodeType { @@ -78,16 +86,31 @@ case class Node( } } - def peerAddress: String = { + def realStatus(networkStatus: NetworkStatus): NodeStatus = { + if (networkStatus == NetworkStatus.Delete) { + NodeStatus.Delete + } else { + status + } + } + + def peerAllowedIp: String = { nodeType match { case NodeType.Relay => ip case NodeType.Client => s"$ip/32" } } + + def peerAddress: String = { + nodeType match { + case NodeType.Relay => ip.split("/").head + case NodeType.Client => ip + } + } } object Node { - import very.util.web.json.{intIDDecoder, intIDEncoder} + import very.util.web.json.{ intIDDecoder, intIDEncoder } given nodeCCodec(using hashId: Hashids): JsonCodec[Node] = DeriveJsonCodec.gen } @@ -109,7 +132,7 @@ object NodeSetting { import io.getquill.* -class NodeDao(using quill: DB) { +class NodeDao(using quill: DB, hashids: Hashids) { import quill.{ *, given } @@ -158,6 +181,11 @@ class NodeDao(using quill: DB) { } } + def getAllAvailableNodes(networkId: IntID): Seq[Node] = quill.run { + quote { + query[Node].filter(n => n.networkId == lift(networkId) && n.status == lift(NodeStatus.Normal)) + } + } def getAllAvailableNodes( networkId: IntID, exceptNodeId: IntID, diff --git a/backend/src/main/scala/com/timzaak/fornet/di/DI.scala b/backend/src/main/scala/com/timzaak/fornet/di/DI.scala index a976ef5..450a3c1 100644 --- a/backend/src/main/scala/com/timzaak/fornet/di/DI.scala +++ b/backend/src/main/scala/com/timzaak/fornet/di/DI.scala @@ -7,18 +7,12 @@ import com.timzaak.fornet.mqtt.MqttCallbackController import com.timzaak.fornet.mqtt.api.RMqttApiClient import com.timzaak.fornet.pubsub.{MqttConnectionManager, NodeChangeNotifyService} import com.timzaak.fornet.service.* -import com.typesafe.config.{Config, ConfigFactory} -import org.hashids.Hashids import very.util.keycloak.{JWKPublicKeyLocator, JWKTokenVerifier, KeycloakJWTAuthStrategy} import very.util.web.auth.{AuthStrategy, AuthStrategyProvider, SingleUserAuthStrategy} object DI extends DaoDI { di => - given config: Config = ConfigFactory.load() object appConfig extends AppConfigImpl(config) - object hashId extends Hashids(config.getString("server.hashId"), 5) - given Hashids = hashId - // connection Manager // object connectionManager extends ConnectionManager object connectionManager @@ -75,7 +69,9 @@ object DI extends DaoDI { di => extends NetworkController( networkDao = di.networkDao, appConfig = di.appConfig, + nodeChangeNotifyService = di.nodeChangeNotifyService, ) + object nodeController extends NodeController( nodeDao = di.nodeDao, diff --git a/backend/src/main/scala/com/timzaak/fornet/di/DaoDI.scala b/backend/src/main/scala/com/timzaak/fornet/di/DaoDI.scala index aa01dd5..beedf90 100644 --- a/backend/src/main/scala/com/timzaak/fornet/di/DaoDI.scala +++ b/backend/src/main/scala/com/timzaak/fornet/di/DaoDI.scala @@ -1,12 +1,22 @@ package com.timzaak.fornet.di import com.timzaak.fornet.dao.{DB, NetworkDao, NodeDao} +import com.typesafe.config.{Config, ConfigFactory} +import org.hashids.Hashids + //import org.json4s.Formats trait DaoDI { //given formats: Formats = org.json4s.DefaultFormats + org.json4s.ext.JOffsetDateTimeSerializer + given config: Config = ConfigFactory.load() + + object hashId extends Hashids(config.getString("server.hashId"), 5) + + given Hashids = hashId + + object db extends DB given DB = db diff --git a/backend/src/main/scala/com/timzaak/fornet/grpc/AuthGRPCController.scala b/backend/src/main/scala/com/timzaak/fornet/grpc/AuthGRPCController.scala index 38890d2..2228560 100644 --- a/backend/src/main/scala/com/timzaak/fornet/grpc/AuthGRPCController.scala +++ b/backend/src/main/scala/com/timzaak/fornet/grpc/AuthGRPCController.scala @@ -44,15 +44,24 @@ class AuthGRPCController( private val mqttClientUrl = config.get[String]("mqtt.clientUrl") import quill.{ *, given } + private def errorResponse(message: String) = ActionResponse(ActionResponse.Response.Error(message)) + private def successResponse(secretId:String) = ActionResponse( + ActionResponse.Response.Success( + com.timzaak.fornet.protobuf.auth.SuccessResponse(mqttClientUrl, secretId) + ) + ) override def inviteConfirm( request: InviteConfirmRequest ): Future[ActionResponse] = { + + var params = Seq(request.networkId) if (request.nodeId.nonEmpty) { params = params.appended(request.nodeId.get) } - if (request.encrypt.exists(v => nodeAuthService.validate2(v, params))) { + + if (request.encrypt.exists(v => nodeAuthService.validate(v, params))) { val networkId = IntID(request.networkId) val publicKey = request.encrypt.get.publicKey @@ -85,21 +94,22 @@ class AuthGRPCController( NodeStatus.Waiting, NodeStatus.Normal ) - ActionResponse(true, mqttUrl = Some(mqttClientUrl)) + successResponse(node.id.secretId) + } else { - ActionResponse(message = Some("already active or error response")) + errorResponse("already active or error response") } case None => createNode(networkId, publicKey) match { - case Some(value) => ActionResponse(message = Some(value)) - case None => - ActionResponse(isOk = true, mqttUrl = Some(mqttClientUrl)) + case Left(value) => errorResponse(value) + case Right(id) => + successResponse(id.secretId) } } Future.successful(response) } else { Future.successful( - ActionResponse(message = Some("Illegal Arguments")) + errorResponse("Illegal Arguments") ) } } @@ -110,7 +120,7 @@ class AuthGRPCController( request: OAuthDeviceCodeRequest ): Future[ActionResponse] = { val params = Seq(request.accessToken, request.deviceCode, request.networkId) - if (!appConfig.enableSAAS && request.encrypt.exists(v => nodeAuthService.validate2(v, params))) { + if (!appConfig.enableSAAS && request.encrypt.exists(v => nodeAuthService.validate(v, params))) { if (config.hasPath("auth.keycloak")) { val authResult = authStrategyProvider .getStrategy(KeycloakJWTAuthStrategy.name) @@ -119,35 +129,34 @@ class AuthGRPCController( authResult match { case Left(value) => - Future.successful(ActionResponse(message = Some(value))) + Future.successful(errorResponse(value)) case Right(userId) => val publicKey = request.encrypt.get.publicKey val networkId = IntID(request.networkId) logger.info( - s"user:${userId},networkId:${networkId}, publicKey:${request.encrypt.get.publicKey} register device with code:${request.deviceCode}" + s"user:${userId},networkId:${request.networkId}, publicKey:${request.encrypt.get.publicKey} register device with code:${request.deviceCode}" ) Future.successful( createNode(networkId, publicKey) match { - case Some(value) => ActionResponse(message = Some(value)) - case None => - ActionResponse(isOk = true, mqttUrl = Some(mqttClientUrl)) + case Left(value) => errorResponse(value) + case Right(id) => successResponse(id.secretId) } ) } } else { Future.successful( - ActionResponse(message = Some("do not support keycloak now")) + errorResponse("do not support keycloak now") ) } } else { Future.successful( - ActionResponse(message = Some("Illegal Arguments")) + errorResponse("Illegal Arguments") ) } } - private def createNode(networkId: IntID, publicKey: String) = { + private def createNode(networkId: IntID, publicKey: String):Either[String, IntID] = { val network = networkDao.findById(networkId).get // network create node val usedIp = nodeDao @@ -186,11 +195,11 @@ class AuthGRPCController( } } logger.info( - s"new client:$id(${publicKey}) join network ${network.id}" + s"new client:${id.id}(${publicKey}) join network ${network.id}" ) - None + Right(id) case None => - Some("Network has no available IP") + Left("Network has no available IP") } } diff --git a/backend/src/main/scala/com/timzaak/fornet/grpc/convert/EntityConvert.scala b/backend/src/main/scala/com/timzaak/fornet/grpc/convert/EntityConvert.scala index 01ab305..5fcdac4 100644 --- a/backend/src/main/scala/com/timzaak/fornet/grpc/convert/EntityConvert.scala +++ b/backend/src/main/scala/com/timzaak/fornet/grpc/convert/EntityConvert.scala @@ -16,11 +16,10 @@ object EntityConvert { val defaultKeepAlive = network.setting.keepAlive Peer( - endpoint = nodeSetting.endpoint.map(v => - s"$v:${nodeSetting.port.getOrElse(defaultPort)}" - ), - allowedIp = Seq(node.peerAddress), + endpoint = nodeSetting.endpoint.map(v => s"$v:${nodeSetting.port.getOrElse(defaultPort)}"), + allowedIp = Seq(node.peerAllowedIp), publicKey = node.publicKey, + address = Seq(node.peerAddress), persistenceKeepAlive = nodeSetting.keepAlive.getOrElse(defaultKeepAlive), ) } @@ -43,9 +42,11 @@ object EntityConvert { mtu = Some(setting.mtu.getOrElse(nSetting.mtu)), postUp = setting.postUp, postDown = setting.postDown, + protocol = nSetting.protocol.gRPCProtocol, ), ), - peers = toPeers(relativeNodes.filter(_.id != node.id), network) + peers = toPeers(relativeNodes.filter(_.id != node.id), network), + `type` = node.nodeType.gRPCNodeType, ) } } diff --git a/backend/src/main/scala/com/timzaak/fornet/mqtt/MqttCallbackController.scala b/backend/src/main/scala/com/timzaak/fornet/mqtt/MqttCallbackController.scala index ac56f52..313089f 100644 --- a/backend/src/main/scala/com/timzaak/fornet/mqtt/MqttCallbackController.scala +++ b/backend/src/main/scala/com/timzaak/fornet/mqtt/MqttCallbackController.scala @@ -1,6 +1,6 @@ package com.timzaak.fornet.mqtt -import com.timzaak.fornet.dao.{DB, NetworkDao, NodeDao, NodeStatus} +import com.timzaak.fornet.dao.* import com.timzaak.fornet.entity.PublicKey import com.timzaak.fornet.grpc.convert.EntityConvert import com.timzaak.fornet.mqtt.api.RMqttApiClient @@ -8,13 +8,20 @@ import com.timzaak.fornet.protobuf.config.ClientMessage import com.timzaak.fornet.pubsub.MqttConnectionManager import com.timzaak.fornet.service.NodeService import com.typesafe.config.Config +import com.typesafe.scalalogging.LazyLogging +import inet.ipaddr.IPAddress.IPVersion +import inet.ipaddr.IPAddressString +import inet.ipaddr.ipv4.IPv4Address import org.hashids.Hashids -import org.scalatra.{Forbidden, Ok, ScalatraServlet} +import org.scalatra.* +import very.util.security.IntID.toIntID import very.util.web.LogSupport import very.util.web.json.{JsonResponse, ZIOJsonSupport} +import very.util.web.validate.ValidationExtra import zio.json.{DeriveJsonDecoder, JsonDecoder, jsonField} -import scala.util.{Failure, Try} +import scala.util.matching.Regex +import scala.util.{Failure, Success, Try} case class AuthRequest( clientId: String, // publicKey @@ -29,9 +36,23 @@ case class WebHookCallbackRequest( @jsonField("clientid") clientId: String, topic: String, + username: String, ) given JsonDecoder[WebHookCallbackRequest] = DeriveJsonDecoder.gen +case class AclRequest( + // 1 = sub, 2 = pub + access: String, + username: String, + ipaddr: String, + @jsonField("clientid") + clientId: String, + topic: String +) + +given JsonDecoder[AclRequest] = DeriveJsonDecoder.gen + +private val networkTopicPattern = """^network/(\w+)$""".r class MqttCallbackController( nodeDao: NodeDao, networkDao: NetworkDao, @@ -39,21 +60,22 @@ class MqttCallbackController( mqttConnectionManager: MqttConnectionManager, )(using hashId: Hashids) extends ScalatraServlet - with LogSupport + with LazyLogging with ZIOJsonSupport { jPost("/auth") { (req: AuthRequest) => import req.* val data = password.split('|') - val isOk = if (data.length != 3) { + val isOk = if (data.length == 3) { val signature = data.last - val plainText = data.dropRight(1).mkString("-") + val plainText = data.dropRight(1).mkString("|") PublicKey(clientId).validate(plainText, signature) && nodeDao .findByPublicKey(clientId) - .nonEmpty + .exists(_.id.secretId == username) } else { false } + logger.debug(s"userName:${req.username}, ${req.clientId} auth ${isOk}") if (isOk) { Ok() } else { @@ -69,7 +91,7 @@ class MqttCallbackController( // {"action":"client_subscribe","clientid":"C5yG28uwzTumy6PpBEGqvvEWLJ8dYzF1uSFGziJG6Q8Jl+DPCRZZX05MPXb/s9GWsuO2JXzADAHz70WVbD2lew==","ipaddress":"127.0.0.1:56588","node":1,"opts":{"qos":1},"topic":"client","username":"undefined"} Try { - if (action == "client_subscribe" && topic == "client") { + if (action == "client_subscribe" && topic == s"client/${req.username}") { // send wr config val nodes = nodeDao @@ -85,19 +107,23 @@ class MqttCallbackController( .toMap } nodes.foreach { node => - val notifyNodes = nodeService.getAllRelativeNodes(node) - val network = networks(node.networkId) - mqttConnectionManager.sendMessage( - networkId = node.networkId, - node.id, - clientId, - ClientMessage( - networkId = node.networkId.secretId, - ClientMessage.Info.Config( - EntityConvert.nodeToWRConfig(node, network, notifyNodes) + + val network: Network = networks(node.networkId) + if (node.realStatus(network.status) == NodeStatus.Normal) { + val notifyNodes = nodeService.getAllRelativeNodes(node) + val network = networks(node.networkId) + mqttConnectionManager.sendClientMessage( + networkId = node.networkId, + node.id, + clientId, + ClientMessage( + networkId = node.networkId.secretId, + ClientMessage.Info.Config( + EntityConvert.nodeToWRConfig(node, network, notifyNodes) + ), ) ) - ) + } } } } match { @@ -107,13 +133,46 @@ class MqttCallbackController( Ok() } - post("/superuser") { - logger.debug(s"mqtt super user does not implement ${request.body}") - Forbidden() - } + jPost("/acl") { (req: AclRequest) => + logger.debug(s"mqtt acl: ${request.body}") + // pub + val result: ActionResult = if (req.access == "2") { + val isPrivateIP = + Try(IPAddressString(req.ipaddr).toAddress(IPVersion.IPV4).asInstanceOf[IPv4Address].isPrivate) match { + case Success(v) => v + case _ => false + } + if (isPrivateIP) { + Ok("allow") + } else { + Forbidden("deny") + } + // sub + } else if (req.access == "1") { + Try(req.username.toIntID).fold( + _ => Forbidden("allow"), + { id => + req.topic match { + case networkTopicPattern(secretId) => + Try(secretId.toIntID).fold( + _ => Forbidden(), + { networkId => + if (nodeDao.findById(networkId, id).nonEmpty) { + Ok("allow") + } else { + Forbidden("deny") + } + } + ) - post("/acl") { - logger.debug(s"mqtt acl does not implement,body: ${request.body}") - Ok() + case s"client/${id.secretId}" => Ok("allow") + case _ => Forbidden("deny") + } + } + ) + } else { + Forbidden("deny") + } + result } } diff --git a/backend/src/main/scala/com/timzaak/fornet/pubsub/MqttConnectionManager.scala b/backend/src/main/scala/com/timzaak/fornet/pubsub/MqttConnectionManager.scala index 7614da1..b758e8b 100644 --- a/backend/src/main/scala/com/timzaak/fornet/pubsub/MqttConnectionManager.scala +++ b/backend/src/main/scala/com/timzaak/fornet/pubsub/MqttConnectionManager.scala @@ -18,19 +18,24 @@ class MqttConnectionManager( private def encodeMessage(message: GeneratedMessage) = Base64.getEncoder.encodeToString(message.toByteArray) - def sendMessage(networkId: IntID, message: NetworkMessage): Try[Boolean] = { - logTry(s"send message[Network:$networkId] failure")( + def sendNetworkMessage( + networkId: IntID, + message: NetworkMessage, + retain: Option[Boolean] = Some(false) + ): Try[Boolean] = { + logTry(s"send message[Network:${networkId.id}] failure")( mqttApiClient.publish( PublishRequest( payload = encodeMessage(message), qos = Some(1), encoding = Some("base64"), - topic = s"network/${networkId.secretId}" + topic = s"network/${networkId.secretId}", + retain = retain, ) ) ) } - def sendMessage( + def sendClientMessage( networkId: IntID, nodeId: IntID, publicKey: String, @@ -44,7 +49,7 @@ class MqttConnectionManager( clientId = Some(publicKey), qos = Some(1), encoding = Some("base64"), - topic = "client", + topic = s"client/${nodeId.secretId}", retain, ) ) diff --git a/backend/src/main/scala/com/timzaak/fornet/pubsub/NodeChangeNotifyService.scala b/backend/src/main/scala/com/timzaak/fornet/pubsub/NodeChangeNotifyService.scala index fd52745..dff44a5 100644 --- a/backend/src/main/scala/com/timzaak/fornet/pubsub/NodeChangeNotifyService.scala +++ b/backend/src/main/scala/com/timzaak/fornet/pubsub/NodeChangeNotifyService.scala @@ -1,10 +1,11 @@ package com.timzaak.fornet.pubsub -import com.timzaak.fornet.dao.* +import com.timzaak.fornet.dao.{NetworkDao, *} import com.timzaak.fornet.grpc.convert.EntityConvert -import com.timzaak.fornet.protobuf.config.{ NodeStatus as PNodeStatus, * } +import com.timzaak.fornet.protobuf.config.{NetworkStatus as PNetworkStatus, NodeStatus as PNodeStatus, NodeType as PNodeType, *} import com.timzaak.fornet.service.NodeService import org.hashids.Hashids +import very.util.security.IntID class NodeChangeNotifyService( nodeDao: NodeDao, @@ -16,10 +17,8 @@ class NodeChangeNotifyService( import quill.{ *, given } - def nodeInfoChangeNotify(oldNode: Node, setting: NodeSetting) = { + def nodeInfoChangeNotify(oldNode: Node, setting: NodeSetting, network: Network) = { // TODO: FIXIT - - val network = networkDao.findById(oldNode.networkId).get val networkId = network.id.secretId val relativeNodes = nodeService.getAllRelativeNodes(oldNode) @@ -28,7 +27,7 @@ class NodeChangeNotifyService( val wrConfig: WRConfig = EntityConvert.nodeToWRConfig(fixedNode, network, relativeNodes) - connectionManager.sendMessage( + connectionManager.sendClientMessage( oldNode.networkId, oldNode.id, oldNode.publicKey, @@ -40,7 +39,7 @@ class NodeChangeNotifyService( // only keep alive matter case NodeType.Relay => // notify other nodes in network that relay change. - connectionManager.sendMessage( + connectionManager.sendNetworkMessage( fixedNode.networkId, NetworkMessage( networkId = networkId, @@ -54,6 +53,36 @@ class NodeChangeNotifyService( } } + // network must be in normal status + def networkSettingChange(oldNetwork: Network, newNetwork: Network): Unit = { + // only care about protocol, others will trigger push in future version.(after solved async push) + if (oldNetwork.setting.protocol != newNetwork.setting.protocol && newNetwork.status == NetworkStatus.Normal) { + val nodes = nodeDao.getAllAvailableNodes(oldNetwork.id).toList + for ((node, relativeNodes) <- nodeService.getNetworkAllRelativeNodes(nodes)) { + val wrConfig = EntityConvert.nodeToWRConfig(node, newNetwork, relativeNodes) + // this would trigger all nodes restart. + connectionManager.sendClientMessage( + node.networkId, + node.id, + node.publicKey, + ClientMessage(networkId = newNetwork.id.secretId, ClientMessage.Info.Config(wrConfig)) + ) + } + + } + } + + // PS: Network would never recover from delete status + def networkDeleteNotify(networkId: IntID): Unit = { + connectionManager.sendNetworkMessage( + networkId, + NetworkMessage( + networkId = networkId.secretId, + NetworkMessage.Info.Status(PNetworkStatus.NETWORK_DELETE) + ) + ) + } + def nodeStatusChangeNotify( node: Node, oldStatus: NodeStatus, @@ -62,19 +91,19 @@ class NodeChangeNotifyService( import NodeStatus.* val networkId = node.networkId.secretId // notify self node status change - connectionManager.sendMessage( + connectionManager.sendClientMessage( node.networkId, node.id, node.publicKey, ClientMessage( networkId = networkId, - ClientMessage.Info.Status(status.gRPCNodeStatus) + ClientMessage.Info.Status(status.gRPCNodeStatus), ) ) (oldStatus, status) match { case (Normal, _) => - connectionManager.sendMessage( + connectionManager.sendNetworkMessage( node.networkId, NetworkMessage( networkId = networkId, @@ -90,7 +119,7 @@ class NodeChangeNotifyService( val network = networkDao.findById(node.networkId).get val peer = EntityConvert.toPeer(node, network) - connectionManager.sendMessage( + connectionManager.sendNetworkMessage( node.networkId, NetworkMessage( networkId = networkId, @@ -102,7 +131,7 @@ class NodeChangeNotifyService( val notifyNodes = nodeService.getAllRelativeNodes(node) - connectionManager.sendMessage( + connectionManager.sendClientMessage( node.networkId, node.id, node.publicKey, @@ -110,12 +139,11 @@ class NodeChangeNotifyService( networkId = networkId, ClientMessage.Info.Config( EntityConvert.nodeToWRConfig(node, network, notifyNodes) - ) + ), ) ) case _ => // do nothing. } - } } diff --git a/backend/src/main/scala/com/timzaak/fornet/service/NodeAuthService.scala b/backend/src/main/scala/com/timzaak/fornet/service/NodeAuthService.scala index 028b06c..212f157 100644 --- a/backend/src/main/scala/com/timzaak/fornet/service/NodeAuthService.scala +++ b/backend/src/main/scala/com/timzaak/fornet/service/NodeAuthService.scala @@ -22,40 +22,9 @@ object GRPCAuthRequest { case class GRPCAuth(publicKey: PublicKey, networkId: Int) +//TODO: This should not be service, change it to object class NodeAuthService(using hashId: Hashids) { - - // import quill.{ given, * } - - /* def validate(grpcAuth: GRPCAuth): Either[Status, NodeIdentity] = { - if (NodeAuthService.validate(grpcAuth)) { - nodeDao - .findIdByPublicKey(grpcAuth.publicKey.key, grpcAuth.networkId) - .map(NodeIdentity(grpcAuth.networkId, _)) - .toRight( - Status.NOT_FOUND.withDescription("Could not find Node") - ) - } else { - Left(Status.INVALID_ARGUMENT.withDescription("Invalid auth")) - } - } */ - - @deprecated - def validate(grpcAuth: GRPCAuthRequest): Option[Int] = { - import grpcAuth.* - val plainText = s"$timestamp-$networkId-$nonce" - if (publicKey.validate(plainText, sign)) { - val hashIds = hashId.decode(networkId) - if (hashIds.size == 1) { - Some(hashIds.head.toInt) - } else { - None - } - } else { - None - } - } - - def validate2( + def validate( encrypt: EncryptRequest, params: Seq[String], ): Boolean = { @@ -67,3 +36,4 @@ class NodeAuthService(using hashId: Hashids) { }.getOrElse(false) } } + diff --git a/backend/src/main/scala/com/timzaak/fornet/service/NodeService.scala b/backend/src/main/scala/com/timzaak/fornet/service/NodeService.scala index 5c88b93..2f1f2eb 100644 --- a/backend/src/main/scala/com/timzaak/fornet/service/NodeService.scala +++ b/backend/src/main/scala/com/timzaak/fornet/service/NodeService.scala @@ -28,5 +28,20 @@ class NodeService(nodeDao: NodeDao)(using quill: DB, hashId: Hashids) { }) } + def getNetworkAllRelativeNodes(nodes:List[Node]):List[(Node,List[Node])] = { + val relayNodes = nodes.filter(_.nodeType == NodeType.Relay) + val clientNodes = nodes.filter(_.nodeType == NodeType.Client) + + nodes.map{ node => + val nodeIp = IPAddressString(node.ip) + node -> ((relayNodes.filter(rNode => IPAddressString(rNode.ip).prefixContains(nodeIp)) ++ ( + node.nodeType match { + case NodeType.Relay => + clientNodes.filter(cNode => nodeIp.prefixContains(IPAddressString(cNode.ip))) + case NodeType.Client => + List.empty + })).filter(_.id != node.id)) + } + } } diff --git a/backend/src/main/scala/very/util/entity/Pagination.scala b/backend/src/main/scala/very/util/entity/Pagination.scala new file mode 100644 index 0000000..fd99531 --- /dev/null +++ b/backend/src/main/scala/very/util/entity/Pagination.scala @@ -0,0 +1,9 @@ +package very.util.entity + +case class Pagination(page: Int, pageSize: Int) { + assert(pageSize <= 50) + assert(page > 0) + def offset: Int = (page - 1) * pageSize + def limit: Int = pageSize + +} diff --git a/backend/src/main/scala/very/util/persistence/quill/IDSupport.scala b/backend/src/main/scala/very/util/persistence/quill/IDSupport.scala index b619ea4..0f9930e 100644 --- a/backend/src/main/scala/very/util/persistence/quill/IDSupport.scala +++ b/backend/src/main/scala/very/util/persistence/quill/IDSupport.scala @@ -9,7 +9,6 @@ trait IDSupport { this: JdbcContextTypes[PostgresDialect, _] => given intIDEncode: MappedEncoding[IntID, Int] = MappedEncoding(_.id) - given intIDDecode(using hashId: Hashids): MappedEncoding[Int, IntID] = MappedEncoding(IntID.apply) given intIDListEncoder: MappedEncoding[List[IntID], List[Int]] = MappedEncoding(_.map(_.id)) diff --git a/backend/src/main/scala/very/util/persistence/quill/PageSupport.scala b/backend/src/main/scala/very/util/persistence/quill/PageSupport.scala new file mode 100644 index 0000000..d7d70fd --- /dev/null +++ b/backend/src/main/scala/very/util/persistence/quill/PageSupport.scala @@ -0,0 +1,33 @@ +package very.util.persistence.quill + +import io.getquill.context.jdbc.JdbcContextTypes +import io.getquill.* +import very.util.entity.Pagination + +trait PageSupport[+N <: NamingStrategy] { + this: PostgresJdbcContext[N] => + + + extension[T] (inline q: Query[T]) { + inline def page(using pagination: Pagination) = { + q.drop(lift(pagination.offset)).take(lift(pagination.pageSize)) + } + + // warning: sortBy should be split, because PG would report error for count(*) + inline def pageWithCount(using pagination: Pagination) = { + (this.run(quote(q.page)), this.run(quote(q.size))) + } + + inline def pageWithCount( + sort: Query[T] => Query[T] + )(using pagination: Pagination) = { + (this.run(quote(sort(q).page)), this.run(quote(q.size))) + } + + // inline def pageWithPram(param:T => Boolean)(using pagination:Pagination) = { + // q.filter(param).page + // } + inline def single = q.take(1) + } + +} diff --git a/backend/src/main/scala/very/util/persistence/quill/ZIOJsonSupport.scala b/backend/src/main/scala/very/util/persistence/quill/ZIOJsonSupport.scala index fe61f29..3ae8f7b 100644 --- a/backend/src/main/scala/very/util/persistence/quill/ZIOJsonSupport.scala +++ b/backend/src/main/scala/very/util/persistence/quill/ZIOJsonSupport.scala @@ -22,9 +22,8 @@ trait ZIOJsonSupport { } given decodeJsonb[T<:DBSerializer](using JsonDecoder[T]):Decoder[T] = - decoder{(index,row, _) => + decoder{(index,row, session) => val data = row.getString(index) - // println(s"data convert:${data},${data.fromJson[T]}") data.fromJson[T].toOption.get } } diff --git a/backend/src/main/scala/very/util/web/Controller.scala b/backend/src/main/scala/very/util/web/Controller.scala index ce6c39f..d5124a0 100644 --- a/backend/src/main/scala/very/util/web/Controller.scala +++ b/backend/src/main/scala/very/util/web/Controller.scala @@ -1,5 +1,6 @@ package very.util.web +import com.typesafe.scalalogging.LazyLogging import org.scalatra.json.JacksonJsonSupport import org.scalatra.* //import org.json4s.Formats @@ -14,7 +15,7 @@ class Controller //(using val jsonFormats: Formats) with I18nSupport with ValidationExtra with PaginationSupport - with LogSupport { + with LazyLogging { override def defaultFormat: Symbol = Symbol("txt") def badResponse(msg: String): ActionResult = { contentType = formats("txt") @@ -39,12 +40,14 @@ class Controller //(using val jsonFormats: Formats) super.renderPipeline(info) }: RenderPipeline) orElse super.renderPipeline + /* def created(id: Long): ActionResult = { contentType = formats("json") Created(s"""{"id":$id}""") } + */ def created(id: very.util.security.ID[_]): ActionResult = { contentType = formats("json") - Created(s"""{"id":${id.secretId}}""") + Created(s"""{"id":"${id.secretId}"}""") } } diff --git a/backend/src/main/scala/very/util/web/PaginationSupport.scala b/backend/src/main/scala/very/util/web/PaginationSupport.scala index 8594be8..e19b04d 100644 --- a/backend/src/main/scala/very/util/web/PaginationSupport.scala +++ b/backend/src/main/scala/very/util/web/PaginationSupport.scala @@ -1,14 +1,6 @@ package very.util.web -import com.timzaak.fornet.dao.DB - -case class Pagination(page: Int, pageSize: Int) { - assert(pageSize <= 50) - assert(page > 0) - def offset: Int = (page - 1) * pageSize - def limit: Int = pageSize - -} +import very.util.entity.Pagination trait PaginationSupport { this: org.scalatra.ScalatraBase => @@ -19,12 +11,12 @@ trait PaginationSupport { this: org.scalatra.ScalatraBase => given pagination: Pagination = Pagination(page, pageSize) inline def search[T]( - arguments: Map[String, String => T => Boolean] - ): Iterable[T => Boolean] = { + arguments: Map[String, String => T => Boolean] + ): Iterable[T => Boolean] = { for { (k, func) <- arguments value <- params.get(k) if value.nonEmpty } yield func(value) - + } } diff --git a/backend/src/main/scala/very/util/web/json/ZIOJsonSupport.scala b/backend/src/main/scala/very/util/web/json/ZIOJsonSupport.scala index 131283f..399f1d0 100644 --- a/backend/src/main/scala/very/util/web/json/ZIOJsonSupport.scala +++ b/backend/src/main/scala/very/util/web/json/ZIOJsonSupport.scala @@ -98,4 +98,4 @@ given intIDDecoder(using hashId: Hashids): JsonDecoder[IntID] = JsonDecoder[Stri case Failure(_) => Left("Invalid ID") } } -given intIDEncoder(using hashId: Hashids): JsonEncoder[IntID] = JsonEncoder.int.contramap(_.id) +given intIDEncoder: JsonEncoder[IntID] = JsonEncoder.string.contramap(_.secretId) diff --git a/backend/src/test/scala/very/util/practice/JsonSuite.scala b/backend/src/test/scala/very/util/practice/JsonSuite.scala index 9961b3d..86394f3 100644 --- a/backend/src/test/scala/very/util/practice/JsonSuite.scala +++ b/backend/src/test/scala/very/util/practice/JsonSuite.scala @@ -3,7 +3,7 @@ package very.util.practice //import com.fasterxml.jackson.annotation.JsonFormat import com.timzaak.fornet.dao.{NetworkSetting, NodeSetting} import munit.FunSuite -import org.json4s.Extraction + import zio.json.* import zio.json.ast.Json @@ -23,10 +23,10 @@ class JsonSuite extends FunSuite { DeriveJsonDecoder.gen[NodeSetting] given JsonEncoder[NodeSetting] = DeriveJsonEncoder.gen[NodeSetting] - test("jsonFormat, encode") { - val a = Extraction.decompose(NodeSetting())(org.json4s.DefaultFormats) - println(a) - } +// test("jsonFormat, encode") { +// val a = Extraction.decompose(NodeSetting())(org.json4s.DefaultFormats) +// println(a) +// } test("zio-json") { val a = NodeSetting(port = Some(1)).toJson @@ -43,6 +43,11 @@ class JsonSuite extends FunSuite { println(a.fromJson[TestEnum]) } + test("json default value") { + val str = """{"mtu": 1420, "port": 51820, "keepAlive": 30}""" + println(s"result: ${str.fromJson[NetworkSetting]}") + } + /* test("jresponse") { val data = PageJResponse(10,List("1","2")) println(data.toJson) diff --git a/backend/src/test/scala/very/util/practice/RegexSuite.scala b/backend/src/test/scala/very/util/practice/RegexSuite.scala new file mode 100644 index 0000000..4971b26 --- /dev/null +++ b/backend/src/test/scala/very/util/practice/RegexSuite.scala @@ -0,0 +1,25 @@ +package very.util.practice + +import com.typesafe.scalalogging.LazyLogging +import munit.FunSuite + +import scala.util.matching.Regex + +class RegexSuite extends FunSuite with LazyLogging { + private val networkTopicPattern = """^network/(\w+)$""".r + test("regex") { + val ID = "ewf014XF" + val topic = s"network/xxx" + topic match { + case s"network/${ID}" => println(s"xx:$ID") + case networkTopicPattern(secretId) => + println(secretId) + case _ => println("should not come here") + } + } + + test("stripSuffix") { + logger.info("bbb") + println("com.timzaak.test$controller".stripSuffix("$")) + } +} diff --git a/client/README.md b/client/README.md index 7e91178..8926e53 100644 --- a/client/README.md +++ b/client/README.md @@ -1,5 +1,5 @@ -some code in device is from `BoringTun`. I reimplement it for Tokio. +some code in device is from `BoringTun`. I refactor it with Tokio. diff --git a/client/lib/src/api/mod.rs b/client/lib/src/api/mod.rs index 999302c..faf51b6 100644 --- a/client/lib/src/api/mod.rs +++ b/client/lib/src/api/mod.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::env; use std::path::PathBuf; use std::sync::Arc; @@ -6,19 +5,19 @@ use anyhow::{anyhow, bail}; use serde_derive::{Deserialize, Serialize}; use tokio::io::AsyncWriteExt; use tokio::sync::mpsc::Sender; -use crate::config::{Config, Identity, ServerConfig}; +use crate::config::{Config, Identity, NodeInfo, ServerConfig}; use crate::sc_manager::SCManager; -use crate::protobuf::auth::{auth_client::AuthClient, InviteConfirmRequest, OAuthDeviceCodeRequest, SsoLoginInfoRequest}; +use crate::protobuf::auth::{auth_client::AuthClient, InviteConfirmRequest, OAuthDeviceCodeRequest, SsoLoginInfoRequest, SuccessResponse}; use crate::server_api::APISocket; use crate::server_manager::{ServerManager, ServerMessage}; use std::time::Duration; -use auto_launch_extra::AutoLaunchBuilder; use cfg_if::cfg_if; use tonic::{ transport::Channel, Request, }; use crate::{APP_NAME, MAC_OS_PACKAGE_NAME}; +use crate::protobuf::auth::action_response::Response; pub mod command_api; @@ -69,14 +68,12 @@ async fn join_network(server_manager: &mut ServerManager, invite_code: &str, str ) .await { - Ok(mqtt_url) => { - let mut map = HashMap::new(); - map.insert(invite_token.network_id, mqtt_url); + Ok(resp) => { //This must be success let server_config = ServerConfig { server: invite_token.endpoint, - mqtt: map, + info: vec![NodeInfo {network_id: invite_token.network_id, mqtt_url:resp.mqtt_url, node_id: resp.client_id}], }; server_config.save_config(&config_dir)?; change_config_and_init_sync_manger(); @@ -89,12 +86,10 @@ async fn join_network(server_manager: &mut ServerManager, invite_code: &str, str } else if version == 2u32 { // keycloak login let (mut client,sso_login) = SSOLogin::get_login_info(data).await?; match handle_oauth(identity, &mut client, &sso_login, stream).await { - Ok(mqtt_url) => { - let mut map = HashMap::new(); - map.insert(sso_login.network_id, mqtt_url); + Ok(resp) => { let server_config = ServerConfig { server: sso_login.endpoint, - mqtt: map, + info: vec![NodeInfo {network_id: sso_login.network_id.clone(), mqtt_url: resp.mqtt_url, node_id: resp.client_id}], }; server_config.save_config(&config_dir)?; change_config_and_init_sync_manger(); @@ -219,7 +214,7 @@ struct OAuthDeviceJWToken { } // https://github.com/keycloak/keycloak-community/blob/main/design/oauth2-device-authorization-grant.md -async fn handle_oauth(identity: Identity, client:&mut AuthClient, sso_login: &SSOLogin, stream: &mut APISocket) -> anyhow::Result { +async fn handle_oauth(identity: Identity, client:&mut AuthClient, sso_login: &SSOLogin, stream: &mut APISocket) -> anyhow::Result { let network_id = sso_login.network_id.clone(); @@ -252,11 +247,11 @@ async fn handle_oauth(identity: Identity, client:&mut AuthClient, sso_l network_id, encrypt:Some(encrypt), }); - let response= client.oauth_device_code_confirm(request).await?.into_inner(); - return if response.is_ok { - Ok(response.mqtt_url.unwrap().clone()) - } else { - Err(anyhow!(response.message.unwrap())) + let response= client.oauth_device_code_confirm(request).await?.into_inner().response; + return match response { + Some(Response::Error(message)) => Err(anyhow!(message)), + Some(Response::Success(resp))=> Ok(resp), + _ => Err(anyhow!("analyse auth response error")), } } else { tracing::debug!("check login status: not login, will try to check after {} seconds...", response.interval + 1); @@ -271,7 +266,7 @@ async fn server_invite_confirm( endpoint: &String, network_id: &String, node_id: Option, -) -> anyhow::Result { +) -> anyhow::Result { tracing::debug!("endpoint: {endpoint}"); let channel = Channel::from_shared(endpoint.clone())?.connect().await?; let mut client = AuthClient::new(channel); @@ -291,11 +286,10 @@ async fn server_invite_confirm( }); let response = client.invite_confirm(request).await?; - let response = response.into_inner(); - if response.is_ok { - Ok(response.mqtt_url.unwrap().clone()) - } else { - Err(anyhow!(response.message.unwrap())) + match response.into_inner().response { + Some(Response::Error(message)) => Err(anyhow!(message)), + Some(Response::Success(resp))=> Ok(resp), + _ => Err(anyhow!("analyse auth response error")), } } diff --git a/client/lib/src/config.rs b/client/lib/src/config.rs index 90c9826..a32e965 100644 --- a/client/lib/src/config.rs +++ b/client/lib/src/config.rs @@ -235,11 +235,18 @@ impl Debug for Identity { ) } } +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct NodeInfo { + pub network_id: String, + pub mqtt_url: String, + pub node_id: String, +} #[derive(Deserialize, Serialize, Debug)] pub struct ServerConfig { pub server: String, - pub mqtt: HashMap + //networkId, mqttUrl, clientId + pub info: Vec } /* impl default for serverconfig { diff --git a/client/lib/src/device/mod.rs b/client/lib/src/device/mod.rs index 5c2c4b7..e7dc2db 100644 --- a/client/lib/src/device/mod.rs +++ b/client/lib/src/device/mod.rs @@ -1,6 +1,6 @@ mod allowed_ips; pub mod peer; -mod udp_network; +mod tunnel; mod tun; pub mod auto_launch; pub mod script_run; @@ -17,31 +17,36 @@ cfg_if! { } use std::collections::HashMap; +use std::mem; use cfg_if::cfg_if; use rand::RngCore; use rand::rngs::OsRng; -use std::net::SocketAddr; -use std::sync::Arc; -use std::time::Duration; +use std::net::{IpAddr, SocketAddr}; +use std::sync::{Arc}; +use std::time::{Duration, SystemTime}; use boringtun::noise::errors::WireGuardError; use boringtun::noise::rate_limiter::RateLimiter; use boringtun::noise::{Packet, Tunn, TunnResult}; use boringtun::noise::handshake::parse_handshake_anon; use prost::bytes::BufMut; -use tokio::net::UdpSocket; +use tokio::net::{TcpListener, TcpStream, UdpSocket}; use tokio::sync::{Mutex, RwLock}; use tokio::time; -use tokio::io::AsyncWriteExt;//keep +use tokio::io::{AsyncReadExt, AsyncWriteExt};//keep +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use allowed_ips::AllowedIps; use peer::{AllowedIP, Peer}; use script_run::Scripts; +use crate::device::peer::TcpConnection; use crate::device::script_run::run_opt_script; +use crate::protobuf::config::NodeType; use self::tun::WritePart; const HANDSHAKE_RATE_LIMIT: u64 = 100; // The number of handshakes per second we can tolerate before using cookies const MAX_UDP_SIZE: usize = (1 << 16) - 1; +const MAX_TCP_SIZE: usize = (1 << 16) -1; // const MAX_ITR: usize = 100; // Number of packets to handle per handler call #[derive(Debug)] @@ -177,6 +182,7 @@ impl DeviceData { endpoint: Option, allowed_ips: &[AllowedIP], keepalive: Option, + ip: IpAddr, preshared_key: Option<[u8; 32]>, ) { // Update an existing peer @@ -198,7 +204,7 @@ impl DeviceData { ) .unwrap(); - let peer = Peer::new(tunn, next_index, endpoint, allowed_ips, preshared_key); + let peer = Peer::new(tunn, next_index, endpoint, allowed_ips, ip, preshared_key); let peer = Arc::new(Mutex::new(peer)); let mut peers = self.peers.write().await; @@ -249,6 +255,34 @@ pub async fn tun_read_handle(peers: &Arc>, udp4: &UdpSocket, udp6: } else { tracing::error!("No endpoint"); } + //TODO: get tcp socket from peers and send + } + _ => panic!("Unexpected result from encapsulate"), + }; + } + } +} + +pub async fn tun_read_tcp_handle(peers: &Arc>, src_buf: &[u8], dst_buf: &mut [u8]) { + //tracing::debug!("tun read:{:x?},{:?}", Tunn::dst_address(src_buf), src_buf); + if let Some(dst_addr) = Tunn::dst_address(src_buf) { + if let Some(peer) = peers.read().await.by_ip.find(dst_addr) { + let mut peer = peer.lock().await; + match peer.tunnel.encapsulate(src_buf, &mut dst_buf[..]) { + TunnResult::Done => { + // tracing::debug!("done"); + } + TunnResult::Err(e) => { + tracing::error!(message = "Encapsulate error", error = ?e) + } + TunnResult::WriteToNetwork(packet) => { + let endpoint = peer.endpoint(); + if let TcpConnection::Connected(conn) = &mut endpoint.tcp_conn { + //TODO: error detect + let _ = conn.write_all(packet).await; + } else { + tracing::info!("no endpoint of {:?}", endpoint.addr); + } } _ => panic!("Unexpected result from encapsulate"), }; @@ -285,6 +319,7 @@ pub async fn peers_timer(peers: &Arc>, udp4: &UdpSocket, udp6: &Ud } TunnResult::Err(e) => tracing::error!(message = "Timer error", error = ?e), TunnResult::WriteToNetwork(packet) => { + let _ = match endpoint_addr { SocketAddr::V4(_) => udp4.send_to(packet, endpoint_addr).await, SocketAddr::V6(_) => udp6.send_to(packet, endpoint_addr).await, @@ -296,6 +331,70 @@ pub async fn peers_timer(peers: &Arc>, udp4: &UdpSocket, udp6: &Ud } } +pub async fn tcp_peers_timer( + ip: &IpAddr, + peers: &Arc>, + key_pair: Arc<(x25519_dalek::StaticSecret, x25519_dalek::PublicKey)>, + rate_limiter: Arc, + iface: Arc>, + pi: bool, + node_type: NodeType, +) { + let mut interval = time::interval(Duration::from_millis(250)); + let mut dst_buf: Vec= vec![0; MAX_UDP_SIZE]; + + loop { + interval.tick().await; + let peer_map = &peers.read().await.by_key; + for peer in peer_map.values() { + let mut p = peer.lock().await; + let endpoint_addr = match p.endpoint().addr { + Some(addr) => addr, + None => continue, + }; + match &mut p.endpoint.tcp_conn { + TcpConnection::Nothing | TcpConnection::ConnectedFailure(_) => { + if node_type == NodeType::NodeClient || ip < &p.ip { + p.endpoint.tcp_conn = TcpConnection::Connecting(SystemTime::now()); + match TcpStream::connect(&endpoint_addr).await { + Ok(conn) => { + let (reader, writer) = conn.into_split(); + p.endpoint.tcp_conn = TcpConnection::Connected(writer); + tcp_handler(reader, WriterState::PeerWriter(peer.clone()), endpoint_addr, key_pair.clone(), rate_limiter.clone(), peers.clone(), iface.clone(), pi); + }, + Err(error) => { + tracing::debug!("connect {endpoint_addr:?} failure, error: {error:?}"); + p.endpoint.tcp_conn = TcpConnection::ConnectedFailure(error) + } + }; + } + continue; + } + TcpConnection::Connecting(_) => { + //TODO: add check of time, and reconnect + continue; + } + _ => {} + }; + match p.update_timers(&mut dst_buf) { + TunnResult::Done => {} + TunnResult::Err(WireGuardError::ConnectionExpired) => { + tracing::debug!("connection expired, should shutdown this endpoint"); + p.shutdown_endpoint(); + } + TunnResult::Err(e) => tracing::error!(message = "Timer error", error = ?e), + TunnResult::WriteToNetwork(packet) => { + if let TcpConnection::Connected(connection) = &mut p.endpoint.tcp_conn { + let _ = connection.write_all(packet).await; + } + + } + _ => tracing::warn!("Unexpected result from update_timers"), + }; + } + } +} + pub async fn udp_handler(udp: &UdpSocket, key_pair: &(x25519_dalek::StaticSecret, x25519_dalek::PublicKey), @@ -407,7 +506,7 @@ pub async fn udp_handler(udp: &UdpSocket, while let TunnResult::WriteToNetwork(packet) = p.tunnel.decapsulate(None, &[], &mut dst_buf[..]) - { + { let _ = udp.send_to(packet, addr).await; } @@ -417,6 +516,182 @@ pub async fn udp_handler(udp: &UdpSocket, } } +pub enum WriterState { + PureWriter(OwnedWriteHalf), + PeerWriter(Arc>), +} + +pub async fn tcp_listener_handler( + listener: &TcpListener, + key_pair: Arc<(x25519_dalek::StaticSecret, x25519_dalek::PublicKey)>, + rate_limiter: Arc, + peers: Arc>, + iface: Arc>, + pi: bool, +) ->anyhow::Result<()> { + loop { + let (socket, addr) = listener.accept().await?; + let key_pair = key_pair.clone(); + let rate_limiter = rate_limiter.clone(); + let peers = peers.clone(); + let iface = iface.clone(); + let (reader, writer ) = socket.into_split(); + tcp_handler(reader, WriterState::PureWriter(writer), addr,key_pair, rate_limiter, peers, iface, pi); + } + //Ok(()) +} +pub fn tcp_handler( + //socket: TcpStream, + reader: OwnedReadHalf, + writer: WriterState, + addr: SocketAddr, + key_pair: Arc<(x25519_dalek::StaticSecret, x25519_dalek::PublicKey)>, + rate_limiter: Arc, + peers: Arc>, + iface: Arc>, + pi: bool, +) { + tokio::spawn(async move { + let (private_key, public_key) = key_pair.as_ref(); + let mut writer = writer; + let mut reader = reader; + //let (mut reader, writer ) = socket.into_split(); + //let mut writer = WriterState::PureWriter(writer); + let mut src_buf: Vec = vec![0; MAX_UDP_SIZE]; + let mut dst_buf: Vec = vec![0; MAX_UDP_SIZE]; + while let Ok(size) = reader.read(&mut src_buf).await { + if size > 0 { + let parsed_packet = + match rate_limiter.as_ref().verify_packet(Some(addr.ip()), &src_buf[..size], &mut dst_buf) { + Ok(packet) => packet, + Err(TunnResult::WriteToNetwork(cookie)) => { + match &mut writer { + WriterState::PureWriter(writer) => { + let _ = writer.write_all(cookie).await; + }, + WriterState::PeerWriter(peer)=> { + let mut p = peer.lock().await; + if let TcpConnection::Connected(w) = &mut p.endpoint.tcp_conn { + let _ = w.write_all(cookie).await; + }else { + tracing::warn!("should not come here"); + } + } + } + continue; + } + Err(_) => continue, + }; + let peer = match &parsed_packet { + Packet::HandshakeInit(p) => { + if let Ok(hh) = parse_handshake_anon(private_key, public_key, p) { + let by_key = &peers.read().await.by_key; + by_key.get(&x25519_dalek::PublicKey::from(hh.peer_static_public)).map(Arc::clone) + } else { + None + } + } + Packet::HandshakeResponse(p) => peers.read().await.by_idx.get(&(p.receiver_idx >> 8)).map(Arc::clone), + Packet::PacketCookieReply(p) => peers.read().await.by_idx.get(&(p.receiver_idx >> 8)).map(Arc::clone), + Packet::PacketData(p) => peers.read().await.by_idx.get(&(p.receiver_idx >> 8)).map(Arc::clone), + }; + let peer = match peer { + None => continue, + Some(peer) => peer, + }; + + let mut p = peer.lock().await; + if let TcpConnection::Nothing | TcpConnection::ConnectedFailure(_) = p.endpoint.tcp_conn { + if let WriterState::PureWriter(_) = &mut writer { + let pure_writer = mem::replace(&mut writer,WriterState::PeerWriter(peer.clone())); + if let WriterState::PureWriter(_writer) = pure_writer { + p.endpoint.tcp_conn = TcpConnection::Connected(_writer); + } + } + } + // We found a peer, use it to decapsulate the message+ + let mut flush = false; // Are there packets to send from the queue? + match p + .tunnel + .handle_verified_packet(parsed_packet, &mut dst_buf[..]) + { + TunnResult::Done => {} + TunnResult::Err(_) => continue, + TunnResult::WriteToNetwork(packet) => { + flush = true; + + if let TcpConnection::Connected(conn) = &mut p.endpoint.tcp_conn { + let _ = conn.write_all(packet).await; + } + } + TunnResult::WriteToTunnelV4(packet, addr) => { + // tracing::debug!("{addr:?}"); + if p.is_allowed_ip(addr) { + if pi { + let mut buf: Vec = Vec::new(); + buf.put_slice(&IP4_HEADER); + buf.put_slice(&packet); + cfg_if! { + if #[cfg(target_os="windows")] { + let _ = iface.lock().await.write(&buf); + } else { + let _ = iface.lock().await.write(&buf).await; + } + } + } else { + cfg_if! { + if #[cfg(target_os="windows")] { + let _ = iface.lock().await.write(&packet); + } else { + let _ = iface.lock().await.write(&packet).await; + } + } + } + } else {} + } + TunnResult::WriteToTunnelV6(packet, addr) => { + if p.is_allowed_ip(addr) { + if pi { + let mut buf: Vec = Vec::new(); + buf.put_slice(&IP6_HEADER); + buf.put_slice(&packet); + cfg_if! { + if #[cfg(target_os="windows")] { + let _ = iface.lock().await.write(&buf); + } else { + let _ = iface.lock().await.write(&buf).await; + } + } + } else { + cfg_if! { + if #[cfg(target_os="windows")] { + let _ = iface.lock().await.write(packet); + } else { + let _ = iface.lock().await.write(packet).await; + } + } + }; + } + } + }; + + if flush { + // Flush pending queue + while let TunnResult::WriteToNetwork(packet) = + p.tunnel.decapsulate(None, &[], &mut dst_buf[..]) + { + if let TcpConnection::Connected(conn) = &mut p.endpoint.tcp_conn { + let _ = conn.write_all(packet).await; + } + } + } + } + } + tracing::info!("tcp: {addr:?} close"); + }); +} + + pub struct Peers { pub by_key: HashMap>>, diff --git a/client/lib/src/device/peer.rs b/client/lib/src/device/peer.rs index 4e073fc..a77e038 100644 --- a/client/lib/src/device/peer.rs +++ b/client/lib/src/device/peer.rs @@ -8,16 +8,25 @@ use std::net::{IpAddr}; use std::net::SocketAddr; use std::str::FromStr; use std::sync::Arc; +use std::time::SystemTime; use boringtun::noise::{Tunn, TunnResult}; -use tokio::net::UdpSocket; +use tokio::net::{UdpSocket}; +use tokio::net::tcp::OwnedWriteHalf; use crate::device::allowed_ips::AllowedIps; - -#[derive(Default, Debug)] +#[derive(Debug)] +pub enum TcpConnection { + Nothing, + Connecting(SystemTime), + Connected(OwnedWriteHalf), + ConnectedFailure(std::io::Error) +} +#[derive(Debug)] pub struct Endpoint { pub addr: Option, - pub conn: Option>, + pub udp_conn: Option>, + pub tcp_conn: TcpConnection, } pub struct Peer { @@ -25,8 +34,9 @@ pub struct Peer { pub(crate) tunnel: Tunn, /// The index the tunnel uses index: u32, - endpoint: Endpoint, + pub endpoint: Endpoint, allowed_ips: AllowedIps<()>, + pub ip: IpAddr, preshared_key: Option<[u8; 32]>, } @@ -66,6 +76,7 @@ impl Peer { index: u32, endpoint: Option, allowed_ips: &[AllowedIP], + ip:IpAddr, preshared_key: Option<[u8; 32]>, ) -> Peer { Peer { @@ -73,8 +84,10 @@ impl Peer { index, endpoint: Endpoint { addr: endpoint, - conn: None, + udp_conn: None, + tcp_conn: TcpConnection::Nothing, }, + ip, allowed_ips: allowed_ips.iter().map(|ip| (ip, ())).collect(), preshared_key, } @@ -89,23 +102,20 @@ impl Peer { } pub fn shutdown_endpoint(&mut self) { - if let Some(conn) = self.endpoint.conn.take() { - tracing::info!("Disconnecting from endpoint"); - drop(conn) + if let Some(_) = &mut self.endpoint.udp_conn.take() { + tracing::info!("disconnecting from endpoint"); + } + if let TcpConnection::Connected(_) = &mut self.endpoint.tcp_conn { + tracing::info!("disconnecting tcp connection"); } + self.endpoint.tcp_conn = TcpConnection::Nothing; } pub fn set_endpoint(&mut self, addr: SocketAddr) { if self.endpoint.addr != Some(addr) { // We only need to update the endpoint if it differs from the current one - if let Some(conn) = self.endpoint.conn.take() { - drop(conn) - // conn.shutdown(); - } - self.endpoint = Endpoint { - addr: Some(addr), - conn: None, - } + self.shutdown_endpoint(); + self.endpoint.addr = Some(addr); }; } @@ -137,6 +147,7 @@ impl Peer { #[cfg(test)] mod tests { + use std::net::{IpAddr, SocketAddr}; use crate::device::peer::AllowedIP; #[test] @@ -145,4 +156,14 @@ mod tests { assert_eq!(ip_v4.to_string(), String::from("10.0.0.0/32")); assert_eq!(ip_v4.addr.to_string(), String::from("10.0.0.0")); } + + #[test] + fn ip_compare() { + let ip = "10.0.0.1".parse::(); + //println!("123 {:?}", ip); + let ip1:IpAddr = "10.0.0.1".parse().unwrap(); + let ip2:IpAddr = "10.0.0.2".parse().unwrap(); + println!("should be false {}", ip1 == ip2); + println!("should be true {}", ip1 < ip2); + } } diff --git a/client/lib/src/device/script_run.rs b/client/lib/src/device/script_run.rs index 8e0ca7b..a1e3206 100644 --- a/client/lib/src/device/script_run.rs +++ b/client/lib/src/device/script_run.rs @@ -1,4 +1,4 @@ -use shell_candy::{ShellTask, ShellTaskBehavior, ShellTaskLog, ShellTaskOutput}; +use shell_candy::{ShellTask, ShellTaskBehavior, ShellTaskLog}; use crate::protobuf::config::Interface; #[derive(Default,Debug)] diff --git a/client/lib/src/device/tunnel.rs b/client/lib/src/device/tunnel.rs new file mode 100644 index 0000000..4537397 --- /dev/null +++ b/client/lib/src/device/tunnel.rs @@ -0,0 +1,73 @@ +use std::net::{SocketAddr}; +use tokio::net::{TcpListener, ToSocketAddrs, UdpSocket}; +use socket2::{Type, Protocol, Domain}; + +pub fn create_udp_socket(port: Option, domain: Domain, mark:Option) -> anyhow::Result { + let socket = socket2::Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?; + socket.set_nonblocking(true)?; + + #[cfg(target_os = "linux")] + { + socket.set_reuse_address(true)?; // On Linux SO_REUSEPORT won't prefer a connected IPv6 socket + if let Some(mark) = mark { + socket.set_mark(mark)?; + } + } + #[cfg(not(any(target_os = "linux", target_os = "windows")))] + socket.set_reuse_port(true)?; + + let port = port.unwrap_or(0); + + let address: SocketAddr = match domain { + Domain::IPV4 => + format!("0.0.0.0:{}", port), + Domain::IPV6 => + format!("[::]:{}", port), + _ => panic!("udp client don't support Domain::Unix") + }.parse()?; + socket.bind(&address.into())?; + Ok(UdpSocket::from_std(socket.into())?) +} +//TODO: how to bind same port of IPv6 IPv4 +pub fn create_tcp_server(port: Option, domain: Domain, mark:Option) ->anyhow::Result{ + let socket = socket2::Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?; + #[cfg(target_os = "linux")] + { + if let Some(mark) = mark { + socket.set_mark(mark)?; + } + } + + let port = port.unwrap_or(0); + let address: SocketAddr = match domain { + Domain::IPV4 => + format!("0.0.0.0:{}", port), + Domain::IPV6 => { + socket.set_only_v6(false)?; + format!("[::]:{}", port) + }, + _ => panic!("tcp server don't support Domain::Unix") + }.parse()?; + + socket.set_nonblocking(true)?; + socket.bind(&address.into())?; + socket.listen(128)?; + + let tcp_listener = socket.into(); + let tcp_listener = TcpListener::from_std(tcp_listener)?; + Ok(tcp_listener) +} + +#[cfg(test)] +mod test { + use socket2::Domain; + use crate::device::tunnel::create_tcp_server; + + #[tokio::test] + async fn test_tcp_bind() { + let ip4_server = create_tcp_server(None, Domain::IPV4, None).unwrap(); + let ip6_server = create_tcp_server(Some(ip4_server.local_addr().unwrap().port()), Domain::IPV6, None).unwrap(); + + println!("init ip4/ip6 in same port ok"); + } +} \ No newline at end of file diff --git a/client/lib/src/device/udp_network.rs b/client/lib/src/device/udp_network.rs deleted file mode 100644 index f5f21da..0000000 --- a/client/lib/src/device/udp_network.rs +++ /dev/null @@ -1,30 +0,0 @@ -use std::net::SocketAddr; -use tokio::net::UdpSocket; -use socket2::{Type, Protocol, Domain}; - -pub fn create_udp_socket(port: Option, domain: Domain, mark:Option) -> anyhow::Result { - let socket = socket2::Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?; - socket.set_nonblocking(true)?; - - #[cfg(target_os = "linux")] - { - socket.set_reuse_address(true)?; // On Linux SO_REUSEPORT won't prefer a connected IPv6 socket - if let Some(mark) = mark { - socket.set_mark(mark)?; - } - } - #[cfg(not(any(target_os = "linux", target_os = "windows")))] - socket.set_reuse_port(true)?; - - let port = port.unwrap_or(0); - - let address: SocketAddr = match domain { - Domain::IPV4 => - format!("0.0.0.0:{}", port), - Domain::IPV6 => - format!("[::]:{}", port), - _ => panic!("udp client don't support Domain::Unix") - }.parse()?; - socket.bind(&address.into())?; - Ok(UdpSocket::from_std(socket.into())?) -} \ No newline at end of file diff --git a/client/lib/src/device/unix_device.rs b/client/lib/src/device/unix_device.rs index 4500ded..b5f0a54 100644 --- a/client/lib/src/device/unix_device.rs +++ b/client/lib/src/device/unix_device.rs @@ -12,13 +12,15 @@ use crate::device::{DeviceData, Peers, HANDSHAKE_RATE_LIMIT, MAX_UDP_SIZE}; use crate::device::peer::AllowedIP; use crate::device::script_run::{run_opt_script, Scripts}; use crate::device::tun::create_async_tun; -use crate::device::udp_network::create_udp_socket; +use crate::device::tunnel::{create_tcp_server, create_udp_socket}; use nix::unistd::Uid; +use crate::protobuf::config::{Protocol, NodeType}; pub struct Device { pub device_data:DeviceData, task:JoinHandle<()>, + protocol:Protocol, } impl Device { @@ -29,29 +31,32 @@ impl Device { key_pair: (x25519_dalek::StaticSecret, x25519_dalek::PublicKey), port: Option, mtu: u32, - pub_key: String, scripts:Scripts, + protocol: Protocol, + node_type: NodeType, ) -> anyhow::Result { run_opt_script(&scripts.pre_up)?; tracing::debug!("begin to create tun"); let (mut iface_reader, iface_writer,pi, name) = create_async_tun(name, mtu, address)?; + tracing::debug!("finish to create tun"); let iface_writer = Arc::new(Mutex::new(iface_writer)); - let udp4 = create_udp_socket(port, Domain::IPV4, None)?; - let port = udp4.local_addr()?.port(); - let udp6 = create_udp_socket(Some(port), Domain::IPV6, None)?; let rate_limiter = Arc::new(RateLimiter::new(&key_pair.1, HANDSHAKE_RATE_LIMIT)); let peers: Arc> = Arc::new(RwLock::new(Peers::default())); - let mut tun_src_buf: Vec = vec![0; MAX_UDP_SIZE]; let mut tun_dst_buf: Vec = vec![0; MAX_UDP_SIZE]; let key_pair1 = key_pair.clone(); let peers1 = peers.clone(); - - let task:JoinHandle<()> = tokio::spawn(async move { - loop { - tokio::select! { + // create tcp/udp server + let (port,task) = match protocol { + Protocol::Udp => { + let udp4 = create_udp_socket(port, Domain::IPV4, None)?; + let port = udp4.local_addr()?.port(); + let udp6 = create_udp_socket(Some(port), Domain::IPV6, None)?; + let task:JoinHandle<()> = tokio::spawn(async move { + loop { + tokio::select! { _ = device::rate_limiter_timer(&rate_limiter) => {} _ = device::peers_timer(&peers,&udp4, &udp6) => {} // iface listen @@ -64,16 +69,55 @@ impl Device { device::tun_read_handle(&peers, &udp4, &udp6, src_buf, &mut tun_dst_buf).await; } // udp listen - _ = device::udp_handler(&udp4, &key_pair,&rate_limiter, Arc::clone(&peers), Arc::clone(&iface_writer), pi) => break, - _ = device::udp_handler(&udp6, &key_pair,&rate_limiter, Arc::clone(&peers), Arc::clone(&iface_writer), pi) => break, - + _ = device::udp_handler(&udp4, &key_pair, rate_limiter.as_ref(), Arc::clone(&peers), Arc::clone(&iface_writer), pi) => break, + _ = device::udp_handler(&udp6, &key_pair, rate_limiter.as_ref(), Arc::clone(&peers), Arc::clone(&iface_writer), pi) => break, } + } + + }); + (port, task) + } + Protocol::Tcp => { + let ip = address[0].addr.clone(); + let tcp6 = create_tcp_server(port, Domain::IPV6, None)?; + let port = tcp6.local_addr()?.port(); + let key_pair = Arc::new(key_pair); + + let task:JoinHandle<()> = tokio::spawn(async move { + loop { + tokio::select! { + _ = device::rate_limiter_timer(&rate_limiter) => {} + _ = device::tcp_peers_timer( + &ip, + &peers, + key_pair.clone(), + rate_limiter.clone(), + iface_writer.clone(), + pi, + node_type, + ) => {} + // iface listen + Ok(len) = iface_reader.read(&mut tun_src_buf) => { + let src_buf = if pi { + &tun_src_buf[4..(len+4)] + } else { + &tun_src_buf[0..len] + }; + device::tun_read_tcp_handle(&peers, src_buf, &mut tun_dst_buf).await; + } + //_ = device::tcp_listener_handler(&tcp4, key_pair.clone(), rate_limiter.clone(), Arc::clone(&peers), Arc::clone(&iface_writer), pi) => {break} + _ = device::tcp_listener_handler(&tcp6, key_pair.clone(), rate_limiter.clone(), Arc::clone(&peers), Arc::clone(&iface_writer), pi) => {break} + } + } + }); + (port, task) } + }; - }); let device = Device { device_data: DeviceData::new(name,peers1, key_pair1, port, scripts), task, + protocol, }; //run_opt_script(&Some("iptables -A FORWARD -i for0 -j ACCEPT".to_owned()))?; diff --git a/client/lib/src/device/windows_device.rs b/client/lib/src/device/windows_device.rs index fd48fbe..cc8c489 100644 --- a/client/lib/src/device/windows_device.rs +++ b/client/lib/src/device/windows_device.rs @@ -11,14 +11,15 @@ use crate::device::{HANDSHAKE_RATE_LIMIT, MAX_UDP_SIZE}; use crate::device::peer::AllowedIP; use crate::device::tun::{create_async_tun, ReadPart, WritePart}; use crate::device::script_run::{run_opt_script, Scripts}; -use crate::device::udp_network::create_udp_socket; - +use crate::device::tunnel::create_udp_socket; +use crate::protobuf::config::Protocol; pub struct Device { pub device_data: DeviceData, read_task:JoinHandle<()>, write_task:JoinHandle<()>, + pub protocol: Protocol, } impl Device { @@ -29,13 +30,14 @@ impl Device { key_pair: (x25519_dalek::StaticSecret, x25519_dalek::PublicKey), port: Option, mtu: u32, - pub_key: String, scripts:Scripts, + protocol: Protocol, ) -> anyhow::Result{ run_opt_script(&scripts.pre_up)?; let (mut iface_reader, iface_writer, name) = create_async_tun(name, mtu, address)?; + let udp4 = Arc::new(create_udp_socket(port, Domain::IPV4, None)?); let port = udp4.local_addr()?.port(); @@ -51,6 +53,7 @@ impl Device { device_data:DeviceData::new(name, peers, key_pair, port, scripts), read_task, write_task, + protocol, }; run_opt_script(&device.scripts.post_up)?; Ok(device) diff --git a/client/lib/src/flutter_api.rs b/client/lib/src/flutter_api.rs index be35743..8a65c8e 100644 --- a/client/lib/src/flutter_api.rs +++ b/client/lib/src/flutter_api.rs @@ -8,7 +8,7 @@ use std::str::FromStr; use once_cell::sync::OnceCell; -use tokio::runtime::{Handle, Runtime}; +use tokio::runtime::Runtime; use tracing::Level; use crate::{default_config_path, server_manager}; use crate::server_manager::StartMethod; diff --git a/client/lib/src/sc_manager.rs b/client/lib/src/sc_manager.rs index 50219f7..60b9ffc 100644 --- a/client/lib/src/sc_manager.rs +++ b/client/lib/src/sc_manager.rs @@ -1,19 +1,15 @@ -use std::convert::identity; use std::sync::Arc; -use std::time::{Duration, SystemTime}; +use std::time::Duration; use paho_mqtt as mqtt; -use paho_mqtt::SslVersion::Default; use prost::Message; use tokio::sync::mpsc::Sender; use tokio_stream::StreamExt; -use tonic::metadata::{Ascii, MetadataValue}; -use tonic::Request; -use tonic::transport::Channel; +use crate::config::{NodeInfo, Config as AppConfig}; -use crate::protobuf::config::{ClientMessage, NetworkMessage, NodeStatus, PeerChange, WrConfig}; +use crate::protobuf::config::{ClientMessage, NetworkMessage, NetworkStatus, NodeStatus, WrConfig}; use crate::protobuf::config::client_message::Info::{Config, Status}; -use crate::protobuf::config::network_message::Info::Peer; +use crate::protobuf::config::network_message::Info::{Peer, Status as NStatus}; use crate::server_manager::ServerMessage; //Sync Config Manager @@ -34,107 +30,131 @@ impl SCManager { } } - pub async fn mqtt_connect(&mut self, config: Arc) -> anyhow::Result<()> { - let mut deduplication = Duplication { - wr_config: None, - status: None, - }; - - for (_, mqtt_url) in &config.server_config.mqtt { - let mut client = mqtt::CreateOptionsBuilder::new() - .server_uri(mqtt_url) - .client_id( - &config.identity.pk_base64, - ).create_client()?; - let mut stream = client.get_stream(25); - - let encrypt = config.identity.sign2(Vec::new())?; - let password = format!("{}|{}|{}", encrypt.nonce, encrypt.timestamp, encrypt.signature); - - let conn_ops = mqtt::ConnectOptionsBuilder::new_v5() - .properties(mqtt::properties![mqtt::PropertyCode::SessionExpiryInterval => 3600]) - .password(password) - .finalize(); - //client - - //tokio spawn - - client.connect(conn_ops).await?; - let mut topics = config.server_config.mqtt.iter().map(|(key,_)| format!("network/{key}")).collect::>(); - topics.push("client".to_owned()); - let sub_opts = vec![mqtt::SubscribeOptions::with_retain_as_published(); topics.len()]; - - let qos = vec![1i32; topics.len()]; - - client.subscribe_many_with_options(&topics, &qos, &sub_opts, None) - .await?; - - while let Some(msg_opt) = stream.next().await { - if let Some(msg) = msg_opt { - tracing::debug!("receive message, topic: {}",msg.topic()); - match msg.topic() { - "client" => { - if let Ok(client_message) = ClientMessage::decode(msg.payload()) { - if let Some(info) = client_message.info { - match info { - Config(wr_config) => { - if deduplication.wr_config == Some(wr_config.clone()) { + async fn mqtt_reconnect(sender:Sender, node_info: &NodeInfo, config:Arc, deduplication:&mut Duplication) -> anyhow::Result<()> { + let mut client = mqtt::CreateOptionsBuilder::new() + .server_uri(&node_info.mqtt_url) + .client_id( + &config.identity.pk_base64, + ).create_client()?; + let mut stream = client.get_stream(25); + + let encrypt = config.identity.sign2(Vec::new())?; + let password = format!("{}|{}|{}", encrypt.nonce, encrypt.timestamp, encrypt.signature); + + let conn_ops = mqtt::ConnectOptionsBuilder::new_v5() + .properties(mqtt::properties![mqtt::PropertyCode::SessionExpiryInterval => 3600]) + .user_name(&node_info.node_id) + .password(password) + .finalize(); + //client + + //tokio spawn + + client.connect(conn_ops).await?; + let client_topic = format!("client/{}",&node_info.node_id); + let network_topic = format!("network/{}", &node_info.network_id); + let topics = vec!(&client_topic, &network_topic); + let sub_opts = vec![mqtt::SubscribeOptions::with_retain_as_published(); topics.len()]; + + let qos = vec![1i32; topics.len()]; + + client.subscribe_many_with_options(&topics, &qos, &sub_opts, None) + .await?; + + while let Some(msg_opt) = stream.next().await { + if let Some(msg) = msg_opt { + tracing::debug!("receive message, topic: {}",msg.topic()); + match msg.topic() { + topic if topic == &client_topic => { + if let Ok(client_message) = ClientMessage::decode(msg.payload()) { + if let Some(info) = client_message.info { + match info { + Config(wr_config) => { + if deduplication.wr_config == Some(wr_config.clone()) { + continue; + } + + let _ = sender.send(ServerMessage::SyncConfig(wr_config.clone())).await; + deduplication.wr_config = Some(wr_config); + } + Status(status) => { + if let Some(node_status) = NodeStatus::from_i32(status) { + if deduplication.status == Some(node_status) { continue; } - - let _ = self.sender.send(ServerMessage::SyncConfig(wr_config.clone())).await; - deduplication.wr_config = Some(wr_config); - } - Status(status) => { - if let Some(node_status) = NodeStatus::from_i32(status) { - if deduplication.status == Some(node_status) { - continue; + match node_status { + NodeStatus::NodeForbid => { + let _ = sender.send( + ServerMessage::StopWR("node has been forbid or delete".to_owned()) + ).await; } - match node_status { - NodeStatus::NodeForbid => { - let _ = self.sender.send( - ServerMessage::StopWR("node has been forbid or delete".to_owned()) - ).await; - } - _ => { - // this would conflict with Info::Config message, so ignore this. - } + _ => { + // this would conflict with Info::Config message, so ignore this. } - deduplication.status = Some(node_status) } + deduplication.status = Some(node_status) } } } - } else { - tracing::warn!("client message can not decode, may should update software"); } + } else { + tracing::warn!("client message can not decode, may should update software"); } - "network" => { - if let Ok(network_message) = NetworkMessage::decode(msg.payload()) { - if let Some(info) = network_message.info { - match info { - Peer(peer_change) => { - let _ = self.sender.send(ServerMessage::SyncPeers(peer_change)).await; + } + topic if topic == &network_topic => { + if let Ok(network_message) = NetworkMessage::decode(msg.payload()) { + if let Some(info) = network_message.info { + match info { + Peer(peer_change) => { + let _ = sender.send(ServerMessage::SyncPeers(peer_change)).await; + } + NStatus(status) => { + if let Some(NetworkStatus::NetworkDelete) = NetworkStatus::from_i32(status) { + let _ = sender.send( + ServerMessage::StopWR("network has been delete".to_owned()) + ).await; } } } - } else { - tracing::warn!("network message can not decode, may should update software"); } - } - _ => { - tracing::warn!("topic:{} message can not decode, may should update software", msg.topic()); + } else { + tracing::warn!("network message can not decode, may should update software"); } } - } else { - // A "None" means we were disconnected. Try to reconnect... - while let Err(err) = client.reconnect().await { - tracing::debug!("mqtt reconnect error: {}", err); - tokio::time::sleep(Duration::from_secs(2)).await; + _ => { + tracing::warn!("topic:{} message can not decode, may should update software", msg.topic()); } } + } else { + // A "None" means we were disconnected. Try to reconnect... + while let Err(err) = client.reconnect().await { + tracing::debug!("mqtt reconnect error: {}", err); + tokio::time::sleep(Duration::from_secs(2)).await; + } } - break; + } + Ok(()) + } + + pub async fn mqtt_connect(&mut self, config: Arc) -> anyhow::Result<()> { + for node_info in &config.server_config.info { + let mut deduplication = Duplication { + wr_config: None, + status: None, + }; + let node_info = node_info.clone(); + let _config = config.clone(); + let _sender = self.sender.clone(); + tokio::spawn(async move{ + loop { + let _config = _config.clone(); + let _sender = _sender.clone(); + let _ = SCManager::mqtt_reconnect(_sender, &node_info, _config, &mut deduplication).await; + tracing::debug!("mqtt connect error"); + tokio::time::sleep(Duration::from_secs(10)).await; + } + }); + } Ok(()) } diff --git a/client/lib/src/server_manager.rs b/client/lib/src/server_manager.rs index 66a9866..76b90b6 100644 --- a/client/lib/src/server_manager.rs +++ b/client/lib/src/server_manager.rs @@ -33,10 +33,8 @@ impl ServerManager { let mut sc_manager = SCManager::new(tx.clone()); let config = config.clone(); let _ = tokio::spawn(async move { - match sc_manager.mqtt_connect(config).await { - Ok(()) => tracing::warn!("sync config manager close, now can not receive any update from server"), - Err(e) => tracing::error!("sync config manager connect server result:{:?}", e), - }; + tracing::debug!("mqtt connect"); + let _ = sc_manager.mqtt_connect(config).await; }); } else { if start_method == StartMethod::CommandLine { @@ -82,15 +80,43 @@ impl ServerManager { } ServerMessage::SyncPeers(peer_change_message) => { if let Some(public_key) = peer_change_message.remove_public_key { - match Identity::get_pub_identity_from_base64(&public_key) { - Ok((x_pub_key, _)) => { - server_manager.wr_manager.remove_peer(&x_pub_key).await; - } - Err(_) => { - tracing::warn!("peer identity parse error") + if server_manager.config.map(|x|x.identity.pk_base64 != public_key).unwrap_or_else(true) { + match Identity::get_pub_identity_from_base64(&public_key) { + Ok((x_pub_key, _)) => { + server_manager.wr_manager.remove_peer(&x_pub_key).await; + } + Err(_) => { + tracing::warn!("peer identity parse error") + } } } } + if let Some(peer) = peer_change_message.add_peer { + let ip:IpAddr = peer.address.first().unwrap().parse().unwrap(); + let allowed_ip:Vec = peer.allowed_ip.into_iter().map(|ip| AllowedIP::from_str(&ip).unwrap()).collect(); + server_manager.wr_manager.add_peer( + peer.public_key, + false, + peer.endpoint, + &allowed_ip, + ip, + Some(peer.persistence_keep_alive as u16), + ).await; + } + if let Some(peer) = peer_change_message.change_peer { + if server_manager.config.map(|x|x.identity.pk_base64 != public_key).unwrap_or_else(true) { + let ip:IpAddr = peer.address.first().unwrap().parse().unwrap(); + let allowed_ip:Vec = peer.allowed_ip.into_iter().map(|ip| AllowedIP::from_str(&ip).unwrap()).collect(); + server_manager.wr_manager.add_peer( + peer.public_key, + false, + peer.endpoint, + &allowed_ip, + ip, + Some(peer.persistence_keep_alive as u16), + ).await; + } + } } }; } diff --git a/client/lib/src/wr_manager.rs b/client/lib/src/wr_manager.rs index 4bffa6d..50e78fb 100644 --- a/client/lib/src/wr_manager.rs +++ b/client/lib/src/wr_manager.rs @@ -1,11 +1,10 @@ -use std::net::SocketAddr; +use std::net::{IpAddr, SocketAddr}; use std::str::FromStr; -use std::time::Duration; use anyhow::anyhow; use serde_derive::{Deserialize, Serialize}; use crate::config::{Config, Identity}; use crate::device::peer::AllowedIP; -use crate::protobuf::config::WrConfig; +use crate::protobuf::config::{Protocol, WrConfig, NodeType}; use crate::device::Device; use crate::device::script_run::Scripts; @@ -34,6 +33,7 @@ impl WRManager { pub_key: x25519_dalek::PublicKey, endpoint: Option, allowed_ips: &[AllowedIP], + ip:IpAddr, keepalive: Option) { if let Some(device) = &mut self.device { device.update_peer( @@ -42,6 +42,7 @@ impl WRManager { endpoint, allowed_ips, keepalive, + ip, None, ).await; } else { @@ -54,6 +55,7 @@ impl WRManager { let interface = wr_config.interface.unwrap(); //let address = AllowedIP::from_str(interface.address.as_str()).map_err(|e| anyhow!(e))?; let mut address: Vec = Vec::new(); + for addr in &interface.address { address.push(AllowedIP::from_str(addr).map_err(|e| anyhow!(e))?); } @@ -64,13 +66,20 @@ impl WRManager { self.close().await; tracing::info!("close device before restart"); let tun_name = config.get_tun_name(); + let protocol = Protocol::from_i32(interface.protocol).unwrap_or(Protocol::Udp); + let node_type = NodeType::from_i32(wr_config.r#type).unwrap(); let scripts = Scripts::load_from_interface(&interface); let key_pair = (config.identity.x25519_sk.clone(), config.identity.x25519_pk.clone()); - let wr_interface = Device::new(&tun_name, &address, key_pair, Some(interface.listen_port as u16), - interface.mtu.unwrap_or(1420) as u32, - config.identity.pk_base64.clone(), - scripts, + let wr_interface = Device::new( + &tun_name, + &address, + key_pair, + Some(interface.listen_port as u16), + interface.mtu.unwrap_or(1420) as u32, + scripts, + protocol, + node_type, )?; self.device = Some(wr_interface); @@ -78,10 +87,12 @@ impl WRManager { let (x_pub_key,_) = Identity::get_pub_identity_from_base64(&peer.public_key)?; let endpoint = peer.endpoint.map(|v| SocketAddr::from_str(&v).unwrap()); let allowed_ip:Vec = peer.allowed_ip.into_iter().map(|ip| AllowedIP::from_str(&ip).unwrap()).collect(); + let ip:IpAddr = peer.address.first().unwrap().parse().unwrap(); self.add_peer( x_pub_key, endpoint, allowed_ip.as_slice(), + ip, Some(peer.persistence_keep_alive as u16), ).await; tracing::debug!("peer: {} join network", peer.public_key); @@ -90,6 +101,7 @@ impl WRManager { } pub fn is_alive(&self) -> bool { self.device.is_some() } + pub async fn close(&mut self) { if let Some(ref mut device) = self.device.take() { device.close().await @@ -97,7 +109,6 @@ impl WRManager { } pub fn device_info(&self) -> Vec { - self.device.as_ref().map_or(vec![], |device| { vec![DeviceInfoResp { name: device.name.clone() diff --git a/command/docker-compose/simple/config/mqtt/plugin/rmqtt-acl.toml b/command/docker-compose/simple/config/mqtt/plugin/rmqtt-acl.toml index 70472a5..a96c60b 100644 --- a/command/docker-compose/simple/config/mqtt/plugin/rmqtt-acl.toml +++ b/command/docker-compose/simple/config/mqtt/plugin/rmqtt-acl.toml @@ -2,6 +2,8 @@ ## rmqtt-acl ##-------------------------------------------------------------------- +disconnect_if_pub_rejected = true + rules = [ ["allow", { user = "dashboard" }, "subscribe", ["$SYS/#"]], ["allow", { ipaddr = "127.0.0.1" }, "pubsub", ["$SYS/#", "#"]], diff --git a/command/docker-compose/simple/config/mqtt/plugin/rmqtt-auth-http.toml b/command/docker-compose/simple/config/mqtt/plugin/rmqtt-auth-http.toml index a7f3b8a..57544d1 100644 --- a/command/docker-compose/simple/config/mqtt/plugin/rmqtt-auth-http.toml +++ b/command/docker-compose/simple/config/mqtt/plugin/rmqtt-auth-http.toml @@ -5,7 +5,7 @@ http_timeout = "5s" http_headers.accept = "*/*" http_headers.Cache-Control = "no-cache" -http_headers.User-Agent = "RMQTT/0.1.1" +http_headers.User-Agent = "RMQTT/0.2.11" http_headers.Connection = "keep-alive" #Stop the hook chain after successful authentication, including auth, pub-acl and sub-acl @@ -35,25 +35,6 @@ http_auth_req.headers.content-type="application/json" http_auth_req.params = { clientId = "%c", username = "%u", password = "%P" } -##-------------------------------------------------------------------- -## Superuser request. -## -## Variables: -## - %u: username -## - %c: clientid -## - %a: ipaddress -## - %r: protocol -## - %P: password -## -## Value: URL -http_super_req.url = "http://backend/mqtt/superuser" -## Value: post | get | put -http_super_req.method = "post" -http_super_req.headers.content-type="application/json" -## Value: Params -http_super_req.params = { clientid = "%c", username = "%u" } - - ##-------------------------------------------------------------------- ## ACL request. ## diff --git a/command/docker-compose/simple/config/mqtt/plugin/rmqtt-web-hook.toml b/command/docker-compose/simple/config/mqtt/plugin/rmqtt-web-hook.toml index e87363c..fee9c7c 100644 --- a/command/docker-compose/simple/config/mqtt/plugin/rmqtt-web-hook.toml +++ b/command/docker-compose/simple/config/mqtt/plugin/rmqtt-web-hook.toml @@ -32,7 +32,7 @@ retry_multiplier = 2.5 #rule.client_connack = [{action = "client_connack", urls = ["http://127.0.0.1:5656/mqtt/webhook", "http://127.0.0.1:5656/mqtt/webhook"] } ] #rule.client_connected = [{action = "client_connected" } ] #rule.client_disconnected = [{action = "client_disconnected" } ] -rule.client_subscribe = [{action = "client_subscribe", topics=["client"]} ] +rule.client_subscribe = [{action = "client_subscribe", topics=["client/#"]} ] #rule.client_unsubscribe = [{action = "client_unsubscribe", topics=["x/y/z", "foo/#"] } ] #rule.message_publish = [{action = "message_publish" }] diff --git a/command/docker-compose/simple/docker-compose.yml b/command/docker-compose/simple/docker-compose.yml index ed2336c..03a7710 100644 --- a/command/docker-compose/simple/docker-compose.yml +++ b/command/docker-compose/simple/docker-compose.yml @@ -3,7 +3,7 @@ version: "3" services: mqtt: - image: rmqtt/rmqtt:latest + image: rmqtt/rmqtt:0.2.11 container_name: mqtt ports: - 1883:1883 diff --git a/command/docker/mqtt/config/plugin/rmqtt-acl.toml b/command/docker/mqtt/config/plugin/rmqtt-acl.toml index 70472a5..a96c60b 100644 --- a/command/docker/mqtt/config/plugin/rmqtt-acl.toml +++ b/command/docker/mqtt/config/plugin/rmqtt-acl.toml @@ -2,6 +2,8 @@ ## rmqtt-acl ##-------------------------------------------------------------------- +disconnect_if_pub_rejected = true + rules = [ ["allow", { user = "dashboard" }, "subscribe", ["$SYS/#"]], ["allow", { ipaddr = "127.0.0.1" }, "pubsub", ["$SYS/#", "#"]], diff --git a/command/docker/mqtt/config/plugin/rmqtt-auth-http.toml b/command/docker/mqtt/config/plugin/rmqtt-auth-http.toml index 038b72e..88c4b1f 100644 --- a/command/docker/mqtt/config/plugin/rmqtt-auth-http.toml +++ b/command/docker/mqtt/config/plugin/rmqtt-auth-http.toml @@ -5,14 +5,15 @@ http_timeout = "5s" http_headers.accept = "*/*" http_headers.Cache-Control = "no-cache" -http_headers.User-Agent = "RMQTT/0.1.1" +http_headers.User-Agent = "RMQTT/0.2.11" http_headers.Connection = "keep-alive" -#Stop the hook chain after successful authentication, including auth, pub-acl and sub-acl -break_if_allow = true #Disconnect if publishing is rejected disconnect_if_pub_rejected = true +#Return 'Deny' if http request error otherwise 'Ignore' +deny_if_error = true + ##-------------------------------------------------------------------- ## Authentication request. ## @@ -35,24 +36,6 @@ http_auth_req.headers.content-type="application/json" http_auth_req.params = { clientId = "%c", username = "%u", password = "%P" } -##-------------------------------------------------------------------- -## Superuser request. -## -## Variables: -## - %u: username -## - %c: clientid -## - %a: ipaddress -## - %r: protocol -## - %P: password -## -## Value: URL -http_super_req.url = "http://dev.fornetcode.com/mqtt/superuser" -## Value: post | get | put -http_super_req.method = "post" -http_super_req.headers.content-type="application/json" -## Value: Params -http_super_req.params = { clientid = "%c", username = "%u" } - ##-------------------------------------------------------------------- ## ACL request. diff --git a/command/docker/mqtt/config/plugin/rmqtt-web-hook.toml b/command/docker/mqtt/config/plugin/rmqtt-web-hook.toml index 85d8219..8cdb82d 100644 --- a/command/docker/mqtt/config/plugin/rmqtt-web-hook.toml +++ b/command/docker/mqtt/config/plugin/rmqtt-web-hook.toml @@ -32,7 +32,7 @@ retry_multiplier = 2.5 #rule.client_connack = [{action = "client_connack", urls = ["http://127.0.0.1:5656/mqtt/webhook", "http://127.0.0.1:5656/mqtt/webhook"] } ] #rule.client_connected = [{action = "client_connected" } ] #rule.client_disconnected = [{action = "client_disconnected" } ] -rule.client_subscribe = [{action = "client_subscribe", topics=["client"]} ] +rule.client_subscribe = [{action = "client_subscribe", topics=["client/#"]} ] #rule.client_unsubscribe = [{action = "client_unsubscribe", topics=["x/y/z", "foo/#"] } ] #rule.message_publish = [{action = "message_publish" }] diff --git a/command/docker/mqtt/run.sh b/command/docker/mqtt/run.sh old mode 100644 new mode 100755 index cbaddbb..495a4a8 --- a/command/docker/mqtt/run.sh +++ b/command/docker/mqtt/run.sh @@ -2,5 +2,5 @@ # 1883(mqtt) 6060(http) 5363(grpc) # docker rm -f mqtt -docker run -d --name mqtt --network=host -v $(pwd)/log:/var/log/rmqtt -v $(pwd)/config/rmqtt.toml:/app/rmqtt/rmqtt.toml -v $(pwd)/config/plugin:/app/rmqtt/rmqtt-plugins rmqtt/rmqtt:latest +docker run -d --name mqtt --network=host -v $(pwd)/log:/var/log/rmqtt -v $(pwd)/config/rmqtt.toml:/app/rmqtt/rmqtt.toml -v $(pwd)/config/plugin:/app/rmqtt/plugin rmqtt/rmqtt:0.2.11 # docker logs -f --tail 50 mqtt \ No newline at end of file diff --git a/command/docker/proxy/run_dev.sh b/command/docker/proxy/run_dev.sh old mode 100644 new mode 100755 diff --git a/command/docker/proxy/run_test.sh b/command/docker/proxy/run_test.sh old mode 100644 new mode 100755 diff --git a/protobuf/auth.proto b/protobuf/auth.proto index 195005b..132cd81 100644 --- a/protobuf/auth.proto +++ b/protobuf/auth.proto @@ -4,13 +4,18 @@ package auth; option java_package = "com.timzaak.fornet.protobuf"; -import "google/protobuf/empty.proto"; +//import "google/protobuf/empty.proto"; +message SuccessResponse { + string mqtt_url = 1; + string client_id = 2; +} message ActionResponse { - bool isOk = 1; - optional string mqtt_url = 2; - optional string message = 3; + oneof response { + string error = 1; + SuccessResponse success = 2; + } } message EncryptRequest { diff --git a/protobuf/config.proto b/protobuf/config.proto index 5d26b43..9798e6f 100644 --- a/protobuf/config.proto +++ b/protobuf/config.proto @@ -2,21 +2,27 @@ syntax = "proto3"; package config; -import "google/protobuf/empty.proto"; +//import "google/protobuf/empty.proto"; option java_package = "com.timzaak.fornet.protobuf"; +enum Protocol { + Protocol_TCP = 0; + Protocol_UDP = 1; +} + message Interface { optional string name = 1; repeated string address = 2; int32 listen_port = 3; - optional string private_key = 4; + // optional string private_key = 4; // this is no needed now, we may support it in future version repeated string dns = 5; optional uint32 mtu = 6; optional string pre_up = 7; optional string post_up = 8; optional string pre_down = 9; optional string post_down = 10; + Protocol protocol = 11; } message Peer { @@ -24,6 +30,8 @@ message Peer { repeated string allowed_ip = 2; string public_key = 3; uint32 persistence_keep_alive = 4; + // This is for tcp + repeated string address = 5; } @@ -36,6 +44,7 @@ message PeerChange { message WRConfig { Interface interface = 1; repeated Peer peers = 2; + NodeType type = 3; } enum NodeStatus { @@ -43,6 +52,13 @@ enum NodeStatus { NODE_NORMAL = 1; NODE_FORBID = 2; } +enum NetworkStatus { + NETWORK_DELETE = 0; +} +enum NodeType { + NODE_CLIENT = 0; + NODE_RELAY = 1; +} message ClientMessage { string network_id = 1; @@ -55,5 +71,6 @@ message NetworkMessage { string network_id = 1; oneof info { PeerChange peer = 2; + NetworkStatus status = 3; } }