In [10]:
import torch 
from pathlib import Path 
from mmengine import Config 

In [11]:
import os
os.getcwd()

'c:\\Users\\EDELLAA6Y\\multitask\\refactoring2\\burned\\notebooks'

In [12]:
import warnings
from torch.nn import functional as F


def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=True):
    if warning:
        if size is not None and align_corners:
            input_h, input_w = tuple(int(x) for x in input.shape[2:])
            output_h, output_w = tuple(int(x) for x in size)
            if output_h > input_h or output_w > output_h:
                if (
                    (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1)
                    and (output_h - 1) % (input_h - 1)
                    and (output_w - 1) % (input_w - 1)
                ):
                    warnings.warn(
                        f"When align_corners={align_corners}, "
                        "the output would more aligned if "
                        f"input size {(input_h, input_w)} is `x+1` and "
                        f"out size {(output_h, output_w)} is `nx+1`"
                    )
    return F.interpolate(input, size, scale_factor, mode, align_corners)


In [13]:
from mmseg.models.segmentors.encoder_decoder import EncoderDecoder
from mmseg.registry import MODELS
from mmseg.utils import OptSampleList
from torch import Tensor
from torch.nn import functional as F


@MODELS.register_module()
class CustomEncoderDecoder(EncoderDecoder):
    def _forward(self, inputs: Tensor, data_samples: OptSampleList = None) -> Tensor:
        """Network forward process.

        Args:
            inputs (Tensor): Inputs with shape (N, C, H, W).
            data_samples (List[:obj:`SegDataSample`]): The seg
                data samples. It usually includes information such
                as `metainfo` and `gt_sem_seg`.

        Returns:
            Tensor: Forward output of model without any post-processes.
        """
        x = self.extract_feat(inputs)
        feat = self.decode_head(x)
        out = self.decode_head.cls_seg(feat)
        out = F.interpolate(out, size=inputs.shape[2:], mode="bilinear", align_corners=True)

        if self.decode_head.has_aux_output():
            aux = self.decode_head.cls_seg_aux(feat)
            aux = F.interpolate(aux, size=inputs.shape[2:], mode="bilinear", align_corners=True)
            return out, aux

        return out


In [14]:
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.registry import MODELS
import warnings
from abc import ABCMeta, abstractmethod

from mmengine.model import BaseModule
from mmseg.models.builder import build_loss
from mmseg.structures import build_pixel_sampler


class CustomBaseDecodeHead(BaseModule, metaclass=ABCMeta):
    """Custom class for BaseDecodeHead to simply remove the loss from the head."""

    def __init__(
        self,
        in_channels,
        channels,
        *,
        num_classes,
        aux_classes=None,
        aux_factor=None,
        out_channels=None,
        threshold=None,
        dropout_ratio=0.1,
        conv_cfg=None,
        norm_cfg=None,
        act_cfg=dict(type="ReLU"),
        in_index=-1,
        input_transform=None,
        loss_decode=None,
        ignore_index=255,
        sampler=None,
        align_corners=False,
        init_cfg=dict(type="Normal", std=0.01, override=dict(name="conv_seg")),
    ):
        super().__init__(init_cfg)
        self._init_inputs(in_channels, in_index, input_transform)
        self.channels = channels
        self.dropout_ratio = dropout_ratio
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.act_cfg = act_cfg
        self.in_index = in_index

        self.ignore_index = ignore_index
        self.align_corners = align_corners

        if out_channels is None:
            if num_classes == 2:
                warnings.warn(
                    "For binary segmentation, we suggest using"
                    "`out_channels = 1` to define the output"
                    "channels of segmentor, and use `threshold`"
                    "to convert `seg_logits` into a prediction"
                    "applying a threshold"
                )
            out_channels = num_classes

        if out_channels != num_classes and out_channels != 1:
            raise ValueError(
                "out_channels should be equal to num_classes,"
                "except binary segmentation set out_channels == 1 and"
                f"num_classes == 2, but got out_channels={out_channels}"
                f"and num_classes={num_classes}"
            )

        if out_channels == 1 and threshold is None:
            threshold = 0.5
            warnings.warn("threshold is not defined for binary, and defaults" "to 0.5")
        self.num_classes = num_classes
        self.out_channels = out_channels
        self.threshold = threshold

        if isinstance(loss_decode, dict):
            self.loss_decode = build_loss(loss_decode)
        elif isinstance(loss_decode, (list, tuple)):
            self.loss_decode = nn.ModuleList()
            for loss in loss_decode:
                self.loss_decode.append(build_loss(loss))
        else:
            warnings.warn("Loss not instantiated, use manual .forward() calls")
            self.loss_decode = None

        if sampler is not None:
            self.sampler = build_pixel_sampler(sampler, context=self)
        else:
            self.sampler = None

        self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
        if aux_classes is not None:
            self.conv_seg_aux = nn.Conv2d(channels, aux_classes, kernel_size=1)
            self.aux_factor = aux_factor
        else:
            self.conv_seg_aux = None
        if dropout_ratio > 0:
            self.dropout = nn.Dropout2d(dropout_ratio)
        else:
            self.dropout = None

    def _init_inputs(self, in_channels, in_index, input_transform):
        """Check and initialize input transforms.

        The in_channels, in_index and input_transform must match.
        Specifically, when input_transform is None, only single feature map
        will be selected. So in_channels and in_index must be of type int.
        When input_transform

        Args:
            in_channels (int|Sequence[int]): Input channels.
            in_index (int|Sequence[int]): Input feature index.
            input_transform (str|None): Transformation type of input features.
                Options: 'resize_concat', 'multiple_select', None.
                'resize_concat': Multiple feature maps will be resize to the
                    same size as first one and than concat together.
                    Usually used in FCN head of HRNet.
                'multiple_select': Multiple feature maps will be bundle into
                    a list and passed into decode head.
                None: Only one select feature map is allowed.
        """

        if input_transform is not None:
            assert input_transform in ["resize_concat", "multiple_select"]
        self.input_transform = input_transform
        self.in_index = in_index
        if input_transform is not None:
            assert isinstance(in_channels, (list, tuple))
            assert isinstance(in_index, (list, tuple))
            assert len(in_channels) == len(in_index)
            if input_transform == "resize_concat":
                self.in_channels = sum(in_channels)
            else:
                self.in_channels = in_channels
        else:
            assert isinstance(in_channels, int)
            assert isinstance(in_index, int)
            self.in_channels = in_channels

    def _transform_inputs(self, inputs):
        """Transform inputs for decoder.

        Args:
            inputs (list[Tensor]): List of multi-level img features.

        Returns:
            Tensor: The transformed inputs
        """
        if self.input_transform == "resize_concat":
            inputs = [inputs[i] for i in self.in_index]
            upsampled_inputs = [
                resize(input=x, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners)
                for x in inputs
            ]
            inputs = torch.cat(upsampled_inputs, dim=1)
        elif self.input_transform == "multiple_select":
            inputs = [inputs[i] for i in self.in_index]
        else:
            inputs = inputs[self.in_index]

        return inputs

    @abstractmethod
    def forward(self, inputs, return_feat: bool = False):
        """Placeholder of forward function."""
        pass

    def has_aux_output(self):
        """Whether the head has auxiliary output."""
        return self.conv_seg_aux is not None

    def cls_seg(self, feat: torch.Tensor) -> torch.Tensor:
        """Classify each pixel."""
        if self.dropout is not None:
            feat = self.dropout(feat)
        output = self.conv_seg(feat)
        return output

    def cls_seg_aux(self, feat: torch.Tensor) -> torch.Tensor:
        """Classify each pixel."""
        if self.dropout is not None:
            feat = self.dropout(feat)
        output = self.conv_seg_aux(feat)
        return output


class PPM(nn.ModuleList):
    """Pooling Pyramid Module used in PSPNet.

    Args:
        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
            Module.
        in_channels (int): Input channels.
        channels (int): Channels after modules, before conv_seg.
        conv_cfg (dict|None): Config of conv layers.
        norm_cfg (dict|None): Config of norm layers.
        act_cfg (dict): Config of activation layers.
        align_corners (bool): align_corners argument of F.interpolate.
    """

    def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, act_cfg, align_corners, **kwargs):
        super().__init__()
        self.pool_scales = pool_scales
        self.align_corners = align_corners
        self.in_channels = in_channels
        self.channels = channels
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.act_cfg = act_cfg
        for pool_scale in pool_scales:
            self.append(
                nn.Sequential(
                    nn.AdaptiveAvgPool2d(pool_scale),
                    ConvModule(
                        self.in_channels,
                        self.channels,
                        1,
                        conv_cfg=self.conv_cfg,
                        norm_cfg=self.norm_cfg,
                        act_cfg=self.act_cfg,
                        **kwargs,
                    ),
                )
            )

    def forward(self, x):
        """Forward function."""
        ppm_outs = []
        for ppm in self:
            ppm_out = ppm(x)
            upsampled_ppm_out = resize(ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners)
            ppm_outs.append(upsampled_ppm_out)
        return ppm_outs


@MODELS.register_module()
class CustomUPerHead(CustomBaseDecodeHead):
    """Unified Perceptual Parsing for Scene Understanding.

    This head is the implementation of `UPerNet
    <https://arxiv.org/abs/1807.10221>`_.

    Args:
        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
            Module applied on the last feature. Default: (1, 2, 3, 6).
    """

    def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
        super().__init__(input_transform="multiple_select", **kwargs)
        # PSP Module
        self.psp_modules = PPM(
            pool_scales,
            self.in_channels[-1],
            self.channels,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg,
            align_corners=self.align_corners,
        )
        self.bottleneck = ConvModule(
            self.in_channels[-1] + len(pool_scales) * self.channels,
            self.channels,
            3,
            padding=1,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg,
        )
        # FPN Module
        self.lateral_convs = nn.ModuleList()
        self.fpn_convs = nn.ModuleList()
        for in_channels in self.in_channels[:-1]:  # skip the top layer
            l_conv = ConvModule(
                in_channels,
                self.channels,
                1,
                conv_cfg=self.conv_cfg,
                norm_cfg=self.norm_cfg,
                act_cfg=self.act_cfg,
                inplace=False,
            )
            fpn_conv = ConvModule(
                self.channels,
                self.channels,
                3,
                padding=1,
                conv_cfg=self.conv_cfg,
                norm_cfg=self.norm_cfg,
                act_cfg=self.act_cfg,
                inplace=False,
            )
            self.lateral_convs.append(l_conv)
            self.fpn_convs.append(fpn_conv)

        self.fpn_bottleneck = ConvModule(
            len(self.in_channels) * self.channels,
            self.channels,
            3,
            padding=1,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg,
        )

    def psp_forward(self, inputs):
        """Forward function of PSP module."""
        x = inputs[-1]
        psp_outs = [x]
        psp_outs.extend(self.psp_modules(x))
        psp_outs = torch.cat(psp_outs, dim=1)
        output = self.bottleneck(psp_outs)

        return output

    def _forward_feature(self, inputs):
        """Forward function for feature maps before classifying each pixel with
        ``self.cls_seg`` fc.

        Args:
            inputs (list[Tensor]): List of multi-level img features.

        Returns:
            feats (Tensor): A tensor of shape (batch_size, self.channels,
                H, W) which is feature map for last layer of decoder head.
        """
        inputs = self._transform_inputs(inputs)

        # build laterals
        laterals = [lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs)]

        laterals.append(self.psp_forward(inputs))

        # build top-down path
        used_backbone_levels = len(laterals)
        for i in range(used_backbone_levels - 1, 0, -1):
            prev_shape = laterals[i - 1].shape[2:]
            laterals[i - 1] = laterals[i - 1] + resize(
                laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
            )

        # build outputs
        fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
        # append psp feature
        fpn_outs.append(laterals[-1])

        for i in range(used_backbone_levels - 1, 0, -1):
            fpn_outs[i] = resize(
                fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
            )
        fpn_outs = torch.cat(fpn_outs, dim=1)
        feats = self.fpn_bottleneck(fpn_outs)
        return feats

    def forward(self, inputs):
        """Forward function."""
        return self._forward_feature(inputs)


In [15]:
from mmseg.registry import MODELS

In [16]:
cfg_path = Path("..\configs\single\pretrained\ems_upernet-rn50_single_50ep.py")
config = Config.fromfile(cfg_path)
model_config = config["model"]


In [17]:
resnet = MODELS.build(model_config)



In [18]:
resnet

CustomEncoderDecoder(
  (data_preprocessor): BaseDataPreprocessor()
  (backbone): ResNet(
    (conv1): Conv2d(12, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): ResLayer(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU

In [19]:
resnet.decode_head

CustomUPerHead(
  (conv_seg): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
  (dropout): Dropout2d(p=0.1, inplace=False)
  (psp_modules): PPM(
    (0): Sequential(
      (0): AdaptiveAvgPool2d(output_size=1)
      (1): ConvModule(
        (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activate): ReLU(inplace=True)
      )
    )
    (1): Sequential(
      (0): AdaptiveAvgPool2d(output_size=2)
      (1): ConvModule(
        (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activate): ReLU(inplace=True)
      )
    )
    (2): Sequential(
      (0): AdaptiveAvgPool2d(output_size=3)
      (1): ConvModule(
        (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(512, eps=1e-05, momentum

In [20]:
import warnings
from typing import Any, Callable, Optional

from pytorch_lightning import LightningModule
from torch import nn
from torch.optim import AdamW
from torchmetrics import F1Score, JaccardIndex
from mmseg.registry import MODELS

class BaseModule(LightningModule):
    def __init__(
        self,
        config: dict,
        tiler: Optional[Callable] = None,
        predict_callback: Optional[Callable] = None,
    ):
        super().__init__()
        self.model = MODELS.build(config)
        self.model.cfg = config
        self.tiler = tiler
        self.predict_callback = predict_callback
        self.train_metrics = nn.ModuleDict(
            {
                "train_f1": F1Score(task="binary", ignore_index=255, average="macro"),
                "train_iou": JaccardIndex(task="binary", ignore_index=255, average="macro"),
            }
        )
        self.val_metrics = nn.ModuleDict(
            {
                "val_f1": F1Score(task="binary", ignore_index=255, average="macro"),
                "val_iou": JaccardIndex(task="binary", ignore_index=255, average="macro"),
            }
        )
        self.test_metrics = nn.ModuleDict(
            {
                "test_f1": F1Score(task="binary", ignore_index=255, average="macro"),
                "test_iou": JaccardIndex(task="binary", ignore_index=255, average="macro"),
            }
        )

    def init_pretrained(self) -> None:
        assert self.model.cfg, "Model config is not set"
        config = self.model.cfg.backbone
        if "pretrained" not in config or config.pretrained is None:
            warnings.warn("No pretrained weights are specified")
            return
        self.model.backbone.load_state_dict(torch.load(config.pretrained), strict=False)
        for param in self.model.backbone.parameters():
            param.requires_grad = False

    def configure_optimizers(self) -> Any:
        return AdamW(self.parameters(), lr=1e-4, weight_decay=1e-4)


class SingleTaskModule(BaseModule):
    def __init__(
        self,
        config: dict,
        tiler: Callable[..., Any] | None = None,
        predict_callback: Callable[..., Any] | None = None,
        loss: str = "bce",
    ):
        super().__init__(config, tiler, predict_callback)
        if loss == "bce":
            self.criterion_decode = SoftBCEWithLogitsLoss(ignore_index=255, pos_weight=torch.tensor(3.0))
        else:
            self.criterion_decode = DiceLoss(mode="binary", from_logits=True, ignore_index=255)

    def training_step(self, batch: Any, batch_idx: int):
        x = batch["S2L2A"]
        y_del = batch["DEL"]

        # lc = batch["ESA_LC"]
        # x = torch.cat([x, lc.unsqueeze(1)], dim=1)
        decode_out = self.model(x)
        loss_decode = self.criterion_decode(decode_out.squeeze(1), y_del.float())
        loss = loss_decode

        self.log("train_loss", loss, on_step=True, prog_bar=True)
        for metric_name, metric in self.train_metrics.items():
            metric(decode_out.squeeze(1), y_del.float())
            self.log(metric_name, metric, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch: Any, batch_idx: int):
        x = batch["S2L2A"]
        y_del = batch["DEL"]
        # lc = batch["ESA_LC"]
        # x = torch.cat([x, lc.unsqueeze(1)], dim=1)
        decode_out = self.model(x)
        loss_decode = self.criterion_decode(decode_out.squeeze(1), y_del.float())
        loss = loss_decode

        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        for metric_name, metric in self.val_metrics.items():
            metric(decode_out.squeeze(1), y_del.float())
            self.log(metric_name, metric, on_epoch=True, prog_bar=True)
        return loss

    def test_step(self, batch: Any, batch_idx: int):
        x = batch["S2L2A"]
        y_del = batch["DEL"]
        # lc = batch["ESA_LC"]
        # x = torch.cat([x, lc.unsqueeze(1)], dim=1)
        decode_out = self.model(x)
        loss_decode = self.criterion_decode(decode_out.squeeze(1), y_del.float())
        loss = loss_decode

        self.log("test_loss", loss, on_epoch=True, logger=True)
        for metric_name, metric in self.test_metrics.items():
            metric(decode_out.squeeze(1), y_del.float())
            self.log(metric_name, metric, on_epoch=True, logger=True)
        return loss

    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
        full_image = batch["S2L2A"]

        def callback(batch: Any):
            del_out = self.model(batch)  # [b, 1, h, w]
            return del_out.squeeze(1)  # [b, h, w]

        full_pred = self.tiler(full_image[0], callback=callback)
        batch["pred"] = torch.sigmoid(full_pred)
        return batch

    def on_predict_batch_end(self, outputs: Any | None, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
        self.predict_callback(batch)


In [21]:
module = SingleTaskModule(model_config, loss = "bce")
module.init_pretrained()



NameError: name 'SoftBCEWithLogitsLoss' is not defined

In [23]:
cfg_path = Path("..\configs\single\pretrained\ems_upernet-rn50_single_50ep.py")
config = Config.fromfile(cfg_path)
model_config = config["model"]
resnet = MODELS.build(model_config)
x = torch.randn(2,12,512,512)
y = resnet(x)
print()
print("x.shape:",x.shape)
print("y.shape:",y.shape)

xb = resnet.backbone(x)
print("backbone res:",[xbi.shape for xbi in xb])
feat = resnet.decode_head(xb)
print("feat:",feat.shape)
out = resnet.decode_head.cls_seg(feat)
print("out:",out.shape)
out = F.interpolate(out, size=x.shape[2:], mode="bilinear", align_corners=True) #2,1,512,512
print("out:",out.shape)




x.shape: torch.Size([2, 12, 512, 512])
y.shape: torch.Size([2, 1, 512, 512])
backbone res: [torch.Size([2, 256, 128, 128]), torch.Size([2, 512, 64, 64]), torch.Size([2, 1024, 32, 32]), torch.Size([2, 2048, 16, 16])]
feat: torch.Size([2, 512, 128, 128])
out: torch.Size([2, 1, 128, 128])
out: torch.Size([2, 1, 512, 512])


In [27]:
cfg_path = Path("..\configs\single\pretrained\ems_upernet-rn50_single_50ep_copy.py")
config = Config.fromfile(cfg_path)
model_config = config["model"]
resnet = MODELS.build(model_config)
x = torch.randn(2,12,512,512)
y = resnet(x)
print()
print("x.shape:",x.shape)
print("y.shape:",y.shape)

xb = resnet.backbone(x)
print("backbone res:",[xbi.shape for xbi in xb])
feat = resnet.decode_head(xb)
print("feat:",feat.shape)
out = resnet.decode_head.cls_seg(feat)
print("out:",out.shape)
out = F.interpolate(out, size=x.shape[2:], mode="bilinear", align_corners=True) #2,1,512,512
print("out:",out.shape)




x.shape: torch.Size([2, 12, 512, 512])
y.shape: torch.Size([2, 1, 512, 512])
backbone res: [torch.Size([2, 96, 128, 128]), torch.Size([2, 192, 64, 64]), torch.Size([2, 384, 32, 32]), torch.Size([2, 768, 16, 16])]
feat: torch.Size([2, 512, 128, 128])
out: torch.Size([2, 1, 128, 128])
out: torch.Size([2, 1, 512, 512])


In [6]:
cfg_path = Path("..\configs\single\pretrained\swin\swin2.py")
config = Config.fromfile(cfg_path)
model_config = config["model"]
swin = MODELS.build(model_config)
x = torch.randn(2,12,512,512)
y = swin(x)
print()
print("x.shape:",x.shape)
print("y.shape:",y.shape)

xb = swin.backbone(x)
print("backbone res:",[xbi.shape for xbi in xb])
feat = swin.decode_head(xb)
print("feat:",feat.shape)
out = swin.decode_head.cls_seg(feat)
print("out:",out.shape)
out = F.interpolate(out, size=x.shape[2:], mode="bilinear", align_corners=True) #2,1,512,512
print("out:",out.shape)




x.shape: torch.Size([2, 12, 512, 512])
y.shape: torch.Size([2, 1, 128, 128])
backbone res: [torch.Size([2, 96, 128, 128]), torch.Size([2, 192, 64, 64]), torch.Size([2, 384, 32, 32]), torch.Size([2, 768, 16, 16])]
feat: torch.Size([2, 1, 128, 128])


RuntimeError: Given groups=1, weight of size [1, 512, 1, 1], expected input[2, 1, 128, 128] to have 512 channels, but got 1 channels instead

In [25]:
resnet

CustomEncoderDecoder(
  (data_preprocessor): BaseDataPreprocessor()
  (backbone): ResNet(
    (conv1): Conv2d(12, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): ResLayer(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU

In [8]:
swin.decode_head

UPerHead(
  input_transform=multiple_select, ignore_index=255, align_corners=False
  (loss_decode): CrossEntropyLoss(avg_non_ignore=False)
  (conv_seg): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
  (dropout): Dropout2d(p=0.1, inplace=False)
  (psp_modules): PPM(
    (0): Sequential(
      (0): AdaptiveAvgPool2d(output_size=1)
      (1): ConvModule(
        (conv): Conv2d(768, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activate): ReLU(inplace=True)
      )
    )
    (1): Sequential(
      (0): AdaptiveAvgPool2d(output_size=2)
      (1): ConvModule(
        (conv): Conv2d(768, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activate): ReLU(inplace=True)
      )
    )
    (2): Sequential(
      (0): AdaptiveAvgPool2d(output_size=3)
      (1): ConvModule(
        (c

In [5]:
swin.decode_head.cls_seg

<bound method BaseDecodeHead.cls_seg of UPerHead(
  input_transform=multiple_select, ignore_index=255, align_corners=False
  (loss_decode): CrossEntropyLoss(avg_non_ignore=False)
  (conv_seg): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
  (dropout): Dropout2d(p=0.1, inplace=False)
  (psp_modules): PPM(
    (0): Sequential(
      (0): AdaptiveAvgPool2d(output_size=1)
      (1): ConvModule(
        (conv): Conv2d(768, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activate): ReLU(inplace=True)
      )
    )
    (1): Sequential(
      (0): AdaptiveAvgPool2d(output_size=2)
      (1): ConvModule(
        (conv): Conv2d(768, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activate): ReLU(inplace=True)
      )
    )
    (2): Sequential(
      (0): AdaptiveAvgPool2d(output_s