# **Exercise 3: Representation learning for bone fractures**

## Overview

In this assignment you are required to implement a bone fracture xray classification task utilizing a SSL approach with the following data set: https://stanfordmlgroup.github.io/competitions/mura/
"MURA is a dataset of musculoskeletal radiographs consisting of 14,863 studies from 12,173 patients, with a total of 40,561 multi-view radiographic images. Each belongs to one of seven standard upper extremity radiographic study types: elbow, finger, forearm, hand, humerus, shoulder, and wrist. Each study was manually labeled as normal or abnormal by board-certified radiologists from the Stanford Hospital .
To evaluate models and get a robust estimate of radiologist performance, we collected additional labels from six board-certified Stanford radiologists on the test set, consisting of 207 musculoskeletal studies."

<img src="https://github.com/HadarPur/RU-HC-RepresentationLearningforBoneFractures/blob/main/figures/radiologist_result_example.png?raw=true" alt="Image" style="max-width: 500px;" />

## Steps
1. Please perform data exploration and create a naïve baseline (e.g. can be done based on the paper https://arxiv.org/abs/1712.06957, or any another approach you wish).
All steps must include a description of data exploration: data distribution, visualization, thorough evaluation, visualization of results, demonstration of good and bad results.
You can focus on the 3 different bones for example – Elbow, Hand and Shoulder as was done in the example https://github.com/Alkoby/Bone-Fracture-Detection:
- <img src="https://github.com/HadarPur/RU-HC-RepresentationLearningforBoneFractures/blob/main/figures/visualization_example.png?raw=true" alt="Image" style="max-width: 300px;" />

2.  Implement one of the following representation learning approaches listed below and provide a detailed explanation of your approach compared to the baseline (e.g. compare the results when using of 1%,10%,100% of the labeled data as done in https://arxiv.org/pdf/2006.10029.pdf).
  * SimCLR Chen et al. https://github.com/google-research/simclr
  * Byol Grill et al.https://papers.nips.cc/paper/2020/file/f3ada80d5c4ee70142b17b8192b2958e-Paper.pdf
  * Moco He et al. https://arxiv.org/pdf/1911.05722.pdf
  * SimSiam Chen et al. https//arxiv.org/abs/2011.10566


 <font color="Burgundy" size=5> 💡 **important-**</font> please note that if you want to run the notebook it's better to:
 1. restart runtime and than compile from start until baseline (including) sections
 2. run only the part you would like to examine (shoulder, hand, elbow)
 3. after each run restart runtime and go over steps 1,2

# Submitted

*   Shir Nitzan
*   Timor Baruch
*   Hadar Pur

## Imports

In [None]:
!pip install torch torchvision pytorch-lightning

In [None]:
import os
import torch
import multiprocessing
import scipy.ndimage

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import torch.optim as optim

from tqdm import tqdm
from google.colab import drive
from collections import Counter

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from torch.autograd import Variable
from psutil import virtual_memory
from tabulate import tabulate
from PIL import Image
from sklearn.model_selection import train_test_split
from torchmetrics import Accuracy, Precision, Recall, F1Score
from torchvision.models import densenet169
from torchvision.transforms.functional import pad
from PIL import Image
from pytorch_lightning.loggers import TensorBoardLogger

In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
mura_v11_path = '/content/MURA-v1.1'
if os.path.exists(mura_v11_path) == False:
  !gdown 1XjMNPle9fO2NATeXtrIgz6h03LCrwOvN
  !unzip -q '/content/MURA-v1.1.zip'
  print("Done unzip")
else:
  print("Data exist, continue")

In [None]:
print(torch.__version__, torch.cuda.is_available())

## Memory

In [None]:
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

## GPU

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
max_workers = multiprocessing.cpu_count()
print("Maximum number of workers:", max_workers)

## Data

In [None]:
pd.set_option('display.max_colwidth', None)

In [None]:
class DatasetPath:
    def __init__(self, dataset_dir):
        self.dataset_dir = dataset_dir
        self.train_csv_path = f'{dataset_dir}/train_labeled_studies.csv'
        self.test_csv_path = f'{dataset_dir}/valid_labeled_studies.csv'

class ImageProcessor:
    def __init__(self, directory):
        self.directory = directory

    def get_paths(self):
        image_paths = []
        for root, dirs, files in os.walk(self.directory):
            image_paths.extend([os.path.join(root, file) for file in files if file.endswith('.png') and file.startswith('image')])
        return image_paths

class DataTransformer:
    def __init__(self, df):
        self.df = df

    def transform_df(self):
        ts_rows = [{'path': image_path, 'label': row['label']}
                    for _, row in self.df.iterrows()
                    for image_path in ImageProcessor(row['path']).get_paths()]
        return pd.DataFrame(ts_rows, columns=['path', 'label'])

    def transform_ds(self):
        dataset = [(ImageProcessor(row['path']).get_paths(), row['label']) for _, row in self.df.iterrows()]
        label_count = Counter(row['label'] for _, row in self.df.iterrows())
        return dataset, label_count

class BodyPartExtractor:
    @staticmethod
    def extract(path):
        split_path = path.split("/")
        return next((part[3:] for part in split_path if "XR_" in part), None)

class DataframeGenerator:
    def __init__(self):
        pass

    def generate_df(self, path, body_parts, flat=False):
        df = pd.read_csv(path, header=None, names=['path', 'label'])
        if flat:
            df = DataTransformer(df).transform_df()

        df['body_part'] = df['path'].apply(BodyPartExtractor.extract)
        df = df[df['body_part'].isin(body_parts)]
        return df

    def create_body_df(self, path, body_parts):
      datasets = {}

      for body_part in body_parts:
          df = self.generate_df(path, [body_part]).drop(['body_part'], axis=1)
          datasets[body_part] = df

      return datasets[body_parts[0]], datasets[body_parts[1]], datasets[body_parts[2]]

## Data distribution

In [None]:
class SummaryGenerator:
    @staticmethod
    def generate_table(df):
        summary = df.groupby(['body_part', 'label']).size().unstack(fill_value=0)
        summary['Total'] = summary.sum(axis=1)
        summary.columns = ['Normal', 'Abnormal', 'Total']
        summary = summary.reset_index().rename(columns={'body_part': 'Part'})
        summary = summary.sort_values('Part')

        summary.style.set_properties(**{'text-align': 'left'})

        return summary

    @staticmethod
    def plot_summary_table(df, title):
        summary = SummaryGenerator.generate_table(df)
        melted_df = summary.melt(id_vars='Part', var_name='Label', value_name='Count')

        colors = ['#747FE3', '#8EE35D', '#E37346']
        sns.set_palette(colors)  # Set the color palette

        sns.set_style('darkgrid')
        sns.barplot(data=melted_df, x='Part', y='Count', hue='Label')  # Use the palette colors
        plt.xlabel('Body Part')
        plt.ylabel('Count')
        plt.title(f'Distribution of Labels by Body Part - {title}')
        plt.legend(title='')
        plt.xticks(rotation=15)
        plt.show()

## Data Visualization

In [None]:
class ImageDisplay:
    @staticmethod
    def display(df, body_part, title):
        normal_df = df[(df['label'] == 0) & (df['body_part'] == body_part)].sample(5).reset_index(drop=True)
        abnormal_df = df[(df['label'] == 1) & (df['body_part'] == body_part)].sample(5).reset_index(drop=True)

        fig, axs = plt.subplots(nrows=2, ncols=5, figsize=(12, 6))
        axs[0, 0].set_ylabel("Normal", fontsize=15)
        axs[1, 0].set_ylabel("Abnormal", fontsize=15)

        for i, row in normal_df.iterrows():
            image_path = row['path']
            # print(f'normal_df image_path = {image_path}')
            image = Image.open(image_path)
            axs[0, i].imshow(image, cmap='gray')
            axs[0, i].grid(False)

        for i, row in abnormal_df.iterrows():
            image_path = row['path']
            # print(f'abnormal_df image_path = {image_path}')
            image = Image.open(image_path)
            axs[1, i].imshow(image, cmap='gray')
            axs[1, i].grid(False)

        fig.suptitle(title, fontsize=20)
        plt.tight_layout()

        plt.show()

## Data Augmentation

In [None]:
class ImageAugmentor:
    def __init__(self, do_flip=True, do_rotate=True, do_scale=False, do_translate=True):
        self.do_flip = do_flip
        self.do_rotate = do_rotate
        self.do_scale = do_scale
        self.do_translate = do_translate

        self.transformer = []

    def augment(self):
        if self.do_flip:
            self.transformer.append(self.flip())
        if self.do_rotate:
            self.transformer.append(self.rotate())
        if self.do_scale:
            self.transformer.append(self.scale())
        if self.do_translate:
            self.transformer.append(self.translate())
        return self.transformer

    def flip(self):
        return transforms.RandomHorizontalFlip()

    def rotate(self):
        return transforms.RandomRotation(10)

    def scale(self):
        return transforms.Resize((128, 128))

    def translate(self):
        return transforms.ToTensor()

## Data Normalization

In [None]:
class ImageNormalizer:
    def __init__(self, normalization_type="none"):
        assert normalization_type in ["zscore", "percentile", "none"], "Invalid normalization type"
        self.normalization_type = normalization_type

    def normalize(self):
        if self.normalization_type == "zscore":
            return self.z_score_normalization()
        elif self.normalization_type == "percentile":
            return self.percentile_normalization()
        elif self.normalization_type == "none":
            return []

    def z_score_normalization(self):
        return [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]

    def percentile_normalization(self):
        return [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]

## Preparing Data
The body parts are ['ELBOW', 'HAND', 'SHOULDER']

In [None]:
dataset_dir = '/content/MURA-v1.1'
dataset = DatasetPath(dataset_dir)
print(dataset.train_csv_path)
print(dataset.test_csv_path)

In [None]:
body_parts = ['ELBOW', 'HAND', 'SHOULDER']
dataframe_generator = DataframeGenerator()

In [None]:
train_df = dataframe_generator.generate_df(dataset.train_csv_path, body_parts, flat=True)
print(f'train_df = {len(train_df)}')

In [None]:
test_df = dataframe_generator.generate_df(dataset.test_csv_path, body_parts, flat=True)
print(f'test_df = {len(test_df)}')

In [None]:
train_df_elbow, train_df_hand, train_df_shoulder = dataframe_generator.create_body_df(dataset.train_csv_path, body_parts)
test_df_elbow, test_df_hand, test_df_shoulder = dataframe_generator.create_body_df(dataset.test_csv_path, body_parts)

In [None]:
print(len(train_df_elbow))

#### Training Data Visualization

In [None]:
train_df.head()

In [None]:
SummaryGenerator.generate_table(train_df)

In [None]:
SummaryGenerator.plot_summary_table(train_df, 'train')

In [None]:
train_df_elbow.head()

In [None]:
ImageDisplay.display(df = train_df, body_part = body_parts[0], title = "Train Data - Elbow")

In [None]:
train_df_hand.head()

In [None]:
ImageDisplay.display(df = train_df, body_part = body_parts[1], title = "Train Data - Hand")

In [None]:
train_df_shoulder.head()

In [None]:
ImageDisplay.display(df = train_df, body_part = body_parts[2], title = "Train Data - Shoulder")

#### Test Data Visualization

In [None]:
test_df.head()

In [None]:
SummaryGenerator.generate_table(test_df)

In [None]:
SummaryGenerator.plot_summary_table(test_df, 'validation')

In [None]:
test_df_elbow.head()

In [None]:
ImageDisplay.display(df = test_df, body_part = body_parts[0], title = "Test Data - Elbow")

In [None]:
test_df_hand.head()

In [None]:
ImageDisplay.display(df = test_df, body_part = body_parts[1], title = "Test Data - Hand")

In [None]:
test_df_shoulder.head()

In [None]:
ImageDisplay.display(df = test_df, body_part = body_parts[2], title = "Test Data - Shoulder")

## Baseline
To build the naive baseline we used the following references:


*   https://github.com/pyaf/DenseNet-MURA-PyTorch.git
*   https://github.com/Hawk453/MURA-DenseNet-Humerus.git



In [None]:
class Baseline(Dataset):
    def __init__(self, data, transform=None):
        self.transform = transform
        self.dataset, self.label_count = DataTransformer(data).transform_ds()

    def __len__(self):
        return len(self.dataset)

    def get_label_weight(self, label):
        return self.label_count[label] / len(self.dataset)

    def __getitem__(self, idx):
        image_paths, label = self.dataset[idx]
        images = [self.transform(Image.open(image_path).convert("RGB")) if self.transform else Image.open(image_path).convert("RGB") for image_path in image_paths]
        images = torch.stack(images)
        return images, len(image_paths), label

In [None]:
class Loss(nn.Module):
    def __init__(self, Wt1, Wt0):
        super(Loss, self).__init__()
        self.Wt1 = Wt1
        self.Wt0 = Wt0

    def forward(self, y_hat, y):
        loss = torch.mean(-(self.Wt1 * y * y_hat.log() + self.Wt0 * (1 - y) * (1 - y_hat).log()))
        return loss

In [None]:
class DenseNet169(nn.Module):
    def __init__(self):
        super(DenseNet169, self).__init__()
        self.model = densenet169(pretrained=True)
        num_ftrs = self.model.classifier.in_features
        self.model.classifier = nn.Sequential(
            nn.Linear(num_ftrs, 1),
            nn.Sigmoid()
        )

    def forward(self, views_x, n_views):
        batch_size, padded_views, c, w, h = views_x.size()
        views_x = views_x.view(-1, c, w, h)

        multi_view_outputs = self.model(views_x)
        multi_view_outputs = multi_view_outputs.view(batch_size, padded_views)

        outputs = torch.sum(multi_view_outputs * n_views, dim=1) / torch.sum(n_views, dim=1)

        return outputs

class LightingDenseNet(pl.LightningModule):
    def __init__(self, loss):
        super(LightingDenseNet, self).__init__()
        self.densenet_model = DenseNet169()
        self.loss = loss
        self.metrics = {
            "accuracy": Accuracy(task="binary"),
            "precision": Precision(task="binary"),
            "recall": Recall(task="binary"),
            "f1_score": F1Score(task="binary"),
        }

    def forward(self, views_x, n_views):
        return self.densenet_model(views_x, n_views)

    def training_step(self, batch, batch_idx):
        views_x, n_views, y = batch

        y_hat = self(views_x, n_views)
        loss = self.loss(y_hat, y)

        self.log('train_loss', loss)
        for metric_name, metric in self.metrics.items():
            self.log(f'train_{metric_name}', metric(y_hat.cpu(), y.cpu()))

        return loss

    def validation_step(self, batch, batch_idx):
        views_x, n_views, y = batch

        y_hat = self(views_x, n_views)
        loss = self.loss(y_hat, y)

        self.log('val_loss', loss)
        for metric_name, metric in self.metrics.items():
            self.log(f'val_{metric_name}', metric(y_hat.cpu(), y.cpu()))

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=2e-5)

In [None]:
class DataLoaderManager:
    def __init__(self,batch_size=8, num_workers=4):
        self.batch_size = batch_size
        self.num_workers = num_workers

    def get_train_transformer(self):
        image_augmentation = ImageAugmentor(do_flip=True, do_rotate=True, do_scale=True, do_translate=True).augment()
        image_normalization = ImageNormalizer(normalization_type = "zscore").normalize()
        return transforms.Compose(image_augmentation + image_normalization)

    def get_test_transformer(self):
        image_augmentation = ImageAugmentor(do_flip=False, do_rotate=False, do_scale=True, do_translate=True).augment()
        image_normalization = ImageNormalizer(normalization_type = "zscore").normalize()
        return transforms.Compose(image_augmentation + image_normalization)

    @staticmethod
    def collate_fn(batch):
        max_views = max([item[0].shape[0] for item in batch])
        images_list, n_views_list, labels_list = [], [], []

        for item in batch:
            images, n_views, label = item
            images = nn.functional.pad(images, (0, 0, 0, 0, 0, 0, 0, max_views - images.shape[0]))
            images_list.append(images)
            n_views_list.append([1] * n_views + [0] * (max_views - n_views))
            labels_list.append(label)

        images_tensor = torch.stack(images_list)
        n_views_tensor = torch.tensor(n_views_list)
        labels_tensor = torch.tensor(labels_list)

        return images_tensor, n_views_tensor, labels_tensor

In [None]:
class TrainerManager:
    def __init__(self, max_epochs=10):
        self.max_epochs = max_epochs

    def train(self, model, train_loader, val_loader, dir_path):
        logger = TensorBoardLogger(save_dir=dir_path, name="MORA V1.1")

        self.trainer = pl.Trainer(max_epochs=self.max_epochs, logger=logger)
        self.trainer.fit(model, train_loader, val_loader)

    def validate(self, model, test_loader):
        self.trainer.validate(model, test_loader)

In [None]:
data_loader_manager = DataLoaderManager(batch_size=8, num_workers=4)

train_val_transformer = data_loader_manager.get_train_transformer()
test_transformer = data_loader_manager.get_test_transformer()

### Elbow

#### Training

In [None]:
train_df_elbow, val_df_elbow = train_test_split(train_df_elbow, test_size=0.2)

In [None]:
train_dataset_elbow = Baseline(train_df_elbow, transform=train_val_transformer)
val_dataset_elbow = Baseline(val_df_elbow, transform=test_transformer)
test_dataset_elbow = Baseline(test_df_elbow, transform=test_transformer)

train_loader_elbow = torch.utils.data.DataLoader(train_dataset_elbow, collate_fn=data_loader_manager.collate_fn,
                                           batch_size=8, shuffle=True, num_workers=6)
val_loader_elbow = torch.utils.data.DataLoader(val_dataset_elbow, collate_fn=data_loader_manager.collate_fn,
                                         batch_size=16, num_workers=6)
test_loader_elbow = torch.utils.data.DataLoader(test_dataset_elbow, collate_fn=data_loader_manager.collate_fn,
                                         batch_size=16, num_workers=6)

loss_elbow = Loss(train_dataset_elbow.get_label_weight(1), train_dataset_elbow.get_label_weight(0))

In [None]:
model_elbow = LightingDenseNet(loss_elbow)

In [None]:
elbow_path = '/content/baseline/training/elbow'
trainer_manager_elbow = TrainerManager(max_epochs=15)
trainer_manager_elbow.train(model_elbow, train_loader_elbow, val_loader_elbow, elbow_path)

In [None]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir /content/baseline/training/elbow

#### Test

In [None]:
trainer_manager_elbow.validate(model_elbow, test_loader_elbow)

### Hand

#### Training

In [None]:
train_df_hand, val_df_hand = train_test_split(train_df_hand, test_size=0.2)

In [None]:
train_dataset_hand = Baseline(train_df_hand, transform=train_val_transformer)
val_dataset_hand = Baseline(val_df_hand, transform=test_transformer)
test_dataset_hand = Baseline(test_df_hand, transform=test_transformer)

train_loader_hand = torch.utils.data.DataLoader(train_dataset_hand, collate_fn=data_loader_manager.collate_fn,
                                           batch_size=8, shuffle=True, num_workers=6)
val_loader_hand = torch.utils.data.DataLoader(val_dataset_hand, collate_fn=data_loader_manager.collate_fn,
                                         batch_size=16, num_workers=6)
test_loader_hand = torch.utils.data.DataLoader(test_dataset_hand, collate_fn=data_loader_manager.collate_fn,
                                         batch_size=16, num_workers=6)

loss_hand = Loss(train_dataset_hand.get_label_weight(1), train_dataset_hand.get_label_weight(0))

In [None]:
model_hand = LightingDenseNet(loss_hand)

In [None]:
hand_path = '/content/baseline/training/hand'

trainer_manager_hand = TrainerManager(max_epochs=15)
trainer_manager_hand.train(model_hand, train_loader_hand, val_loader_hand, hand_path)

In [None]:
# Start tensorboard.
%reload_ext tensorboard
%tensorboard --logdir /content/baseline/training/hand

#### Test

In [None]:
trainer_manager_hand.validate(model_hand, test_loader_hand)

### Shoulder

#### Training

In [None]:
train_df_shoulder, val_df_shoulder = train_test_split(train_df_shoulder, test_size=0.2)

In [None]:
train_dataset_shoulder = Baseline(train_df_shoulder, transform=train_val_transformer)
val_dataset_shoulder = Baseline(train_df_shoulder, transform=test_transformer)
test_dataset_shoulder = Baseline(test_df_shoulder, transform=test_transformer)

train_loader_shoulder = torch.utils.data.DataLoader(train_dataset_shoulder, collate_fn=data_loader_manager.collate_fn,
                                           batch_size=8, shuffle=True, num_workers=6)
val_loader_shoulder = torch.utils.data.DataLoader(val_dataset_shoulder, collate_fn=data_loader_manager.collate_fn,
                                         batch_size=16, num_workers=6)
test_loader_shoulder = torch.utils.data.DataLoader(test_dataset_shoulder, collate_fn=data_loader_manager.collate_fn,
                                        batch_size=16, num_workers=6)

loss_shoulder = Loss(train_dataset_shoulder.get_label_weight(1), train_dataset_shoulder.get_label_weight(0))

In [None]:
model_shoulder = LightingDenseNet(loss_shoulder)

In [None]:
shoulder_path = '/content/baseline/training/shoulder'

trainer_manager_shoulder = TrainerManager(max_epochs=15)
trainer_manager_shoulder.train(model_shoulder, train_loader_shoulder, val_loader_shoulder, shoulder_path)

In [None]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir /content/baseline/training/shoulder --port=8017

#### Test

In [None]:
trainer_manager_shoulder.validate(model_shoulder, test_loader_shoulder)