Skip to content

Commit

Permalink
linter
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Jun 16, 2024
1 parent 5b2a60b commit 79de09c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
19 changes: 10 additions & 9 deletions modules/models/sd3/other_impls.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
### This file contains impls for underlying related models (CLIP, T5, etc)

import torch, math
import torch
import math
from torch import nn
from transformers import CLIPTokenizer, T5TokenizerFast

Expand All @@ -14,7 +15,7 @@ def attention(q, k, v, heads, mask=None):
"""Convenience wrapper around a basic attention operation"""
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v))
q, k, v = [t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)]
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)

Expand Down Expand Up @@ -89,8 +90,8 @@ def forward(self, x, mask=None, intermediate_output=None):
if intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
intermediate = None
for i, l in enumerate(self.layers):
x = l(x, mask)
for i, layer in enumerate(self.layers):
x = layer(x, mask)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
Expand Down Expand Up @@ -215,7 +216,7 @@ def tokenize_with_weights(self, text:str):

class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
tokens = list(map(lambda a: a[0], token_weight_pairs[0]))
tokens = [a[0] for a in token_weight_pairs[0]]
out, pooled = self([tokens])
if pooled is not None:
first_pooled = pooled[0:1].cpu()
Expand All @@ -229,7 +230,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = ["last", "pooled", "hidden"]
def __init__(self, device="cpu", max_length=77, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, return_projected_pooled=True):
special_tokens=None, layer_norm_hidden_state=True, return_projected_pooled=True):
super().__init__()
assert layer in self.LAYERS
self.transformer = model_class(textmodel_json_config, dtype, device)
Expand All @@ -240,7 +241,7 @@ def __init__(self, device="cpu", max_length=77, layer="last", layer_idx=None, te
param.requires_grad = False
self.layer = layer
self.layer_idx = None
self.special_tokens = special_tokens
self.special_tokens = special_tokens if special_tokens is not None else {"start": 49406, "end": 49407, "pad": 49407}
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled
Expand Down Expand Up @@ -465,8 +466,8 @@ def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermed
intermediate = None
x = self.embed_tokens(input_ids)
past_bias = None
for i, l in enumerate(self.block):
x, past_bias = l(x, past_bias)
for i, layer in enumerate(self.block):
x, past_bias = layer(x, past_bias)
if i == intermediate_output:
intermediate = x.clone()
x = self.final_layer_norm(x)
Expand Down
10 changes: 6 additions & 4 deletions modules/models/sd3/sd3_impls.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
### Impls of the SD3 core diffusion model and VAE

import torch, math, einops
import torch
import math
import einops
from modules.models.sd3.mmdit import MMDiT
from PIL import Image

Expand Down Expand Up @@ -214,7 +216,7 @@ def forward(self, x):
k = self.k(hidden)
v = self.v(hidden)
b, c, h, w = q.shape
q, k, v = map(lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v))
q, k, v = [einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous() for x in (q, k, v)]
hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
hidden = self.proj_out(hidden)
Expand Down Expand Up @@ -259,7 +261,7 @@ def __init__(self, ch=128, ch_mult=(1,2,4,4), num_res_blocks=2, in_channels=3, z
attn = torch.nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for i_block in range(num_res_blocks):
for _ in range(num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
block_in = block_out
down = torch.nn.Module()
Expand Down Expand Up @@ -318,7 +320,7 @@ def __init__(self, ch=128, out_ch=3, ch_mult=(1, 2, 4, 4), num_res_blocks=2, res
for i_level in reversed(range(self.num_resolutions)):
block = torch.nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
block_in = block_out
up = torch.nn.Module()
Expand Down

0 comments on commit 79de09c

Please sign in to comment.