In [1]:
import os
import time
import copy
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, models
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from tqdm import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau

from utils.loss_function import SaliencyLoss
from utils.data_process import MyDataset

flag = 1 # 0 for TranSalNet_Dense, 1 for TranSalNet_Res

if flag:
    from TranSalNet_Res import TranSalNet
else:
    from TranSalNet_Dense import TranSalNet


  from .autonotebook import tqdm as notebook_tqdm


↑↑↑ Set flag=1 to load TranSalNet_Dense,set flag=0 to load TranSalNet_Res.

In [2]:
train_ids = pd.read_csv(r'datasets/train_ids.csv')
val_ids = pd.read_csv(r'datasets/val_ids.csv')
print(train_ids.iloc[1])
print(val_ids.iloc[1])

dataset_sizes = {'train':len(train_ids),'val':len(val_ids)}
print(dataset_sizes)

image              100.jpg
map         100_fixMap.jpg
fixation    100_fixPts.png
text_map           100.jpg
Name: 1, dtype: object
image              103.jpg
map         103_fixMap.jpg
fixation    103_fixPts.png
text_map           103.jpg
Name: 1, dtype: object
{'train': 871, 'val': 101}


↑↑↑Load image id from dataset

In [3]:
batch_size = 8
train_set = MyDataset(ids=train_ids,
                           stimuli_dir=r'datasets/train/train_stimuli/',
                           saliency_dir=r'datasets/train/train_saliency/',
                           fixation_dir=r'datasets/train/train_fixation/',
                           text_map_dir=r'datasets/train/train_text_map/',
                           transform=transforms.Compose([
                              transforms.ToTensor(),
                              transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                            ]))

val_set = MyDataset(ids=val_ids,
                        stimuli_dir=r'datasets/val/val_stimuli/',
                        saliency_dir = r'datasets/val/val_saliency/',
                        fixation_dir=r'datasets/val/val_fixation/',
                        text_map_dir=r'datasets/val/val_text_map/',
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                            ]))



dataloaders = {'train':DataLoader(train_set, batch_size=batch_size,shuffle=True, num_workers=4)
               ,'val':DataLoader(val_set, batch_size=batch_size,shuffle=False, num_workers=4)}


↑↑↑Set batch_size and Load dataset

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = TranSalNet()
model = model.to(device)
print(device)

cuda:0


In [5]:
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in the model: {total_params}")

Total number of parameters in the model: 66410497


In [6]:
def mean_std(test_list):
    mean = sum(test_list) / len(test_list) 
    variance = sum([((x - mean) ** 2) for x in test_list]) / len(test_list) 
    res = variance ** 0.5
    return mean , res 

# Train the model below

In [7]:
optimizer = optim.Adam(model.parameters(),lr=5e-4 , weight_decay = 1e-4)

scheduler = lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
#scheduler = lr_scheduler.LambdaLR(optimizer, lambda x: 0.5**min(x/15, 1))

loss_fn = SaliencyLoss()

'''Training'''
best_model_wts = copy.deepcopy(model.state_dict())
num_epochs =30
best_loss = 100
for k,v in model.named_parameters():
    print('{}: {}'.format(k, v.requires_grad))

             
for epoch in range(num_epochs):

    print('Epoch {}/{}'.format(epoch + 1, num_epochs))
    print('-' * 10)

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()  # Set model to training mode
        else:
            model.eval()   # Set model to evaluate mode

        running_loss = 0.0
        kl_loss = 0.0
        cc_loss = 0.0 
        sim_loss = 0.0
        nss_loss = 0.0
        kl_loss_list = []
        cc_loss_list = []
        sim_loss_list = []
        nss_loss_list = []

        # Iterate over data.
        for i_batch, sample_batched in tqdm(enumerate(dataloaders[phase])):
            stimuli, smap, fmap , tmap  = sample_batched['image'], sample_batched['saliency'], sample_batched['fixation'] , sample_batched['text_map'] 
            stimuli, smap, fmap , tmap = stimuli.type(torch.cuda.FloatTensor), smap.type(torch.cuda.FloatTensor), fmap.type(torch.cuda.FloatTensor) , tmap.type(torch.cuda.FloatTensor)
            stimuli, smap, fmap , tmap = stimuli.to(device), smap.to(device), fmap.to(device) , tmap.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            # track history if only in train
            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(stimuli , tmap)
                kl = loss_fn(outputs,smap,loss_type='kldiv')
                kl_loss_list.append(kl.detach().cpu().numpy()[0])
                cc = loss_fn(outputs,smap,loss_type='cc')
                cc_loss_list.append(cc.detach().cpu().numpy()[0])
                sim = loss_fn(outputs,smap,loss_type='sim')
                sim_loss_list.append(sim.detach().cpu().numpy()[0])
                nss = loss_fn(outputs,fmap,loss_type='nss')
                nss_loss_list.append(nss.detach().cpu().numpy()[0])

                loss = -1*cc\
                        -1*sim+\
                        10*kl-2*nss
                
                
#                 loss1 = kl 
                
#                 loss2 = nn.MSELoss()(outputs, smap)
                
#                 loss = loss1 + loss2

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            # statistics
            running_loss += loss.item() * stimuli.size(0)
            
            kl_loss += kl.item() * stimuli.size(0)
            cc_loss += cc.item() * stimuli.size(0)
            sim_loss += sim.item() * stimuli.size(0)
            nss_loss += nss.item() * stimuli.size(0)



        if phase == 'train':
            scheduler.step()
            
        epoch_loss = running_loss / dataset_sizes[phase]    

        epoch_loss_kl = kl_loss / dataset_sizes[phase]
        epoch_loss_cc = cc_loss / dataset_sizes[phase]
        epoch_loss_sim = sim_loss / dataset_sizes[phase]
        epoch_loss_nss = nss_loss / dataset_sizes[phase]

        print('{} Loss: {:.4f}'.format(
            phase, epoch_loss))
              
        print('{} KL : {:.4f}'.format(
            phase, epoch_loss_kl))
        
        print('{} CC : {:.4f}'.format(
            phase, epoch_loss_cc))
        
        print('{} SIM : {:.4f}'.format(
            phase, epoch_loss_sim))
        
        print('{} NSS : {:.4f}'.format(
            phase, epoch_loss_nss))
        
        print(phase)
        m , s = mean_std(kl_loss_list)
        print(f"mean_kl : {m} , STD: {s}")
        m , s = mean_std(cc_loss_list)
        print(f"mean_cc : {m} , STD: {s}")
        m , s = mean_std(sim_loss_list)
        print(f"mean_sim : {m} , STD: {s}")
        m , s = mean_std(nss_loss_list)
        print(f"mean_nss : {m} , STD: {s}")

        if phase == 'val' and epoch_loss < best_loss:
            best_loss = epoch_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            counter = 0
        elif phase == 'val' and epoch_loss >= best_loss:
            counter += 1
            if counter ==5:
                print('early stop!')
                break
    else:
        continue
    break
    print()


print('Best val loss: {:4f}'.format(best_loss))
model.load_state_dict(best_model_wts)

encoder.encoder.0.weight: True
encoder.encoder.1.weight: True
encoder.encoder.1.bias: True
encoder.encoder.4.0.conv1.weight: True
encoder.encoder.4.0.bn1.weight: True
encoder.encoder.4.0.bn1.bias: True
encoder.encoder.4.0.conv2.weight: True
encoder.encoder.4.0.bn2.weight: True
encoder.encoder.4.0.bn2.bias: True
encoder.encoder.4.0.conv3.weight: True
encoder.encoder.4.0.bn3.weight: True
encoder.encoder.4.0.bn3.bias: True
encoder.encoder.4.0.downsample.0.weight: True
encoder.encoder.4.0.downsample.1.weight: True
encoder.encoder.4.0.downsample.1.bias: True
encoder.encoder.4.1.conv1.weight: True
encoder.encoder.4.1.bn1.weight: True
encoder.encoder.4.1.bn1.bias: True
encoder.encoder.4.1.conv2.weight: True
encoder.encoder.4.1.bn2.weight: True
encoder.encoder.4.1.bn2.bias: True
encoder.encoder.4.1.conv3.weight: True
encoder.encoder.4.1.bn3.weight: True
encoder.encoder.4.1.bn3.bias: True
encoder.encoder.4.2.conv1.weight: True
encoder.encoder.4.2.bn1.weight: True
encoder.encoder.4.2.bn1.bias: T

109it [00:48,  2.25it/s]

train Loss: 3.8485
train KL : 0.8863
train CC : 0.6732
train SIM : 0.5238
train NSS : 1.9087
train
mean_kl : 0.886070189126041 , STD: 0.2588492037247794
mean_cc : 0.6732835672248941 , STD: 0.09381172940839672
mean_sim : 0.5238615793919345 , STD: 0.07635097573237871
mean_nss : 1.9089143816894347 , STD: 0.2860457148945107



13it [00:02,  6.34it/s]

val Loss: 0.6832
val KL : 0.6163
val CC : 0.7343
val SIM : 0.5811
val NSS : 2.0821
val
mean_kl : 0.6200805810781626 , STD: 0.10090686683478292
mean_cc : 0.7321204130466168 , STD: 0.048045951021515126
mean_sim : 0.5791160922784072 , STD: 0.04732619095304672
mean_nss : 2.076773817722614 , STD: 0.20147162050444117
Epoch 2/30
----------



109it [00:47,  2.30it/s]

train Loss: -0.0266
train KL : 0.5842
train CC : 0.7591
train SIM : 0.6163
train NSS : 2.2465
train
mean_kl : 0.5842886616330628 , STD: 0.0758434787130835
mean_cc : 0.7590079887197652 , STD: 0.035851254479135584
mean_sim : 0.6162630636757667 , STD: 0.027050242219136527
mean_nss : 2.2462808954606364 , STD: 0.16013224766822134



13it [00:02,  6.36it/s]

val Loss: 1.3033
val KL : 0.6820
val CC : 0.7102
val SIM : 0.5998
val NSS : 2.1036
val
mean_kl : 0.6892397816364582 , STD: 0.10625931155929158
mean_cc : 0.7067656837976896 , STD: 0.05539135596485143
mean_sim : 0.5974281705342807 , STD: 0.04345238676841346
mean_nss : 2.0951403196041403 , STD: 0.24903805192686448
Epoch 3/30
----------



109it [00:47,  2.30it/s]

train Loss: -1.9344
train KL : 0.4468
train CC : 0.8196
train SIM : 0.6769
train NSS : 2.4532
train
mean_kl : 0.4468014207454996 , STD: 0.05801497093021465
mean_cc : 0.8196231865007942 , STD: 0.02700689461984712
mean_sim : 0.6769589606775056 , STD: 0.02608523214647227
mean_nss : 2.4535165839239 , STD: 0.14062943142553885



13it [00:02,  6.41it/s]

val Loss: -0.3637
val KL : 0.5598
val CC : 0.7653
val SIM : 0.6422
val NSS : 2.2771
val
mean_kl : 0.5654547673005325 , STD: 0.0973299665628839
mean_cc : 0.7627027539106516 , STD: 0.047141177682684436
mean_sim : 0.6399387304599469 , STD: 0.04434486579441145
mean_nss : 2.271057679102971 , STD: 0.23537347150139593
Epoch 4/30
----------



109it [00:47,  2.30it/s]

train Loss: -2.7702
train KL : 0.3872
train CC : 0.8469
train SIM : 0.7062
train NSS : 2.5444
train
mean_kl : 0.3871874013625154 , STD: 0.05154920162081152
mean_cc : 0.8469237597710496 , STD: 0.023337162946795664
mean_sim : 0.7061643797323245 , STD: 0.021980524790346387
mean_nss : 2.5445556137539924 , STD: 0.16957103634424542



13it [00:02,  6.48it/s]

val Loss: -0.4911
val KL : 0.5486
val CC : 0.7706
val SIM : 0.6467
val NSS : 2.2801
val
mean_kl : 0.5536045730113983 , STD: 0.09932636353310337
mean_cc : 0.7685597722346966 , STD: 0.04495019336662334
mean_sim : 0.6447020539870629 , STD: 0.042504739507715
mean_nss : 2.276254708950336 , STD: 0.2231050564031888
Epoch 5/30
----------



109it [00:47,  2.30it/s]

train Loss: -3.3669
train KL : 0.3435
train CC : 0.8668
train SIM : 0.7254
train NSS : 2.6048
train
mean_kl : 0.3435040168259122 , STD: 0.04675785415638952
mean_cc : 0.8668137531761729 , STD: 0.019186848900818818
mean_sim : 0.7253539994222309 , STD: 0.02155506553253832
mean_nss : 2.605317034852614 , STD: 0.12815123670066922



13it [00:01,  6.55it/s]

val Loss: -0.3124
val KL : 0.5682
val CC : 0.7673
val SIM : 0.6463
val NSS : 2.2907
val
mean_kl : 0.573732655781966 , STD: 0.1035517281347695
mean_cc : 0.7652208438286414 , STD: 0.04516250947839038
mean_sim : 0.6443495246080252 , STD: 0.041921270480106465
mean_nss : 2.2864942183861365 , STD: 0.23044224818223422
Epoch 6/30
----------



109it [00:47,  2.32it/s]

train Loss: -3.4673
train KL : 0.3366
train CC : 0.8692
train SIM : 0.7314
train NSS : 2.6165
train
mean_kl : 0.3367020506924445 , STD: 0.04340285941887825
mean_cc : 0.869185762121043 , STD: 0.019319943124793397
mean_sim : 0.7313394595723633 , STD: 0.018617164569609518
mean_nss : 2.6164320302665782 , STD: 0.15277916116428422



13it [00:01,  6.54it/s]

val Loss: -0.2543
val KL : 0.5730
val CC : 0.7663
val SIM : 0.6464
val NSS : 2.2857
val
mean_kl : 0.5783494252424973 , STD: 0.10385197108197948
mean_cc : 0.7642540060556852 , STD: 0.04495063732670184
mean_sim : 0.6444264145997854 , STD: 0.04203448344350473
mean_nss : 2.2816891486828146 , STD: 0.22662989620148816
Epoch 7/30
----------



109it [00:47,  2.32it/s]

train Loss: -3.5534
train KL : 0.3301
train CC : 0.8723
train SIM : 0.7345
train NSS : 2.6240
train
mean_kl : 0.3300663566370623 , STD: 0.053215378197081976
mean_cc : 0.8722826280725111 , STD: 0.020954031803933806
mean_sim : 0.7344722261122607 , STD: 0.02430478044956812
mean_nss : 2.623901647165281 , STD: 0.13406274188772718



13it [00:01,  6.56it/s]

val Loss: -0.2057
val KL : 0.5770
val CC : 0.7660
val SIM : 0.6469
val NSS : 2.2813
val
mean_kl : 0.5824463321612432 , STD: 0.10675336182087486
mean_cc : 0.7639967065591079 , STD: 0.04560408771309737
mean_sim : 0.6449914849721469 , STD: 0.04261676300278605
mean_nss : 2.2774516619168796 , STD: 0.22793701605077138
Epoch 8/30
----------



109it [00:47,  2.32it/s]

train Loss: -3.5653
train KL : 0.3293
train CC : 0.8727
train SIM : 0.7349
train NSS : 2.6256
train
mean_kl : 0.32935089720498534 , STD: 0.04311177796807762
mean_cc : 0.8726561550700337 , STD: 0.018985546691814865
mean_sim : 0.7348475297656628 , STD: 0.019438318334591084
mean_nss : 2.625428805657483 , STD: 0.14295527071554648



13it [00:01,  6.51it/s]

val Loss: -0.1795
val KL : 0.5791
val CC : 0.7649
val SIM : 0.6454
val NSS : 2.2803
val
mean_kl : 0.5844631309692676 , STD: 0.10459664420189176
mean_cc : 0.763000291127425 , STD: 0.04514951376404567
mean_sim : 0.6435319781303406 , STD: 0.04193370793631337
mean_nss : 2.276536904848539 , STD: 0.22860999572888405
Epoch 9/30
----------



109it [00:47,  2.32it/s]

train Loss: -3.5715
train KL : 0.3288
train CC : 0.8726
train SIM : 0.7349
train NSS : 2.6262
train
mean_kl : 0.3288558282709997 , STD: 0.042005923660906865
mean_cc : 0.8726412139901327 , STD: 0.019801011138753318
mean_sim : 0.7348756965147246 , STD: 0.019846753311577453
mean_nss : 2.626243895346965 , STD: 0.15557411093351703



13it [00:01,  6.52it/s]

val Loss: -0.2011
val KL : 0.5789
val CC : 0.7670
val SIM : 0.6488
val NSS : 2.2872
val
mean_kl : 0.5843399052436535 , STD: 0.10586683040389414
mean_cc : 0.7651891800073477 , STD: 0.0437924291413231
mean_sim : 0.6469571773822491 , STD: 0.041246205416815
mean_nss : 2.2839256066542406 , STD: 0.22685751342168167
early stop!
Best val loss: -0.491074





<All keys matched successfully>

# Save the model below

In [8]:
savepath = r'Ar_sal_model_most_new.pth'
torch.save(model.state_dict(),savepath)