<a href="https://colab.research.google.com/github/Arteeemiy/-Segmentation-NN-Spine/blob/main/Segmentation_NN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import numpy as np
import torchvision
import os
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib
import torch.nn.functional as F
from torch import nn
import matplotlib.patches as mpatches
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
!kaggle datasets download -d "newra008/chest-segmentation-image"

Dataset URL: https://www.kaggle.com/datasets/newra008/chest-segmentation-image
License(s): CC0-1.0
Downloading chest-segmentation-image.zip to /content
 99% 1.13G/1.15G [00:11<00:00, 110MB/s]
100% 1.15G/1.15G [00:11<00:00, 104MB/s]


In [None]:
%%capture
!unzip chest-segmentation-image.zip

In [None]:
!rm chest-segmentation-image.zip

In [None]:
class SegmentationDataset(Dataset):
  def __init__(self, image_dir, mask_dir, transform=None):
    self.image_dir = image_dir
    self.mask_dir = mask_dir
    self.transform = transform
    self.images = os.listdir(image_dir)
    self.masks = os.listdir(mask_dir)

  def __len__(self):
    return len(self.images)

  def __getitem__(self, idx):
    img_path = os.path.join(self.image_dir, self.images[idx])
    mask_path = os.path.join(self.mask_dir, self.masks[idx])
    image = Image.open(img_path).convert("RGB")
    mask = Image.open(mask_path).convert("L")

    if self.transform:
      image = self.transform(image)
      mask = self.transform(mask)

    return image, mask

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

train_image_dir = '/content/image'
train_mask_dir = '/content/mask'
test_image_dir = '/content/test_image'
test_mask_dir = '/content/test_mask'
train_dataset = SegmentationDataset(train_image_dir, train_mask_dir, transform=transform)
test_dataset = SegmentationDataset(test_image_dir, test_mask_dir, transform=transform)

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False)

In [None]:
class UNetConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNetConvBlock, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNetUpBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNetUpBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv = UNetConvBlock(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        self.in_conv = UNetConvBlock(in_channels, 64)
        self.down1 = UNetConvBlock(64, 128)
        self.down2 = UNetConvBlock(128, 256)
        self.down3 = UNetConvBlock(256, 512)
        self.down4 = UNetConvBlock(512, 1024)
        self.up1 = UNetUpBlock(1024, 512)
        self.up2 = UNetUpBlock(512, 256)
        self.up3 = UNetUpBlock(256, 128)
        self.up4 = UNetUpBlock(128, 64)
        self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.in_conv(x)
        x2 = self.down1(F.max_pool2d(x1, 2))
        x3 = self.down2(F.max_pool2d(x2, 2))
        x4 = self.down3(F.max_pool2d(x3, 2))
        x5 = self.down4(F.max_pool2d(x4, 2))
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.out_conv(x)
        return x

In [None]:
model = UNet(in_channels=3, out_channels=1)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def train(model, train_dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for images, masks in train_dataloader:
        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    epoch_loss = running_loss / len(train_dataloader)
    return epoch_loss

def evaluate(model, test_dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for images, masks in test_dataloader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            loss = criterion(outputs, masks)
            running_loss += loss.item()
            visualize_segmentation(images.cpu(), masks.cpu(), outputs.cpu())

    epoch_loss = running_loss / len(test_dataloader)
    return epoch_loss

def visualize_segmentation(images, masks, outputs):
    fig, axs = plt.subplots(len(images), 3, figsize=(15, 5 * len(images)))
    if len(images) == 1:
        axs = [axs]
    for i in range(len(images)):
        axs[i][0].imshow(images[i].permute(1, 2, 0))
        axs[i][0].set_title('Original Image')
        axs[i][1].imshow(masks[i].squeeze(), cmap='gray')
        axs[i][1].set_title('Ground Truth')
        axs[i][2].imshow(torch.sigmoid(outputs[i]).squeeze(), cmap='gray')
        axs[i][2].set_title('Predicted Mask')
    plt.show()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

num_epochs = 10
for epoch in range(num_epochs):
    train_loss = train(model, train_dataloader, criterion, optimizer, device)
    val_loss = evaluate(model, test_dataloader, criterion, device)

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')