In [None]:
!git clone https://github.com/LongpingZhang/stat430_classification.git

Cloning into 'stat430_classification'...
remote: Enumerating objects: 644, done.[K
remote: Counting objects: 100% (24/24), done.[K
remote: Compressing objects: 100% (16/16), done.[K
remote: Total 644 (delta 5), reused 23 (delta 4), pack-reused 620[K
Receiving objects: 100% (644/644), 191.00 MiB | 26.45 MiB/s, done.
Resolving deltas: 100% (5/5), done.
Checking out files: 100% (616/616), done.


In [None]:
import sys
sys.path.append('stat430_classification/')

In [None]:
from utils.annotation_util import annotation_df
from utils.dataset import CustomImageDataset
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import os
import glob
from model.model import CustomNetwork
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from copy import deepcopy

In [None]:
batch_size = 10
num_classes = 1
epochs = 25
criterion = torch.nn.BCEWithLogitsLoss(reduction='mean')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

project_dir = "stat430_classification/"
base_dir = os.path.join(project_dir, "data/")
train_dir = os.path.join(base_dir, "Training")
val_dir = os.path.join(base_dir, "Validate")
test_dir = os.path.join(base_dir, "Testing")
df_train, df_val, df_test = annotation_df(train_dir, val_dir, test_dir)

torch.manual_seed(17)
data_aug = transforms.Compose([transforms.Resize((224,224))])
train_dataset = CustomImageDataset(df_train, transform = data_aug)
val_dataset = CustomImageDataset(df_val, transform = data_aug)
test_dataset = CustomImageDataset(df_test, transform = data_aug)
train_dataloader = DataLoader(train_dataset, batch_size = batch_size)
val_dataloader = DataLoader(val_dataset, batch_size = batch_size)
test_dataloader = DataLoader(test_dataset, batch_size = batch_size)

In [None]:
!rm -rf logs/

In [None]:
df_results = []

for backbone in ['alexnet', 'vgg11', 'resnet18']:
    for learning_rate in [1e-1, 1e-2, 1e-3]:
        model = CustomNetwork(num_classes=num_classes, loss_fn=criterion, device=device, threshold=0.5, backbone=backbone).to(device)
        log_dir_base = os.path.join(os.getcwd(), 'logs')
        experiment_name = '1'
        ckpts_til_saving = 5
        start_training_from_ckpt = None
        optimizer = torch.optim.Adam(
            model.parameters(), lr=learning_rate
        )

        log_dir = os.path.join(log_dir_base, experiment_name, backbone, 'lr_'+str(learning_rate))
        Path(log_dir).mkdir(parents=True, exist_ok=True)
        model_weights_dir = os.path.join(log_dir, "checkpoints")
        Path(model_weights_dir).mkdir(parents=True, exist_ok=True)
        summary_dir = os.path.join(log_dir, "summary")
        Path(summary_dir).mkdir(parents=True, exist_ok=True)

        if start_training_from_ckpt:
            model = model.load_state_dict(torch.load(start_training_from_ckpt)['model_state_dict'])
        
        writer = SummaryWriter(log_dir=summary_dir ,flush_secs=20)
        best_loss = float("inf")
        train_losses = []
        val_losses = []
        epoch_num = []

        for epoch in range(epochs):
            if epoch == 0:
                torch.save({'model_state_dict': model.state_dict()},
                            f'{model_weights_dir}/epoch{epoch}_before_training.pt')
                
            iteration = 0
            epoch_train_loss_it_cum = 0
            model.train()
    
            for batch in train_dataloader:
                optimizer.zero_grad()
                train_loss = model.training_step(batch)
                train_loss.backward()
                optimizer.step()
                epoch_train_loss_it_cum += train_loss.item()

                iteration += 1         
            epoch_train_loss = epoch_train_loss_it_cum / iteration
            train_losses.append(epoch_train_loss)
            epoch_num.append(epoch)

            # Validation step
            with torch.no_grad():
                model.eval()
                val_loss, cf_matrix = model.validation_step(val_dataloader)
                val_losses.append(val_loss.item())
                
                model.train()
            
            # Write to logs for tensorboard visualization
            writer.add_scalars('alexnet', {'training_loss': epoch_train_loss,
                                        'validation_loss': val_loss}, epoch)
            
            # Save the model weights every ckpts_til_saving
            if epochs % ckpts_til_saving == 0:
                torch.save({'model_state_dict': model.state_dict()},
                          f'{model_weights_dir}/epoch{epoch}.pt')
            
            # Save the best model
            if val_loss < best_loss:
                best_loss = val_loss
                torch.save({'model_state_dict': model.state_dict()},
                          f'{model_weights_dir}/best_model.pt')
                
            # Calculate accuracy, sensitivity, and specificity over validation set
            tn, fp, fn, tp = cf_matrix.ravel()
            accuracy = (tp + tn) / (tp + tn + fp + fn)

            print(f'Method: {backbone}, epcoh: {epoch}, training_loss: {epoch_train_loss}, validation_loss: {val_loss}, accuracy: {accuracy}')
        
        # Testing step
        with torch.no_grad():
            best_model = deepcopy(model)
            best_model.load_state_dict(torch.load(f'{model_weights_dir}/best_model.pt')['model_state_dict'])
            best_model.eval()
            test_loss, cf_matrix = model.testing_step(test_dataloader)
        tn, fp, fn, tp = cf_matrix.ravel()
        accuracy = (tp + tn) / (tp + tn + fp + fn)
        # precision = tp / (tp + fp)
        # recall = tp / (tp + fn)
        sensitivity = tp / (tp + fn)
        specificity = tn / (tn + fp)
        df_results.append({'Method': backbone, 
                           'Learning_rate': learning_rate, 
                           'test_loss': test_loss.item(), 
                           'accuracy': accuracy,
                           'sensitivity': sensitivity,
                           'specificity': specificity})

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Method: alexnet, epcoh: 0, training_loss: 1.0978629016323682e+19, validation_loss: 9317301354496.0, accuracy: 0.5
Method: alexnet, epcoh: 1, training_loss: 6406733927195.385, validation_loss: 560413439688704.0, accuracy: 0.5
Method: alexnet, epcoh: 2, training_loss: 3001854052457.1104, validation_loss: 9578061234176.0, accuracy: 0.5
Method: alexnet, epcoh: 3, training_loss: 412256927097.2308, validation_loss: 40166982942720.0, accuracy: 0.5
Method: alexnet, epcoh: 4, training_loss: 1033384226887.3846, validation_loss: 241449572499456.0, accuracy: 0.5
Method: alexnet, epcoh: 5, training_loss: 1689103214542.7693, validation_loss: 31920415047680.0, accuracy: 0.5
Method: alexnet, epcoh: 6, training_loss: 188749033738.6154, validation_loss: 518939893301248.0, accuracy: 0.5
Method: alexnet, epcoh: 7, training_loss: 2040833007903.4326, validation_loss: 129923532455936.0, accuracy: 0.5
Method: alexnet, epcoh: 8, training_loss: 301087481016.7211, validation_loss: 40076931235840.0, accuracy: 0.5

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Method: alexnet, epcoh: 0, training_loss: 32986796347.1911, validation_loss: 440791392.0, accuracy: 0.5
Method: alexnet, epcoh: 1, training_loss: 2025389353.2157452, validation_loss: 239587968.0, accuracy: 0.5
Method: alexnet, epcoh: 2, training_loss: 21421957.99215933, validation_loss: 557099968.0, accuracy: 0.5
Method: alexnet, epcoh: 3, training_loss: 53909518.76242975, validation_loss: 67381.09375, accuracy: 0.5
Method: alexnet, epcoh: 4, training_loss: 9023.2035108713, validation_loss: 16554.728515625, accuracy: 0.5
Method: alexnet, epcoh: 5, training_loss: 2959.9400349396924, validation_loss: 6471.3173828125, accuracy: 0.5
Method: alexnet, epcoh: 6, training_loss: 1250.3557558549512, validation_loss: 1537.923095703125, accuracy: 0.5
Method: alexnet, epcoh: 7, training_loss: 393.8301139153683, validation_loss: 1250.466064453125, accuracy: 0.5
Method: alexnet, epcoh: 8, training_loss: 177.26471884928358, validation_loss: 144.5393829345703, accuracy: 0.5
Method: alexnet, epcoh: 9, t

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Method: alexnet, epcoh: 0, training_loss: 405.41262411691537, validation_loss: 123338.015625, accuracy: 0.5
Method: alexnet, epcoh: 1, training_loss: 8125.914536971312, validation_loss: 50677.390625, accuracy: 0.5
Method: alexnet, epcoh: 2, training_loss: 3756.9160175759057, validation_loss: 13234.2099609375, accuracy: 0.5
Method: alexnet, epcoh: 3, training_loss: 963.449623936644, validation_loss: 10427.7392578125, accuracy: 0.5
Method: alexnet, epcoh: 4, training_loss: 780.6791548063472, validation_loss: 15115.955078125, accuracy: 0.5
Method: alexnet, epcoh: 5, training_loss: 954.2982205863183, validation_loss: 1851.060791015625, accuracy: 0.5
Method: alexnet, epcoh: 6, training_loss: 159.25068883941108, validation_loss: 1095.473876953125, accuracy: 0.5
Method: alexnet, epcoh: 7, training_loss: 72.06444875287005, validation_loss: 995.644287109375, accuracy: 0.5
Method: alexnet, epcoh: 8, training_loss: 67.61432338092453, validation_loss: 523.9483642578125, accuracy: 0.5
Method: alexn

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Method: vgg11, epcoh: 0, training_loss: 2.722778435886305e+27, validation_loss: 3.190071305678224e+27, accuracy: 0.5
Method: vgg11, epcoh: 1, training_loss: 2.586323947911489e+27, validation_loss: 3.017751267274595e+27, accuracy: 0.5
Method: vgg11, epcoh: 2, training_loss: 2.3289592294962493e+26, validation_loss: 2.7056496236998823e+26, accuracy: 0.5
Method: vgg11, epcoh: 3, training_loss: 8.134612973034071e+25, validation_loss: 8.96538970929793e+25, accuracy: 0.5
Method: vgg11, epcoh: 4, training_loss: 7.653952280118011e+24, validation_loss: 4.2620735249096553e+24, accuracy: 0.4375
Method: vgg11, epcoh: 5, training_loss: 5.433861182391295e+24, validation_loss: 2.61696348891787e+24, accuracy: 0.34375
Method: vgg11, epcoh: 6, training_loss: 4.763290812895028e+24, validation_loss: 1.918759125644358e+24, accuracy: 0.34375
Method: vgg11, epcoh: 7, training_loss: 4.381637977395894e+24, validation_loss: 1.7721733708993178e+24, accuracy: 0.3125
Method: vgg11, epcoh: 8, training_loss: 4.213171

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Method: vgg11, epcoh: 0, training_loss: 452806895428214.25, validation_loss: 2285431685120.0, accuracy: 0.5
Method: vgg11, epcoh: 1, training_loss: 7118491723815.385, validation_loss: 287377489920.0, accuracy: 0.5
Method: vgg11, epcoh: 2, training_loss: 83295991477.11539, validation_loss: 189507141632.0, accuracy: 0.5
Method: vgg11, epcoh: 3, training_loss: 32393032072.923077, validation_loss: 795451457536.0, accuracy: 0.5
Method: vgg11, epcoh: 4, training_loss: 49575290496.61539, validation_loss: 4560942592.0, accuracy: 0.5
Method: vgg11, epcoh: 5, training_loss: 2702976898667.6924, validation_loss: 1383715045376.0, accuracy: 0.5
Method: vgg11, epcoh: 6, training_loss: 150106569235.69232, validation_loss: 15249518362624.0, accuracy: 0.5
Method: vgg11, epcoh: 7, training_loss: 917871821968.0, validation_loss: 63801581568.0, accuracy: 0.5
Method: vgg11, epcoh: 8, training_loss: 4550495393.961538, validation_loss: 11490113536.0, accuracy: 0.5
Method: vgg11, epcoh: 9, training_loss: 34629

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Method: vgg11, epcoh: 0, training_loss: 5860.79249059237, validation_loss: 183882.46875, accuracy: 0.5
Method: vgg11, epcoh: 1, training_loss: 11610.586405313932, validation_loss: 42125.4921875, accuracy: 0.5
Method: vgg11, epcoh: 2, training_loss: 2510.9661285258258, validation_loss: 14926.1455078125, accuracy: 0.5
Method: vgg11, epcoh: 3, training_loss: 1345.584905220912, validation_loss: 36200.8046875, accuracy: 0.5
Method: vgg11, epcoh: 4, training_loss: 1903.4955755380483, validation_loss: 5415.2724609375, accuracy: 0.5
Method: vgg11, epcoh: 5, training_loss: 459.3778746128082, validation_loss: 5885.0791015625, accuracy: 0.5
Method: vgg11, epcoh: 6, training_loss: 464.3881662866513, validation_loss: 10417.59765625, accuracy: 0.5
Method: vgg11, epcoh: 7, training_loss: 672.0488331806201, validation_loss: 9772.900390625, accuracy: 0.5
Method: vgg11, epcoh: 8, training_loss: 786.514889056866, validation_loss: 4605.71484375, accuracy: 0.5
Method: vgg11, epcoh: 9, training_loss: 305.17

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Method: resnet18, epcoh: 0, training_loss: 16.575418563320106, validation_loss: 20.831039428710938, accuracy: 0.5
Method: resnet18, epcoh: 1, training_loss: 5.09189942453304, validation_loss: 165.57723999023438, accuracy: 0.5
Method: resnet18, epcoh: 2, training_loss: 4.192355534654871, validation_loss: 172.0626983642578, accuracy: 0.5
Method: resnet18, epcoh: 3, training_loss: 4.55914563244499, validation_loss: 8.745125770568848, accuracy: 0.5
Method: resnet18, epcoh: 4, training_loss: 4.405631171926002, validation_loss: 4.328950881958008, accuracy: 0.5
Method: resnet18, epcoh: 5, training_loss: 2.820039151435637, validation_loss: 18.067476272583008, accuracy: 0.5
Method: resnet18, epcoh: 6, training_loss: 5.098500187750295, validation_loss: 83.21969604492188, accuracy: 0.5
Method: resnet18, epcoh: 7, training_loss: 5.471969490570446, validation_loss: 4.010058403015137, accuracy: 0.5
Method: resnet18, epcoh: 8, training_loss: 3.9068411493707567, validation_loss: 37.08466720581055, acc

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Method: resnet18, epcoh: 0, training_loss: 1.2929516525295084, validation_loss: 17683.666015625, accuracy: 0.5
Method: resnet18, epcoh: 1, training_loss: 4.212694412476804, validation_loss: 1.1532440185546875, accuracy: 0.5
Method: resnet18, epcoh: 2, training_loss: 1.0190181457079375, validation_loss: 15.427924156188965, accuracy: 0.5
Method: resnet18, epcoh: 3, training_loss: 0.9253433341017137, validation_loss: 17.67323112487793, accuracy: 0.5
Method: resnet18, epcoh: 4, training_loss: 0.8278463213489606, validation_loss: 0.9652955532073975, accuracy: 0.5
Method: resnet18, epcoh: 5, training_loss: 0.7973280720985852, validation_loss: 0.7313010096549988, accuracy: 0.5
Method: resnet18, epcoh: 6, training_loss: 0.7204252790946227, validation_loss: 0.6678750514984131, accuracy: 0.5
Method: resnet18, epcoh: 7, training_loss: 0.7067021045547265, validation_loss: 0.6746976375579834, accuracy: 0.5
Method: resnet18, epcoh: 8, training_loss: 0.7008995703206613, validation_loss: 0.67856502532

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Method: resnet18, epcoh: 0, training_loss: 0.6432342769101245, validation_loss: 5.838571548461914, accuracy: 0.5
Method: resnet18, epcoh: 1, training_loss: 1.4077180109631557, validation_loss: 0.8475325107574463, accuracy: 0.5
Method: resnet18, epcoh: 2, training_loss: 1.1311612576246262, validation_loss: 0.6919293403625488, accuracy: 0.5
Method: resnet18, epcoh: 3, training_loss: 1.0030908223528128, validation_loss: 0.7144464254379272, accuracy: 0.5
Method: resnet18, epcoh: 4, training_loss: 0.9733205431929002, validation_loss: 0.6317487955093384, accuracy: 0.5
Method: resnet18, epcoh: 5, training_loss: 0.9406625826198322, validation_loss: 0.6932646036148071, accuracy: 0.5
Method: resnet18, epcoh: 6, training_loss: 0.9825828745961189, validation_loss: 0.6884805560112, accuracy: 0.5
Method: resnet18, epcoh: 7, training_loss: 0.9374059066176414, validation_loss: 0.6753637790679932, accuracy: 0.5
Method: resnet18, epcoh: 8, training_loss: 0.9100435490791614, validation_loss: 0.6857782006

In [None]:
import pandas as pd
pd.DataFrame(df_results)

Unnamed: 0,Method,Learning_rate,test_loss,accuracy,sensitivity,specificity
0,alexnet,0.1,106.89,0.53125,0.625,0.4375
1,alexnet,0.01,0.6898298,0.5,1.0,0.0
2,alexnet,0.001,4.659916,0.703125,0.96875,0.4375
3,vgg11,0.1,3.688562e+24,0.546875,0.8125,0.28125
4,vgg11,0.01,88794540.0,0.5,1.0,0.0
5,vgg11,0.001,3.328768,0.5,1.0,0.0
6,resnet18,0.1,0.7252467,0.5,1.0,0.0
7,resnet18,0.01,0.6870048,0.5,1.0,0.0
8,resnet18,0.001,0.7893061,0.5,1.0,0.0


In [None]:
!pip install gradio

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting gradio
  Downloading gradio-3.13.0-py3-none-any.whl (13.8 MB)
[K     |████████████████████████████████| 13.8 MB 9.5 MB/s 
Collecting websockets>=10.0
  Downloading websockets-10.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (106 kB)
[K     |████████████████████████████████| 106 kB 43.1 MB/s 
[?25hCollecting fastapi
  Downloading fastapi-0.88.0-py3-none-any.whl (55 kB)
[K     |████████████████████████████████| 55 kB 1.7 MB/s 
Collecting orjson
  Downloading orjson-3.8.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (278 kB)
[K     |████████████████████████████████| 278 kB 76.0 MB/s 
Collecting python-multipart
  Downloading python-multipart-0.0.5.tar.gz (32 kB)
Collecting ffmpy
  Downloading ffmpy-0.3.0.tar.gz (4.8 kB)
Collecting pycryptodome
  Downloading pycryptodome-3.16.0-cp35-abi3-manylinux_2_5_x86_64.manyl

In [None]:
from PIL import Image
import gradio as gr
import numpy as np

def classify_object_alexnet(numpy_image):
  img = Image.fromarray(np.uint8(numpy_image)).convert('RGB')

  preprocess = transforms.Compose(
    [
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()
    ]
  )

  input_img_preprocessed = preprocess(img)
  input_img_under_batch = torch.unsqueeze(input_img_preprocessed, 0)

  num_classes = 1
  criterion = torch.nn.BCEWithLogitsLoss(reduction='mean')
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  with torch.no_grad():
    model = CustomNetwork(num_classes=num_classes, loss_fn=criterion, device=device, threshold=0.5, backbone='alexnet')
    # model_path = 'logs/1/alexnet/lr_0.001/checkpoints/best_model.pt'
    # model.load_state_dict(torch.load(model_path)["model_state_dict"])
    pred = model(input_img_under_batch)
    pred_prob = 1 / (1 + torch.exp(-pred))
    pred_class = (pred_prob > 0.5).float()

  output_dictionary = {}
  print("Label - Probability")
  entry = {"Pneumonia": pred_prob.item(),
           "Normal": 1-pred_prob.item()}
  output_dictionary.update(entry)

  for key,val in output_dictionary.items():
    print(f'{key}: {val}')
  
  return output_dictionary 


webcam = gr.inputs.Image(shape=(224, 224), source="webcam")
gr.Interface(fn=classify_object_alexnet, inputs="image", outputs="label").launch(debug=True)



Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Note: opening Chrome Inspector may crash demo inside Colab notebooks.

To create a public link, set `share=True` in `launch()`.


<IPython.core.display.Javascript object>

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Label - Probability
Pneumonia: 0.31117573380470276
Normal: 0.6888242661952972
