In [None]:
import torch
import torch.nn as nn
from torchvision import models

class UNetWithResNet(nn.Module):
    def __init__(self, pretrained=True, out_channels=1):
        super(UNetWithResNet, self).__init__()
        
        # Load a pretrained ResNet model (e.g., ResNet34)
        resnet = models.resnet34(pretrained=pretrained)
        
        # Encoder: Remove the classification head (fully connected layers)
        self.encoder = nn.Sequential(
            resnet.conv1,   # Initial convolution layer
            resnet.bn1,     # BatchNorm
            resnet.relu,    # ReLU activation
            resnet.maxpool, # Max pooling
            resnet.layer1,  # First residual block
            resnet.layer2,  # Second residual block
            resnet.layer3,  # Third residual block
            resnet.layer4   # Fourth residual block
        )
        
        # Decoder: Transposed convolutions (Upsampling layers) with skip connections
        self.upconv4 = self.decoder_block(512, 256)
        self.upconv3 = self.decoder_block(256, 128)
        self.upconv2 = self.decoder_block(128, 64)
        self.upconv1 = self.decoder_block(64, 64)
        
        # Final output layer (1 output channel for binary mask)
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
    
    def decoder_block(self, in_channels, out_channels):
        """ Helper function to create decoder blocks for upsampling """
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
    
    def forward(self, x):
        # Encoder forward pass
        enc1 = self.encoder[0](x)  # conv1
        enc2 = self.encoder[4](enc1)  # layer1
        enc3 = self.encoder[5](enc2)  # layer2
        enc4 = self.encoder[6](enc3)  # layer3
        bottleneck = self.encoder[7](enc4)  # layer4
        
        # Decoder forward pass with skip connections
        dec4 = self.upconv4(bottleneck)
        dec3 = self.upconv3(dec4 + enc4)
        dec2 = self.upconv2(dec3 + enc3)
        dec1 = self.upconv1(dec2 + enc2)
        
        return self.final_conv(dec1 + enc1)

