In [None]:
from typing import Optional, Sequence

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import escnn
import timm
import tqdm
from torchvision.transforms import functional as F
from escnn import gspaces
from PIL import Image
from omegaconf import DictConfig

import torch
from torchvision.transforms import Resize, ToTensor
np.set_printoptions(linewidth=10000)

In [None]:
def test_model_single_image(config, x: torch.Tensor, N: int = 4, k: int = 5):
    x = Image.fromarray(x.cpu().numpy().transpose(1, 2, 0), mode='RGB')

    # to reduce interpolation artifacts (e.g. when testing the model on rotated images),
    # we upsample an image by a factor of 3, rotate it and finally downsample it again
    resize = Resize(224) # to upsample
    totensor = ToTensor()
    x = resize(x)

    # evaluate the model on N rotated versions of the input image x
    print()
    print('##########################################################################################')
    header = 'angle  |  ' + '  '.join(["{:5d}".format(d) for d in range(10)])
    print(header)
    results = np.zeros(shape=(k * N, 10))
    with torch.no_grad():
        for i in range(1, k+1):
            model = hydra.utils.instantiate(config.model) # Hydra
            model.eval()
            for r in range(N):
                x_transformed = totensor(x.rotate(r*360./N, Image.BILINEAR)).reshape(3, 224, 224).unsqueeze(0)

                y = model(x_transformed)
                y = y.numpy().squeeze()
                results[i * r, 🙂 = y
                
                angle = r * 360. / N
                if i == 1:
                    print("{:6.1f} : {}".format(angle, y))
    print(f"mean: {results.mean(axis=0)}")
    print(f"std: {results.std(axis=0)}")
    print('##########################################################################################')
    print()


def compare_two_models_relative_errors(x, model, eq_model, W=225):
    # r2_act = gspaces.rot2dOnR2(-1, maximum_frequency=32)
    # r2_act = gspaces.flipRot2dOnR2(-1, maximum_frequency=32)
    N = 4
    r2_act = gspaces.flipRot2dOnR2(N)
    G = r2_act.fibergroup
    x = Resize(W)(x)
    # create the mask for the input
    input_center_mask = torch.ones_like(x)

    # mask the input image
    x = x * input_center_mask
    x = eq_model.in_type(x)

    # compute the output of both models
    y_equivariant = eq_model.forward_features(x)
    y_conventional = model.forward_features(x.tensor)
    # y_conventional = eq_model.out_type(model.forward_features(x.tensor))

    # create the mask for the output images
    output_center_mask = torch.ones_like(y_equivariant)

    # We evaluate the equivariance error on N=4 rotations

    error_equivariant = []
    error_conventional = []

    # for each of the N rotations
    for i in tqdm.tqdm(range(N)):
        g = G.element((0, i))

        # rotate the input
        x_transformed = x.transform(g)
        x_transformed.tensor *= input_center_mask

        # F(g.X)  feed the transformed images in both models
        y_from_x_transformed_equivariant = eq_model.forward_features(x_transformed)
        y_from_x_transformed_conventional = model.forward_features(x_transformed.tensor)

        # g.F(x)  transform the output of both models
        y_transformed_from_x_equivariant = y_equivariant.transform(g)
        # y_transformed_from_x_conventional = y_conventional.transform(g)
        y_transformed_from_x_conventional = F.rotate(y_conventional, i / N * 2*np.pi)

        # mask all the outputs
        y_from_x_transformed_equivariant = y_from_x_transformed_equivariant * output_center_mask
        y_from_x_transformed_conventional = y_from_x_transformed_conventional * output_center_mask
        y_transformed_from_x_equivariant = y_transformed_from_x_equivariant.tensor * output_center_mask
        # y_transformed_from_x_conventional = y_transformed_from_x_conventional.tensor * output_center_mask
        y_transformed_from_x_conventional = y_transformed_from_x_conventional * output_center_mask

        # compute the relative error of both models
        rel_error_equivariant = torch.norm(y_from_x_transformed_equivariant - y_transformed_from_x_equivariant).item() / torch.norm(y_equivariant.tensor).item()
        rel_error_conventional = torch.norm(y_from_x_transformed_conventional - y_transformed_from_x_conventional).item() / torch.norm(y_conventional).item()

        error_equivariant.append(rel_error_equivariant)
        error_conventional.append(rel_error_conventional)

    # plot the error of both models as a function of the rotation angle theta
    fig, ax = plt.subplots(figsize=(10, 6))

    xs = [i / N * 2*np.pi for i in range(N)]
    plt.plot(xs, error_equivariant, label='SO(2)-Steerable CNN')
    plt.plot(xs, error_conventional, label='Conventional CNN')
    plt.title('Equivariant vs Conventional CNNs', fontsize=20)
    plt.xlabel(r'$g = r_\theta$', fontsize=20)
    plt.ylabel('Equivariance Error', fontsize=20)
    ax.tick_params(axis='both', which='major', labelsize=15)
    plt.legend(fontsize=20)
    i = 0
    path = "/home/adrian/pdp/equivariant-map-gen/equivariantgan/storage/figures/"
    name = "compare_equivariance.png"
    while name in os.listdir(path):
        i += 1
        name = f"compare_equivariance{i}.png"
    plt.savefig(path + name)


def _test(config: DictConfig) -> None:
    """
    Quick check, whether the model works
    """
    datamodule = hydra.utils.instantiate(config.datamodule, _recursive_=False)
    test_loader = datamodule.test_dataloader()
    x = next(iter(test_loader))["input"].squeeze()
    lightning_model = hydra.utils.instantiate(config.lightning_model, _recursive_=False, _convert_="partial")
    eq_model = lightning_model.model.cpu()
    eq_model.eval()
    # model = timm.create_model("regnety_002", pretrained=False)
    kwargs = {
        "target": "classification.models.regnety.RegNetTimm",
        "num_classes": 10,
        "class_weights": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        "model": {
            "target": "timm.models.regnet.RegNet",
            "cfg": {
                "target": "classification.models.e2_regnety_simple.RegNetCfg",
                "w0": 16,
                "wa": 7.5,
                "wm": 1.83,
                "group_size": 8,
                "depth": 16,
                "se_ratio": 0.25
            },
            "in_chans": 3,
            "output_stride": 32,
            "global_pool": "avg",
            "drop_rate": 0.0,
            "drop_path_rate": 0.0,
            "zero_init_last": True,
        }
    }
    lightning_config = dict(config.lightning_model)
    lightning_config["model"] = kwargs
    lightning_config["lr_scheduler"]["t_initial"] = 50
    model = hydra.utils.instantiate(lightning_config, _recursive_=False, _convert_="partial").model.cpu()
    model.eval()
    
    with torch.no_grad():
        compare_two_models_relative_errors(x, model, eq_model)
    

if _name_ == "_main_":
    _test()