In [2]:
import os
import time
import copy
import torch
import random
import pandas as pd
from skimage import io, transform
import numpy as np
from PIL import Image
import albumentations as A
import matplotlib.pyplot as plt
from IPython.display import clear_output
import bitsandbytes as bnb
import torchvision.transforms.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, models
from torchvision.utils import save_image
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from tqdm import tqdm

from utils.loss_function import SaliencyLoss
from utils.loss_function import AUC
from utils.data_process import MyDataset, MyTransform
from utils.data_process import MyDatasetCNNMerge, MyTransformCNNMerge
from utils.data_process import preprocess_img, postprocess_img
from utils.data_process import compute_metric, compute_metric_CNNMerge
from utils.data_process import count_parameters

from model.Merge_CNN_model import CNNMerge

flag = 3 # 0 for TranSalNet_Dense, 1 for TranSalNet_Res, 2 for TranSalNet_ViT, 3 for TranSalNet_ViT_multidecoder

if flag == 0:
    from model.TranSalNet_Dense import TranSalNet
elif flag == 1:
    from model.TranSalNet_Res import TranSalNet
elif flag == 2:
    from model.TranSalNet_ViT import TranSalNet
elif flag == 3:
    from model.TranSalNet_ViT_multidecoder import TranSalNet

In [3]:
path_images_train = './datasets/train/train_images/'
path_images_val = './datasets/val/val_images/'
path_images_test = './datasets/test/test_images_salicon/'

path_maps_train = './datasets/train/train_maps/'
path_maps_val = './datasets/val/val_maps/'

path_train_ids = './datasets/train_ids_SALICON_CAT.csv'
path_val_ids = './datasets/val_ids_SALICON_CAT.csv'

print(len(os.listdir(path_images_train)))
print(len(os.listdir(path_images_test)))
print(len(os.listdir(path_images_val)))
print(len(os.listdir(path_maps_train)))
print(len(os.listdir(path_maps_val)))

In [None]:
train_ids = pd.read_csv(path_train_ids)
val_ids = pd.read_csv(path_val_ids)
print(train_ids.iloc[1])
print(val_ids.iloc[1])

dataset_sizes = {'train': len(train_ids), 'val': len(val_ids)}
print(dataset_sizes)

In [6]:
batch_size = 4
shape_r = 288
shape_c = 384

p = 0.5
train_transform = MyTransform(p=p, shape_r=shape_r, shape_c=shape_c, iftrain=True)
val_transform = MyTransform(p=p, shape_r=shape_r, shape_c=shape_c, iftrain=False)

train_set = MyDataset(
    ids=train_ids,
    stimuli_dir=path_images_train,
    saliency_dir=path_maps_train,
    transform=train_transform
)

val_set = MyDataset(
    ids=val_ids,
    stimuli_dir=path_images_val,
    saliency_dir=path_maps_val,
    transform=val_transform
)

dataloaders = {
    'train':DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2),
    'val':DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2)
}

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = TranSalNet()
model = model.to(device)
print(f'The model has {count_parameters(model):,} trainable parameters')

# Train the model

In [9]:
history_loss_train = []
history_loss_val = []

history_loss_train_cc = []
history_loss_train_sim = []
history_loss_train_kldiv = []
history_loss_train_nss = []
history_loss_train_auc = []

history_loss_val_cc = []
history_loss_val_sim = []
history_loss_val_kldiv = []
history_loss_val_nss = []
history_loss_val_auc = []

In [None]:
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=1e-5)
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

loss_fn = SaliencyLoss()

'''Training'''
best_model_wts = copy.deepcopy(model.state_dict())
num_epochs = 30
best_loss = 100
path_to_save = 'path to save'

for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch + 1, num_epochs))
    print('-' * 10)

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()  # Set model to training mode
        else:
            model.eval()   # Set model to evaluate mode

        running_loss = 0.0

        # Iterate over data.
        for i_batch, sample_batched in tqdm(enumerate(dataloaders[phase])):
            stimuli, smap = sample_batched['image'], sample_batched['saliency']
            stimuli, smap = stimuli.type(torch.float32), smap.type(torch.float32)
            stimuli, smap = stimuli.to(device), smap.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            # track history if only in train
            with torch.set_grad_enabled(phase == 'train'):
                if flag == 3:
                    outputs_1, outputs_2 = model(stimuli)
                    
                    loss_1 = -2*loss_fn(outputs_1, smap, loss_type='cc')
                    loss_1 = loss_1 - loss_fn(outputs_1, smap, loss_type='sim')
                    loss_1 = loss_1 + 10*loss_fn(outputs_1, smap, loss_type='kldiv')
                    
                    loss_2 = -2*loss_fn(outputs_2, smap, loss_type='cc')
                    loss_2 = loss_2 - loss_fn(outputs_2, smap, loss_type='sim')
                    loss_2 = loss_2 + 10*loss_fn(outputs_2, smap, loss_type='kldiv')
                    
                    loss = loss_1 + loss_2
                else:
                    outputs = model(stimuli)

                    loss = -2*loss_fn(outputs, smap, loss_type='cc')
                    loss = loss - loss_fn(outputs, smap, loss_type='sim')
                    loss = loss + 10*loss_fn(outputs, smap, loss_type='kldiv')

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            # statistics
            if phase == 'train':
                if flag == 3:
                    history_loss_train.append(loss.item())
                    history_loss_train_cc.append([loss_fn(outputs_1, smap, loss_type='cc').item(), loss_fn(outputs_2, smap, loss_type='cc').item()])
                    history_loss_train_sim.append([loss_fn(outputs_1, smap, loss_type='sim').item(), loss_fn(outputs_2, smap, loss_type='sim').item()])
                    history_loss_train_kldiv.append([loss_fn(outputs_1, smap, loss_type='kldiv').item(), loss_fn(outputs_2, smap, loss_type='kldiv').item()])
                    history_loss_train_nss.append([loss_fn(outputs_1, smap, loss_type='nss').item(), loss_fn(outputs_2, smap, loss_type='nss').item()])
                else:
                    history_loss_train.append(loss.item())
                    history_loss_train_cc.append(loss_fn(outputs, smap, loss_type='cc').item())
                    history_loss_train_sim.append(loss_fn(outputs, smap, loss_type='sim').item())
                    history_loss_train_kldiv.append(loss_fn(outputs, smap, loss_type='kldiv').item())
                    history_loss_train_nss.append(loss_fn(outputs, smap, loss_type='nss').item())
            else:
                if flag == 3:
                    history_loss_val.append(loss.item())
                    history_loss_val_cc.append([loss_fn(outputs_1, smap, loss_type='cc').item(), loss_fn(outputs_2, smap, loss_type='cc').item()])
                    history_loss_val_sim.append([loss_fn(outputs_1, smap, loss_type='sim').item(), loss_fn(outputs_2, smap, loss_type='sim').item()])
                    history_loss_val_kldiv.append([loss_fn(outputs_1, smap, loss_type='kldiv').item(), loss_fn(outputs_2, smap, loss_type='kldiv').item()])
                    history_loss_val_nss.append([loss_fn(outputs_1, smap, loss_type='nss').item(), loss_fn(outputs_2, smap, loss_type='nss').item()])
                else:
                    history_loss_val.append(loss.item())
                    history_loss_val_cc.append(loss_fn(outputs, smap, loss_type='cc').item())
                    history_loss_val_sim.append(loss_fn(outputs, smap, loss_type='sim').item())
                    history_loss_val_kldiv.append(loss_fn(outputs, smap, loss_type='kldiv').item())
                    history_loss_val_nss.append(loss_fn(outputs, smap, loss_type='nss').item())

            running_loss += loss.item()

        if phase == 'train':
            scheduler.step()

        epoch_loss = running_loss / dataset_sizes[phase]

        print('{} Loss: {:.4f}'.format(phase, epoch_loss))

        if phase == 'val' and epoch_loss < best_loss:
            best_loss = epoch_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            counter = 0
        elif phase == 'val' and epoch_loss >= best_loss:
            counter += 1
            if counter == 5:
                savepath = path_to_save + '/TranSalNet_ViT_multidecoder_'+str(epoch)+'.pth'
                torch.save(model.state_dict(), savepath)
                print('EARLY STOP!')
                break

    # saving weights
    if epoch%1 == 0:
        savepath = path_to_save + '/TranSalNet_ViT_multidecoder_'+str(epoch)+'.pth'
        torch.save(model.state_dict(), savepath)

    print()

print('Best val loss: {:4f}'.format(best_loss))
savepath = path_to_save + '/TranSalNet_ViT_multidecoder_'+str(epoch)+'.pth'
torch.save(model.state_dict(), savepath)
model.load_state_dict(best_model_wts)

# Show val

In [9]:
def make_sub(model, num_pic, flag, shape_r=288, shape_c=384):
    im_path = './datasets/val_images/COCO_val2014_'+num_pic+'.jpg'
    smap_path = './datasets/val_maps/COCO_val2014_'+num_pic+'.png'

    image = Image.open(im_path).convert('RGB')
    img = np.array(image) / 255.
    img = np.transpose(img, (2, 0, 1))
    img = torch.from_numpy(img)

    saliency = Image.open(smap_path)
    smap = np.expand_dims(np.array(saliency) / 255., axis=0)
    smap = torch.from_numpy(smap)

    img, smap = val_transform(img, smap)
    img = img.type(torch.float32).to(device)
    
    toPIL = transforms.ToPILImage()

    if flag == 3:
        pred_1, pred_2 = model(img.unsqueeze(0))
        pic_1 = toPIL(pred_1.squeeze())
        pic_2 = toPIL(pred_2.squeeze())
        pred_np_1 = pred_1.squeeze().detach().cpu().numpy()
        pred_np_2 = pred_2.squeeze().detach().cpu().numpy()
        smp_np = smap.squeeze().numpy()
        auc_1 = AUC(pred_np_1, smp_np)
        auc_2 = AUC(pred_np_2, smp_np)
        return (pic_1, pic_1), (auc_1, auc_2)
    else:
        pred = model(img.unsqueeze(0))
        pic = toPIL(pred.squeeze())
        pred_np = pred.squeeze().detach().cpu().numpy()
        smp_np = smap.squeeze().numpy()
        auc = AUC(pred_np, smp_np)
        return pic, auc

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

path_sub_model = 'your path to model here'
model.load_state_dict(torch.load(path_sub_model))
model.eval()

if flag == 3:
    fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(15, 15))
else:
    fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))

num_pic = str('000000010138')

img_orig = mpimg.imread('./datasets/val_images/COCO_val2014_'+num_pic+'.jpg')
img_true = mpimg.imread('./datasets/val_maps/COCO_val2014_'+num_pic+'.png')

if flag == 3:
    pics, aucs = make_sub(model, num_pic, flag=flag)
    pic_1, pic_2 = pics
    auc_1, auc_2 = aucs
else:
    pic, auc = make_sub(model, num_pic, flag=flag)

if flag == 3:
    ax[0][0].imshow(img_orig)
    ax[0][0].set_title('Image')
    
    ax[0][1].imshow(img_true)
    ax[0][1].set_title('True')
    
    ax[1][0].imshow(pic_1)
    ax[1][0].set_title('Pred 1')
    
    ax[1][1].imshow(pic_2)
    ax[1][1].set_title('Pred 2')
else:
    ax[0].imshow(img_orig)
    ax[0].set_title('Image')
    ax[1].imshow(pic)
    ax[1].set_title('Pred')
    ax[2].imshow(img_true)
    ax[2].set_title('True')

plt.show()
if flag == 3:
    print('AUC_1 = ', auc_1)
    print('AUC_2 = ', auc_2)
else:
    print('AUC = ', auc)

# Compute val metrics

In [None]:
path_sub_model = 'your path to model here'
model.load_state_dict(torch.load(path_sub_model))
model.eval()

val_loss, val_loss_cc, val_loss_sim, val_loss_kldiv, val_loss_nss, val_auc = compute_metric(
    model, 
    dataloaders['val'], 
    device = device,
    flag=flag, 
    t=10
)

print('Loss = ', val_loss)
print('CC = ', val_loss_cc)
print('SIM = ', val_loss_sim)
print('KL = ', val_loss_kldiv)
print('NSS = ', val_loss_nss)
print('AUC = ', val_auc)

# Create maps dataset

In [None]:
path_to_weight = 'your path to model here'
path_to_save_1 = './datasets/train/maps_train_1/' # create folder maps_train_1 before
path_to_save_2 = './datasets/train/maps_train_2/' # create folder maps_train_2 before
model.load_state_dict(torch.load(path_to_weight))
model.eval()

for i_batch, sample_batched in tqdm(enumerate(dataloaders['train'])):
    stimuli, smap = sample_batched['image'], sample_batched['saliency']
    stimuli, smap = stimuli.type(torch.float32), smap.type(torch.float32)
    stimuli, smap = stimuli.to(device), smap.to(device)

    with torch.no_grad():
        outputs_1, outputs_2 = model(stimuli)
        save_image(outputs_1[0], path_to_save_1+sample_batched['path'][0].split('/')[-1])
        save_image(outputs_2[0], path_to_save_2+sample_batched['path'][0].split('/')[-1])

In [None]:
path_to_weight = 'your path to model here'
path_to_save_1 = './datasets/val/maps_val_1/' # create folder maps_val_1 before
path_to_save_2 = './datasets/val/maps_val_2/' # create folder maps_val_2 before
model.load_state_dict(torch.load(path_to_weight))
model.eval()

for i_batch, sample_batched in tqdm(enumerate(dataloaders['val'])):
    stimuli, smap = sample_batched['image'], sample_batched['saliency']
    stimuli, smap = stimuli.type(torch.float32), smap.type(torch.float32)
    stimuli, smap = stimuli.to(device), smap.to(device)

    with torch.no_grad():
        outputs_1, outputs_2 = model(stimuli)
        save_image(outputs_1[0], path_to_save_1+sample_batched['path'][0].split('/')[-1])
        save_image(outputs_2[0], path_to_save_2+sample_batched['path'][0].split('/')[-1])