# Attention U-shaped Pyramid Segmentation Network

## import packages

In [9]:
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import os
import torch.nn.functional as F
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.optim as optim
import torchvision.transforms as standard_transforms

import numpy as np
import glob

from data_loader import Rescale
from data_loader import RescaleT
from data_loader import RandomCrop
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset

from model import U2NET
from model import U2NETP

from model import UPSPNet

import matplotlib
matplotlib.use('AGG')
import matplotlib.pyplot as plt

import eval

## set parameters

In [10]:
epoch_num = 500
batch_size_train = 36
batch_size_val = 1
train_num = 0
val_num = 0

ite_num = 0
running_loss = 0.0
running_tar_loss = 0.0
ite_num4val = 0
save_epoch = 5
Loss_list = []

## define loss function 

In [11]:
bce_loss = nn.BCELoss(size_average=True)

def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):

    loss0 = bce_loss(d0,labels_v)
    loss1 = bce_loss(d1,labels_v)
    loss2 = bce_loss(d2,labels_v)
    loss3 = bce_loss(d3,labels_v)
    loss4 = bce_loss(d4,labels_v)
    loss5 = bce_loss(d5,labels_v)
    loss6 = bce_loss(d6,labels_v)

    loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
    print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.data,loss1.data,loss2.data,loss3.data,loss4.data,loss5.data,loss6.data))

    return loss0, loss

## set the directory of training dataset

In [12]:
model_name = 'upspnet' #'u2netp'

data_dir = os.path.join(os.getcwd(), '/mnt/DATA_512/Train' + os.sep)
tra_image_dir = os.path.join('src' + os.sep)
tra_label_dir = os.path.join('gt' + os.sep)

image_ext = '.jpg'
label_ext = '.png'

model_dir = os.path.join(os.getcwd(), 'saved_models', "auspnet_512" + os.sep)

In [13]:
tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + image_ext)

tra_lbl_name_list = []
for img_path in tra_img_name_list:
	img_name = img_path.split(os.sep)[-1]

	aaa = img_name.split(".")
	bbb = aaa[0:-1]
	imidx = bbb[0]
	for i in range(1,len(bbb)):
		imidx = imidx + "." + bbb[i]

	tra_lbl_name_list.append(data_dir + tra_label_dir + imidx + label_ext)

print("---")
print("train images: ", len(tra_img_name_list))
print("train labels: ", len(tra_lbl_name_list))
print("---")

train_num = len(tra_img_name_list)

salobj_dataset = SalObjDataset(
    img_name_list=tra_img_name_list,
    lbl_name_list=tra_lbl_name_list,
    transform=transforms.Compose([
        RescaleT(320),
        RandomCrop(288),
        ToTensorLab(flag=0)]))
salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=1)

---
train images:  4800
train labels:  4800
---


## define model

In [14]:
net = UPSPNet.UPSPNET_RSU(3, 1)
#net=torch.nn.DataParallel(net)
#net = nn.DataParallel(net) # multi-GPU

if torch.cuda.is_available():
    net.cuda()

## define optimizer

In [15]:
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

## training process

In [None]:
print("---start training...")

for epoch in range(0, epoch_num):
    net.train()

    for i, data in enumerate(salobj_dataloader):
        ite_num = ite_num + 1
        ite_num4val = ite_num4val + 1

        inputs, labels = data['image'], data['label']

        inputs = inputs.type(torch.FloatTensor)
        labels = labels.type(torch.FloatTensor)

        # wrap them in Variable
        if torch.cuda.is_available():
            inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),
                                                                                        requires_grad=False)
        else:
            inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)

        # y zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        d0, d1, d2, d3, d4, d5, d6= net(inputs_v)
        loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4,d5, d6, labels_v)

        loss.backward()
        optimizer.step()

        # # print statistics
        running_loss += loss.data
        running_tar_loss += loss2.data

        # del temporary outputs and loss
        del d0, d1, d2, d3, d4, loss2, loss

        print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % (
        epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))


    if (epoch+1) % save_epoch== 0:

        torch.save(net.state_dict(), model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
        Loss_list.append(running_loss / ite_num4val)
        running_loss = 0.0
        running_tar_loss = 0.0
        net.train()  # resume train
        ite_num4val = 0

---start training...




l0: 0.671697, l1: 0.790183, l2: 0.687988, l3: 0.734891, l4: 0.687826, l5: 0.736316, l6: 0.711290

[epoch:   1/500, batch:    36/ 4800, ite: 1] train loss: 5.020191, tar: 0.671697 
l0: 0.765791, l1: 0.683958, l2: 0.703052, l3: 0.697804, l4: 0.917278, l5: 1.347346, l6: 1.187292

[epoch:   1/500, batch:    72/ 4800, ite: 2] train loss: 5.661357, tar: 0.718744 
l0: 0.629138, l1: 0.706826, l2: 0.706599, l3: 0.740176, l4: 0.762037, l5: 1.149707, l6: 0.662870

[epoch:   1/500, batch:   108/ 4800, ite: 3] train loss: 5.560022, tar: 0.688876 
l0: 0.721754, l1: 0.671377, l2: 0.670369, l3: 0.688932, l4: 0.903362, l5: 1.065774, l6: 0.776250

[epoch:   1/500, batch:   144/ 4800, ite: 4] train loss: 5.544471, tar: 0.697095 
l0: 0.716175, l1: 0.624220, l2: 0.639525, l3: 0.654826, l4: 0.695071, l5: 0.704586, l6: 0.879921

[epoch:   1/500, batch:   180/ 4800, ite: 5] train loss: 5.418441, tar: 0.700911 
l0: 0.660246, l1: 0.619649, l2: 0.637464, l3: 0.640643, l4: 0.671824, l5: 0.731769, l6: 0.647085

[e

## plot loss

In [None]:
x = range(0, len(Loss_list))
y = Loss_list
plt.plot(x, y, '.-')
plt.xlabel('Test loss vs. ite_num')
plt.ylabel('Test loss')
plt.savefig("loss/loss_{}.png".format(str(epoch+1)))