<a href="https://colab.research.google.com/github/NINE9-9-9/Flood-Detection/blob/main/FCN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torchvision.models.segmentation as segmentation
import pytorch_lightning as pl
!pip install segmentation-models-pytorch==0.1.0
import segmentation_models_pytorch as smp


class FCN(pl.LightningModule):
    def __init__(self, num_classes=3, learning_rate=1e-4):
        super(FCN, self).__init__()

        self.weight = torch.Tensor([1.93445299, 36.60054169, 2.19400729])

        self.model = segmentation.fcn_resnet101(pretrained=True)

        # 修改第一层卷积层，使其接受13个通道作为输入
        self.model.backbone.conv1 = nn.Conv2d(13, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.model.classifier[4] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))

        self.learning_rate = learning_rate

    def forward(self, x):

        return self.model(x)['out']

    def training_step(self, batch, batch_idx):
        # 训练循环
        images = batch["image"].cuda()
        labels = batch["mask"].squeeze(1).cuda()
        outputs = self(images).cuda()
        loss = self.cross_entropy_loss_mask_invalid(outputs, labels, weight=self.weight)
        # loss = self.calc_loss_mask_invalid(outputs, labels, weight=self.weight)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        # 训练循环
        images = batch["image"].cuda()
        labels = batch["mask"].squeeze(1).cuda()
        outputs = self(images).cuda()
        loss = self.cross_entropy_loss_mask_invalid(outputs, labels, weight=self.weight)
        # loss = self.calc_loss_mask_invalid(outputs, labels, weight=self.weight)
        self.log('val_loss', loss)

    def configure_optimizers(self):
        # 定义优化器
        optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def cross_entropy_loss_mask_invalid(self, logits: torch.Tensor, target:torch.Tensor, weight=None) -> float:

      assert logits.dim() == 4, "Expected logits to have 4 dimensions, but got {} dimensions.".format(logits.dim())
      assert target.dim() == 3, "Expected target to have 3 dimensions, but got {} dimensions.".format(target.dim())

      valid = (target != 0)
      target_without_invalids = (target - 1) * valid

      weight = weight.cuda()
      logits = logits.cuda()
      target_without_invalids = target_without_invalids.cuda()
    # BCE Loss (ignoring invalid values)
      criterion = nn.CrossEntropyLoss(weight=weight.cuda(), reduction='none')  # (B, 1, H, W)

      bce = criterion(logits, target_without_invalids)

      bce *= valid  # mask out invalid pixels

      return torch.sum(bce / (torch.sum(valid) + 1e-6))

    def dice_loss_mask_invalid(self, logits, target, smooth=1.) -> float:

      assert logits.dim() == 4, f
      assert target.dim() == 3, f

      pred = torch.softmax(logits, dim=1)
      valid = (target != 0) # (B, H, W) tensor
      target_without_invalids = (target - 1) * valid  # Set invalids to land

      # target_without_invalids.cuda()
      target_one_hot_without_invalid = torch.nn.functional.one_hot(target_without_invalids,
                                                                 num_classes=pred.shape[1]).permute(0, 3, 1, 2)
      axis_red = (2, 3) # H, W reduction

      pred_valid = pred * valid.unsqueeze(1).float()  # # Set invalids to 0 (all values in prob tensor are 0

      intersection = (pred_valid * target_one_hot_without_invalid).sum(dim=axis_red) # (B, C) tensor

      union = pred_valid.sum(dim=axis_red) + target_one_hot_without_invalid.sum(dim=axis_red)  # (B, C) tensor

      dice_score = ((2. * intersection + smooth) /
                  (union + smooth))

      loss = (1 - dice_score)  # (B, C) tensor

      return torch.mean(loss)

    def calc_loss_mask_invalid(self, logits, target,bce_weight=0.5, weight=None):

      bce = self.cross_entropy_loss_mask_invalid(logits, target, weight=weight)

      dice = self.dice_loss_mask_invalid(logits, target) # (B, C)

      return bce * bce_weight + dice * (1 - bce_weight)

In [None]:
class FCN2(pl.LightningModule):
    def __init__(self, num_classes=3, learning_rate=1e-4):
        super(FCN2, self).__init__()

        self.weight = torch.Tensor([1.93445299, 36.60054169, 2.19400729])

        self.model = segmentation.fcn_resnet50(pretrained=True)

        # 修改第一层卷积层，使其接受13个通道作为输入
        self.model.backbone.conv1 = nn.Conv2d(13, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.model.classifier[4] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))

        self.learning_rate = learning_rate

    def forward(self, x):

        return self.model(x)['out']

    def training_step(self, batch, batch_idx):
        # 训练循环
        images = batch["image"].cuda()
        labels = batch["mask"].squeeze(1).cuda()
        outputs = self(images).cuda()
        loss = self.cross_entropy_loss_mask_invalid(outputs, labels, weight=self.weight)
        # loss = self.calc_loss_mask_invalid(outputs, labels, weight=self.weight)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        # 训练循环
        images = batch["image"].cuda()
        labels = batch["mask"].squeeze(1).cuda()
        outputs = self(images).cuda()
        # loss = self.cross_entropy_loss_mask_invalid(outputs, labels, weight=self.weight)
        loss = self.calc_loss_mask_invalid(outputs, labels, weight=self.weight)
        self.log('val_loss', loss)

    def configure_optimizers(self):
        # 定义优化器
        optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def cross_entropy_loss_mask_invalid(self, logits: torch.Tensor, target:torch.Tensor, weight=None) -> float:

      assert logits.dim() == 4, "Expected logits to have 4 dimensions, but got {} dimensions.".format(logits.dim())
      assert target.dim() == 3, "Expected target to have 3 dimensions, but got {} dimensions.".format(target.dim())

      valid = (target != 0)
      target_without_invalids = (target - 1) * valid

      weight = weight.cuda()
      logits = logits.cuda()
      target_without_invalids = target_without_invalids.cuda()
    # BCE Loss (ignoring invalid values)
      criterion = nn.CrossEntropyLoss(weight=weight.cuda(), reduction='none')  # (B, 1, H, W)

      bce = criterion(logits, target_without_invalids)

      bce *= valid  # mask out invalid pixels

      return torch.sum(bce / (torch.sum(valid) + 1e-6))

    def dice_loss_mask_invalid(self, logits, target, smooth=1.) -> float:

      assert logits.dim() == 4, f
      assert target.dim() == 3, f

      pred = torch.softmax(logits, dim=1)
      valid = (target != 0) # (B, H, W) tensor
      target_without_invalids = (target - 1) * valid  # Set invalids to land

      # target_without_invalids.cuda()
      target_one_hot_without_invalid = torch.nn.functional.one_hot(target_without_invalids,
                                                                 num_classes=pred.shape[1]).permute(0, 3, 1, 2)
      axis_red = (2, 3) # H, W reduction

      pred_valid = pred * valid.unsqueeze(1).float()  # # Set invalids to 0 (all values in prob tensor are 0

      intersection = (pred_valid * target_one_hot_without_invalid).sum(dim=axis_red) # (B, C) tensor

      union = pred_valid.sum(dim=axis_red) + target_one_hot_without_invalid.sum(dim=axis_red)  # (B, C) tensor

      dice_score = ((2. * intersection + smooth) /
                  (union + smooth))

      loss = (1 - dice_score)  # (B, C) tensor

      return torch.mean(loss)

    def calc_loss_mask_invalid(self, logits, target,bce_weight=0.5, weight=None):

      bce = self.cross_entropy_loss_mask_invalid(logits, target, weight=weight)

      dice = self.dice_loss_mask_invalid(logits, target) # (B, C)

      return bce * bce_weight + dice * (1 - bce_weight)

In [None]:
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torchvision import models
from torch.optim import Adam

class DeepLabV3(pl.LightningModule):
    def __init__(self, learning_rate=1e-4):
        super(DeepLabV3, self).__init__()

        self.weight = torch.Tensor([1.93445299, 36.60054169, 2.19400729])


        # 获取预训练的 deeplabv3_resnet101 模型
        self.deeplab = models.segmentation.deeplabv3_resnet101(pretrained=True)

        # 修改第一层以接受 13 个通道的输入
        self.deeplab.backbone.conv1 = nn.Conv2d(13, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

        # 修改分类器的最后一层以适应你的类别数量
        self.deeplab.classifier[4] = nn.Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))

        self.learning_rate = learning_rate

    def forward(self, x):
        return self.deeplab(x)['out']

    def training_step(self, batch, batch_idx):
        x, y = batch["image"].cuda(), batch["mask"].squeeze(1).cuda()
        y_hat = self(x)
        # loss = self.cross_entropy_loss_mask_invalid(y_hat, y, weight=self.weight)
        loss = self.calc_loss_mask_invalid(y_hat, y, weight=self.weight)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):

        x, y = batch["image"].cuda(), batch["mask"].squeeze(1).cuda()
        y_hat = self(x)

        # loss = self.cross_entropy_loss_mask_invalid(y_hat, y, weight=self.weight)
        loss = self.calc_loss_mask_invalid(y_hat, y, weight=self.weight)
        self.log('val_loss', loss)

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=self.learning_rate)

    def cross_entropy_loss_mask_invalid(self, logits: torch.Tensor, target:torch.Tensor, weight=None) -> float:

      assert logits.dim() == 4, "Expected logits to have 4 dimensions, but got {} dimensions.".format(logits.dim())
      assert target.dim() == 3, "Expected target to have 3 dimensions, but got {} dimensions.".format(target.dim())

      valid = (target != 0)
      target_without_invalids = (target - 1) * valid

      weight = weight.cuda()
      logits = logits.cuda()
      target_without_invalids = target_without_invalids.cuda()
    # BCE Loss (ignoring invalid values)
      criterion = nn.CrossEntropyLoss(weight=weight.cuda(), reduction='none')  # (B, 1, H, W)

      bce = criterion(logits, target_without_invalids)

      bce *= valid  # mask out invalid pixels

      return torch.sum(bce / (torch.sum(valid) + 1e-6))

    def dice_loss_mask_invalid(self, logits, target, smooth=1.) -> float:

      assert logits.dim() == 4, f
      assert target.dim() == 3, f

      pred = torch.softmax(logits, dim=1)
      valid = (target != 0) # (B, H, W) tensor
      target_without_invalids = (target - 1) * valid  # Set invalids to land

      # target_without_invalids.cuda()
      target_one_hot_without_invalid = torch.nn.functional.one_hot(target_without_invalids,
                                                                 num_classes=pred.shape[1]).permute(0, 3, 1, 2)
      axis_red = (2, 3) # H, W reduction

      pred_valid = pred * valid.unsqueeze(1).float()  # # Set invalids to 0 (all values in prob tensor are 0

      intersection = (pred_valid * target_one_hot_without_invalid).sum(dim=axis_red) # (B, C) tensor

      union = pred_valid.sum(dim=axis_red) + target_one_hot_without_invalid.sum(dim=axis_red)  # (B, C) tensor

      dice_score = ((2. * intersection + smooth) /
                  (union + smooth))

      loss = (1 - dice_score)  # (B, C) tensor

      return torch.mean(loss)

    def calc_loss_mask_invalid(self, logits, target,bce_weight=0.5, weight=None):

      bce = self.cross_entropy_loss_mask_invalid(logits, target, weight=weight)

      dice = self.dice_loss_mask_invalid(logits, target) # (B, C)

      return bce * bce_weight + dice * (1 - bce_weight)

In [None]:
import torch
import torch.nn as nn
import pytorch_lightning as pl

class PSPModule(nn.Module):
    def __init__(self, in_channels, bin_sizes):
        super(PSPModule, self).__init__()
        self.stages = nn.ModuleList([self._make_stage(in_channels, size) for size in bin_sizes])
        self.bottleneck = nn.Conv2d(in_channels * (len(bin_sizes) + 1), in_channels, kernel_size=1)
        self.relu = nn.ReLU()

    def _make_stage(self, in_channels, size):
        prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
        conv = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False)
        return nn.Sequential(prior, conv)

    def forward(self, feats):
        h, w = feats.size(2), feats.size(3)
        pyramids = [feats]
        pyramids.extend([nn.functional.interpolate(stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in self.stages])
        output = torch.cat(pyramids, dim=1)
        return self.relu(self.bottleneck(output))

class PSPNet(pl.LightningModule):
    def __init__(self, in_channels=13, num_classes=3, bin_sizes=[1, 2, 3, 6]):
        super(PSPNet, self).__init__()

        self.weight = torch.Tensor([1.93445299, 36.60054169, 2.19400729])
        resnet = models.resnet101(pretrained=True)
        resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.features = nn.Sequential(*list(resnet.children())[:-2])
        self.psp = PSPModule(2048, [1, 2, 4, 8, 16])
        self.up_sampling = nn.Sequential(
            nn.Conv2d(2048, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),

            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),

            # nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
            # nn.BatchNorm2d(256),
            # nn.ReLU(inplace=True),
            # nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),

            # nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            # nn.BatchNorm2d(256),
            # nn.ReLU(inplace=True),
            # nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )
        self.classifier = nn.Conv2d(256, num_classes, kernel_size=1, stride=1)

    def forward(self, x):
        x = self.features(x)
        x = self.psp(x)
        x = self.up_sampling(x)
        x = self.classifier(x)
        return F.interpolate(x, size=(x.size(2)*8, x.size(3)*8), mode='bilinear', align_corners=True)


    def training_step(self, batch, batch_idx):
        x, y = batch["image"].cuda(), batch["mask"].squeeze(1).cuda()
        y_hat = self(x)
        loss = self.cross_entropy_loss_mask_invalid(y_hat, y, weight=self.weight)
        return loss

    def validation_step(self, batch, batch_idx):

        x, y = batch["image"].cuda(), batch["mask"].squeeze(1).cuda()
        y_hat = self(x)

        loss = self.cross_entropy_loss_mask_invalid(y_hat, y, weight=self.weight)

        self.log('val_loss', loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)


    def cross_entropy_loss_mask_invalid(self, logits: torch.Tensor, target:torch.Tensor, weight=None) -> float:

      assert logits.dim() == 4, "Expected logits to have 4 dimensions, but got {} dimensions.".format(logits.dim())
      assert target.dim() == 3, "Expected target to have 3 dimensions, but got {} dimensions.".format(target.dim())

      valid = (target != 0)
      target_without_invalids = (target - 1) * valid

      weight = weight.cuda()
      logits = logits.cuda()
      target_without_invalids = target_without_invalids.cuda()
    # BCE Loss (ignoring invalid values)
      criterion = nn.CrossEntropyLoss(weight=weight.cuda(), reduction='none')  # (B, 1, H, W)

      bce = criterion(logits, target_without_invalids)

      bce *= valid  # mask out invalid pixels

      return torch.sum(bce / (torch.sum(valid) + 1e-6))

In [None]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torchvision.models as models
import torch.nn.functional as F

class SegNet(pl.LightningModule):
    def __init__(self, num_classes=3):
        super(SegNet, self).__init__()
        self.weight = torch.Tensor([1.93445299, 36.60054169, 2.19400729])
        # 使用预训练的 VGG16 模型
        vgg16 = models.vgg16_bn(pretrained=True)

        # 修改VGG的第一个卷积层以接受13个输入通道
        vgg16.features[0] = nn.Conv2d(13, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

        inplace = True
        batchNorm_momentum = 0.1
        vgg16_bn = list(vgg16.features.children())
        self.encoder1 = nn.Sequential(*vgg16_bn[0:6])
        self.encoder2 = nn.Sequential(*vgg16_bn[7:13])
        self.encoder3 = nn.Sequential(*vgg16_bn[14:23])
        self.encoder4 = nn.Sequential(*vgg16_bn[24:33])
        self.encoder5 = nn.Sequential(*vgg16_bn[34:-1])
        self.decoder5 = nn.Sequential(
                     nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                     nn.BatchNorm2d(512, momentum=batchNorm_momentum),
                     nn.ReLU(inplace),
                     nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                     nn.BatchNorm2d(512, momentum=batchNorm_momentum),
                     nn.ReLU(inplace),
                     nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                     nn.BatchNorm2d(512, momentum=batchNorm_momentum),
                     nn.ReLU(inplace),
                     )
        self.decoder4 = nn.Sequential(
                     nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                     nn.BatchNorm2d(512, momentum=batchNorm_momentum),
                     nn.ReLU(inplace),
                     nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                     nn.BatchNorm2d(512, momentum=batchNorm_momentum),
                     nn.ReLU(inplace),
                     nn.Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                     nn.BatchNorm2d(256, momentum=batchNorm_momentum),
                     nn.ReLU(inplace),
                     )
        self.decoder3 = nn.Sequential(
                      nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                      nn.BatchNorm2d(256, momentum=batchNorm_momentum),
                      nn.ReLU(inplace),
                      nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                      nn.BatchNorm2d(256, momentum=batchNorm_momentum),
                      nn.ReLU(inplace),
                      nn.Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                      nn.BatchNorm2d(128, momentum=batchNorm_momentum),
                      nn.ReLU(inplace),
                      )
        self.decoder2 = nn.Sequential(
                      nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                      nn.BatchNorm2d(128, momentum=batchNorm_momentum),
                      nn.ReLU(inplace),
                      nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                      nn.BatchNorm2d(64, momentum=batchNorm_momentum),
                      nn.ReLU(inplace),
                      )
        self.decoder1 = nn.Sequential(
                      nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                      nn.BatchNorm2d(64, momentum=batchNorm_momentum),
                      nn.ReLU(inplace),
                      nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                      nn.BatchNorm2d(64, momentum=batchNorm_momentum),
                      nn.ReLU(inplace),
                      nn.Conv2d(64, num_classes, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                      )

    def forward(self, x):
        x = self.encoder1(x)
        size1 = x.size()
        x, idx1 = F.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2), return_indices=True)
        x = self.encoder2(x)
        size2 = x.size()
        x, idx2 = F.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2), return_indices=True)
        x = self.encoder3(x)
        size3 = x.size()
        x, idx3 = F.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2), return_indices=True)

        x = self.encoder4(x)
        size4 = x.size()
        x, idx4 = F.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2), return_indices=True)
        x = self.encoder5(x)
        size5 = x.size()
        x, idx5 = F.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2), return_indices=True)
        x = self.decoder5(F.max_unpool2d(x, idx5, kernel_size=(2, 2), stride=(2, 2), output_size = size5))
        x = self.decoder4(F.max_unpool2d(x, idx4, kernel_size=(2, 2), stride=(2, 2), output_size = size4))

        x = self.decoder3(F.max_unpool2d(x, idx3, kernel_size=(2, 2), stride=(2, 2), output_size = size3))
        x = self.decoder2(F.max_unpool2d(x, idx2, kernel_size=(2, 2), stride=(2, 2), output_size = size2))
        x = self.decoder1(F.max_unpool2d(x, idx1, kernel_size=(2, 2), stride=(2, 2), output_size = size1))
        x = F.softmax(x, dim=1)
        return x


    def training_step(self, batch, batch_idx):
        x, y = batch["image"].cuda(), batch["mask"].squeeze(1).cuda()
        y_hat = self(x)
        loss = self.cross_entropy_loss_mask_invalid(y_hat, y, weight=self.weight)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):

        x, y = batch["image"].cuda(), batch["mask"].squeeze(1).cuda()
        y_hat = self(x)

        loss = self.cross_entropy_loss_mask_invalid(y_hat, y, weight=self.weight)

        self.log('val_loss', loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def cross_entropy_loss_mask_invalid(self, logits: torch.Tensor, target:torch.Tensor, weight=None) -> float:

      assert logits.dim() == 4, "Expected logits to have 4 dimensions, but got {} dimensions.".format(logits.dim())
      assert target.dim() == 3, "Expected target to have 3 dimensions, but got {} dimensions.".format(target.dim())

      valid = (target != 0)
      target_without_invalids = (target - 1) * valid

      weight = weight.cuda()
      logits = logits.cuda()
      target_without_invalids = target_without_invalids.cuda()
    # BCE Loss (ignoring invalid values)
      criterion = nn.CrossEntropyLoss(weight=weight.cuda(), reduction='none')  # (B, 1, H, W)

      bce = criterion(logits, target_without_invalids)

      bce *= valid  # mask out invalid pixels

      return torch.sum(bce / (torch.sum(valid) + 1e-6))

In [None]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.nn import functional as F
from torchvision.models import resnet18
from torch.nn.modules.transformer import TransformerEncoder, TransformerEncoderLayer

class SegFormer(pl.LightningModule):
    def __init__(self, in_channels=13, num_classes=3):
        super(SegFormer, self).__init__()

        self.weight = torch.Tensor([1.93445299, 36.60054169, 2.19400729])

        self.backbone = resnet18(pretrained=True)
        self.backbone.conv1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

        # Transformer 编码器
        encoder_layers = TransformerEncoderLayer(d_model=512, nhead=8)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers=4)

        batchNorm_momentum = 0.1
        inplace = True
        # 解码器部分
        self.decoder1 = nn.Sequential(
                     nn.ConvTranspose2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
                     nn.BatchNorm2d(512, momentum=batchNorm_momentum),
                     nn.ReLU(inplace))
        self.decoder2 = nn.Sequential(
                     nn.ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
                     nn.BatchNorm2d(256, momentum=batchNorm_momentum),
                     nn.ReLU(inplace))
        self.decoder3 = nn.Sequential(
                     nn.ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
                     nn.BatchNorm2d(128, momentum=batchNorm_momentum),
                     nn.ReLU(inplace))
        self.decoder4 = nn.Sequential(
                     nn.ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
                     nn.BatchNorm2d(64, momentum=batchNorm_momentum),
                     nn.ReLU(inplace))
        self.decoder5 = nn.Sequential(
                     nn.ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
                     nn.BatchNorm2d(32, momentum=batchNorm_momentum),
                     nn.ReLU(inplace))  # 添加额外的反卷积层

        # 分割头
        self.segmentation_head = nn.Conv2d(32, num_classes, kernel_size=1)

    def forward(self, x):
        # 通过 backbone
        x1 = self.backbone.conv1(x)
        x1 = self.backbone.bn1(x1)
        x1 = self.backbone.relu(x1)
        x2 = self.backbone.maxpool(x1)

        x2 = self.backbone.layer1(x2)
        x3 = self.backbone.layer2(x2)
        x4 = self.backbone.layer3(x3)
        x5 = self.backbone.layer4(x4)

        # 调整形状以适应 Transformer
        B, C, H, W = x5.shape
        x5 = x5.permute(0, 2, 3, 1)  # [B, H, W, C]
        x5 = x5.flatten(1, 2)  # [B, H*W, C]

        # 通过 Transformer 编码器
        x5 = self.transformer_encoder(x5)

        # 调整形状回来
        x5 = x5.view(B, H, W, C).permute(0, 3, 1, 2)  # [B, C, H, W]

        # 通过解码器部分
        x = self.decoder1(x5)
        x = self.decoder2(x)
        x = self.decoder3(x)
        x = self.decoder4(x)
        x = self.decoder5(x)  # 通过额外的反卷积层

        # 通过分割头
        x = self.segmentation_head(x)

        return x

    def training_step(self, batch, batch_idx):
        x, y = batch["image"].cuda(), batch["mask"].squeeze(1).cuda()
        y_hat = self(x)
        loss = self.cross_entropy_loss_mask_invalid(y_hat, y, weight=self.weight)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):

        x, y = batch["image"].cuda(), batch["mask"].squeeze(1).cuda()
        y_hat = self(x)

        loss = self.cross_entropy_loss_mask_invalid(y_hat, y, weight=self.weight)

        self.log('val_loss', loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def cross_entropy_loss_mask_invalid(self, logits: torch.Tensor, target:torch.Tensor, weight=None) -> float:

      assert logits.dim() == 4, "Expected logits to have 4 dimensions, but got {} dimensions.".format(logits.dim())
      assert target.dim() == 3, "Expected target to have 3 dimensions, but got {} dimensions.".format(target.dim())

      valid = (target != 0)
      target_without_invalids = (target - 1) * valid

      weight = weight.cuda()
      logits = logits.cuda()
      target_without_invalids = target_without_invalids.cuda()
    # BCE Loss (ignoring invalid values)
      criterion = nn.CrossEntropyLoss(weight=weight.cuda(), reduction='none')  # (B, 1, H, W)

      bce = criterion(logits, target_without_invalids)

      bce *= valid  # mask out invalid pixels

      return torch.sum(bce / (torch.sum(valid) + 1e-6))

In [None]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torchvision.models as models

class UNet(pl.LightningModule):
    def __init__(self, in_channels=13, num_classes=3):
        super(UNet, self).__init__()

        self.weight = torch.Tensor([1.93445299, 36.60054169, 2.19400729])

        self.resnet101 = models.resnet101(pretrained=True)
        self.resnet101.conv1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        # 编码器部分
        self.enc1 = nn.Sequential(self.resnet101.conv1, self.resnet101.bn1, self.resnet101.relu, self.resnet101.maxpool) (64, 128, 128)
        self.enc2 = self.resnet101.layer1 //(256, 64, 64)
        self.enc3 = self.resnet101.layer2 //(512, 32, 32)
        self.enc4 = self.resnet101.layer3 //(1024, 16, 16)
        self.enc5 = self.resnet101.layer4 //(2048, 8, 8)

        # 解码器部分
        self.dec5 = self.conv_block(2048, 1024)
        self.dec4 = self.conv_block(1024, 512)
        self.dec3 = self.conv_block(512, 256)
        self.dec2 = self.conv_block(256, 64)
        self.dec1 = nn.Conv2d(64, num_classes, 1)

        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        enc5 = self.enc5(enc4)

        dec5 = self.up(self.dec5(enc5))
        dec4 = self.up(self.dec4(dec5 + enc4))
        dec3 = self.up(self.dec3(dec4 + enc3))
        dec2 = self.up(self.dec2(dec3 + enc2))
        dec1 = self.dec1(self.up(dec2))

        return dec1

    def training_step(self, batch, batch_idx):
        x, y = batch["image"].cuda(), batch["mask"].squeeze(1).cuda()
        y_hat = self(x)
        loss = self.cross_entropy_loss_mask_invalid(y_hat, y, weight=self.weight)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):

        x, y = batch["image"].cuda(), batch["mask"].squeeze(1).cuda()
        y_hat = self(x)

        loss = self.cross_entropy_loss_mask_invalid(y_hat, y, weight=self.weight)

        self.log('val_loss', loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def cross_entropy_loss_mask_invalid(self, logits: torch.Tensor, target:torch.Tensor, weight=None) -> float:

      assert logits.dim() == 4, "Expected logits to have 4 dimensions, but got {} dimensions.".format(logits.dim())
      assert target.dim() == 3, "Expected target to have 3 dimensions, but got {} dimensions.".format(target.dim())

      valid = (target != 0)
      target_without_invalids = (target - 1) * valid

      weight = weight.cuda()
      logits = logits.cuda()
      target_without_invalids = target_without_invalids.cuda()
    # BCE Loss (ignoring invalid values)
      criterion = nn.CrossEntropyLoss(weight=weight.cuda(), reduction='none')  # (B, 1, H, W)

      bce = criterion(logits, target_without_invalids)

      bce *= valid  # mask out invalid pixels

      return torch.sum(bce / (torch.sum(valid) + 1e-6))

In [None]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torchvision.models as models


class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.psi(g1 + x1)
        return x * psi


class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels, rates=[12, 24, 36]):
        super(ASPP, self).__init__()
        self.aspp_blocks = nn.ModuleList()

        # 1x1 Convolution
        self.aspp_blocks.append(nn.Conv2d(in_channels, out_channels, 1, bias=False))

        # 3x3 Convolution with different dilation rates
        for rate in rates:
            self.aspp_blocks.append(nn.Conv2d(in_channels, out_channels, 3, padding=rate, dilation=rate, bias=False))

        # Image-level features
        self.global_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(in_channels, out_channels, 1, bias=False)
        )

    def forward(self, x):
        features = [block(x) for block in self.aspp_blocks]
        global_features = self.global_avg_pool(x)
        global_features = nn.Upsample(size=x.size()[2:], mode='bilinear', align_corners=True)(global_features)
        features.append(global_features)
        return torch.cat(features, dim=1)


class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)


class UNetPlus(pl.LightningModule):
    def __init__(self, in_channels=13, num_classes=3):
        super(UNetPlus, self).__init__()

        self.weight = torch.Tensor([1.93445299, 36.60054169, 2.19400729])

        self.resnet101 = models.resnet101(pretrained=True)
        self.resnet101.conv1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        # 编码器部分
        self.enc1 = nn.Sequential(self.resnet101.conv1, self.resnet101.bn1, self.resnet101.relu, self.resnet101.maxpool)
        self.enc2 = self.resnet101.layer1
        self.sbblock1 = SEBlock(256)
        self.enc3 = self.resnet101.layer2
        self.sbblock2 = SEBlock(512)
        self.enc4 = self.resnet101.layer3
        self.sbblock3 = SEBlock(1024)
        self.enc5 = self.resnet101.layer4

        self.aspp1 = ASPP(2048, 256)
        self.project1 = nn.Sequential(nn.Conv2d(1280, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False),
                      nn.BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                      nn.ReLU(),
                      nn.Dropout(p=0.5, inplace=False))

        self.attention1 = AttentionBlock(2048, 2048, 2048)
        self.up1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec1 = self.conv_block(4096, 1024)
        self.attention2 = AttentionBlock(1024, 1024, 1024)
        self.up2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec2 = self.conv_block(2048, 512)
        self.attention3 = AttentionBlock(512, 512, 512)
        self.up3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec3 = self.conv_block(1024, 256)
        self.attention4 = AttentionBlock(256, 256, 256)
        self.up4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec4 = self.conv_block(512, 256)
        self.up5 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.aspp2 = ASPP(256, 32)
        self.project2 = nn.Sequential(nn.Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False),
                      nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                      nn.ReLU(),
                      nn.Dropout(p=0.5, inplace=False))

        self.classfication = nn.Conv2d(128, num_classes, kernel_size=(1, 1), stride=(1, 1))


    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(enc1)
        # sbblock1 = self.sbblock1(enc2)
        enc3 = self.enc3(enc2)
        # sbblock2 = self.sbblock2(enc3)
        enc4 = self.enc4(enc3)
        # sbblock3 = self.sbblock3(enc4)
        enc5 = self.enc5(enc4)

        aspp1 = self.aspp1(enc5)
        project1 = self.project1(aspp1)

        # up1 = self.up(project1)
        attention1 = self.attention1(enc5, project1)
        dec1 = self.dec1(torch.cat((attention1, enc5), dim=1))

        up2 = self.up1(dec1)
        attention2 = self.attention2(enc4, up2)
        dec2 = self.dec2(torch.cat((attention2, enc4), dim=1))

        up3 = self.up2(dec2)
        attention3 = self.attention3(enc3, up3)
        dec3 = self.dec3(torch.cat((attention3, enc3), dim=1))


        up4 = self.up3(dec3)
        attention4 = self.attention4(enc2, up4)
        dec4 = self.dec4(torch.cat((attention4, enc2), dim=1))

        up5 = self.up4(dec4)
        aspp2 = self.aspp2(up5)
        up1 = self.up5(aspp2)
        project2 = self.project2(up1)

        classification = self.classfication(project2)

        return classification

    def training_step(self, batch, batch_idx):
        x, y = batch["image"].cuda(), batch["mask"].squeeze(1).cuda()
        y_hat = self(x)
        loss = self.cross_entropy_loss_mask_invalid(y_hat, y, weight=self.weight)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):

        x, y = batch["image"].cuda(), batch["mask"].squeeze(1).cuda()
        y_hat = self(x)

        loss = self.cross_entropy_loss_mask_invalid(y_hat, y, weight=self.weight)

        self.log('val_loss', loss)



    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def cross_entropy_loss_mask_invalid(self, logits: torch.Tensor, target:torch.Tensor, weight=None) -> float:

      assert logits.dim() == 4, "Expected logits to have 4 dimensions, but got {} dimensions.".format(logits.dim())
      assert target.dim() == 3, "Expected target to have 3 dimensions, but got {} dimensions.".format(target.dim())

      valid = (target != 0)
      target_without_invalids = (target - 1) * valid

      weight = weight.cuda()
      logits = logits.cuda()
      target_without_invalids = target_without_invalids.cuda()
    # BCE Loss (ignoring invalid values)
      criterion = nn.CrossEntropyLoss(weight=weight.cuda(), reduction='none')  # (B, 1, H, W)

      bce = criterion(logits, target_without_invalids)

      bce *= valid  # mask out invalid pixels

      return torch.sum(bce / (torch.sum(valid) + 1e-6))



In [None]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torchvision.models as models


class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.psi(g1 + x1)
        return x * psi


class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels, rates=[12, 24, 36]):
        super(ASPP, self).__init__()
        self.aspp_blocks = nn.ModuleList()

        # 1x1 Convolution
        self.aspp_blocks.append(nn.Conv2d(in_channels, out_channels, 1, bias=False))

        # 3x3 Convolution with different dilation rates
        for rate in rates:
            self.aspp_blocks.append(nn.Conv2d(in_channels, out_channels, 3, padding=rate, dilation=rate, bias=False))

        # Image-level features
        self.global_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(in_channels, out_channels, 1, bias=False)
        )

    def forward(self, x):
        features = [block(x) for block in self.aspp_blocks]
        global_features = self.global_avg_pool(x)
        global_features = nn.Upsample(size=x.size()[2:], mode='bilinear', align_corners=True)(global_features)
        features.append(global_features)
        return torch.cat(features, dim=1)


class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)


class PSPModule(nn.Module):
    def __init__(self, in_channels, bin_sizes):
        super(PSPModule, self).__init__()
        self.stages = nn.ModuleList([self._make_stage(in_channels, size) for size in bin_sizes])
        self.bottleneck = nn.Conv2d(in_channels * (len(bin_sizes) + 1), in_channels, kernel_size=1)
        self.relu = nn.ReLU()

    def _make_stage(self, in_channels, size):
        prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
        conv = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False)
        return nn.Sequential(prior, conv)

    def forward(self, feats):
        h, w = feats.size(2), feats.size(3)
        pyramids = [feats]
        pyramids.extend([nn.functional.interpolate(stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in self.stages])
        output = torch.cat(pyramids, dim=1)
        return self.relu(self.bottleneck(output))


class PSPDeep(pl.LightningModule):
    def __init__(self, in_channels=13, num_classes=3):
        super(PSPDeep, self).__init__()

        self.weight = torch.Tensor([1.93445299, 36.60054169, 2.19400729])

        self.resnet101 = models.resnet101(pretrained=True)
        self.resnet101.conv1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        # 编码器部分
        self.enc1 = nn.Sequential(self.resnet101.conv1, self.resnet101.bn1, self.resnet101.relu, self.resnet101.maxpool)
        self.enc2 = self.resnet101.layer1

        self.enc3 = self.resnet101.layer2

        self.enc4 = self.resnet101.layer3

        self.enc5 = self.resnet101.layer4

        self.aspp1 = ASPP(2048, 256)
        self.project1 = nn.Sequential(nn.Conv2d(1280, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False),
                      nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                      nn.ReLU(),
                      nn.Dropout(p=0.5, inplace=False))

        self.psp = PSPModule(2048, [1, 4, 9, 16, 25])
        self.project2 = nn.Sequential(nn.Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False),
                      nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                      nn.ReLU(),
                      nn.Dropout(p=0.5, inplace=False))
        self.up1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        # self.up2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
        self.dec5 = self.conv_block(2048, 1024)
        self.dec4 = self.conv_block(1024, 512)
        self.dec3 = self.conv_block(512, 256)
        self.dec2 = self.conv_block(256, 64)
        # self.dec1 = nn.Conv2d(64, num_classes, 1)
        self.classfication = nn.Conv2d(64, num_classes, kernel_size=(1, 1), stride=(1, 1))


    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        enc5 = self.enc5(enc4)

        aspp1 = self.aspp1(enc5)
        project1 = self.project1(aspp1)

        psp = self.psp(enc5)
        project2 = self.project2(psp)

        dec5 = self.dec5(self.up1(torch.cat((project1,project2), dim=1)))
        dec4 = self.dec4(self.up1(dec5))
        dec3 = self.dec3(self.up1(dec4))
        dec2 = self.dec2(self.up1(dec3))

        classification = self.classfication(self.up1(dec2))

        return classification

    def training_step(self, batch, batch_idx):
        x, y = batch["image"].cuda(), batch["mask"].squeeze(1).cuda()
        y_hat = self(x)
        loss = self.cross_entropy_loss_mask_invalid(y_hat, y, weight=self.weight)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):

        x, y = batch["image"].cuda(), batch["mask"].squeeze(1).cuda()
        y_hat = self(x)

        loss = self.cross_entropy_loss_mask_invalid(y_hat, y, weight=self.weight)

        self.log('val_loss', loss)



    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def cross_entropy_loss_mask_invalid(self, logits: torch.Tensor, target:torch.Tensor, weight=None) -> float:

      assert logits.dim() == 4, "Expected logits to have 4 dimensions, but got {} dimensions.".format(logits.dim())
      assert target.dim() == 3, "Expected target to have 3 dimensions, but got {} dimensions.".format(target.dim())

      valid = (target != 0)
      target_without_invalids = (target - 1) * valid

      weight = weight.cuda()
      logits = logits.cuda()
      target_without_invalids = target_without_invalids.cuda()
    # BCE Loss (ignoring invalid values)
      criterion = nn.CrossEntropyLoss(weight=weight.cuda(), reduction='none')  # (B, 1, H, W)

      bce = criterion(logits, target_without_invalids)

      bce *= valid  # mask out invalid pixels

      return torch.sum(bce / (torch.sum(valid) + 1e-6))

