In [2]:
import torch
from einops import rearrange
import json
import os
import sys
import inspect
from activations.torch import Rational
from clip.model import Bottleneck


from pathlib import Path
# from .find_init
# _weights import find_weights
# from .warnings import RationalImportError
import numpy as np
from termcolor import colored
known_functions = {
    "relu": lambda x: 0 if x < 0 else x,
    "leaky_relu": lambda x: x/100 if x < 0 else x,
    "lrelu": lambda x: x/100 if x < 0 else x,
    "normal": lambda x: 1/np.sqrt(2*np.pi) * np.exp(-.5*x**2),
}
from magma.image_prefix import get_image_encoder

Torch distributed Default Port: 29500


In [48]:
enc = get_image_encoder(
            "clip_RN50",
            convert_to_rational=True
        ).to("cuda:5")

In [4]:
def freeze_rational_clip(enc):
    for child in enc.children():
        if isinstance(child, torch.nn.Sequential):
            for seq_child in child.children():
                if isinstance(seq_child, Bottleneck):
                    for bot_params in seq_child.children():
                        if isinstance(bot_params, (Rational, torch.nn.BatchNorm2d)):
                            
                            for param in bot_params.parameters():
                                param.requires_grad = True

                        else:
                            for param in bot_params.parameters():
                                param.requires_grad = False
                else:
                    for param in seq_child.parameters():
                        param.requires_grad = False
        else:
            for param in child.parameters():
                param.requires_grad = False
    return enc

In [43]:
enc = freeze_rational_clip(enc).to(dtype=torch.float16)

In [49]:
for i, (name, module) in enumerate(list(enc.named_parameters())):
    print(name,module.dtype)
    # if isinstance(module, torch.nn.BatchNorm2d): 
    #     prev_module = list(enc.named_modules())[i-1][1]
    #     print(module.running_var.dtype)
    #     enc.__setattr__(name, torch.nn.LayerNorm(module.weight.shape))
       

conv1.weight torch.float16
bn1.weight torch.float32
bn1.bias torch.float32
conv2.weight torch.float16
bn2.weight torch.float32
bn2.bias torch.float32
conv3.weight torch.float16
bn3.weight torch.float32
bn3.bias torch.float32
layer1.0.conv1.weight torch.float16
layer1.0.bn1.weight torch.float32
layer1.0.bn1.bias torch.float32
layer1.0.relu1.numerator torch.float32
layer1.0.relu1.denominator torch.float32
layer1.0.conv2.weight torch.float16
layer1.0.bn2.weight torch.float32
layer1.0.bn2.bias torch.float32
layer1.0.relu2.numerator torch.float32
layer1.0.relu2.denominator torch.float32
layer1.0.conv3.weight torch.float16
layer1.0.bn3.weight torch.float32
layer1.0.bn3.bias torch.float32
layer1.0.relu3.numerator torch.float32
layer1.0.relu3.denominator torch.float32
layer1.0.downsample.0.weight torch.float16
layer1.0.downsample.1.weight torch.float32
layer1.0.downsample.1.bias torch.float32
layer1.1.conv1.weight torch.float16
layer1.1.bn1.weight torch.float32
layer1.1.bn1.bias torch.float32


In [45]:
enc = enc.to("cuda:5")
x = torch.randn(1, 3, 224, 224).to('cuda:5', dtype=torch.float16)


In [46]:
enc(x)

tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]]], device='cuda:5',
       dtype=torch.float16, grad_fn=<ReshapeAliasBackward0>)

In [None]:
def get_parameters(rational_version, degrees, approx_func, k=None):
    nd, dd = degrees
    if approx_func == "identity":
        return [0., 1.] + [0.] * (nd - 1), [0.] * dd
    elif approx_func == "ones":
        return [1.] * (nd + 1), [1.] * dd
    rational_full_name = f"Rational_version_{rational_version.upper()}{nd}/{dd}"
    if rational_version.lower() == "rare":
        nd -= 2
        rational_full_name += f"_k_{k}"
    config_file = '../rationals_config.json'
    config_file_dir = str(Path(os.path.abspath(__file__)).parent)
    url = "https://rational-activations.readthedocs.io/en/latest/tutorials/tutorials.1_find_weights_for_initialization.html"
    with open(os.path.join(config_file_dir, config_file)) as json_file:
        rationals_dict = json.load(json_file)
    if rational_full_name not in rationals_dict:
        if approx_func.lower() in known_functions:
            msg = f"Found {approx_func} but haven't computed its rational approximation yet for degrees {degrees}.\
            \nLet's do do it now:. \n--> More info:"
            print(colored(msg, "yellow"))
            print(colored(url, "blue"))
            find_weights(known_functions[approx_func.lower()], function_name=approx_func.lower(), degrees=degrees, version=rational_version)
            with open(os.path.join(config_file_dir, config_file)) as json_file:
                rationals_dict = json.load(json_file)
        else:
            msg = f"{rational_full_name} approximating \"{approx_func}\" not found in {config_file}.\
            \nWe need to add it.\nLet's do do it now. \n--> More info:"
            print(colored(msg, "yellow"))
            print(colored(url, "blue"))
            find_weights(known_functions[approx_func.lower()], function_name=approx_func.lower(), degrees=degrees, version=rational_version)
            with open(os.path.join(config_file_dir, config_file)) as json_file:
                rationals_dict = json.load(json_file)
    if approx_func not in rationals_dict[rational_full_name]:
        if approx_func.lower() in known_functions:
            msg = f"Found {approx_func} but haven't computed its rational approximation yet for degrees {degrees}.\
            \nLet's do do it now:. \n--> More info:"
            print(colored(msg, "yellow"))
            print(colored(url, "blue"))
            find_weights(known_functions[approx_func.lower()], function_name=approx_func.lower(), degrees=degrees, version=rational_version)
            with open(os.path.join(config_file_dir, config_file)) as json_file:
                rationals_dict = json.load(json_file)
        else:
            msg = f"{rational_full_name} approximating {approx_func} not found in {config_file}.\
            \nWe need to add it.\nLet's do do it now. \n--> More info:"
            print(colored(msg, "yellow"))
            print(colored(url, "blue"))
    params = rationals_dict[rational_full_name][approx_func]
    return params["init_w_numerator"], params["init_w_denominator"]

In [None]:
def Rational_PYTORCH_A_F(x, weight_numerator, weight_denominator, device, pre, post):
    # P(X) / Q(X) = a_0 + a_1 * X + ... + a_n * X^n /
    #               1 + | b_1 * X | + | b_2 * X^2| + ... + | b_m * X ^m|

    len_num, len_deno = len(weight_numerator), len(weight_denominator)

    z = x.view(-1)

    xps = torch.vander(z, N=max(len_num, len_deno), increasing=True)
    numerator_mul = xps.mul(weight_numerator)
    numerator = numerator_mul.sum(-1)
    expanded_dw = torch.cat([pre, weight_denominator, post])
    denominator = xps.mul(expanded_dw)
    denominator_ab = denominator.abs()
    denomi_sum = denominator_ab.sum(-1)

    out = numerator.div(denomi_sum)
    new_out = out.view(x.shape)
    return new_out

In [None]:
x = torch.tensor([[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]])
weight_numerator = torch.tensor([1.0, 2.0, 3.0])