In [51]:
import argparse
import os
import numpy as np

from test import evaluation
from test import PearsonCoeff
from utils.dataset import Dataset
from utils.parse_config import parse_config

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.tensorboard import SummaryWriter

Load config file and dataset

In [271]:
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="config/train_STSIM.cfg", help="path to data config file")
# parser.add_argument("--config", type=str, default="config/STSIM_mode0_coeff.cfg", help="path to data config file")

args, unknown = parser.parse_known_args()
print(args)
config = parse_config(args.config)
print(config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not os.path.isdir(config['weights_path']):
    os.mkdir(config['weights_path'])

# read training data
dataset_dir = config['dataset_dir']
label_file = config['label_file']
dist_img_folder = config['train_img_folder']
train_batch_size = int(config['train_batch_size'])
trainset = Dataset(data_dir=dataset_dir, label_file=label_file, dist_folder=dist_img_folder)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=True)

# read validation data
dataset_dir = config['dataset_dir']
dist_img_folder = config['valid_img_folder']
valid_batch_size = int(config['valid_batch_size'])
validset = Dataset(data_dir=dataset_dir, label_file=label_file, dist_folder=dist_img_folder)
valid_loader = torch.utils.data.DataLoader(validset, batch_size=valid_batch_size)

epochs = int(config['epochs'])
evaluation_interval = int(config['evaluation_interval'])
checkpoint_interval = int(config['checkpoint_interval'])

Namespace(config='config/train_STSIM.cfg')
{'gpus': '0', 'num_workers': '0', 'model': 'STSIM', 'dataset_dir': 'dataset/jana2012/', 'label_file': 'label.xlsx', 'train_img_folder': 'train', 'valid_img_folder': 'train', 'ref_img_folder': 'original', 'weights_path': 'weights/STSIM/', 'loss': 'MSE', 'mode': '0', 'epochs': '1000', 'lr': '0.0001', 'train_batch_size': '1000', 'valid_batch_size': '1000', 'checkpoint_interval': '25', 'evaluation_interval': '50'}


In [241]:
# mode for STSIM model
mode = int(config['mode'])
learning_rate = float(config['lr'])
loss_type = config['loss']

In [8]:
# Prepare data for training and validation
if config['model'] == 'STSIM':
    # prepare data
    X1_train, X2_train, Y_train, mask_train = next(iter(train_loader))
    X1_valid, X2_valid, Y_valid, mask_valid = next(iter(valid_loader))

    from steerable.sp3Filters import sp3Filters
    from metrics.STSIM import *
    m = Metric(sp3Filters, device)
    # STSIM-M features
    X1_train = m.STSIM_M(X1_train.double().to(device))
    X2_train = m.STSIM_M(X2_train.double().to(device))
    X1_valid = m.STSIM_M(X1_valid.double().to(device))
    X2_valid = m.STSIM_M(X2_valid.double().to(device))
    Y_train = Y_train.to(device)
    Y_valid = Y_valid.to(device)
    mask_train = mask_train.to(device)
    mask_valid = mask_valid.to(device)

In [272]:
class STSIM_M(torch.nn.Module):
    def __init__(self, dim ,mode = 0, device=None):
        '''
        Args:
            mode: STSIM-M(0),regression(1)
        '''
        super(STSIM_M, self).__init__()
        
        self.device = torch.device('cpu') if device is None else device
        self.mode = mode
        
        if self.mode == 0: #STSIM_M
            self.linear = nn.Linear(dim[0], dim[1])
        elif self.mode == 1: #Regression
            self.hidden = torch.nn.Linear(dim[0], dim[0])
            self.predict = torch.nn.Linear(dim[0], 1)
        

    def forward(self, X1, X2):
        '''
        Args:
            X1:
            X2:
        Returns:
            pred:
        '''
        
        if len(X1.shape)==4:
            # the input are raw images, extract STSIM-M features
            from steerable.sp3Filters import sp3Filters
            m = Metric(sp3Filters, device=self.device)
            with torch.no_grad():
                X1 = m.STSIM_M(X1)
                X2 = m.STSIM_M(X2)
        if self.mode == 0: #STSIM_M
            pred = F.sigmoid(self.linear(torch.abs(X1-X2)))	# [N, dim]
#             pred = self.linear(torch.abs(X1-X2)) # [N, dim]
            pred = torch.bmm(pred.unsqueeze(1), pred.unsqueeze(-1)).squeeze(-1)	# inner-prod
            return pred
#             return torch.sqrt(pred)	# [N, 1]
        elif self.mode == 1: #Regression
            pred = F.relu(self.hidden(torch.abs(X1-X2)))     
            pred = torch.sigmoid(self.predict(pred)) 
            return pred

In [258]:
# learnable parameters
model = STSIM_M([X1_train.shape[1], 10],mode,device).double().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.MSELoss()

for i in range(epochs):
    pred = model(X1_train, X2_train)
    if loss_type == 'MSE':
        # mode等于0的时候 sigmoid不带sqrt结果最好
        loss = criterion(pred, Y_train)
#         loss = -torch.mean((pred - Y_train) ** 2)
    elif loss_type == 'Coeff':
        loss = -PearsonCoeff(pred, Y_train, mask_train)  # min neg ==> max
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i % 25 == 0:
        print('training iter ' + str(i) + ' :', loss.item())
    if i % evaluation_interval == 0:    # validation
        pred = model(X1_valid, X2_valid)
        val = evaluation(pred, Y_valid, mask_valid)
        print('validation iter ' + str(i) + ' :', val)
    if i % checkpoint_interval == 0:    # save weights
        torch.save(model.state_dict(), os.path.join(config['weights_path'], 'epoch_' + str(i).zfill(4) + '.pt'))

training iter 0 : -0.6648650264283136
validation iter 0 : 0.8623489408413969
training iter 25 : -0.9447467461103607
training iter 50 : -0.9564949923562294
validation iter 50 : 0.9569015594171175
training iter 75 : -0.9654109941209393
training iter 100 : -0.9721345545303084
validation iter 100 : 0.9723673072580347
training iter 125 : -0.9771073972724895
training iter 150 : -0.9805526755332762
validation iter 150 : 0.980658135135432
training iter 175 : -0.982769663137781
training iter 200 : -0.9843979319076992
validation iter 200 : 0.9844546018030856
training iter 225 : -0.9856833622719325
training iter 250 : -0.9867744235302842
validation iter 250 : 0.9868150650265637
training iter 275 : -0.9877295216049641
training iter 300 : -0.9885648975166983
validation iter 300 : 0.9885960637678698
training iter 325 : -0.9893010884702887
training iter 350 : -0.9899633897995017
validation iter 350 : 0.9899886438840266
training iter 375 : -0.9905721266906184
training iter 400 : -0.9910269035562498
va

# Testing

In [15]:
import argparse

from steerable.sp3Filters import sp3Filters
from utils.dataset import Dataset
from utils.parse_config import parse_config

import torch
import torch.nn.functional as F

def PearsonCoeff(X, Y, mask):
    '''
    Args:
        X: [N, 1] neural prediction for one batch, or [N] some other metric's output
        Y: [N] label
        mask: [N] indicator of correspondent class, e.g. [0,0,1,1] ,means first two samples are class 0, the rest two samples are class 1
    Returns: Borda's rule of pearson coeff between X&Y, the same as using numpy.corrcoef()
    '''
    coeff = 0
    N = set(mask.detach().cpu().numpy())
    X = X.squeeze(-1)
    for i in N:
        X1 = X[mask == i].double()
        X1 = X1 - X1.mean()
        X2 = Y[mask == i].double()
        X2 = X2 - X2.mean()

        nom = torch.dot(X1, X2)
        denom = torch.sqrt(torch.sum(X1 ** 2) * torch.sum(X2 ** 2))

        coeff += torch.abs(nom / (denom + 1e-10))
    return coeff / len(N)

def evaluation(pred, Y, mask):
    return PearsonCoeff(pred, Y, mask).item()

In [31]:
# Read config and data for testing
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="config/test.cfg", help="path to data config file")
parser.add_argument("--batch_size", type=int, default=1000, help="size of each image batch")
args, unknown = parser.parse_known_args()
print(args)
config = parse_config(args.config)
print(config)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# read data
dataset_dir = config['dataset_dir']
label_file = config['label_file']
dist_img_folder = config['dist_img_folder']
testset = Dataset(data_dir=dataset_dir, label_file=label_file, dist_folder=dist_img_folder)
test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size)

X1, X2, Y, mask = next(iter(test_loader))

Namespace(batch_size=1000, config='config/test.cfg')
{'gpus': '0', 'num_workers': '0', 'model': 'STSIM', 'dataset_dir': 'dataset/jana2012/', 'label_file': 'label.xlsx', 'dist_img_folder': 'test', 'ref_img_folder': 'original', 'weights_path': 'weights/weights_DISTS_finetuned.pt'}


In [32]:
# Testing STSIM-1 and STSIM-2
from metrics.STSIM import *
X1 = X1.to(device).double()
X2 = X2.to(device).double()
Y = Y.to(device).double()
mask = mask.to(device).double()
m_g = Metric(sp3Filters, device=device)
pred = m_g.STSIM(X1, X2)
print("STSIM-1 test:", evaluation(pred, Y, mask)) # 0.8158

pred = m_g.STSIM2(X1, X2)
print("STSIM-2 test:", evaluation(pred, Y, mask))  # 0.8517

STSIM-1 test: 0.8157561768297956
STSIM-2 test: 0.8517013571296325


In [275]:
# Testing model with saved weights
saved_weights_path = 'weights/weights_STSIM_mode0_mse.pt'
# saved_weights_path = 'weights/STSIM/epoch_0975.pt'
mode = 0 # STSIM_M(0) regression(1)
model = STSIM_M([X1_train.shape[1], 10],mode,device).double().to(device)

In [276]:
model.load_state_dict(torch.load(saved_weights_path))
pred = []
pred.append(model(X1,X2))
pred = torch.cat(pred, dim=0).detach()

In [277]:
print("STSIM_M test:", evaluation(pred, Y, mask))  #0.9195754925772788

STSIM_M test: 0.919878942383851
