<a href="https://colab.research.google.com/github/1kaiser/jax-unet/blob/master/UNet%2B%2B_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Sure, here is an example of how to use the Nested UNet model to segment objects in an image:

In [1]:
import torch
import torch.nn as nn

class conv_block_nested(nn.Module):

    def __init__(self, in_ch, mid_ch, out_ch):
        super(conv_block_nested, self).__init__()
        self.activation = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1, bias=True)
        self.bn1 = nn.BatchNorm2d(mid_ch)
        self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1, bias=True)
        self.bn2 = nn.BatchNorm2d(out_ch)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.activation(x)

        x = self.conv2(x)
        x = self.bn2(x)
        output = self.activation(x)

        return output

class Nested_UNet(nn.Module):

    def __init__(self, in_ch=3, out_ch=1):
        super(Nested_UNet, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv0_0 = conv_block_nested(in_ch, filters[0], filters[0])
        self.conv1_0 = conv_block_nested(filters[0], filters[1], filters[1])
        self.conv2_0 = conv_block_nested(filters[1], filters[2], filters[2])
        self.conv3_0 = conv_block_nested(filters[2], filters[3], filters[3])
        self.conv4_0 = conv_block_nested(filters[3], filters[4], filters[4])

        self.conv0_1 = conv_block_nested(filters[0] + filters[1], filters[0], filters[0])
        self.conv1_1 = conv_block_nested(filters[1] + filters[2], filters[1], filters[1])
        self.conv2_1 = conv_block_nested(filters[2] + filters[3], filters[2], filters[2])
        self.conv3_1 = conv_block_nested(filters[3] + filters[4], filters[3], filters[3])

        self.conv0_2 = conv_block_nested(filters[0]*2 + filters[1], filters[0], filters[0])
        self.conv1_2 = conv_block_nested(filters[1]*2 + filters[2], filters[1], filters[1])
        self.conv2_2 = conv_block_nested(filters[2]*2 + filters[3], filters[2], filters[2])

        self.conv0_3 = conv_block_nested(filters[0]*3 + filters[1], filters[0], filters[0])
        self.conv1_3 = conv_block_nested(filters[1]*3 + filters[2], filters[1], filters[1])

        self.conv0_4 = conv_block_nested(filters[0]*4 + filters[1], filters[0], filters[0])

        self.final = nn.Conv2d(filters[0], out_ch, kernel_size=1)

    def forward(self, x):

        x0_0 = self.conv0_0(x)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.Up(x1_0)], 1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.Up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.Up(x1_1)], 1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.Up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.Up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.Up(x1_2)], 1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.Up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.Up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.Up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.Up(x1_3)], 1))

        output = self.final(x0_4)
        return output

In [2]:
import torch
import torch.nn as nn

# Load the Nested UNet model
model = Nested_UNet(in_ch=3, out_ch=1)

# Load the input image
image = torch.randn(1, 3, 256, 256)

# Set the model in evaluation mode
model.eval()

# Forward pass the input image through the model
output = model(image)

# Get the segmentation mask
segmentation_mask = output.argmax(1)

# Save the segmentation mask
torch.save(segmentation_mask, "segmentation_mask.pt")

In [None]:
#@title **serial**
import torch
import torchvision
import matplotlib.pyplot as plt

# Load the Nested UNet model
model = Nested_UNet(in_ch=3, out_ch=3)


# Load the dataset
dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=torchvision.transforms.ToTensor())

# Create a data loader
data_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

from google.colab.patches import cv2_imshow
import numpy as np
import cv2
# Define a function to plot the predicted images
def plot_predictions(images, labels, predictions):
    print(len(images))
    for i in range(len(images)):
      print(type(predictions[i]), type(images[i]))
      
      cv2_imshow(cv2.cvtColor(np.array(images[i].detach().numpy()).reshape(32, 32, 3), cv2.COLOR_BGR2RGB))
      cv2_imshow(cv2.cvtColor(np.array(predictions[i].detach().numpy()).reshape(32, 32, 3), cv2.COLOR_BGR2RGB))


# Train the model
for epoch in range(10):
    for batch in data_loader:
        images, labels = batch
        predictions = model(images)
        print(images.shape, predictions.shape)
        loss = nn.CrossEntropyLoss()(predictions, images)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Plot the predicted images
        plot_predictions(images, labels, predictions)

# Evaluate the model
correct = 0
total = 0
for batch in data_loader:
    images, labels = batch
    predictions = model(images)
    _, predicted = torch.max(predictions.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

accuracy = correct / total
print("Accuracy:", accuracy)

# Save the model
torch.save(model.state_dict(), "model.pt")


This code will load the Nested UNet model, load the input image, set the model in evaluation mode, forward pass the input image through the model, get the segmentation mask, and save the segmentation mask.

The segmentation mask can be used to identify the objects in the image. For example, if the input image is a picture of a cat, the segmentation mask will identify the pixels that belong to the cat. The segmentation mask can be used for a variety of tasks, such as object detection, object tracking, and image editing.

In [None]:
!wget https://github.com/1kaiser/Media-Segment-Depth-MLP/releases/download/v0.2/s.zip
!unzip '*.zip' -d files
!rm -r *.zip

In [2]:
import os

import time
#@title **parallel**
import torch
import torchvision
import matplotlib.pyplot as plt

# Load the Nested UNet model
model = Nested_UNet(in_ch=3, out_ch=3)

if torch.cuda.is_available():
    device = torch.device("cuda")
    model = model.to(device)
    print("Model is now running on GPU!")
else:
    print("GPU is not available!")



from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor, Resize
import torchvision.transforms as transforms
# Get the path to the image folder
data_dir = "/content/files"
transform = transforms.Compose([ToTensor(), Resize((128, 128))])
dataset = ImageFolder(data_dir, transform=transform, target_transform=None)
print(len(dataset))
data_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


from google.colab.patches import cv2_imshow
import numpy as np
import cv2
# Define a function to plot the predicted images
def plot_predictions(images, labels, predictions):
    print(len(images))
    for i in range(len(images)):
      print(type(predictions[i]), type(images[i]))
      
      cv2_imshow(cv2.cvtColor(torch.moveaxis(np.array(images[i].cpu().detach().numpy()), 0, -1), cv2.COLOR_BGR2RGB))
      cv2_imshow(cv2.cvtColor(torch.moveaxis(np.array(predictions[i].cpu().detach().numpy()), 0, -1), cv2.COLOR_BGR2RGB))

# Train the model
start_time = time.time()
for epoch in range(1):
    for batch in data_loader:
        images, labels = batch
        images = images.to(device)
        # labels = labels.to(device)
        predictions = model(images)
        loss = nn.CrossEntropyLoss()(predictions, images)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Plot the predicted images
        # plot_predictions(images, labels, predictions)
        print(time.time() - start_time, "<<< ⌚ for a batch")
    print(time.time() - start_time, "<<< ⌛⌛⌛⌛⌛⌛⌛⌛⌛⌛ for a epoch")
# Evaluate the model
correct = 0
total = 0
for batch in data_loader:
    images, labels = batch
    images = images.to(device)
    # labels = labels.to(device)
    predictions = model(images)
    total += images.size(0)
    correct += (predictions == images).sum().item()

accuracy = correct / total
print("Accuracy:", accuracy)

# Save the model
torch.save(model.state_dict(), "model.pt")

print("The time to completion of the model code is:", time.time() - start_time)


Model is now running on GPU!
2028




2.7060868740081787 <<< ⌚ for a batch
3.4487032890319824 <<< ⌚ for a batch
4.700771331787109 <<< ⌚ for a batch
5.9438207149505615 <<< ⌚ for a batch
7.194634914398193 <<< ⌚ for a batch
8.455477237701416 <<< ⌚ for a batch
9.72258973121643 <<< ⌚ for a batch
11.015422821044922 <<< ⌚ for a batch
12.311810493469238 <<< ⌚ for a batch
13.587861061096191 <<< ⌚ for a batch
14.943670511245728 <<< ⌚ for a batch
16.663172960281372 <<< ⌚ for a batch
17.947860717773438 <<< ⌚ for a batch
19.229580879211426 <<< ⌚ for a batch
20.541556119918823 <<< ⌚ for a batch
21.836486339569092 <<< ⌚ for a batch
23.125173807144165 <<< ⌚ for a batch
24.394511461257935 <<< ⌚ for a batch
25.665920972824097 <<< ⌚ for a batch
26.940340042114258 <<< ⌚ for a batch
28.190704584121704 <<< ⌚ for a batch
29.426841020584106 <<< ⌚ for a batch
30.67162799835205 <<< ⌚ for a batch
31.90028190612793 <<< ⌚ for a batch
33.12438631057739 <<< ⌚ for a batch
34.35887289047241 <<< ⌚ for a batch
35.57278490066528 <<< ⌚ for a batch
36.77126836

OutOfMemoryError: ignored

In [18]:

image_path = '/content/files/annotated_images/out_000000001.png'

import torch
from PIL import Image
import numpy as np
from torchvision import transforms

def predict(image_path):

    # Load the model and put it in eval mode
    model = torch.load('/content/model.pt')
    model.eval()

    # Convert the image to a Torch tensor
    image = Image.open(image_path)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((128, 128))
    ])
    image = transform(image)

    # Feed the image to the model
    predictions = model(image)

    # Get the predictions from the model
    _, predicted = torch.max(predictions.data, 1)

    # Evaluate the predictions
    return predicted

predict(image_path)

AttributeError: ignored

In [None]:
import time
#@title **parallel**
import torch
import torchvision
import matplotlib.pyplot as plt

# Load the Nested UNet model
model = Nested_UNet(in_ch=3, out_ch=3)

# Check if GPU is available
if torch.cuda.is_available():

    # Set the device to GPU
    device = torch.device("cuda")

    # Move the model to GPU
    model = model.to(device)

    # Print a message to the user
    print("Model is now running on GPU!")

else:

    # Print a message to the user
    print("GPU is not available!")

# Load the dataset
dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=torchvision.transforms.ToTensor())

# Create a data loader
data_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


from google.colab.patches import cv2_imshow
import numpy as np
import cv2
# Define a function to plot the predicted images
def plot_predictions(images, labels, predictions):
    print(len(images))
    for i in range(len(images)):
      print(type(predictions[i]), type(images[i]))
      
      cv2_imshow(cv2.cvtColor(np.array(images[i].cpu().detach().numpy()).reshape(32, 32, 3), cv2.COLOR_BGR2RGB))
      cv2_imshow(cv2.cvtColor(np.array(predictions[i].cpu().detach().numpy()).reshape(32, 32, 3), cv2.COLOR_BGR2RGB))

# Train the model
start_time = time.time()
for epoch in range(10):
    for batch in data_loader:
        images, labels = batch
        images = images.to(device)
        labels = labels.to(device)
        predictions = model(images)
        loss = nn.CrossEntropyLoss()(predictions, images)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Plot the predicted images
        # plot_predictions(images, labels, predictions)
        print(time.time() - start_time, "<<< ⌚ for a batch")
    print(time.time() - start_time, "<<< ⌛⌛⌛⌛⌛⌛⌛⌛⌛⌛ for a epoch")
# Evaluate the model
correct = 0
total = 0
for batch in data_loader:
    images, labels = batch
    images = images.to(device)
    labels = labels.to(device)
    predictions = model(images)
    _, predicted = torch.max(predictions.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

accuracy = correct / total
print("Accuracy:", accuracy)

# Save the model
torch.save(model.state_dict(), "model.pt")

print("The time to completion of the model code is:", time.time() - start_time)


In [None]:
#@title **parallel**
import torch
import torchvision
import matplotlib.pyplot as plt

# Load the Nested UNet model
model = Nested_UNet(in_ch=3, out_ch=3)

# Check if GPU is available
if torch.cuda.is_available():

    # Set the device to GPU
    device = torch.device("cuda")

    # Move the model to GPU
    model = model.to(device)

    # Print a message to the user
    print("Model is now running on GPU!")

else:

    # Print a message to the user
    print("GPU is not available!")

# Load the dataset
dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=torchvision.transforms.ToTensor())

# Create a data loader
data_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


from google.colab.patches import cv2_imshow
import numpy as np
import cv2
# Define a function to plot the predicted images
def plot_predictions(images, labels, predictions):
    print(len(images))
    for i in range(len(images)):
      print(type(predictions[i]), type(images[i]))
      
      cv2_imshow(cv2.cvtColor(np.array(images[i].cpu().detach().numpy()).reshape(32, 32, 3), cv2.COLOR_BGR2RGB))
      cv2_imshow(cv2.cvtColor(np.array(predictions[i].cpu().detach().numpy()).reshape(32, 32, 3), cv2.COLOR_BGR2RGB))

# Train the model
for epoch in range(10):
    for batch in data_loader:
        images, labels = batch
        images = images.to(device)
        labels = labels.to(device)
        predictions = model(images)
        print(images.shape, predictions.shape)
        loss = nn.CrossEntropyLoss()(predictions, images)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Plot the predicted images
        # plot_predictions(images, labels, predictions)

# Evaluate the model
correct = 0
total = 0
for batch in data_loader:
    images, labels = batch
    images = images.to(device)
    labels = labels.to(device)
    predictions = model(images)
    _, predicted = torch.max(predictions.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

accuracy = correct / total
print("Accuracy:", accuracy)

# Save the model
torch.save(model.state_dict(), "model.pt")

