In [1]:
# !pip install flask_caching

In [2]:
import os
import time
import uuid
from copy import deepcopy
import csv
import sys
import pathlib
from jupyter_dash import JupyterDash
import dash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output, State
from flask_caching import Cache
import sys
import dash_reusable_components as drc
import utils
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image

In [3]:
from torchvision.transforms import ToTensor

# Load the model

In [4]:
class ImageClassificationBase(nn.Module):
    def training_step(self, batch):
        images, labels = batch
        images = DiffAugment(images, policy='color,translation') #DiffAugment is used here
        out = self(images)                  # Generate predictions
        loss = F.cross_entropy(out, labels) # Calculate loss
        return loss
    
    def validation_step(self, batch):
        images, labels = batch 
        out = self(images)                    # Generate predictions
        loss = F.cross_entropy(out, labels)   # Calculate loss
        acc = accuracy(out, labels)           # Calculate accuracy
        return {'val_loss': loss.detach(), 'val_acc': acc}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['train_loss'], result['val_loss'], result['val_acc']))
        
class Net(ImageClassificationBase):
    def __init__(self, num_classes=2, num_channels=3):
        super().__init__()
        preloaded = torchvision.models.densenet161(pretrained=True)
        self.features = preloaded.features
        self.features.conv0 = nn.Conv2d(num_channels, 96, 7, 2, 3)
        self.classifier = nn.Linear(2208, num_classes, bias=True)
        self.bn = nn.BatchNorm1d(2208)
        del preloaded
        
    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.adaptive_max_pool2d(out, (1, 1)).view(features.size(0), -1)
        #out = self.bn(out)
        out = self.classifier(out)
        return out

def predict_image(img, model):
    # Convert to a batch of 1
    img = ToTensor()(img)
    xb = to_device(img.unsqueeze(0), device)
    # Get predictions from model
    yb = model(xb)
    # Pick index with highest probability
    _, preds  = torch.max(yb, dim=1)
    classes = ['autistic', 'non autistic']
    return classes[preds[0].item()]

Net = Net()

In [5]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

In [6]:
device = get_default_device()

In [7]:
Net.load_state_dict(torch.load('./autism_best_model.pt', map_location=torch.device('cpu')))

<All keys matched successfully>

In [8]:

app = JupyterDash(__name__)
server = app.server

def serve_layout():
    
    # App Layout
    return html.Div(
        id="root",
        children=[
            # Main body
            html.Div(
                id="app-container",
                children=[
                    # Banner display
                    html.Div(
                        id="banner",
                        children=[
                            html.Img(
                                id="logo", src=app.get_asset_url("taif-logo.png")
                            ),
                            html.H2("Taif: ASD Diagnosis System", id="title"),
                        ],
                    ),
                    html.Div(
                        id="image",
                        children=[
                            # The Interactive Image Div contains the dcc Graph
                            # showing the image, as well as the hidden div storing
                            # the true image
                            html.Div(
                                id="div-interactive-image",
                                children=[
                                    utils.GRAPH_PLACEHOLDER
                                ],
                            )
                        ],
                    ),
                ],
            ),
            # Sidebar
            html.Div(
                id="sidebar",
                children=[
                    #TO-DO:
                    #add prediction component could be h1 or h2...
                    html.Div([
                        html.H1("Prediction:"),
                        html.H3(id="prediction")
                    ]),
                    drc.Card(
                        [
                            dcc.Upload(
                                id="upload-image",
                                children=[
                                    "Drag and Drop or ",
                                    html.A(children="Select an Image"),
                                ],
                                # No CSS alternative here
                                style={
                                    "color": "darkgray",
                                    "width": "100%",
                                    "height": "50px",
                                    "lineHeight": "50px",
                                    "borderWidth": "1px",
                                    "borderStyle": "dashed",
                                    "borderRadius": "5px",
                                    "borderColor": "darkgray",
                                    "textAlign": "center",
                                    "padding": "2rem 0",
                                    "margin-bottom": "2rem",
                                },
                                accept="image/*",
                            ),
                    
                        ]
                    ),
                ],
            ),
        ],
    )


app.layout = serve_layout



In [9]:
@app.callback(
    Output("div-interactive-image", "children"),
       [
        Input("upload-image", "contents")
       ],
    [
        State("upload-image", "filename"),
    ],
)
def update_graph_interactive_image(
    content,
    new_filename
):


    
    # Parse the string and convert to pil
    string = content.split(";base64,")[-1]
    im_pil = drc.b64_to_pil(string)

    # Update the image signature, which is the first 200 b64 characters
    return [
        drc.InteractiveImagePIL(
            image_id="interactive-image",
            image=im_pil,
        )
    ]


In [10]:
#a function to update the prediction component you created above here 
#output would be the prediction component

@app.callback(
    Output("prediction", "children"),
       [
        Input("upload-image", "contents")
       ],
    [
        State("upload-image", "filename"),
    ],
)
def update_prediction(
    content,
    new_filename
):
 
    # Parse the string and convert to pil
    string = content.split(";base64,")[-1]
    im_pil = drc.b64_to_pil(string)
    pred =  predict_image(im_pil, Net)
    # Update the image signature, which is the first 200 b64 characters
    return 'We think your child is {}'.format(pred)


#and use your prediction function here 

In [11]:
# Running the server
if __name__ == "__main__":
    app.run_server(debug=True)

Dash app running on http://127.0.0.1:8050/
