In [1]:
from torchvision.models import efficientnet as efn
import torch
import torch.nn as nn

In [2]:
import os

# from src.modules import *
from src.data_handler import EyeFair

# from train_glaucoma_fair_fin import train, validation, Identity_Info

from fairlearn.metrics import *

# imb_info = Identity_Info()

In [3]:
efn._efficientnet_conf("efficientnet_b1", width_mult=1.0, depth_mult=1.1)

([MBConvConfig(expand_ratio=1, kernel=3, stride=1, input_channels=32, out_channels=16, num_layers=2, block=<class 'torchvision.models.efficientnet.MBConv'>),
  MBConvConfig(expand_ratio=6, kernel=3, stride=2, input_channels=16, out_channels=24, num_layers=3, block=<class 'torchvision.models.efficientnet.MBConv'>),
  MBConvConfig(expand_ratio=6, kernel=5, stride=2, input_channels=24, out_channels=40, num_layers=3, block=<class 'torchvision.models.efficientnet.MBConv'>),
  MBConvConfig(expand_ratio=6, kernel=3, stride=2, input_channels=40, out_channels=80, num_layers=4, block=<class 'torchvision.models.efficientnet.MBConv'>),
  MBConvConfig(expand_ratio=6, kernel=5, stride=1, input_channels=80, out_channels=112, num_layers=4, block=<class 'torchvision.models.efficientnet.MBConv'>),
  MBConvConfig(expand_ratio=6, kernel=5, stride=2, input_channels=112, out_channels=192, num_layers=5, block=<class 'torchvision.models.efficientnet.MBConv'>),
  MBConvConfig(expand_ratio=6, kernel=3, stride=1

In [4]:
# efn.efficientnet_b1().features[0][2] = nn.Hardswish(inplace=True)
data_dir = "../quant_notes/data_cmpr"
image_size = 200
attribute_type = 'race' 
modality_types = 'rnflt'
task = 'cls'
trn_dataset = EyeFair(
    os.path.join(data_dir, "train"),
    modality_type=modality_types,
    task=task,
    resolution=image_size,
    attribute_type=attribute_type,
)
batch_size = 6
validation_dataset_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, drop_last=False)

min: -31.9900, max: 2.2700


In [6]:
with torch.no_grad():
        for i, (input, target, attr) in enumerate(validation_dataset_loader):
            input = input
            target = target
            attr = attr
            break

input.shape, target, attr

(torch.Size([6, 1, 200, 200]),
 tensor([1., 1., 1., 1., 0., 1.]),
 tensor([2, 1, 0, 1, 1, 0], dtype=torch.int32))

In [7]:
vf_predictor = efn.efficientnet_b1()
vf_predictor.features[0][0] = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

In [8]:
vf_predictor(input)

tensor([[-0.0710, -0.0810,  0.2184,  ...,  0.4250,  0.1929,  0.1946],
        [-0.2074,  0.0270,  0.0295,  ...,  0.0024, -0.0803,  0.0370],
        [-0.0092, -0.0492,  0.1602,  ...,  0.3292,  0.2240,  0.5916],
        [-0.1562, -0.0042,  0.0782,  ..., -0.0252,  0.0917,  0.0276],
        [-0.1243,  0.0897, -0.0185,  ...,  0.0509,  0.1678,  0.2858],
        [-0.0150,  0.0763,  0.0207,  ..., -0.0083,  0.0947,  0.4025]],
       grad_fn=<AddmmBackward0>)

In [9]:
def replace_activations(model):
    for name, module in model.named_children():
        if isinstance(module, nn.SiLU):
            setattr(model, name, nn.Hardswish(inplace=module.inplace))
        else:
            replace_activations(module)

replace_activations(vf_predictor)

In [10]:
# vf_predictor

In [11]:
# vf_predictor.features[0][2] = nn.Hardswish(inplace=True)
vf_predictor(input)

tensor([[-0.1488,  0.4292,  0.1477,  ...,  0.3587, -0.2659,  1.0195],
        [-0.1539, -0.0677,  0.0510,  ...,  0.0853,  0.0375, -0.0057],
        [-0.1577,  0.0715,  0.0117,  ..., -0.0156,  0.1207,  0.0635],
        [-0.1107,  0.0660,  0.0141,  ..., -0.0615, -0.0448,  0.0409],
        [ 0.0827,  0.0298, -0.1939,  ...,  0.0200,  0.0266,  0.3319],
        [-0.0578, -0.0629,  0.0364,  ...,  0.0325,  0.0881,  0.0856]],
       grad_fn=<AddmmBackward0>)

In [12]:
# def print_size_of_model(model):
#     torch.save(model.state_dict(), "temp.p")
#     print('Size (MB):', os.path.getsize("temp.p")/1e6)
#     os.remove('temp.p')

# q_pred = nn.Sequential(
#     torch.ao.quantization.QuantStub(), vf_predictor, torch.ao.quantization.DeQuantStub()
# )
# print_size_of_model(q_pred)
# q_pred.eval().to('cpu')
# q_pred.qconfig = torch.ao.quantization.default_per_channel_qconfig
# torch.ao.quantization.prepare(q_pred, inplace=True)
# torch.ao.quantization.convert(q_pred, inplace=True)
# print_size_of_model(q_pred)

# q_pred(input)

In [13]:
from functools import partial
from torchvision.models.quantization.mobilenetv3 import QuantizableSqueezeExcitation
from typing import Callable

class QuantizableMBConv(efn.MBConv):
    def __init__(self,
        cnf: efn.MBConvConfig,
        stochastic_depth_prob: float,
        norm_layer: Callable[..., nn.Module],
        se_layer: Callable[..., nn.Module] = QuantizableSqueezeExcitation,):
        super().__init__(cnf=cnf, stochastic_depth_prob=stochastic_depth_prob, norm_layer=norm_layer, se_layer=se_layer)
        # super().__init__()
        # self.skip_mul = nn.quantized.FloatFunctional()
        self.f_add = nn.quantized.FloatFunctional()


    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # return self.skip_mul.mul(self._scale(input), input)
        result = self.block(input)
        if self.use_res_connect:
            result = self.f_add.add(input, self.stochastic_depth(result))
            # result = self.additive.add
            # result += input
        return result

def qeffnet_conf(weights=None, progress=True, **kwargs):
    block = partial(efn.MBConv, se_layer=QuantizableSqueezeExcitation)
    block = partial(QuantizableMBConv, se_layer=QuantizableSqueezeExcitation)
    bneck_conf = partial(efn.MBConvConfig, width_mult=kwargs.pop("width_mult"), depth_mult=kwargs.pop("depth_mult"), block=block)
    inverted_residual_setting = [
        bneck_conf(1, 3, 1, 32, 16, 1),
        bneck_conf(6, 3, 2, 16, 24, 2),
        bneck_conf(6, 5, 2, 24, 40, 2),
        bneck_conf(6, 3, 2, 40, 80, 3),
        bneck_conf(6, 5, 1, 80, 112, 3),
        bneck_conf(6, 5, 2, 112, 192, 4),
        bneck_conf(6, 3, 1, 192, 320, 1),
    ]
    last_channel = None

    model = efn._efficientnet(
            inverted_residual_setting, kwargs.pop("dropout", 0.2), last_channel, weights, progress, **kwargs
        )
    model.features[0][0] = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    replace_activations(model)
    return model

qnet = qeffnet_conf(width_mult=1.0, depth_mult=1.0).to('cuda:0')

In [17]:
qnet.eval()
qnet(input.to('cuda:0'))

tensor([[-2.3925e-11,  2.5601e-11, -7.5397e-13,  ..., -3.8245e-11,
         -1.9479e-11, -2.2482e-11],
        [-1.1568e-11,  1.7935e-11, -5.0793e-12,  ..., -1.6636e-11,
         -1.0472e-11, -8.8664e-12],
        [-9.6969e-12,  1.8795e-11, -4.2952e-12,  ..., -1.4385e-11,
         -1.0879e-11, -7.5126e-12],
        [-1.1569e-11,  1.8386e-11, -2.3364e-12,  ..., -1.8193e-11,
         -1.1599e-11, -9.5036e-12],
        [-1.4608e-11,  1.9915e-11, -1.8071e-12,  ..., -2.1927e-11,
         -1.2409e-11, -1.1909e-11],
        [-1.4378e-11,  2.0976e-11, -9.7564e-13,  ..., -2.2421e-11,
         -1.4014e-11, -1.0522e-11]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [24]:
qnet

EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): Sequential(
      (0): QuantizableMBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): Hardswish()
          )
          (1): QuantizableSqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (activation): Hardswish()
            (scale_activation): Hardsigmoid()
            (skip_mul): FloatFunctional

In [21]:
# nn.Sequential(
#     torch.ao.quantization.QuantStub(), qnet.features[:2], torch.ao.quantization.DeQuantStub()
# )(input)
qnet.features[1][0]

QuantizableMBConv(
  (block): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): QuantizableSqueezeExcitation(
      (avgpool): AdaptiveAvgPool2d(output_size=1)
      (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
      (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
      (activation): Hardswish()
      (scale_activation): Hardsigmoid()
      (skip_mul): FloatFunctional(
        (activation_post_process): Identity()
      )
    )
    (2): Conv2dNormActivation(
      (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (stochastic_depth): StochasticDepth(p=0.0, mode=row)
  (f_add): FloatFunctional(
    (activation_post_process): Identity

In [23]:
import copy
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

q_pred = nn.Sequential(
    torch.ao.quantization.QuantStub(), copy.deepcopy(qnet), torch.ao.quantization.DeQuantStub()
)
print_size_of_model(q_pred)
q_pred.eval().to('cpu')
q_pred.qconfig = torch.ao.quantization.default_per_channel_qconfig
q_pred_q = torch.ao.quantization.prepare(q_pred, inplace=True)
q_pred_q(input)
torch.ao.quantization.convert(q_pred_q, inplace=True)
print_size_of_model(q_pred_q)

q_pred_q(input)

Size (MB): 21.431134




Size (MB): 6.387544


tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [20]:

q_pred_q(input).std()

tensor(0.)

In [107]:
class Fair_Identity_Normalizer(nn.Module):
    def __init__(self, num_attr=0, dim=0, mu=0.001, sigma=0.1, momentum=0, test=False):
        super().__init__()
        self.num_attr = num_attr
        self.dim = dim

        self.mus = nn.Parameter(torch.randn(self.num_attr, self.dim)*mu)
        self.sigmas = nn.Parameter(torch.randn(self.num_attr, self.dim)*sigma)
        if test:
            self.sigmas = nn.Parameter(torch.ones(self.num_attr, self.dim)*sigma)
        self.eps = 1e-6
        self.momentum = momentum


    def forward(self, x, attr):
        x_clone = x.clone()
        for idx in range(x.shape[0]):
            print(idx)
            x[idx,:] = (x[idx,:] - self.mus[attr[idx], :])/( torch.log(1+torch.exp(self.sigmas[attr[idx], :])) + self.eps)
        x = (1-self.momentum)*x + self.momentum*x_clone

        return x
x = torch.tensor([[-0.1929, -1.3424, -0.7393,  0.4896, -0.8414,  0.7906, -0.1743, -0.9311,
         -1.2933, -1.1389],
        [-0.0192,  1.3003, -1.1609, -0.1955,  0.5325, -0.4887, -0.7268, -1.6017,
          0.6089, -0.5085],
        [-0.3077,  0.8897, -1.5360,  1.6535, -0.9082, -2.7388, -1.1735, -1.2531,
          0.3921, -0.0116],
        [-0.8089, -0.9253,  2.0942,  0.1911, -1.4162,  0.5749,  0.3561,  1.2656,
         -0.6028,  0.3160],
        [-2.0408,  0.0425,  0.6633, -0.8312, -1.0075, -0.3526,  1.2529, -1.2759,
         -0.4371, -0.3171]])  # Example input
attr = torch.randint(0, 3, (5,))  # Example attribute

In [114]:
class Fast_FIN(nn.Module):
    def __init__(self, num_attr=0, dim=0, mu=0.001, sigma=0.1, momentum=0, test=False):
        super().__init__()
        self.num_attr = num_attr
        self.dim = dim

        self.mus = nn.Parameter(torch.randn(self.num_attr, self.dim)*mu)
        self.sigmas = nn.Parameter(torch.randn(self.num_attr, self.dim)*sigma)
        if test:
            self.sigmas = nn.Parameter(torch.ones(self.num_attr, self.dim)*sigma)
        self.eps = 1e-6
        self.momentum = momentum


    def forward(self, x, attr):
        x_clone = x.clone()
        for group in range(self.num_attr):
            mask = attr == group
            mu = self.mus[group]
            sigma = self.sigmas[group]
            x_clone[mask] = (x[mask] - mu) / (torch.log(1+torch.exp(sigma)) + self.eps)
        
        x_clone = (1-self.momentum) * x_clone + self.momentum*x

        return x_clone

class Mult_FINN(Fast_FIN):
    def forward(self, x, attr):
        x_clone = x.clone()
        for group in range(self.num_attr):
            mask = attr == group
            mu = self.mus[group]
            sigma = self.sigmas[group]
            x_clone[mask] = (x[mask] - mu) * torch.qu/ (torch.log(1+torch.exp(sigma)) + self.eps)
        
        x_clone = (1-self.momentum) * x_clone + self.momentum*x

        return x_clone

In [86]:
class FairIdentityNormalizationQuant(nn.Module):
    def __init__(self, num_attr=0, dim=0, mu=0.001, sigma=0.1, momentum=0, test=False):
        super(FairIdentityNormalizationQuant, self).__init__()
        self.num_attr = num_attr
        self.dim = dim

        self.mus = nn.Parameter(torch.randn(self.num_attr, self.dim) * mu)
        self.sigmas = nn.Parameter(torch.randn(self.num_attr, self.dim) * sigma)
        if test:
            self.sigmas = nn.Parameter(torch.ones(self.num_attr, self.dim) * sigma)
        self.eps = 1e-6
        self.momentum = momentum
        
        # Initialize FloatFunctional for quantization
        self.ff = nn.quantized.FloatFunctional()

    def forward(self, x, attr):
        x_clone = x.clone()
        for idx in range(x.shape[0]):
            mu_attr = self.mus[attr[idx], :]
            sigma_attr = torch.log(1 + torch.exp(self.sigmas[attr[idx], :])) + self.eps  # Do this in FP32
            x[idx, :] = self.ff.div(self.ff.sub(x[idx, :], mu_attr), sigma_attr)
        
        # Batch normalization with momentum using FloatFunctional
        x = self.ff.add(
            self.ff.mul_scalar(x, 1 - self.momentum),
            self.ff.mul_scalar(x_clone, self.momentum)
        )

        return x

In [80]:
def set_random_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if use multi-GPU
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True


In [116]:
set_random_seed(42)
model = Fair_Identity_Normalizer(num_attr=3, dim=10)
set_random_seed(42)
qmodel = Fast_FIN(num_attr=3, dim=10)
model(x.clone(), attr) ==  qmodel(x.clone(), attr)


0
1
2
3
4


tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])

In [90]:
import copy
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

q_pred = nn.Sequential(
    torch.ao.quantization.QuantStub(), copy.deepcopy(qmodel), torch.ao.quantization.DeQuantStub()
)
print_size_of_model(q_pred)
q_pred.eval().to('cpu')
q_pred.qconfig = torch.ao.quantization.default_per_channel_qconfig
q_pred_q = torch.ao.quantization.prepare(q_pred, inplace=True)
q_pred_q(x.clone(), attr)
torch.ao.quantization.convert(q_pred_q, inplace=True)
print_size_of_model(q_pred_q)

q_pred_q(x.clone(), attr)

Size (MB): 0.001926


TypeError: Sequential.forward() takes 2 positional arguments but 3 were given

In [92]:
model(x.clone(), attr) ==  qmodel(x.clone(), attr)


0
1
2
3
4


AttributeError: 'FloatFunctional' object has no attribute 'div'