## Import

In [9]:
import cv2 as cv
import numpy as np
from os.path import join, exists
from os import mkdir
import torchvision.models as models
import torch
from tqdm import tqdm
from torch.optim import AdamW
from torch.utils.data import TensorDataset, DataLoader, random_split
from collections import abc

## Preprocess

### Set up dataset

In [10]:
dataset_dir = "COVID-19_Radiography_Dataset"
covid_dir = "COVID/images"
covid_size = 3616
lung_dir = "Lung_Opacity/images"
lung_size = 6012
normal_dir = "Normal/images"
normal_size = 10192
pneumo_dir = "Viral Pneumonia/images"
pneumo_size = 1345

In [11]:
if not exists(dataset_dir) : raise FileNotFoundError("The dataset directory is wrong")

### Data tensors

In [12]:
def create_tensors(class_name:str, data_dir:str, size:int) :
    imgs = []
    for k in range(1,size+1) :
        fname = f"{class_name}-{k}.png"
        img_path = join(dataset_dir, data_dir, fname)
        img = cv.imread(img_path, cv.IMREAD_GRAYSCALE)
        imgs.append(img)
    data_pt = torch.from_numpy(np.asarray(imgs)).unsqueeze(1)
    labels_pt = torch.from_numpy(np.asarray([np.eye(4)[0] for k in range(size)]))
    return data_pt, labels_pt

In [13]:
covid_pt, covid_labels_pt = create_tensors("COVID", covid_dir, covid_size)
assert covid_pt.shape == (covid_size,1,299,299)
assert covid_pt.shape[0] == covid_labels_pt.shape[0]
assert covid_labels_pt[1] == 4

In [21]:
covid_imgs = []
for k in range(1,covid_size+1) :
    fname = f"COVID-{k}.png"
    img_path = join(dataset_dir, covid_dir, fname)
    img = cv.imread(img_path, cv.IMREAD_GRAYSCALE)
    covid_imgs.append(img)
covid_pt = torch.from_numpy(np.asarray(covid_imgs)).unsqueeze(1)
covid_labels_pt = torch.from_numpy(np.asarray([np.eye(4)[0] for k in range(covid_size)]))

In [22]:
normal_imgs = []
for k in range(1,normal_size+1) :
    fname = f"Normal-{k}.png"
    img_path = join(dataset_dir, normal_dir, fname)
    img = cv.imread(img_path, cv.IMREAD_GRAYSCALE)
    normal_imgs.append(img)
normal_pt = torch.from_numpy(np.asarray(normal_imgs)).unsqueeze(1)
normal_labels_pt = torch.from_numpy(np.asarray([np.eye(4)[1] for k in range(normal_size)]))

In [23]:
lung_imgs = []
for k in range(1,lung_size+1) :
    fname = f"Lung_Opacity-{k}.png"
    img_path = join(dataset_dir, lung_dir, fname)
    img = cv.imread(img_path, cv.IMREAD_GRAYSCALE)
    lung_imgs.append(img)
lung_pt = torch.from_numpy(np.asarray(lung_imgs)).unsqueeze(1)
lung_labels_pt = torch.from_numpy(np.asarray([np.eye(4)[2] for k in range(lung_size)]))

In [24]:
pneumo_imgs = []
for k in range(1,pneumo_size+1) :
    fname = f"Viral Pneumonia-{k}.png"
    img_path = join(dataset_dir, pneumo_dir, fname)
    img = cv.imread(img_path, cv.IMREAD_GRAYSCALE)
    pneumo_imgs.append(img)
pneumo_pt = torch.from_numpy(np.asarray(pneumo_imgs)).unsqueeze(1)
pneumo_labels_pt = torch.from_numpy(np.asarray([np.eye(4)[3] for k in range(pneumo_size)]))

In [None]:
permutation = torch.randperm()

In [25]:
imgs = torch.cat([covid_pt, normal_pt, lung_pt, pneumo_pt])
labels = torch.cat([covid_labels_pt, normal_labels_pt, lung_labels_pt, pneumo_labels_pt])
train_set = TensorDataset(imgs, labels)

In [27]:
torch.save({"imgs":imgs,"labels":labels}, "dataset.pt")

### Load dataset

If you have already pre-load the TensorDataset :

In [2]:
data = torch.load("dataset.pt")

  data = torch.load("dataset.pt")


In [None]:
train_set = TensorDataset(data["imgs"], data["labels"])

### DataLoader

In [None]:
sampler = .

In [4]:
train_dataloader = DataLoader(train_set, batch_size=16, sampler=sampler)

## Set model

In [5]:
resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

In [6]:
nb_input_fc = resnet.fc.in_features

In [7]:
fc = torch.nn.Linear(nb_input_fc, 4)
resnet._modules["fc"] = fc
resnet.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

In [8]:
optimizer = AdamW(resnet.parameters(), 10e-5)
loss = torch.nn.CrossEntropyLoss()

## Training loop

In [18]:
nb_epochs = 3

In [19]:
loss_list = []

In [20]:
for epoch in tqdm(range(nb_epochs)) :
   loss_value = 0
   for it, batch in tqdm(enumerate(train_dataloader)) :
      images = batch[0].float()
      labels = batch[1].float()

      optimizer.zero_grad()
      outputs = resnet(images)
      loss_pt = loss(outputs, labels)
      loss_pt.backward()
      optimizer.step()
      loss_value+=loss_pt.item()

   loss_list.append(loss_value)

  0%|          | 0/3 [00:00<?, ?it/s]

In [17]:
loss_list

[1.2130813598632812]

## Inference loop

## Test metrics