In [2]:
import numpy as np
import pandas as pd
import torch
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F 
from torch import nn,optim
from torchvision import transforms as T,datasets,models
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from collections import OrderedDict
from tqdm import tqdm
pd.options.plotting.backend = "plotly"

In [3]:
data_dir = "../input/covid19-chest-xray-detection/covid_update"
TEST = 'Test'
TRAIN = 'Train'


In [4]:
def data_transforms(phase = None):
    
    if phase == TRAIN:

        data_T = T.Compose([
            
                T.Resize(size = (256,256)),
                T.RandomRotation(degrees = (-20,+20)),
                T.CenterCrop(size=224),
                T.ToTensor(),
                T.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
        ])
    
    elif phase == TEST or phase == VAL:

        data_T = T.Compose([

                T.Resize(size = (224,224)),
                T.ToTensor(),
                T.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
        ])
        
    return data_T

In [5]:
trainset = datasets.ImageFolder(os.path.join(data_dir, TRAIN),transform = data_transforms(TRAIN))
testset = datasets.ImageFolder(os.path.join(data_dir, TEST),transform = data_transforms(TEST))


In [6]:
class_names = trainset.classes
print(class_names)
print(trainset.class_to_idx)

['covid', 'non_covid']
{'covid': 0, 'non_covid': 1}


In [7]:
def plot_class_count(classes,name = None):
    pd.DataFrame(classes,columns = [name]).groupby([classes]).size().plot(kind = 'bar',title = name).show()

In [None]:
def get_class_count(dataset,name = None):
    classes = []
    for _,label in dataset:
        if label == 0:
            classes.append(class_names[label])
            
        elif label == 1:
            classes.append(class_names[label])
            
    return classes

In [None]:
trainset_class_count = get_class_count(trainset,name = 'trainset_classes_count')
plot_class_count(trainset_class_count,name = 'trainset_classes_count')

In [None]:
testset_class_count = get_class_count(testset,name = 'testset_classes_count')
plot_class_count(testset_class_count,name = 'testset_classes_count')

In [None]:
trainloader = DataLoader(trainset,batch_size = 16,shuffle = True)
testloader = DataLoader(testset,batch_size = 8,shuffle = True)

In [None]:
def show_image(image,title = None,get_denormalize = False):
    
    image = image.permute(1,2,0)
    mean = torch.FloatTensor([0.485, 0.456, 0.406])
    std = torch.FloatTensor([0.229, 0.224, 0.225])
    
    image = image*std + mean
    image = np.clip(image,0,1)
    
    if get_denormalize == False:
        plt.figure(figsize=[15, 15])
        plt.imshow(image)

        if title != None:
            plt.title(title)
            
    else : 
        return image

In [None]:
dataiter = iter(trainloader)
images,labels = dataiter.next()

out = make_grid(images,nrow=4)

show_image(out, title=[class_names[x] for x in labels])

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
model = models.vgg16()
model.load_state_dict(torch.load("../input/pretrained-model-weights-pytorch/vgg16-397923af.pth"))

In [None]:
for param in model.parameters():
    param.requires_grad = False




classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(25088, 4096)),
                                         ('relu', nn.ReLU()),
                                         ('dropout',nn.Dropout(0.3)),
                                         ('fc2', nn.Linear(4096, 4096)),
                                         ('relu', nn.ReLU()),
                                         ('drop', nn.Dropout(0.3)),
                                         ('fc3', nn.Linear(4096, 2)), 
                                         ('output', nn.LogSoftmax(dim = 1))]))

model.classifier = classifier
model.to(device)

In [None]:
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(),lr = 0.001)
schedular = optim.lr_scheduler.ReduceLROnPlateau(optimizer,factor = 0.1,patience = 5)
epochs = 5
valid_loss_min = np.Inf

In [None]:
def accuracy(y_pred,y_true):
    y_pred = torch.exp(y_pred)
    top_p,top_class = y_pred.topk(1,dim = 1)
    equals = top_class == y_true.view(*top_class.shape)
    return torch.mean(equals.type(torch.FloatTensor))

In [None]:
for i in range(epochs):
    
    train_loss = 0.0
    valid_loss = 0.0
    train_acc = 0.0
    valid_acc = 0.0 
    
    
    model.train()
    
    for images,labels in tqdm(trainloader):
        
        images = images.to(device)
        labels = labels.to(device)
        
        ps = model(images)
        loss = criterion(ps,labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_acc += accuracy(ps,labels)
        train_loss += loss.item()
        
    avg_train_acc = train_acc / len(trainloader)
    avg_train_loss = train_loss / len(trainloader)
        
    model.eval()
    with torch.no_grad():
        
        for images,labels in tqdm(testloader):
            
            images = images.to(device)
            labels = labels.to(device)
            
            ps = model(images)
            loss = criterion(ps,labels)
            
            valid_acc += accuracy(ps,labels)
            valid_loss += loss.item()
            
            
        avg_valid_acc = valid_acc / len(testloader)
        avg_valid_loss = valid_loss / len(testloader)
        
        schedular.step(avg_valid_loss)
        
        if avg_valid_loss <= valid_loss_min:
            print('Validation loss decreased ({:.6f} --> {:.6f}).   Saving model ...'.format(valid_loss_min,avg_valid_loss))
            torch.save({
                'epoch' : i,
                'model_state_dict' : model.state_dict(),
                'optimizer_state_dict' : optimizer.state_dict(),
                'valid_loss_min' : avg_valid_loss
            },'Pneumonia_model.pt')
            
            valid_loss_min = avg_valid_loss
            
            
    print("Epoch : {} Train Loss : {:.6f} Train Acc : {:.6f}".format(i+1,avg_train_loss,avg_train_acc))
    print("Epoch : {} Valid Loss : {:.6f} Valid Acc : {:.6f}".format(i+1,avg_valid_loss,avg_valid_acc))

In [None]:
model.eval()

test_loss = 0
test_acc = 0

for images,labels in testloader:
    
    images = images.to(device)
    labels = labels.to(device)
    
    pred = model(images)
    loss = criterion(pred,labels)
    
    test_loss += loss.item()
    test_acc += accuracy(pred,labels)
    
avg_test_loss = test_loss/len(testloader)
avg_test_acc = test_acc/len(testloader)

print("Test Loss : {:.6f} Test Acc : {:.6f}".format(avg_test_loss,avg_test_acc))

In [None]:
def view_classify(img,ps,label):
    
    class_name = ['NORMAL', 'COVID']
    classes = np.array(class_name)

    ps = ps.cpu().data.numpy().squeeze()
    img = show_image(img,get_denormalize = True)
    
    

    fig, (ax1, ax2) = plt.subplots(figsize=(8,12), ncols=2)
    ax1.imshow(img)
    ax1.set_title('Ground Truth : {}'.format(class_name[label]))
    ax1.axis('off')
    ax2.barh(classes, ps)
    ax2.set_aspect(0.1)
    ax2.set_yticks(classes)
    ax2.set_yticklabels(classes)
    ax2.set_title('Predicted Class')
    ax2.set_xlim(0, 1.1)

    plt.tight_layout()

    return None

In [None]:
image,label = testset[0]

ps = torch.exp(model(image.to(device).unsqueeze(0)))
view_classify(image,ps,label)

In [None]:
image,label = testset[135]

ps = torch.exp(model(image.to(device).unsqueeze(0)))
view_classify(image,ps,label)

In [None]:
image,label = testset[160]

ps = torch.exp(model(image.to(device).unsqueeze(0)))
view_classify(image,ps,label)

In [None]:
image,label = testset[35]

ps = torch.exp(model(image.to(device).unsqueeze(0)))
view_classify(image,ps,label)

In [None]:
image,label = testset[78]

ps = torch.exp(model(image.to(device).unsqueeze(0)))
view_classify(image,ps,label)

In [None]:
image,label = testset[180]

ps = torch.exp(model(image.to(device).unsqueeze(0)))
view_classify(image,ps,label)

In [None]:
image,label = testset[45]

ps = torch.exp(model(image.to(device).unsqueeze(0)))
view_classify(image,ps,label)

In [None]:
image,label = testset[45]

ps = torch.exp(model(image.to(device).unsqueeze(0)))
view_classify(image,ps,label)

In [None]:
image,label = testset[35]

ps = torch.exp(model(image.to(device).unsqueeze(0)))
view_classify(image,ps,label)