<a href="https://colab.research.google.com/github/Siasmaan/Brain_tumor_DL/blob/colab/resnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
!pip install super-gradients &> /dev/null
!pip install torchinfo &> /dev/null
!pip install imutils &> /dev/null

In [None]:
%%capture
import os
import requests
import zipfile
import random
import numpy as np
import torchvision
import pprint
import torch
import pathlib

from matplotlib import pyplot as plt
from torchinfo import summary
from pathlib import Path, PurePath
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from PIL import Image
from typing import List, Tuple
import super_gradients
from super_gradients.training import models
from super_gradients.training import dataloaders
from super_gradients.training import Trainer
from super_gradients.training import training_hyperparams

In [None]:
class config:
    EXPERIMENT_NAME = 'resnet_in_action'
    MODEL_NAME = 'resnet50'
    CHECKPOINT_DIR = 'checkpoints'

    # specify the paths to training and validation set
    TRAIN_DIR = 'datasets/Datasets/Training'
    VAL_DIR = 'datasets/Datasets/Testing'

    # set the input height and width
    INPUT_HEIGHT = 224
    INPUT_WIDTH = 224

    # set the input height and width
    IMAGENET_MEAN = [0.485, 0.456, 0.406]
    IMAGENET_STD = [0.229, 0.224, 0.225]

    NUM_WORKERS = os.cpu_count()

    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

    FLIP_PROB = 0.25
    ROTATION_DEG = 15
    JITTER_PARAM = 0.25
    BATCH_SIZE = 64

In [None]:
trainer = Trainer(experiment_name=config.EXPERIMENT_NAME, ckpt_root_dir=config.CHECKPOINT_DIR, device=config.DEVICE)

KeyError: ignored

In [None]:
def create_dataloaders(
    train_dir: str,
    val_dir: str,
    train_transform: transforms.Compose,
    val_transform:  transforms.Compose,
    batch_size: int,
    num_workers: int=config.NUM_WORKERS
):
  """Creates training and validation DataLoaders.
  Args:
    train_dir: Path to training data.
    val_dir: Path to validation data.
    transform: Transformation pipeline.
    batch_size: Number of samples per batch in each of the DataLoaders.
    num_workers: An integer for number of workers per DataLoader.
  Returns:
    A tuple of (train_dataloader, val_dataloader, class_names).
  """
  # Use ImageFolder to create dataset
  train_data = datasets.ImageFolder(train_dir, transform=train_transform)
  val_data = datasets.ImageFolder(val_dir, transform=val_transform)

  print(f"[INFO] training dataset contains {len(train_data)} samples...")
  print(f"[INFO] validation dataset contains {len(val_data)} samples...")

  # Get class names
  class_names = train_data.classes
  print(f"[INFO] dataset contains {len(class_names)} labels...")

  # Turn images into data loaders
  print("[INFO] creating training and validation set dataloaders...")
  train_dataloader = DataLoader(
      train_data,
      batch_size=batch_size,
      shuffle=True,
      num_workers=num_workers,
      pin_memory=True,
  )
  val_dataloader = DataLoader(
      val_data,
      batch_size=batch_size,
      shuffle=False,
      num_workers=num_workers,
      pin_memory=True,
  )

  return train_dataloader, val_dataloader, class_names

In [None]:
# initialize our data augmentation functions
resize = transforms.Resize(size=(config.INPUT_HEIGHT,config.INPUT_WIDTH))

horizontal_flip = transforms.RandomHorizontalFlip(p=config.FLIP_PROB)


random_crop = transforms.RandomCrop(size=(config.INPUT_HEIGHT,config.INPUT_WIDTH))

norm = transforms.Normalize(mean=config.IMAGENET_MEAN, std=config.IMAGENET_STD)

make_tensor = transforms.ToTensor()

# initialize our training and validation set data augmentation pipeline
train_transforms = transforms.Compose([resize, random_crop, horizontal_flip,  make_tensor, norm])
val_transforms = transforms.Compose([resize, make_tensor, norm])

In [None]:
train_dataloader, valid_dataloader, class_names = create_dataloaders(train_dir=config.TRAIN_DIR,
                                                                     val_dir=config.VAL_DIR,
                                                                     train_transform=train_transforms,
                                                                     val_transform=val_transforms,
                                                                     batch_size=config.BATCH_SIZE)
NUM_CLASSES = len(class_names)

FileNotFoundError: ignored

In [None]:
def show_img(img):
    plt.figure(figsize=(20,16))
    img = img * 0.5 + 0.5
    npimg = np.clip(img.numpy(), 0., 1.)
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

data = iter(train_dataloader)
images, labels = data.next()
show_img(torchvision.utils.make_grid(images))

NameError: ignored

In [None]:
resnet50_imagenet_model = models.get(model_name=config.MODEL_NAME, num_classes=NUM_CLASSES, pretrained_weights="imagenet")

In [None]:
summary(model=resnet50_imagenet_model,
        input_size=(32, 3, 224, 224),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

In [None]:
# Create a list of layers that you want to freeze
layers_to_freeze = [resnet50_imagenet_model.conv1, resnet50_imagenet_model.bn1, resnet50_imagenet_model.layer1, resnet50_imagenet_model.layer2]

# Loop through the list of layers and set the requires_grad attribute to False for each layer
for layer in layers_to_freeze:
    for param in layer.parameters():
        param.requires_grad = False

In [None]:
summary(model=resnet50_imagenet_model,
        input_size=(32, 3, 224, 224),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

In [None]:
%%capture
training_params =  training_hyperparams.get("training_hyperparams/imagenet_resnet50_train_params")

In [None]:
pprint.pprint("Training parameters")
pprint.pprint(training_params)

In [None]:
training_params["max_epochs"] = 4
training_params["ema"] = True
training_params["criterion_params"] = {'smooth_eps': 0.1} # Enable label-smoothing cross-entropy

In [None]:
trainer.train(model=resnet50_imagenet_model,
              training_params=training_params,
              train_loader=train_dataloader,
              valid_loader=valid_dataloader)

In [None]:
# Load the best model that we trained
best_model = models.get(config.MODEL_NAME,
                        num_classes=NUM_CLASSES,
                        checkpoint_path=os.path.join(trainer.checkpoints_dir_path,"ckpt_best.pth"))

In [None]:
trainer.test(model=best_model, test_loader=valid_dataloader, test_metrics_list=['Accuracy','Top5'])

In [None]:
# Get a random list of image paths from test set
num_images_to_plot = 30
test_image_path_list = list(Path(config.VAL_DIR).glob("*/*.jpg")) # get list all image paths from test data
test_image_path_sample = random.sample(population=test_image_path_list, # go through all of the test image paths
                                       k=num_images_to_plot) # randomly select 'k' image paths to pred and plot

# set up subplots
num_rows = int(np.ceil(num_images_to_plot / 5))
fig, ax = plt.subplots(num_rows, 5, figsize=(15, num_rows * 3))
ax = ax.flatten()

# Make predictions on and plot the images
for i, image_path in enumerate(test_image_path_sample):
    pred_and_plot_image(model=best_model,
                        image_path=image_path,
                        class_names=class_names,
                        subplot=(num_rows, 5, i+1),  # subplot tuple for `subplot()` function
                        image_size=(config.INPUT_HEIGHT, config.INPUT_WIDTH))

# adjust spacing between subplots
plt.subplots_adjust(wspace=1)
plt.show()


In [None]:
import textwrap

def pred_and_plot_image(model: torch.nn.Module,
                        image_path: str,
                        class_names: List[str],
                        subplot: Tuple[int, int, int],  # subplot tuple for `subplot()` function
                        image_size: Tuple[int, int] = (config.INPUT_HEIGHT, config.INPUT_WIDTH),
                        transform: torchvision.transforms = None,
                        device: torch.device=config.DEVICE):

    if isinstance(image_path, pathlib.PosixPath):
        img = Image.open(image_path)
    else:
        img = Image.open(requests.get(image_path, stream=True).raw)

    # create transformation for image (if one doesn't exist)
    if transform is None:
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=config.IMAGENET_MEAN,
                                 std=config.IMAGENET_STD),
        ])
    transformed_image = transform(img)

    # make sure the model is on the target device
    model.to(device)

    # turn on model evaluation mode and inference mode
    model.eval()
    with torch.inference_mode():
        # add an extra dimension to image (model requires samples in [batch_size, color_channels, height, width])
        transformed_image = transformed_image.unsqueeze(dim=0)

        # make a prediction on image with an extra dimension and send it to the target device
        target_image_pred = model(transformed_image.to(device))

    # convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
    target_image_pred_probs = torch.softmax(target_image_pred, dim=1)

    # convert prediction probabilities -> prediction labels
    target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)

    # actual label
    ground_truth = PurePath(image_path).parent.name

    # plot image with predicted label and probability
    plt.subplot(*subplot)
    plt.imshow(img)
    if isinstance(image_path, pathlib.PosixPath):
        title = f"Ground Truth: {ground_truth} | Pred: {class_names[target_image_pred_label]} | Prob: {target_image_pred_probs.max():.3f}"
    else:
        title = f"Pred: {class_names[target_image_pred_label]} | Prob: {target_image_pred_probs.max():.3f}"
    plt.title("\n".join(textwrap.wrap(title, width=20)))  # wrap text using textwrap.wrap() function
    plt.axis(False)