In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import numpy as np
from dotmap import DotMap
import wandb
from wandb.wandb_run import Run
import pickle
from PIL import Image
from pathlib import Path
from typing import List, Tuple
from datetime import datetime
from einops import rearrange
from sklearn.linear_model import Ridge
from scipy import stats
import argparse
from tqdm import tqdm
import os
import torchvision.transforms as transforms

from data import SCINDataset

In [2]:
def train_linear_regressor_on_dataset(model: nn.Module,
                                      dataset: Dataset, 
                                      batch_size: int, 
                                      num_workers: int,
                                      device: torch.device, 
                                      save_dir: str):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    features, scores = get_features_scores(model, dataloader, device, save_dir)

    # perform grid search to find best alpha values
    best_alpha = alpha_grid_search(dataset=dataset, features=features, scores=scores, num_splits=10)
    #best_alpha = 0.151991108295293
    print(f'Best alpha found: {best_alpha}')

    # Generate indices for training on the whole dataset
    train_indices = np.arange(len(dataset))
    train_indices = np.repeat(train_indices * 5, 5) + np.tile(np.arange(5), len(train_indices))  # for each index generate 5 indices (one for each crop)

    # Select train features and scores using the generated indices
    train_features = features[train_indices]
    train_scores = scores[train_indices]

    # Fit a Ridge regressor to the training data
    regressor = Ridge(alpha=best_alpha).fit(train_features, train_scores)

    # Save the trained regressor to a file
    with open(f"scin_regressor_{best_alpha}.pkl", "wb") as f:
        pickle.dump(regressor, f)

    return regressor

In [6]:
def get_features_scores(model: nn.Module,
                        dataloader: DataLoader,
                        device: torch.device,
                        save_dir: str) -> Tuple[np.ndarray, np.ndarray]:
    feats_file = os.path.join(save_dir, "features.npy")
    scores_file = os.path.join(save_dir, "scores.npy")
    """
    if os.path.exists(feats_file) and os.path.exists(scores_file):
        feats = np.load(feats_file)
        scores = np.load(scores_file)
        print('Loaded features and scores from saved file in SCIN.')
        return feats, scores"""
    feats = np.zeros((0, model.encoder.feat_dim * 2))   # Double the features because of the original and downsampled image
    scores = np.zeros(0)
    #progress_bar = tqdm(total=len(dataloader), desc="Extracting features")
    total_iterations = len(dataloader) * dataloader.batch_size
    with tqdm(total=total_iterations, desc="Extracting features", leave=False) as progress_bar:
        for i, batch in enumerate(dataloader):
            img_orig = batch["img"].to(device)
            img_ds = batch["img_ds"].to(device)
            mos = batch["mos"]
    
            img_orig = rearrange(img_orig, "b n c h w -> (b n) c h w")
            img_ds = rearrange(img_ds, "b n c h w -> (b n) c h w")
            mos = mos.repeat_interleave(5)  # repeat MOS for each crop
    
            with torch.cuda.amp.autocast(), torch.no_grad():
                _, f = model(img_orig, img_ds, return_embedding=True)
    
            feats = np.concatenate((feats, f.cpu().numpy()), 0)
            scores = np.concatenate((scores, mos.numpy()), 0)
            progress_bar.update(1)
    np.save(feats_file, feats)
    np.save(scores_file, scores)
    return feats, scores

In [4]:
grid_search_range = [1e-3, 1e3, 100]
grid_search_range

[0.001, 1000.0, 100]

In [5]:
grid_search_range = [1e-1, 1e1, 1]
grid_search_range

[0.1, 10.0, 1]

In [4]:
def alpha_grid_search(dataset: Dataset,
                      features: np.ndarray,
                      scores: np.ndarray,
                      num_splits: int) -> float:
    #grid_search_range = [1e-3, 1e3, 100]
    grid_search_range = [1e-2, 1e2, 10]
    #grid_search_range = [1e-1, 1e1, 1]
    alphas = np.geomspace(*grid_search_range, endpoint=True)
    srocc_all = [[] for _ in range(len(alphas))]
    with tqdm(total=num_splits * len(alphas), desc='Grid Search', unit='split') as pbar:
        for i in range(num_splits):
            train_indices = dataset.get_split_indices(split=i, phase="train")
            val_indices = dataset.get_split_indices(split=i, phase="val")
    
            # for each index generate 5 indices (one for each crop)
            train_indices = np.repeat(train_indices * 5, 5) + np.tile(np.arange(5), len(train_indices))
            val_indices = np.repeat(val_indices * 5, 5) + np.tile(np.arange(5), len(val_indices))
    
            train_features = features[train_indices]
            train_scores = scores[train_indices]
    
            val_features = features[val_indices]
            val_scores = scores[val_indices]
            val_scores = val_scores[::5]  # Scores are repeated for each crop, so we only keep the first one
    
            for idx, alpha in enumerate(alphas):
                regressor = Ridge(alpha=alpha).fit(train_features, train_scores)
                preds = regressor.predict(val_features)
                preds = np.mean(np.reshape(preds, (-1, 5)), 1)  # Average the predictions of the 5 crops of the same image
                srocc_all[idx].append(stats.spearmanr(preds, val_scores)[0])
                pbar.update(1)

    srocc_all_median = [np.median(srocc) for srocc in srocc_all]
    srocc_all_median = np.array(srocc_all_median)
    best_alpha_idx = np.argmax(srocc_all_median)
    best_alpha = alphas[best_alpha_idx]

    return best_alpha

In [5]:
DEVICE = torch.device("cuda") if torch.cuda.is_available() else "cpu"
arniqa = torch.hub.load(repo_or_dir="miccunifi/ARNIQA", source="github", model="ARNIQA")
arniqa.eval().to(DEVICE)
next(arniqa.parameters()).is_cuda

Using cache found in /home/jovyan/.cache/torch/hub/miccunifi_ARNIQA_main


True

In [7]:
DATA_PATH = "SCIN_v2"
NUM_SPLITS = 10
BATCH_SIZE = 4
NUM_WORKERS = 1
GRID_SEARCH = True
# ALPHA = 0.151991108295293 # For simplicity, a fixed alpha value is used
CROP_SIZE = 224

In [8]:
# Dataset
scin_dataset = SCINDataset(root=DATA_PATH, phase="all", crop_size=CROP_SIZE)

In [None]:
regressor = train_linear_regressor_on_dataset(model=arniqa, dataset=scin_dataset, batch_size=BATCH_SIZE, 
                                              num_workers=NUM_WORKERS, device=DEVICE, save_dir=DATA_PATH)

Grid Search:  10%|█         | 10/100 [03:59<36:31, 24.35s/split]        

# Test Image

In [6]:
with open("scin_regressor_0.21544346900318834.pkl", "rb") as f:
    regressor = pickle.load(f)

In [8]:
def display_image(image_path):
    width, height = image_path.size
    print(f"Image size: {width}x{height} pixels")
    display(image_path)

In [7]:
img_path = "../test_images/fig5c.jpg"
#img_path = "../test_images/fig4e.png"
img = Image.open(img_path).convert("RGB")
img_ds = transforms.Resize((img.size[1] // 2, img.size[0] // 2))(img)

preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

img = preprocess(img).unsqueeze(0).to(DEVICE)
img_ds = preprocess(img_ds).unsqueeze(0).to(DEVICE)

In [8]:
test = 0

if test == 0:
    with torch.no_grad(), torch.cuda.amp.autocast():
        _, features = arniqa(img, img_ds, return_embedding=True, scale_score=True)
        features = features.cpu().numpy()
    predicted_score = regressor.predict(features)
    print("Predicted MOS Score:", predicted_score)
elif test == 1:
    with torch.no_grad(), torch.cuda.amp.autocast():
        predicted_score = arniqa(img, img_ds, return_embedding=False, scale_score=True)
    print(f"Image quality score: {predicted_score.item()}")

Predicted MOS Score: [2.66327731]


In [11]:
regressor.coef_

array([-2.47138855, -1.53796235, -0.64137025, ...,  0.46897656,
        2.30539767,  2.18712411])

In [12]:
regressor.intercept_

3.025981603865405

In [13]:
regressor.alpha

0.21544346900318834

In [10]:
arniqa.regressor

RecursiveScriptModule(original_name=TorchLinearRegression)

In [22]:
import torch
import torch.nn as nn

class CustomRegressor(nn.Module):
    def __init__(self, regressor):
        super(CustomRegressor, self).__init__()
        self.regressor = regressor
    
    def forward(self, x):
        x_cpu = x.cpu()
        return torch.tensor(self.regressor.predict(x_cpu.numpy()))

# Convert the scikit-learn regressor to a PyTorch module
custom_regressor = CustomRegressor(regressor)
arniqa.regressor = custom_regressor

In [23]:
arniqa.regressor

CustomRegressor()