In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchsummary import summary
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, random_split
from torch.cuda.amp import GradScaler, autocast
from torchvision.utils import save_image
from PIL import Image
import os
from PIL import Image
import numpy as np
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm


In [2]:
class DualTaskResNet(nn.Module):
    def __init__(self, num_classes=1, pretrained=True):
        super(DualTaskResNet, self).__init__()
        
        resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
        self.resnet18 = nn.Sequential(*list(resnet.children())[:-2])
        self.layer0 = nn.Sequential(*list(self.resnet18.children())[:3])
        self.layer1 = nn.Sequential(*list(self.resnet18.children())[3:5])
        self.layer2 = self.resnet18[5]
        self.layer3 = self.resnet18[6]
        self.layer4 = self.resnet18[7]

        # Dilated convolutions for layer2 (reduced to 3)
        self.dilation_conv1_l2 = self._make_dilated_conv(128, 256, 2)
        self.dilation_conv2_l2 = self._make_dilated_conv(128, 256, 4)
        self.dilation_conv3_l2 = self._make_dilated_conv(128, 256, 8)

        # Dilated convolutions for layer3 (reduced to 3)
        self.dilation_conv1_l3 = self._make_dilated_conv(256, 512, 2)
        self.dilation_conv2_l3 = self._make_dilated_conv(256, 512, 4)
        self.dilation_conv3_l3 = self._make_dilated_conv(256, 512, 8)

        # New: Dilated convolutions for layer4
        self.dilation_conv1_l4 = self._make_dilated_conv(512, 1024, 2)
        self.dilation_conv2_l4 = self._make_dilated_conv(512, 1024, 4)
        self.dilation_conv3_l4 = self._make_dilated_conv(512, 1024, 8)

        # Upsampling path (adjusted for new dimensions)
        self.upsample1 = self._make_transpose_conv(3072, 512, 2)  # 7x7 -> 14x14
        self.upsample2 = self._make_transpose_conv(2048, 512, 2)  # 14x14 -> 28x28
        self.upsample3 = self._make_transpose_conv(1280, 256, 2)  # 28x28 -> 56x56
        self.upsample4 = self._make_transpose_conv(256, 128, 2)  # 56x56 -> 112x112
        self.upsample5 = self._make_transpose_conv(128, 64, 2)  # 112x112 -> 224x224
        self.convf = nn.Conv2d(64, num_classes, kernel_size=1)

        # Task-specific output layers
        self.mask_output = nn.Conv2d(num_classes, num_classes, kernel_size=1)
        self.skeleton_output = nn.Conv2d(num_classes, num_classes, kernel_size=1)

    def _make_dilated_conv(self, in_channels, out_channels, dilation):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=dilation, dilation=dilation),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def _make_transpose_conv(self, in_channels, out_channels, scale_factor):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=scale_factor, padding=0, output_padding=0),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, img, task='mask'):
        # Expected input size: 224x224x3
        layer0 = self.layer0(img)  # 112x112x64
        layer1 = self.layer1(layer0)  # 56x56x64
        layer2 = self.layer2(layer1)  # 28x28x128
        layer3 = self.layer3(layer2)  # 14x14x256
        layer4 = self.layer4(layer3)  # 7x7x512

        # Apply dilation to layer2 (28x28x128)
        y1 = self.dilation_conv1_l2(layer2)
        y2 = self.dilation_conv2_l2(layer2)
        y3 = self.dilation_conv3_l2(layer2)
        y = torch.cat([y1, y2, y3], dim=1)  # 28x28x768

        # Apply dilation to layer3 (14x14x256)
        z1 = self.dilation_conv1_l3(layer3)
        z2 = self.dilation_conv2_l3(layer3)
        z3 = self.dilation_conv3_l3(layer3)
        z = torch.cat([z1, z2, z3], dim=1)  # 14x14x1536

        # Apply dilation to layer4 (7x7x512)
        w1 = self.dilation_conv1_l4(layer4)
        w2 = self.dilation_conv2_l4(layer4)
        w3 = self.dilation_conv3_l4(layer4)
        w = torch.cat([w1, w2, w3], dim=1)  # 7x7x3072

        # Upsampling path
        x = self.upsample1(w)  # 14x14x512
        x = torch.cat([x, z], dim=1)  # 14x14x2048
        x = self.upsample2(x)  # 28x28x512
        x = torch.cat([x, y], dim=1)  # 28x28x1280
        x = self.upsample3(x)  # 56x56x256
        x = self.upsample4(x)  # 112x112x128
        x = self.upsample5(x)  # 224x224x64
        x = self.convf(x)  # 224x224xnum_classes

        if task == 'mask':
            output = self.mask_output(x)
        elif task == 'skeleton':
            output = self.skeleton_output(x)
        else:
            raise ValueError("Task must be either 'mask' or 'skeleton'")

        return torch.sigmoid(output)

In [3]:
# Test function
def test_dualtask_resnet():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create model instance
    model = DualTaskResNet(num_classes=1, pretrained=False).to(device)

    # Print model summary
    summary(model, (3, 224, 224))

    # Generate random input
    batch_size = 4
    input_tensor = torch.randn(batch_size, 3, 224, 224).to(device)

    # Test mask generation
    print("\nTesting mask generation:")
    mask_output = model(input_tensor, task='mask')
    print(f"Mask output shape: {mask_output.shape}")
    print(f"Mask output min: {mask_output.min().item():.4f}, max: {mask_output.max().item():.4f}")

    # Test skeletonization
    print("\nTesting skeletonization:")
    skeleton_output = model(input_tensor, task='skeleton')
    print(f"Skeleton output shape: {skeleton_output.shape}")
    print(f"Skeleton output min: {skeleton_output.min().item():.4f}, max: {skeleton_output.max().item():.4f}")

    # Test if outputs are different
    print("\nChecking if mask and skeleton outputs are different:")
    is_different = not torch.allclose(mask_output, skeleton_output)
    print(f"Outputs are different: {is_different}")



In [4]:
test_dualtask_resnet()

Using device: cuda
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
            Conv2d-2         [-1, 64, 112, 112]           9,408
       BatchNorm2d-3         [-1, 64, 112, 112]             128
       BatchNorm2d-4         [-1, 64, 112, 112]             128
              ReLU-5         [-1, 64, 112, 112]               0
              ReLU-6         [-1, 64, 112, 112]               0
         MaxPool2d-7           [-1, 64, 56, 56]               0
         MaxPool2d-8           [-1, 64, 56, 56]               0
            Conv2d-9           [-1, 64, 56, 56]          36,864
           Conv2d-10           [-1, 64, 56, 56]          36,864
      BatchNorm2d-11           [-1, 64, 56, 56]             128
      BatchNorm2d-12           [-1, 64, 56, 56]             128
             ReLU-13           [-1, 64, 56, 56]               0
             ReLU-14