import torch

from unet import UNet

def test_unet():
    in_channels=3
    out_channels=1
    model=UNet(in_channels=in_channels,out_channels=out_channels)
    batch_size=1
    height,width=256,256
    x=torch.randn(batch_size,in_channels,height,width)
    output=model(x)
    assert output.shape==(batch_size,out_channels,height,width)
    print("passed")
test_unet()

In [4]:
import torch
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=3, out_channels=1, init_features=32, pretrained=True)

Using cache found in /Users/nshravanreddy/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


In [6]:
# Download an example image
import urllib
url, filename = ("https://github.com/mateuszbuda/brain-segmentation-pytorch/raw/master/assets/TCGA_CS_4944.png", "TCGA_CS_4944.png")
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)

In [7]:
import numpy as np
from PIL import Image
from torchvision import transforms

input_image = Image.open(filename)
m, s = np.mean(input_image, axis=(0, 1)), np.std(input_image, axis=(0, 1))
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=m, std=s),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)

if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model = model.to('cuda')

with torch.no_grad():
    output = model(input_batch)

print(torch.round(output[0]))

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])


In [1]:
import torch
from torch import nn

class DoubleConv(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(DoubleConv,self).__init__()
        self.conv=nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1),
            nn.ReLU(inplace=True),
        )
    def forward(self,x):
        return self.conv(x)
    
class DownSample(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(DownSample,self).__init__()
        self.conv=DoubleConv(in_channels,out_channels)
        self.pool=nn.MaxPool2d(kernel_size=2,stride=2)
    def forward(self,x):
        down=self.conv(x)
        pooled=self.pool(down)
        return down,pooled

class UpSample(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(UpSample,self).__init__()
        self.up_conv=nn.ConvTranspose2d(in_channels,out_channels,kernel_size=2,stride=2)
        self.conv=DoubleConv(in_channels,out_channels)
    def forward(self,x1,x2):
        x1=self.up_conv(x1)
        diffX=x2.size()[2]-x1.size()[2]
        diffY=x2.size()[3]-x1.size()[3]
        x1=nn.functional.pad(x1,[diffX//2,diffX-diffX//2,diffY//2,diffY-diffY//2])
        x=torch.cat([x2,x1],dim=1)
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(UNet,self).__init__()
        self.down_conv_1=DownSample(in_channels,64)
        self.down_conv_2=DownSample(64,128)
        self.down_conv_3=DownSample(128,256)
        self.down_conv_4=DownSample(256,512)

        self.bottle_neck=DoubleConv(512,1024)

        self.up_conv_1=UpSample(1024,512)
        self.up_conv_2=UpSample(512,256)
        self.up_conv_3=UpSample(256,128)
        self.up_conv_4=UpSample(128,64)

        self.out=nn.Conv2d(64,out_channels,kernel_size=1)
    
    def forward(self,x):
        down_1, p1=self.down_conv_1(x)
        down_2, p2=self.down_conv_2(p1)
        down_3, p3=self.down_conv_3(p2)
        down_4, p4=self.down_conv_4(p3)

        b=self.bottle_neck(p4)

        up_1=self.up_conv_1(b,down_4)
        up_2=self.up_conv_2(up_1,down_3)
        up_3=self.up_conv_3(up_2,down_2)
        up_4=self.up_conv_4(up_3,down_1)

        out=self.out(up_4)
        return out