# Application

In [1]:
from ipywidgets import Accordion, Button, Dropdown, FileUpload, HBox, Image, Output, Tab, VBox, Layout, IntSlider, AppLayout
from PIL import Image as PilImage, ImageOps
import PIL.Image
#from PIL import Image
from skimage import color
import io
import os

import IPython.display as display
import numpy as np

import torch
from torchvision import transforms, models, datasets
from torch import nn

#Google Mount
from google.colab import drive
drive.mount('/content/drive/')
%cd drive/My\ Drive/Classes/Covid_19_project/

train_dir = 'data/train'
val_dir = 'data/val'
test_dir = 'data/test'


dirs = {'test': test_dir,
        'train': train_dir, 
        'val': val_dir, 
       }

# add transforms to the data
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.RandomResizedCrop(224),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.CenterCrop(224),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.CenterCrop(224),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], 
                             [0.229, 0.224, 0.225])
    ]),
}

batch_size = 16

# Load the datasets with ImageFolder
image_datasets = {x: datasets.ImageFolder(dirs[x],   transform=data_transforms[x]) for x in ['train', 'val', 'test']}

# load the data into batches
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size = batch_size, shuffle = True) for x in ['train', 'val', 'test']}

dataset_sizes = {x: len(image_datasets[x]) 
                              for x in ['train', 'val', 'test']}

class_names = image_datasets['test'].classes


# ResNet50 model
resnet50_model = models.resnet50(pretrained=True)
num_features = resnet50_model.fc.in_features
# change the outputs from 512 to 2
resnet50_model.fc = nn.Linear(num_features, 2)

#unfreezing your model (for better accuracy)
for param in resnet50_model.parameters():
    param.requires_grad = True

checkpoint_covid_net = torch.load('best_models/resnet50.pth')    
    
def load_checkpoint(model, checkpoint, filepath):
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    for parameter in model.parameters():
        parameter.requires_grad = False
    model.eval()
    return model

def makepred(img):
  image_preprocess = transforms.Compose([                                 
        transforms.Resize((224,224)),
        transforms.CenterCrop(224),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])

  # Pass the image for preprocessing
  img_preprocessed = image_preprocess(img) #.cuda()

  # Reshape, crop, and normalize the input tensor for feeding into network for evaluation
  img_tensor = torch.unsqueeze(img_preprocessed, 0)

  model = load_checkpoint(resnet50_model, checkpoint_covid_net, 'best_models/resnet50.pth')
  #print(model)

  out = model(img_tensor)

  _, index = torch.max(out, 1)

  percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
  
  return print(class_names[index[0]], percentage[index[0]].item())   

class Covid_prediction:
  def __init__(self):
    self.file_upload = FileUpload(accept='image/*', multiple=False, description='Upload image')
    self.file_upload.style.button_color = 'lightblue'
    self.apply_button = Button(description='Check disease')
    self.apply_button.style.button_color = 'salmon'
    self.apply_button.on_click(self.on_apply)

    self.output = Output()
     
    self.tab = Tab(layout=Layout(width='50%', height = '100%'))
    self.tab.children = [self.output]
    self.tab.set_title(0,'Results')
    
    self.header = HBox([self.file_upload, self.apply_button])
    self.container = VBox([self.header, self.tab])

  def on_apply(self, btn):
    with self.output:

      #print(f'Applying model to the image.... ')
      upload_dict = self.file_upload.value
      img = list(upload_dict.values())[0]
      img_meta = img['metadata']
      img_content = img['content']
      #print(img_meta)
      #print(f'Name: {img_meta["name"]}, Size: {img_meta["size"]}')
      
      img = Image(value = img_content)

      test_img = PIL.Image.open('data/'+img_meta["name"]).convert('RGB')
      print('---------------') 
      print('Your X-ray image is: ') 
      makepred(test_img)
      
      self.tab.children = list(self.tab.children) + [img]
      self.tab.set_title(len(self.tab.children)-1, 'Your image')
      
  
  def get_layout(self):
    return self.container


Covid_prediction().get_layout()

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).
/content/drive/My Drive/Classes/Covid_19_project


VBox(children=(HBox(children=(FileUpload(value={}, accept='image/*', description='Upload image', style=ButtonS…