# Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%reload_ext autoreload

In [3]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
import time
from itertools import islice
from dataclasses import dataclass
import torchvision
from torchvision.models import densenet161, DenseNet161_Weights, vit_b_16, ViT_B_16_Weights, densenet121, DenseNet121_Weights
import os
import sys
from pathlib import Path
from torchinfo import summary

In [4]:
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)

In [5]:
from CheXpert.race_prediction.dataset import CheXpertRaceDataset
from CheXpert.disease_prediction.dataset import CheXpertDiseaseDataset
from shared_utils import vprint, to_gpu, add_mean_to_list, Mode, SharedConfigs
import shared_utils
from CheXpert.disease_prediction.utils import Configs as DiseaseConfigs
from CheXpert.race_prediction.utils import Configs as RaceConfigs
from MIMIC_CXR.dataset import CXRDataset
from MIMIC_CXR.utils import Configs as CXRConfigs

# Configs

In [63]:
@dataclass
class Configs(SharedConfigs):
    CXR_DATA_DIR = os.path.join("data", "MIMIC-CXR-JPG")
    CXR_VALID_LABELS_FILENAME = "valid_400_no_u_no_other.csv"
    CXR_FILENAMES = CXRConfigs.CXR_FILENAMES
    CHEXPERT_DATA_DIR = os.path.join("data", "CheXpert", "CheXpert-v1.0-small")
    CHEXPERT_DISEASE_TRAINED_MODELS_DIR = os.path.join("CheXpert", "disease_prediction", "trained_models")
    CHEXPERT_RACE_TRAINED_MODELS_DIR = os.path.join("CheXpert", "race_prediction", "trained_models")
    CHEXPERT_VALID_LABELS_FILENAME = "valid_demo30_no_u_sampled30.csv"
    CHEXPERT_DEMO_FILENAME = "CHEXPERT DEMO.csv"
    DISEASE_ANNOTATIONS_COLUMNS = DiseaseConfigs.ANNOTATIONS_COLUMNS
    CHALLENGE_ANNOTATIONS_COLUMNS = DiseaseConfigs.CHALLENGE_ANNOTATIONS_COLUMNS
    RACE_ANNOTATIONS_COLUMNS = RaceConfigs.ANNOTATIONS_COLUMNS
    NUM_DISEASE_CLASSES = DiseaseConfigs.NUM_CLASSES
    NUM_RACE_CLASSES = RaceConfigs.NUM_CLASSES
    CHEXPERT_RACE_DICT = RaceConfigs.RACE_DICT
    BATCH_SIZE = 4
    VALID_SIZE_DEBUG = 10**10

In [7]:
shared_utils.set_seed(Configs.SEED)

In [8]:
if torch.cuda.is_available():
    vprint(f"Memory info: {torch.cuda.mem_get_info()[0]/10e8:.1f} GB free GPU.", Configs)
else: 
    vprint(f"No GPU Memory.", Configs)

2022-08-20 10:45: Memory info: 2.9 GB free GPU.


In [9]:
valid_transform = transforms.Compose([
    transforms.Resize((320,320)),
    transforms.ToTensor(), 
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Disease Prediction 

## Validaiton Dataloaders

In [10]:
cxp_disease_valid_dataset = CheXpertDiseaseDataset(data_dir=Configs.CHEXPERT_DATA_DIR, 
                                                   labels_filename=Configs.CHEXPERT_VALID_LABELS_FILENAME,
                                                   transform=valid_transform)
cxp_disease_valid_dataset.df_labels = cxp_disease_valid_dataset.df_labels[:Configs.VALID_SIZE_DEBUG]
cxp_disease_valid_dataloader = DataLoader(cxp_disease_valid_dataset, batch_size=Configs.BATCH_SIZE, shuffle=False)
len(cxp_disease_valid_dataset)

960

In [11]:
if set(['race', 'gender', 'age']).issubset(cxp_disease_valid_dataset.df_labels.columns):
    display(cxp_disease_valid_dataset.df_labels.groupby(['race', 'gender', 'age']).size())

race      gender  age  
Asian     Female  20-40    40
                  40-70    40
                  70-90    40
          Male    20-40    40
                  40-70    40
                  70-90    40
Black     Female  20-40    40
                  40-70    40
                  70-90    40
          Male    20-40    40
                  40-70    40
                  70-90    40
Hispanic  Female  20-40    40
                  40-70    40
                  70-90    40
          Male    20-40    40
                  40-70    40
                  70-90    40
White     Female  20-40    40
                  40-70    40
                  70-90    40
          Male    20-40    40
                  40-70    40
                  70-90    40
dtype: int64

In [12]:
cxp_disease_valid_dataset.df_labels

Unnamed: 0,original_path,Sex,Age,Frontal/Lateral,AP/PA,No Finding,Enlarged Cardiomediastinum,Cardiomegaly,Lung Opacity,Lung Lesion,Edema,Consolidation,Pneumonia,Atelectasis,Pneumothorax,Pleural Effusion,Pleural Other,Fracture,Support Devices,patient_id,PATIENT,GENDER,AGE_AT_CXR,PRIMARY_RACE,ETHNICITY,race,Asian,Black,Hispanic,White,age,gender,img_path,study,view
0,CheXpert-v1.0-small/train/patient27841/study2/...,Female,31,Frontal,PA,,,0.0,1.0,1.0,0.0,0.0,,0.0,0.0,1.0,,,1.0,patient27841,patient27841,Female,31.0,Asian,Non-Hispanic/Non-Latino,Asian,1.0,0.0,0.0,0.0,20-40,Female,/home/student/MLH/debiasing-racial-effect-in-m...,study2,view1_frontal.jpg
1,CheXpert-v1.0-small/train/patient34905/study1/...,Female,39,Frontal,AP,,,0.0,,,0.0,0.0,,0.0,0.0,0.0,,0.0,,patient34905,patient34905,Female,43.0,Asian,Non-Hispanic/Non-Latino,Asian,1.0,0.0,0.0,0.0,20-40,Female,/home/student/MLH/debiasing-racial-effect-in-m...,study1,view1_frontal.jpg
2,CheXpert-v1.0-small/train/patient11116/study1/...,Female,29,Frontal,PA,,,0.0,1.0,,0.0,0.0,,0.0,,0.0,,,,patient11116,patient11116,Female,29.0,Asian,Non-Hispanic/Non-Latino,Asian,1.0,0.0,0.0,0.0,20-40,Female,/home/student/MLH/debiasing-racial-effect-in-m...,study1,view1_frontal.jpg
3,CheXpert-v1.0-small/train/patient28123/study1/...,Female,29,Frontal,PA,1.0,,0.0,,,0.0,0.0,,0.0,0.0,0.0,,,,patient28123,patient28123,Female,29.0,Asian,Non-Hispanic/Non-Latino,Asian,1.0,0.0,0.0,0.0,20-40,Female,/home/student/MLH/debiasing-racial-effect-in-m...,study1,view1_frontal.jpg
4,CheXpert-v1.0-small/train/patient43253/study2/...,Female,34,Frontal,AP,,,0.0,,,1.0,0.0,,0.0,,0.0,,,,patient43253,patient43253,Female,34.0,Asian,Non-Hispanic/Non-Latino,Asian,1.0,0.0,0.0,0.0,20-40,Female,/home/student/MLH/debiasing-racial-effect-in-m...,study2,view1_frontal.jpg
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
955,CheXpert-v1.0-small/train/patient45732/study1/...,Male,79,Lateral,,1.0,,0.0,,,0.0,0.0,,0.0,0.0,0.0,,,-1.0,patient45732,patient45732,Male,79.0,White,Non-Hispanic/Non-Latino,White,0.0,0.0,0.0,1.0,70-90,Male,/home/student/MLH/debiasing-racial-effect-in-m...,study1,view2_lateral.jpg
956,CheXpert-v1.0-small/train/patient28971/study3/...,Male,77,Frontal,AP,,,0.0,1.0,,0.0,0.0,,0.0,1.0,1.0,,,1.0,patient28971,patient28971,Male,77.0,White,Non-Hispanic/Non-Latino,White,0.0,0.0,0.0,1.0,70-90,Male,/home/student/MLH/debiasing-racial-effect-in-m...,study3,view1_frontal.jpg
957,CheXpert-v1.0-small/train/patient13442/study1/...,Male,88,Lateral,,,,0.0,,,0.0,1.0,-1.0,1.0,,0.0,,,,patient13442,patient13442,Male,88.0,"White, non-Hispanic",Unknown,White,0.0,0.0,0.0,1.0,70-90,Male,/home/student/MLH/debiasing-racial-effect-in-m...,study1,view2_lateral.jpg
958,CheXpert-v1.0-small/train/patient31369/study2/...,Male,85,Frontal,AP,,,0.0,,,1.0,0.0,1.0,0.0,,0.0,,,,patient31369,patient31369,Male,80.0,White,Non-Hispanic/Non-Latino,White,0.0,0.0,0.0,1.0,70-90,Male,/home/student/MLH/debiasing-racial-effect-in-m...,study2,view1_frontal.jpg


In [13]:
# group_sample_size = {
#     "Black_40-70_F": 100,
#     "Hispanic_40-70_F": 100
# }

In [14]:
# cxr_disease_valid_dataset = CXRDataset.download_dataset(400, Mode.Disease, Configs.CXR_DATA_DIR,
#                                                         Configs.CXR_VALID_LABELS_FILENAME, **Configs.CXR_FILENAMES,
#                                                         transform=valid_transform, target_transform=None)
cxr_disease_valid_dataset = CXRDataset(Mode.Disease, Configs.CXR_DATA_DIR, Configs.CXR_VALID_LABELS_FILENAME,
                                       transform=valid_transform)
cxr_disease_valid_dataset.df_labels = cxr_disease_valid_dataset.df_labels[:Configs.VALID_SIZE_DEBUG]
cxr_disease_valid_dataloader = DataLoader(cxr_disease_valid_dataset, batch_size=Configs.BATCH_SIZE, shuffle=False)
len(cxr_disease_valid_dataset)

9600

In [15]:
cxr_disease_valid_dataset.df_labels.groupby(['race', 'gender', 'age']).size()

race      gender  age  
Asian     F       20-40    400
                  40-70    400
                  70-90    400
          M       20-40    400
                  40-70    400
                  70-90    400
Black     F       20-40    400
                  40-70    400
                  70-90    400
          M       20-40    400
                  40-70    400
                  70-90    400
Hispanic  F       20-40    400
                  40-70    400
                  70-90    400
          M       20-40    400
                  40-70    400
                  70-90    400
White     F       20-40    400
                  40-70    400
                  70-90    400
          M       20-40    400
                  40-70    400
                  70-90    400
dtype: int64

## Pretrained Models 

In [16]:
_, _, files = next(os.walk(Configs.CHEXPERT_DISEASE_TRAINED_MODELS_DIR))
disease_trained_models = [os.path.join(Configs.CHEXPERT_DISEASE_TRAINED_MODELS_DIR, file) for file in files]
len(disease_trained_models)

2

In [17]:
disease_trained_models = disease_trained_models[1:]
disease_trained_models

['CheXpert/disease_prediction/trained_models/2022_08_20-04_18__densenet121_disease_demo30_no_uV2__epoch-4__iter-13145__batch_size-16__trainLastLoss-0.3823__validAUC-0.8581__orgValidAUC-0.8922.dict']

In [18]:
disease_model = densenet121()
num_features = disease_model.classifier.in_features
disease_model.classifier = nn.Sequential(
    nn.Linear(num_features, num_features, bias=True),
    nn.ReLU(),
    nn.Dropout(p=0.1),
    nn.Linear(in_features=num_features, out_features=Configs.NUM_DISEASE_CLASSES, bias=True)
)
disease_model.eval()
not disease_model.training

True

In [19]:
disease_model, results, _, _ = shared_utils.load_statedict(disease_model, disease_trained_models[0], Configs)
disease_model = to_gpu(disease_model)

2022-08-20 10:45: Loading model - CheXpert/disease_prediction/trained_models/2022_08_20-04_18__densenet121_disease_demo30_no_uV2__epoch-4__iter-13145__batch_size-16__trainLastLoss-0.3823__validAUC-0.8581__orgValidAUC-0.8922.dict


## Predictions

In [20]:
df_res_disease = pd.DataFrame(columns=Configs.DISEASE_ANNOTATIONS_COLUMNS + ['Mean'])
df_res_disease

Unnamed: 0,Atelectasis,Cardiomegaly,Consolidation,Edema,Pleural Effusion,Mean


In [69]:
cxp_disease_labels, cxp_disease_outputs = shared_utils.get_metric_tensors(disease_model, cxp_disease_valid_dataloader, Configs,
                                                  apply_on_outputs=lambda x: torch.sigmoid(x),
                                                  by_study=False, challenge_ann_only=None)
df_res_disease.loc['CXP'] = add_mean_to_list(shared_utils.auc_score(cxp_disease_labels, cxp_disease_outputs, per_class=True))

In [22]:
cxr_disease_labels, cxr_disease_outputs = shared_utils.get_metric_tensors(disease_model, cxr_disease_valid_dataloader, Configs,
                                                  apply_on_outputs=lambda x: torch.sigmoid(x),
                                                  by_study=False, challenge_ann_only=None)
df_res_disease.loc['CXR'] = add_mean_to_list(shared_utils.auc_score(cxr_disease_labels, cxr_disease_outputs, per_class=True))

In [70]:
df_res_disease.sort_values(by="Mean", ascending=False, inplace=True)
df_res_disease = df_res_disease.round(2)

In [71]:
print(df_res_disease)

     Atelectasis  Cardiomegaly  Consolidation  Edema  Pleural Effusion  Mean
CXR         0.83          0.81           0.83   0.92              0.93  0.87
CXP         0.49          0.50           0.50   0.51              0.48  0.50


Bad pipe message: %s [b'}\x97\xa8ny\x18\xdb\xbd\xd7\x17tu|\xcd5\xb6\x1b\r\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac']
Bad pipe message: %s [b"H\r\x01{\xfa\x18\xf6\xbe\x8c\xa6\x7f\x11\xe7\x12\x946\xb5P\x00\x00\xa6\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0s\xc0w\x00\xc4\x00\xc3\xc0#\xc0'\x00g\x00@\xc0r\xc0v\x00\xbe\x00\xbd\xc0\n\xc0\x14\x009\x008\x00\x88\x00\x87\xc0\t\xc0\x13\x003\x002\x00\x9a\x00\x99\x00E\x00D\xc0\x07\xc0\x11\xc0\x08\xc0\x12\x00\x16\x00\x13\x00\x9d\xc0\xa1\xc0\x9d\xc0Q\x00\x9c\xc0\xa0\xc0\x9c\xc0P\x00=\x00\xc0\x00<\x00\xba\x005\x00\x84\x00/\x00\x96\x00A\x00\x05\x00\n\x00\xff\x01\x00\x00j\x00\x00\x00\x0e\x00\x0c"]
Bad pipe message: %s [b'-\xbf\xddN\x8b\xeeLj\xb4\x10\xb4\x1f\x8a\x96Lr\xc7\xae\x00\x00\xa

     Atelectasis  Cardiomegaly  Consolidation  Edema  Pleural Effusion  Mean
CXR         0.83          0.81           0.83   0.92              0.93  0.87
CXP         0.76          0.87           0.82   0.89              0.93  0.85


In [25]:
#    Atelectasis  Cardiomegaly  Consolidation  Edema  Pleural Effusion  Mean
# CXP         0.75          0.88           0.84   0.89              0.93  0.86
#      Atelectasis  Cardiomegaly  Consolidation  Edema  Pleural Effusion  Mean
# CXP         0.75          0.91           0.86    0.9              0.94  0.87

# CXP         0.75          0.91           0.86   0.90              0.94  0.87
# CXR         0.83          0.81           0.84   0.93              0.93  0.87

## Performance Per Protected Groups 

In [26]:
# race, age group, and gender to ChexPertDiseaseDatatset
if set(['race', 'gender', 'age']).issubset(cxp_disease_valid_dataset.df_labels.columns):
    cxp_disease_df_labels = cxp_disease_valid_dataset.df_labels
else:
    cxp_demo_df = CheXpertRaceDataset.generate_race_dummies(pd.read_csv(os.path.join(Configs.CHEXPERT_DATA_DIR,
                                                                                     Configs.CHEXPERT_DEMO_FILENAME)),
                                                           'PRIMARY_RACE', Configs.CHEXPERT_RACE_DICT)
    cxp_race_df = cxp_demo_df[['PATIENT', 'race'] + Configs.RACE_ANNOTATIONS_COLUMNS].drop_duplicates()
    cxp_disease_df_labels = cxp_disease_valid_dataset.df_labels.merge(cxp_race_df, how='left',
                                                                                    left_on='patient_id', right_on='PATIENT')
    cxp_disease_df_labels.race.fillna('Other', inplace=True)
    cxp_disease_df_labels['age'] = cxp_disease_df_labels.Age.apply(shared_utils.age_to_age_group)
    cxp_disease_df_labels['gender'] = cxp_disease_df_labels.Sex
cxp_disease_df_labels.head(2)

Unnamed: 0,original_path,Sex,Age,Frontal/Lateral,AP/PA,No Finding,Enlarged Cardiomediastinum,Cardiomegaly,Lung Opacity,Lung Lesion,Edema,Consolidation,Pneumonia,Atelectasis,Pneumothorax,Pleural Effusion,Pleural Other,Fracture,Support Devices,patient_id,PATIENT,GENDER,AGE_AT_CXR,PRIMARY_RACE,ETHNICITY,race,Asian,Black,Hispanic,White,age,gender,img_path,study,view
0,CheXpert-v1.0-small/train/patient27841/study2/...,Female,31,Frontal,PA,,,0.0,1.0,1.0,0.0,0.0,,0.0,0.0,1.0,,,1.0,patient27841,patient27841,Female,31.0,Asian,Non-Hispanic/Non-Latino,Asian,1.0,0.0,0.0,0.0,20-40,Female,/home/student/MLH/debiasing-racial-effect-in-m...,study2,view1_frontal.jpg
1,CheXpert-v1.0-small/train/patient34905/study1/...,Female,39,Frontal,AP,,,0.0,,,0.0,0.0,,0.0,0.0,0.0,,0.0,,patient34905,patient34905,Female,43.0,Asian,Non-Hispanic/Non-Latino,Asian,1.0,0.0,0.0,0.0,20-40,Female,/home/student/MLH/debiasing-racial-effect-in-m...,study1,view1_frontal.jpg


In [27]:
cxr_disease_df_labels = cxr_disease_valid_dataset.df_labels.copy()
cxr_disease_df_labels.gender.replace({"M": "Male", "F": "Female"}, inplace=True)
cxr_disease_df_labels.head(2)

Unnamed: 0,subject_id,study_id,split,dicom_id,ethnicity,race,age,gender,Atelectasis,Cardiomegaly,Consolidation,Edema,Pleural Effusion,folder_number,img_path
0,18855302,53537225,train,bafbef15-550a6520-fbcd3b2c-81552b65-3050c322,ASIAN,Asian,20-40,Female,0.0,0.0,0.0,0.0,0.0,18,data/MIMIC-CXR-JPG/physionet.org/files/mimic-c...
1,10296904,51710336,train,16c1bc0c-90207d40-b93a861d-ad0bc18e-ac97afb2,ASIAN,Asian,20-40,Female,0.0,0.0,0.0,0.0,0.0,10,data/MIMIC-CXR-JPG/physionet.org/files/mimic-c...


In [28]:
shared_utils.auc_per_protected_group(cxp_disease_df_labels, Mode.Disease, Configs, cxp_disease_labels,
                                     cxp_disease_outputs, protected_groups=['race'])

Unnamed: 0_level_0,Atelectasis,Cardiomegaly,Consolidation,Edema,Pleural Effusion,Mean
race,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Asian,0.75,0.95,0.93,0.89,0.93,0.89
Black,0.74,0.88,0.81,0.9,0.92,0.85
Hispanic,0.74,0.86,0.83,0.85,0.94,0.844
White,0.76,0.67,0.78,0.92,0.94,0.814


In [61]:
t = cxp_disease_df_labels.groupby(['race', 'gender', 'age'])[Configs.DISEASE_ANNOTATIONS_COLUMNS].sum()
t['sum'] = t.sum(axis=1)
t

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Atelectasis,Cardiomegaly,Consolidation,Edema,Pleural Effusion,sum
race,gender,age,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
Asian,Female,20-40,7.0,2.0,3.0,4.0,11.0,27.0
Asian,Female,40-70,4.0,4.0,2.0,5.0,14.0,29.0
Asian,Female,70-90,7.0,9.0,1.0,8.0,11.0,36.0
Asian,Male,20-40,5.0,3.0,2.0,0.0,9.0,19.0
Asian,Male,40-70,2.0,3.0,0.0,6.0,11.0,22.0
Asian,Male,70-90,12.0,10.0,4.0,6.0,17.0,49.0
Black,Female,20-40,3.0,7.0,3.0,6.0,5.0,24.0
Black,Female,40-70,6.0,2.0,1.0,6.0,6.0,21.0
Black,Female,70-90,7.0,9.0,5.0,5.0,7.0,33.0
Black,Male,20-40,0.0,1.0,2.0,0.0,3.0,6.0


In [60]:
t = cxr_disease_df_labels.groupby('race')[Configs.DISEASE_ANNOTATIONS_COLUMNS].sum()
t['sum'] = t.sum(axis=1)
t

Unnamed: 0_level_0,Atelectasis,Cardiomegaly,Consolidation,Edema,Pleural Effusion,sum
race,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Asian,233.0,271.0,45.0,95.0,237.0,881.0
Black,189.0,282.0,34.0,70.0,112.0,687.0
Hispanic,239.0,248.0,49.0,96.0,162.0,794.0
White,266.0,224.0,43.0,98.0,262.0,893.0


In [29]:
shared_utils.auc_per_protected_group(cxp_disease_df_labels, Mode.Disease, Configs, cxp_disease_labels,
                                     cxp_disease_outputs, protected_groups=['gender'])

Unnamed: 0_level_0,Atelectasis,Cardiomegaly,Consolidation,Edema,Pleural Effusion,Mean
gender,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Female,0.73,0.82,0.86,0.89,0.93,0.846
Male,0.78,0.94,0.77,0.9,0.94,0.866


In [30]:
shared_utils.auc_per_protected_group(cxp_disease_df_labels, Mode.Disease, Configs, cxp_disease_labels,
                                     cxp_disease_outputs, protected_groups=['age'])

Unnamed: 0_level_0,Atelectasis,Cardiomegaly,Consolidation,Edema,Pleural Effusion,Mean
age,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
20-40,0.79,0.87,0.93,0.89,0.96,0.888
40-70,0.78,0.85,0.81,0.9,0.94,0.856
70-90,0.69,0.86,0.72,0.86,0.91,0.808


In [56]:
shared_utils.auc_per_protected_group(cxp_disease_df_labels, Mode.Disease, Configs, cxp_disease_labels,
                                     cxp_disease_outputs, protected_groups=['race', 'age','gender'])

Only one class present in y_true. ROC AUC score is not defined in that case.
Only one class present in y_true. ROC AUC score is not defined in that case.
Only one class present in y_true. ROC AUC score is not defined in that case.
Only one class present in y_true. ROC AUC score is not defined in that case.
Only one class present in y_true. ROC AUC score is not defined in that case.
Only one class present in y_true. ROC AUC score is not defined in that case.
Only one class present in y_true. ROC AUC score is not defined in that case.
Only one class present in y_true. ROC AUC score is not defined in that case.
Only one class present in y_true. ROC AUC score is not defined in that case.


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Atelectasis,Cardiomegaly,Consolidation,Edema,Pleural Effusion,Mean
race,age,gender,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
Asian,20-40,Female,0.59,0.74,0.98,0.96,0.92,0.838
Asian,20-40,Male,0.82,0.99,0.91,,0.97,0.9225
Asian,40-70,Female,0.88,0.99,0.87,0.86,1.0,0.92
Asian,40-70,Male,0.88,0.99,,0.84,0.92,0.9075
Asian,70-90,Female,0.77,0.9,1.0,0.85,0.92,0.888
Asian,70-90,Male,0.68,0.96,0.99,0.9,0.87,0.88
Black,20-40,Female,0.84,0.83,0.97,0.86,0.97,0.894
Black,20-40,Male,,0.95,0.76,,0.91,0.873333
Black,40-70,Female,0.56,0.96,0.92,0.86,0.8,0.82
Black,40-70,Male,0.75,0.96,0.72,0.99,0.93,0.87


In [31]:
shared_utils.auc_per_protected_group(cxr_disease_df_labels, Mode.Disease, Configs, cxr_disease_labels,
                                     cxr_disease_outputs, protected_groups=['race'])

Unnamed: 0_level_0,Atelectasis,Cardiomegaly,Consolidation,Edema,Pleural Effusion,Mean
race,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Asian,0.84,0.8,0.79,0.92,0.91,0.852
Black,0.82,0.8,0.89,0.92,0.92,0.87
Hispanic,0.82,0.83,0.81,0.93,0.93,0.864
White,0.83,0.83,0.85,0.91,0.94,0.872


In [32]:
shared_utils.auc_per_protected_group(cxr_disease_df_labels, Mode.Disease, Configs, cxr_disease_labels,
                                     cxr_disease_outputs, protected_groups=['gender'])

Unnamed: 0_level_0,Atelectasis,Cardiomegaly,Consolidation,Edema,Pleural Effusion,Mean
gender,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Female,0.82,0.82,0.84,0.93,0.94,0.87
Male,0.83,0.8,0.83,0.91,0.92,0.858


In [33]:
shared_utils.auc_per_protected_group(cxr_disease_df_labels, Mode.Disease, Configs, cxr_disease_labels,
                                     cxr_disease_outputs, protected_groups=['age'])

Unnamed: 0_level_0,Atelectasis,Cardiomegaly,Consolidation,Edema,Pleural Effusion,Mean
age,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
20-40,0.86,0.8,0.84,0.95,0.94,0.878
40-70,0.83,0.79,0.87,0.93,0.94,0.872
70-90,0.75,0.77,0.77,0.88,0.89,0.812


In [34]:
shared_utils.auc_per_protected_group(cxr_disease_df_labels, Mode.Disease, Configs, cxr_disease_labels,
                                     cxr_disease_outputs, protected_groups=['gender', 'age','race'])

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Atelectasis,Cardiomegaly,Consolidation,Edema,Pleural Effusion,Mean
gender,age,race,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
Female,20-40,Asian,0.83,0.69,0.68,0.95,0.96,0.822
Female,20-40,Black,0.82,0.79,0.98,0.99,0.97,0.91
Female,20-40,Hispanic,0.83,0.85,0.94,0.98,0.98,0.916
Female,20-40,White,0.84,0.87,0.97,0.92,0.97,0.914
Female,40-70,Asian,0.79,0.73,0.89,0.95,0.92,0.856
Female,40-70,Black,0.77,0.77,0.82,0.99,0.9,0.85
Female,40-70,Hispanic,0.8,0.82,0.75,0.97,0.96,0.86
Female,40-70,White,0.83,0.87,0.97,0.81,0.97,0.89
Female,70-90,Asian,0.83,0.79,0.74,0.93,0.91,0.84
Female,70-90,Black,0.71,0.72,0.88,0.85,0.86,0.804


# Race Prediction

## Validation Dataloaders

In [35]:
cxp_race_valid_dataset = CheXpertRaceDataset(data_dir=Configs.CHEXPERT_DATA_DIR, demo_filename=Configs.CHEXPERT_DEMO_FILENAME, 
                                             labels_filename=Configs.CHEXPERT_VALID_LABELS_FILENAME, transform=valid_transform,
                                             label_transform=False)
cxp_race_valid_dataset.df_labels = cxp_race_valid_dataset.df_labels[:Configs.VALID_SIZE_DEBUG]
cxp_race_valid_dataloader = DataLoader(cxp_race_valid_dataset, batch_size=Configs.BATCH_SIZE, shuffle=False)
len(cxp_race_valid_dataset)

960

In [36]:
# cxr_disease_valid_dataset = CXRDataset.download_dataset(10, cxr_mode.Race, Configs.CXR_DATA_DIR,
#                                                         Configs.CXR_VALID_LABELS_FILENAME, **Configs.CXR_FILENAMES,
#                                                         transform=valid_transform, target_transform=None)
cxr_race_valid_dataset = CXRDataset(Mode.Race, Configs.CXR_DATA_DIR, Configs.CXR_VALID_LABELS_FILENAME,
                                    transform=valid_transform)
cxr_race_valid_dataset.df_labels = cxr_race_valid_dataset.df_labels[:Configs.VALID_SIZE_DEBUG]
cxr_race_valid_dataloader = DataLoader(cxr_race_valid_dataset, batch_size=Configs.BATCH_SIZE, shuffle=False)
len(cxr_race_valid_dataset)

9600

## Pretrained Models 

In [37]:
_, _, files = next(os.walk(Configs.CHEXPERT_RACE_TRAINED_MODELS_DIR))
race_trained_models = [os.path.join(Configs.CHEXPERT_RACE_TRAINED_MODELS_DIR, file) for file in files]
race_model_versions = [p.split('__')[1] for p in race_trained_models]
len(race_trained_models)

8

In [38]:
models_dict = {}
for model_version, model_path in zip(race_model_versions, race_trained_models):
    model = densenet121()
    if "shallow" in model_version:
        shallow_denseblock = int(model_version.split('_')[2][10:])
        layer_offset = 3 + 2 * shallow_denseblock
        num_features = model.features[layer_offset].norm.num_features
        model = model.features[:layer_offset]
        classifier_module = nn.Sequential(
            nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),   
            nn.Flatten(start_dim=1),
            nn.Linear(in_features=num_features, out_features=num_features, bias=True),
            nn.Dropout(p=0.1),
            nn.Linear(in_features=num_features, out_features=Configs.NUM_RACE_CLASSES, bias=True))
        model.add_module('classifier', classifier_module)
    else:
        num_features = model.classifier.in_features
        model.classifier = nn.Sequential(
            nn.Linear(num_features, num_features, bias=True),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(in_features=num_features, out_features=Configs.NUM_RACE_CLASSES, bias=True))
    model.eval()
    model, results, _, _ = shared_utils.load_statedict(model, model_path, Configs, device='cpu')
    models_dict[model_version] = model

2022-08-20 10:57: Loading model - CheXpert/race_prediction/trained_models/2022_07_19-15_29__densenet121_race_denseblock1_freezed__epoch-9__iter-10266__batch_size-16__trainLastLoss-0.2805__validAUC-0.9178.dict
2022-08-20 10:57: Loading model - CheXpert/race_prediction/trained_models/2022_07_19-02_27__densenet121_race_denseblock2_freezed__epoch-9__iter-10266__batch_size-16__trainLastLoss-0.2811__validAUC-0.9094.dict
2022-08-20 10:57: Loading model - CheXpert/race_prediction/trained_models/2022_07_18-10_11__densenet121_race_classifier_freezed__epoch-9__iter-10266__batch_size-16__trainLastLoss-0.732__validAUC-0.6978.dict
2022-08-20 10:57: Loading model - CheXpert/race_prediction/trained_models/2022_07_21-13_21__densenet121_race_denseblock2_shallow__epoch-9__iter-10266__batch_size-16__trainLastLoss-0.3791__validAUC-0.9075.dict
2022-08-20 10:57: Loading model - CheXpert/race_prediction/trained_models/2022_07_18-21_46__densenet121_race_denseblock3_freezed__epoch-9__iter-10266__batch_size-16__

## Predictions 

In [39]:
df_res_race = pd.DataFrame(columns=Configs.RACE_ANNOTATIONS_COLUMNS + ['Mean'])
df_res_race

Unnamed: 0,Asian,Black,Hispanic,White,Mean


In [40]:
for model_version, model in tqdm(models_dict.items()):
    model = to_gpu(model)
    cxp_race_labels, cxp_race_outputs = shared_utils.get_metric_tensors(model, cxp_race_valid_dataloader, Configs,
                                                                        apply_on_outputs=lambda x: torch.softmax(x, dim=1),
                                                                        by_study=False, challenge_ann_only=None)
    cxr_race_labels, cxr_race_outputs = shared_utils.get_metric_tensors(model, cxr_race_valid_dataloader, Configs,
                                                                        apply_on_outputs=lambda x: torch.softmax(x, dim=1),
                                                                        by_study=False, challenge_ann_only=None)
    df_res_race.loc[f"CXP_{model_version}"] = add_mean_to_list(shared_utils.auc_score(cxp_race_labels, cxp_race_outputs, per_class=True))
    df_res_race.loc[f"CXR_{model_version}"] = add_mean_to_list(shared_utils.auc_score(cxr_race_labels, cxr_race_outputs, per_class=True))
    model.cpu()

  0%|          | 0/8 [00:00<?, ?it/s]

In [41]:
cxp_race_valid_dataloader.dataset.df_labels#.loc[idx]

Unnamed: 0,original_path,Sex,Age,Frontal/Lateral,AP/PA,No Finding,Enlarged Cardiomediastinum,Cardiomegaly,Lung Opacity,Lung Lesion,Edema,Consolidation,Pneumonia,Atelectasis,Pneumothorax,Pleural Effusion,Pleural Other,Fracture,Support Devices,patient_id,PATIENT,GENDER,AGE_AT_CXR,PRIMARY_RACE,ETHNICITY,race,Asian,Black,Hispanic,White,age,gender,img_path,study,view
0,CheXpert-v1.0-small/train/patient27841/study2/...,Female,31,Frontal,PA,,,0.0,1.0,1.0,0.0,0.0,,0.0,0.0,1.0,,,1.0,patient27841,patient27841,Female,31.0,Asian,Non-Hispanic/Non-Latino,Asian,1.0,0.0,0.0,0.0,20-40,Female,/home/student/MLH/debiasing-racial-effect-in-m...,study2,view1_frontal.jpg
1,CheXpert-v1.0-small/train/patient34905/study1/...,Female,39,Frontal,AP,,,0.0,,,0.0,0.0,,0.0,0.0,0.0,,0.0,,patient34905,patient34905,Female,43.0,Asian,Non-Hispanic/Non-Latino,Asian,1.0,0.0,0.0,0.0,20-40,Female,/home/student/MLH/debiasing-racial-effect-in-m...,study1,view1_frontal.jpg
2,CheXpert-v1.0-small/train/patient11116/study1/...,Female,29,Frontal,PA,,,0.0,1.0,,0.0,0.0,,0.0,,0.0,,,,patient11116,patient11116,Female,29.0,Asian,Non-Hispanic/Non-Latino,Asian,1.0,0.0,0.0,0.0,20-40,Female,/home/student/MLH/debiasing-racial-effect-in-m...,study1,view1_frontal.jpg
3,CheXpert-v1.0-small/train/patient28123/study1/...,Female,29,Frontal,PA,1.0,,0.0,,,0.0,0.0,,0.0,0.0,0.0,,,,patient28123,patient28123,Female,29.0,Asian,Non-Hispanic/Non-Latino,Asian,1.0,0.0,0.0,0.0,20-40,Female,/home/student/MLH/debiasing-racial-effect-in-m...,study1,view1_frontal.jpg
4,CheXpert-v1.0-small/train/patient43253/study2/...,Female,34,Frontal,AP,,,0.0,,,1.0,0.0,,0.0,,0.0,,,,patient43253,patient43253,Female,34.0,Asian,Non-Hispanic/Non-Latino,Asian,1.0,0.0,0.0,0.0,20-40,Female,/home/student/MLH/debiasing-racial-effect-in-m...,study2,view1_frontal.jpg
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
955,CheXpert-v1.0-small/train/patient45732/study1/...,Male,79,Lateral,,1.0,,0.0,,,0.0,0.0,,0.0,0.0,0.0,,,-1.0,patient45732,patient45732,Male,79.0,White,Non-Hispanic/Non-Latino,White,0.0,0.0,0.0,1.0,70-90,Male,/home/student/MLH/debiasing-racial-effect-in-m...,study1,view2_lateral.jpg
956,CheXpert-v1.0-small/train/patient28971/study3/...,Male,77,Frontal,AP,,,0.0,1.0,,0.0,0.0,,0.0,1.0,1.0,,,1.0,patient28971,patient28971,Male,77.0,White,Non-Hispanic/Non-Latino,White,0.0,0.0,0.0,1.0,70-90,Male,/home/student/MLH/debiasing-racial-effect-in-m...,study3,view1_frontal.jpg
957,CheXpert-v1.0-small/train/patient13442/study1/...,Male,88,Lateral,,,,0.0,,,0.0,1.0,-1.0,1.0,,0.0,,,,patient13442,patient13442,Male,88.0,"White, non-Hispanic",Unknown,White,0.0,0.0,0.0,1.0,70-90,Male,/home/student/MLH/debiasing-racial-effect-in-m...,study1,view2_lateral.jpg
958,CheXpert-v1.0-small/train/patient31369/study2/...,Male,85,Frontal,AP,,,0.0,,,1.0,0.0,1.0,0.0,,0.0,,,,patient31369,patient31369,Male,80.0,White,Non-Hispanic/Non-Latino,White,0.0,0.0,0.0,1.0,70-90,Male,/home/student/MLH/debiasing-racial-effect-in-m...,study2,view1_frontal.jpg


In [42]:
df_res_race.sort_values(by="Mean", ascending=False, inplace=True)
df_res_race = df_res_race.round(2)

In [43]:
df_res_race

Unnamed: 0,Asian,Black,Hispanic,White,Mean
CXP_densenet121_race_denseblock2_freezed,0.95,0.95,0.77,0.9,0.9
CXP_densenet121_race,0.95,0.95,0.77,0.89,0.89
CXP_densenet121_race_denseblock1_freezed,0.94,0.94,0.72,0.9,0.88
CXP_densenet121_race_denseblock3_freezed,0.93,0.93,0.75,0.89,0.88
CXP_densenet121_race_denseblock4_freezed,0.92,0.9,0.77,0.83,0.85
CXP_densenet121_race_denseblock2_shallow,0.93,0.93,0.59,0.87,0.83
CXR_densenet121_race,0.9,0.87,0.66,0.87,0.83
CXR_densenet121_race_denseblock1_freezed,0.9,0.88,0.61,0.88,0.81
CXR_densenet121_race_denseblock2_freezed,0.87,0.88,0.61,0.86,0.81
CXR_densenet121_race_denseblock2_shallow,0.88,0.86,0.58,0.86,0.8


## Performance Per Protected Groups

In [44]:
# race, age group, and gender to ChexPertDiseaseDatatset
cxp_race_df_labels = cxp_race_valid_dataset.df_labels[['PATIENT', 'race', 'AGE_AT_CXR', 'GENDER'] + Configs.RACE_ANNOTATIONS_COLUMNS].drop_duplicates()
cxp_race_df_labels['age'] = cxp_race_df_labels.AGE_AT_CXR.apply(shared_utils.age_to_age_group)
cxp_race_df_labels['gender'] = cxp_race_df_labels.GENDER
cxp_race_df_labels.head(2)

Unnamed: 0,PATIENT,race,AGE_AT_CXR,GENDER,Asian,Black,Hispanic,White,age,gender
0,patient27841,Asian,31.0,Female,1.0,0.0,0.0,0.0,20-40,Female
1,patient34905,Asian,43.0,Female,1.0,0.0,0.0,0.0,40-70,Female


In [45]:
cxr_race_df_labels = cxr_race_valid_dataset.df_labels.copy()
cxr_race_df_labels.gender.replace({"M": "Male", "F": "Female"}, inplace=True)
cxr_race_df_labels.head(2)

Unnamed: 0,subject_id,study_id,split,dicom_id,ethnicity,race,age,gender,Atelectasis,Cardiomegaly,Consolidation,Edema,Pleural Effusion,folder_number,img_path,Asian,Black,Hispanic,White
0,18855302,53537225,train,bafbef15-550a6520-fbcd3b2c-81552b65-3050c322,ASIAN,Asian,20-40,Female,0.0,0.0,0.0,0.0,0.0,18,data/MIMIC-CXR-JPG/physionet.org/files/mimic-c...,1,0,0,0
1,10296904,51710336,train,16c1bc0c-90207d40-b93a861d-ad0bc18e-ac97afb2,ASIAN,Asian,20-40,Female,0.0,0.0,0.0,0.0,0.0,10,data/MIMIC-CXR-JPG/physionet.org/files/mimic-c...,1,0,0,0


In [46]:
model = models_dict['densenet121_race']
model = to_gpu(model)
cxp_race_labels, cxp_race_outputs = shared_utils.get_metric_tensors(model, cxp_race_valid_dataloader, Configs,
                                                                    apply_on_outputs=lambda x: torch.softmax(x, dim=1),
                                                                    by_study=False, challenge_ann_only=None)
cxr_race_labels, cxr_race_outputs = shared_utils.get_metric_tensors(model, cxr_race_valid_dataloader, Configs,
                                                                    apply_on_outputs=lambda x: torch.softmax(x, dim=1),
                                                                    by_study=False, challenge_ann_only=None)

In [47]:
shared_utils.auc_per_protected_group(cxp_race_df_labels, Mode.Race, Configs, cxp_race_labels,
                                     cxp_race_outputs, protected_groups=['gender'])

Unnamed: 0_level_0,Asian,Black,Hispanic,White,Mean
gender,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Female,0.96,0.96,0.8,0.91,0.9075
Male,0.95,0.96,0.78,0.92,0.9025


In [48]:
shared_utils.auc_per_protected_group(cxp_race_df_labels, Mode.Race, Configs, cxp_race_labels,
                                     cxp_race_outputs, protected_groups=['age'])

Unnamed: 0_level_0,Asian,Black,Hispanic,White,Mean
age,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
20-40,0.93,0.98,0.81,0.91,0.9075
40-70,0.97,0.95,0.78,0.91,0.9025
70-90,0.96,0.97,0.81,0.95,0.9225


In [49]:
shared_utils.auc_per_protected_group(cxr_race_df_labels, Mode.Race, Configs, cxr_race_labels,
                                     cxr_race_outputs, protected_groups=['gender'])

Unnamed: 0_level_0,Asian,Black,Hispanic,White,Mean
gender,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Female,0.9,0.86,0.66,0.86,0.82
Male,0.91,0.88,0.66,0.87,0.83


In [50]:
shared_utils.auc_per_protected_group(cxr_race_df_labels, Mode.Race, Configs, cxr_race_labels,
                                     cxr_race_outputs, protected_groups=['age'])

Unnamed: 0_level_0,Asian,Black,Hispanic,White,Mean
age,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
20-40,0.87,0.88,0.65,0.84,0.81
40-70,0.93,0.89,0.67,0.89,0.845
70-90,0.92,0.87,0.66,0.87,0.83


In [51]:
shared_utils.auc_per_protected_group(cxr_race_df_labels, Mode.Race, Configs, cxr_race_labels,
                                     cxr_race_outputs, protected_groups=['gender', 'age'])

Unnamed: 0_level_0,Unnamed: 1_level_0,Asian,Black,Hispanic,White,Mean
gender,age,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Female,20-40,0.88,0.88,0.64,0.85,0.8125
Female,40-70,0.93,0.88,0.67,0.89,0.8425
Female,70-90,0.9,0.86,0.67,0.85,0.82
Male,20-40,0.87,0.88,0.66,0.83,0.81
Male,40-70,0.93,0.89,0.68,0.88,0.845
Male,70-90,0.94,0.88,0.66,0.89,0.8425


# Grad-Cam

In [52]:
model = models_dict['densenet121_race']
model = to_gpu(model)
model

DenseNet(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu

In [53]:
imgs, _ = next(iter(cxp_race_valid_dataloader))
imgs = to_gpu(imgs)
pred = model(imgs)

In [54]:
img = imgs[0]
pred = pred[0]

In [55]:
pred.argmax(dim=1)

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

In [None]:
pred = model(img).argmax(dim=1)

In [None]:
class DenseNet(nn.Module):
    def __init__(self):
        super(DenseNet, self).__init__()
        
        # get the pretrained DenseNet201 network
        self.densenet = densenet201(pretrained=True)
        
        # disect the network to access its last convolutional layer
        self.features_conv = self.densenet.features
        
        # add the average global pool
        self.global_avg_pool = nn.AvgPool2d(kernel_size=7, stride=1)
        
        # get the classifier of the vgg19
        self.classifier = self.densenet.classifier
        
        # placeholder for the gradients
        self.gradients = None
    
    # hook for the gradients of the activations
    def activations_hook(self, grad):
        self.gradients = grad
        
    def forward(self, x):
        x = self.features_conv(x)
        
        # register the hook
        h = x.register_hook(self.activations_hook)
        
        # don't forget the pooling
        x = self.global_avg_pool(x)
        x = x.view((1, 1920))
        x = self.classifier(x)
        return x
    
    def get_activations_gradient(self):
        return self.gradients
    
    def get_activations(self, x):
        return self.features_conv(x)