In [1]:
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, dataloader, random_split

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
import time
import copy
import pandas as pd

import glob

In [2]:
print('Using torch %s %s' % (torch.__version__, torch.cuda.get_device_properties(0) if torch.cuda.is_available() else 'CPU'))

Using torch 2.0.0 _CudaDeviceProperties(name='NVIDIA GeForce RTX 3060', major=8, minor=6, total_memory=12287MB, multi_processor_count=28)


In [3]:
#pth file path
pth_file_path = 'C:/Users/liFangzheng/Desktop/completed_AI_program/ResU-Net/save/embryo_seg.pth'
#images path
img_path = 'C:/Users/liFangzheng/Desktop/LI_development/new/elt-1/input/'
#save images path
save_path = 'C:/Users/liFangzheng/Desktop/LI_development/new/elt-1/output/'

In [4]:
#Do not modify！
N_CLASSES       = 1
LEARNING_RATE   = 0.002
START_FRAME     = 16

In [5]:
class BatchActivate(nn.Module):
    def __init__(self, num_features):
        super(BatchActivate, self).__init__()
        self.norm = nn.BatchNorm2d(num_features)

    def forward(self, x):
        return F.relu(self.norm(x))

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel=3, padding=1, stride=1, activation=True):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                            kernel_size=kernel, stride=stride, padding=padding)
        self.batchnorm  = BatchActivate(out_channels)
        self.activation = activation

    def forward(self, x):
        x = self.conv(x)
        if self.activation:
            x = self.batchnorm(x)
        return x

class DoubleConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel=3, padding=1, stride=1):
        super(DoubleConvBlock, self).__init__()
        self.conv1 = ConvBlock(in_channels, out_channels, kernel, padding, stride)
        self.conv2 = ConvBlock(out_channels, out_channels, kernel, padding, stride)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, batch_activation=False):
        super(ResidualBlock, self).__init__()
        self.batch_activation = batch_activation
        self.norm  = nn.BatchNorm2d(num_features=in_channels)
        self.conv1 = ConvBlock(in_channels, in_channels, kernel=3, stride=1, padding=1)
        self.conv2 = ConvBlock(in_channels, in_channels, kernel=3, stride=1, padding=1, activation=False)

    def forward(self, x):
        residual = x
        x = self.norm(x)
        x = self.conv1(x)
        x = self.conv2(x)

        x += residual


        if self.batch_activation:
            x = self.norm(x)

        return x

In [6]:
class UNet_ResNet(nn.Module):
    def __init__(self, in_channels=1, n_classes=N_CLASSES, dropout=0.1, start_fm=START_FRAME):
        super(UNet_ResNet, self).__init__()

        self.drop = dropout

        self.pool = nn.MaxPool2d((2,2))

        # Encoder
        self.encoder_1 = nn.Sequential(
            nn.Conv2d(in_channels, start_fm, 3, padding=(1,1)),
            ResidualBlock(start_fm),
            ResidualBlock(start_fm, batch_activation=True),

        )

        self.encoder_2 = nn.Sequential(
            nn.Conv2d(start_fm, start_fm*2, 3, padding=(1,1)),
            ResidualBlock(start_fm*2),
            ResidualBlock(start_fm*2, batch_activation=True),

        )

        self.encoder_3 = nn.Sequential(
            nn.Conv2d(start_fm*2, start_fm*4, 3, padding=(1,1)),
            ResidualBlock(start_fm*4),
            ResidualBlock(start_fm*4, batch_activation=True),

        )

        self.encoder_4 = nn.Sequential(
            nn.Conv2d(start_fm*4, start_fm*8, 3, padding=(1,1)),
            ResidualBlock(start_fm*8),
            ResidualBlock(start_fm*8, batch_activation=True),

        )

        self.middle = nn.Sequential(
            nn.Conv2d(start_fm*8, start_fm*16, 3, padding=3//2),
            ResidualBlock(start_fm*16),
            ResidualBlock(start_fm*16, batch_activation=True),

        )


        self.deconv_4  = nn.ConvTranspose2d(start_fm*16, start_fm*8, 2, 2)
        self.deconv_3  = nn.ConvTranspose2d(start_fm*8, start_fm*4, 2, 2)
        self.deconv_2  = nn.ConvTranspose2d(start_fm*4, start_fm*2, 2, 2)
        self.deconv_1  = nn.ConvTranspose2d(start_fm*2, start_fm, 2, 2)

        # Decoder
        self.decoder_4 = nn.Sequential(
            nn.Dropout2d(dropout),
            nn.Conv2d(start_fm*16, start_fm*8, 3, padding=(1,1)),
            ResidualBlock(start_fm*8),
            ResidualBlock(start_fm*8, batch_activation=True),
        )

        self.decoder_3 = nn.Sequential(
            nn.Dropout2d(dropout),
            nn.Conv2d(start_fm*8, start_fm*4, 3, padding=(1,1)),
            ResidualBlock(start_fm*4),
            ResidualBlock(start_fm*4, batch_activation=True),
        )

        self.decoder_2 = nn.Sequential(
            nn.Dropout2d(dropout),
            nn.Conv2d(start_fm*4, start_fm*2, 3, padding=(1,1)),
            ResidualBlock(start_fm*2),
            ResidualBlock(start_fm*2, batch_activation=True),
        )

        self.decoder_1 = nn.Sequential(
            nn.Dropout2d(dropout),
            nn.Conv2d(start_fm*2, start_fm, 3, padding=(1,1)),
            ResidualBlock(start_fm),
            ResidualBlock(start_fm, batch_activation=True),
        )

        self.conv_last = nn.Conv2d(start_fm, n_classes, 1)

    def forward(self, x):
        # Encoder

        conv1 = self.encoder_1(x) #128
        x = self.pool(conv1) # 64
        x = nn.Dropout2d(self.drop)(x)

        conv2 = self.encoder_2(x) #64
        x = self.pool(conv2) # 32
        x = nn.Dropout2d(self.drop)(x)

        conv3 = self.encoder_3(x) #32
        x = self.pool(conv3) #16
        x = nn.Dropout2d(self.drop)(x)

        conv4 = self.encoder_4(x) #16
        x = self.pool(conv4) # 8
        x = nn.Dropout2d(self.drop)(x)


        # Middle
        x     = self.middle(x) # 8

        # Decoder
        x     = self.deconv_4(x) #16
        x     = torch.cat([conv4, x], dim=1) #16
        x     = self.decoder_4(x)


        x     = self.deconv_3(x) #32
        x     = torch.cat([conv3, x], dim=1)
        x     = self.decoder_3(x)


        x     = self.deconv_2(x) #64
        x     = torch.cat([conv2, x], dim=1)
        x     = self.decoder_2(x)


        x     = self.deconv_1(x) # 128
        x     = torch.cat([conv1, x], dim=1)
        x     = self.decoder_1(x)

        out   = (self.conv_last(x)) # 128
        return out

In [7]:
def tensor2np(tensor):
    tensor = tensor.squeeze().cpu()
    return tensor.detach().numpy()

def normtensor(tensor):
    tensor = torch.where(tensor<0., torch.zeros(1).cuda(), torch.ones(1).cuda())
    return tensor

def count_params(model):
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    return pytorch_total_params

In [8]:
def make_model(prev_model = None):
    # Make the model
    if prev_model == None:
        model = UNet_ResNet().to(device)
    else:
        model = prev_model

    print('Number of parameter:', count_params(model))

    # Make the loss and optimizer
    criterion = nn.BCEWithLogitsLoss()
    optimizer   = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    return model, criterion, optimizer

In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [10]:
model, _, _ = make_model(None)

Number of parameter: 4896369


In [11]:
try:
    model.load_state_dict(torch.load(pth_file_path))
except:
    print('Can not load weights')

In [12]:
def predict(model,input,device):
    model.eval()
    predicted_masks = []
    with torch.no_grad():
        predict = model(input)
        predict = (predict > 0).type(torch.float)
        predicted_masks.append(predict)
    predicted_masks = torch.cat(predicted_masks)
    return predicted_masks

In [13]:
model = model.eval()
model = model.to(device)
test_transform = transforms.Compose([transforms.Grayscale(1),
                                     transforms.ToTensor()])

In [14]:
image_name_path = glob.glob(os.path.join(img_path, "*.jpg"))

In [15]:
for file in image_name_path:
    img_name = os.path.basename(file)
    img_name = img_name.replace('.jpg','.png')
    img_PIL = Image.open(file)
    input_img = test_transform(img_PIL)
    input_img_data = input_img.unsqueeze(0).to(device)
    predicted_mask = predict(model, input_img_data, device=device)
    org_image=input_img_data.to(device='cuda')
    out_tensor = torch.mul(predicted_mask,org_image)
    out_tensor = out_tensor.squeeze(dim=0)
    enhance_image = transforms.ToPILImage()(out_tensor)
    enhance_image.save(save_path+img_name,"PNG")
    print(img_name)

0.png
150.png
155.png
160.png
165.png
170.png
175.png
180.png
185.png
190.png
195.png
200.png
205.png
210.png
215.png
220.png
225.png
230.png
235.png
240.png
245.png
250.png
255.png
260.png
265.png
270.png
275.png
280.png
285.png
290.png
295.png
300.png
305.png
310.png
315.png
320.png
325.png
330.png
335.png
340.png
345.png
350.png
355.png
360.png
365.png
370.png
375.png
380.png
385.png
390.png
395.png
400.png
405.png
410.png
415.png
420.png
425.png
430.png
435.png
440.png
445.png
450.png
455.png
460.png
465.png
470.png
475.png
480.png
485.png
490.png
495.png
500.png
505.png
510.png
515.png
520.png
525.png
530.png
535.png
540.png
545.png
550.png
555.png
560.png
565.png
570.png
575.png
