In [None]:
# import dependencies
import os 
import numpy as np 
import cv2
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
from torch.utils.data import  random_split
import torch
import torchvision.transforms as transforms
import segmentation_models_pytorch as smp
import torch.nn.functional as F
import albumentations as albu


In [None]:
# Data path 
DATA_DIR = './data/'
# set synthetic dataset path 
synthetic_raw_dir = ""
synthetic_tag_dir = ""
# set real image test dataset 
real_raw_dir  = ""
real_tag_dir = "" 

x_train_dir = os.path.join(DATA_DIR, 'train')
y_train_dir = os.path.join(DATA_DIR, 'trainannot')

x_valid_dir = os.path.join(DATA_DIR, 'val')
y_valid_dir = os.path.join(DATA_DIR, 'valannot')

x_test_dir = os.path.join(DATA_DIR, 'test')
y_test_dir = os.path.join(DATA_DIR, 'testannot')
# gazebo corn mean and std 
gazebo_corn_raw_mean = [0.22065574, 0.29405924, 0.34150184]
gazebo_corn_raw_std = [0.1133429,  0.10545158, 0.11077029]

# blender corn 1 mean and std 
blender_corn_raw_mean1 = [0.4455225,  0.4115633,  0.31220761]
blender_corn_raw_std1 = [0.13111327, 0.12116198, 0.11744572]

# blender single corn 局部 mean and std 
blender_corn_raw_mean2 = [0.79994568, 0.7940369,  0.78530365]
blender_corn_raw_std2 = [0.18375659, 0.17886415, 0.20934963]
# blender signlecorn 整体 mean and std 
blender_corn_raw_mean3 = [0.94009896, 0.94333034, 0.93762586]
blender_corn_raw_std3 = [0.14359212, 0.13377417, 0.15746053]

# multi corn 
# Mean: [0.18631026 0.2225844  0.13032862]
# Std: [0.1057228  0.11291125 0.08504182]

# real corn field 
#Mean: [0.22904675 0.315527   0.21870158]
#Std: [0.21330218 0.24271984 0.20648903]

In [None]:
# Helper functions 
## helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()
## helper function for preprocessing corn mask 
def preprocess_gazebo_corn_mask(im):
    # convert all (0, 255, 255) to white, others to black 
    for x in range(im.width):
        for y in range(im.height):
            # Get the RGB value of the pixel
            r, g, b = im.getpixel((x, y))
            # Check if the pixel is green (has a high green component)
            if r == 0 and g == 255 and b == 255:
                # Convert the pixel to white
                im.putpixel((x, y), (255, 255, 255))
            else:
                im.putpixel((x, y), (0, 0, 0))

In [None]:
# preprocess gazebo corn mask 
def gazebo_corn_mask_preprocess(im): 
    # convert all (50,50,50) above to white, others to black 
    for x in range(im.width):
        for y in range(im.height):
            # Get the RGB value of the pixel
            r, g, b = im.getpixel((x, y))

            # Check if the pixel is green (has a high green component)
            if r >50 or g>50 or b >50:
                # Convert the pixel to white
                im.putpixel((x, y), (255, 255, 255))
            else:
                im.putpixel((x, y), (0, 0, 0))
    im = im.convert('L')
    return im 
# preprocess gazebo corn raw
def gazebo_corn_raw_preprocess(im): 
    return im

In [None]:
# gazebo tree leaf preprocessing 
def gazebo_tree_mask_preprocess(im):
    # convert all (50,50,50) above to white, others to black 
    for x in range(im.width):
        for y in range(im.height):
            # Get the RGB value of the pixel
            r, g, b = im.getpixel((x, y))

            # Check if the pixel is green (has a high green component)
            if r==0 or g==255 or b  == 255:
                # Convert the pixel to white
                im.putpixel((x, y), (255, 255, 255))
            else:
                im.putpixel((x, y), (0, 0, 0))
    im = im.convert('L')
    return im 

In [None]:
# preprocess blender corn mask 
def blender_corn_mask_preprocess(im): 
    #convert rgba to rgb 
    im = im.convert('RGB')
    width, height = im.size
    left = (width - 960) / 2
    top = (height - 640) / 2
    right = (width + 960) / 2
    bottom = (height + 640) / 2
    # 居中裁剪
    im = im.crop((left, top, right, bottom))
    im = im.resize((480,320))
    # convert all (200, 200, 200) above to white, others to black 
    for x in range(im.width):
        for y in range(im.height):
            # Get the RGB value of the pixel
            r, g, b = im.getpixel((x, y))

            if r >200 and g>200 and b >200:
                # Convert the pixel to white
                im.putpixel((x, y), (255, 255, 255))
            else:
                im.putpixel((x, y), (0, 0, 0))
    im = im.convert('L')
    return im 
# preprocess gazebo corn raw
def blender_corn_raw_preprocess(im): 
    im = im.convert('RGB')
    width, height = im.size
    left = (width - 960) / 2
    top = (height - 640) / 2
    right = (width + 960) / 2
    bottom = (height + 640) / 2
    # 居中裁剪
    im = im.crop((left, top, right, bottom))
    im = im.resize((480,320))
    return im

In [None]:
# blender rand pre
# preprocess blender corn mask 
def blender_rand_mask_preprocess(im): 
    #convert rgba to rgb 
    im = im.convert('RGB')
    # convert all (200, 200, 200) above to white, others to black 
    for x in range(im.width):
        for y in range(im.height):
            # Get the RGB value of the pixel
            r, g, b = im.getpixel((x, y))

            # Check if the pixel is green (has a high green component)
            if r >200 and g>200 and b >200:
                # Convert the pixel to white
                im.putpixel((x, y), (255, 255, 255))
            else:
                im.putpixel((x, y), (0, 0, 0))
    im = im.convert('L')
    return im 
# preprocess gazebo corn raw
def blender_rand_raw_preprocess(im): 
    im = im.convert('RGB')
    return im

In [None]:
# Dataset
class Dataset(BaseDataset):
    def __init__(
            self, 
            images_dir, 
            masks_dir, 
            augmentation=None, 
            raw_preprocessing=None,
            mask_preprocessing = None,
            mean = [0.485, 0.456, 0.406],
            std = [0.229, 0.224, 0.225]
    ):
        filenames = os.listdir(images_dir)
        raw_suffix = "_raw.png"
        tag_suffix = "_tag.png"
        self.ids = [filename[:-len(raw_suffix)] for filename in filenames]
        self.images_fps = [os.path.join(images_dir, f"{image_id}{raw_suffix}") for image_id in self.ids]
        self.masks_fps =[os.path.join(masks_dir, f"{image_id}{tag_suffix}") for image_id in self.ids]
        self.augmentation = augmentation
        self.raw_preprocessing = raw_preprocessing
        self.mask_preprocessing = mask_preprocessing
        self.mean = mean 
        self.std = std 

    def __getitem__(self, i):
        # read raw image 
        image = Image.open(self.images_fps[i])

        # preprocessing raw image 
        if self.raw_preprocessing:
            image = self.raw_preprocessing(image)
        
        # read mask iamge 
        mask = Image.open(self.masks_fps[i])

        # preprocess mask image 
        if self.mask_preprocessing:
            mask = self.mask_preprocessing(mask)

        # apply augmentations to raw image 
        if self.augmentation:
            sample = self.augmentation(image=np.array(image), mask=np.array(mask))
            image, mask = sample['image'], sample['mask']

        # convert raw image to tensor 
        raw_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=self.mean,
                                std=self.std)
        ])
        image = raw_transform(image)
        # convert black and white iamge to tensor 
        mask_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        mask = mask_transform(mask)
        return image, mask 
    
    def __len__(self):
        return len(self.ids)

In [None]:
def get_training_augmentation():
    train_transform = [
        albu.HorizontalFlip(p=0.5),
        albu.OneOf(
            [
                albu.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1),
                albu.RandomGamma(gamma_limit=(80, 120), p=1),
                albu.CLAHE(p=1),
                albu.HueSaturationValue(p=1),
            ],
            p=0.9,
        ),
        albu.OneOf(
            [
                albu.Blur(blur_limit=3, p=1),
                albu.MedianBlur(blur_limit=3, p=1),
                albu.GaussNoise(var_limit=(10.0, 50.0), p=1),
                albu.MotionBlur(blur_limit=3, p=1),
            ],
            p=0.9,
        ),
        albu.OneOf(
            [
                albu.GridDistortion(p=1),
                albu.OpticalDistortion(distort_limit=1, shift_limit=0.5, p=1),
                albu.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1),
            ],
            p=0.9,
        ),
    ]
    return albu.Compose(train_transform)

# 1280 720 => 480 320 
def get_training_augmentation1():
    train_transform = [
        albu.HorizontalFlip(p=0.5),
        albu.OneOf(
            [
                albu.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1),
                albu.RandomGamma(gamma_limit=(80, 120), p=1),
                albu.CLAHE(p=0.5),
                albu.HueSaturationValue(p=1),
            ],
            p=0.9,
        ),
        albu.OneOf(
            [
                albu.Blur(blur_limit=3, p=0.5),
                albu.MedianBlur(blur_limit=3, p=0.3),
                albu.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
                #albu.MotionBlur(blur_limit=3, p=1),
            ],
            p=0.9,
        ),
        albu.OneOf(
            [
                albu.GridDistortion(p=1),
                #albu.OpticalDistortion(distort_limit=1, shift_limit=0.5, p=1),
                #albu.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1),
            ],
            p=0.9,
        ),
    ]
    return albu.Compose(train_transform)



In [None]:
trans1 = get_training_augmentation1()
# corn blender dataset 1 稀疏
blender_corn_dataset1 = Dataset(images_dir="/home/ps/leaf_seg/blender_dataset/blender_corn/raw", masks_dir="/home/ps/leaf_seg/blender_dataset/blender_corn/tag", mask_preprocessing=blender_corn_mask_preprocess,raw_preprocessing=blender_corn_raw_preprocess , augmentation=trans1)

# corn blender dataset 2 局部
blender_corn_dataset2 = Dataset(images_dir="/home/ps/leaf_seg/blender_dataset/blender_single_corn/raw", masks_dir="/datadisk/yang/grap_dataset/blender_dataset/blender_single_corn/tag/", mask_preprocessing=blender_corn_mask_preprocess,raw_preprocessing=blender_corn_raw_preprocess , augmentation=trans1)

# corn blender dataset 3 单株整体
blender_corn_dataset3 = Dataset(images_dir="/datadisk/yang/grap_dataset/blender_dataset/blender_single_corn_2/raw/", masks_dir="/datadisk/yang/grap_dataset/blender_dataset/blender_single_corn_2/tag/", mask_preprocessing=blender_corn_mask_preprocess,raw_preprocessing=blender_corn_raw_preprocess , augmentation=trans1 )

# corn blender dataset 4 高密度
blender_corn_dataset4 = Dataset(images_dir="/datadisk/yang/grap_dataset/blender_dataset/blender_multi_corn/raw/", masks_dir="/datadisk/yang/grap_dataset/blender_dataset/blender_multi_corn/tag/", mask_preprocessing=blender_corn_mask_preprocess,raw_preprocessing=blender_corn_raw_preprocess , augmentation=trans1 )


In [None]:
# lets look at some samples of corn blender 
to_pil_image = transforms.ToPILImage()
image1, mask1  = blender_corn_dataset1[0]
print (image1.shape, mask1.shape)
uni_v = mask1.unique()
print(uni_v)
mask1 = to_pil_image(mask1.squeeze().cpu())
visualize(mask = mask1)
print("333", torch.min(image1), torch.max(image1))
image1, mask1  = blender_corn_dataset2[0]
print (image1.shape, mask1.shape)
uni_v = mask1.unique()
print(uni_v)
mask1 = to_pil_image(mask1.squeeze().cpu())
visualize(mask = mask1)
image1, mask1  = blender_corn_dataset3[0]
print (image1.shape, mask1.shape)
uni_v = mask1.unique()
print(uni_v)
print("333", torch.min(image1), torch.max(image1))
mask1 = to_pil_image(mask1.squeeze().cpu())
visualize(mask = mask1)
image1, mask1  = blender_corn_dataset4[0]
print (image1.shape, mask1.shape)
uni_v = mask1.unique()
print(uni_v)
print("333", torch.min(image1), torch.max(image1))
mask1 = to_pil_image(mask1.squeeze().cpu())
visualize(mask = mask1)


In [None]:
#split blender dataset into 80-20 train-val 
# Calculate the lengths of train and validation sets
train_len = int(0.8 * len(blender_corn_dataset1))
val_len = len(blender_corn_dataset1) - train_len

# Split the dataset using random_split
train_set1, val_set1 = random_split(blender_corn_dataset1, [train_len, val_len])
print(len(train_set1), len(val_set1))

# Split the dataset using random_split
train_len = int(0.8 * len(blender_corn_dataset2))
val_len = len(blender_corn_dataset2) - train_len

train_set2, val_set2 = random_split(blender_corn_dataset2, [train_len, val_len])
print(len(train_set2), len(val_set2))

# Split the dataset using random_split
train_len = int(0.8 * len(blender_corn_dataset3))
val_len = len(blender_corn_dataset3) - train_len
train_set3, val_set3 = random_split(blender_corn_dataset3, [train_len, val_len])
print(len(train_set3), len(val_set3))

# Split the dataset using random_split
train_len = int(0.8 * len(blender_corn_dataset4))
val_len = len(blender_corn_dataset4) - train_len
train_set4, val_set4 = random_split(blender_corn_dataset4, [train_len, val_len])
print(len(train_set4), len(val_set4))


In [None]:
# dataloader 
train_loader1 = DataLoader(train_set1, batch_size=12, shuffle=True, num_workers=12)
valid_loader1 = DataLoader(val_set1, batch_size=2, shuffle=False, num_workers=4)
train_loader2 = DataLoader(train_set2, batch_size=12, shuffle=True, num_workers=12)
valid_loader2 = DataLoader(val_set2, batch_size=2, shuffle=False, num_workers=4)
train_loader3 = DataLoader(train_set3, batch_size=12, shuffle=True, num_workers=12)
valid_loader3 = DataLoader(val_set3, batch_size=2, shuffle=False, num_workers=4)
train_loader4 = DataLoader(train_set4, batch_size=12, shuffle=True, num_workers=12)
valid_loader4 = DataLoader(val_set4, batch_size=2, shuffle=False, num_workers=4)

In [None]:
ENCODER = 'se_resnext50_32x4d'
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = 'cuda'

# create segmentation model with pretrained encoder
model = smp.UnetPlusPlus(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    activation=ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [None]:
import segmentation_models_pytorch.utils.metrics
loss = smp.utils.losses.DiceLoss()
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.00008),
])

In [None]:
# create epoch runners 
# it is a simple loop of iterating over dataloader`s samples
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
# train model for 40 epochs
max_score = 0

round1_epoch = 15

def train1Round(round, epoch, train_loader, valid_loader, lr = 0.00008 ):
    max_score = 0
    print("training round" ,round)
    optimizer.param_groups[0]['lr'] = lr
    for i in range(0, epoch):
        print('\nEpoch: {}'.format(i))
        train_logs = train_epoch.run(train_loader)
        valid_logs = valid_epoch.run(valid_loader)
        # do something (save model, change lr, etc.)
        if max_score < valid_logs['iou_score']:
            max_score = valid_logs['iou_score']
            torch.save(model, './round_' + round + '_best_model.pth')
            print('This is the best ever model ?????')
        if i == 15:
            optimizer.param_groups[0]['lr'] = 1e-5
            print('Decrease decoder learning rate to 1e-5!')

In [None]:
torch.cuda.empty_cache()

In [None]:
torch.save(model, 'never_train.pth')

In [None]:
# train new dataset 
dataset_rand = Dataset(images_dir="/datadisk/yang/grap_dataset/blender_rand/raw/", masks_dir="/datadisk/yang/grap_dataset/blender_rand/tag/", mask_preprocessing=blender_rand_mask_preprocess,raw_preprocessing=blender_rand_raw_preprocess , augmentation=trans1 )
train_len = int(0.8 * len(dataset_rand))
val_len = len(dataset_rand) - train_len
# display dataset 
to_pil_image = transforms.ToPILImage()
image1, mask1  = dataset_rand[0]
print (image1.shape, mask1.shape)
uni_v = mask1.unique()
print(uni_v)
mask1 = to_pil_image(mask1.squeeze().cpu())
visualize(mask = mask1)
# Split the dataset using random_split
train_rand, val_rand = random_split(dataset_rand, [train_len, val_len])
print(len(train_rand), len(val_rand))
# dataloader 
trainloader_rand = DataLoader(train_rand, batch_size=12, shuffle=True, num_workers=12)
validloader_rand = DataLoader(val_rand, batch_size=2, shuffle=False, num_workers=4)

#model = torch.load('')
#train1Round("4", 10, train_loader= train_loader4, valid_loader= valid_loader4, lr = 0.00002)

In [None]:
#train1Round("no_train", 6, train_loader=train_loader1, valid_loader=valid_loader1)


In [None]:
train1Round("1", round1_epoch, train_loader=train_loader2, valid_loader=valid_loader2)
train1Round("2", round1_epoch, train_loader=train_loader3, valid_loader=valid_loader3)
train1Round("3", 15, train_loader=train_loader1, valid_loader=valid_loader1)
train1Round("4", 15, train_loader=train_loader4, valid_loader=valid_loader4)
train1Round("5", 15, train_loader=trainloader_rand, valid_loader=validloader_rand)