# EXPLAINABLE MACHINE LEARNING

@Author Gabriel Schurr, Ilyesse Hettenbach

### IMPORTS

In [None]:
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from dotenv import load_dotenv
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from tqdm.notebook import tqdm

load_dotenv()


In [None]:
DATAPATH = str(os.getenv("DATAPATH"))
P_LABELS = os.path.join(DATAPATH, "images_labels.txt")
DATAPATH = os.path.join(DATAPATH, "animals")
print(f'Path to images: {DATAPATH}')
print(f'Path to labels: {P_LABELS}')

# P_LABELS = "D:\\Database\\animals\\original\\images_labels.txt"
# DATAPATH = "D:\\Database\\animals\\original\\animals"

# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = 'cpu'


### EDA

In [None]:
data = []
with open(P_LABELS, 'r') as f:
    for line in f:
        image_path, label = line.strip().split(' ')
        data.append({'image_path': image_path, 'label': label})
LABELS = pd.DataFrame(data)
LABELS.head()


In [None]:
LABELS.describe()


In [None]:
fig, axes = plt.subplots(2, 3, figsize=(10, 6))
for i, ax in enumerate(axes.flat):
    random_index = random.randint(0, len(LABELS)-1)
    img = Image.open(LABELS['image_path'][random_index])
    label = LABELS['label'][random_index]
    ax.imshow(img)
    ax.set_title(label)
    ax.axis('off')
plt.tight_layout()
plt.show()


In [None]:
# Visualize class distribution
plt.figure(figsize=(10, 6))
sns.countplot(x='label', data=LABELS)
plt.title('Class Distribution')
plt.xticks(rotation=60)
plt.show()


### MODEL

In [None]:
class CustomResNet18(nn.Module):
    def __init__(self, num_classes=90):
        super(CustomResNet18, self).__init__()
        self.output = None
        self.resnet = models.resnet18(weights='IMAGENET1K_V1') # 'IMAGENET1K_V1'
        # for param in self.resnet.parameters():
        #     param.requires_grad = False
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_features, num_classes)

    def forward(self, x):
        self.output = self.resnet(x)
        return self.output

model = CustomResNet18()
model = model.to(DEVICE)


In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

dataset = ImageFolder(root=DATAPATH, transform=transform)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*0.95), len(dataset)-int(len(dataset)*0.95)])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 2


### TRAINING

In [None]:
with tqdm(total=num_epochs, desc='Training') as pbar:
    for epoch in range(1, num_epochs+1):
        running_loss = 0.0
        for i, data in enumerate(train_loader):
            inputs, labels = data
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            # print(f'[{epoch+1} /{2}, {i+1}/{len(train_loader)}] Loss: {running_loss/(i+1):.3f}')
            pbar.set_postfix({'batch': f'{i+1}/{len(train_loader)}', 'loss': f'{running_loss/(i+1):.3f}'})

        pbar.update(1)


### EVALUATION

In [None]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in val_loader:
        inputs, labels = data
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)

        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted==labels).sum().item()

print(f'Accuracy of the network on the {len(val_loader)} test images: {100*correct/total}%')


### GRAD-CAM

In [None]:
import cv2
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

idx = random.randint(0, len(val_dataset)-1)
# print(f'Index: {idx}')
image = Image.open(LABELS['image_path'][idx])
target = LABELS['label'][idx]

target_layer = model.resnet.layer2

image_tensor = transform(image).unsqueeze(0)
pred = model(image_tensor).argmax().item()

cam = GradCAM(model=model, target_layers=target_layer)
grayscale_cam = cam(input_tensor=image_tensor, targets=None)
grayscale_cam = grayscale_cam[0, :]
grayscale_cam = cv2.resize(grayscale_cam, (224, 224))

image = image_tensor.squeeze().permute(1, 2, 0).cpu().numpy()
image = image - np.min(image)
image = image / np.max(image)

visualization = show_cam_on_image(image, grayscale_cam, use_rgb=True, image_weight=0.5)
plt.imshow(visualization)
plt.show()
