In [1]:
import os
import argparse
import random
import time
import json

import numpy as np
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import *
from torch.optim import *
import torch.nn.functional as F

from sklearn.metrics import *
from sklearn.model_selection import KFold

import sys
sys.path.append('.')

from src.modules import *
from src.data_handler import *
from src import logger
from src.class_balanced_loss import *
from typing import NamedTuple
from torchvision.models import efficientnet as efn

from train_glaucoma_fair_fin import train, validation, Identity_Info, quantifiable_efficientnet

from fairlearn.metrics import *

imb_info = Identity_Info()

In [2]:
out_dim = 1
criterion = nn.BCEWithLogitsLoss()
predictor_head = nn.Sigmoid()
in_feat_to_final = 1280
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

fin_mu = 0.01
fin_sigma = 1.
fin_momentum = 0.3
model_type = 'quant'# 'resnext'  # or quant
modality_types = 'rnflt'
task = 'cls'
pretrained_weights = 'results/crosssectional_rnflt_fin_race_ablation_of_sigma/fullysup_quant_rnflt_Taskcls_lr5e-5_bz6_4442_auc0.7311/best_weights.pth'
pretrained_weights = 'results/crosssectional_rnflt_fin_race_ablation_of_sigma/fullysup_resnext_rnflt_Taskcls_lr5e-5_bz6_3162_auc0.8324/best_weights.pth'
if model_type == 'resnext':
    pretrained_weights = 'results/crosssectional_rnflt_fin_race_ablation_of_sigma/fullysup_resnext_rnflt_Taskcls_lr5e-5_bz6_865_auc0.8510/last_weights.pth'
else:
    pretrained_weights = 'results/crosssectional_rnflt_fin_race_ablation_of_sigma/fullysup_quant_rnflt_Taskcls_lr5e-5_bz6_9354_auc0.8495/best_weights.pth'

ag_norm = Fair_Identity_Normalizer(
    3,
    dim=in_feat_to_final,
    mu=fin_mu,
    sigma=fin_sigma,
    momentum=fin_momentum,
)
in_dim = 1
# model = quantifiable_efficientnet(width_mult=1.0, depth_mult=1.0, weights=EfficientNet_B1_Weights.IMAGENET1K_V2)# create_model(model_type=model_type, in_dim=in_dim, out_dim=out_dim, include_final=False)
model = create_model(model_type=model_type, in_dim=in_dim, out_dim=out_dim, include_final=False)
final_layer = nn.Linear(in_features=in_feat_to_final, out_features=out_dim, bias=False)
model = nn.Sequential(model, ag_norm, final_layer)
model = model.to(device)

checkpoint = torch.load(pretrained_weights)

start_epoch = checkpoint['epoch'] + 1
model.load_state_dict(checkpoint['model_state_dict'])
# efnm = create_model(model_type=model_type, in_dim=in_dim, out_dim=out_dim, include_final=False)
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# scaler.load_state_dict(checkpoint['scaler_state_dict'])
# scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

<All keys matched successfully>

In [3]:
data_dir = "../quant_notes/data_cmpr"
image_size = 200
attribute_type = "race"

trn_dataset = EyeFair(
    os.path.join(data_dir, "train"),
    modality_type=modality_types,
    task=task,
    resolution=image_size,
    attribute_type=attribute_type,
    depth=3 if model_type == "resnext" else 1,
)
tst_dataset = EyeFair(
    os.path.join(data_dir, "test"),
    modality_type=modality_types,
    task=task,
    resolution=image_size,
    attribute_type=attribute_type,
    depth=3 if model_type == "resnext" else 1,
)

batch_size = 6
workers = 1

train_dataset_loader = torch.utils.data.DataLoader(
    trn_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=workers,
    pin_memory=True,
    drop_last=True,
)

validation_dataset_loader = torch.utils.data.DataLoader(
    tst_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=workers,
    pin_memory=True,
    drop_last=False,
)

min: -31.9900, max: 2.2700
min: -31.1600, max: 2.5300


In [4]:
for i in train_dataset_loader:
    print(i[0].dtype)
    break

torch.float32


In [5]:
from copy import deepcopy

import torch.ao.quantization
qmodel = deepcopy(model).to('cpu')
if model_type == 'resnext':
    qmodel[0].fuse_model(is_qat=True)
else:
    qmodel[0] = torch.quantization.QuantWrapper(qmodel[0])
qmodel[1].v = False
# qmodel = torch.ao.quantization.fuse_modules(model, ['conv2', 'bn2'])
qmodel[2] = torch.quantization.QuantWrapper(qmodel[2])
qmodel.qconfig = torch.ao.quantization.default_per_channel_qconfig
torch.ao.quantization.prepare_qat(qmodel, inplace=True)
print(qmodel.qconfig)


QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, quant_min=0, quant_max=127){'factory_kwargs': <function _add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device at 0x7fba753416c0>}, weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){'factory_kwargs': <function _add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device at 0x7fba753416c0>})


In [6]:
# scaler = torch.cuda.amp.GradScaler()
scaler = None

optimizer = AdamW(qmodel.parameters(), lr=5e-5, betas=(0.0, 0.1), weight_decay=0.)

scheduler = StepLR(optimizer, step_size=30, gamma=0.1)

for epoch in range(start_epoch, start_epoch+4):
    train_loss, train_acc, train_auc, trn_preds, trn_gts, trn_attrs, trn_pred_gt_by_attrs, trn_other_metrics = train(qmodel, criterion, optimizer, scaler, train_dataset_loader, epoch, None, identity_Info=imb_info, time_window=-1, _device='cpu')
    test_loss, test_acc, test_auc, tst_preds, tst_gts, tst_attrs, tst_pred_gt_by_attrs, tst_other_metrics = validation(qmodel, criterion, optimizer, validation_dataset_loader, epoch, identity_Info=imb_info, _device='cpu')
    scheduler.step()


349train ====> epcoh 2 loss: 0.5350 auc: 0.8055 time: 282.7320
0-attr auc: 0.8228
1-attr auc: 0.7849
2-attr auc: 0.7869
cpu
test <==== epcoh 2 loss: 3.0521 auc: 0.8398
0-attr auc: 0.8567
1-attr auc: 0.7945
2-attr auc: 0.8513
349train ====> epcoh 3 loss: 0.5122 auc: 0.8210 time: 288.8274
0-attr auc: 0.8354
1-attr auc: 0.8068
2-attr auc: 0.8046
cpu
test <==== epcoh 3 loss: 1.5259 auc: 0.8378
0-attr auc: 0.8442
1-attr auc: 0.7906
2-attr auc: 0.8543
349train ====> epcoh 4 loss: 0.4939 auc: 0.8372 time: 274.7201
0-attr auc: 0.8433
1-attr auc: 0.8465
2-attr auc: 0.8073
cpu
test <==== epcoh 4 loss: 4.1779 auc: 0.8296
0-attr auc: 0.8368
1-attr auc: 0.7934
2-attr auc: 0.8346
349train ====> epcoh 5 loss: 0.4710 auc: 0.8503 time: 272.4942
0-attr auc: 0.8549
1-attr auc: 0.8307
2-attr auc: 0.8505
cpu
test <==== epcoh 5 loss: 12.2015 auc: 0.8322
0-attr auc: 0.8356
1-attr auc: 0.7890
2-attr auc: 0.8544


In [7]:
qmodel.eval()
torch.quantization.convert(qmodel, inplace=True)



Sequential(
  (0): QuantWrapper(
    (quant): Quantize(scale=tensor([8.1437]), zero_point=tensor([0]), dtype=torch.quint8)
    (dequant): DeQuantize()
    (module): EfficientNet(
      (features): Sequential(
        (0): Conv2dNormActivation(
          (0): QuantizedConv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), scale=17.555675506591797, zero_point=51, padding=(1, 1), bias=False)
          (1): QuantizedBatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): QuantizedHardswish()
        )
        (1): Sequential(
          (0): QuantizableMBConv(
            (block): Sequential(
              (0): Conv2dNormActivation(
                (0): QuantizedConv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), scale=17.564468383789062, zero_point=53, padding=(1, 1), groups=32, bias=False)
                (1): QuantizedBatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (2): QuantizedHardswish()
              )
   

In [8]:
test_loss, test_acc, test_auc, tst_preds, tst_gts, tst_attrs, tst_pred_gt_by_attrs, tst_other_metrics = validation(qmodel, criterion, optimizer, validation_dataset_loader, epoch, identity_Info=imb_info, _device='cpu')


cpu
test <==== epcoh 5 loss: 3987.0296 auc: 0.5102
0-attr auc: 0.4984
1-attr auc: 0.5182
2-attr auc: 0.5118


In [13]:
forward_model_with_fin(qmodel, i[0], i[2]), i[1]

((tensor([[ 1571.4341],
          [12075.2295],
          [ 4879.7163],
          [17533.8945],
          [10421.0889],
          [ 6451.1504]]),
  tensor([[ 5.4292e+02, -5.8519e-03,  5.6278e+02,  ...,  5.2520e+02,
            2.0309e-02,  1.4293e+03],
          [ 1.0364e+03,  1.8478e+03,  8.7756e+03,  ...,  8.9955e+02,
            3.5601e+03,  2.4749e+03],
          [-5.5711e-03,  6.1593e+02,  2.1939e+03,  ...,  4.4977e+02,
            1.4241e+03,  1.9352e-02],
          [ 1.0364e+03,  7.3910e+03,  6.5817e+03,  ...,  3.1485e+03,
            2.8481e+03,  4.1247e+03],
          [ 1.0364e+03,  2.4637e+03,  8.7756e+03,  ...,  8.9955e+02,
            1.4241e+03,  2.4749e+03],
          [ 6.2375e+02,  9.3059e+03,  1.9183e+03,  ...,  2.3414e+02,
            2.6367e+03,  1.4994e+03]], grad_fn=<CopySlices>)),
 tensor([1., 1., 0., 0., 1., 1.]))

In [14]:
forward_model_with_fin(model.to('cpu'), i[0], i[2]), i[1]


((tensor([[ 1.3384],
          [ 2.5935],
          [-1.0915],
          [-0.9906],
          [-1.3039],
          [ 1.6806]], grad_fn=<MmBackward0>),
  tensor([[-0.0710,  0.5979,  0.3407,  ...,  0.4270,  0.3172,  0.0049],
          [-0.0826,  0.0116,  1.4590,  ..., -0.1726,  0.0186,  0.1054],
          [ 0.1008, -0.3759, -0.3595,  ...,  0.0357, -0.3483,  0.0250],
          [ 0.3190,  0.1690,  0.0527,  ..., -0.0835, -0.0580, -0.2341],
          [ 0.0390, -0.1434, -0.6564,  ...,  0.5235, -0.2372,  0.0250],
          [ 0.0070,  1.2618,  0.4286,  ..., -0.0641,  1.1384, -0.0683]],
         grad_fn=<AsStridedBackward0>)),
 tensor([1., 1., 0., 0., 1., 1.]))