In [1]:
import json
import torch
import sys,os
sys.path.insert(0,os.getcwd())
from dataset import RachelDataset, collatevisualhash
import copy
from tqdm import tqdm
    

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
datas = json.load(open('fake_data/datajson-prospective.json'))
datadict = {}
for data in datas:
    path = data[0]
    h = path.split('/')[-1]
    datadict[h] = data
# Tools for figuring out the series and head models
def getseries(h):
    data = datadict[h]
    return [r[0] for r in data[1]]
    
def getheadinfo(configjson,taskname):
    infos = json.load(open(configjson))
    return infos[taskname]

In [3]:
# Find series in a study
getseries('BRAIN_FAKE_20752')

['2D_PC_SAG',
 'Lt__Cor',
 'T1_AXIAL',
 'COR_RFRMT',
 'AXIAL__T1',
 'DWI_SAG',
 'coronal',
 'AX_SWI']

In [4]:
getheadinfo('fake_data/prospective-config.json','vascular_hemorrhagic_intracranial_hemorrhage')

['fake_data/prospective_classification/vascular_hemorrhagic_intracranial_hemorrhage.txt',
 [['2025-fake-data-heads/bestauc_vascular_hemorrhagic_intracranial_hemorrhage.pt',
   40]]]

In [5]:
series_list = getseries('BRAIN_FAKE_20752')
series_list.remove('DWI_SAG')

In [6]:
# Fill in these settings before running occlusion sensitivity
device = 'cuda:0' # the device to compute the model with
outjsonname = 'lime_output.json' # the output importance dictionary json name
#clipmodelname = 'tempmodelsavesite/scratch/checkpoints712bigvit/154.pt' # the checkpoint for the clip model base to use
clipmodelname = 'ckpts/last.pt' # the checkpoint for the clip model base to use
taskname = 'vascular_hemorrhagic_intracranial_hemorrhage' # the name of the task
headmodelname = '2025-fake-data-heads/bestauc_vascular_hemorrhagic_intracranial_hemorrhage.pt' # the head model. You can choose this from the head config json file
classid = 48 # the task id of the head model. You must choose this based on the task and the head config json file
datapointhash = 'BRAIN_FAKE_20752' # The hash of the datapoint we want to look at
exclude_series = series_list # any series we want to exclude in this analysis
#excluse_series = []
series_of_interest = 'DWI_SAG' # The series that we want to focus on
lime_steps = 1000 # Number of lime steps. The more steps means more accurate lime visualization, but also takes longer


# alternative models
#clipmodelname = 'tempmodelsavesite/scratch/checkpoints720clip/141.pt' # The head config json file is configs/jsons/prospectivecla.json
#clipmodelname = 'tempmodelsavesite/scratch/checkpoints712bigvit/154.pt' # The head config json file is configs/jsons/prospectivecla2.json
#headmodelname = 'tempmodelsavesite/87-61-bigvitcont/bestauc_cyst_epidermoid_cyst.pt'

In [7]:
from lime import lime_base
from scipy import spatial
from sklearn.utils.validation import check_random_state
import numpy as np

class Lime_Explainer:
    def __init__(self, kernelfn=None, feature_selection="none", verbose=False):
        if kernelfn is None:

            def kernelfn(d):
                return np.sqrt(np.exp(-(d**2) / 0.25**2))

        self.base = lime_base.LimeBase(kernelfn, verbose)
        self.fs = feature_selection

    def explain_instance(
        self, inp, serie_of_interest, classfn, num_samples, seed=0, fracs=1
    ):
        samples = num_samples
        randomstate = check_random_state(seed)
        series_ord = inp['serienamestr'].index(serie_of_interest)
        lentokens = len(inp['visual'][series_ord])
        
        masks = (
            randomstate.randint(0, fracs + 1, lentokens*samples)
            .reshape(samples, lentokens)
            .astype(np.float64)
        )
        masks /= float(fracs)
        masks[0] = 1
        for i,mask in enumerate(masks):
            if np.sum(mask) == 0.0:
                masks[i] = 1
        # print(samples)
        distances = np.zeros(samples)
        llabels = np.zeros((samples, 1))
        for i in tqdm(range(samples)):
            newdata = copy.deepcopy(inp)
            tensormask = torch.LongTensor(masks[i])
            indices = torch.nonzero(tensormask)[:,0]
            newdata['visual'][series_ord] = newdata['visual'][series_ord][indices]
            newdata['coordinates'][series_ord] = newdata['coordinates'][series_ord][indices]
            llabels[i,0] = classfn([newdata])

        ret = self.base.explain_instance_with_data(
            masks, llabels, distances, 0, lentokens, feature_selection=self.fs
        )
        return ret

In [8]:
dataset = RachelDataset(datajson='fake_data/datajson-prospective.json',
            datarootdir='fake_data/data/',
            tokenizer='biomed',
            text_max_len=128,
            is_train=False,
            nosplit=True,
            vqvaename = 'FAKE_TOKENIZER',
            visualhashonly=True,
            percentage = 5,
            novisualaug = True,
            exclude_series = exclude_series
            )

datapoint = dataset.find_by_hash(datapointhash)
posmaps = datapoint['posmap']

clipmodel = torch.load(clipmodelname,map_location=device).module

print('GOT TO HERE')
patchify = copy.deepcopy(clipmodel.patchifier).cpu()


collate = collatevisualhash(patchify, device, puttodevice=True)

visualclip = clipmodel.visual_model
visualclip.make_no_flashattn()
visualclip.patdis = False
head = torch.load(headmodelname,map_location=device)

explainer = Lime_Explainer()

def getlogits(datas):
    collated = collate(datas)
    with torch.no_grad():
        with torch.amp.autocast(device_type='cuda',dtype=torch.float16):
            clipout = visualclip(collated,retpool=True).to(device)
    finalout = head(clipout)
    outval = finalout[:,classid]
    return outval

limeresult = explainer.explain_instance(datapoint,series_of_interest,getlogits,lime_steps)

results = limeresult[1]
results.sort()


tokenimportancedict = {series_of_interest: {}}
posmap = posmaps[datapoint['serienamestr'].index(series_of_interest)]
assert len(posmap) == len(results)
for j,tok in enumerate(results):
    jj,importance = tok
    assert jj == j
    pos = posmap[j]
    tokenimportancedict[series_of_interest][pos.item()] = importance

json.dump(tokenimportancedict, open(outjsonname,'w+'),indent=2)



GOT TO HERE


100%|██████████| 1000/1000 [00:35<00:00, 27.88it/s]
