# Welcome to AI in Medicine Workshop 2 - Image Processing with FastAI

FastAI is not just a python library. They are a non-profit organisation for developers and deep learners.

They have a great free course that is an excellent way to start your deep learning journey.

Useful links:

- [FastAI Documentation](https://docs.fast.ai)
- [Practical Deep Learning for Coders Course](https://www.youtube.com/watch?v=8SF_h3xF3cE&list=PLfYUBJiXbdtSvpQjSnJJ_PmDQB_VyT5iU)
- [Deep Learning Book](https://www.amazon.co.uk/Deep-Learning-Coders-fastai-PyTorch/dp/1492045527/ref=sr_1_1?crid=3P2KJ1MZW0O4V&dib=eyJ2IjoiMSJ9.Itvh5Q_R4WLygyGf-JgrShDgDSbUkJvYejPU6o6dLn7ylDq0uvHVsRunQ0fKOQU2Lnb0-wYR2m957IVRtJn2i8zPadtA1l8mS21DiDCi8H15nbqFkhCWs_3BAOnFOefi5S-zyja3Z4lHY8LV61-fjJvwwdd0akZBp99UBVtUERviMRY3n8Byddt93KfZTntemuI1Yjj-kJbCKaUSLPEVeqFS7ebpDM1IOUEAdLTfPMQ.ueQIs921Bk3R2fwbXn-PhcRwQc7HCleCEMdz7C4zaO4&dib_tag=se&keywords=deep+learning+for+fastai&nsdOptOutParam=true&qid=1732919951&sprefix=deep+learning+for+fastai%2Caps%2C90&sr=8-1&ufe=INHOUSE_INSTALLMENTS%3AUK_IHI_3M_AUTOMATED)

### Workshop Lead: Aria Torkpour

- [Email](mailto:at1120@ic.ac.uk)

- [LinkedIn](www.linkedin.com/in/aria-torkpour)

- [Github](https://github.com/ArT0r0)

#### [Imperial College Medics' Coding Society](https://www.imperialcollegeunion.org/activities/a-to-z/coding-icsm) Beyond the Stethoscope: AI in Medicine Conference

In [None]:
!pip install -Uqq fastbook

## Data collection

In [None]:
from fastbook import *

In [None]:
ims = search_images_ddg('jaundice eye', max_images=100)
len(ims),ims[0]

In [None]:
destination = Path('jaundice.jpg')
if not destination.exists():
    download_url(ims[0], destination)

In [None]:
im = Image.open(destination)
im.to_thumb(255,255)

In [None]:
shutil.rmtree('data', ignore_errors=True)
searches = 'jaundice eye','normal eye'
path = Path('data')

if not path.exists():
    for i in searches:
        destination = (path/i)
        destination.mkdir(exist_ok=True, parents=True)
        results = search_images_ddg([i])
        download_images(destination, urls=search_images_ddg([i]))
        resize_images(destination, max_size=400, dest=destination)

In [None]:
failed = verify_images(get_image_files(path))
failed.map(Path.unlink)
len(failed)

In [None]:
dls = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_items=get_image_files,
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=parent_label,
    item_tfms=Resize(192, method='squish')
).dataloaders(path)

dls.show_batch(max_n=10)

## Model Training

In [None]:
learn = vision_learner(dls, resnet34, metrics=error_rate)
learn.fine_tune(9)

## Model Evaluation

In [None]:
print(dls.vocab)

In [None]:
is_jaundiced_eye, _, probs = learn.predict(PILImage.create('jaundice.jpg'))

print(f"This is a: {is_jaundiced_eye}.")
print(f"Probability it's a jaundice eye: {probs[0]:.4f}")

In [None]:
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

In [None]:
interp.plot_top_losses(5, nrows=2, figsize=(20,4))
#prediction/actual/loss/probability

In [None]:
!pip install -Uqq fastai
from fastai.vision.all import *
from fastai.callback.fp16 import *
from fastai.vision.widgets import *
import torch
print(torch.__version__)

## Data Cleaning

In [None]:
cleaner = ImageClassifierCleaner(learn)
cleaner

In [None]:
for idx in cleaner.delete(): cleaner.fns[idx].unlink()
for idx,cat in cleaner.change(): shutil.move(str(cleaner.fns[idx]), path/cat)

In [None]:
dls = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_items=get_image_files,
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=parent_label,
    item_tfms=Resize(192, method='pad' , pad_mode = 'zeros')
).dataloaders(path)

dls.show_batch(max_n=10)

## Model Retraining and Re-evaluation

In [None]:
learn = vision_learner(dls, resnet34, metrics=error_rate)
learn.fine_tune(9)

In [None]:
is_jaundiced_eye, _, probs = learn.predict(PILImage.create('jaundice.jpg'))

print(f"This is a: {is_jaundiced_eye}.")
print(f"Probability it's a jaundice eye: {probs[0]:.4f}")

In [None]:
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

### [Useful link: Comparing different pre-trained models for your backbone](https://www.kaggle.com/code/jhoward/which-image-models-are-best/code)

In [None]:
!pip install timm
import timm

In [None]:
path = Path('data')
dls = ImageDataLoaders.from_folder(path, valid_pct=0.2,
                                   label_func = parent_label,
                                   item_tfms=Resize(224))
dls.show_batch(max_n=10)

In [None]:
timm.list_models()

## Using a different Backbone

In [None]:
learn = vision_learner(dls, 'convnext_tiny', metrics=error_rate)
learn.fine_tune(9)

In [None]:
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

In [None]:
is_jaundiced_eye, _, probs = learn.predict(PILImage.create('jaundice.jpg'))

print(f"This is a: {is_jaundiced_eye}.")
print(f"Probability it's a jaundice eye: {probs[0]:.4f}")

## What is the model looking at?

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from fastai.vision.all import first, PILImage

def saliency_map(learn, img_path):
    """
    Function to generate a saliency map for the inferred image
    """
    img = PILImage.create(img_path)
    x, = first(learn.dls.test_dl([img]))

    print(f"Initial tensor shape: {x.shape}")

    # input shape should be (batch_size, channels, height, width)
    if len(x.shape) > 4:
        x = x.reshape(-1, *x.shape[-3:])
    print(f"Reshaped tensor shape: {x.shape}")

    # Moving device
    device = next(learn.model.parameters()).device
    x = x.to(device)
    x.requires_grad_()

    # Forward pass
    learn.model.eval()
    with torch.enable_grad():
        output = learn.model(x)

    # Get prediction probability for jaundice
    class_score = output[0, 1]
    class_score.backward()

    saliency = x.grad.abs()
    saliency = saliency.mean(dim=1)
    saliency = saliency.squeeze()
    saliency = saliency.cpu().detach().numpy()
    saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min() + 1e-8)

    is_jaundiced_eye, _, probs = learn.predict(img)

    return img, is_jaundiced_eye, probs[0], saliency

def plot_saliency_results(img, prediction, probability, saliency):
    """Return initial image and saliency map"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

    # Original Image
    ax1.imshow(img)
    ax1.axis('off')
    ax1.set_title(f'Original Image\nPrediction: {prediction}\nProbability: {probability:.4f}')

    # Saliency map
    ax2.imshow(saliency, cmap='hot')
    ax2.axis('off')
    ax2.set_title('Saliency Map')

    plt.tight_layout()
    plt.show()

# Output
try:
    img, prediction, prob, saliency = get_saliency_map(learn, 'jaundice.jpg')
    plot_saliency_results(img, prediction, prob, saliency)
except Exception as e:
    print(f"Error occurred: {str(e)}")

## Image transformations

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from fastai.vision.all import get_image_files, PILImage
from PIL import Image
import ipywidgets as widgets
from IPython.display import clear_output, display

def image_transform_widget(path):
    """Create an interactive widget to view different image transformations"""
    image_files = get_image_files(path)[:4]
    methods = ['squish', 'crop', 'pad', 'mirror']

    image_dropdown = widgets.Dropdown(
        options=[(f.name, f) for f in image_files],
        description='Image:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='50%')
    )

    size_slider = widgets.IntSlider(
        value=224,
        min=32,
        max=1024,
        step=32,
        description='Size:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='50%'),
        continuous_update=False
    )

    size_input = widgets.BoundedIntText(
        value=224,
        min=32,
        max=1024,
        step=1,
        description='Custom size:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='50%')
    )

    method_buttons = widgets.ToggleButtons(
        options=methods,
        description='Method:',
        style={'description_width': 'initial'},
        button_style='info'
    )

    def resize_and_pad(img, target_size, method='pad_zeros'):
        """Resize image to target size using specified padding method"""
        if isinstance(img, PILImage):
            img = np.array(img)

        # Aspect Ratio
        h, w = img.shape[:2]
        aspect = w / h

        # New dimensions and Aspect ratio
        if w > h:
            new_w = target_size
            new_h = int(target_size / aspect)
        else:
            new_h = target_size
            new_w = int(target_size * aspect)

        img_pil = Image.fromarray(img)
        img_resized = img_pil.resize((new_w, new_h), Image.Resampling.LANCZOS)
        img = np.array(img_resized)

        pad_h = target_size - new_h
        pad_w = target_size - new_w

        top = pad_h // 2
        bottom = pad_h - top
        left = pad_w // 2
        right = pad_w - left

        padding = ((top, bottom), (left, right), (0, 0)) if len(img.shape) == 3 else ((top, bottom), (left, right))

        if method == 'mirror':
            padded = np.pad(img, padding, mode='reflect')
        elif method == 'pad_zeros':
            padded = np.pad(img, padding, mode='constant', constant_values=0)
        else:
            padded = np.pad(img, padding, mode='constant', constant_values=0)

        return padded

    def apply_transforms(image_path, size, method):
        """Applying transformation"""
        img = PILImage.create(image_path)
        print(f"Applying {method} transform to size {size}x{size}")

        fig, axes = plt.subplots(1, 2, figsize=(15, 7))

        # Original image
        axes[0].imshow(img)
        axes[0].set_title(f'Original {np.array(img).shape[:2]}')
        axes[0].axis('off')

        # Apply transformation
        if method == 'squish':
            img_pil = Image.fromarray(np.array(img))
            transformed = np.array(img_pil.resize((size, size), Image.Resampling.LANCZOS))
            title = f'Squished ({size}x{size})'
        elif method == 'crop':
            img_array = np.array(img)
            h, w = img_array.shape[:2]
            aspect = w / h
            if w > h:
                new_w = int(size * aspect)
                new_h = size
            else:
                new_h = int(size / aspect)
                new_w = size
            img_pil = Image.fromarray(img_array)
            img_resized = img_pil.resize((new_w, new_h), Image.Resampling.LANCZOS)
            left = (new_w - size) // 2
            top = (new_h - size) // 2
            transformed = np.array(img_resized.crop((left, top, left + size, top + size)))
            title = f'Cropped ({size}x{size})'
        else:
            transformed = resize_and_pad(img, size, method)
            title = f'{method.replace("_", " ").title()} ({size}x{size})'

        # Display transformation
        axes[1].imshow(transformed)
        axes[1].set_title(title)
        axes[1].axis('off')

        if method in ['pad', 'pad_zeros', 'mirror']:
            axes[1].grid(True, color='gray', linestyle='-', linewidth=0.5, alpha=0.3)

        plt.tight_layout()
        plt.show()

        print(f"\nDimension Details:")
        print(f"Original dimensions: {np.array(img).shape[:2]}")
        print(f"Target size: {size}x{size}")
        print(f"Final dimensions: {transformed.shape[:2]}")

    def on_size_change(change):
        if change['owner'] is size_slider:
            size_input.value = change['new']
        else:
            size_slider.value = change['new']

    def on_change(change):
        clear_output(wait=True)
        display(widgets.VBox([image_dropdown,
                            widgets.HBox([size_slider, size_input]),
                            method_buttons]))
        apply_transforms(
            image_dropdown.value,
            size_input.value,
            method_buttons.value
        )

    # Listeners
    image_dropdown.observe(on_change, names='value')
    size_slider.observe(on_size_change, names='value')
    size_input.observe(on_size_change, names='value')
    method_buttons.observe(on_change, names='value')

    display(widgets.VBox([image_dropdown,
                         widgets.HBox([size_slider, size_input]),
                         method_buttons]))
    apply_transforms(
        image_dropdown.value,
        size_input.value,
        method_buttons.value
    )

image_transform_widget(path)