<a href="https://colab.research.google.com/github/Furkanpusher/U-NET/blob/main/%C4%B0mproved_UNET.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Cbam öncesi shape:torch.Size([1, 64, 56, 56])
Cbam sonrası shape:torch.Size([1, 64, 56, 56])
torch.Size([1, 1000])
Cbam öncesi shape:torch.Size([2, 64, 56, 56])
Cbam sonrası shape:torch.Size([2, 64, 56, 56])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
 AdaptiveAvgPool2d-5             [-1, 64, 1, 1]               0
            Conv2d-6              [-1, 4, 1, 1]             256
              ReLU-7              [-1, 4, 1, 1]               0
            Conv2d-8             [-1, 64, 1, 1]             256
 AdaptiveMaxPool2d-9             [-1, 64, 1, 1]               0
           Conv2d-10              [-1, 4, 1, 1]             256
             ReLU-11   

In [7]:
# ENCODER KISMI: MOBILENET
# SKIP CONNECTION: CBAM + STARNET
# BOTTLE NECK: STARNET
# OUTPUT KISMI: STARNET


import torch
import torch.nn as nn
from CBAM import CBAM
from STARNET_2D import STARNet2D
from unet_parts import UpSample
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
import torch.nn.functional as F
from torchsummary import summary


class UNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        # MobileNetV3 encoder
        self.encoder = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1).features

        # Bottleneck with STARNet2D
        self.bottle_neck = STARNet2D(in_channels=576, out_channels=1024, dim=1024, depth=2, use_stem = True, use_head = True)

        # Upsampling path
        self.up_convolution_1 = UpSample(1024, 512, 96)
        self.up_convolution_2 = UpSample(512, 256, 40)
        self.up_convolution_3 = UpSample(256, 128, 24)
        self.up_convolution_4 = UpSample(128, 64, 16)

        # Output layer
        self.out = nn.Conv2d(64, num_classes, kernel_size=1)

        # CBAM + STARNet modülleri burda
        self.cbam1 = nn.Sequential(
            CBAM(in_channels=16),
            STARNet2D(16, out_channels=16, depth=1, use_stem=False, use_head=False)
        )
        self.cbam2 = nn.Sequential(
            CBAM(in_channels=24),
            STARNet2D(24, out_channels=24, depth=1, use_stem=False, use_head=False)
        )
        self.cbam3 = nn.Sequential(
            CBAM(in_channels=40),
            STARNet2D(40, out_channels=40, depth=1, use_stem=False, use_head=False)
        )
        self.cbam4 = nn.Sequential(
            CBAM(in_channels=96),
            STARNet2D(96, out_channels=96, depth=2, use_stem=False, use_head=False)
        )

    # Output refinement ile son outputu vermeden resimdeki ufak iyileştirmeleri ypamak için hemde interpolasyondan oluşabilcek etkileride azaltır

        self.output_refiner = STARNet2D(num_classes, out_channels=num_classes, depth=1, dim=64)


    def forward(self, x):
        # Encoder
        skip_1, skip_2, skip_3, skip_4 = None, None, None, None
        for i, layer in enumerate(self.encoder):
            x = layer(x)
            if i == 0: skip_1 = x
            elif i == 3: skip_2 = x
            elif i == 6: skip_3 = x
            elif i == 10: skip_4 = x

        # Bottleneck
        b = self.bottle_neck(x)

        # Decoder with enhanced skip connections
        down_4 = self.cbam4(skip_4)
        up_1 = self.up_convolution_1(b, down_4)

        down_3 = self.cbam3(skip_3)
        up_2 = self.up_convolution_2(up_1, down_3)

        down_2 = self.cbam2(skip_2)
        up_3 = self.up_convolution_3(up_2, down_2)

        down_1 = self.cbam1(skip_1)   # burdaki cbam1 ler içinde hem CBAM hemde STARNET VAR
        up_4 = self.up_convolution_4(up_3, down_1)

        # Final output
        out = self.out(up_4)
        out = F.interpolate(out, size=(224, 224), mode='bilinear', align_corners=False)
        return self.output_refiner(out)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels=3, num_classes=3).to(device)

input_tensor = torch.rand(1, 3, 224, 224).to(device)
with torch.no_grad():
    output = model(input_tensor)
    print(f"Çıkış boyutu: {output.shape}")  # [1, 3, 224, 224]

summary(model, input_size=(3, 224, 224), device=str(device))


# 31M parametre  NORMALDE 25.8m dı STARNET ile 31M parametre oldu

Çıkış boyutu: torch.Size([1, 3, 224, 224])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 16, 112, 112]             432
       BatchNorm2d-2         [-1, 16, 112, 112]              32
         Hardswish-3         [-1, 16, 112, 112]               0
            Conv2d-4           [-1, 16, 56, 56]             144
       BatchNorm2d-5           [-1, 16, 56, 56]              32
              ReLU-6           [-1, 16, 56, 56]               0
 AdaptiveAvgPool2d-7             [-1, 16, 1, 1]               0
            Conv2d-8              [-1, 8, 1, 1]             136
              ReLU-9              [-1, 8, 1, 1]               0
           Conv2d-10             [-1, 16, 1, 1]             144
      Hardsigmoid-11             [-1, 16, 1, 1]               0
SqueezeExcitation-12           [-1, 16, 56, 56]               0
           Conv2d-13           [-1, 16, 56, 56]             