In [1]:
import os
import argparse
import random
import numpy as np

In [2]:
import torch
import torchvision
import folders

class DataLoader(object):
    """Dataset class for IQA databases"""

    def __init__(self, dataset, path, img_indx, patch_size, patch_num, batch_size=1, istrain=True):

        self.batch_size = batch_size
        self.istrain = istrain

        if (dataset == 'live') | (dataset == 'csiq') | (dataset == 'tid2013') | (dataset == 'livec'):
            # Train transforms
            if istrain:
                transforms = torchvision.transforms.Compose([
                    torchvision.transforms.RandomHorizontalFlip(),
                    torchvision.transforms.RandomCrop(size=patch_size),
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                                     std=(0.229, 0.224, 0.225))
                ])
            # Test transforms
            else:
                transforms = torchvision.transforms.Compose([
                    torchvision.transforms.RandomCrop(size=patch_size),
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                                     std=(0.229, 0.224, 0.225))
                ])
        elif dataset == 'koniq-10k':
            if istrain:
                transforms = torchvision.transforms.Compose([
                    torchvision.transforms.RandomHorizontalFlip(),
                    torchvision.transforms.Resize((512, 384)),
                    torchvision.transforms.RandomCrop(size=patch_size),
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                                     std=(0.229, 0.224, 0.225))])
            else:
                transforms = torchvision.transforms.Compose([
                    torchvision.transforms.Resize((512, 384)),
                    torchvision.transforms.RandomCrop(size=patch_size),
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                                     std=(0.229, 0.224, 0.225))])
        elif dataset == 'bid':
            if istrain:
                transforms = torchvision.transforms.Compose([
                    torchvision.transforms.RandomHorizontalFlip(),
                    torchvision.transforms.Resize((512, 512)),
                    torchvision.transforms.RandomCrop(size=patch_size),
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                                     std=(0.229, 0.224, 0.225))])
            else:
                transforms = torchvision.transforms.Compose([
                    torchvision.transforms.Resize((512, 512)),
                    torchvision.transforms.RandomCrop(size=patch_size),
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                                     std=(0.229, 0.224, 0.225))])

        if dataset == 'live':
            self.data = folders.LIVEFolder(
                root=path, index=img_indx, transform=transforms, patch_num=patch_num)
        elif dataset == 'livec':
            self.data = folders.LIVEChallengeFolder(
                root=path, index=img_indx, transform=transforms, patch_num=patch_num)
        elif dataset == 'csiq':
            self.data = folders.CSIQFolder(
                root=path, index=img_indx, transform=transforms, patch_num=patch_num)
        elif dataset == 'koniq-10k':
            self.data = folders.Koniq_10kFolder(
                root=path, index=img_indx, transform=transforms, patch_num=patch_num)
        elif dataset == 'tid2013':
            self.data = folders.TID2013Folder(
                root=path, index=img_indx, transform=transforms, patch_num=patch_num)
        
        elif dataset == 'bid':
            self.data = folders.CustomDataSet(
                root=path, index=img_indx, transform=transforms, patch_num=patch_num)

    def get_data(self):
        if self.istrain:
            dataloader = torch.utils.data.DataLoader(
                self.data, batch_size=self.batch_size, shuffle=True)
        else:
            dataloader = torch.utils.data.DataLoader(
                self.data, batch_size=1, shuffle=False)
        return dataloader

In [3]:
# data_train = '../Data/train/'
# data_test = '../Data/test/'
# data_val = '../Data/valid/'

# patch_size = 224

# transforms = torchvision.transforms.Compose([
#                     torchvision.transforms.RandomHorizontalFlip(),
#                     torchvision.transforms.Resize((512, 512)),
#                     torchvision.transforms.RandomCrop(size=patch_size),
#                     torchvision.transforms.ToTensor(),
#                     torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
#                                                      std=(0.229, 0.224, 0.225))])

In [4]:
# import folders

# data = folders.CustomDataSet(root=data_train, index=0, transform=transforms, patch_num=25)
# data.__len__()

In [5]:
import torch
from scipy import stats
import numpy as np
import models
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error as mae
from tqdm import tqdm


class HyperIQASolver(object):
    """Solver for training and testing hyperIQA"""
    def __init__(self, config, path, train_idx, test_idx):

        self.epochs = config.epochs
        self.test_patch_num = config.test_patch_num

        self.model_hyper = models.HyperNet(16, 112, 224, 112, 56, 28, 14, 7).to(device)
        self.model_hyper.train(True)

        self.l1_loss = torch.nn.L1Loss().to(device)

        backbone_params = list(map(id, self.model_hyper.res.parameters()))
        self.hypernet_params = filter(lambda p: id(p) not in backbone_params, self.model_hyper.parameters())
        self.lr = config.lr
        self.lrratio = config.lr_ratio
        self.weight_decay = config.weight_decay
        paras = [{'params': self.hypernet_params, 'lr': self.lr * self.lrratio},
                 {'params': self.model_hyper.res.parameters(), 'lr': self.lr}
                 ]
        self.solver = torch.optim.Adam(paras, weight_decay=self.weight_decay)

        train_loader = DataLoader(config.dataset, config.train_path, train_idx, config.patch_size, config.train_patch_num, batch_size=config.batch_size, istrain=True)
        test_loader = DataLoader(config.dataset, config.test_path, test_idx, config.patch_size, config.test_patch_num, istrain=False)
        self.train_data = train_loader.get_data()
        self.test_data = test_loader.get_data()

    def train(self):
        """Training"""
        best_srcc = 0.0
        best_plcc = 0.0
        print('Epoch\tTrain_Loss\tTrain_SRCC\tTest_SRCC\tTest_PLCC\tTest_MAE\tTest_MSE')
        for t in range(self.epochs):
            epoch_loss = []
            pred_scores = []
            gt_scores = []

            for img, label in tqdm(self.train_data):
                img = torch.tensor(img.to(device))
                label = torch.tensor(label.to(device))

                self.solver.zero_grad()

                # Generate weights for target network
                paras = self.model_hyper(img)  # 'paras' contains the network weights conveyed to target network

                # Building target network
                model_target = models.TargetNet(paras).to(device)
                for param in model_target.parameters():
                    param.requires_grad = False

                # Quality prediction
                pred = model_target(paras['target_in_vec'])  # while 'paras['target_in_vec']' is the input to target net
                pred_scores = pred_scores + pred.cpu().tolist()
                gt_scores = gt_scores + label.cpu().tolist()

                loss = self.l1_loss(pred.squeeze(), label.float().detach())
                epoch_loss.append(loss.item())
                loss.backward()
                self.solver.step()

            train_srcc, _ = stats.spearmanr(pred_scores, gt_scores)

            test_srcc, test_plcc, test_mae, test_mse = self.test(self.test_data)
            if test_srcc > best_srcc:
                best_srcc = test_srcc
                best_plcc = test_plcc
            
            # save model
            weights_file = "./weight/epoch_{}.pth".format(t+1)
            torch.save({
                'epoch': t,
                'model_hyper_state_dict': self.model_hyper.state_dict(),
                'model_target_state_dict': model_target.state_dict(),
                'optimizer_state_dict': self.solver.state_dict(),
                'loss': sum(epoch_loss) / len(epoch_loss)
            }, weights_file)

            print('%d\t%4.3f\t\t%4.4f\t\t%4.4f\t\t%4.4f\t\t%4.4f\t\t%4.4f' %
                  (t + 1, sum(epoch_loss) / len(epoch_loss), train_srcc, test_srcc, test_plcc, test_mae, test_mse))

            # Update optimizer
            lr = self.lr / pow(10, (t // 6))
            if t > 8:
                self.lrratio = 1
            self.paras = [{'params': self.hypernet_params, 'lr': lr * self.lrratio},
                          {'params': self.model_hyper.res.parameters(), 'lr': self.lr}
                          ]
            self.solver = torch.optim.Adam(self.paras, weight_decay=self.weight_decay)

        print('Best test SRCC %f, PLCC %f' % (best_srcc, best_plcc))

        return best_srcc, best_plcc

    def test(self, data):
        """Testing"""
        self.model_hyper.train(False)
        pred_scores = []
        gt_scores = []

        for img, label in data:
            # Data.
            img = torch.tensor(img.to(device))
            label = torch.tensor(label.to(device))

            paras = self.model_hyper(img)
            model_target = models.TargetNet(paras).to(device)
            model_target.train(False)
            pred = model_target(paras['target_in_vec'])

            pred_scores.append(float(pred.item()))
            gt_scores = gt_scores + label.cpu().tolist()

        pred_scores = np.mean(np.reshape(np.array(pred_scores), (-1, self.test_patch_num)), axis=1)
        gt_scores = np.mean(np.reshape(np.array(gt_scores), (-1, self.test_patch_num)), axis=1)
        test_srcc, _ = stats.spearmanr(pred_scores, gt_scores)
        test_plcc, _ = stats.pearsonr(pred_scores, gt_scores)
        test_mse = mean_squared_error(pred_scores, gt_scores)
        test_mae = mae(pred_scores, gt_scores)

        self.model_hyper.train(True)
        return test_srcc, test_plcc, test_mae, test_mse

In [None]:
import pandas as pd
import argparse

class Config:
    dataset = 'bid'
    train_patch_num = 3
    test_patch_num = 3
    lr = 2e-5
    weight_decay = 5e-4
    lr_ratio = 10
    batch_size = 20
    epochs = 5
    patch_size = 224
    train_test_num = 10
    train_path = '../Data/train/'
    test_path = '../Data/test/'
    
config = Config()
config.train_patch_num

In [9]:
device = torch.device("cuda:3")

In [None]:

srcc_all = np.zeros(config.train_test_num, dtype=np.float)
plcc_all = np.zeros(config.train_test_num, dtype=np.float)
sel_num = list(range(0, 29))
print('Training and testing on %s dataset for %d rounds...' % (config.dataset, config.train_test_num))
for i in range(config.train_test_num):
    print('Round %d' % (i+1))
    # Randomly select 80% images for training and the rest for testing
    random.shuffle(sel_num)
    train_index = sel_num[0:int(round(0.8 * len(sel_num)))]
    test_index = sel_num[int(round(0.8 * len(sel_num))):len(sel_num)]

    solver = HyperIQASolver(config, ' ', train_index, test_index)
    srcc_all[i], plcc_all[i] = solver.train()

# print(srcc_all)
# print(plcc_all)
srcc_med = np.median(srcc_all)
plcc_med = np.median(plcc_all)

print('Testing median SRCC %4.4f,\tmedian PLCC %4.4f' % (srcc_med, plcc_med))

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  srcc_all = np.zeros(config.train_test_num, dtype=np.float)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  plcc_all = np.zeros(config.train_test_num, dtype=np.float)


Training and testing on bid dataset for 10 rounds...
Round 1
Epoch	Train_Loss	Train_SRCC	Test_SRCC	Test_PLCC	Test_MAE	Test_MSE


  img = torch.tensor(img.to(device))
  label = torch.tensor(label.to(device))
  0%|▍                                                                                                                                               | 75/26206 [00:23<1:59:37,  3.64it/s]