In [2]:
import os
import time
import copy
import torch
import random
import pandas as pd
from skimage import io, transform
import numpy as np
from PIL import Image
import albumentations as A
import matplotlib.pyplot as plt
from IPython.display import clear_output
import bitsandbytes as bnb
import torchvision.transforms.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, models
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from tqdm import tqdm

from utils.loss_function import SaliencyLoss
from utils.loss_function import AUC
from utils.data_process import MyDataset, MyTransform
from utils.data_process import MyDatasetCNNMerge, MyTransformCNNMerge
from utils.data_process import preprocess_img, postprocess_img
from utils.data_process import compute_metric, compute_metric_CNNMerge
from utils.data_process import count_parameters

from model.Merge_CNN_model import CNNMerge

In [3]:
path_maps_1_train = './datasets/train/maps_train_1/'
path_maps_2_train = './datasets/train/maps_train_2/'

path_maps_1_val = './datasets/val/maps_val_1/'
path_maps_2_val = './datasets/val/maps_val_2/'

path_maps_train = './datasets/train/train_maps/'
path_maps_val = './datasets/val/val_maps/'

path_train_ids = './datasets/train_ids_SALICON_CAT.csv'
path_val_ids = './datasets/val_ids_SALICON_CAT.csv'

print(len(os.listdir(path_maps_1_train))), print(len(os.listdir(path_maps_2_train)))
print(len(os.listdir(path_maps_1_val))), print(len(os.listdir(path_maps_2_val)))
print(len(os.listdir(path_maps_train)))
print(len(os.listdir(path_maps_val)))

In [None]:
train_ids = pd.read_csv(path_train_ids)
val_ids = pd.read_csv(path_val_ids)
print(train_ids.iloc[1])
print(val_ids.iloc[1])

dataset_sizes = {'train': len(train_ids), 'val': len(val_ids)}
print(dataset_sizes)

In [6]:
batch_size = 64
shape_r = 288
shape_c = 384

transform = MyTransformCNNMerge(shape_r=shape_r, shape_c=shape_c)

train_set = MyDatasetCNNMerge(
    ids=train_ids,
    map1_dir=path_maps_1_train,
    map2_dir=path_maps_2_train,
    saliency_dir=path_maps_train,
    transform=transform
)

val_set = MyDatasetCNNMerge(
    ids=val_ids,
    map1_dir=path_maps_1_val,
    map2_dir=path_maps_2_val,
    saliency_dir=path_maps_val,
    transform=transform
)

dataloaders = {
    'train':DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2), 
    'val':DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2)
}

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = CNNMerge()
model = model.to(device)
print(f'The model has {count_parameters(model):,} trainable parameters')

# Train the model

In [15]:
history_loss_train = []
history_loss_val = []

history_loss_train_cc = []
history_loss_train_sim = []
history_loss_train_kldiv = []
history_loss_train_nss = []
history_loss_train_auc = []

history_loss_val_cc = []
history_loss_val_sim = []
history_loss_val_kldiv = []
history_loss_val_nss = []
history_loss_val_auc = []

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=5e-5)
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

loss_fn = SaliencyLoss()

'''Training'''
best_model_wts = copy.deepcopy(model.state_dict())
num_epochs = 30
best_loss = 100
path_to_save = 'path to save'

for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch + 1, num_epochs))
    print('-' * 10)

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()  # Set model to training mode
        else:
            model.eval()   # Set model to evaluate mode

        running_loss = 0.0

        # Iterate over data.
        for i_batch, sample_batched in tqdm(enumerate(dataloaders[phase])):
            smap1, smap2, smap = sample_batched['map1'], sample_batched['map2'], sample_batched['saliency']
            smap1, smap2, smap = smap1.type(torch.float32), smap2.type(torch.float32), smap.type(torch.float32)
            smap1, smap2, smap = smap1.to(device), smap2.to(device), smap.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            # track history if only in train
            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(smap1, smap2)

                loss = -2*loss_fn(outputs, smap, loss_type='cc')
                loss = loss - loss_fn(outputs, smap, loss_type='sim')
                loss = loss + 10*loss_fn(outputs, smap, loss_type='kldiv')

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            # statistics
            if phase == 'train':
                history_loss_train.append(loss.item())
                history_loss_train_cc.append(loss_fn(outputs, smap, loss_type='cc').item())
                history_loss_train_sim.append(loss_fn(outputs, smap, loss_type='sim').item())
                history_loss_train_kldiv.append(loss_fn(outputs, smap, loss_type='kldiv').item())
                history_loss_train_nss.append(loss_fn(outputs, smap, loss_type='nss').item())
            else:
                history_loss_val.append(loss.item())
                history_loss_val_cc.append(loss_fn(outputs, smap, loss_type='cc').item())
                history_loss_val_sim.append(loss_fn(outputs, smap, loss_type='sim').item())
                history_loss_val_kldiv.append(loss_fn(outputs, smap, loss_type='kldiv').item())
                history_loss_val_nss.append(loss_fn(outputs, smap, loss_type='nss').item())

            running_loss += loss.item()

        if phase == 'train':
            scheduler.step()

        epoch_loss = running_loss / dataset_sizes[phase]

        print('{} Loss: {:.4f}'.format(phase, epoch_loss))

        if phase == 'val' and epoch_loss < best_loss:
            best_loss = epoch_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            counter = 0
        elif phase == 'val' and epoch_loss >= best_loss:
            counter += 1
            if counter == 5:
                savepath = path_to_save + '/CNNMerge_'+str(epoch)+'.pth'
                torch.save(model.state_dict(), savepath)
                print('EARLY STOP!')
                break

    # saving weights
    if epoch%1 == 0:
        savepath = path_to_save + '/CNNMerge_'+str(epoch)+'.pth'
        torch.save(model.state_dict(), savepath)

    print()

print('Best val loss: {:4f}'.format(best_loss))
savepath = path_to_save + '/CNNMerge_'+str(epoch)+'.pth'
torch.save(model.state_dict(), savepath)
model.load_state_dict(best_model_wts)

# Show val

In [10]:
def make_sub(model, num_pic, shape_r=288, shape_c=384):
    map1_path = './datasets/maps_val_1/COCO_val2014_'+num_pic+'.jpg'
    map2_path = './datasets/maps_val_2/COCO_val2014_'+num_pic+'.jpg'
    smap_path = './datasets/val_maps/COCO_val2014_'+num_pic+'.png'
    
    map1 = Image.open(map1_path).convert('L')
    smap1 = np.expand_dims(np.array(map1) / 255., axis=0)
    smap1 = torch.from_numpy(smap1)
    
    map2 = Image.open(map2_path).convert('L')
    smap2 = np.expand_dims(np.array(map2) / 255., axis=0)
    smap2 = torch.from_numpy(smap2)

    saliency = Image.open(smap_path).convert('L')
    smap = np.expand_dims(np.array(saliency) / 255., axis=0)
    smap = torch.from_numpy(smap)

    smap1, smap2, smap = transform(smap1, smap2, smap)
    smap1 = smap1.type(torch.float32).to(device)
    smap2 = smap2.type(torch.float32).to(device)
    
    toPIL = transforms.ToPILImage()
    pred = model(smap1.unsqueeze(0), smap2.unsqueeze(0))
    pic = toPIL(pred.squeeze())
    pred_np = pred.squeeze().detach().cpu().numpy()
    smp_np = smap.squeeze().numpy()
    auc = AUC(pred_np, smp_np)
    
    return pic, auc

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

path_sub_model = 'your path to model here'
model.load_state_dict(torch.load(path_sub_model))
model.eval()

fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))

num_pic = '000000102837'
image = mpimg.imread('./datasets/val_images/COCO_val2014_'+num_pic+'.jpg')
img_true = mpimg.imread('./datasets/val_maps/COCO_val2014_'+num_pic+'.png')
pic, auc = make_sub(model, num_pic)

ax[0].imshow(image)
ax[0].set_title('Image')
ax[1].imshow(img_true)
ax[1].set_title('True')
ax[2].imshow(pic)
ax[2].set_title('Pred')

plt.show()
print('AUC = ', auc)

# Compute val metrics

In [None]:
path_sub_model = 'your path to model here'
model.load_state_dict(torch.load(path_sub_model))
model.eval()

val_loss, val_loss_cc, val_loss_sim, val_loss_kldiv, val_loss_nss, val_auc = compute_metric_CNNMerge(
    model, 
    dataloaders['val'], 
    device = device,
    t=10
)

print('Loss = ', val_loss)
print('CC = ', val_loss_cc)
print('SIM = ', val_loss_sim)
print('KL = ', val_loss_kldiv)
print('NSS = ', val_loss_nss)
print('AUC = ', val_auc)