* Create dataset
* Import model and alter it for this problem
* Create custom train and test function
* Visualize performance

**Hyperparameter Tuning**
* Methods to try:
    * Data Augmentation
    * Learning rate schedule
    * Weighted Loss Function
    * Focal Loss
**Try cropping images to remove 'R' label from edge**

In [48]:
import numpy as np, pandas as pd, matplotlib.pyplot
import torch
import torchvision
from torchvision.models import mobilenet_v2
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torch.nn as nn
import cv2 as cv
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

In [49]:
# Hyperparameters
class cfg:
    IMG_SIZE = 224
    BATCH = 64
    EPOCHS = 10
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [50]:
# Dataset
class X_Ray_Dataset(Dataset):
    def __init__(self, paths, labels, transforms=None):
        super().__init__()
        self.paths = paths
        self.labels = labels
        self.transforms = transforms
        
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, idx):
        # get path
        # read and preprocess image
        # return img along with label (maybe with tranformations)
        path, label = self.paths[idx], self.labels[idx]
        img = torchvision.io.read_image(path, torchvision.io.image.ImageReadMode.RGB) / 255.
#         img = cv.cvtColor(cv.imread(path, 0), cv.COLOR_BGR2RGB) / 255.
#         img = cv.resize(img, (cfg.IMG_SIZE, cfg.IMG_SIZE))
        
        if self.transforms:
            img = self.transforms(img)
        
        return (img, label)

In [1]:
train_augs = transforms.Compose([
    transforms.Resize([cfg.IMG_SIZE, cfg.IMG_SIZE]),
    transforms.RandomRotation(degrees=359),
    transforms.RandomVerticalFlip(p=0.7),
#     transforms.RandomAutocontrast()
])

test_augs = transforms.Compose([transforms.Resize([cfg.IMG_SIZE, cfg.IMG_SIZE])])
df = pd.read_csv('X_Ray_Data.csv')
train_paths, test_paths, train_labels, test_labels = train_test_split(df.path.values, df.label.values, test_size=0.2, random_state=0)

train_ds = X_Ray_Dataset(train_paths, train_labels, transforms=train_augs)
test_ds = X_Ray_Dataset(test_paths, test_labels, transforms=test_augs)

train_loader = DataLoader(train_ds, shuffle=True, batch_size=cfg.BATCH)
test_loader = DataLoader(test_ds, shuffle=False, batch_size=cfg.BATCH)

NameError: name 'transforms' is not defined

In [96]:
def fit(model, dataloader, opt, criterion, epochs):
    # send imgs through network
    # compute loss
    # compute gradients and update weights
    # print metrics
    for epoch in range(epochs):
        total_loss = 0.0
        acc = 0
        model.train()
    
        for i, (imgs, labels) in enumerate(dataloader):
            imgs, labels = imgs.to(cfg.device), labels.to(cfg.device) # Place images onto gpu
            output = model(imgs.float()) # Forward pass
            loss = criterion(output, labels) # Calculate loss
            opt.zero_grad() # Zero the previous gradients held by optimizer
            loss.backward() # calculate new gradient of loss wrt current parameters
            opt.step() # Update model weights
            
            total_loss += loss
            _, preds = torch.max(output, 1)
            correct = len(torch.where(preds==labels)[0])
            acc += correct
        
        # print metrics
        print(f'Epoch {epoch+1} | Loss: {total_loss/i+1} | Accuracy: {acc/len(dataloader)}')
        
def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss = 0.0
    acc = 0.0
    P = torch.Tensor()
    
    for i, (imgs, labels) in enumerate(dataloader):
        imgs, labels = imgs.float().to(cfg.device), labels.to(cfg.device)
        
        with torch.no_grad():
            output = model(imgs)
            loss = criterion(output, labels)
        
        _, preds = torch.max(output, 1)
        P = torch.cat([P.to(cfg.device), preds], dim=0)
        acc += len(torch.where(preds==labels)[0])
        total_loss += loss
    
    print(f'Accuracy: {acc/len(dataloader):.1f}% | Loss: {total_loss/i+1:.3f}')
    
    return P.cpu().numpy()

In [57]:
model = mobilenet_v2(pretrained=True)
model.classifier[1] = nn.Linear(in_features=1280, out_features=4, bias=True)
model.to(cfg.device)
opt = torch.optim.Adam(model.parameters(), lr=.0001)
criterion = nn.CrossEntropyLoss()

In [58]:
fit(model, train_loader, opt=opt, criterion=criterion, epochs=cfg.EPOCHS)

Epoch 1 | Loss: 0.4096754789352417 | Accuracy: 0.8513849431818182
Epoch 2 | Loss: 0.2524297535419464 | Accuracy: 0.9103929924242424
Epoch 3 | Loss: 0.22697879374027252 | Accuracy: 0.9212831439393939
Epoch 4 | Loss: 0.198633074760437 | Accuracy: 0.9317589962121212
Epoch 5 | Loss: 0.19276265799999237 | Accuracy: 0.9318773674242424
Epoch 6 | Loss: 0.17544551193714142 | Accuracy: 0.9398082386363636
Epoch 7 | Loss: 0.17425788938999176 | Accuracy: 0.9381510416666666
Epoch 8 | Loss: 0.1672159880399704 | Accuracy: 0.9399266098484849
Epoch 9 | Loss: 0.15397192537784576 | Accuracy: 0.9459635416666666
Epoch 10 | Loss: 0.1499778777360916 | Accuracy: 0.9471472537878788


In [97]:
#model = torch.load('mobilenetV2').to('cuda')
L = []
for imgs, labels in test_loader:
    L.extend(labels.cpu().numpy())

preds = evaluate(model, test_loader, criterion)
print(classification_report(np.array(L), preds, target_names=['normal', 'lung_opacity', 'pneumonia', 'covid-19']))

Accuracy: 59.313 | Loss: 1.168
              precision    recall  f1-score   support

      normal       0.91      0.98      0.94      2039
lung_opacity       0.98      0.84      0.91      1226
   pneumonia       0.98      0.98      0.98       250
    covid-19       0.97      0.96      0.97       718

    accuracy                           0.94      4233
   macro avg       0.96      0.94      0.95      4233
weighted avg       0.94      0.94      0.94      4233



In [95]:
len(np.where(L[:50]==preds[:50])[0])

47

In [91]:
preds[:20]

array([0., 3., 0., 0., 0., 1., 3., 0., 0., 3., 0., 1., 3., 0., 0., 0., 0.,
       0., 3., 1.], dtype=float32)