In [1]:
import os
import argparse
import random
import time
import json

import numpy as np
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import *
from torch.optim import *
import torch.nn.functional as F

from sklearn.metrics import *
from sklearn.model_selection import KFold

import sys
sys.path.append('.')

from src.modules import *
from src.data_handler import *
from src import logger
from src.class_balanced_loss import *
from typing import NamedTuple
from torchvision.models import efficientnet as efn

from train_glaucoma_fair_fin import train, validation, Identity_Info, quantifiable_efficientnet

from fairlearn.metrics import *

imb_info = Identity_Info()

In [2]:
out_dim = 1
criterion = nn.BCEWithLogitsLoss()
predictor_head = nn.Sigmoid()
in_feat_to_final = 1280
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

fin_mu = 0.01
fin_sigma = 1.
fin_momentum = 0.3
model_type = 'resnext'# 'resnext'  # or quant
modality_types = 'rnflt'
task = 'cls'
pretrained_weights = 'results/crosssectional_rnflt_fin_race_ablation_of_sigma/fullysup_quant_rnflt_Taskcls_lr5e-5_bz6_4442_auc0.7311/best_weights.pth'
pretrained_weights = 'results/crosssectional_rnflt_fin_race_ablation_of_sigma/fullysup_resnext_rnflt_Taskcls_lr5e-5_bz6_3162_auc0.8324/best_weights.pth'
if model_type == 'resnext':
    pretrained_weights = 'results/crosssectional_rnflt_fin_race_ablation_of_sigma/fullysup_resnext_rnflt_Taskcls_lr5e-5_bz6_865_auc0.8510/last_weights.pth'
else:
    pretrained_weights = 'results/crosssectional_rnflt_fin_race_ablation_of_sigma/fullysup_quant_rnflt_Taskcls_lr5e-5_bz6_9354_auc0.8495/best_weights.pth'

ag_norm = Fair_Identity_Normalizer(
    3,
    dim=in_feat_to_final,
    mu=fin_mu,
    sigma=fin_sigma,
    momentum=fin_momentum,
)
in_dim = 1
# model = quantifiable_efficientnet(width_mult=1.0, depth_mult=1.0, weights=EfficientNet_B1_Weights.IMAGENET1K_V2)# create_model(model_type=model_type, in_dim=in_dim, out_dim=out_dim, include_final=False)
model = create_model(model_type=model_type, in_dim=in_dim, out_dim=out_dim, include_final=False)
final_layer = nn.Linear(in_features=in_feat_to_final, out_features=out_dim, bias=False)
model = nn.Sequential(model, ag_norm, final_layer)
model = model.to(device)

checkpoint = torch.load(pretrained_weights)

start_epoch = checkpoint['epoch'] + 1
model.load_state_dict(checkpoint['model_state_dict'])
# efnm = create_model(model_type=model_type, in_dim=in_dim, out_dim=out_dim, include_final=False)
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# scaler.load_state_dict(checkpoint['scaler_state_dict'])
# scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

<All keys matched successfully>

In [3]:
data_dir = "../quant_notes/data_cmpr"
image_size = 200
attribute_type = "race"

trn_dataset = EyeFair(
    os.path.join(data_dir, "train"),
    modality_type=modality_types,
    task=task,
    resolution=image_size,
    attribute_type=attribute_type,
    depth=3 if model_type == "resnext" else 1,
)
tst_dataset = EyeFair(
    os.path.join(data_dir, "test"),
    modality_type=modality_types,
    task=task,
    resolution=image_size,
    attribute_type=attribute_type,
    depth=3 if model_type == "resnext" else 1,
)

batch_size = 6
workers = 1

train_dataset_loader = torch.utils.data.DataLoader(
    trn_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=workers,
    pin_memory=True,
    drop_last=True,
)

validation_dataset_loader = torch.utils.data.DataLoader(
    tst_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=workers,
    pin_memory=True,
    drop_last=False,
)

min: -31.9900, max: 2.2700
min: -31.1600, max: 2.5300


In [4]:
for i in train_dataset_loader:
    print(i[0].dtype)
    break

torch.float32


In [5]:
from copy import deepcopy

import torch.ao.quantization
qmodel = deepcopy(model).to('cpu')
if model_type == 'resnext':
    qmodel[0].fuse_model(is_qat=True)
else:
    qmodel[0] = torch.quantization.QuantWrapper(qmodel[0])
qmodel[1].v = False
# qmodel = torch.ao.quantization.fuse_modules(model, ['conv2', 'bn2'])
qmodel[2] = torch.quantization.QuantWrapper(qmodel[2])
qmodel.qconfig = torch.ao.quantization.default_per_channel_qconfig
torch.ao.quantization.prepare_qat(qmodel, inplace=True)
print(qmodel.qconfig)
qmodel


QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, quant_min=0, quant_max=127){'factory_kwargs': <function _add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device at 0x7f0664439ea0>}, weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){'factory_kwargs': <function _add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device at 0x7f0664439ea0>})


In [6]:
# scaler = torch.cuda.amp.GradScaler()
scaler = None

optimizer = AdamW(qmodel.parameters(), lr=5e-5, betas=(0.0, 0.1), weight_decay=0.)

scheduler = StepLR(optimizer, step_size=30, gamma=0.1)

for epoch in range(start_epoch, start_epoch+1):
    train_loss, train_acc, train_auc, trn_preds, trn_gts, trn_attrs, trn_pred_gt_by_attrs, trn_other_metrics = train(qmodel, criterion, optimizer, scaler, train_dataset_loader, epoch, None, identity_Info=imb_info, time_window=-1, _device='cpu')
    test_loss, test_acc, test_auc, tst_preds, tst_gts, tst_attrs, tst_pred_gt_by_attrs, tst_other_metrics = validation(qmodel, criterion, optimizer, validation_dataset_loader, epoch, identity_Info=imb_info, _device='cpu')
    scheduler.step()


349train ====> epcoh 4 loss: 0.2247 auc: 0.9806 time: 1769.8557
0-attr auc: 0.9746
1-attr auc: 0.9848
2-attr auc: 0.9816
cpu
test <==== epcoh 4 loss: 7.2842 auc: 0.7964
0-attr auc: 0.8400
1-attr auc: 0.7598
2-attr auc: 0.7760


In [7]:
qmodel.eval()
torch.quantization.convert(qmodel, inplace=True)

Sequential(
  (0): QuantizableResNet(
    (conv1): QuantizedConvReLU2d(3, 64, kernel_size=(7, 7), stride=(2, 2), scale=1.6556651592254639, zero_point=0, padding=(3, 3))
    (bn1): Identity()
    (relu): Identity()
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): QuantizableBottleneck(
        (conv1): QuantizedConvReLU2d(64, 256, kernel_size=(1, 1), stride=(1, 1), scale=0.815531849861145, zero_point=0)
        (bn1): Identity()
        (conv2): QuantizedConvReLU2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=1.0254476070404053, zero_point=0, padding=(1, 1), groups=64)
        (bn2): Identity()
        (conv3): QuantizedConv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), scale=1.6215773820877075, zero_point=79)
        (bn3): Identity()
        (relu): ReLU()
        (downsample): Sequential(
          (0): QuantizedConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), scale=2.743269681930542, zero_point=83

In [8]:
test_loss, test_acc, test_auc, tst_preds, tst_gts, tst_attrs, tst_pred_gt_by_attrs, tst_other_metrics = validation(qmodel, criterion, optimizer, validation_dataset_loader, epoch, identity_Info=imb_info, _device='cpu')


cpu
test <==== epcoh 4 loss: 7.3356 auc: 0.4951
0-attr auc: 0.4965
1-attr auc: 0.5019
2-attr auc: 0.4871


In [9]:
forward_model_with_fin(qmodel, i[0], i[2]), i[1]

((tensor([[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]]),
  tensor([[-0.0056,  0.0104, -0.0095,  ..., -0.0075,  0.0050, -0.0149],
          [-0.0056,  0.0104, -0.0095,  ..., -0.0075,  0.0050, -0.0149],
          [ 0.0029, -0.0043,  0.0092,  ..., -0.0040,  0.0090,  0.0016],
          [-0.0056,  0.0104, -0.0095,  ..., -0.0075,  0.0050, -0.0149],
          [-0.0056,  0.0104, -0.0095,  ..., -0.0075,  0.0050, -0.0149],
          [-0.0040, -0.0076,  0.0070,  ...,  0.0041,  0.0212,  0.0049]],
         grad_fn=<CopySlices>)),
 tensor([0., 1., 0., 1., 0., 1.]))

In [10]:
forward_model_with_fin(model.to('cpu'), i[0], i[2]), i[1]


((tensor([[ -6.7776],
          [ 10.7342],
          [-14.3578],
          [ 18.2717],
          [ -6.4111],
          [  6.9140]], grad_fn=<MmBackward0>),
  tensor([[-0.1255, -0.4943,  0.9652,  ...,  0.2935, -0.2616, -0.2475],
          [ 0.2636,  1.0062, -1.7994,  ..., -0.3134,  0.2436,  0.4206],
          [-0.4610, -0.4019,  1.4701,  ...,  0.6462, -2.1826, -0.4101],
          [ 0.2795,  1.5579, -3.1290,  ..., -0.3701,  0.1406,  0.7438],
          [-0.1650, -0.4156,  1.4783,  ...,  0.0969, -0.3455, -0.1474],
          [ 0.3351,  0.2130, -0.3723,  ..., -0.3091, -0.0670,  0.2366]],
         grad_fn=<CopySlices>)),
 tensor([0., 1., 0., 1., 0., 1.]))