In [1]:
import torch
import os
import random
import utils
import data_utils
import similarity
import argparse
import datetime
import json

from glm_saga.elasticnet import IndexedTensorDataset, glm_saga
from torch.utils.data import DataLoader, TensorDataset

In [2]:
# Parameter setting
dataset = "covid"
concept_set = "data/concept_sets/covid_filtered_new.txt"
backbone = "vit"
clip_name = "ViT-B/16"
device = "cuda"
batch_size = 16
saga_batch_size=256
proj_batch_size = 50000
feature_layer = 'norm'
activation_dir ='saved_activations'
save_dir ='saved_models'
clip_cutoff = 0.21
proj_steps = 20000
interpretability_cutoff = 0.15
lam = 0.00007
n_iters=10000
prints = True
parser = argparse.ArgumentParser(description='Settings for creating CBM')
parser.add_argument("--print", action='store_true', help="Print all concepts being deleted in this stage")

_StoreTrueAction(option_strings=['--print'], dest='print', nargs=0, const=True, default=False, type=None, choices=None, required=False, help='Print all concepts being deleted in this stage', metavar=None)

In [3]:
if not os.path.exists(save_dir):
    os.mkdir(save_dir)
if concept_set==None:
    concept_set = "data/concept_sets/{}_filtered.txt".format(dataset)
similarity_fn = similarity.cos_similarity_cubed_single
    
d_train = dataset + "_train" 
d_val = dataset + "_val"

cls_file = data_utils.LABEL_FILES[dataset]
with open(cls_file, "r") as f:
    classes = f.read().split("\n")

with open(concept_set) as f:
    concepts = f.read().split("\n")

num_classes = len(classes)
# save activations and get save_paths
for d_probe in [d_train, d_val]:
    utils.save_activations(clip_name = clip_name, target_name = backbone, 
                            target_layers = [feature_layer], d_probe = d_probe,
                            concept_set = concept_set, batch_size = batch_size, 
                            device = device, pool_mode = "avg", save_dir = activation_dir, dataset=dataset, num_classes=num_classes)


100%|██████████| 3/3 [00:00<00:00,  5.78it/s]
100%|██████████| 27/27 [00:14<00:00,  1.82it/s]
100%|██████████| 27/27 [00:13<00:00,  2.01it/s]
100%|██████████| 21/21 [00:12<00:00,  1.72it/s]
100%|██████████| 21/21 [00:11<00:00,  1.81it/s]


In [4]:
target_save_name, clip_save_name, text_save_name = utils.get_save_names(clip_name, backbone, 
                                        feature_layer,d_train, concept_set, "avg", activation_dir)
val_target_save_name, val_clip_save_name, text_save_name =  utils.get_save_names(clip_name, backbone,
                                        feature_layer, d_val, concept_set, "avg", activation_dir)
print(target_save_name,"\n", clip_save_name,"\n", text_save_name)
print(val_target_save_name,"\n", val_clip_save_name,"\n", text_save_name)

saved_activations/covid_train_vit_norm.pt 
 saved_activations/covid_train_clip_ViT-B16.pt 
 saved_activations/covid_filtered_new_ViT-B16.pt
saved_activations/covid_val_vit_norm.pt 
 saved_activations/covid_val_clip_ViT-B16.pt 
 saved_activations/covid_filtered_new_ViT-B16.pt


In [5]:
with torch.no_grad():
    target_features = torch.load(target_save_name, map_location="cpu").float()
        
    val_target_features = torch.load(val_target_save_name, map_location="cpu").float()
    
    image_features = torch.load(clip_save_name, map_location="cpu").float()
    image_features /= torch.norm(image_features, dim=1, keepdim=True)

    val_image_features = torch.load(val_clip_save_name, map_location="cpu").float()
    val_image_features /= torch.norm(val_image_features, dim=1, keepdim=True)

    text_features = torch.load(text_save_name, map_location="cpu").float()
    text_features /= torch.norm(text_features, dim=1, keepdim=True)
        
    clip_features = image_features @ text_features.T
    val_clip_features = val_image_features @ text_features.T

    del image_features, text_features, val_image_features

In [6]:
#filter concepts not activating highly
highest = torch.mean(torch.topk(clip_features, dim=0, k=5)[0], dim=0)
    
if prints:
    for i, concept in enumerate(concepts):
        if highest[i]<=clip_cutoff:
            print("Deleting {}, CLIP top5:{:.3f}".format(concept, highest[i]))
concepts = [concepts[i] for i in range(len(concepts)) if highest[i]>clip_cutoff]

Deleting clear, CLIP top5:0.205
Deleting hilar, CLIP top5:0.192
Deleting infiltrate, CLIP top5:0.206


In [7]:
#save memory by recalculating
del clip_features
with torch.no_grad():
    image_features = torch.load(clip_save_name, map_location="cpu").float()
    image_features /= torch.norm(image_features, dim=1, keepdim=True)

    text_features = torch.load(text_save_name, map_location="cpu").float()[highest>clip_cutoff]
    text_features /= torch.norm(text_features, dim=1, keepdim=True)
    
    clip_features = image_features @ text_features.T
    del image_features, text_features
    
val_clip_features = val_clip_features[:, highest>clip_cutoff]

In [8]:
 #learn projection layer
proj_layer = torch.nn.Linear(in_features=target_features.shape[1], out_features=len(concepts),
                                 bias=False).to(device)
opt = torch.optim.Adam(proj_layer.parameters(), lr=1e-3)
print(len(concepts))
indices = [ind for ind in range(len(target_features))]
    
best_val_loss = float("inf")
best_step = 0
best_weights = None
proj_batch_size = min(proj_batch_size, len(target_features))
for i in range(proj_steps):
    batch = torch.LongTensor(random.sample(indices, k=proj_batch_size))
    outs = proj_layer(target_features[batch].to(device).detach())
    loss = -similarity_fn(clip_features[batch].to(device).detach(), outs)
        
    loss = torch.mean(loss)
    loss.backward()
    opt.step()
    if i%50==0 or i==proj_steps-1:
        with torch.no_grad():
            val_output = proj_layer(val_target_features.to(device).detach())
            val_loss = -similarity_fn(val_clip_features.to(device).detach(), val_output)
            val_loss = torch.mean(val_loss)
        if i==0:
            best_val_loss = val_loss
            best_step = i
            best_weights = proj_layer.weight.clone()
            print("Step:{}, Avg train similarity:{:.4f}, Avg val similarity:{:.4f}".format(best_step, -loss.cpu(),
                                                                                               -best_val_loss.cpu()))
                
        elif val_loss < best_val_loss:
            best_val_loss = val_loss
            best_step = i
            best_weights = proj_layer.weight.clone()
        else: #stop if val loss starts increasing
            break
    opt.zero_grad()
        
proj_layer.load_state_dict({"weight":best_weights})
print("Best step:{}, Avg val similarity:{:.4f}".format(best_step, -best_val_loss.cpu()))

36
Step:0, Avg train similarity:-0.0099, Avg val similarity:0.0441
Best step:100, Avg val similarity:0.2957


In [9]:
#delete concepts that are not interpretable
with torch.no_grad():
    proj_layer = proj_layer.to(device)
    outs = proj_layer(val_target_features.to(device).detach())
    sim = similarity_fn(val_clip_features.to(device).detach(), outs)
    interpretable = sim > interpretability_cutoff
        
if prints:
    for i, concept in enumerate(concepts):
        if sim[i]<=interpretability_cutoff:
            print("Deleting {}, Iterpretability:{:.3f}".format(concept, sim[i]))
    
concepts = [concepts[i] for i in range(len(concepts)) if interpretable[i]]
    
del clip_features, val_clip_features

W_c = proj_layer.weight[interpretable]
proj_layer = torch.nn.Linear(in_features=target_features.shape[1], out_features=len(concepts), bias=False)
proj_layer.load_state_dict({"weight":W_c})

train_targets = data_utils.get_targets_only(d_train)
val_targets = data_utils.get_targets_only(d_val)
    
with torch.no_grad():
    # feature fusion
    train_c = torch.cat([proj_layer(target_features.detach()), target_features.detach()], dim=1)
    val_c = torch.cat([proj_layer(val_target_features.detach()), val_target_features.detach()],dim=1)
    
#     train_c = proj_layer(target_features.detach())
#     val_c = proj_layer(val_target_features.detach())
    
        
    train_mean = torch.mean(train_c, dim=0, keepdim=True)
    train_std = torch.std(train_c, dim=0, keepdim=True)
        
    train_c -= train_mean
    train_c /= train_std
        
    train_y = torch.LongTensor(train_targets)
    indexed_train_ds = IndexedTensorDataset(train_c, train_y)

    val_c -= train_mean
    val_c /= train_std
        
    val_y = torch.LongTensor(val_targets)

    val_ds = TensorDataset(val_c,val_y)


indexed_train_loader = DataLoader(indexed_train_ds, batch_size=saga_batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=saga_batch_size, shuffle=False)


Deleting clear chest x-ray, Iterpretability:0.061
Deleting diffuse lesions, Iterpretability:0.148
Deleting multifocal lesions, Iterpretability:0.093
Deleting no pneumothorax, Iterpretability:0.045
Deleting perivascular cuffing, Iterpretability:0.138
Deleting the blood vessels, Iterpretability:0.092
Deleting the bronchi, Iterpretability:0.042


In [10]:
# Make linear model and zero initialize
linear = torch.nn.Linear(train_c.shape[1],len(classes)).to(device)
# linear.weight.data.zero_()
# linear.bias.data.zero_()
    
STEP_SIZE = 0.1
ALPHA = 0.99
metadata = {}
metadata['max_reg'] = {}
metadata['max_reg']['nongrouped'] = lam

# Solve the GLM path
output_proj = glm_saga(linear, indexed_train_loader, STEP_SIZE, n_iters, ALPHA, epsilon=1, k=1,
                      val_loader=val_loader, do_zero=False, metadata=metadata, n_ex=len(target_features), n_classes = len(classes))
W_g = output_proj['path'][0]['weight']
b_g = output_proj['path'][0]['bias']

    
save_name = "{}/{}_cbm_{}".format(save_dir, dataset, datetime.datetime.now().strftime("%Y_%m_%d_%H_%M"))
os.mkdir(save_name)
torch.save(train_mean, os.path.join(save_name, "proj_mean.pt"))
torch.save(train_std, os.path.join(save_name, "proj_std.pt"))
torch.save(W_c, os.path.join(save_name ,"W_c.pt"))
torch.save(W_g, os.path.join(save_name, "W_g.pt"))
torch.save(b_g, os.path.join(save_name, "b_g.pt"))
    
with open(os.path.join(save_name, "concepts.txt"), 'w') as f:
    f.write(concepts[0])
    for concept in concepts[1:]:
        f.write('\n'+concept)
dict = {
  "clip_name": "{}".format(clip_name),
  "backbone": "{}".format(backbone),
  "device": "{}".format(device),
  "batch_size": "{}".format(batch_size),
  "saga_batch_size": "{}".format(saga_batch_size),
  "dataset": "{}".format(dataset),
  "concept_set": "{}".format(concept_set),
  "feature_layer": "{}".format(feature_layer),
  "activation_dir": "{}".format(activation_dir),
  "save_dir": "{}".format(save_dir),
  "clip_cutoff": "{}".format(clip_cutoff),
  "proj_steps": "{}".format(proj_steps),
  "interpretability_cutoff": "{}".format(interpretability_cutoff),
  "lam": '{}'.format(lam),
  "n_iters": "{}".format(n_iters)
}
with open(os.path.join(save_name, "args.txt"), 'w') as f:
    json.dump(dict, f, indent=2)
    
with open(os.path.join(save_name, "metrics.txt"), 'w') as f:
    out_dict = {}
    for key in ('lam', 'lr', 'alpha', 'time'):
        out_dict[key] = float(output_proj['path'][0][key])
    out_dict['metrics'] = output_proj['path'][0]['metrics']
    nnz = (W_g.abs() > 1e-5).sum().item()
    total = W_g.numel()
    out_dict['sparsity'] = {"Non-zero weights":nnz, "Total weights":total, "Percentage non-zero":nnz/total}
    json.dump(out_dict, f, indent=2)

 22%|██▏       | 2224/10000 [00:11<00:41, 189.11it/s]

(0) lambda 0.0001, loss 0.0015, acc 1.0000 [val acc 0.8069] [test acc -1.0000], sparsity 0.548306148055207 [874/1594], time 11.77230191230774, lr 0.1000



