# Model training for WSCNet2

In [None]:
import os
import cv2
import time
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from dropDataset import DropDataset, WSCDataset
from dropNets import DropCondNet, WSCNet, WSCLoss

In [None]:
# Parameters
# target_dir = 'D:/data/droplets' # Path to the directory where the data is stored (the top directory of the data)
target_dir = 'G:/images/dropDatasetDrops/c3ra43624a_4/train' # Path to the directory where the data is stored (the top directory of the data)

net_name = 'WSCNet' # Name of the network to be used, chosen from ['Resnet18', 'Resnet50', 'LeNet', 'MobileNet', 'WSCNet']

pretrained_model_path = '' # Path to the pretrained model, '' if not using pretrained model

batch_size = 64 # Batch size for training


assert net_name in ['Resnet18', 'Resnet50', 'LeNet', 'MobileNet', 'WSCNet'], 'net_name must be chosen from ["Resnet18", "Resnet50", "LeNet", "MobileNet", "WSCNet"]'

## 1. Droplet segmentation

Use modified results to generate droplet segmentation images, which will be saved in `<target_dir>/drops`.
These images will soon be loaded for model training

In [None]:
# create directory
for i in range(4):
    if not os.path.exists(os.path.join(target_dir, 'drops', str(i))):
        os.makedirs(os.path.join(target_dir, 'drops', str(i)))
    else: # clear
        for file in os.listdir(os.path.join(target_dir, 'drops', str(i))):
            os.remove(os.path.join(target_dir, 'drops', str(i), file))

# search image files
img_ext_list = ['.jpg', '.png', '.bmp', '.JPG', '.PNG', '.BMP']
img_name_list = []
for file in os.listdir(target_dir):
    if file[-4:] in img_ext_list:
        img_name_list.append(file)
print('Number of images: ', len(img_name_list))

# load modified results and save droplet images
drop_num_list = [0, 0, 0, 0]
for img_name in img_name_list:
    img_path = os.path.join(target_dir, img_name)
    src_img = cv2.imread(img_path)

    modified_text_path = os.path.join(target_dir, 'textResult', img_name[:-4] + '_modified.txt')
    if not os.path.exists(modified_text_path):
        continue

    modified_results = np.loadtxt(modified_text_path, delimiter='\t', dtype=np.string_)[:, :4].astype(np.float32)
    for line in modified_results:
        x, y, r, _class = line
        _class = int(_class) + 1

        x1 = max(int(x - r), 0)
        x2 = min(int(x + r), src_img.shape[1])
        y1 = max(int(y - r), 0) 
        y2 = min(int(y + r), src_img.shape[0])

        drop_img = src_img[y1:y2, x1:x2]
        cv2.imwrite(os.path.join(target_dir, 'drops', str(_class), img_name[:-4] + "_drops_" + str(i) + img_name[-4:]), drop_img)
        drop_num_list[_class] += 1
print('Number of droplets in each class: ', drop_num_list)

## 2. Training preparation

In [None]:
data = []
for i in range(4):
    path_i = os.path.join(target_dir, 'drops', str(i))
    for file in os.listdir(path_i):
        file_path = os.path.join(path_i, file)
        data.append((file_path, i))

np.random.shuffle(data)

valid_rate = 0.2
train_data = data[:int(len(data) * (1 - valid_rate))]
valid_data = data[int(len(data) * (1 - valid_rate)):]
if net_name == 'WSCNet':
    train_dataset = DropDataset(train_data, mode='WSCNet')
    valid_dataset = DropDataset(valid_data, mode='valid')
else:
    train_dataset = DropDataset(train_data, mode='train')
    valid_dataset = DropDataset(valid_data, mode='valid')
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
print('Number of training droplets: ', len(train_dataset))
print('Number of validation droplets: ', len(valid_dataset))

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device: ', device)

if net_name == 'WSCNet':
    net = WSCNet().to(device)
else:
    net = DropCondNet(net_name).to(device)
# print('Network: ', net)

if pretrained_model_path != '':
    net.load_state_dict(torch.load(pretrained_model_path))
    print('Pretrained model: ', pretrained_model_path)

if net_name == 'WSCNet':
    criterion = WSCLoss(device)
else: # net_name in ['Resnet18', 'Resnet50', 'LeNet', 'MobileNet']
    criterion = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(net.parameters(), lr=0.001) # optimizer

## 3. Model training

In [None]:
writer = SummaryWriter('runs/' + net_name + time.strftime('_%m%d%H%M', time.localtime(time.time())))

print('Start training...')
max_epoch = 1000
min_valid_loss = 100000
for epoch in range(max_epoch):
    ## Train
    running_loss = 0.0
    net.train()
    for i, data in enumerate(train_loader):
        inputs,labels = data
        inputs,labels = inputs.to(device),labels.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        if net_name == 'WSCNet':
            loss = criterion(*outputs,labels)
        else:
            loss = criterion(outputs,labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.cpu().item()

    print('[%d] train loss: %.3f' %(epoch+1, 1000*running_loss/len(train_dataset)))
    writer.add_scalar('train_loss', 1000*running_loss/len(train_dataset), epoch)

    ## Valid
    running_loss = 0.0
    net.eval()
    with torch.no_grad():
        correct_num = [0, 0, 0, 0]
        total_num = [0, 0, 0, 0]
        for i, data in enumerate(valid_loader):
            inputs,labels = data
            inputs,labels = inputs.to(device),labels.to(device)
            outputs = net(inputs)
            if net_name == 'WSCNet':
                loss = criterion(*outputs,labels)
            else:
                loss = criterion(outputs,labels)
            running_loss += loss.cpu().item()
            # calculate accuracy
            if not net_name == 'WSCNet':
                outputs_label = torch.argmax(outputs, dim=1)
                inputs_label = torch.argmax(labels, dim=1)
                outputs_label = outputs_label.cpu().numpy()
                inputs_label = inputs_label.cpu().numpy()
                for j in range(len(outputs_label)):
                    if outputs_label[j] == inputs_label[j]:
                        correct_num[inputs_label[j]] += 1
                    total_num[inputs_label[j]] += 1

    print('[%d] valid loss : %.3f' %(epoch+1, 1000*running_loss/len(valid_dataset)))
    writer.add_scalar('valid_loss', 1000*running_loss/len(valid_dataset), epoch)
    if not net_name == 'WSCNet':
        print('[%d] valid acc : %.3f' %(epoch+1, 100*sum(correct_num)/sum(total_num)))
        writer.add_scalar('valid_acc_total', 100*sum(correct_num)/sum(total_num), epoch)
        for i in range(4):
            if total_num[i] == 0:
                continue
            writer.add_scalar('valid_acc_' + str(i), 100*correct_num[i]/total_num[i], epoch)
    print('----------------------------------------------')

    ## Save best model
    if running_loss < min_valid_loss:
        min_valid_loss = running_loss
        es_num = 0 # early stop number
        model_save_path = os.path.join(target_dir, 'Drop_' + net_name + '.pt')
        torch.save(net.state_dict(), model_save_path)
    else:
        es_num += 1

    ## early stop
    if es_num >= 50:
        print('early stop')
        break

## 4. Generate traced model for software using
If you need to use trained model in WSCNet2.exe to inference, don't forget this step. Only traced model can be used in WSCNet2.exe.

In [None]:
net.load_state_dict(torch.load(os.path.join(target_dir, net_name + '.pt')))
net.eval()

input = torch.randn(1, 3, 32, 32).to(device)
traced_script_module = torch.jit.trace(net, input)
traced_script_module.save(os.path.join(target_dir, 'traced_' + net_name + '.pth'))