In [None]:
# %pip uninstall torch torchvision
# %pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
# %pip install lime
# %pip install scikit-image


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from skimage.segmentation import mark_boundaries
from skimage.segmentation import mark_boundaries, slic
from sklearn.linear_model import LinearRegression, Ridge

import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [None]:
seed = 31
random.seed(seed)  
np.random.seed(seed)  
torch.manual_seed(seed)  
torch.cuda.manual_seed(seed)  

In [None]:
class DigitClassifier(nn.Module):
    def __init__(self):
        super(DigitClassifier, self).__init__()
        self.fc1 = nn.Linear(28*28, 64)  
        self.relu = nn.ReLU()             
        self.fc2 = nn.Linear(64, 10)      
    
    def forward(self, x):
        x = x.view(-1, 28*28) 
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

transform = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [None]:
def train_model(model, train_loader, criterion, optimizer, epochs=5):
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0.0
        correct, total = 0, 0
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print(f'Epoch {epoch+1}, Loss: {epoch_loss/len(train_loader):.4f}, Accuracy: {100 * correct / total:.2f}%')

def select_images(model, test_loader, class_type = "correct", k = 5):
    model.eval()
    correct, total = 0, 0
    correct_classified = []
    incorrect_classified = []
    
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images.view(-1, 28*28))  
            _, predicted = torch.max(outputs, 1)
            
            for i in range(len(images)):
                if predicted[i] == labels[i]:
                    correct_classified.append((images[i], labels[i]))
                else:
                    incorrect_classified.append((images[i], labels[i], predicted[i]))

            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f'Accuracy: {accuracy:.2f}%')

    if class_type == "correct":
        selected_indices = random.sample(range(len(correct_classified)), k)
        return  [correct_classified[i] for i in selected_indices]
    else:
        selected_indices = random.sample(range(len(incorrect_classified)), k)
        return  [incorrect_classified[i] for i in selected_indices]



def visualise_digit(image_label):
    plt.imshow(image_label[0].squeeze(), cmap='gray')
    plt.title(f'Label: {image_label[1]}')
    plt.axis('off')  
    plt.show()


def segment_image(image_label, n_segments=10, compactness=10):

    image = np.array(image_label[0]).reshape(28,28)
    segmented_image = slic(image, n_segments=n_segments, compactness=compactness, channel_axis=None)
    
    return segmented_image

def visualize_segments(image_label, segments):

    image = np.array(image_label[0]).reshape(28,28)
    
    plt.figure(figsize=(5, 5))
    plt.imshow(mark_boundaries(image, segments), cmap='gray')
    plt.title(f'Segmented Image {image_label[1]} of with ({np.unique(segments).size} segments)')
    plt.axis('off')
    plt.tight_layout()
    plt.show()

def generate_perturbations(image_label, segments, num_samples=1000):
    n_segments = np.unique(segments).size
    perturbations = np.random.binomial(1, 0.7, size=(num_samples, n_segments))
    image = np.array(image_label[0]).reshape(28,28)
    perturbed_images = []
    count = 0
    for perturbation in perturbations:  
        perturbed_image = np.copy(image)
        for segment_idx, off in enumerate(perturbation):
            if off == 0:  
                perturbed_image[segments == segment_idx] = 0
        if not np.array_equal(perturbed_image, image):
            count += 1

        perturbed_images.append((perturbed_image, image_label[1]))
    # print(count)
    return perturbed_images



def predict_digits(model, image_labels):

    model.eval()  
    predictions = []

    with torch.no_grad():  
        for image, true_label in image_labels:
            image_tensor = torch.tensor(image, dtype=torch.float32).view(1, 28*28)
            output = model(image_tensor)
            predicted_label = torch.argmax(output, dim=1).item()
            predictions.append(predicted_label)
    
    return predictions

def fit_linear_model(image_label, perturbations, predictions, kernel_width=2):

    perturbated_images = np.array([perturbation[0] for perturbation in perturbations])
    perturbated_images = perturbated_images.reshape(perturbated_images.shape[0], -1) 

    original_image = np.array(image_label[0]).reshape(1, -1)


    distances = np.sqrt(np.sum((perturbated_images - original_image) ** 2, axis=1))
    weights = np.exp(-(distances ** 2) / kernel_width ** 2)
    
    linear_model = Ridge(alpha=0.1)
    linear_model.fit(perturbated_images, predictions, sample_weight=weights)
    
    coefficients = linear_model.coef_
    
    return coefficients

def visualise_coefficients(original_image, coefficients, figsize=(12, 10)):


    orig_img = original_image.reshape(28, 28)
    coef_img = coefficients.reshape(28, 28)

    
    plt.figure(figsize=figsize)    

    plt.subplot(1, 3, 1)
    coef_plot = plt.imshow(coef_img, cmap='magma')
    plt.title('Coefficient Values\n(Blue=Positive, Red=Negative)')
    plt.colorbar(coef_plot)
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    abs_coef = np.abs(coef_img)
    abs_plot = plt.imshow(abs_coef, cmap='viridis')
    plt.title('Absolute Importance')
    plt.colorbar(abs_plot)
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    max_abs_coef = np.max(np.abs(coef_img))
    norm_coef = coef_img / (max_abs_coef + 1e-10)
    
    rgb_img = np.zeros((28, 28, 3))
    for i in range(3):
        rgb_img[:, :, i] = orig_img / 2.0
    
    green_mask = norm_coef > 0
    rgb_img[:, :, 1][green_mask] += norm_coef[green_mask] * 0.5
    red_mask = norm_coef < 0
    rgb_img[:, :, 0][red_mask] += -norm_coef[red_mask] * 0.5
    rgb_img = np.clip(rgb_img, 0, 1)
    plt.imshow(rgb_img)
    plt.title('Coefficient Overlay\n(Green=Positive, Red=Negative)')
    plt.axis('off')
    
    plt.figure(figsize=(8, 8))
    plt.imshow(orig_img, cmap='gray')
    contour = plt.contour(coef_img, levels=10, colors='r', alpha=0.8)
    plt.clabel(contour, inline=True, fontsize=8)
    plt.title('Coefficient Contours on Original Image')
    plt.axis('off')
    plt.tight_layout()
    plt.show()
    



# Feed Forward NN for MNIST

In [None]:

model = DigitClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [None]:
train_model(model, train_loader, criterion, optimizer, epochs=5)


In [None]:
# selecting correctly classified fimages
correctly_classified_images = select_images(model, test_loader)


# LIME

In [None]:
def explain_with_lime(selected_images = [], n_segments = 10, kernel_width = 10):

    for id, image_label in enumerate(selected_images):
        print(f'{id+1}. Selected digit')
        visualise_digit(image_label)

        segmented_image  = segment_image(image_label, n_segments=n_segments)
        visualize_segments(image_label, segmented_image)

        print("Generating perturbations : ")
        perturbed_images = generate_perturbations(image_label, segmented_image)
        
        print("Perturbed Images are : ")
        for perturbed_id, perturbed_image in enumerate(perturbed_images[:4]): #showing only 4 examples per image
            print(f'{perturbed_id}')
            visualise_digit(perturbed_image)

        
        #predicting labels of these images
        predicted_digits = predict_digits(model, perturbed_images)

        #fitting linear model
        coefficients = fit_linear_model(image_label, perturbed_images, predicted_digits, kernel_width=kernel_width)

        #visualising coefficients
        visualise_coefficients(image_label[0], coefficients)





## Experimenting with kernel widths

In [None]:
explain_with_lime(correctly_classified_images, n_segments=10, kernel_width=0.1)


In [None]:
explain_with_lime(correctly_classified_images, n_segments=10, kernel_width=1)


In [None]:
explain_with_lime(correctly_classified_images, n_segments=10, kernel_width=10)


In [None]:
explain_with_lime(correctly_classified_images, n_segments=10, kernel_width=100)


# Experimenting with different segment numbers

In [None]:
explain_with_lime(correctly_classified_images, n_segments=4, kernel_width=10)


In [None]:
explain_with_lime(correctly_classified_images, n_segments=10, kernel_width=10)


In [None]:
explain_with_lime(correctly_classified_images, n_segments=30, kernel_width=10)
