In [1]:
import numpy as np
import pandas as pd
import subprocess
import os
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from scipy.io import loadmat
import re

In [2]:
test_path = "temp_test2.txt"
test_text_path = "temp_text_test2.txt"
test_pred_path = "annexml-result-example(2)2.txt"
perturbed_test_path = "perturbed_test.txt"
perturbed_interpret_path = "perturbed_interpretable.txt"
perturbed_pred_path = "annexml-result-example2.txt"

explainrows=2
samples_per_row = 5000
pick_probab = 0.6

In [3]:
import xml.etree.ElementTree as ET
import chardet

def get_desc(file_path):
    description=""
    try:
        # Detect file encoding
        with open(file_path, "rb") as file:
            raw_data = file.read(10000)  # Read a portion of the file
            result = chardet.detect(raw_data)
            encoding = result["encoding"] if result["encoding"] else "utf-8"

        # Read file with detected encoding
        with open(file_path, "r", encoding=encoding, errors="replace") as file:
            xml_content = file.read()

        # Parse XML
        root = ET.fromstring(xml_content)
        description = root.find("DESCRIPTION").text

        return re.sub(r'[^A-Za-z\s]', '', description)

    except ET.ParseError:
        print(f"XML Parsing Error in {file_path}")
        return ""
    except UnicodeDecodeError:
        print(f"Encoding issue in {file_path}")
        return ""
    except Exception as e:
        print(f"Unexpected error in {file_path}: {e}")
        return ""

In [4]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("all-MiniLM-L6-v2")
def get_sentence_embedding(text):
    """
    Generate sentence embeddings for a given text.

    :param text: str, input sentence or text
    :return: list, sentence embedding as a list of float values
    """
    embedding = model.encode(text)
    return embedding.tolist()  # Convert to list for easy handling




In [5]:
def one_hot_to_indices(one_hot_labels):
    """Convert one-hot encoding to label indices."""
    return [list(np.where(row == 1)[0]) for row in one_hot_labels]
    
def get_test_data(count=5):
    test_label_data = loadmat("IAPRTC/IAPRTC-12_TestLabels.mat")
    I_z_te=test_label_data['I_z_te']
    label_test_column_1 = I_z_te[0,0]
    test_labels = one_hot_to_indices(label_test_column_1)
    
    labels=[]
    texts = []
    with open("IAPRTC/iapr_test_list.txt") as file1:
        row=0
        for line in file1:
            path = "iaprtc12/annotations_complete_eng/"+line[:-1]+".eng"
            description = get_desc(path)
            if not description=="":
                # pass
                labels.append(test_labels[row])
                texts.append(description)
            row+=1
            if row>=count:
                break
    return labels,texts

In [6]:
def get_test_files(text_path,test_path, labels, texts):
    with open(text_path,"w") as file1, open(test_path,"w") as file2:
        for row in range(len(labels)):
            description=texts[row]
            text = ",".join(map(str, labels[row])) + " "
            text += description
            file1.write(text+"\n")

            embedding=get_sentence_embedding(description)
            text = ",".join(map(str, labels[row])) + " "
            text += " ".join(f"{i+1}:{100*value:.6f}" for i, value in enumerate(embedding))
            file2.write(text+"\n")

In [7]:
def predict(test_path, result_file):
    import subprocess
    direc = os.path.abspath("AnnexML")
    command = [
        os.path.abspath(os.path.join(direc, "src", "annexml")),  # Absolute path to annexml.exe
        "predict",
        "annexml-example.json",
        f"predict_file={test_path}",
        f"result_file={result_file}"
    ]
    try:
        result = subprocess.run(command, check=True, capture_output=True, text=True, cwd=direc)
        print("Standard Output:\n", result.stdout)
        if result.stderr:
            print("Standard Error:\n", result.stderr)
    except subprocess.CalledProcessError as e:
        pass
    print("Done.")

In [8]:
labels, texts = get_test_data(count=explainrows)
get_test_files(test_text_path, test_path, labels, texts)

In [9]:
def generate_perturbed_files(test_text_path, perturbed_test_path, perturbed_interpret_path):
    with open(test_text_path) as file1, open(perturbed_interpret_path,"w") as file2, open(perturbed_test_path,"w") as file3:
        for line in file1:
            parts = line.split()
            labels = parts[0].split(",")
            words = parts[1:]
            for i in range(samples_per_row):
                pert = (np.random.rand(len(words)) < pick_probab).astype(int)
                text = " ".join(map(str, pert))
                file2.write(text+"\n")
                
                text = " ".join([word for word, val in zip(words, pert) if val == 1])
                embedding=get_sentence_embedding(text)
                text = ",".join(map(str, labels)) + " "
                text += " ".join(f"{i+1}:{100*value:.6f}" for i, value in enumerate(embedding))
                file3.write(text+"\n")

In [10]:
generate_perturbed_files(test_text_path, perturbed_test_path, perturbed_interpret_path)

In [11]:
predict(test_path="../"+test_path, result_file="../"+test_pred_path)
predict(test_path="../"+perturbed_test_path, result_file="../"+perturbed_pred_path)

Done.
Done.


In [12]:
classes = []
with open("IAPRTC/iaprtc12_dictionary.txt") as file:
    for line in file:
        classes.append(line[:-1])
for i in range(len(classes)):  
    print(i, classes[i])

0 adult
1 airplane
2 airport
3 anorak
4 area
5 back
6 backpack
7 bag
8 balcony
9 bank
10 bar
11 base
12 bay
13 beach
14 bed
15 bedcover
16 bedside
17 bell
18 bench
19 bicycle
20 bike
21 bird
22 bit
23 blanket
24 bloom
25 board
26 boat
27 body
28 bone
29 bottle
30 boy
31 branch
32 brick
33 bridge
34 building
35 bus
36 bush
37 cactus
38 camera
39 canyon
40 cap
41 cape
42 car
43 carpet
44 cathedral
45 ceiling
46 centre
47 chair
48 child
49 church
50 city
51 classroom
52 cliff
53 clock
54 cloth
55 clothes
56 cloud
57 coast
58 cobblestone
59 column
60 condor
61 corner
62 corridor
63 couch
64 country
65 couple
66 court
67 courtyard
68 creek
69 cross
70 cup
71 curtain
72 cycling
73 cyclist
74 deck
75 desert
76 desk
77 dirt
78 dog
79 dome
80 door
81 dress
82 dune
83 edge
84 embankment
85 entrance
86 face
87 fence
88 fern
89 field
90 fjord
91 flag
92 flagpole
93 floor
94 flower
95 fog
96 footpath
97 forest
98 formation
99 fountain
100 frame
101 front
102 fruit
103 garden
104 gate
105 giant
106 

In [13]:
class LinearRegression_KLasso:
    def __init__(self,K=-1):
        self.loss = 0
        self.w = np.zeros(10)
        self.w0 = 0
        self.K = K
        
    def fit(self,X,Y,alpha = 2e-3,epsilon = 1e-3, sigma=4):
        if self.K==-1:
            self.K = int(X.shape[1] / 4)
        self.K = int(min(self.K,X.shape[1]))
        iters=0
        self.w = np.zeros(X.shape[1])
        distances = self.get_distance(X,sigma)
        newloss = self.MSE(X,Y,distances)
        while abs(self.loss-newloss)>epsilon:
            self.loss = newloss
            diff = np.dot(X*distances,self.w)+self.w0-Y
            self.w0 -= 2*alpha*np.mean(diff)
            self.w -= 2*alpha*(diff@X)/(X.shape[0])
            
            top_k_indices = np.argsort(np.abs(self.w))[-self.K:]
            new_w = np.zeros_like(self.w)
            new_w[top_k_indices] = self.w[top_k_indices]
            self.w = new_w
            
            newloss = self.MSE(X,Y,distances)
            iters+=1
            if iters==10000:
                break
        print("Iterations:", iters, ", Loss:", self.MSE(X,Y,distances))
            
    def get_distance(self, X, sigma):
        distances = (X.shape[1]-np.sum(X,axis=1))/(sigma*sigma)
        total = sum([np.e**((-1)*i/sigma**2) for i in range(X.shape[1])])
        # distances = (np.e**(distances))/total
        distances = np.e**(distances*(-1))
        return distances[:, np.newaxis]

    def MSE(self,X,Y,distances):
        return np.mean((np.dot(X*distances,self.w)+self.w0 - Y)**2)
    
    def predict(self,X):
        return np.dot(X,self.w)+self.w0

In [14]:
from IPython.display import display, HTML

def generate_colored_text(words, word_scores):
    word_scores = dict(word_scores)  # Convert word_scores to dictionary
    colored_words = []

    for word in words:
        score = word_scores.get(word, 0)  # Default to 0 if word not in dictionary
        blue_opacity = max(0, min(score, 1))  # Ensure opacity is within [0,1]
        red_opacity = max(0,min(score*(-1),1))
        color = f"rgba(0, 0, 255, {blue_opacity})"
        if score<0:
            color = f"rgba(255, 0, 0, {red_opacity})"
        colored_words.append(f'<span style="background-color: {color}; padding: 2px; border-radius: 3px; color: black;">{word}</span>')

    return " ".join(colored_words)

In [15]:
from itertools import islice

def LIME(test_pred_path, test_text_path, perturbed_pred_path, perturbed_interpret_path):
    model = LinearRegression_KLasso(10)
    with open(test_pred_path) as file1, open(test_text_path) as file2, open(perturbed_interpret_path) as file3, open(perturbed_pred_path) as file4:
        row=1
        for line1,line2 in zip(file1,file2):
            print(f"Explaining for row {row}")
            row+=1
            
            parts = line1.split('\t')
            predictions = list(map(int, parts[0].split(',')))
            parts = line2.split()
            words = parts[1:]
            
            ## read 10,000 lines of file3
            x = list(islice(file3, samples_per_row))
            x = np.array([list(map(int, line3.split())) for line3 in x])
            chunk = list(islice(file4, samples_per_row))
            for label in predictions:
                print(f"Explaining for label {classes[label]}:")
            
                # Optimized parsing using list comprehension
                y = np.array([
                    float(pred_labels.get(label, 0))
                    for pred_labels in (
                        {int(k): float(v) for k, v in (item.split(":") for item in line4.split("\t")[1].split(','))}
                        for line4 in chunk
                    )
                ])
                # print(y[:100])
                model.fit(x,y)
                bias,weights = model.w0,model.w
                attributions = dict(zip(words, weights))
                attributions = sorted(attributions.items(), key=lambda item: item[1], reverse=True)
                print(attributions)

                weights_normalized = weights.copy()
                weights_normalized /= np.max(np.abs(weights_normalized))
                attributions = dict(zip(words, weights_normalized))
                attributions = sorted(attributions.items(), key=lambda item: item[1], reverse=True)
                html_output = generate_colored_text(words, attributions)
                display(HTML(f"<p>{html_output}</p>"))

                print()
            print()

In [None]:
LIME(test_pred_path, test_text_path, perturbed_pred_path, perturbed_interpret_path)

Explaining for row 1
Explaining for label building:
Iterations: 6971 , Loss: 588.8107983980458
[('building', 162.21209014079517), ('palm', 45.18570650551939), ('cars', 29.417191855839878), ('large', 21.690020316473497), ('white', 17.09413791788946), ('red', 6.078123693261277), ('umbrellas', 6.032576838345477), ('turning', 5.275970546504185), ('a', 0.0), ('on', 0.0), ('the', 0.0), ('left', 0.0), ('tree', 0.0), ('in', 0.0), ('centre', 0.0), ('of', 0.0), ('picture', 0.0), ('mostly', 0.0), ('street', 0.0), ('at', 0.0), ('junction', 0.0), ('some', 0.0), ('them', 0.0), ('others', 0.0), ('going', 0.0), ('straight', 0.0), ('there', 0.0), ('are', 0.0), ('park', 0.0), ('right', 0.0), ('people', 0.0), ('walking', 0.0), ('through', 0.0), ('crossing', 0.0), ('road', 0.0), ('foreground', 0.0)]



Explaining for label car:
Iterations: 10000 , Loss: 436.446887001983
[('cars', 167.1238488834615), ('street', 51.744241704091145), ('building', 23.10720841977844), ('turning', 12.819273364942125), ('road', 6.623735828010408), ('a', 0.0), ('large', 0.0), ('on', 0.0), ('the', 0.0), ('left', 0.0), ('palm', 0.0), ('in', 0.0), ('centre', 0.0), ('of', 0.0), ('picture', 0.0), ('white', 0.0), ('at', 0.0), ('junction', 0.0), ('some', 0.0), ('them', 0.0), ('others', 0.0), ('going', 0.0), ('straight', 0.0), ('there', 0.0), ('are', 0.0), ('red', 0.0), ('umbrellas', 0.0), ('park', 0.0), ('right', 0.0), ('people', 0.0), ('walking', 0.0), ('through', 0.0), ('crossing', 0.0), ('foreground', 0.0), ('mostly', -36.02420039457894), ('tree', -87.7201199327583)]



Explaining for label centre:
Iterations: 2148 , Loss: 17.623375381957644
[('on', 0.0), ('the', 0.0), ('left', 0.0), ('palm', 0.0), ('in', 0.0), ('centre', 0.0), ('of', 0.0), ('picture', 0.0), ('mostly', 0.0), ('white', 0.0), ('cars', 0.0), ('street', 0.0), ('at', 0.0), ('them', 0.0), ('turning', 0.0), ('others', 0.0), ('going', 0.0), ('straight', 0.0), ('there', 0.0), ('are', 0.0), ('umbrellas', 0.0), ('park', 0.0), ('right', 0.0), ('walking', 0.0), ('through', 0.0), ('crossing', 0.0), ('road', 0.0), ('foreground', 0.0), ('junction', -2.8953953629710836), ('some', -2.901855428128973), ('a', -2.9241915411749155), ('large', -3.008352185514116), ('red', -3.230067225143334), ('building', -3.954932356185402), ('people', -4.576356988386527), ('tree', -4.96933076495661)]



Explaining for label palm:
