In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchsummary
import torchinfo
from torchvision import transforms, utils, datasets
import os
import torch.optim as optim
from tqdm import tqdm

In [2]:
# https://github.com/clemkoa/u-net/blob/master/unet/unet.py

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class Up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Up, self).__init__()
        self.up_scale = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)

    def forward(self, x1, x2):
        x2 = self.up_scale(x2)

        diffY = x1.size()[2] - x2.size()[2]
        diffX = x1.size()[3] - x2.size()[3]

        x2 = F.pad(x2, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return x


class DownLayer(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DownLayer, self).__init__()
        self.pool = nn.MaxPool2d(2, stride=2, padding=0)
        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(self.pool(x))
        return x


class UpLayer(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(UpLayer, self).__init__()
        self.up = Up(in_ch, out_ch)
        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x1, x2):
        a = self.up(x1, x2)
        x = self.conv(a)
        return x


class UNet(nn.Module):
    def __init__(self, dimensions=2):
        super(UNet, self).__init__()
        self.conv1 = DoubleConv(1, 64)
        self.down1 = DownLayer(64, 128)
        self.down2 = DownLayer(128, 256)
        self.down3 = DownLayer(256, 512)
        self.down4 = DownLayer(512, 1024)
        self.up1 = UpLayer(1024, 512)
        self.up2 = UpLayer(512, 256)
        self.up3 = UpLayer(256, 128)
        self.up4 = UpLayer(128, 64)
        self.last_conv = nn.Conv2d(64, dimensions, 1)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x1_up = self.up1(x4, x5)
        x2_up = self.up2(x3, x1_up)
        x3_up = self.up3(x2, x2_up)
        x4_up = self.up4(x1, x3_up)
        output = self.last_conv(x4_up)
        return output

if __name__ == "__main__":
    # Example usage
    model = UNet(dimensions=3)
    x = torch.randn(1, 1, 256, 256)  # Batch size of 1, 1 channel, 256x256 image
    output = model(x)
    print(output.shape)  # Should be (1, 3, 256, 256)
    

torch.Size([1, 3, 256, 256])


In [3]:
# Example usage
model = UNet(dimensions=3)
x = torch.randn(1, 1, 256, 256)  # Batch size of 1, 1 channel, 256x256 image
output = model(x)
print(output.shape)  # Should be (1, 3, 256, 256)

torch.Size([1, 3, 256, 256])


In [4]:
torchinfo.summary(model, input_size=(1, 1, 256, 256))

Layer (type:depth-idx)                   Output Shape              Param #
UNet                                     [1, 3, 256, 256]          --
├─DoubleConv: 1-1                        [1, 64, 256, 256]         --
│    └─Sequential: 2-1                   [1, 64, 256, 256]         --
│    │    └─Conv2d: 3-1                  [1, 64, 256, 256]         640
│    │    └─BatchNorm2d: 3-2             [1, 64, 256, 256]         128
│    │    └─ReLU: 3-3                    [1, 64, 256, 256]         --
│    │    └─Conv2d: 3-4                  [1, 64, 256, 256]         36,928
│    │    └─BatchNorm2d: 3-5             [1, 64, 256, 256]         128
│    │    └─ReLU: 3-6                    [1, 64, 256, 256]         --
├─DownLayer: 1-2                         [1, 128, 128, 128]        --
│    └─MaxPool2d: 2-2                    [1, 64, 128, 128]         --
│    └─DoubleConv: 2-3                   [1, 128, 128, 128]        --
│    │    └─Sequential: 3-7              [1, 128, 128, 128]        221,952
├─D

In [5]:
torchsummary.summary(model, input_size=(1, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]             640
       BatchNorm2d-2         [-1, 64, 256, 256]             128
              ReLU-3         [-1, 64, 256, 256]               0
            Conv2d-4         [-1, 64, 256, 256]          36,928
       BatchNorm2d-5         [-1, 64, 256, 256]             128
              ReLU-6         [-1, 64, 256, 256]               0
        DoubleConv-7         [-1, 64, 256, 256]               0
         MaxPool2d-8         [-1, 64, 128, 128]               0
            Conv2d-9        [-1, 128, 128, 128]          73,856
      BatchNorm2d-10        [-1, 128, 128, 128]             256
             ReLU-11        [-1, 128, 128, 128]               0
           Conv2d-12        [-1, 128, 128, 128]         147,584
      BatchNorm2d-13        [-1, 128, 128, 128]             256
             ReLU-14        [-1, 128, 1

In [6]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {total_params:,}")

Total trainable parameters: 31,042,499


In [7]:
transform = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Grayscale()])
dataset = datasets.VOCSegmentation(
    "../data",
    year="2007",
    download=True,
    image_set="train",
    transform=transform,
    target_transform=transform,
)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cuda"
model_path = "unet_model.pth"
epoch_number = 10
saving_interval = 5

print(f"Using device: {device}")

def train():
    cell_dataset = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=False)

    model = UNet(dimensions=22)
    model.to(device)
    if os.path.isfile(model_path):
        model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
    optimizer = optim.RMSprop(
        model.parameters(), lr=0.0001, weight_decay=1e-8, momentum=0.9
    )
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epoch_number):
        print(f"Epoch {epoch}")
        losses = []
        for i, batch in tqdm(enumerate(cell_dataset), total=len(cell_dataset)):
            input, target = batch
            input = input.to(device)
            target = target.type(torch.LongTensor).to(device)
            # HACK to skip the last item that has a batch size of 1, not working with the cross entropy implementation
            if input.shape[0] < 2:
                continue
            optimizer.zero_grad()
            output = model(input)
            loss = criterion(output, target.squeeze())
            # step_loss = loss.item()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
        # print the average loss for that epoch.
        print(sum(losses) /len(losses))
        if (epoch + 1) % saving_interval == 0:
            print("Saving model")

        torch.save(model.state_dict(), model_path)
    torch.save(model.state_dict(), model_path)
    return

Using device: cuda


In [8]:
train()

Epoch 0


100%|██████████| 105/105 [00:31<00:00,  3.36it/s]


0.689219132065773
Epoch 1


100%|██████████| 105/105 [00:29<00:00,  3.50it/s]


0.22173123835371092
Epoch 2


100%|██████████| 105/105 [00:27<00:00,  3.77it/s]


0.20926615667457765
Epoch 3


100%|██████████| 105/105 [00:31<00:00,  3.30it/s]


0.20408104324283508
Epoch 4


100%|██████████| 105/105 [00:36<00:00,  2.86it/s]


0.20092592073174623
Saving model
Epoch 5


100%|██████████| 105/105 [00:37<00:00,  2.83it/s]


0.1992411929397629
Epoch 6


100%|██████████| 105/105 [00:36<00:00,  2.86it/s]


0.1966919664723369
Epoch 7


100%|██████████| 105/105 [00:37<00:00,  2.82it/s]


0.19570810806292754
Epoch 8


100%|██████████| 105/105 [00:36<00:00,  2.84it/s]


0.19642507141599289
Epoch 9


100%|██████████| 105/105 [00:36<00:00,  2.85it/s]


0.1947769453175939
Saving model


In [None]:
data_folder = "../data"
model_path = "model/unet-voc.pt"

shuffle_data_loader = False

transform = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Grayscale()])
dataset = datasets.VOCSegmentation(
    data_folder,
    year="2007",
    download=True,
    image_set="train",
    transform=transform,
    target_transform=transform,
)


def predict():
    model = UNet(dimensions=22)
    checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
    cell_dataset = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=shuffle_data_loader)
    model.load_state_dict(checkpoint)
    model.eval()
    for i, batch in enumerate(cell_dataset):
        input, _ = batch
        output = model(input).detach()
        input_array = input.squeeze().detach().numpy()
        output_array = output.argmax(dim=1)
        # Simple conversion to black and white.
        # Everything class 0 is background, make everything else white.
        # This is bad for images with several classes.
        output_array = torch.where(output_array > 0, 255, 0)
        input_img = Image.fromarray(input_array * 255)
        input_img.show()
        output_img = Image.fromarray(output_array.squeeze().numpy().astype(dtype=np.uint16)).convert("L")
        output_img.show()
        # Just showing first ten images. Change as you wish!
        if i > 10:
            break
    return