In [1]:
import cv2
import os
import time
import torch
import numpy as np
from mss import mss
from datetime import datetime
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, dataloader, random_split
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
import time
import copy
import pandas as pd

import glob

In [2]:
print('Using torch %s %s' % (torch.__version__, torch.cuda.get_device_properties(0) if torch.cuda.is_available() else 'CPU'))

Using torch 1.10.1 CPU


In [3]:
# ===============================
# Model configuration
# ===============================
MODEL_PATH = "/Users/user/Desktop/size_worm/save/worm_epoch076_iou0.7476.pth" 
DEVICE = "cpu" # if you dont have GPU

In [4]:
# 512×512 of moniter
MONITOR = {
    "left": 704,
    "top": 284,
    "width": 512,
    "height": 512
}
SAVE_DIR = "/Users/user/Desktop/size_worm/DIR/realtime_results"
os.makedirs(SAVE_DIR, exist_ok=True)

In [5]:
#Do not modify！
N_CLASSES       = 1
LEARNING_RATE   = 0.002
START_FRAME     = 16

In [6]:
class BatchActivate(nn.Module):
    def __init__(self, num_features):
        super(BatchActivate, self).__init__()
        self.norm = nn.BatchNorm2d(num_features)

    def forward(self, x):
        return F.relu(self.norm(x))

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel=3, padding=1, stride=1, activation=True):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 
                            kernel_size=kernel, stride=stride, padding=padding)
        self.batchnorm  = BatchActivate(out_channels)
        self.activation = activation

    def forward(self, x):
        x = self.conv(x)
        if self.activation:
            x = self.batchnorm(x)
        return x

class DoubleConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel=3, padding=1, stride=1):
        super(DoubleConvBlock, self).__init__()
        self.conv1 = ConvBlock(in_channels, out_channels, kernel, padding, stride)
        self.conv2 = ConvBlock(out_channels, out_channels, kernel, padding, stride)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, batch_activation=False):
        super(ResidualBlock, self).__init__()
        self.batch_activation = batch_activation
        self.norm  = nn.BatchNorm2d(num_features=in_channels)
        self.conv1 = ConvBlock(in_channels, in_channels, kernel=3, stride=1, padding=1)
        self.conv2 = ConvBlock(in_channels, in_channels, kernel=3, stride=1, padding=1, activation=False)

    def forward(self, x):
        residual = x
        x = self.norm(x)
        x = self.conv1(x)
        x = self.conv2(x)

        x += residual
        # x = x.view(x.size(0),-1)
        
        if self.batch_activation:
            x = self.norm(x)
        
        return x

In [7]:
class UNet_ResNet(nn.Module):
    def __init__(self, in_channels=1, n_classes=N_CLASSES, dropout=0.1, start_fm=START_FRAME):
        super(UNet_ResNet, self).__init__()
        #Dropout
        self.drop = dropout
        #Pooling
        self.pool = nn.MaxPool2d((2,2))

        # Encoder 
        self.encoder_1 = nn.Sequential(
            nn.Conv2d(in_channels, start_fm, 3, padding=(1,1)),
            ResidualBlock(start_fm),
            ResidualBlock(start_fm, batch_activation=True),
#             nn.MaxPool2d((2,2)),
#             nn.Dropout2d(dropout//2),
        )

        self.encoder_2 = nn.Sequential(
            nn.Conv2d(start_fm, start_fm*2, 3, padding=(1,1)),
            ResidualBlock(start_fm*2),
            ResidualBlock(start_fm*2, batch_activation=True),
#             nn.MaxPool2d((2,2)),
#             nn.Dropout2d(dropout),
        )

        self.encoder_3 = nn.Sequential(
            nn.Conv2d(start_fm*2, start_fm*4, 3, padding=(1,1)),
            ResidualBlock(start_fm*4),
            ResidualBlock(start_fm*4, batch_activation=True),
#             nn.MaxPool2d((2,2)),
#             nn.Dropout2d(dropout),
        )
        
        self.encoder_4 = nn.Sequential(
            nn.Conv2d(start_fm*4, start_fm*8, 3, padding=(1,1)),
            ResidualBlock(start_fm*8),
            ResidualBlock(start_fm*8, batch_activation=True),
#             nn.MaxPool2d((2,2)),
#             nn.Dropout2d(dropout),
        )

        self.middle = nn.Sequential(
            nn.Conv2d(start_fm*8, start_fm*16, 3, padding=3//2),
            ResidualBlock(start_fm*16),
            ResidualBlock(start_fm*16, batch_activation=True),
#             nn.MaxPool2d((2,2))
        )
        
        # Transpose conv
        self.deconv_4  = nn.ConvTranspose2d(start_fm*16, start_fm*8, 2, 2)
        self.deconv_3  = nn.ConvTranspose2d(start_fm*8, start_fm*4, 2, 2)
        self.deconv_2  = nn.ConvTranspose2d(start_fm*4, start_fm*2, 2, 2)
        self.deconv_1  = nn.ConvTranspose2d(start_fm*2, start_fm, 2, 2)

        # Decoder 
        self.decoder_4 = nn.Sequential(
            nn.Dropout2d(dropout),
            nn.Conv2d(start_fm*16, start_fm*8, 3, padding=(1,1)),
            ResidualBlock(start_fm*8),
            ResidualBlock(start_fm*8, batch_activation=True),
        )

        self.decoder_3 = nn.Sequential(
            nn.Dropout2d(dropout),
            nn.Conv2d(start_fm*8, start_fm*4, 3, padding=(1,1)),
            ResidualBlock(start_fm*4),
            ResidualBlock(start_fm*4, batch_activation=True),
        )

        self.decoder_2 = nn.Sequential(
            nn.Dropout2d(dropout),
            nn.Conv2d(start_fm*4, start_fm*2, 3, padding=(1,1)),
            ResidualBlock(start_fm*2),
            ResidualBlock(start_fm*2, batch_activation=True),
        )

        self.decoder_1 = nn.Sequential(
            nn.Dropout2d(dropout),
            nn.Conv2d(start_fm*2, start_fm, 3, padding=(1,1)),
            ResidualBlock(start_fm),
            ResidualBlock(start_fm, batch_activation=True),
        )
            
        self.conv_last = nn.Conv2d(start_fm, n_classes, 1)

    def forward(self, x):
        # Encoder
        
        conv1 = self.encoder_1(x) #128
        x = self.pool(conv1) # 64
        x = nn.Dropout2d(self.drop)(x)

        conv2 = self.encoder_2(x) #64
        x = self.pool(conv2) # 32
        x = nn.Dropout2d(self.drop)(x)

        conv3 = self.encoder_3(x) #32
        x = self.pool(conv3) #16
        x = nn.Dropout2d(self.drop)(x)

        conv4 = self.encoder_4(x) #16
        x = self.pool(conv4) # 8
        x = nn.Dropout2d(self.drop)(x)


        # Middle
        x     = self.middle(x) # 8
        
        # Decoder
        x     = self.deconv_4(x) #16
        x     = torch.cat([conv4, x], dim=1) #16
        x     = self.decoder_4(x)
        

        x     = self.deconv_3(x) #32
        x     = torch.cat([conv3, x], dim=1)
        x     = self.decoder_3(x)


        x     = self.deconv_2(x) #64
        x     = torch.cat([conv2, x], dim=1)
        x     = self.decoder_2(x)


        x     = self.deconv_1(x) # 128
        x     = torch.cat([conv1, x], dim=1)
        x     = self.decoder_1(x)

        out   = (self.conv_last(x)) # 128
        return out

In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [9]:
model = UNet_ResNet().to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()

UNet_ResNet(
  (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (encoder_1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ResidualBlock(
      (norm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): ConvBlock(
        (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (batchnorm): BatchActivate(
          (norm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (conv2): ConvBlock(
        (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (batchnorm): BatchActivate(
          (norm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (2): ResidualBlock(
      (norm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): ConvBlock(
   

In [10]:
# ===============================
# Utility function
# ===============================
def preprocess(frame):
    x = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    x = x / 255.0
    x = torch.from_numpy(x).float().unsqueeze(0).unsqueeze(0)
    return x

def postprocess(pred):
    mask = torch.sigmoid(pred)[0, 0].detach().cpu().numpy()
    mask = (mask > 0.5).astype(np.uint8) * 255
    return mask

In [None]:
# ===============================
# main loop
# ===============================
sct = mss()
frame_id = 0
last_time = time.time()

print("worm segmentation（512x512）")
print("s = save image, q = quit")

while True:
    
    img = np.array(sct.grab(MONITOR))
    frame = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
    
    # resize 到 512x512（if needed）
    frame_512 = cv2.resize(frame, (512, 512))
    
    # every frame running model
    x = preprocess(frame_512).to(DEVICE)
    with torch.no_grad():
        pred = model(x)
    mask = postprocess(pred)
    
    #count pixels of mask
    worm_pixels = np.sum(mask == 255)
    
    # show the pixels number (you can change to worm siez)
   # cv2.putText(mask, f"Pixels: {worm_pixels}", (10, 50),
    #            cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
    
    # overlay
    show = frame_512.copy()
    show[mask==255] = (0,0,255)
    show = cv2.addWeighted(frame_512, 0.7, show, 0.3, 0)
    
    #show FPS 
    now = time.time()
    fps = 1.0 / (now - last_time)
    last_time = now
    cv2.putText(show, f"FPS: {fps:.1f}", (10,25),
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,255,0), 2)
    
    # show the pixels number in window (you can change to worm siez)
    cv2.putText(show, f"Worm Pixels: {worm_pixels}", (10, 50),
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,255,255), 2)
    
    # show window
    cv2.imshow("Realtime Segmentation 512x512", show)
    
    key = cv2.waitKey(1) & 0xFF
    if key == ord('s'):
        ts = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
        cv2.imwrite(f"{SAVE_DIR}/{ts}_img.png", frame_512)
        cv2.imwrite(f"{SAVE_DIR}/{ts}_mask.png", mask)
        print("saved:", ts)
    elif key == ord('q'):
        break
    
    frame_id += 1

cv2.destroyAllWindows()

worm segmentation（512x512）
s = save image, q = quit
