In [31]:
from torch import nn
from torch.nn import functional as F
class DeepLabV3(nn.Module):
    def __init__(self):
        super(DeepLabV3, self).__init__()
        
        # Initial layers (unchanged)
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0)  # Downscale 256x256 -> 128x128
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0)  # Downscale 128x128 -> 64x64
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0)  # Downscale 64x64 -> 32x32
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True)  # Keep 32x32
        )
        self.conv5 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, dilation=2, padding=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, dilation=2, padding=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, dilation=2, padding=2),
            nn.ReLU(inplace=True),
        )

        # Replace the old atrous pyramid pooling with DeepLabV3-style ASPP
        self.aspp_conv1 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=0.5)
        )
        
        self.aspp_conv2 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=6, dilation=6, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=0.5)
        )
        
        self.aspp_conv3 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=12, dilation=12, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=0.5)
        )
        
        self.aspp_conv4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=18, dilation=18, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=0.5)
        )
        
        # Global pooling branch
        self.aspp_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(256, 512, kernel_size=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=0.5)
        )
        
        # Final 1x1 conv after concatenation
        self.aspp_final = nn.Sequential(
            nn.Conv2d(512 * 5, 512, kernel_size=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=0.5)
        )

        self.embedding_layer = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1)

    def forward_branch(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        conv3_feature = self.conv3(x)
        conv4_feature = self.conv4(conv3_feature)
        conv5_feature = self.conv5(conv4_feature)
        
        # ASPP forward pass
        aspp1 = self.aspp_conv1(conv5_feature)
        aspp2 = self.aspp_conv2(conv5_feature)
        aspp3 = self.aspp_conv3(conv5_feature)
        aspp4 = self.aspp_conv4(conv5_feature)
        
        # Global pooling branch
        pool = self.aspp_pool(conv5_feature)
        pool = F.interpolate(pool, size=conv5_feature.shape[2:], mode='bilinear', align_corners=False)
        
        # Concatenate all branches
        aspp_out = torch.cat([aspp1, aspp2, aspp3, aspp4, pool], dim=1)
        aspp_out = self.aspp_final(aspp_out)
        
        # Final embedding
        embedding_feature = self.embedding_layer(aspp_out)
        
        return conv4_feature, conv5_feature, embedding_feature

    def normalize(self, x, scale=1.0, dim=1):
        norm = x.pow(2).sum(dim=dim, keepdim=True).clamp(min=1e-12).rsqrt()
        return scale * x * norm

    def forward(self, x1, mode='train'):
        out1 = self.forward_branch(x1)
        return out1

In [32]:
!which pip

/home/adil/Documents/TUE/ThesisPrepPhase/myProject/.conda/bin/pip


In [33]:
!which python

/home/adil/Documents/TUE/ThesisPrepPhase/myProject/.conda/bin/python


In [36]:
import torch
from torchview import draw_graph
# Load the DeepLabV3 Siamese model
model = DeepLabV3()

# Simulated input (batch_size=16, grayscale input)
x = torch.randn(16, 1, 256, 256)

# Generate model visualization (handles two inputs correctly)
model_graph = draw_graph(model, input_data=x, expand_nested=True, depth=1)

# Show visualization
model_graph.visual_graph.render("images/siamese_simple", format="png")


'images/siamese_simple.png'

In [None]:
import torch
from torchview import draw_graph
# Load the DeepLabV3 Siamese model
model = DeepLabV3()

# Simulated input (batch_size=16, grayscale input)
x = torch.randn(16, 1, 256, 256)

# Generate model visualization (handles two inputs correctly)
model_graph = draw_graph(model, input_data=x, expand_nested=True, depth=6)

# Show visualization
model_graph.visual_graph.render("images/siamese_depth", format="png")


RuntimeError: Failed to run torchgraph see error message