In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

In [3]:
class WarwickDataset(Dataset):
    def __init__(self, root_dir, train=True, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images_dir = os.path.join(root_dir, 'Train' if train else 'Test', 'images')
        self.masks_dir = os.path.join(root_dir, 'Train' if train else 'Test', 'masks')

        self.ids = [os.path.splitext(file)[0] for file in os.listdir(self.images_dir)
                    if not file.startswith('.')]
        
    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        img_id = self.ids[idx]
        img_path = os.path.join(self.images_dir, img_id + '.png')
        mask_path = os.path.join(self.masks_dir, img_id + '_anno.png')

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")
        
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        
        return image, mask


In [4]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.down1 = ConvBlock(3, 64)
        self.down2 = ConvBlock(64, 128)
        self.down3 = ConvBlock(128, 256)
        self.down4 = ConvBlock(256, 512)

        self.up1 = ConvBlock(512, 256)
        self.up2 = ConvBlock(256, 128)
        self.up3 = ConvBlock(128, 64)
        
        self.trans1 = nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1)
        self.trans2 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)
        self.trans3 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
        
        self.out = nn.Conv2d(64, 1, 1)

    def forward(self, x):
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        
        x = self.trans1(x4)
        x = self.up1(x + x3)
        x = self.trans2(x)
        x = self.up2(x + x2)
        x = self.trans3(x)
        x = self.up3(x + x1)
        x = self.out(x)
        
        return x


In [5]:
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import os
from PIL import Image

class SegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transforms=None):
        self.root_dir = root_dir
        self.transforms = transforms
        self.images = [os.path.join(root_dir, file) for file in os.listdir(root_dir) if file.endswith('.png')]
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        mask_path = img_path.replace('.png', '_mask.png')
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")  # Convert mask to grayscale

        if self.transforms:
            image = self.transforms(image)
            mask = self.transforms(mask)

        return image, mask

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Load datasets
train_dataset = SegmentationDataset('C:/Users/HP/Downloads/WARWICK/Train', transforms=transform)
test_dataset = SegmentationDataset('C:/Users/HP/Downloads/WARWICK/Test', transforms=transform)

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)


In [6]:
import torch.nn as nn
import torch.nn.functional as F

class SegNet(nn.Module):
    def __init__(self):
        super(SegNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

model = SegNet()


In [12]:
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCEWithLogitsLoss()

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    epoch_loss = running_loss / len(dataloader.dataset)
    return epoch_loss

# training
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

for epoch in range(10):  #  epoch count
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
    print(f'Epoch {epoch+1}, Train Loss: {train_loss}')

train_model(model, criterion, optimizer, num_epochs=10)

Epoch 1/10, Train Loss: 0.2426, Train Acc: 0.9271, Test Loss: 0.0662, Test Acc: 0.9790
Epoch 2/10, Train Loss: 0.0656, Train Acc: 0.9797, Test Loss: 0.0494, Test Acc: 0.9832
Epoch 3/10, Train Loss: 0.0479, Train Acc: 0.9848, Test Loss: 0.0453, Test Acc: 0.9845
Epoch 4/10, Train Loss: 0.0393, Train Acc: 0.9876, Test Loss: 0.0357, Test Acc: 0.9890
Epoch 5/10, Train Loss: 0.0313, Train Acc: 0.9905, Test Loss: 0.0362, Test Acc: 0.9876
Epoch 6/10, Train Loss: 0.0263, Train Acc: 0.9915, Test Loss: 0.0384, Test Acc: 0.9877
Epoch 7/10, Train Loss: 0.0230, Train Acc: 0.9923, Test Loss: 0.0346, Test Acc: 0.9885
Epoch 8/10, Train Loss: 0.0200, Train Acc: 0.9934, Test Loss: 0.0360, Test Acc: 0.9889
Epoch 9/10, Train Loss: 0.0171, Train Acc: 0.9942, Test Loss: 0.0332, Test Acc: 0.9884
Epoch 10/10, Train Loss: 0.0146, Train Acc: 0.9953, Test Loss: 0.0356, Test Acc: 0.9886
