<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/U_Net_for_Image_Segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        # Encoder
        self.enc1 = self.contract_block(1, 64, 3, 1)
        self.enc2 = self.contract_block(64, 128, 3, 1)
        self.enc3 = self.contract_block(128, 256, 3, 1)
        self.enc4 = self.contract_block(256, 512, 3, 1)

        # Decoder
        self.upconv4 = self.expand_block(512, 256, 3, 1)
        self.dec4 = self.expand_block(512, 256, 3, 1)
        self.upconv3 = self.expand_block(256, 128, 3, 1)
        self.dec3 = self.expand_block(256, 128, 3, 1)
        self.upconv2 = self.expand_block(128, 64, 3, 1)
        self.dec2 = self.expand_block(128, 64, 3, 1)
        self.final = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))

        # Decoder
        dec4 = self.up_sample(enc4)
        dec4 = torch.cat((dec4, enc3), dim=1)
        dec4 = self.dec4(dec4)

        dec3 = self.up_sample(dec4)
        dec3 = torch.cat((dec3, enc2), dim=1)
        dec3 = self.dec3(dec3)

        dec2 = self.up_sample(dec3)
        dec2 = torch.cat((dec2, enc1), dim=1)
        dec2 = self.dec2(dec2)

        final = self.final(dec2)
        return torch.sigmoid(final)

    def contract_block(self, in_channels, out_channels, kernel_size, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding),
            nn.ReLU(inplace=True)
        )

    def expand_block(self, in_channels, out_channels, kernel_size, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding),
            nn.ReLU(inplace=True)
        )

    def pool(self, x):
        return nn.MaxPool2d(kernel_size=2, stride=2)(x)

    def up_sample(self, x):
        return nn.ConvTranspose2d(x.size(1), x.size(1) // 2, kernel_size=2, stride=2)(x)

# Create the model
model = UNet()

# Define loss and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Assume `trainloader` is defined and provides input images and masks
# Example dummy data loader
dummy_data = torch.randn(5, 1, 256, 256)  # Batch of 5 images, 1 channel, 256x256 size
dummy_masks = torch.randint(0, 2, (5, 1, 256, 256)).float()  # Corresponding masks

# Training loop
for epoch in range(10):
    for images, masks in [(dummy_data, dummy_masks)]:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")