In [1]:
import time
start_time = time.time()

In [2]:
import numpy as np
import pandas as pd
import sys
from os import listdir, makedirs, getcwd, remove
from os.path import isfile, join, abspath, exists, isdir, expanduser
from PIL import Image
from tqdm import tqdm

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.autograd as autograd
from torch.utils.data import Dataset, DataLoader
import torch.utils.data as data
import torch.nn.functional as F
import torch.optim as optim
import itertools

import torchvision
from torchvision import transforms, datasets, models
from torch import Tensor
from torch.utils.data.sampler import SubsetRandomSampler


from torch.utils.tensorboard import SummaryWriter

from sklearn.model_selection import ShuffleSplit

from resnet import resnet32

In [3]:
class medical_dataset(Dataset):
    def __init__(self, covid_path, normal_path, transform=None):
        self.covid_path = covid_path
        self.normal_path = normal_path

        self.transform = transforms.Compose([
            lambda x: Image.open(x).convert('RGB'),
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])

        self.covid_img = listdir(self.covid_path)
        self.normal_img = listdir(self.normal_path)
        
        print("Class imbalance is: {}".format(len(self.normal_img)/len(self.covid_img)))
        
        self.indices=np.arange(len(self.covid_img)+len(self.normal_img))

    def __len__(self):
        return len(self.normal_img) + len(self.covid_img)

    def __getitem__(self, idx):

        if (idx + 1 <= len(self.normal_img)):
            img = self.transform(join(self.normal_path, self.normal_img[idx]))
            label = 0
        else:
            idx = idx - len(self.normal_img)
            img = self.transform(join(self.covid_path, self.covid_img[idx]))
            label = 1

        return img, torch.tensor([label]).float()

In [4]:
dataset = medical_dataset(
    covid_path='./dataset/archive/COVID-19_Radiography_Dataset/COVID/',
    normal_path='./dataset/archive/COVID-19_Radiography_Dataset/Normal/',
)

X_unshuffled = dataset.indices
rs = ShuffleSplit(n_splits=1, test_size=.2, random_state=32)
rs.get_n_splits(X_unshuffled)

train_ind = []
val_ind = []
for train_index, test_index in rs.split(X_unshuffled):
    train_ind.append(train_index)
    val_ind.append(test_index)

Class imbalance is: 2.8185840707964602


In [5]:
train_sampler = SubsetRandomSampler(train_ind[0].tolist())
test_sampler = SubsetRandomSampler(val_ind[0].tolist())

train_loader = DataLoader(dataset, batch_size=32, sampler=train_sampler)
test_loader = DataLoader(dataset, batch_size=32, sampler=test_sampler)

In [6]:
class MiniONN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)
        self.conv2 = nn.Conv2d(64, 64, 3, 1, 1)
        self.conv3 = nn.Conv2d(64, 64, 3, 1, 1)
        self.conv4 = nn.Conv2d(64, 64, 3, 1, 1)
        self.conv5 = nn.Conv2d(64, 64, 3, 1, 1)
        
        self.conv6 = nn.Conv2d(64, 64, 1, 1, 0)
        self.conv7 = nn.Conv2d(64, 16, 1, 1, 0)
        
        self.fc = nn.Linear(1024, 1)
        
        self.avg1 = nn.AvgPool2d(2, 2)
        self.avg2 = nn.AvgPool2d(2, 2)
        
        
    def forward(self, x):
        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h))
        
        h = self.avg1(h)
        
        h = F.relu(self.conv3(h))
        h = F.relu(self.conv4(h))
        
        h = self.avg2(h)
        
        h = F.relu(self.conv5(h))
        h = F.relu(self.conv6(h))
        h = F.relu(self.conv7(h))
        
        h = h.view(-1, 1024)
        h = self.fc(h)
        
        return h
        

In [7]:
model_ft = MiniONN()

In [14]:
model_ft.load_state_dict(torch.load("./models/minionn/checkpoint.pt"))

<All keys matched successfully>

In [15]:
torch.save(model_ft.cpu(), "./models/minionn/checkpoint_cpu_cpu.pt")

In [8]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self,
                 patience=7,
                 verbose=False,
                 delta=0,
                 path='./models/minionn/checkpoint.pt',
                 trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(
                f'EarlyStopping counter: {self.counter} out of {self.patience}'
            )
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(
                f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...'
            )
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [9]:

class Trainer:
    def __init__(self,
                 trainloader,
                 vallaoder,
                 model_ft,
                 writer=None,
                 testloader=None,
                 checkpoint_path=None,
                 patience=5,
                 feature_extract=True,
                 print_itr=50):
        self.trainloader = trainloader
        self.valloader = vallaoder
        self.testloader = testloader

#         self.device = torch.device(
#             "cuda:0" if torch.cuda.is_available() else "cpu")

        self.device = torch.device("cpu")

        print("==" * 10)
        print("Training will be done on ", self.device)
        print("==" * 10)

        self.model = model_ft
        self.model = self.model.to(self.device)

        # Observe that all parameters are being optimized
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.0001)
        self.criterion = nn.BCELoss()

        self.early_stopping = EarlyStopping(patience=patience, verbose=True)
        self.writer = writer
        self.print_itr = print_itr

    def train(self, ep):
        self.model.train()

        running_loss = 0.0
        train_tqdm = tqdm(self.trainloader)

        for en, (x, y) in enumerate(train_tqdm):
            x = x.to(self.device)
            y = y.to(self.device)

            self.optimizer.zero_grad()

            outputs = F.sigmoid(self.model(x))
            loss = self.criterion(outputs, y)
            loss.backward()
            self.optimizer.step()

            running_loss += loss.item()
            train_tqdm.set_description("Loss: {}".format(running_loss))
            running_loss = 0

            # print statistics


#             running_loss += loss.item()
#             if (en + 1) % self.print_itr == 0:
#                 print('[%d, %5d] loss: %.3f' %
#                       (ep, en + 1, running_loss / self.print_itr))
#                 running_loss = 0.0

    def validate(self, ep):
        self.model.eval()

        total = 0
        correct = 0
        running_loss = 0.0
        for en, (x, y) in enumerate(tqdm(self.valloader)):

            x = x.to(self.device)
            y = y.to(self.device)

            outputs = F.sigmoid(self.model(x))
            loss = self.criterion(outputs, y)

            predicted = torch.tensor([1 if outputs[i]>=0.5 else 0 for i in range(outputs.shape[0])])

            total += y.size(0)
            correct += (predicted.squeeze() == y.cpu().squeeze()).sum().item()

            # print statistics
            running_loss += loss.item()
        return running_loss / len(self.valloader), correct * 100 / total

    def evaluate(self, ep, dataloader):
        self.model.eval()

        total = 0
        correct = 0
        for en, (x, y) in enumerate(tqdm(dataloader)):

            x = x.to(self.device)
            y = y.to(self.device)

            outputs = F.sigmoid(self.model(x))
            predicted = torch.tensor([1 if outputs[i]>=0.5 else 0 for i in range(outputs.shape[0])])
#             print(predicted.shape)
            total += y.size(0)
            correct += (predicted.squeeze() == y.cpu().squeeze()).sum().item()

        return correct * 100 / total

    def perform_training(self, total_epoch):
        val_loss, acc = self.validate(0)

        print("[Initial Validation results] Loss: {} \t Acc:{}".format(
            val_loss, acc))

        for i in range(total_epoch):
            self.train(i + 1)
            val_loss, acc = self.validate(i + 1)
            #             acc = self.evaluate(i+1, self.valloader)
            print("Epoch:{} \t Accuracy:{}".format(i+1, acc))
            if self.writer:
                self.writer.add_scalar('Validation Loss', val_loss, (i + 1))
                self.writer.add_scalar('Validation Acc', acc, (i + 1))

            self.early_stopping(val_loss, self.model)

            if self.early_stopping.early_stop:
                print("Early stopping")
                break

        print("=" * 20)
        print("Training finished !!")
        print("=" * 20)

In [10]:
writer = SummaryWriter('runs/minionn')
trainer = Trainer(train_loader, test_loader, model_ft, writer=writer)

Training will be done on  cpu


In [11]:
trainer.perform_training(30)

100%|██████████| 87/87 [00:04<00:00, 19.45it/s]
Loss: 0.682716429233551:   0%|          | 1/346 [00:00<00:54,  6.37it/s]

[Initial Validation results] Loss: 0.6851303200612123 	 Acc:74.04055032585083


Loss: 0.5923810005187988: 100%|██████████| 346/346 [00:26<00:00, 13.10it/s] 
100%|██████████| 87/87 [00:04<00:00, 19.86it/s]
Loss: 0.7306787967681885:   1%|          | 2/346 [00:00<00:25, 13.65it/s]

Epoch:1 	 Accuracy:75.74221578566257
Validation loss decreased (inf --> 0.477536).  Saving model ...


Loss: 0.8527958393096924: 100%|██████████| 346/346 [00:22<00:00, 15.06it/s] 
100%|██████████| 87/87 [00:04<00:00, 19.85it/s]
Loss: 0.6016228199005127:   1%|          | 2/346 [00:00<00:23, 14.46it/s]

Epoch:2 	 Accuracy:83.20057929036929
Validation loss decreased (0.477536 --> 0.385155).  Saving model ...


Loss: 0.31926199793815613: 100%|██████████| 346/346 [00:22<00:00, 15.06it/s]
100%|██████████| 87/87 [00:04<00:00, 19.90it/s]
Loss: 0.29629406332969666:   1%|          | 2/346 [00:00<00:23, 14.36it/s]

Epoch:3 	 Accuracy:86.42288196958725
Validation loss decreased (0.385155 --> 0.328742).  Saving model ...


Loss: 0.1743118017911911: 100%|██████████| 346/346 [00:22<00:00, 15.10it/s] 
100%|██████████| 87/87 [00:04<00:00, 19.91it/s]
Loss: 0.3928096294403076:   1%|          | 2/346 [00:00<00:24, 14.22it/s]

Epoch:4 	 Accuracy:86.13323678493845
EarlyStopping counter: 1 out of 5


Loss: 0.6721366047859192: 100%|██████████| 346/346 [00:23<00:00, 14.85it/s] 
100%|██████████| 87/87 [00:04<00:00, 20.07it/s]
Loss: 0.2701786756515503:   1%|          | 2/346 [00:00<00:24, 14.24it/s]

Epoch:5 	 Accuracy:87.21940622737146
Validation loss decreased (0.328742 --> 0.286246).  Saving model ...


Loss: 0.22350013256072998: 100%|██████████| 346/346 [00:23<00:00, 14.80it/s]
100%|██████████| 87/87 [00:04<00:00, 19.87it/s]
Loss: 0.5423309803009033:   1%|          | 2/346 [00:00<00:25, 13.61it/s]

Epoch:6 	 Accuracy:87.69007965242578
Validation loss decreased (0.286246 --> 0.281410).  Saving model ...


Loss: 0.1103893518447876: 100%|██████████| 346/346 [00:23<00:00, 14.65it/s] 
100%|██████████| 87/87 [00:04<00:00, 19.88it/s]
Loss: 0.3503119647502899:   1%|          | 2/346 [00:00<00:25, 13.39it/s]

Epoch:7 	 Accuracy:89.10209992758871
Validation loss decreased (0.281410 --> 0.274759).  Saving model ...


Loss: 0.22305792570114136: 100%|██████████| 346/346 [00:23<00:00, 14.69it/s]
100%|██████████| 87/87 [00:04<00:00, 19.78it/s]
Loss: 0.24596792459487915:   1%|          | 2/346 [00:00<00:24, 13.79it/s]

Epoch:8 	 Accuracy:88.95727733526431
Validation loss decreased (0.274759 --> 0.267993).  Saving model ...


Loss: 0.03144235908985138: 100%|██████████| 346/346 [00:23<00:00, 14.51it/s]
100%|██████████| 87/87 [00:04<00:00, 19.57it/s]
Loss: 0.26645973324775696:   1%|          | 2/346 [00:00<00:24, 14.13it/s]

Epoch:9 	 Accuracy:88.7762490948588
EarlyStopping counter: 1 out of 5


Loss: 0.3038310408592224: 100%|██████████| 346/346 [00:23<00:00, 14.70it/s]  
100%|██████████| 87/87 [00:04<00:00, 20.18it/s]
Loss: 0.07431577891111374:   0%|          | 1/346 [00:00<00:40,  8.59it/s]

Epoch:10 	 Accuracy:89.46415640839972
Validation loss decreased (0.267993 --> 0.253791).  Saving model ...


Loss: 0.10680681467056274: 100%|██████████| 346/346 [00:23<00:00, 14.91it/s]
100%|██████████| 87/87 [00:04<00:00, 19.70it/s]
Loss: 0.24118764698505402:   1%|          | 2/346 [00:00<00:24, 13.86it/s]

Epoch:11 	 Accuracy:89.64518464880521
EarlyStopping counter: 1 out of 5


Loss: 0.37095174193382263: 100%|██████████| 346/346 [00:23<00:00, 14.86it/s]
100%|██████████| 87/87 [00:04<00:00, 19.85it/s]
Loss: 0.17621935904026031:   1%|          | 2/346 [00:00<00:24, 14.03it/s]

Epoch:12 	 Accuracy:90.07965242577842
EarlyStopping counter: 2 out of 5


Loss: 0.19959914684295654: 100%|██████████| 346/346 [00:23<00:00, 14.72it/s]
100%|██████████| 87/87 [00:04<00:00, 19.30it/s]
Loss: 0.26152029633522034:   1%|          | 2/346 [00:00<00:25, 13.58it/s]

Epoch:13 	 Accuracy:88.74004344677769
EarlyStopping counter: 3 out of 5


Loss: 0.13465173542499542: 100%|██████████| 346/346 [00:23<00:00, 14.61it/s]
100%|██████████| 87/87 [00:04<00:00, 20.04it/s]
Loss: 0.28691455721855164:   1%|          | 2/346 [00:00<00:25, 13.58it/s]

Epoch:14 	 Accuracy:89.82621288921072
Validation loss decreased (0.253791 --> 0.248801).  Saving model ...


Loss: 0.07968512922525406: 100%|██████████| 346/346 [00:23<00:00, 14.42it/s]
100%|██████████| 87/87 [00:04<00:00, 20.13it/s]
Loss: 0.22770081460475922:   1%|          | 2/346 [00:00<00:25, 13.41it/s]

Epoch:15 	 Accuracy:90.11585807385953
Validation loss decreased (0.248801 --> 0.242016).  Saving model ...


Loss: 0.4892733097076416: 100%|██████████| 346/346 [00:24<00:00, 14.25it/s] 
100%|██████████| 87/87 [00:04<00:00, 20.03it/s]
Loss: 0.2366361767053604:   1%|          | 2/346 [00:00<00:24, 13.94it/s]

Epoch:16 	 Accuracy:90.26068066618393
Validation loss decreased (0.242016 --> 0.241898).  Saving model ...


Loss: 0.08033476024866104: 100%|██████████| 346/346 [00:23<00:00, 14.49it/s] 
100%|██████████| 87/87 [00:04<00:00, 20.15it/s]
Loss: 0.2778118848800659:   0%|          | 1/346 [00:00<00:41,  8.22it/s] 

Epoch:17 	 Accuracy:90.18826937002173
Validation loss decreased (0.241898 --> 0.236457).  Saving model ...


Loss: 0.010079403407871723: 100%|██████████| 346/346 [00:24<00:00, 14.36it/s]
100%|██████████| 87/87 [00:04<00:00, 19.28it/s]
Loss: 0.10027806460857391:   1%|          | 2/346 [00:00<00:25, 13.55it/s]

Epoch:18 	 Accuracy:89.42795076031861
EarlyStopping counter: 1 out of 5


Loss: 0.01889205537736416: 100%|██████████| 346/346 [00:24<00:00, 14.29it/s]
100%|██████████| 87/87 [00:04<00:00, 19.99it/s]
Loss: 0.18170005083084106:   1%|          | 2/346 [00:00<00:24, 13.83it/s]

Epoch:19 	 Accuracy:89.31933381607531
EarlyStopping counter: 2 out of 5


Loss: 0.1949135810136795: 100%|██████████| 346/346 [00:23<00:00, 14.53it/s]  
100%|██████████| 87/87 [00:04<00:00, 19.54it/s]
Loss: 0.11915119737386703:   1%|          | 2/346 [00:00<00:25, 13.67it/s]

Epoch:20 	 Accuracy:90.80376538740043
Validation loss decreased (0.236457 --> 0.224152).  Saving model ...


Loss: 0.0917714536190033: 100%|██████████| 346/346 [00:23<00:00, 14.87it/s] 
100%|██████████| 87/87 [00:04<00:00, 19.70it/s]
Loss: 0.26502206921577454:   1%|          | 2/346 [00:00<00:24, 13.97it/s]

Epoch:21 	 Accuracy:88.848660391021
EarlyStopping counter: 1 out of 5


Loss: 0.5337105393409729: 100%|██████████| 346/346 [00:23<00:00, 14.87it/s] 
100%|██████████| 87/87 [00:04<00:00, 19.58it/s]
Loss: 0.2321438044309616:   1%|          | 2/346 [00:00<00:25, 13.39it/s]

Epoch:22 	 Accuracy:91.09341057204924
Validation loss decreased (0.224152 --> 0.222573).  Saving model ...


Loss: 0.2796437442302704: 100%|██████████| 346/346 [00:23<00:00, 14.73it/s] 
100%|██████████| 87/87 [00:04<00:00, 19.53it/s]
Loss: 0.3258102238178253:   1%|          | 2/346 [00:00<00:25, 13.66it/s]

Epoch:23 	 Accuracy:91.20202751629255
EarlyStopping counter: 1 out of 5


Loss: 0.021579839289188385: 100%|██████████| 346/346 [00:23<00:00, 14.63it/s]
100%|██████████| 87/87 [00:04<00:00, 19.57it/s]
Loss: 0.10655373334884644:   0%|          | 1/346 [00:00<00:42,  8.18it/s]

Epoch:24 	 Accuracy:91.02099927588704
EarlyStopping counter: 2 out of 5


Loss: 0.18538106977939606: 100%|██████████| 346/346 [00:23<00:00, 14.67it/s] 
100%|██████████| 87/87 [00:04<00:00, 19.67it/s]
Loss: 0.13690230250358582:   1%|          | 2/346 [00:00<00:25, 13.27it/s]

Epoch:25 	 Accuracy:90.91238233164374
EarlyStopping counter: 3 out of 5


Loss: 0.10920313000679016: 100%|██████████| 346/346 [00:23<00:00, 14.57it/s] 
100%|██████████| 87/87 [00:04<00:00, 19.74it/s]
Loss: 0.20582206547260284:   1%|          | 2/346 [00:00<00:25, 13.61it/s]

Epoch:26 	 Accuracy:90.76755973931934
EarlyStopping counter: 4 out of 5


Loss: 0.06214547157287598: 100%|██████████| 346/346 [00:23<00:00, 14.64it/s] 
100%|██████████| 87/87 [00:04<00:00, 20.00it/s]
Loss: 0.20739884674549103:   1%|          | 2/346 [00:00<00:24, 13.98it/s]

Epoch:27 	 Accuracy:91.52787834902244
Validation loss decreased (0.222573 --> 0.218269).  Saving model ...


Loss: 0.19360971450805664: 100%|██████████| 346/346 [00:23<00:00, 14.60it/s] 
100%|██████████| 87/87 [00:04<00:00, 19.96it/s]
Loss: 0.09204600006341934:   1%|          | 2/346 [00:00<00:25, 13.48it/s]

Epoch:28 	 Accuracy:90.80376538740043
EarlyStopping counter: 1 out of 5


Loss: 0.15398919582366943: 100%|██████████| 346/346 [00:23<00:00, 14.61it/s] 
100%|██████████| 87/87 [00:04<00:00, 19.27it/s]
Loss: 0.1832505315542221:   1%|          | 2/346 [00:00<00:25, 13.36it/s]

Epoch:29 	 Accuracy:91.67270094134685
EarlyStopping counter: 2 out of 5


Loss: 0.39062440395355225: 100%|██████████| 346/346 [00:23<00:00, 14.56it/s] 
100%|██████████| 87/87 [00:04<00:00, 19.73it/s]

Epoch:30 	 Accuracy:88.26937002172339
EarlyStopping counter: 3 out of 5
Training finished !!





In [12]:
# trainer.evaluate(0, test_loader)

In [13]:
end_time = time.time()
print('Total execution training time: ', end_time-start_time)

Total execution training time:  847.6523497104645
