Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some deprecated types #12846

Merged
merged 1 commit into from
Sep 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 13 additions & 13 deletions modules/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
from typing import Dict, List, Any
from typing import Any
import piexif
import piexif.helper
from contextlib import closing
Expand Down Expand Up @@ -221,15 +221,15 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem])
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem])
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem])
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
Expand All @@ -242,8 +242,8 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=List[models.ExtensionItem])
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo])
self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])

if shared.cmd_opts.api_server_stop:
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
Expand Down Expand Up @@ -563,7 +563,7 @@ def get_config(self):

return options

def set_config(self, req: Dict[str, Any]):
def set_config(self, req: dict[str, Any]):
checkpoint_name = req.get("sd_model_checkpoint", None)
if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
raise RuntimeError(f"model {checkpoint_name!r} not found")
Expand Down
24 changes: 11 additions & 13 deletions modules/api/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import inspect

from pydantic import BaseModel, Field, create_model
from typing import Any, Optional
from typing_extensions import Literal
from typing import Any, Optional, Literal
from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
from modules.shared import sd_upscalers, opts, parser
from typing import Dict, List

API_NOT_ALLOWED = [
"self",
Expand Down Expand Up @@ -130,12 +128,12 @@ def generate_model(self):
).generate_model()

class TextToImageResponse(BaseModel):
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str

class ImageToImageResponse(BaseModel):
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str

Expand Down Expand Up @@ -168,10 +166,10 @@ class FileData(BaseModel):
name: str = Field(title="File name")

class ExtrasBatchImagesRequest(ExtrasBaseRequest):
imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")

class ExtrasBatchImagesResponse(ExtraBaseResponse):
images: List[str] = Field(title="Images", description="The generated images in base64 format.")
images: list[str] = Field(title="Images", description="The generated images in base64 format.")

class PNGInfoRequest(BaseModel):
image: str = Field(title="Image", description="The base64 encoded PNG image")
Expand Down Expand Up @@ -233,8 +231,8 @@ class PreprocessResponse(BaseModel):

class SamplerItem(BaseModel):
name: str = Field(title="Name")
aliases: List[str] = Field(title="Aliases")
options: Dict[str, str] = Field(title="Options")
aliases: list[str] = Field(title="Aliases")
options: dict[str, str] = Field(title="Options")

class UpscalerItem(BaseModel):
name: str = Field(title="Name")
Expand Down Expand Up @@ -285,8 +283,8 @@ class EmbeddingItem(BaseModel):
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")

class EmbeddingsResponse(BaseModel):
loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
loaded: dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
skipped: dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")

class MemoryResponse(BaseModel):
ram: dict = Field(title="RAM", description="System memory stats")
Expand All @@ -304,14 +302,14 @@ class ScriptArg(BaseModel):
minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
choices: Optional[list[str]] = Field(default=None, title="Choices", description="Possible values for the argument")


class ScriptInfo(BaseModel):
name: str = Field(default=None, title="Name", description="Script name")
is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
args: list[ScriptArg] = Field(title="Arguments", description="List of script's arguments")

class ExtensionItem(BaseModel):
name: str = Field(title="Name", description="Extension name")
Expand Down
2 changes: 1 addition & 1 deletion modules/gitpython_hack.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def get_object_header(self, ref: str | bytes) -> tuple[str, str, int]:
)
return self._parse_object_header(ret)

def stream_object_data(self, ref: str) -> tuple[str, str, int, "Git.CatFileContentStream"]:
def stream_object_data(self, ref: str) -> tuple[str, str, int, Git.CatFileContentStream]:
# Not really streaming, per se; this buffers the entire object in memory.
# Shouldn't be a problem for our use case, since we're only using this for
# object headers (commit objects).
Expand Down
7 changes: 3 additions & 4 deletions modules/prompt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import re
from collections import namedtuple
from typing import List
import lark

# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
Expand Down Expand Up @@ -240,14 +239,14 @@ def get_multicond_prompt_list(prompts: SdConditioning | list[str]):

class ComposableScheduledPromptConditioning:
def __init__(self, schedules, weight=1.0):
self.schedules: List[ScheduledPromptConditioning] = schedules
self.schedules: list[ScheduledPromptConditioning] = schedules
self.weight: float = weight


class MulticondLearnedConditioning:
def __init__(self, shape, batch):
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
self.batch: list[list[ComposableScheduledPromptConditioning]] = batch


def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
Expand Down Expand Up @@ -278,7 +277,7 @@ def shape(self):
return self["crossattn"].shape


def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
param = c[0][0].cond
is_dict = isinstance(param, dict)

Expand Down
6 changes: 3 additions & 3 deletions modules/script_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import os
from collections import namedtuple
from typing import Optional, Dict, Any
from typing import Optional, Any

from fastapi import FastAPI
from gradio import Blocks
Expand Down Expand Up @@ -255,7 +255,7 @@ def image_grid_callback(params: ImageGridLoopParams):
report_exception(c, 'image_grid')


def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
for c in callback_map['callbacks_infotext_pasted']:
try:
c.callback(infotext, params)
Expand Down Expand Up @@ -446,7 +446,7 @@ def on_infotext_pasted(callback):
"""register a function to be called before applying an infotext.
The callback is called with two arguments:
- infotext: str - raw infotext.
- result: Dict[str, any] - parsed infotext parameters.
- result: dict[str, any] - parsed infotext parameters.
"""
add_callback(callback_map['callbacks_infotext_pasted'], callback)

Expand Down
4 changes: 2 additions & 2 deletions modules/sub_quadratic_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import math
from typing import Optional, NamedTuple, List
from typing import Optional, NamedTuple


def narrow_trunc(
Expand Down Expand Up @@ -97,7 +97,7 @@ def chunk_scanner(chunk_idx: int) -> AttnChunk:
)
return summarize_chunk(query, key_chunk, value_chunk)

chunks: List[AttnChunk] = [
chunks: list[AttnChunk] = [
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
]
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
Expand Down
3 changes: 1 addition & 2 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,7 +1338,6 @@ def versions_html():

def setup_ui_api(app):
from pydantic import BaseModel, Field
from typing import List

class QuicksettingsHint(BaseModel):
name: str = Field(title="Name of the quicksettings field")
Expand All @@ -1347,7 +1346,7 @@ class QuicksettingsHint(BaseModel):
def quicksettings_hint():
return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()]

app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint])
app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=list[QuicksettingsHint])

app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])

Expand Down