This is the part from controlnet-xs where the controlnet and basenet are executed together

---

In [4]:
from torch import nn
from abc import abstractmethod

In [5]:
class TimestepBlock(nn.Module):
    """Any module where forward() takes timestep embeddings as a second argument."""
    @abstractmethod
    def forward(self, x, emb): pass # Apply the module to `x` given `emb` timestep embeddings.

class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    """A sequential module that passes timestep embeddings to the children that support it as an extra input."""
    def forward(self, x, emb, context=None, skip_time_mix=False, time_context=None, num_video_frames=None, time_context_cat=None, use_crossframe_attention_in_spatial_layers=False):
        for layer in self:
            if isinstance(layer, TimestepBlock): x = layer(x, emb)
            elif isinstance(layer, SpatialTransformer): x = layer(x, context)
            elif layer.__class__.__name__ == 'SpatialTransformer': x = layer(x, context)
            else: x = layer(x)
        return x

In [10]:
class TwoStreamControlNet(nn.Module):

    def __init__(
            self,
            in_channels, model_channels, out_channels, hint_channels,
            num_res_blocks,
            attention_resolutions,
            dropout=0,
            channel_mult=(1, 2, 4, 8),
            conv_resample=True,
            dims=2,
            use_checkpoint=False,
            use_fp16=False,
            num_heads=-1, num_head_channels=-1, num_heads_upsample=-1,
            use_scale_shift_norm=False,
            resblock_updown=False,
            use_new_attention_order=False,
            adm_in_channels=None,
            use_spatial_transformer=False,  # custom transformer support
            transformer_depth=1,  # custom transformer support
            context_dim=None,  # custom transformer support
            n_embed=None,  # custom support for prediction of discrete ids into codebook of first stage vq model
            legacy=False,
            spatial_transformer_attn_type="softmax",
            use_linear_in_transformer=False,
            num_classes=None,
            control_model_ratio=1.0,    # ratio of the control model size compared to the base model. [0, 1]
            base_model=None,
            learn_embedding=False,
            control_mode='canny',
        ):
        # Umer: Let's fix some params to make the code easier
        infusion2control = 'cat' # how to infuse intermediate information into the control net? {'add', 'cat', None}
        infusion2base = 'add'    # how to infuse intermediate information into the base net? {'add', 'cat'}
        guiding='encoder',       # use just encoder for control or the whole encoder + decoder net? {'encoder', 'encoder_double', 'full'}
        two_stream_mode='cross', # mode for the two stream infusion. {'cross', 'sequential'}
        # # #
            
        super().__init__()
        
        self.control_mode = control_mode
        self.learn_embedding = learn_embedding
        self.infusion2control = infusion2control
        self.infusion2base = infusion2base
        self.in_ch_factor = 1 if infusion2control == 'add' else 2
        self.guiding = guiding
        self.two_stream_mode = two_stream_mode
        self.control_model_ratio = control_model_ratio
        self.out_channels = out_channels
        self.dims = 2
        self.model_channels = model_channels
        self.no_control = False
        self.control_scale = 1.0
    
        self.hint_model = None
        
        ################# start control model variations #################
        if base_model is None:
            base_model = UNetModel(
                adm_in_channels=adm_in_channels, num_classes=num_classes, use_checkpoint=use_checkpoint,
                in_channels=in_channels, out_channels=out_channels, model_channels=model_channels,
                attention_resolutions=attention_resolutions, num_res_blocks=num_res_blocks,
                channel_mult=channel_mult, num_head_channels=num_head_channels, use_spatial_transformer=use_spatial_transformer,
                use_linear_in_transformer=use_linear_in_transformer, transformer_depth=transformer_depth,
                context_dim=context_dim, spatial_transformer_attn_type=spatial_transformer_attn_type,
                legacy=legacy, dropout=dropout,
                conv_resample=conv_resample, dims=dims, use_fp16=use_fp16, num_heads=num_heads,
                num_heads_upsample=num_heads_upsample, use_scale_shift_norm=use_scale_shift_norm,
                resblock_updown=resblock_updown, use_new_attention_order=use_new_attention_order,
                n_embed=n_embed,
            )
    
        self.control_model = ControlledXLUNetModel(
            adm_in_channels=adm_in_channels, num_classes=num_classes, use_checkpoint=use_checkpoint,
            in_channels=in_channels, out_channels=out_channels, model_channels=model_channels,
            attention_resolutions=attention_resolutions, num_res_blocks=num_res_blocks,
            channel_mult=channel_mult, num_head_channels=num_head_channels, use_spatial_transformer=use_spatial_transformer,
            use_linear_in_transformer=use_linear_in_transformer, transformer_depth=transformer_depth,
            context_dim=context_dim, spatial_transformer_attn_type=spatial_transformer_attn_type,
            legacy=legacy, dropout=dropout,
            conv_resample=conv_resample, dims=dims, use_fp16=use_fp16, num_heads=num_heads,
            num_heads_upsample=num_heads_upsample, use_scale_shift_norm=use_scale_shift_norm,
            resblock_updown=resblock_updown, use_new_attention_order=use_new_attention_order,
            n_embed=n_embed,
            infusion2control=infusion2control,
            guiding=guiding, two_stream_mode=two_stream_mode, control_model_ratio=control_model_ratio,
        )
    
        self.diffusion_model = base_model
        ################# end control model variations #################
    
        self.enc_zero_convs_out = nn.ModuleList([])
        self.enc_zero_convs_in = nn.ModuleList([])
    
        self.middle_block_out = nn.ModuleList([])
        self.middle_block_in = nn.ModuleList([])
    
        self.dec_zero_convs_out = nn.ModuleList([])
        self.dec_zero_convs_in = nn.ModuleList([])
    
        ch_inout_ctr = {'enc': [], 'mid': [], 'dec': []}
        ch_inout_base = {'enc': [], 'mid': [], 'dec': []}
    
        ################# Gather Channel Sizes #################
        for module in self.control_model.input_blocks:
            if isinstance(module[0], nn.Conv2d):
                ch_inout_ctr['enc'].append((module[0].in_channels, module[0].out_channels))
            elif isinstance(module[0], (ResBlock, ResBlock_orig)):
                ch_inout_ctr['enc'].append((module[0].channels, module[0].out_channels))
            elif isinstance(module[0], Downsample):
                ch_inout_ctr['enc'].append((module[0].channels, module[-1].out_channels))
    
        for module in base_model.input_blocks:
            if isinstance(module[0], nn.Conv2d):
                ch_inout_base['enc'].append((module[0].in_channels, module[0].out_channels))
            elif isinstance(module[0], (ResBlock, ResBlock_orig)):
                ch_inout_base['enc'].append((module[0].channels, module[0].out_channels))
            elif isinstance(module[0], Downsample):
                ch_inout_base['enc'].append((module[0].channels, module[-1].out_channels))
    
        ch_inout_ctr['mid'].append((self.control_model.middle_block[0].channels, self.control_model.middle_block[-1].out_channels))
        ch_inout_base['mid'].append((base_model.middle_block[0].channels, base_model.middle_block[-1].out_channels))
    
        # guiding == 'encoder'
    
        for module in base_model.output_blocks:
            if isinstance(module[0], nn.Conv2d):
                ch_inout_base['dec'].append((module[0].in_channels, module[0].out_channels))
            elif isinstance(module[0], (ResBlock, ResBlock_orig)):
                ch_inout_base['dec'].append((module[0].channels, module[0].out_channels))
            elif isinstance(module[-1], Upsample):
                ch_inout_base['dec'].append((module[0].channels, module[-1].out_channels))
    
        self.ch_inout_ctr = ch_inout_ctr
        self.ch_inout_base = ch_inout_base
    
        ################# Build zero convolutions #################
        # two_stream_mode == 'cross'
        ################# cross infusion #################
        # infusion2control == 'cat'  (ie processing full concatenation (all output layers are concatenated without "slimming"))
        for ch_io_base in ch_inout_base['enc']:
            self.enc_zero_convs_in.append(self.make_zero_conv(
                in_channels=ch_io_base[1], out_channels=ch_io_base[1])
            )
            # guiding == 'encoder'
    
    
        # infusion2base (- consider all three guidings) == 'add'
        self.middle_block_out = self.make_zero_conv(ch_inout_ctr['mid'][-1][1], ch_inout_base['mid'][-1][1])
        
        # guiding == 'encoder'
        self.dec_zero_convs_out.append(
            self.make_zero_conv(ch_inout_ctr['enc'][-1][1], ch_inout_base['mid'][-1][1])
        )
        for i in range(1, len(ch_inout_ctr['enc'])):
            self.dec_zero_convs_out.append(
                self.make_zero_conv(ch_inout_ctr['enc'][-(i + 1)][1], ch_inout_base['dec'][i - 1][1])
            )
    
        
        self.input_hint_block = TimestepEmbedSequential(
            conv_nd(dims, hint_channels, 16, 3, padding=1),
            nn.SiLU(),
            conv_nd(dims, 16, 16, 3, padding=1),
            nn.SiLU(),
            conv_nd(dims, 16, 32, 3, padding=1, stride=2),
            nn.SiLU(),
            conv_nd(dims, 32, 32, 3, padding=1),
            nn.SiLU(),
            conv_nd(dims, 32, 96, 3, padding=1, stride=2),
            nn.SiLU(),
            conv_nd(dims, 96, 96, 3, padding=1),
            nn.SiLU(),
            conv_nd(dims, 96, 256, 3, padding=1, stride=2),
            nn.SiLU(),
            zero_module(conv_nd(dims, 256, int(model_channels * control_model_ratio), 3, padding=1))
        )
    
        scale_list = [1.] * len(self.enc_zero_convs_out) + [1.] + [1.] * len(self.dec_zero_convs_out)
        self.register_buffer('scale_list', torch.tensor(scale_list))

    def make_zero_conv(self, in_channels, out_channels=None):
        self.in_channels = in_channels
        self.out_channels = out_channels or in_channels
        return TimestepEmbedSequential(
            zero_module(conv_nd(self.dims, in_channels, out_channels, 1, padding=0))
        )

In [11]:
TwoStreamControlNet()

TypeError: TwoStreamControlNet.__init__() missing 6 required positional arguments: 'in_channels', 'model_channels', 'out_channels', 'hint_channels', 'num_res_blocks', and 'attention_resolutions'