# Model Training

In this notebook, we will train a deeplabv3 semantic segmentation model based on the resnet50 network structure, using the previous dataset and labeling results for training.

#1 Load Library and data preprocessing

In [1]:
from google.colab import drive
drive.mount('/content/drive')

repo = "/content/drive/MyDrive/Dissertation/MyWork/Dataset/" #will be useful to loop over the files later on.

Mounted at /content/drive


In [2]:
import matplotlib.pyplot as plt
import os
import numpy as np
import cv2
import torchvision.models.segmentation
import torch
import torchvision.transforms as tf
from torchvision import transforms, datasets, models
from PIL import Image
import os
import shutil


Image preprocessing and training/testing dataset division

In [None]:
# # define file path
# input_folder = repo + 'GT1024'
# output_folder = repo + 'GT'

# for filename in os.listdir(input_folder):
#     if filename.endswith('.png'):
#         input_path = os.path.join(input_folder, filename)
#         output_path = os.path.join(output_folder, filename)

#         # resize to 64x64
#         img = Image.open(input_path)
#         img = img.resize((64, 64), Image.ANTIALIAS)

#         # save the image
#         img.save(output_path)
#         print( filename + 'done。')
# print('all image is done。')


In [None]:
# # check the number of image file
# lll = "/content/drive/MyDrive/Dissertation/MyWork/Dataset/GT"
# file_count = len(os.listdir(lll))

# print(f'file number is: {file_count} ')


Split 1000 pictures as a test set

In [None]:
# for i in range(9001, 10001):
#     name1 = output_folder + "/Sample_" + str(i) + "_gt.png"
#     name2 = repo + "Blood_Cancer/Sample_" + str(i) + ".tiff"
#     path1 = repo + "TestGT/Sample_" + str(i) + "_gt.png"
#     path2 = repo + "TestBC/Sample_" + str(i) + ".tiff"

#     shutil.move(name1, path1)
#     shutil.move(name2, path2)
#     print(str(i) + ' done。')

# print('data set split is done')

# 2 Train the model

Define basic parameters

In [3]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("using {} device.".format(device))

using cuda device.


In [4]:
Learning_Rate=1e-5
width=height=800 # image width and height
batchSize=3

In [5]:
train_folder= repo + "Blood_Cancer"
gt_folder = repo + "GT"
list_img = os.listdir(train_folder)
list_gt = os.listdir(gt_folder)

In [26]:
transformImg = tf.Compose([tf.ToPILImage(),tf.Resize((height,width)), tf.ToTensor(),tf.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
transformAnn = tf.Compose([tf.ToPILImage(),tf.Resize((height,width)), tf.ToTensor()])

Define some functions to assist training

1, ReadRandomImage and LoadBatch: Randomly select a few pictures to load the pre-trained model, and save the model locally after a few simple training steps

2, ReadNextImage and LoadBatchAll: Load all data set for fully train the model

In [7]:
def ReadRandomImage():
    idx = np.random.randint(0,len(list_img)) # Pick random image
    img = cv2.imread(os.path.join(train_folder,list_img[idx]))
    img_GT = cv2.imread(os.path.join(gt_folder,list_img[idx].replace(".tiff","_gt.png")),0)

    ann_map = np.zeros(img_GT.shape[0:2],np.float32) # Segmentation map
    ann_map[ img_GT == 255 ] = 1

    img=transformImg(img)
    ann_map=transformAnn(ann_map)

    return img, ann_map

In [8]:
def ReadNextImage(idx):
    img = cv2.imread(os.path.join(train_folder,list_img[idx]))
    img_GT = cv2.imread(os.path.join(gt_folder,list_img[idx].replace(".tiff","_gt.png")),0)

    ann_map = np.zeros(img_GT.shape[0:2],np.float32) # Segmentation map
    ann_map[ img_GT == 255 ] = 1

    img=transformImg(img)
    ann_map=transformAnn(ann_map)

    return img, ann_map

In [9]:
def LoadBatch(): # Load batch of images
    images = torch.zeros([batchSize,3,height,width])
    ann = torch.zeros([batchSize, height, width])

    for i in range(batchSize):
        images[i],ann[i]=ReadRandomImage()

    return images, ann

In [10]:
def LoadBatchAll(idx): # Load batch of images
    images = torch.zeros([batchSize,3,height,width])
    ann = torch.zeros([batchSize, height, width])

    for i in range(batchSize):
        images[i],ann[i]=ReadNextImage(idx)
        idx = idx + 1

    return images, ann

Randomly train, read the pre-trained weights

In [25]:
Net = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True)

Net.classifier[4] = torch.nn.Conv2d(256, 2, kernel_size=(1, 1), stride=(1, 1)) # Change final layer to 3 classes
Net=Net.to(device)
optimizer=torch.optim.Adam(params=Net.parameters(),
                           lr=Learning_Rate)
                          #  ,weight_decay=Weight_Decay) # Create adam optimizer
criterion = torch.nn.CrossEntropyLoss() # Set loss function

In [13]:
Best_loss = 1
Best_loss = [Best_loss]
model_path = repo + "weight.torch"
Net = Net.to(device)

for itr in range(5): # Training loop
    images,ann=LoadBatch() # Load taining batch
    images=torch.autograd.Variable(images,requires_grad=False).to(device) # Load image
    ann = torch.autograd.Variable(ann, requires_grad=False).to(device) # Load annotation
    Pred=Net(images)['out'] # make prediction
    Net.zero_grad()

    Loss=criterion(Pred,ann.long()) # Calculate cross entropy loss
    Loss.backward() # Backpropogate loss
    optimizer.step() # Apply gradient descent change to weight
    seg = torch.argmax(Pred[0], 0).cpu().detach().numpy()  # Get  prediction classes
    print(itr,") Loss=",Loss.data.cpu().numpy())

    if Loss.data.cpu().numpy() < Best_loss[0]:
        Best_loss[0] = Loss.data.cpu().numpy()
        np.savetxt(repo + 'Best_loss.txt',Best_loss,fmt = '%f')
        print("best loss is saved:" + str(Best_loss[0]))
        torch.save(Net.state_dict(), model_path)
        print("model is saved:")

0 ) Loss= 0.69061303
best loss is saved:0.69061303
model is saved:
1 ) Loss= 0.6734091
best loss is saved:0.6734091
model is saved:
2 ) Loss= 0.6780588
3 ) Loss= 0.65889704
best loss is saved:0.65889704
model is saved:
4 ) Loss= 0.673522


Read all data for training

In [23]:
Best_loss = np.loadtxt(repo + 'Best_loss.txt').tolist()
Best_loss = [Best_loss]
Learning_Rate=1e-6
model_path = repo + "weight.torch"
Net.load_state_dict(torch.load(model_path))
Net = Net.to(device)
Best_loss
width=height=800 # image width and height

In [None]:
width=height=800 # image width and height
for itr in range(1): # Training loop
    start = 1
    end = 9000
    step = batchSize
    for i in range(start, end + 1, step):
        if i+2 <= end:
            images,ann=LoadBatchAll(i) # Load taining batch
            images=torch.autograd.Variable(images,requires_grad=False).to(device) # Load image
            ann = torch.autograd.Variable(ann, requires_grad=False).to(device) # Load annotation
            Pred=Net(images)['out'] # make prediction
            Net.zero_grad()

            Loss=criterion(Pred,ann.long()) # Calculate cross entropy loss
            Loss.backward() # Backpropogate loss
            optimizer.step() # Apply gradient descent change to weight
            seg = torch.argmax(Pred[0], 0).cpu().detach().numpy()  # Get  prediction classes
            print(itr,") " + str(i) + " Loss=",Loss.data.cpu().numpy())

            if Loss.data.cpu().numpy() < Best_loss[0]:
                Best_loss[0] = Loss.data.cpu().numpy()
                np.savetxt(repo + 'Best_loss.txt',Best_loss,fmt = '%f')
                print("best loss is saved:" + str(Best_loss[0]))
                torch.save(Net.state_dict(), model_path)
                print("model is saved:")