In [5]:
import argparse
import os
import numpy as np

from test import evaluation
from test import PearsonCoeff
from utils.dataset import Dataset
from utils.parse_config import parse_config

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.tensorboard import SummaryWriter

In [32]:
# if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="config/train_STSIM.cfg", help="path to data config file")

args, unknown = parser.parse_known_args()
print(args)
config = parse_config(args.config)
print(config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not os.path.isdir(config['weights_path']):
    os.mkdir(config['weights_path'])

# read training data
dataset_dir = config['dataset_dir']
label_file = config['label_file']
dist_img_folder = config['train_img_folder']
train_batch_size = int(config['train_batch_size'])
trainset = Dataset(data_dir=dataset_dir, label_file=label_file, dist_folder=dist_img_folder)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=True)

# read validation data
dataset_dir = config['dataset_dir']
dist_img_folder = config['valid_img_folder']
valid_batch_size = int(config['valid_batch_size'])
validset = Dataset(data_dir=dataset_dir, label_file=label_file, dist_folder=dist_img_folder)
valid_loader = torch.utils.data.DataLoader(validset, batch_size=valid_batch_size)

epochs = int(config['epochs'])
evaluation_interval = int(config['evaluation_interval'])
checkpoint_interval = int(config['checkpoint_interval'])

Namespace(config='config/train_STSIM.cfg')
{'gpus': '0', 'num_workers': '0', 'model': 'STSIM', 'dataset_dir': 'dataset/jana2012/', 'label_file': 'label.xlsx', 'train_img_folder': 'train', 'valid_img_folder': 'train', 'ref_img_folder': 'original', 'weights_path': 'weights/STSIM/', 'epochs': '1000', 'train_batch_size': '1000', 'valid_batch_size': '1000', 'checkpoint_interval': '25', 'evaluation_interval': '50'}


In [33]:
if config['model'] == 'STSIM':
    # prepare data
    X1_train, X2_train, Y_train, mask_train = next(iter(train_loader))
    X1_valid, X2_valid, Y_valid, mask_valid = next(iter(valid_loader))

    from steerable.sp3Filters import sp3Filters
    from metrics.STSIM import *
    m = Metric(sp3Filters, device)
    # STSIM-M features
    X1_train = m.STSIM_M(X1_train.double().to(device))
    X2_train = m.STSIM_M(X2_train.double().to(device))
    X1_valid = m.STSIM_M(X1_valid.double().to(device))
    X2_valid = m.STSIM_M(X2_valid.double().to(device))
    Y_train = Y_train.to(device)
    Y_valid = Y_valid.to(device)
    mask_train = mask_train.to(device)
    mask_valid = mask_valid.to(device)

    # learnable parameters
    model = STSIM_M([X1_train.shape[1], 10], device).double().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.01)

    for i in range(epochs):
        pred = model(X1_train, X2_train)
        loss = -PearsonCoeff(pred, Y_train, mask_train)  # min neg ==> max
        #loss = torch.mean((pred - Y_train) ** 2)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 25 == 0:
            print('training iter ' + str(i) + ' :', loss.item())
        if i % evaluation_interval == 0:    # validation
            pred = model(X1_valid, X2_valid)
            val = evaluation(pred, Y_valid, mask_valid)
            print('validation iter ' + str(i) + ' :', val)
        if i % checkpoint_interval == 0:    # save weights
            torch.save(model.state_dict(), os.path.join(config['weights_path'], 'epoch_' + str(i).zfill(4) + '.pt'))

training iter 0 : -0.4727785392293161
validation iter 0 : 0.7674194513769742
training iter 25 : -0.9364448604712491
training iter 50 : -0.9465968989969985
validation iter 50 : 0.9468765310028425
training iter 75 : -0.9529661097362881
training iter 100 : -0.9600361427655347
validation iter 100 : 0.9603249009879116
training iter 125 : -0.966575100866053
training iter 150 : -0.971195776177751
validation iter 150 : 0.9713498867943622
training iter 175 : -0.9747435491563895
training iter 200 : -0.9781039231214879
validation iter 200 : 0.9782355418456247
training iter 225 : -0.9810774603628811
training iter 250 : -0.98320349723983
validation iter 250 : 0.9832752576498688
training iter 275 : -0.9847703021071144
training iter 300 : -0.9860148876338238
validation iter 300 : 0.9860599614202215
training iter 325 : -0.987048191176593
training iter 350 : -0.9879119681187791
validation iter 350 : 0.9879436010630073
training iter 375 : -0.988647793621395
training iter 400 : -0.9892889219596992
valida

# Testing

In [37]:
import argparse

from steerable.sp3Filters import sp3Filters
from utils.dataset import Dataset
from utils.parse_config import parse_config

import torch
import torch.nn.functional as F

def PearsonCoeff(X, Y, mask):
    '''
    Args:
        X: [N, 1] neural prediction for one batch, or [N] some other metric's output
        Y: [N] label
        mask: [N] indicator of correspondent class, e.g. [0,0,1,1] ,means first two samples are class 0, the rest two samples are class 1
    Returns: Borda's rule of pearson coeff between X&Y, the same as using numpy.corrcoef()
    '''
    coeff = 0
    N = set(mask.detach().cpu().numpy())
    X = X.squeeze(-1)
    for i in N:
        X1 = X[mask == i].double()
        X1 = X1 - X1.mean()
        X2 = Y[mask == i].double()
        X2 = X2 - X2.mean()

        nom = torch.dot(X1, X2)
        denom = torch.sqrt(torch.sum(X1 ** 2) * torch.sum(X2 ** 2))

        coeff += torch.abs(nom / (denom + 1e-10))
    return coeff / len(N)

def evaluation(pred, Y, mask):
    return PearsonCoeff(pred, Y, mask).item()

In [40]:
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="config/test.cfg", help="path to data config file")
parser.add_argument("--batch_size", type=int, default=1000, help="size of each image batch")
args, unknown = parser.parse_known_args()
print(args)
config = parse_config(args.config)
print(config)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# read data
dataset_dir = config['dataset_dir']
label_file = config['label_file']
dist_img_folder = config['dist_img_folder']
testset = Dataset(data_dir=dataset_dir, label_file=label_file, dist_folder=dist_img_folder)
test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size)

X1, X2, Y, mask = next(iter(test_loader))

Namespace(batch_size=1000, config='config/test.cfg')
{'gpus': '0', 'num_workers': '0', 'model': 'STSIM', 'dataset_dir': 'dataset/jana2012/', 'label_file': 'label.xlsx', 'dist_img_folder': 'test', 'ref_img_folder': 'original', 'weights_path': 'weights/weights_DISTS_finetuned.pt'}


In [41]:
from metrics.STSIM import *
X1 = X1.to(device).double()
X2 = X2.to(device).double()
Y = Y.to(device).double()
mask = mask.to(device).double()
m_g = Metric(sp3Filters, device=device)
pred = m_g.STSIM(X1, X2)
print("STSIM-1 test:", evaluation(pred, Y, mask)) # 0.8158

pred = m_g.STSIM2(X1, X2)
print("STSIM-2 test:", evaluation(pred, Y, mask))  # 0.8517

STSIM-1 test: 0.8157561768297956
STSIM-2 test: 0.8517013571296325


In [42]:
weights_path = 'weights/STSIM/epoch_0975.pt'

In [75]:
class STSIM_M(torch.nn.Module):
    def __init__(self, dim ,device=None):
        '''
        Args:
            mode: regression, STSIM-M
            weights_path:
        '''
        super(STSIM_M, self).__init__()
        self.linear = nn.Linear(dim[0], dim[1])
        self.device = torch.device('cpu') if device is None else device

    def forward(self, X1, X2):
        '''
        Args:
            X1:
            X2:
        Returns:
        '''
        if len(X1.shape)==4:
            # the input are raw images, extract STSIM-M features
            from steerable.sp3Filters import sp3Filters
            m = Metric(sp3Filters, device=self.device)
            with torch.no_grad():
                X1 = m.STSIM_M(X1)
                X2 = m.STSIM_M(X2)
        #pred = F.sigmoid(self.linear(torch.abs(X1-X2)))	# [N, dim]
        pred = self.linear(torch.abs(X1-X2))	# [N, dim]
        pred = torch.bmm(pred.unsqueeze(1), pred.unsqueeze(-1)).squeeze(-1)	# inner-prod
        return torch.sqrt(pred)	# [N, 1]

In [76]:
# model = STSIM_M(dim = [82,10],weights_path=config['weights_path'])

In [78]:
model.load_state_dict(torch.load(weights_path))
pred = []
pred.append(model(X1,X2))
pred = torch.cat(pred, dim=0).detach()

In [79]:
print("STSIM_M test:", evaluation(pred, Y, mask))  #0.9555486244612739

STSIM_M test: 0.9555486244612739
