-
-
Notifications
You must be signed in to change notification settings - Fork 65
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pyrofork: Add Mongodb Session Storage
Signed-off-by: wulan17 <wulan17@nusantararom.org> Co-authored-by: wulan17 <wulan17@nusantararom.org>
- Loading branch information
Showing
5 changed files
with
198 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import asyncio | ||
import inspect | ||
import time | ||
from typing import List, Tuple, Any | ||
|
||
from motor.motor_asyncio import AsyncIOMotorClient | ||
from pymongo import UpdateOne | ||
from pyrogram.storage.storage import Storage | ||
from pyrogram.storage.sqlite_storage import get_input_peer | ||
|
||
|
||
class MongoStorage(Storage): | ||
""" | ||
config (``dict``) | ||
Mongodb config as dict, e.g.: *dict(uri="mongodb://...", db_name="pyrofork-session", remove_peers=False)*. | ||
Only applicable for new sessions. | ||
""" | ||
lock: asyncio.Lock | ||
USERNAME_TTL = 8 * 60 * 60 | ||
|
||
def __init__(self, config: dict): | ||
super().__init__('') | ||
db_name = "pyrofork-session" | ||
db_uri = config["uri"] | ||
remove_peers = False | ||
if "db_name" in config: | ||
db_name = config["db_name"] | ||
if "remove_peers" in config: | ||
remove_peers = config["remove_peers"] | ||
database = AsyncIOMotorClient(db_uri)[db_name] | ||
self.lock = asyncio.Lock() | ||
self.database = database | ||
self._peer = database['peers'] | ||
self._session = database['session'] | ||
self._remove_peers = remove_peers | ||
|
||
async def open(self): | ||
""" | ||
dc_id INTEGER PRIMARY KEY, | ||
api_id INTEGER, | ||
test_mode INTEGER, | ||
auth_key BLOB, | ||
date INTEGER NOT NULL, | ||
user_id INTEGER, | ||
is_bot INTEGER | ||
""" | ||
if await self._session.find_one({'_id': 0}, {}): | ||
return | ||
await self._session.insert_one( | ||
{ | ||
'_id': 0, | ||
'dc_id': 2, | ||
'api_id': None, | ||
'test_mode': None, | ||
'auth_key': b'', | ||
'date': 0, | ||
'user_id': 0, | ||
'is_bot': 0, | ||
|
||
} | ||
) | ||
|
||
async def save(self): | ||
pass | ||
|
||
async def close(self): | ||
pass | ||
|
||
async def delete(self): | ||
try: | ||
await self._session.delete_one({'_id': 0}) | ||
if self._remove_peers: | ||
await self._peer.remove({}) | ||
except Exception as _: | ||
return | ||
|
||
async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]): | ||
"""(id, access_hash, type, username, phone_number)""" | ||
s = int(time.time()) | ||
bulk = [ | ||
UpdateOne( | ||
{'_id': i[0]}, | ||
{'$set': { | ||
'access_hash': i[1], | ||
'type': i[2], | ||
'username': i[3], | ||
'phone_number': i[4], | ||
'last_update_on': s | ||
}}, | ||
upsert=True | ||
) for i in peers | ||
] | ||
if not bulk: | ||
return | ||
await self._peer.bulk_write( | ||
bulk | ||
) | ||
|
||
async def get_peer_by_id(self, peer_id: int): | ||
# id, access_hash, type | ||
r = await self._peer.find_one({'_id': peer_id}, {'_id': 1, 'access_hash': 1, 'type': 1}) | ||
if not r: | ||
raise KeyError(f"ID not found: {peer_id}") | ||
return get_input_peer(*r.values()) | ||
|
||
async def get_peer_by_username(self, username: str): | ||
# id, access_hash, type, last_update_on, | ||
r = await self._peer.find_one({'username': username}, | ||
{'_id': 1, 'access_hash': 1, 'type': 1, 'last_update_on': 1}) | ||
|
||
if r is None: | ||
raise KeyError(f"Username not found: {username}") | ||
|
||
if abs(time.time() - r['last_update_on']) > self.USERNAME_TTL: | ||
raise KeyError(f"Username expired: {username}") | ||
|
||
return get_input_peer(*list(r.values())[:3]) | ||
|
||
async def get_peer_by_phone_number(self, phone_number: str): | ||
|
||
# _id, access_hash, type, | ||
r = await self._peer.find_one({'phone_number': phone_number}, | ||
{'_id': 1, 'access_hash': 1, 'type': 1}) | ||
|
||
if r is None: | ||
raise KeyError(f"Phone number not found: {phone_number}") | ||
|
||
return get_input_peer(*r) | ||
|
||
async def _get(self): | ||
attr = inspect.stack()[2].function | ||
d = await self._session.find_one({'_id': 0}, {attr: 1}) | ||
if not d: | ||
return | ||
return d[attr] | ||
|
||
async def _set(self, value: Any): | ||
attr = inspect.stack()[2].function | ||
await self._session.update_one({'_id': 0}, {'$set': {attr: value}}, upsert=True) | ||
|
||
async def _accessor(self, value: Any = object): | ||
return await self._get() if value == object else await self._set(value) | ||
|
||
async def dc_id(self, value: int = object): | ||
return await self._accessor(value) | ||
|
||
async def api_id(self, value: int = object): | ||
return await self._accessor(value) | ||
|
||
async def test_mode(self, value: bool = object): | ||
return await self._accessor(value) | ||
|
||
async def auth_key(self, value: bytes = object): | ||
return await self._accessor(value) | ||
|
||
async def date(self, value: int = object): | ||
return await self._accessor(value) | ||
|
||
async def user_id(self, value: int = object): | ||
return await self._accessor(value) | ||
|
||
async def is_bot(self, value: bool = object): | ||
return await self._accessor(value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
aiosqlite>=0.17.0,<0.19.0 | ||
motor==3.1.2 | ||
pyaes==1.6.1 | ||
pymediainfo==6.0.1 | ||
pymongo==4.3.3 | ||
pysocks==1.7.1 | ||
aiosqlite>=0.17.0,<0.19.0 |