In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.cuda.empty_cache()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResBlock, self).__init__()
        self.residual = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.gn1 = nn.GroupNorm(8, in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.gn2 = nn.GroupNorm(8, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        res = self.residual(x)
        x = self.gn1(x)
        x = F.relu(x)
        x = self.conv1(x)
        x = self.gn2(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = x + res
        return x
    
class VesselSegmentationModel(nn.Module):
    def __init__(self, input_shape=(2, 512, 512), output_channels=1): # TODO: Can be changed to 25 channels for vessel cla
        super(VesselSegmentationModel, self).__init__()
        c, H, W = input_shape
        # Encoder Part
        self.enc1 = nn.Conv2d(c, 32, kernel_size=3, stride=1, padding=1)
        self.drop1 = nn.Dropout2d(0.2)
        self.green1 = ResBlock(32, 32)
        self.down1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Conv2d Downsampling?

        self.green2_1 = ResBlock(32, 64)
        self.green2_2 = ResBlock(64, 64)
        self.down2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.green3_1 = ResBlock(64, 128)
        self.green3_2 = ResBlock(128, 128)
        self.down3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.green4_1 = ResBlock(128, 256)
        self.green4_2 = ResBlock(256, 256)
        self.green4_3 = ResBlock(256, 256)
        self.green4_4 = ResBlock(256, 256)

        # Segmentation Decoder Part
        self.up1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.green5 = ResBlock(128, 128)

        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.green6 = ResBlock(64, 64)

        self.up3 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.green7 = ResBlock(32, 32)

        # Reconstruction Decoder Part
        self.rec_green_1 = ResBlock(256, 128)
        self.down_rec_1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.rec_green_2 = ResBlock(128, 128)
        self.down_rec_2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.bottleneck = nn.Linear(128 * 16 * 16, 256)
        self.out_conv = nn.Conv2d(32, output_channels, kernel_size=1, stride=1)

    def forward(self, x):
        x1 = self.enc1(x)
        x1 = self.drop1(x1)
        x1 = self.green1(x1)
        x_down1 = self.down1(x1)

        x2 = self.green2_1(x_down1)
        x2 = self.green2_2(x2)
        x_down2 = self.down2(x2)

        x3 = self.green3_1(x_down2)
        x3 = self.green3_2(x3)
        x_down3 = self.down3(x3)

        x4 = self.green4_1(x_down3)
        x4 = self.green4_2(x4)
        x4 = self.green4_3(x4)
        x4 = self.green4_4(x4)
        print(f'{x4.shape} x4 shape')

        x_up1 = self.up1(x4)
        x_up1 = x_up1 + x3
        x_up1 = self.green5(x_up1)

        x_up2 = self.up2(x_up1)
        x_up2 = x_up2 + x2
        x_up2 = self.green6(x_up2)

        x_up3 = self.up3(x_up2)
        x_up3 = x_up3 + x1
        x_up3 = self.green7(x_up3)

        out = self.out_conv(x_up3)
        out_seg = torch.sigmoid(out)

        # Reconstruction Decoder forward pass
        

        return out_seg

# Example usage
input_shape = (2, 512, 512)
output_channels = 1
model = VesselSegmentationModel(input_shape=input_shape, output_channels=1).to('cuda')

# Dummy input for testing
x = torch.randn((9, 2, 512, 512)).to('cuda')
output = model(x)

print(output.shape)  # Expected output torch.Size([9, 1, 512, 512])
print(sum([p.numel() for p in model.parameters()]))  # Expected number of parameters