In [1]:
from HistoTransfer.compute_feat import *
from HistoTransfer.dataloader import *
from HistoTransfer.model import *
from HistoTransfer.train import *
from HistoTransfer.eval_model import *
from HistoTransfer.utils import *

import torch.optim as optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Stage 1

- Read Patch CSV file:
    - path: location of each patch
    - wsi: Unique identifier for WSI
    - label: Label of WSI (Binary 0 or 1)
    - is_valid: If WSI part of validation cohort
- Get Base Model
- Compute Feature
- Train Model on the feature
- Get attention map for top patches
- Generate filtered CSV

#### Path and Files

In [2]:
df_train_val = pd.read_csv('data/11-3-2021 celiac_normal_train_valid_split.csv')
df_test = pd.read_csv('data/11-3-2021 celiac_normal_test_split.csv')

feature_csv_path = 'csv/backbone_resnet18_imagenet_norm_imagenet/'
model_csv_path = 'trained_model/12-6-2021 celiac_csv_model_alpha1_resnet18_lr1e3.pt'

#### Get Model for feature generation

In [3]:
def get_backbone(model_name='resnet', truncate_layer='layer4'):
    
    
    if model_name == 'resnet':
        resnet = models.resnet18(pretrained=True)
        module_list = []
        for k,v in resnet.named_children():
            module_list.append(v)
            if k == truncate_layer:
                break
        model = nn.Sequential(*module_list+[list(resnet.children())[-2]])
        
    elif model_name == 'densenet':
        densenet = models.densenet121(pretrained=True)
        module_list = []
        for k,v in densenet.features.named_children():
            module_list.append(v)
            if k == truncate_layer:
                break

        model = nn.Sequential(*module_list+[nn.AdaptiveAvgPool2d(output_size=(1, 1))])        
        
    return model

#### Generate Features

In [14]:
# Get Model
model = get_backbone('resnet', 'layer4')
model = model.to(device)
model.eval()

#Data Transformation

data_transforms = albumentations.Compose([
    albumentations.Normalize(),
    ToTensorV2()
    ])    

# Compute Feature
df = pd.concat([df_train_val, df_test])

images = dict(df.groupby('wsi')['path'].apply(list))
compute_feat_wsi(images, model, data_transforms, output_path=feature_csv_path)

100%|██████████| 22/22 [03:01<00:00,  8.23s/it]


#### Frozen Feature Model

In [6]:
# Model
model = WSIFeatClassifier(feat_dim=512).to(device)

# Cross Entropy Loss 
criterion_ce = nn.CrossEntropyLoss()
criterion_dic = {'CE': criterion_ce}

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [7]:
model = train_csv_model(model, 
                        criterion_dic, 
                        optimizer, 
                        df_train_val, 
                        feature_csv_path, 
                        alpha=1., 
                        beta=0., 
                        num_epochs=20, \
                        fpath=model_csv_path)

Epoch 0/19
----------
train Loss WSI: 0.7300  Acc: 0.6667
val Loss WSI: 0.6813  Acc: 0.6000
Epoch 1/19
----------
train Loss WSI: 0.6447  Acc: 0.6667
val Loss WSI: 0.6789  Acc: 0.6000
Epoch 2/19
----------
train Loss WSI: 0.6409  Acc: 0.6667
val Loss WSI: 0.6709  Acc: 0.6000
Epoch 3/19
----------
train Loss WSI: 0.6852  Acc: 0.6667
val Loss WSI: 0.7064  Acc: 0.6000
Epoch 4/19
----------
train Loss WSI: 0.6605  Acc: 0.6667
val Loss WSI: 0.6671  Acc: 0.6000
Epoch 5/19
----------
train Loss WSI: 0.6405  Acc: 0.6667
val Loss WSI: 0.6785  Acc: 0.6000
Epoch 6/19
----------
train Loss WSI: 0.6511  Acc: 0.6667
val Loss WSI: 0.6615  Acc: 0.6000
Epoch 7/19
----------
train Loss WSI: 0.6210  Acc: 0.6667
val Loss WSI: 0.6722  Acc: 0.6000
Epoch 8/19
----------
train Loss WSI: 0.6172  Acc: 0.6667
val Loss WSI: 0.6750  Acc: 0.6000
Epoch 9/19
----------
train Loss WSI: 0.6160  Acc: 0.6667
val Loss WSI: 0.6592  Acc: 0.6000
Epoch 10/19
----------
train Loss WSI: 0.6064  Acc: 0.6667
val Loss WSI: 0.6449 

#### Validate Model on Test Data

In [8]:
validate_csv_model(df_test, model, feature_csv_path)

100%|██████████| 22/22 [00:00<00:00, 29.47it/s]


Accuracy: 0.8181818181818182
Auc Score: 0.8928571428571428


(0.8181818181818182, 0.8928571428571428)

### Stage 2

- Extract Top Attended Patches - Top 64 extracted
- FineTune using those Patches

In [9]:
enc_attn = EncAttn(model).to(device)

df_attn_train_val = get_attn_patches(df_train_val, enc_attn, feature_csv_path)
df_attn_train_val = pd.merge(df_attn_train_val, df_train_val[['path', 'is_valid']])

#### Fine-tuning Model

In [17]:
model = WSIClassifier(base_model='resnet18').to(device)
model.apply(set_bn_eval)
model = model.to(device)

In [18]:
# Data Transforms
data_transforms = albumentations.Compose([
    albumentations.Normalize(),
    ToTensorV2()
    ])    

In [19]:
# Cross Entropy Loss 
criterion_ce = nn.CrossEntropyLoss()
criterion_dic = {'CE': criterion_ce}

# Observe that all parameters are being optimized
optimizer = optim.Adam(model.parameters(), lr=1e-5)

In [20]:
# Copy weights of later layer from csv model

checkpoint = torch.load(model_csv_path)

with torch.no_grad():
    model.tail[0].weight.copy_(checkpoint['state_dict']['tail.0.weight'])
    model.tail[0].bias.copy_(checkpoint['state_dict']['tail.0.bias'])    
    model.tail[2].weight.copy_(checkpoint['state_dict']['tail.2.weight'])
    model.tail[2].bias.copy_(checkpoint['state_dict']['tail.2.bias'])        
    model.attention[0].weight.copy_(checkpoint['state_dict']['attention.0.weight'])
    model.attention[0].bias.copy_(checkpoint['state_dict']['attention.0.bias'])    
    model.attention[2].weight.copy_(checkpoint['state_dict']['attention.2.weight'])
    model.attention[2].bias.copy_(checkpoint['state_dict']['attention.2.bias'])        
    model.classifier[0].weight.copy_(checkpoint['state_dict']['classifier.0.weight'])
    model.classifier[0].bias.copy_(checkpoint['state_dict']['classifier.0.bias'])            
    model.patch_classifier[0].weight.copy_(checkpoint['state_dict']['patch_classifier.0.weight'])
    model.patch_classifier[0].bias.copy_(checkpoint['state_dict']['patch_classifier.0.bias'])                

In [23]:
model = finetune_model(model,criterion_dic, optimizer, df_attn_train_val, data_transforms,\
                       alpha=1., beta=0., num_epochs=20, \
                       fpath='trained_model/12-6-2021 celiac_finetune_model_alpha1_resnet18.pt')

  0%|          | 0/12 [00:00<?, ?it/s]

Epoch 0/19
----------


100%|██████████| 12/12 [00:34<00:00,  2.87s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train Loss WSI: 0.5864  Acc: 0.7500


100%|██████████| 10/10 [00:18<00:00,  1.84s/it]


val Loss WSI: 0.5322  Acc: 0.6000


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch 1/19
----------


100%|██████████| 12/12 [00:33<00:00,  2.75s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train Loss WSI: 0.5773  Acc: 0.8333


100%|██████████| 10/10 [00:16<00:00,  1.70s/it]


val Loss WSI: 0.5106  Acc: 0.7000


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch 2/19
----------


100%|██████████| 12/12 [00:31<00:00,  2.61s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train Loss WSI: 0.4948  Acc: 0.8333


100%|██████████| 10/10 [00:15<00:00,  1.50s/it]
  0%|          | 0/12 [00:00<?, ?it/s]

val Loss WSI: 0.5525  Acc: 0.6000
Epoch 3/19
----------


100%|██████████| 12/12 [00:31<00:00,  2.64s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train Loss WSI: 0.5287  Acc: 0.6667


100%|██████████| 10/10 [00:16<00:00,  1.62s/it]


val Loss WSI: 0.4997  Acc: 0.7000


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch 4/19
----------


100%|██████████| 12/12 [00:30<00:00,  2.56s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train Loss WSI: 0.4954  Acc: 0.8333


100%|██████████| 10/10 [00:17<00:00,  1.71s/it]


val Loss WSI: 0.4662  Acc: 0.7000


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch 5/19
----------


100%|██████████| 12/12 [00:34<00:00,  2.85s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train Loss WSI: 0.5731  Acc: 0.8333


100%|██████████| 10/10 [00:17<00:00,  1.73s/it]
  0%|          | 0/12 [00:00<?, ?it/s]

val Loss WSI: 0.6020  Acc: 0.6000
Epoch 6/19
----------


100%|██████████| 12/12 [00:31<00:00,  2.63s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train Loss WSI: 0.6155  Acc: 0.5833


100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


val Loss WSI: 0.5208  Acc: 1.0000


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch 7/19
----------


100%|██████████| 12/12 [00:34<00:00,  2.88s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train Loss WSI: 0.4478  Acc: 0.9167


100%|██████████| 10/10 [00:13<00:00,  1.37s/it]
  0%|          | 0/12 [00:00<?, ?it/s]

val Loss WSI: 0.5500  Acc: 0.6000
Epoch 8/19
----------


100%|██████████| 12/12 [00:31<00:00,  2.59s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train Loss WSI: 0.4622  Acc: 0.7500


100%|██████████| 10/10 [00:13<00:00,  1.36s/it]
  0%|          | 0/12 [00:00<?, ?it/s]

val Loss WSI: 0.5758  Acc: 0.6000
Epoch 9/19
----------


100%|██████████| 12/12 [00:28<00:00,  2.41s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train Loss WSI: 0.4243  Acc: 0.8333


100%|██████████| 10/10 [00:14<00:00,  1.43s/it]
  0%|          | 0/12 [00:00<?, ?it/s]

val Loss WSI: 0.4375  Acc: 0.7000
Epoch 10/19
----------


100%|██████████| 12/12 [00:32<00:00,  2.71s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train Loss WSI: 0.4011  Acc: 0.8333


100%|██████████| 10/10 [00:16<00:00,  1.67s/it]
  0%|          | 0/12 [00:00<?, ?it/s]

val Loss WSI: 0.4439  Acc: 0.7000
Epoch 11/19
----------


100%|██████████| 12/12 [00:32<00:00,  2.75s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train Loss WSI: 0.4202  Acc: 0.7500


100%|██████████| 10/10 [00:15<00:00,  1.58s/it]
  0%|          | 0/12 [00:00<?, ?it/s]

val Loss WSI: 0.4717  Acc: 0.7000
Epoch 12/19
----------


100%|██████████| 12/12 [00:31<00:00,  2.63s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train Loss WSI: 0.4805  Acc: 0.8333


100%|██████████| 10/10 [00:14<00:00,  1.42s/it]


val Loss WSI: 0.3904  Acc: 1.0000


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch 13/19
----------


100%|██████████| 12/12 [00:32<00:00,  2.75s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train Loss WSI: 0.3832  Acc: 0.8333


100%|██████████| 10/10 [00:16<00:00,  1.65s/it]
  0%|          | 0/12 [00:00<?, ?it/s]

val Loss WSI: 0.3748  Acc: 0.7000
Epoch 14/19
----------


100%|██████████| 12/12 [00:32<00:00,  2.68s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train Loss WSI: 0.3831  Acc: 0.8333


100%|██████████| 10/10 [00:15<00:00,  1.54s/it]
  0%|          | 0/12 [00:00<?, ?it/s]

val Loss WSI: 0.3570  Acc: 0.9000
Epoch 15/19
----------


100%|██████████| 12/12 [00:32<00:00,  2.70s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train Loss WSI: 0.3128  Acc: 1.0000


100%|██████████| 10/10 [00:16<00:00,  1.64s/it]
  0%|          | 0/12 [00:00<?, ?it/s]

val Loss WSI: 0.3437  Acc: 0.9000
Epoch 16/19
----------


100%|██████████| 12/12 [00:31<00:00,  2.61s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train Loss WSI: 0.2716  Acc: 0.8333


100%|██████████| 10/10 [00:16<00:00,  1.62s/it]
  0%|          | 0/12 [00:00<?, ?it/s]

val Loss WSI: 0.3511  Acc: 0.7000
Epoch 17/19
----------


100%|██████████| 12/12 [00:30<00:00,  2.55s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train Loss WSI: 0.2704  Acc: 0.8333


100%|██████████| 10/10 [00:14<00:00,  1.40s/it]
  0%|          | 0/12 [00:00<?, ?it/s]

val Loss WSI: 0.3943  Acc: 0.7000
Epoch 18/19
----------


100%|██████████| 12/12 [00:31<00:00,  2.64s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train Loss WSI: 0.2582  Acc: 0.9167


100%|██████████| 10/10 [00:14<00:00,  1.46s/it]
  0%|          | 0/12 [00:00<?, ?it/s]

val Loss WSI: 0.4415  Acc: 0.7000
Epoch 19/19
----------


100%|██████████| 12/12 [00:30<00:00,  2.54s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train Loss WSI: 0.2574  Acc: 0.9167


100%|██████████| 10/10 [00:16<00:00,  1.66s/it]


val Loss WSI: 0.2637  Acc: 1.0000
Training complete in 15m 60s
Best val Acc: 1.000000


#### Validating Model

In [26]:
pred_df = eval_test(model, df_test, data_transforms)

100%|██████████| 22/22 [01:28<00:00,  4.04s/it]

Test Accuracy:  0.9090909090909091
AUC Score:  0.9969098284797144



