In [2]:
import copy
from os import path
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from math import ceil
from sklearn.model_selection import train_test_split
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision.models as models
%matplotlib inline

Matplotlib is building the font cache; this may take a moment.


In [3]:
data = pd.read_csv("labels.csv")

In [4]:
X, Y = np.asarray(data.x), np.asarray(data.y)

In [5]:
X_train, X_dev, Y_train, Y_dev = train_test_split(
    X, Y, train_size=0.8, random_state=19, shuffle=True
)

In [6]:
print("No of train samples: {}".format(len(X_train)))
print("No of dev samples: {}".format(len(X_dev)))

No of train samples: 224
No of dev samples: 56


In [7]:
def plot_classes(Y_train, Y_dev):
    # (non corona = 0, corona = 1)
    train_uniques, train_uniques_count = np.unique(Y_train, return_counts=True)
    dev_uniques, dev_uniques_count = np.unique(Y_dev, return_counts=True)

    train_uniques = train_uniques.astype(np.object)
    train_uniques[train_uniques == 0] = "Non Corona"
    train_uniques[train_uniques == 1] = "Corona"

    dev_uniques = dev_uniques.astype(np.object)
    dev_uniques[dev_uniques == 0] = "Non Corona"
    dev_uniques[dev_uniques == 1] = "Corona"

    plt.figure(figsize=(9, 3))
    plt.subplot(121)
    bar1 = plt.bar(train_uniques, train_uniques_count)
    bar1[0].set_color("g")
    bar1[1].set_color("r")
    plt.xlabel("Category")
    plt.ylabel("No of pictures")
    plt.title("Train set")
    plt.subplot(122)
    bar2 = plt.bar(dev_uniques, dev_uniques_count)
    bar2[0].set_color("g")
    bar2[1].set_color("r")
    plt.xlabel("Category")
    plt.ylabel("No of pictures")
    plt.title("Dev test")
    plt.show()

In [None]:
plot_classes(Y_train, Y_dev)

In [None]:
class CovidDataset(object):
    def __init__(self, X, Y, transforms=None):
        self.X = X
        self.Y = Y
        self.transforms = transforms

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        img_path = self.X[idx]
        img_label = self.Y[idx]

        img = Image.open(path.join("data/covid_data/", img_path)).convert("RGB")

        if self.transforms:
            img = self.transforms(img)

        # print(img_path)
        # print(img.shape)
        # print(img_label)
        return img, img_label

In [None]:
def get_transformations(for_train=True, resize=(128, 128)):
    transformations = {
        "train_transforms": transforms.Compose(
            [
                transforms.RandomRotation(degrees=(-5, 5)),
                transforms.RandomAffine(degrees=0, shear=(-0.05, 0.05)),
                transforms.RandomHorizontalFlip(0.5),
                transforms.Resize(resize),
                transforms.ToTensor(),
            ]
        ),
        "test_transforms": transforms.Compose(
            [transforms.Resize(resize), transforms.ToTensor()]
        ),
    }
    if for_train:
        return transformations["train_transforms"]
    else:
        return transformations["test_transforms"]

In [None]:
def show_images(img_path, transforms=None):
    img = Image.open(img_path).convert("RGB")
    if transforms:
        img1 = transforms(img)
        img1 = img1.permute(2, 1, 0)

    plt.figure(figsize=(9, 3))
    plt.subplot(121)
    plt.imshow(img)
    plt.title("Normal")
    if transforms:
        plt.subplot(122)
        plt.imshow(img1)
        plt.title("Transformed")
    plt.show()

In [None]:
show_images(img_path="data/covid_data/Corona2_9.jpg", transforms=None)

In [None]:
show_images(
    img_path="data/covid_data/Corona2_9.jpg",
    transforms=get_transformations(for_train=True),
)

In [None]:
BATCH_SIZE = 8

In [None]:
train_dataset = CovidDataset(X_train, Y_train, get_transformations(for_train=True))
dev_dataset = CovidDataset(X_dev, Y_dev, get_transformations(for_train=False))

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def get_mnasnet1_0():
    model = models.mnasnet1_0(pretrained=True)
    model.classifier = torch.nn.Linear(in_features=1280, out_features=2)
    model = model.to(device)
    return model

In [None]:
def get_alexnet():
    model = models.alexnet(pretrained=True)
    model.classifier[6] = torch.nn.Linear(in_features=4096, out_features=2)
    model = model.to(device)
    return model

In [None]:
model = get_alexnet()

In [None]:
criterion = torch.nn.CrossEntropyLoss()
criterion = criterion.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

In [None]:
def train_step(model, inputs, labels, criterion, optimizer):
    optimizer.zero_grad()

    preds = model(inputs)
    loss = criterion(preds, labels)

    loss.backward()
    optimizer.step()

    return preds, loss

In [None]:
def eval_step(model, inputs, labels, criterion):
    preds = model(inputs)
    loss = criterion(preds, labels)

    return preds, loss

In [None]:
def train_epoch(model, train_dataset, criterion, optimizer, batch_size):
    train_loader = DataLoader(
        dataset=train_dataset, batch_size=batch_size, shuffle=True
    )

    correct_count = 0
    total_loss = 0

    model.train()
    for imgs, labels in train_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        preds, loss = train_step(model, imgs, labels, criterion, optimizer)

        preds = torch.argmax(preds, axis=1)
        correct_count += (preds == labels).sum().item()
        total_loss += loss.item()

    return correct_count / len(train_dataset), total_loss / len(train_dataset)

In [None]:
def eval_epoch(model, dev_dataset, criterion, batch_size):
    dev_loader = DataLoader(dataset=dev_dataset, batch_size=batch_size, shuffle=True)

    correct_count = 0
    total_loss = 0

    model.eval()
    with torch.no_grad():
        for imgs, labels in dev_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)

            preds, loss = eval_step(model, imgs, labels, criterion)

            preds = torch.argmax(preds, axis=1)
            correct_count += (preds == labels).sum().item()
            total_loss = loss.item()

    return correct_count / len(dev_dataset), total_loss / len(dev_dataset)

In [None]:
def train(model, train_dataset, dev_dataset, criterion, optimizer, num_epochs=25):

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):

        print("Epoch {}/{}".format(epoch + 1, num_epochs))
        print("-" * 10)

        train_acc, train_loss = train_epoch(
            model, train_dataset, criterion, optimizer, BATCH_SIZE
        )
        print("train_acc: {:.4f}, train_loss: {:.4f}".format(train_acc, train_loss))

        dev_acc, dev_loss = eval_epoch(model, dev_dataset, criterion, BATCH_SIZE)
        print("dev_acc: {:.4f}, dev_loss: {:.4f}".format(dev_acc, dev_loss))

        if dev_acc > best_acc:
            best_acc = dev_acc
            best_model_wts = copy.deepcopy(model.state_dict())

    model.load_state_dict(best_model_wts)
    return model

In [None]:
model = train(model, train_dataset, dev_dataset, criterion, optimizer, 5)