# **Exercise 3: Representation learning for bone fractures**

## Overview

In this assignment you are required to implement a bone fracture xray classification task utilizing a SSL approach with the following data set: https://stanfordmlgroup.github.io/competitions/mura/
"MURA is a dataset of musculoskeletal radiographs consisting of 14,863 studies from 12,173 patients, with a total of 40,561 multi-view radiographic images. Each belongs to one of seven standard upper extremity radiographic study types: elbow, finger, forearm, hand, humerus, shoulder, and wrist. Each study was manually labeled as normal or abnormal by board-certified radiologists from the Stanford Hospital .
To evaluate models and get a robust estimate of radiologist performance, we collected additional labels from six board-certified Stanford radiologists on the test set, consisting of 207 musculoskeletal studies."

<img src="https://github.com/HadarPur/RU-HC-RepresentationLearningforBoneFractures/blob/main/figures/radiologist_result_example.png?raw=true" alt="Image" style="max-width: 500px;" />

## Steps
1. Please perform data exploration and create a naïve baseline (e.g. can be done based on the paper https://arxiv.org/abs/1712.06957, or any another approach you wish).
All steps must include a description of data exploration: data distribution, visualization, thorough evaluation, visualization of results, demonstration of good and bad results.
You can focus on the 3 different bones for example – Elbow, Hand and Shoulder as was done in the example https://github.com/Alkoby/Bone-Fracture-Detection:
- <img src="https://github.com/HadarPur/RU-HC-RepresentationLearningforBoneFractures/blob/main/figures/visualization_example.png?raw=true" alt="Image" style="max-width: 300px;" />

2.  Implement one of the following representation learning approaches listed below and provide a detailed explanation of your approach compared to the baseline (e.g. compare the results when using of 1%,10%,100% of the labeled data as done in https://arxiv.org/pdf/2006.10029.pdf).
  * SimCLR Chen et al. https://github.com/google-research/simclr
  * Byol Grill et al.https://papers.nips.cc/paper/2020/file/f3ada80d5c4ee70142b17b8192b2958e-Paper.pdf
  * Moco He et al. https://arxiv.org/pdf/1911.05722.pdf
  * SimSiam Chen et al. https//arxiv.org/abs/2011.10566


# Submitted

*   Shir Nitzan
*   Timor Baruch
*   Hadar Pur

## Imports

In [None]:
!pip install torch torchvision pytorch-lightning

In [None]:
import os
import torch
import multiprocessing
import scipy.ndimage

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import torch.optim as optim
import tensorflow as tf
import tensorflow_datasets as tfds

from tqdm import tqdm
from google.colab import drive
from collections import Counter

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from torch.autograd import Variable
from psutil import virtual_memory
from tabulate import tabulate
from PIL import Image
from sklearn.model_selection import train_test_split
from torchmetrics import Accuracy, Precision, Recall, F1Score
from torchvision.models import densenet169
from torchvision.transforms.functional import pad
from PIL import Image


In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
mura_v11_path = '/content/MURA-v1.1'
if os.path.exists(mura_v11_path) == False:
  !gdown 1XjMNPle9fO2NATeXtrIgz6h03LCrwOvN
  !unzip -q '/content/MURA-v1.1.zip'
  print("Done unzip")
else:
  print("Data exist, continue")

In [None]:
print(torch.__version__, torch.cuda.is_available())

In [None]:
# change folder name from valid to test
!mv /content/MURA-v1.1/valid /content/MURA-v1.1/test

## Memory

In [None]:
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

## GPU

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
max_workers = multiprocessing.cpu_count()
print("Maximum number of workers:", max_workers)

## Prepare data for training

In [None]:
import os
import random
from sklearn.model_selection import train_test_split
import shutil

def split_data(mode):
    # Path to the folder containing the train examples
    train_folder = "/content/MURA-v1.1/train" + "/" + mode

    # Path to the folder where you want to save the validation images
    validation_output_folder = "/content/MURA-v1.1/valid" + "/" + mode

    # Get the list of all image files in the train folder
    image_files = [f for f in os.listdir(train_folder) if os.path.isdir(os.path.join(train_folder, f))]

    # Split the image files into training and validation sets
    if len(image_files) > 0:
        train_files, validation_files = train_test_split(image_files, test_size=0.2, random_state=42)

        # Move the validation files to the validation output folder
        for file in validation_files:
            src = os.path.join(train_folder, file)
            dst = os.path.join(validation_output_folder, file)
            shutil.move(src, dst)
    else:
        print("No image files found in the train folder.")



In [None]:
split_data('XR_ELBOW')
split_data('XR_HAND')
split_data('XR_SHOULDER')

In [None]:
def split_data_pct(mode, pct):
    # Path to the folder containing the train examples
    train_folder = "/content/MURA-v1.1/train" + "/" + mode

    # Get the list of all image files in the train folder
    image_files = [f for f in os.listdir(train_folder) if os.path.isdir(os.path.join(train_folder, f))]

    # Calculate the number of files based on the percentage
    num_files = len(image_files)
    num_files_pct = int(num_files * (pct / 100.0))

    print(f"Total num of files is {num_files}")

    # Randomly select the files for the desired percentage
    selected_files = random.sample(image_files, num_files_pct)

    # Move the selected files to the desired output folder
    output_folder = f"/content/MURA-v1.1/train_{pct}" + "/" + mode
    os.makedirs(output_folder, exist_ok=True)
    for file in selected_files:
        src = os.path.join(train_folder, file)
        dst = os.path.join(output_folder, file)
        shutil.copytree(src, dst)

    # Print the summary
    print(f"Split {pct}%: {num_files_pct} files moved to {output_folder}")

In [None]:
# Split the data into different percentages
split_data_pct("XR_ELBOW", 1)
split_data_pct("XR_ELBOW", 10)
split_data_pct("XR_ELBOW", 100)

In [None]:
# Split the data into different percentages
split_data_pct("XR_HAND", 1)
split_data_pct("XR_HAND", 10)
split_data_pct("XR_HAND", 100)

In [None]:
# Split the data into different percentages
split_data_pct("XR_SHOULDER", 1)
split_data_pct("XR_SHOULDER", 10)
split_data_pct("XR_SHOULDER", 100)

## SimCLR

In [None]:
!git clone https://github.com/google-research/simclr.git

In [None]:
!rm -rf /root/tensorflow_datasets

In [None]:
!pip install tensorflow_datasets

In [None]:
!tfds --version

## Preparing Data
The body parts are ['ELBOW', 'HAND', 'SHOULDER']

### Pre-Proccessing


In [None]:
%cd /content/simclr

#### Elbow

###### Elbow 100

In [None]:
!rm -rf /content/simclr/elbow_custom

In [None]:
%cd /content/simclr

In [None]:
!tfds new elbow_custom

In [None]:
%cd /content/simclr/elbow_custom

In [None]:
%%writefile elbow_custom_dataset_builder.py

import tensorflow_datasets as tfds
import csv
import tensorflow as tf
from pathlib import Path
import numpy as np
import os
import random

class ElbowBuilder(tfds.core.GeneratorBasedBuilder):
  """DatasetBuilder for elbow_custom dataset."""
  MANUAL_DOWNLOAD_INSTRUCTIONS = "/content/simclr"

  VERSION = tfds.core.Version('1.0.0')
  RELEASE_NOTES = {'1.0.0': 'Initial release.'}

  # Implement the required abstract methods
  def _info(self):
      # Define the dataset information
      features = tfds.features.FeaturesDict({
          'image': tfds.features.Image(shape=(None, None, 3)),
          'label': tfds.features.ClassLabel(num_classes=2),
      })
      return tfds.core.DatasetInfo(
          builder=self,
          description='My custom elbow dataset',
          features=features,
          supervised_keys=('image', 'label'),
      )

  def _split_generators(self, dl_manager: tfds.download.DownloadManager):
      # Specify the splits
      return {
          'train': self._generate_examples(Path('/content/MURA-v1.1/train/XR_ELBOW')),
          'test': self._generate_examples(Path('/content/MURA-v1.1/test/XR_ELBOW')),
          'valid': self._generate_examples(Path('/content/MURA-v1.1/valid/XR_ELBOW'))
      }

  def _generate_examples(self, path):
      for file_name in path.glob('*/*/*.png'):
          if 'negative' in str(file_name):
              label = 0
          elif 'positive' in str(file_name):
              label = 1

          yield str(file_name), {
              'image': file_name,
              'label': label,
          }

In [None]:
!tfds build

###### Elbow 10

In [None]:
!rm -rf /content/simclr/elbow_custom_10

In [None]:
%cd /content/simclr

In [None]:
!tfds new elbow_custom_10

In [None]:
%cd /content/simclr/elbow_custom_10

In [None]:
%%writefile elbow_custom_10_dataset_builder.py

import tensorflow_datasets as tfds
import csv
import tensorflow as tf
from pathlib import Path
import numpy as np
import os
import random

class ElbowBuilder10(tfds.core.GeneratorBasedBuilder):
  """DatasetBuilder for elbow_custom_10 dataset."""
  MANUAL_DOWNLOAD_INSTRUCTIONS = "/content/simclr"

  VERSION = tfds.core.Version('1.0.0')
  RELEASE_NOTES = {'1.0.0': 'Initial release.'}

  # Implement the required abstract methods
  def _info(self):
      # Define the dataset information
      features = tfds.features.FeaturesDict({
          'image': tfds.features.Image(shape=(None, None, 3)),
          'label': tfds.features.ClassLabel(num_classes=2),
      })
      return tfds.core.DatasetInfo(
          builder=self,
          description='My custom elbow dataset',
          features=features,
          supervised_keys=('image', 'label'),
      )

  def _split_generators(self, dl_manager: tfds.download.DownloadManager):
      # Specify the splits
      return {
          'train': self._generate_examples(Path('/content/MURA-v1.1/train_10/XR_ELBOW')),
          'test': self._generate_examples(Path('/content/MURA-v1.1/test/XR_ELBOW')),
          'valid': self._generate_examples(Path('/content/MURA-v1.1/valid/XR_ELBOW'))
      }

  def _generate_examples(self, path):
      for file_name in path.glob('*/*/*.png'):
          if 'negative' in str(file_name):
              label = 0
          elif 'positive' in str(file_name):
              label = 1

          yield str(file_name), {
              'image': file_name,
              'label': label,
          }

In [None]:
!tfds build

###### Elbow 1

In [None]:
!rm -rf /content/simclr/elbow_custom_1

In [None]:
%cd /content/simclr

In [None]:
!tfds new elbow_custom_1

In [None]:
%cd /content/simclr/elbow_custom_1

In [None]:
%%writefile elbow_custom_1_dataset_builder.py

import tensorflow_datasets as tfds
import csv
import tensorflow as tf
from pathlib import Path
import numpy as np
import os
import random

class ElbowBuilder1(tfds.core.GeneratorBasedBuilder):
  """DatasetBuilder for elbow_custom_1 dataset."""
  MANUAL_DOWNLOAD_INSTRUCTIONS = "/content/simclr"

  VERSION = tfds.core.Version('1.0.0')
  RELEASE_NOTES = {'1.0.0': 'Initial release.'}

  # Implement the required abstract methods
  def _info(self):
      # Define the dataset information
      features = tfds.features.FeaturesDict({
          'image': tfds.features.Image(shape=(None, None, 3)),
          'label': tfds.features.ClassLabel(num_classes=2),
      })
      return tfds.core.DatasetInfo(
          builder=self,
          description='My custom elbow dataset',
          features=features,
          supervised_keys=('image', 'label'),
      )

  def _split_generators(self, dl_manager: tfds.download.DownloadManager):
      # Specify the splits
      return {
          'train': self._generate_examples(Path('/content/MURA-v1.1/train_1/XR_ELBOW')),
          'test': self._generate_examples(Path('/content/MURA-v1.1/test/XR_ELBOW')),
          'valid': self._generate_examples(Path('/content/MURA-v1.1/valid/XR_ELBOW'))
      }

  def _generate_examples(self, path):
      for file_name in path.glob('*/*/*.png'):
          if 'negative' in str(file_name):
              label = 0
          elif 'positive' in str(file_name):
              label = 1

          yield str(file_name), {
              'image': file_name,
              'label': label,
          }

In [None]:
!tfds build

#### Hand

###### Hand 100

In [None]:
!rm -rf /content/simclr/hand_custom

In [None]:
%cd /content/simclr

In [None]:
!tfds new hand_custom

In [None]:
%cd /content/simclr/hand_custom

In [None]:
%%writefile hand_custom_dataset_builder.py

import tensorflow_datasets as tfds
import csv
import tensorflow as tf
from pathlib import Path
import numpy as np
import os

class HandBuilder(tfds.core.GeneratorBasedBuilder):
  """DatasetBuilder for hand_custom dataset."""
  MANUAL_DOWNLOAD_INSTRUCTIONS = "/content/simclr"

  VERSION = tfds.core.Version('1.0.0')
  RELEASE_NOTES = {'1.0.0': 'Initial release.'}

  # Implement the required abstract methods
  def _info(self):
      # Define the dataset information
      features = tfds.features.FeaturesDict({
          'image': tfds.features.Image(shape=(None, None, 3)),
          'label': tfds.features.ClassLabel(num_classes=2),
      })
      return tfds.core.DatasetInfo(
          builder=self,
          description='My custom hand dataset',
          features=features,
          supervised_keys=('image', 'label'),
      )

  def _split_generators(self, dl_manager: tfds.download.DownloadManager):
      # Specify the splits
      return {
          'train': self._generate_examples(Path('/content/MURA-v1.1/train/XR_HAND')),
          'test': self._generate_examples(Path('/content/MURA-v1.1/test/XR_HAND')),
          'valid': self._generate_examples(Path('/content/MURA-v1.1/valid/XR_HAND'))
      }

  def _generate_examples(self, path):
      for file_name in path.glob('*/*/*.png'):
          if 'negative' in str(file_name):
              label = 0
          elif 'positive' in str(file_name):
              label = 1

          yield str(file_name), {
              'image': file_name,
              'label': label,
          }

In [None]:
!tfds build

###### Hand 10

In [None]:
!rm -rf /content/simclr/hand_custom_10

In [None]:
%cd /content/simclr

In [None]:
!tfds new hand_custom_10

In [None]:
%cd /content/simclr/hand_custom_10

In [None]:
%%writefile hand_custom_10_dataset_builder.py

import tensorflow_datasets as tfds
import csv
import tensorflow as tf
from pathlib import Path
import numpy as np
import os

class HandBuilder10(tfds.core.GeneratorBasedBuilder):
  """DatasetBuilder for hand_custom_10 dataset."""
  MANUAL_DOWNLOAD_INSTRUCTIONS = "/content/simclr"

  VERSION = tfds.core.Version('1.0.0')
  RELEASE_NOTES = {'1.0.0': 'Initial release.'}

  # Implement the required abstract methods
  def _info(self):
      # Define the dataset information
      features = tfds.features.FeaturesDict({
          'image': tfds.features.Image(shape=(None, None, 3)),
          'label': tfds.features.ClassLabel(num_classes=2),
      })
      return tfds.core.DatasetInfo(
          builder=self,
          description='My custom hand dataset',
          features=features,
          supervised_keys=('image', 'label'),
      )

  def _split_generators(self, dl_manager: tfds.download.DownloadManager):
      # Specify the splits
      return {
          'train': self._generate_examples(Path('/content/MURA-v1.1/train_10/XR_HAND')),
          'test': self._generate_examples(Path('/content/MURA-v1.1/test/XR_HAND')),
          'valid': self._generate_examples(Path('/content/MURA-v1.1/valid/XR_HAND'))
      }

  def _generate_examples(self, path):
      for file_name in path.glob('*/*/*.png'):
          if 'negative' in str(file_name):
              label = 0
          elif 'positive' in str(file_name):
              label = 1

          yield str(file_name), {
              'image': file_name,
              'label': label,
          }

In [None]:
!tfds build

###### Hand 1

In [None]:
!rm -rf /content/simclr/hand_custom_1

In [None]:
%cd /content/simclr

In [None]:
!tfds new hand_custom_1

In [None]:
%cd /content/simclr/hand_custom_1

In [None]:
%%writefile hand_custom_1_dataset_builder.py

import tensorflow_datasets as tfds
import csv
import tensorflow as tf
from pathlib import Path
import numpy as np
import os

class HandBuilder1(tfds.core.GeneratorBasedBuilder):
  """DatasetBuilder for hand_custom_1 dataset."""
  MANUAL_DOWNLOAD_INSTRUCTIONS = "/content/simclr"

  VERSION = tfds.core.Version('1.0.0')
  RELEASE_NOTES = {'1.0.0': 'Initial release.'}

  # Implement the required abstract methods
  def _info(self):
      # Define the dataset information
      features = tfds.features.FeaturesDict({
          'image': tfds.features.Image(shape=(None, None, 3)),
          'label': tfds.features.ClassLabel(num_classes=2),
      })
      return tfds.core.DatasetInfo(
          builder=self,
          description='My custom hand dataset',
          features=features,
          supervised_keys=('image', 'label'),
      )

  def _split_generators(self, dl_manager: tfds.download.DownloadManager):
      # Specify the splits
      return {
          'train': self._generate_examples(Path('/content/MURA-v1.1/train_1/XR_HAND')),
          'test': self._generate_examples(Path('/content/MURA-v1.1/test/XR_HAND')),
          'valid': self._generate_examples(Path('/content/MURA-v1.1/valid/XR_HAND'))
      }

  def _generate_examples(self, path):
      for file_name in path.glob('*/*/*.png'):
          if 'negative' in str(file_name):
              label = 0
          elif 'positive' in str(file_name):
              label = 1

          yield str(file_name), {
              'image': file_name,
              'label': label,
          }

In [None]:
!tfds build

#### Shoulder

###### Shoulder 100

In [None]:
!rm -rf /content/simclr/shoulder_custom

In [None]:
%cd /content/simclr

In [None]:
!tfds new shoulder_custom

In [None]:
%cd /content/simclr/shoulder_custom

In [None]:
%%writefile shoulder_custom_dataset_builder.py

import tensorflow_datasets as tfds
import csv
import tensorflow as tf
from pathlib import Path
import numpy as np
import os

class ShoulderBuilder(tfds.core.GeneratorBasedBuilder):
  """DatasetBuilder for shoulder_custom dataset."""
  MANUAL_DOWNLOAD_INSTRUCTIONS = "/content/simclr"

  VERSION = tfds.core.Version('1.0.0')
  RELEASE_NOTES = {'1.0.0': 'Initial release.'}

  # Implement the required abstract methods
  def _info(self):
      # Define the dataset information
      features = tfds.features.FeaturesDict({
          'image': tfds.features.Image(shape=(None, None, 3)),
          'label': tfds.features.ClassLabel(num_classes=2),
      })
      return tfds.core.DatasetInfo(
          builder=self,
          description='My custom shoulder dataset',
          features=features,
          supervised_keys=('image', 'label'),
      )

  def _split_generators(self, dl_manager: tfds.download.DownloadManager):
      # Specify the splits
      return {
          'train': self._generate_examples(Path('/content/MURA-v1.1/train/XR_SHOULDER')),
          'test': self._generate_examples(Path('/content/MURA-v1.1/test/XR_SHOULDER')),
          'valid': self._generate_examples(Path('/content/MURA-v1.1/valid/XR_SHOULDER'))
      }

  def _generate_examples(self, path):
      for file_name in path.glob('*/*/*.png'):
          if 'negative' in str(file_name):
              label = 0
          elif 'positive' in str(file_name):
              label = 1

          yield str(file_name), {
              'image': file_name,
              'label': label,
          }

In [None]:
!tfds build

###### Shoulde 10

In [None]:
!rm -rf /content/simclr/shoulder_custom_10

In [None]:
%cd /content/simclr

In [None]:
!tfds new shoulder_custom_10

In [None]:
%cd /content/simclr/shoulder_custom_10

In [None]:
%%writefile shoulder_custom_10_dataset_builder.py

import tensorflow_datasets as tfds
import csv
import tensorflow as tf
from pathlib import Path
import numpy as np
import os

class ShoulderBuilder10(tfds.core.GeneratorBasedBuilder):
  """DatasetBuilder for shoulder_custom_10 dataset."""
  MANUAL_DOWNLOAD_INSTRUCTIONS = "/content/simclr"

  VERSION = tfds.core.Version('1.0.0')
  RELEASE_NOTES = {'1.0.0': 'Initial release.'}

  # Implement the required abstract methods
  def _info(self):
      # Define the dataset information
      features = tfds.features.FeaturesDict({
          'image': tfds.features.Image(shape=(None, None, 3)),
          'label': tfds.features.ClassLabel(num_classes=2),
      })
      return tfds.core.DatasetInfo(
          builder=self,
          description='My custom shoulder dataset',
          features=features,
          supervised_keys=('image', 'label'),
      )

  def _split_generators(self, dl_manager: tfds.download.DownloadManager):
      # Specify the splits
      return {
          'train': self._generate_examples(Path('/content/MURA-v1.1/train_10/XR_SHOULDER')),
          'test': self._generate_examples(Path('/content/MURA-v1.1/test/XR_SHOULDER')),
          'valid': self._generate_examples(Path('/content/MURA-v1.1/valid/XR_SHOULDER'))
      }

  def _generate_examples(self, path):
      for file_name in path.glob('*/*/*.png'):
          if 'negative' in str(file_name):
              label = 0
          elif 'positive' in str(file_name):
              label = 1

          yield str(file_name), {
              'image': file_name,
              'label': label,
          }

In [None]:
!tfds build

###### Shoulder 1

In [None]:
!rm -rf /content/simclr/shoulder_custom_1

In [None]:
%cd /content/simclr

In [None]:
!tfds new shoulder_custom_1

In [None]:
%cd /content/simclr/shoulder_custom_1

In [None]:
%%writefile shoulder_custom_1_dataset_builder.py

import tensorflow_datasets as tfds
import csv
import tensorflow as tf
from pathlib import Path
import numpy as np
import os

class ShoulderBuilder1(tfds.core.GeneratorBasedBuilder):
  """DatasetBuilder for shoulder_custom_1 dataset."""
  MANUAL_DOWNLOAD_INSTRUCTIONS = "/content/simclr"

  VERSION = tfds.core.Version('1.0.0')
  RELEASE_NOTES = {'1.0.0': 'Initial release.'}

  # Implement the required abstract methods
  def _info(self):
      # Define the dataset information
      features = tfds.features.FeaturesDict({
          'image': tfds.features.Image(shape=(None, None, 3)),
          'label': tfds.features.ClassLabel(num_classes=2),
      })
      return tfds.core.DatasetInfo(
          builder=self,
          description='My custom shoulder dataset',
          features=features,
          supervised_keys=('image', 'label'),
      )

  def _split_generators(self, dl_manager: tfds.download.DownloadManager):
      # Specify the splits
      return {
          'train': self._generate_examples(Path('/content/MURA-v1.1/train_1/XR_SHOULDER')),
          'test': self._generate_examples(Path('/content/MURA-v1.1/test/XR_SHOULDER')),
          'valid': self._generate_examples(Path('/content/MURA-v1.1/valid/XR_SHOULDER'))
      }

  def _generate_examples(self, path):
      for file_name in path.glob('*/*/*.png'):
          if 'negative' in str(file_name):
              label = 0
          elif 'positive' in str(file_name):
              label = 1

          yield str(file_name), {
              'image': file_name,
              'label': label,
          }

In [None]:
!tfds build

## SimCLR Adjustment

In [None]:
%cd /content/simclr

In [None]:
%%writefile run.py

# coding=utf-8
# Copyright 2020 The SimCLR Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific simclr governing permissions and
# limitations under the License.
# ==============================================================================
"""The main training pipeline."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import math
import os
from absl import app
from absl import flags

import resnet
import data as data_lib
import model as model_lib
import model_util as model_util

import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
import tensorflow_datasets as tfds
import tensorflow_hub as hub

from elbow_custom.elbow_custom_dataset_builder import ElbowBuilder
from elbow_custom_10.elbow_custom_10_dataset_builder import ElbowBuilder10
from elbow_custom_1.elbow_custom_1_dataset_builder import ElbowBuilder1

from hand_custom.hand_custom_dataset_builder import HandBuilder
from hand_custom_10.hand_custom_10_dataset_builder import HandBuilder10
from hand_custom_1.hand_custom_1_dataset_builder import HandBuilder1

from shoulder_custom.shoulder_custom_dataset_builder import ShoulderBuilder
from shoulder_custom_10.shoulder_custom_10_dataset_builder import ShoulderBuilder10
from shoulder_custom_1.shoulder_custom_1_dataset_builder import ShoulderBuilder1

FLAGS = flags.FLAGS

name2builder= {"ELBOW" : ElbowBuilder,
               "ELBOW/1" : ElbowBuilder1,
               "ELBOW/10" : ElbowBuilder10,
               "ELBOW/100" : ElbowBuilder,
               "HAND" : HandBuilder,
               "HAND/1" : HandBuilder1,
               "HAND/10" : HandBuilder10,
               "HAND/100" : HandBuilder,
               "SHOULDER" : ShoulderBuilder,
               "SHOULDER/1" : ShoulderBuilder1,
               "SHOULDER/10" : ShoulderBuilder10,
               "SHOULDER/100" : ShoulderBuilder}

flags.DEFINE_float(
    'learning_rate', 0.3,
    'Initial learning rate per batch size of 256.')

flags.DEFINE_enum(
    'learning_rate_scaling', 'linear', ['linear', 'sqrt'],
    'How to scale the learning rate as a function of batch size.')

flags.DEFINE_float(
    'warmup_epochs', 10,
    'Number of epochs of warmup.')

flags.DEFINE_float(
    'weight_decay', 1e-4,
    'Amount of weight decay to use.')

flags.DEFINE_float(
    'batch_norm_decay', 0.9,
    'Batch norm decay parameter.')

flags.DEFINE_integer(
    'train_batch_size', 512,
    'Batch size for training.')

flags.DEFINE_string(
    'train_split', 'train',
    'Split for training.')

flags.DEFINE_integer(
    'train_epochs', 100,
    'Number of epochs to train for.')

flags.DEFINE_integer(
    'train_steps', 0,
    'Number of steps to train for. If provided, overrides train_epochs.')

flags.DEFINE_integer(
    'eval_batch_size', 256,
    'Batch size for eval.')

flags.DEFINE_integer(
    'train_summary_steps', 100,
    'Steps before saving training summaries. If 0, will not save.')

flags.DEFINE_integer(
    'checkpoint_epochs', 1,
    'Number of epochs between checkpoints/summaries.')

flags.DEFINE_integer(
    'checkpoint_steps', 0,
    'Number of steps between checkpoints/summaries. If provided, overrides '
    'checkpoint_epochs.')

flags.DEFINE_string(
    'eval_split', 'validation',
    'Split for evaluation.')

flags.DEFINE_string(
    'dataset', 'imagenet2012',
    'Name of a dataset.')

flags.DEFINE_bool(
    'cache_dataset', False,
    'Whether to cache the entire dataset in memory. If the dataset is '
    'ImageNet, this is a very bad idea, but for smaller datasets it can '
    'improve performance.')

flags.DEFINE_enum(
    'mode', 'train', ['train', 'eval', 'train_then_eval'],
    'Whether to perform training or evaluation.')

flags.DEFINE_enum(
    'train_mode', 'pretrain', ['pretrain', 'finetune'],
    'The train mode controls different objectives and trainable components.')

flags.DEFINE_string(
    'checkpoint', None,
    'Loading from the given checkpoint for continued training or fine-tuning.')

flags.DEFINE_string(
    'variable_schema', '?!global_step',
    'This defines whether some variable from the checkpoint should be loaded.')

flags.DEFINE_bool(
    'zero_init_logits_layer', False,
    'If True, zero initialize layers after avg_pool for supervised learning.')

flags.DEFINE_integer(
    'fine_tune_after_block', -1,
    'The layers after which block that we will fine-tune. -1 means fine-tuning '
    'everything. 0 means fine-tuning after stem block. 4 means fine-tuning '
    'just the linera head.')

flags.DEFINE_string(
    'master', None,
    'Address/name of the TensorFlow master to use. By default, use an '
    'in-process master.')

flags.DEFINE_string(
    'model_dir', None,
    'Model directory for training.')

flags.DEFINE_string(
    'data_dir', None,
    'Directory where dataset is stored.')

flags.DEFINE_bool(
    'use_tpu', True,
    'Whether to run on TPU.')

tf.flags.DEFINE_string(
    'tpu_name', None,
    'The Cloud TPU to use for training. This should be either the name '
    'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
    'url.')

tf.flags.DEFINE_string(
    'tpu_zone', None,
    '[Optional] GCE zone where the Cloud TPU is located in. If not '
    'specified, we will attempt to automatically detect the GCE project from '
    'metadata.')

tf.flags.DEFINE_string(
    'gcp_project', None,
    '[Optional] Project name for the Cloud TPU-enabled project. If not '
    'specified, we will attempt to automatically detect the GCE project from '
    'metadata.')

flags.DEFINE_enum(
    'optimizer', 'lars', ['momentum', 'adam', 'lars'],
    'Optimizer to use.')

flags.DEFINE_float(
    'momentum', 0.9,
    'Momentum parameter.')

flags.DEFINE_string(
    'eval_name', None,
    'Name for eval.')

flags.DEFINE_integer(
    'keep_checkpoint_max', 5,
    'Maximum number of checkpoints to keep.')

flags.DEFINE_integer(
    'keep_hub_module_max', 1,
    'Maximum number of Hub modules to keep.')

flags.DEFINE_float(
    'temperature', 0.1,
    'Temperature parameter for contrastive loss.')

flags.DEFINE_boolean(
    'hidden_norm', True,
    'Temperature parameter for contrastive loss.')

flags.DEFINE_enum(
    'proj_head_mode', 'nonlinear', ['none', 'linear', 'nonlinear'],
    'How the head projection is done.')

flags.DEFINE_integer(
    'proj_out_dim', 128,
    'Number of head projection dimension.')

flags.DEFINE_integer(
    'num_proj_layers', 3,
    'Number of non-linear head layers.')

flags.DEFINE_integer(
    'ft_proj_selector', 0,
    'Which layer of the projection head to use during fine-tuning. '
    '0 means throwing away the projection head, and -1 means the final layer.')

flags.DEFINE_boolean(
    'global_bn', True,
    'Whether to aggregate BN statistics across distributed cores.')

flags.DEFINE_integer(
    'width_multiplier', 1,
    'Multiplier to change width of network.')

flags.DEFINE_integer(
    'resnet_depth', 50,
    'Depth of ResNet.')

flags.DEFINE_float(
    'sk_ratio', 0.,
    'If it is bigger than 0, it will enable SK. Recommendation: 0.0625.')

flags.DEFINE_float(
    'se_ratio', 0.,
    'If it is bigger than 0, it will enable SE.')

flags.DEFINE_integer(
    'image_size', 224,
    'Input image size.')

flags.DEFINE_float(
    'color_jitter_strength', 1.0,
    'The strength of color jittering.')

flags.DEFINE_boolean(
    'use_blur', True,
    'Whether or not to use Gaussian blur for augmentation during pretraining.')


def build_hub_module(model, num_classes, global_step, checkpoint_path):
  """Create TF-Hub module."""

  tags_and_args = [
      # The default graph is built with batch_norm, dropout etc. in inference
      # mode. This graph version is good for inference, not training.
      ([], {'is_training': False}),
      # A separate "train" graph builds batch_norm, dropout etc. in training
      # mode.
      (['train'], {'is_training': True}),
  ]

  def module_fn(is_training):
    """Function that builds TF-Hub module."""
    endpoints = {}
    inputs = tf.placeholder(
        tf.float32, [None, None, None, 3])
    with tf.variable_scope('base_model', reuse=tf.AUTO_REUSE):
      hiddens = model(inputs, is_training)
      for v in ['initial_conv', 'initial_max_pool', 'block_group1',
                'block_group2', 'block_group3', 'block_group4',
                'final_avg_pool']:
        endpoints[v] = tf.get_default_graph().get_tensor_by_name(
            'base_model/{}:0'.format(v))
    if FLAGS.train_mode == 'pretrain':
      hiddens_proj = model_util.projection_head(hiddens, is_training)
      endpoints['proj_head_input'] = hiddens
      endpoints['proj_head_output'] = hiddens_proj
    else:
      logits_sup = model_util.supervised_head(
          hiddens, num_classes, is_training)
      endpoints['logits_sup'] = logits_sup
    hub.add_signature(inputs=dict(images=inputs),
                      outputs=dict(endpoints, default=hiddens))

  # Drop the non-supported non-standard graph collection.
  drop_collections = ['trainable_variables_inblock_%d'%d for d in range(6)]
  spec = hub.create_module_spec(module_fn, tags_and_args, drop_collections)
  hub_export_dir = os.path.join(FLAGS.model_dir, 'hub')
  checkpoint_export_dir = os.path.join(hub_export_dir, str(global_step))
  if tf.io.gfile.exists(checkpoint_export_dir):
    # Do not save if checkpoint already saved.
    tf.io.gfile.rmtree(checkpoint_export_dir)
  spec.export(
      checkpoint_export_dir,
      checkpoint_path=checkpoint_path,
      name_transform_fn=None)

  if FLAGS.keep_hub_module_max > 0:
    # Delete old exported Hub modules.
    exported_steps = []
    for subdir in tf.io.gfile.listdir(hub_export_dir):
      if not subdir.isdigit():
        continue
      exported_steps.append(int(subdir))
    exported_steps.sort()
    for step_to_delete in exported_steps[:-FLAGS.keep_hub_module_max]:
      tf.io.gfile.rmtree(os.path.join(hub_export_dir, str(step_to_delete)))


def perform_evaluation(estimator, input_fn, eval_steps, model, num_classes,
                       checkpoint_path=None):
  """Perform evaluation.

  Args:
    estimator: TPUEstimator instance.
    input_fn: Input function for estimator.
    eval_steps: Number of steps for evaluation.
    model: Instance of transfer_learning.models.Model.
    num_classes: Number of classes to build model for.
    checkpoint_path: Path of checkpoint to evaluate.

  Returns:
    result: A Dict of metrics and their values.
  """
  if not checkpoint_path:
    checkpoint_path = estimator.latest_checkpoint()
  result = estimator.evaluate(
      input_fn, eval_steps, checkpoint_path=checkpoint_path,
      name=FLAGS.eval_name)

  # Record results as JSON.
  result_json_path = os.path.join(FLAGS.model_dir, 'result.json')
  with tf.io.gfile.GFile(result_json_path, 'w') as f:
    json.dump({k: float(v) for k, v in result.items()}, f)
  result_json_path = os.path.join(
      FLAGS.model_dir, 'result_%d.json'%result['global_step'])
  with tf.io.gfile.GFile(result_json_path, 'w') as f:
    json.dump({k: float(v) for k, v in result.items()}, f)
  flag_json_path = os.path.join(FLAGS.model_dir, 'flags.json')

  def json_serializable(val):
    try:
      json.dumps(val)
      return True
    except TypeError:
      return False

  with tf.io.gfile.GFile(flag_json_path, 'w') as f:
    serializable_flags = {}
    for key, val in FLAGS.flag_values_dict().items():
      # Some flag value types e.g. datetime.timedelta are not json serializable,
      # filter those out.
      if json_serializable(val):
        serializable_flags[key] = val
    json.dump(serializable_flags, f)

  # Save Hub module.
  build_hub_module(model, num_classes,
                   global_step=result['global_step'],
                   checkpoint_path=checkpoint_path)

  return result


def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  # Enable training summary.
  if FLAGS.train_summary_steps > 0:
    tf.config.set_soft_device_placement(True)

  # builder = tfds.builder(FLAGS.dataset, data_dir=FLAGS.data_dir)
  # builder.download_and_prepare()

  part = FLAGS.dataset
  builder = name2builder.get(part)

  if (builder is None):
    return

  builder = builder()
  num_train_examples = builder.info.splits[FLAGS.train_split].num_examples
  num_eval_examples = builder.info.splits[FLAGS.eval_split].num_examples
  num_classes = builder.info.features['label'].num_classes

  train_steps = model_util.get_train_steps(num_train_examples)
  eval_steps = int(math.ceil(num_eval_examples / FLAGS.eval_batch_size))
  epoch_steps = int(round(num_train_examples / FLAGS.train_batch_size))

  resnet.BATCH_NORM_DECAY = FLAGS.batch_norm_decay
  model = resnet.resnet_v1(
      resnet_depth=FLAGS.resnet_depth,
      width_multiplier=FLAGS.width_multiplier,
      cifar_stem=FLAGS.image_size <= 32)

  checkpoint_steps = (
      FLAGS.checkpoint_steps or (FLAGS.checkpoint_epochs * epoch_steps))

  cluster = None
  if FLAGS.use_tpu and FLAGS.master is None:
    if FLAGS.tpu_name:
      cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
          FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    else:
      cluster = tf.distribute.cluster_resolver.TPUClusterResolver()
      tf.config.experimental_connect_to_cluster(cluster)
      tf.tpu.experimental.initialize_tpu_system(cluster)

  default_eval_mode = tf_estimator.tpu.InputPipelineConfig.PER_HOST_V1
  sliced_eval_mode = tf_estimator.tpu.InputPipelineConfig.SLICED
  run_config = tf_estimator.tpu.RunConfig(
      tpu_config=tf_estimator.tpu.TPUConfig(
          iterations_per_loop=checkpoint_steps,
          eval_training_input_configuration=sliced_eval_mode
          if FLAGS.use_tpu else default_eval_mode),
      model_dir=FLAGS.model_dir,
      save_summary_steps=checkpoint_steps,
      save_checkpoints_steps=checkpoint_steps,
      keep_checkpoint_max=FLAGS.keep_checkpoint_max,
      master=FLAGS.master,
      cluster=cluster)
  estimator = tf_estimator.tpu.TPUEstimator(
      model_lib.build_model_fn(model, num_classes, num_train_examples),
      config=run_config,
      train_batch_size=FLAGS.train_batch_size,
      eval_batch_size=FLAGS.eval_batch_size,
      use_tpu=FLAGS.use_tpu)

  if FLAGS.mode == 'eval':
    for ckpt in tf.train.checkpoints_iterator(
        run_config.model_dir, min_interval_secs=15):
      try:
        result = perform_evaluation(
            estimator=estimator,
            input_fn=data_lib.build_input_fn(builder, False),
            eval_steps=eval_steps,
            model=model,
            num_classes=num_classes,
            checkpoint_path=ckpt)
      except tf.errors.NotFoundError:
        continue
      if result['global_step'] >= train_steps:
        return
  else:
    estimator.train(
        data_lib.build_input_fn(builder, True), max_steps=train_steps)
    if FLAGS.mode == 'train_then_eval':
      perform_evaluation(
          estimator=estimator,
          input_fn=data_lib.build_input_fn(builder, False),
          eval_steps=eval_steps,
          model=model,
          num_classes=num_classes)


if __name__ == '__main__':
  tf.disable_v2_behavior()  # Disable eager mode when running with TF2.
  app.run(main)


In [None]:
%%writefile model.py

# coding=utf-8
# Copyright 2020 The SimCLR Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific simclr governing permissions and
# limitations under the License.
# ==============================================================================
"""Model specification for SimCLR."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import flags

import data_util as data_util
import model_util as model_util
import objective as obj_lib

import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
import tensorflow.compat.v2 as tf2

FLAGS = flags.FLAGS


def build_model_fn(model, num_classes, num_train_examples):
  """Build model function."""
  def model_fn(features, labels, mode, params=None):
    """Build model and optimizer."""
    is_training = mode == tf_estimator.ModeKeys.TRAIN

    # Check training mode.
    if FLAGS.train_mode == 'pretrain':
      num_transforms = 2
      if FLAGS.fine_tune_after_block > -1:
        raise ValueError('Does not support layer freezing during pretraining,'
                         'should set fine_tune_after_block<=-1 for safety.')
    elif FLAGS.train_mode == 'finetune':
      num_transforms = 1
    else:
      raise ValueError('Unknown train_mode {}'.format(FLAGS.train_mode))

    # Split channels, and optionally apply extra batched augmentation.
    features_list = tf.split(
        features, num_or_size_splits=num_transforms, axis=-1)
    if FLAGS.use_blur and is_training and FLAGS.train_mode == 'pretrain':
      features_list = data_util.batch_random_blur(
          features_list, FLAGS.image_size, FLAGS.image_size)
    features = tf.concat(features_list, 0)  # (num_transforms * bsz, h, w, c)

    # Base network forward pass.
    with tf.variable_scope('base_model'):
      if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block >= 4:
        # Finetune just supervised (linear) head will not update BN stats.
        model_train_mode = False
      else:
        # Pretrain or finetune anything else will update BN stats.
        model_train_mode = is_training
      hiddens = model(features, is_training=model_train_mode)

    # Add head and loss.
    if FLAGS.train_mode == 'pretrain':
      tpu_context = params['context'] if 'context' in params else None
      hiddens_proj = model_util.projection_head(hiddens, is_training)
      contrast_loss, logits_con, labels_con = obj_lib.add_contrastive_loss(
          hiddens_proj,
          hidden_norm=FLAGS.hidden_norm,
          temperature=FLAGS.temperature,
          tpu_context=tpu_context if is_training else None)
      logits_sup = tf.zeros([params['batch_size'], num_classes])
    else:
      contrast_loss = tf.zeros([])
      logits_con = tf.zeros([params['batch_size'], 10])
      labels_con = tf.zeros([params['batch_size'], 10])
      hiddens = model_util.projection_head(hiddens, is_training)
      logits_sup = model_util.supervised_head(
          hiddens, num_classes, is_training)
      obj_lib.add_supervised_loss(
          labels=labels['labels'],
          logits=logits_sup,
          weights=labels['mask'])

    # Add weight decay to loss, for non-LARS optimizers.
    model_util.add_weight_decay(adjust_per_optimizer=True)
    loss = tf.losses.get_total_loss()

    if FLAGS.train_mode == 'pretrain':
      variables_to_train = tf.trainable_variables()
    else:
      collection_prefix = 'trainable_variables_inblock_'
      variables_to_train = []
      for j in range(FLAGS.fine_tune_after_block + 1, 6):
        variables_to_train += tf.get_collection(collection_prefix + str(j))
      assert variables_to_train, 'variables_to_train shouldn\'t be empty!'

    tf.logging.info('===============Variables to train (begin)===============')
    tf.logging.info(variables_to_train)
    tf.logging.info('================Variables to train (end)================')

    learning_rate = model_util.learning_rate_schedule(
        FLAGS.learning_rate, num_train_examples)

    if is_training:
      if FLAGS.train_summary_steps > 0:
        # Compute stats for the summary.
        prob_con = tf.nn.softmax(logits_con)
        entropy_con = - tf.reduce_mean(
            tf.reduce_sum(prob_con * tf.math.log(prob_con + 1e-8), -1))

        summary_writer = tf2.summary.create_file_writer(FLAGS.model_dir)
        # TODO(iamtingchen): remove this control_dependencies in the future.
        with tf.control_dependencies([summary_writer.init()]):
          with summary_writer.as_default():
            should_record = tf.math.equal(
                tf.math.floormod(tf.train.get_global_step(),
                                 FLAGS.train_summary_steps), 0)
            with tf2.summary.record_if(should_record):
              contrast_acc = tf.equal(
                  tf.argmax(labels_con, 1), tf.argmax(logits_con, axis=1))
              contrast_acc = tf.reduce_mean(tf.cast(contrast_acc, tf.float32))
              label_acc = tf.equal(
                  tf.argmax(labels['labels'], 1), tf.argmax(logits_sup, axis=1))
              label_acc = tf.reduce_mean(tf.cast(label_acc, tf.float32))

              # Write to a file
              #with open('/content/simclr/simclr_pretrain/debug_file.txt', 'w') as f:
               # f.write(f"step: {tf.train.get_global_step() // 30}, current loss: {loss}, contrast acc: {contrast_acc}, label acc: {label_acc}\n")

              tf2.summary.scalar(
                  'train_contrast_loss',
                  contrast_loss,
                  step=tf.train.get_global_step())
              tf2.summary.scalar(
                  'train_contrast_acc',
                  contrast_acc,
                  step=tf.train.get_global_step())
              tf2.summary.scalar(
                  'train_label_accuracy',
                  label_acc,
                  step=tf.train.get_global_step())
              tf2.summary.scalar(
                  'contrast_entropy',
                  entropy_con,
                  step=tf.train.get_global_step())
              tf2.summary.scalar(
                  'learning_rate', learning_rate,
                  step=tf.train.get_global_step())

      optimizer = model_util.get_optimizer(learning_rate)
      control_deps = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      if FLAGS.train_summary_steps > 0:
        control_deps.extend(tf.summary.all_v2_summary_ops())
      with tf.control_dependencies(control_deps):
        train_op = optimizer.minimize(
            loss, global_step=tf.train.get_or_create_global_step(),
            var_list=variables_to_train)

      if FLAGS.checkpoint:
        def scaffold_fn():
          """Scaffold function to restore non-logits vars from checkpoint."""
          tf.train.init_from_checkpoint(
              FLAGS.checkpoint,
              {v.op.name: v.op.name
               for v in tf.global_variables(FLAGS.variable_schema)})

          if FLAGS.zero_init_logits_layer:
            # Init op that initializes output layer parameters to zeros.
            output_layer_parameters = [
                var for var in tf.trainable_variables() if var.name.startswith(
                    'head_supervised')]
            tf.logging.info('Initializing output layer parameters %s to zero',
                            [x.op.name for x in output_layer_parameters])
            with tf.control_dependencies([tf.global_variables_initializer()]):
              init_op = tf.group([
                  tf.assign(x, tf.zeros_like(x))
                  for x in output_layer_parameters])
            return tf.train.Scaffold(init_op=init_op)
          else:
            return tf.train.Scaffold()
      else:
        scaffold_fn = None

      return tf_estimator.tpu.TPUEstimatorSpec(
          mode=mode, train_op=train_op, loss=loss, scaffold_fn=scaffold_fn)
    else:

      def metric_fn(logits_sup, labels_sup, logits_con, labels_con, mask,
                    **kws):
        """Inner metric function."""
        metrics = {k: tf.metrics.mean(v, weights=mask)
                   for k, v in kws.items()}
        metrics['label_accuracy'] = tf.metrics.accuracy(
            tf.argmax(labels_sup, 1), tf.argmax(logits_sup, axis=1),
            weights=mask)

        metrics['contrastive_accuracy'] = tf.metrics.accuracy(
            tf.argmax(labels_con, 1), tf.argmax(logits_con, axis=1),
            weights=mask)

        metrics['recall'] = tf.metrics.recall(
            tf.argmax(labels_sup, 1), tf.argmax(logits_sup, axis=1),
            weights=mask)

        metrics['precision'] = tf.metrics.precision(
            tf.argmax(labels_sup, 1), tf.argmax(logits_sup, axis=1),
            weights=mask)
        return metrics

      metrics = {
          'logits_sup': logits_sup,
          'labels_sup': labels['labels'],
          'logits_con': logits_con,
          'labels_con': labels_con,
          'mask': labels['mask'],
          'contrast_loss': tf.fill((params['batch_size'],), contrast_loss),
          'regularization_loss': tf.fill((params['batch_size'],),
                                         tf.losses.get_regularization_loss()),
      }

      return tf_estimator.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=loss,
          eval_metrics=(metric_fn, metrics),
          scaffold_fn=None)

  return model_fn


## SimCLR Experiments

In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np

def plot_statistics (label):
  file_paths = [
      f"/content/simclr/simclr_finetune/{label}/log_1_lr_0.1/result.json",
      f"/content/simclr/simclr_finetune/{label}/log_10_lr_0.1/result.json",
      f"/content/simclr/simclr_finetune/{label}/log_100_lr_0.1/result.json"
  ]

  # Precentage of images for training
  pct_size = [1, 10, 100]

  # Create separate lists for precision, recall, and label accuracy
  precisions = []
  recalls = []
  label_accuracies = []

  # Read each JSON file and extract the desired values
  for file_path in file_paths:
      with open(file_path, "r") as file:
          data = json.load(file)
          precision = round(data["precision"], 4)
          recall = round(data["recall"], 4)
          label_accuracy = round(data["label_accuracy"], 4)
          precisions.append(precision)
          recalls.append(recall)
          label_accuracies.append(label_accuracy)

  ############### Label Accuracy #####################

  # Plot label accuracy
  fig = plt.figure(figsize=(5, 3))
  plt.plot(pct_size, label_accuracies, '--', color="#000", marker="*", markeredgecolor="#000", markerfacecolor="y", markersize=16)
  plt.xscale("log")
  plt.xticks(pct_size, labels=pct_size)
  plt.title(f"MURA-v1.1 {label} Label Classification over dataset size", fontsize=11)
  plt.xlabel("Percentage of data")
  plt.ylabel("Label Accuracy")
  plt.minorticks_off()
  plt.show()

  for k, score in zip(pct_size, label_accuracies):
      print(f'Label accuracy for {k}% of data: {100 * score:4.2f}%')
  print('\n')


  ############### Precision #####################

  # Plot Precision
  fig = plt.figure(figsize=(5, 3))
  plt.plot(pct_size, precisions, '--', color="#000", marker="*", markeredgecolor="#000", markerfacecolor="y", markersize=16)
  plt.xscale("log")
  plt.xticks(pct_size, labels=pct_size)
  plt.title(f"MURA-v1.1 {label} Precision over dataset size", fontsize=11)
  plt.xlabel("Percentage of data")
  plt.ylabel("Precision")
  plt.minorticks_off()
  plt.show()

  for k, score in zip(pct_size, precisions):
      print(f'Precision for {k:3}% images per label: {100*score:4.2f}%')
  print('\n')


  ############### Recall #####################

  # Plot Recall
  fig = plt.figure(figsize=(5, 3))
  plt.plot(pct_size, recalls, '--', color="#000", marker="*", markeredgecolor="#000", markerfacecolor="y", markersize=16)
  plt.xscale("log")
  plt.xticks(pct_size, labels=pct_size)
  plt.title(f"MURA-v1.1 {label} Recall over dataset size", fontsize=11)
  plt.xlabel("Percentage of data")
  plt.ylabel("Recall")
  plt.minorticks_off()
  plt.show()

  for k, score in zip(pct_size, recalls):
      print(f'Recall for {k:3}% images per label: {100*score:4.2f}%')

#### Elbow

###### Pretrain

In [None]:
!rm -rf '/content/simclr/simclr_pretrain/elbow'

In [None]:
!python run.py --train_mode=pretrain \
  --train_batch_size=8 --train_epochs=50 --temperature=0.1 \
  --learning_rate=0.075 --learning_rate_scaling=sqrt --weight_decay=1e-4 \
  --dataset=ELBOW --eval_split=valid --resnet_depth=50 \
  --model_dir=/content/simclr/simclr_pretrain/elbow \
  --use_tpu=False --train_summary_steps=100 --eval_batch_size=16

In [None]:
# Start tensorboard.
%reload_ext tensorboard
%tensorboard --logdir /content/simclr/simclr_pretrain/elbow

###### Finetune 1 percentage

###### 1.1 lr=0.3

In [None]:
!rm -rf '/content/simclr/simclr_finetune/elbow/log_1_lr_0.3'

In [None]:
!python run.py \
  --mode=train_then_eval \
  --train_mode=finetune \
  --fine_tune_after_block=4 \
  --zero_init_logits_layer=True \
  --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
  --global_bn=False \
  --optimizer=momentum \
  --learning_rate=0.3 \
  --weight_decay=0.0 \
  --train_epochs=20 \
  --train_batch_size=8 \
  --warmup_epochs=0 \
  --dataset=ELBOW/1 \
  --eval_split=test \
  --resnet_depth=50 \
  --checkpoint=/content/simclr/simclr_pretrain/elbow \
  --model_dir=/content/simclr/simclr_finetune/elbow/log_1_lr_0.3 \
  --use_tpu=False \
  --eval_batch_size=16

######1.2 lr=0.1

In [None]:
!rm -rf '/content/simclr/simclr_finetune/elbow/log_1_lr_0.1'

In [None]:
!python run.py \
  --mode=train_then_eval \
  --train_mode=finetune \
  --fine_tune_after_block=4 \
  --zero_init_logits_layer=True \
  --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
  --global_bn=False \
  --optimizer=momentum \
  --learning_rate=0.1 \
  --weight_decay=0.0 \
  --train_epochs=20 \
  --train_batch_size=8 \
  --warmup_epochs=0 \
  --dataset=ELBOW/1 \
  --eval_split=test \
  --resnet_depth=50 \
  --checkpoint=/content/simclr/simclr_pretrain/elbow \
  --model_dir=/content/simclr/simclr_finetune/elbow/log_1_lr_0.1 \
  --use_tpu=False \
  --eval_batch_size=16

###### Finetune 10 percentage

###### 2.1 lr=0.3

In [None]:
!rm -rf '/content/simclr/simclr_finetune/elbow/log_10_lr_0.3'

In [None]:
!python run.py \
  --mode=train_then_eval \
  --train_mode=finetune \
  --fine_tune_after_block=4 \
  --zero_init_logits_layer=True \
  --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
  --global_bn=False \
  --optimizer=momentum \
  --learning_rate=0.3 \
  --weight_decay=0.0 \
  --train_epochs=20 \
  --train_batch_size=8 \
  --warmup_epochs=0 \
  --dataset=ELBOW/10 \
  --eval_split=test \
  --resnet_depth=50 \
  --checkpoint=/content/simclr/simclr_pretrain/elbow \
  --model_dir=/content/simclr/simclr_finetune/elbow/log_10_lr_0.3 \
  --use_tpu=False \
  --eval_batch_size=16

###### 2.2 lr=0.1

In [None]:
!rm -rf '/content/simclr/simclr_finetune/elbow/log_10_lr_0.1'

In [None]:
!python run.py \
  --mode=train_then_eval \
  --train_mode=finetune \
  --fine_tune_after_block=4 \
  --zero_init_logits_layer=True \
  --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
  --global_bn=False \
  --optimizer=momentum \
  --learning_rate=0.1 \
  --weight_decay=0.0 \
  --train_epochs=20 \
  --train_batch_size=8 \
  --warmup_epochs=0 \
  --dataset=ELBOW/10 \
  --eval_split=test \
  --resnet_depth=50 \
  --checkpoint=/content/simclr/simclr_pretrain/elbow \
  --model_dir=/content/simclr/simclr_finetune/elbow/log_10_lr_0.1 \
  --use_tpu=False \
  --eval_batch_size=16

###### Finetune 100 percentage

######3.1 lr=0.3

In [None]:
!rm -rf '/content/simclr/simclr_finetune/elbow/log_100_lr_0.3'

In [None]:
!python run.py \
  --mode=train_then_eval \
  --train_mode=finetune \
  --fine_tune_after_block=4 \
  --zero_init_logits_layer=True \
  --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
  --global_bn=False \
  --optimizer=momentum \
  --learning_rate=0.3 \
  --weight_decay=0.0 \
  --train_epochs=30 \
  --train_batch_size=8 \
  --warmup_epochs=0 \
  --dataset=ELBOW/100 \
  --eval_split=test \
  --resnet_depth=50 \
  --checkpoint=/content/simclr/simclr_pretrain/elbow \
  --model_dir=/content/simclr/simclr_finetune/elbow/log_100_lr_0.3 \
  --use_tpu=False \
  --eval_batch_size=16

######3.2 lr=0.1

In [None]:
!rm -rf '/content/simclr/simclr_finetune/elbow/log_100_lr_0.1'

In [None]:
!python run.py \
  --mode=train_then_eval \
  --train_mode=finetune \
  --fine_tune_after_block=4 \
  --zero_init_logits_layer=True \
  --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
  --global_bn=False \
  --optimizer=momentum \
  --learning_rate=0.1 \
  --weight_decay=0.0 \
  --train_epochs=30 \
  --train_batch_size=8 \
  --warmup_epochs=0 \
  --dataset=ELBOW/100 \
  --eval_split=test \
  --resnet_depth=50 \
  --checkpoint=/content/simclr/simclr_pretrain/elbow \
  --model_dir=/content/simclr/simclr_finetune/elbow/log_100_lr_0.1 \
  --use_tpu=False \
  --eval_batch_size=16

###### Graphs and Evals

In [None]:
plot_statistics("elbow")

#### Hand

###### Pretrain

In [None]:
!rm -rf '/content/simclr/simclr_pretrain/hand'

In [None]:
!python run.py --train_mode=pretrain \
  --train_batch_size=8 --train_epochs=50 --temperature=0.1 \
  --learning_rate=0.075 --learning_rate_scaling=sqrt --weight_decay=1e-4 \
  --dataset=HAND --eval_split=valid --resnet_depth=50 \
  --model_dir=/content/simclr/simclr_pretrain/hand \
  --use_tpu=False --train_summary_steps=100 --eval_batch_size=16

In [None]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir /content/simclr/simclr_pretrain/hand --port=8010

###### Finetune 1 percentage

######1.1 lr=0.3

In [None]:
!rm -rf '/content/simclr/simclr_finetune/hand/log_1_lr_0.3'

In [None]:
!python run.py \
  --mode=train_then_eval \
  --train_mode=finetune \
  --fine_tune_after_block=4 \
  --zero_init_logits_layer=True \
  --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
  --global_bn=False \
  --optimizer=momentum \
  --learning_rate=0.3 \
  --weight_decay=0.0 \
  --train_epochs=20 \
  --train_batch_size=8 \
  --warmup_epochs=0 \
  --dataset=HAND/1 \
  --eval_split=test \
  --resnet_depth=50 \
  --checkpoint=/content/simclr/simclr_pretrain/hand \
  --model_dir=/content/simclr/simclr_finetune/hand/log_1_lr_0.3 \
  --use_tpu=False \
  --eval_batch_size=16

######1.2 lr=0.1

In [None]:
!rm -rf '/content/simclr/simclr_finetune/hand/log_1_lr_0.1'

In [None]:
!python run.py \
  --mode=train_then_eval \
  --train_mode=finetune \
  --fine_tune_after_block=4 \
  --zero_init_logits_layer=True \
  --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
  --global_bn=False \
  --optimizer=momentum \
  --learning_rate=0.1 \
  --weight_decay=0.0 \
  --train_epochs=20 \
  --train_batch_size=8 \
  --warmup_epochs=0 \
  --dataset=HAND/1 \
  --eval_split=test \
  --resnet_depth=50 \
  --checkpoint=/content/simclr/simclr_pretrain/hand \
  --model_dir=/content/simclr/simclr_finetune/hand/log_1_lr_0.1 \
  --use_tpu=False \
  --eval_batch_size=16

###### Finetune 10 percentage

###### 2.1 lr=0.3

In [None]:
!rm -rf '/content/simclr/simclr_finetune/hand/log_10_lr_0.3'

In [None]:
!python run.py \
  --mode=train_then_eval \
  --train_mode=finetune \
  --fine_tune_after_block=4 \
  --zero_init_logits_layer=True \
  --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
  --global_bn=False \
  --optimizer=momentum \
  --learning_rate=0.3 \
  --weight_decay=0.0 \
  --train_epochs=20 \
  --train_batch_size=8 \
  --warmup_epochs=0 \
  --dataset=HAND/10 \
  --eval_split=test \
  --resnet_depth=50 \
  --checkpoint=/content/simclr/simclr_pretrain/hand \
  --model_dir=/content/simclr/simclr_finetune/hand/log_10_lr_0.3 \
  --use_tpu=False \
  --eval_batch_size=16

###### 2.2 lr=0.1

In [None]:
!rm -rf '/content/simclr/simclr_finetune/hand/log_10_lr_0.1'

In [None]:
!python run.py \
  --mode=train_then_eval \
  --train_mode=finetune \
  --fine_tune_after_block=4 \
  --zero_init_logits_layer=True \
  --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
  --global_bn=False \
  --optimizer=momentum \
  --learning_rate=0.1 \
  --weight_decay=0.0 \
  --train_epochs=20 \
  --train_batch_size=8 \
  --warmup_epochs=0 \
  --dataset=HAND/10 \
  --eval_split=test \
  --resnet_depth=50 \
  --checkpoint=/content/simclr/simclr_pretrain/hand \
  --model_dir=/content/simclr/simclr_finetune/hand/log_10_lr_0.1 \
  --use_tpu=False \
  --eval_batch_size=16

###### Finetune 100 percentage

######3.1 lr=0.3

In [None]:
!rm -rf '/content/simclr/simclr_finetune/hand/log_100_lr_0.3'

In [None]:
!python run.py \
  --mode=train_then_eval \
  --train_mode=finetune \
  --fine_tune_after_block=4 \
  --zero_init_logits_layer=True \
  --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
  --global_bn=False \
  --optimizer=momentum \
  --learning_rate=0.3 \
  --weight_decay=0.0 \
  --train_epochs=20 \
  --train_batch_size=8 \
  --warmup_epochs=0 \
  --dataset=HAND/100 \
  --eval_split=test \
  --resnet_depth=50 \
  --checkpoint=/content/simclr/simclr_pretrain/hand \
  --model_dir=/content/simclr/simclr_finetune/hand/log_100_lr_0.3 \
  --use_tpu=False \
  --eval_batch_size=16

######3.2 lr=0.1

In [None]:
!rm -rf '/content/simclr/simclr_finetune/hand/log_100_lr_0.1'

In [None]:
!python run.py \
  --mode=train_then_eval \
  --train_mode=finetune \
  --fine_tune_after_block=4 \
  --zero_init_logits_layer=True \
  --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
  --global_bn=False \
  --optimizer=momentum \
  --learning_rate=0.1 \
  --weight_decay=0.0 \
  --train_epochs=10 \
  --train_batch_size=8 \
  --warmup_epochs=0 \
  --dataset=HAND/100 \
  --eval_split=test \
  --resnet_depth=50 \
  --checkpoint=/content/simclr/simclr_pretrain/hand \
  --model_dir=/content/simclr/simclr_finetune/hand/log_100_lr_0.1 \
  --use_tpu=False \
  --eval_batch_size=16

###### Graphs and Evals

In [None]:
plot_statistics("hand")

#### Shoulder

###### Pretrain

In [None]:
!rm -rf '/content/simclr/simclr_pretrain/shoulder'

In [None]:
!python run.py --train_mode=pretrain \
  --train_batch_size=8 --train_epochs=50 --temperature=0.1 \
  --learning_rate=0.075 --learning_rate_scaling=sqrt --weight_decay=1e-4 \
  --dataset=SHOULDER --eval_split=valid --resnet_depth=50 \
  --model_dir=/content/simclr/simclr_pretrain/shoulder \
  --use_tpu=False --train_summary_steps=100 --eval_batch_size=16

In [None]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir /content/simclr/simclr_pretrain/shoulder --port=8017

###### Finetune 1 percentage

######1.1 lr=0.3

In [None]:
!rm -rf '/content/simclr/simclr_finetune/shoulder/log_1_lr_0.3'

In [None]:
!python run.py \
  --mode=train_then_eval \
  --train_mode=finetune \
  --fine_tune_after_block=4 \
  --zero_init_logits_layer=True \
  --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
  --global_bn=False \
  --optimizer=momentum \
  --learning_rate=0.3 \
  --weight_decay=0.0 \
  --train_epochs=20 \
  --train_batch_size=8 \
  --warmup_epochs=0 \
  --dataset=SHOULDER/1 \
  --eval_split=test \
  --resnet_depth=50 \
  --checkpoint=/content/simclr/simclr_pretrain/shoulder \
  --model_dir=/content/simclr/simclr_finetune/shoulder/log_1_lr_0.3 \
  --use_tpu=False \
  --eval_batch_size=16

######1.2 lr=0.1

In [None]:
!rm -rf '/content/simclr/simclr_finetune/shoulder/log_1_lr_0.1'

In [None]:
!python run.py \
  --mode=train_then_eval \
  --train_mode=finetune \
  --fine_tune_after_block=4 \
  --zero_init_logits_layer=True \
  --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
  --global_bn=False \
  --optimizer=momentum \
  --learning_rate=0.1 \
  --weight_decay=0.0 \
  --train_epochs=20 \
  --train_batch_size=8 \
  --warmup_epochs=0 \
  --dataset=SHOULDER/1 \
  --eval_split=test \
  --resnet_depth=50 \
  --checkpoint=/content/simclr/simclr_pretrain/shoulder \
  --model_dir=/content/simclr/simclr_finetune/shoulder/log_1_lr_0.1 \
  --use_tpu=False \
  --eval_batch_size=16

###### Finetune 10 percentage

###### 2.1 lr=0.3

In [None]:
!rm -rf '/content/simclr/simclr_finetune/shoulder/log_10_lr_0.3'

In [None]:
!python run.py \
  --mode=train_then_eval \
  --train_mode=finetune \
  --fine_tune_after_block=4 \
  --zero_init_logits_layer=True \
  --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
  --global_bn=False \
  --optimizer=momentum \
  --learning_rate=0.3 \
  --weight_decay=0.0 \
  --train_epochs=20 \
  --train_batch_size=8 \
  --warmup_epochs=0 \
  --dataset=SHOULDER/10 \
  --eval_split=test \
  --resnet_depth=50 \
  --checkpoint=/content/simclr/simclr_pretrain/shoulder \
  --model_dir=/content/simclr/simclr_finetune/shoulder/log_10_lr_0.3 \
  --use_tpu=False \
  --eval_batch_size=16

###### 2.2 lr=0.1

In [None]:
!rm -rf '/content/simclr/simclr_finetune/shoulder/log_10_lr_0.1'

In [None]:
!python run.py \
  --mode=train_then_eval \
  --train_mode=finetune \
  --fine_tune_after_block=4 \
  --zero_init_logits_layer=True \
  --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
  --global_bn=False \
  --optimizer=momentum \
  --learning_rate=0.1 \
  --weight_decay=0.0 \
  --train_epochs=20 \
  --train_batch_size=8 \
  --warmup_epochs=0 \
  --dataset=SHOULDER/10 \
  --eval_split=test \
  --resnet_depth=50 \
  --checkpoint=/content/simclr/simclr_pretrain/shoulder \
  --model_dir=/content/simclr/simclr_finetune/shoulder/log_10_lr_0.1 \
  --use_tpu=False \
  --eval_batch_size=16

###### Finetune 100 percentage

######3.1 lr=0.3

In [None]:
!rm -rf '/content/simclr/simclr_finetune/shoulder/log_100_lr_0.3'

In [None]:
!python run.py \
  --mode=train_then_eval \
  --train_mode=finetune \
  --fine_tune_after_block=4 \
  --zero_init_logits_layer=True \
  --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
  --global_bn=False \
  --optimizer=momentum \
  --learning_rate=0.3 \
  --weight_decay=0.0 \
  --train_epochs=30 \
  --train_batch_size=8 \
  --warmup_epochs=0 \
  --dataset=SHOULDER/100 \
  --eval_split=test \
  --resnet_depth=50 \
  --checkpoint=/content/simclr/simclr_pretrain/shoulder \
  --model_dir=/content/simclr/simclr_finetune/shoulder/log_100_lr_0.3 \
  --use_tpu=False \
  --eval_batch_size=16

######3.2 lr=0.1

In [None]:
!rm -rf '/content/simclr/simclr_finetune/shoulder/log_100_lr_0.1'

In [None]:
!python run.py \
  --mode=train_then_eval \
  --train_mode=finetune \
  --fine_tune_after_block=4 \
  --zero_init_logits_layer=True \
  --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
  --global_bn=False \
  --optimizer=momentum \
  --learning_rate=0.1 \
  --weight_decay=0.0 \
  --train_epochs=30 \
  --train_batch_size=8 \
  --warmup_epochs=0 \
  --dataset=SHOULDER/100 \
  --eval_split=test \
  --resnet_depth=50 \
  --checkpoint=/content/simclr/simclr_pretrain/shoulder \
  --model_dir=/content/simclr/simclr_finetune/shoulder/log_100_lr_0.1 \
  --use_tpu=False \
  --eval_batch_size=16

###### Graphs and Evals

In [None]:
plot_statistics("shoulder")