In [1]:
import time
start_time = time.time()

In [2]:
import numpy as np
import pandas as pd
import sys
from os import listdir, makedirs, getcwd, remove
from os.path import isfile, join, abspath, exists, isdir, expanduser
from PIL import Image
from tqdm import tqdm

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.autograd as autograd
from torch.utils.data import Dataset, DataLoader
import torch.utils.data as data
import torch.nn.functional as F
import torch.optim as optim
import itertools

import torchvision
from torchvision import transforms, datasets, models
from torch import Tensor
from torch.utils.data.sampler import SubsetRandomSampler


from torch.utils.tensorboard import SummaryWriter

from sklearn.model_selection import ShuffleSplit

from resnet import resnet32

In [3]:
class medical_dataset(Dataset):
    def __init__(self, covid_path, normal_path, transform=None):
        self.covid_path = covid_path
        self.normal_path = normal_path

        self.transform = transforms.Compose([
            lambda x: Image.open(x).convert('RGB'),
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])

        self.covid_img = listdir(self.covid_path)
        self.normal_img = listdir(self.normal_path)
        
        print("Class imbalance is: {}".format(len(self.normal_img)/len(self.covid_img)))
        
        self.indices=np.arange(len(self.covid_img)+len(self.normal_img))

    def __len__(self):
        return len(self.normal_img) + len(self.covid_img)

    def __getitem__(self, idx):

        if (idx + 1 <= len(self.normal_img)):
            img = self.transform(join(self.normal_path, self.normal_img[idx]))
            label = 0
        else:
            idx = idx - len(self.normal_img)
            img = self.transform(join(self.covid_path, self.covid_img[idx]))
            label = 1

        return img, torch.tensor([label]).float()

In [4]:
dataset = medical_dataset(
    covid_path='./dataset/archive/COVID-19_Radiography_Dataset/COVID/',
    normal_path='./dataset/archive/COVID-19_Radiography_Dataset/Normal/',
)

X_unshuffled = dataset.indices
rs = ShuffleSplit(n_splits=1, test_size=.2, random_state=32)
rs.get_n_splits(X_unshuffled)

train_ind = []
val_ind = []
for train_index, test_index in rs.split(X_unshuffled):
    train_ind.append(train_index)
    val_ind.append(test_index)

Class imbalance is: 2.8185840707964602


In [5]:
train_sampler = SubsetRandomSampler(train_ind[0].tolist())
test_sampler = SubsetRandomSampler(val_ind[0].tolist())

train_loader = DataLoader(dataset, batch_size=32, sampler=train_sampler)
test_loader = DataLoader(dataset, batch_size=32, sampler=test_sampler)

In [12]:
resnet_model = resnet32(1)

CifarResNet : Depth : 32 , Layers for each block : 5


In [13]:
# resnet_model.cuda()

In [14]:
resnet_model.load_state_dict(torch.load("./models/covid_resnet32/checkpoint.pt"))

<All keys matched successfully>

In [15]:
torch.save(resnet_model.cpu(), "./models/covid_resnet32/checkpoint_cpu_cpu.pt")

In [7]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self,
                 patience=7,
                 verbose=False,
                 delta=0,
                 path='./models/covid_resnet32/checkpoint.pt',
                 trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(
                f'EarlyStopping counter: {self.counter} out of {self.patience}'
            )
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(
                f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...'
            )
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [8]:
from tqdm.notebook import tqdm 

class Trainer:
    def __init__(self,
                 trainloader,
                 vallaoder,
                 model_ft,
                 writer=None,
                 testloader=None,
                 checkpoint_path=None,
                 patience=5,
                 feature_extract=True,
                 print_itr=50):
        self.trainloader = trainloader
        self.valloader = vallaoder
        self.testloader = testloader

#         self.device = torch.device(
#             "cuda:0" if torch.cuda.is_available() else "cpu")

        self.device = torch.device("cpu")

        print("==" * 10)
        print("Training will be done on ", self.device)
        print("==" * 10)

        self.model = model_ft
        self.model = self.model.to(self.device)

        # Observe that all parameters are being optimized
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.0001)
        self.criterion = nn.BCELoss()

        self.early_stopping = EarlyStopping(patience=patience, verbose=True)
        self.writer = writer
        self.print_itr = print_itr

    def train(self, ep):
        self.model.train()

        running_loss = 0.0
        train_tqdm = tqdm(self.trainloader)

        for en, (x, y) in enumerate(train_tqdm):
            x = x.to(self.device)
            y = y.to(self.device)

            self.optimizer.zero_grad()

            outputs = F.sigmoid(self.model(x))
            loss = self.criterion(outputs, y)
            loss.backward()
            self.optimizer.step()

            running_loss += loss.item()
            train_tqdm.set_description("Loss: {}".format(running_loss))
            running_loss = 0

            # print statistics


#             running_loss += loss.item()
#             if (en + 1) % self.print_itr == 0:
#                 print('[%d, %5d] loss: %.3f' %
#                       (ep, en + 1, running_loss / self.print_itr))
#                 running_loss = 0.0

    def validate(self, ep):
        self.model.eval()

        total = 0
        correct = 0
        running_loss = 0.0
        for en, (x, y) in enumerate(tqdm(self.valloader)):

            x = x.to(self.device)
            y = y.to(self.device)

            outputs = F.sigmoid(self.model(x))
            loss = self.criterion(outputs, y)

            predicted = torch.tensor([1 if outputs[i]>=0.5 else 0 for i in range(outputs.shape[0])])

            total += y.size(0)
            correct += (predicted.squeeze() == y.cpu().squeeze()).sum().item()

            # print statistics
            running_loss += loss.item()
        return running_loss / len(self.valloader), correct * 100 / total

    def evaluate(self, ep, dataloader):
        self.model.eval()

        total = 0
        correct = 0
        for en, (x, y) in enumerate(tqdm(dataloader)):

            x = x.to(self.device)
            y = y.to(self.device)

            outputs = F.sigmoid(self.model(x))
            predicted = torch.tensor([1 if outputs[i]>=0.5 else 0 for i in range(outputs.shape[0])])
#             print(predicted.shape)
            total += y.size(0)
            correct += (predicted.squeeze() == y.cpu().squeeze()).sum().item()

        return correct * 100 / total

    def perform_training(self, total_epoch):
        val_loss, acc = self.validate(0)

        print("[Initial Validation results] Loss: {} \t Acc:{}".format(
            val_loss, acc))

        for i in range(total_epoch):
            self.train(i + 1)
            val_loss, acc = self.validate(i + 1)
            #             acc = self.evaluate(i+1, self.valloader)
            print("Epoch:{} \t Accuracy:{}".format(i+1, acc))
            if self.writer:
                self.writer.add_scalar('Validation Loss', val_loss, (i + 1))
                self.writer.add_scalar('Validation Acc', acc, (i + 1))

            self.early_stopping(val_loss, self.model)

            if self.early_stopping.early_stop:
                print("Early stopping")
                break

        print("=" * 20)
        print("Training finished !!")
        print("=" * 20)

In [9]:
writer = SummaryWriter('runs/covid_resnet32')
trainer = Trainer(train_loader, test_loader, resnet_model, writer=writer)

Training will be done on  cpu


In [10]:
trainer.perform_training(30)

  0%|          | 0/87 [00:00<?, ?it/s]



[Initial Validation results] Loss: 21.891431611159753 	 Acc:74.04055032585083


  0%|          | 0/346 [00:00<?, ?it/s]

  0%|          | 0/87 [00:00<?, ?it/s]

Epoch:1 	 Accuracy:83.12816799420709
Validation loss decreased (inf --> 0.388356).  Saving model ...


  0%|          | 0/346 [00:00<?, ?it/s]

  0%|          | 0/87 [00:00<?, ?it/s]

Epoch:2 	 Accuracy:92.61404779145546
Validation loss decreased (0.388356 --> 0.197348).  Saving model ...


  0%|          | 0/346 [00:00<?, ?it/s]

  0%|          | 0/87 [00:00<?, ?it/s]

Epoch:3 	 Accuracy:90.04344677769733
EarlyStopping counter: 1 out of 5


  0%|          | 0/346 [00:00<?, ?it/s]

  0%|          | 0/87 [00:00<?, ?it/s]

Epoch:4 	 Accuracy:88.05213613323679
EarlyStopping counter: 2 out of 5


  0%|          | 0/346 [00:00<?, ?it/s]

  0%|          | 0/87 [00:00<?, ?it/s]

Epoch:5 	 Accuracy:83.3091962346126
EarlyStopping counter: 3 out of 5


  0%|          | 0/346 [00:00<?, ?it/s]

  0%|          | 0/87 [00:00<?, ?it/s]

Epoch:6 	 Accuracy:86.56770456191165
EarlyStopping counter: 4 out of 5


  0%|          | 0/346 [00:00<?, ?it/s]

  0%|          | 0/87 [00:00<?, ?it/s]

Epoch:7 	 Accuracy:92.50543084721217
EarlyStopping counter: 5 out of 5
Early stopping
Training finished !!


In [None]:
# trainer.evaluate(0, test_loader)

In [11]:
end_time = time.time()
print('Total execution training time: ', end_time-start_time)

Total execution training time:  325.31664276123047
