In [1]:
import cv2 as cv
import numpy as np
np.set_printoptions(suppress=True)
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.v2 as v2
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from tqdm import tqdm

import random
import glob
from os import cpu_count

ROOT='/kaggle/input/covidqu/Infection Segmentation Data/Infection Segmentation Data'

In [2]:
def imread(path):
    img = cv.imread(path, cv.IMREAD_GRAYSCALE)
    return img

In [3]:
%%capture
xray = imread(f'{ROOT}/Train/COVID-19/images/covid_1.png')
lung = imread(f'{ROOT}/Train/COVID-19/lung masks/covid_1.png')
infect = imread(f'{ROOT}/Train/COVID-19/infection masks/covid_1.png')

plt.imshow(xray, cmap='gray')
plt.xticks([])
plt.yticks([])
plt.savefig('xray')

plt.imshow(lung, cmap='gray')
plt.xticks([])
plt.yticks([])
plt.savefig('lung')

plt.imshow(infect, cmap='gray')
plt.xticks([])
plt.yticks([])
plt.savefig('infect')

In [4]:
!ls /kaggle/input/covidqu/Infection\ Segmentation\ Data/Infection\ Segmentation\ Data/Test/COVID-19/images | wc -l
!ls /kaggle/input/covidqu/Infection\ Segmentation\ Data/Infection\ Segmentation\ Data/Test/Non-COVID/images | wc -l
!ls /kaggle/input/covidqu/Infection\ Segmentation\ Data/Infection\ Segmentation\ Data/Test/Normal/images | wc -l

583
292
291


In [5]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        )

    def forward(self, x):
        return self.double_conv(x)
    
class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.down_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.down_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels//2, 2, 2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x1,x2], dim=1)
        return self.conv(x)
    
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super().__init__()
        self.conv = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.up1 = Up(512, 256)
        self.up2 = Up(256, 128)
        self.up3 = Up(128, 64)
        self.out = nn.Conv2d(64, n_classes, 1)
        
    # NOTE: resize so that height and width divided by 8
    def forward(self, x):
        x1 = self.conv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        
        x = self.up1(x4, x3)
        del x4, x3
        x = self.up2(x, x2)
        del x2
        x = self.up3(x, x1)
        del x1
        
        return self.out(x)

In [6]:
def get_paths(split):
    if split not in ['train', 'test']:
        return print("split can only be train or test!")
    if split == 'train':
        xray_paths = (
            glob.glob(f'{ROOT}/Train/*/images/*') +
            glob.glob(f'{ROOT}/Val/*/images/*')
        )
        mask_paths = (
            glob.glob(f'{ROOT}/Train/*/infection masks/*') +
            glob.glob(f'{ROOT}/Val/*/infection masks/*')
        )
    else:
         xray_paths = glob.glob(f'{ROOT}/Test/*/images/*')
         mask_paths = glob.glob(f'{ROOT}/Test/*/infection masks/*')
            
    return xray_paths, mask_paths

In [7]:
def transform(img, flip:bool):
    img = torch.Tensor(img).float().unsqueeze(0)
    img = v2.RandomHorizontalFlip(flip)(img)
    return img
    
class XRAY:
    def __init__(self, split, flip):
        if split not in ['train', 'test']:
            return print("split can only be train or test!")

        self.xray_paths, self.mask_paths = get_paths(split)
        self.flip = flip
        
    def __len__(self):
        return len(self.xray_paths)
    
    def __getitem__(self, idx):
        xray_path = self.xray_paths[idx]
        mask_path = self.mask_paths[idx]
        
        xray = imread(xray_path)
        mask = imread(mask_path)
        
        flip = random.choice([True, False]) if self.flip else False
        xray = transform(xray, flip)
        mask = transform(mask, flip)
        
        return xray, mask/255

In [8]:
train = XRAY('train', False)
test = XRAY('test', False)

NUM_CPU = cpu_count()
test_loader = DataLoader(test, batch_size=16, shuffle=False, num_workers=NUM_CPU)
train_loader = DataLoader(train, batch_size=16, shuffle=True, num_workers=NUM_CPU)

In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = UNet(1,1).to(device)
criterion = nn.BCEWithLogitsLoss().to(device)
optim_fn = optim.Adam(model.parameters(), lr=1e-4)

n_epochs = 10
best_loss = 1e5
for epoch in range(1, n_epochs+1):
    model.train()
    losses = []
    for xray, mask in tqdm(train_loader, desc=f'Epoch {epoch}'):
        xray = xray.to(device)
        mask = mask.to(device)
        
        pred = model(xray)
        loss = criterion(pred, mask)
        losses.append(loss)
        
        loss.backward()
        optim_fn.step()
        
    mean_loss = sum(losses)/len(losses)
    print(f'Train loss: {mean_loss:.5f}')

    # EVAL
    model.eval()
    losses = []
    for xray, mask in tqdm(test_loader, desc=f'Validate'):
        xray = xray.to(device)
        mask = mask.to(device)
        
        with torch.no_grad():
            pred = model(xray)
            loss = criterion(pred, mask)
            losses.append(loss)
    
    mean_loss = sum(losses)/len(losses)
    print(f'Test loss: {mean_loss:.5f}')
    if mean_loss < best_loss:
        torch.save(model.state_dict(), 'best_model.pt')

Epoch 1: 100%|██████████| 292/292 [01:37<00:00,  3.00it/s]


Train loss: 0.38326


Validate: 100%|██████████| 73/73 [00:08<00:00,  8.90it/s]


Test loss: 0.27797


Epoch 2: 100%|██████████| 292/292 [01:36<00:00,  3.02it/s]


Train loss: 0.24833


Validate: 100%|██████████| 73/73 [00:08<00:00,  8.92it/s]


Test loss: 0.22347


Epoch 3: 100%|██████████| 292/292 [01:36<00:00,  3.02it/s]


Train loss: 0.21499


Validate: 100%|██████████| 73/73 [00:08<00:00,  8.89it/s]


Test loss: 0.21772


Epoch 4: 100%|██████████| 292/292 [01:36<00:00,  3.02it/s]


Train loss: 0.20923


Validate: 100%|██████████| 73/73 [00:08<00:00,  8.90it/s]


Test loss: 0.21059


Epoch 5: 100%|██████████| 292/292 [01:36<00:00,  3.02it/s]


Train loss: 0.20190


Validate: 100%|██████████| 73/73 [00:08<00:00,  8.90it/s]


Test loss: 0.19734


Epoch 6: 100%|██████████| 292/292 [01:36<00:00,  3.02it/s]


Train loss: 0.19583


Validate: 100%|██████████| 73/73 [00:08<00:00,  8.93it/s]


Test loss: 0.19547


Epoch 7: 100%|██████████| 292/292 [01:36<00:00,  3.02it/s]


Train loss: 0.19201


Validate: 100%|██████████| 73/73 [00:08<00:00,  8.84it/s]


Test loss: 0.19605


Epoch 8: 100%|██████████| 292/292 [01:36<00:00,  3.02it/s]


Train loss: 0.19294


Validate: 100%|██████████| 73/73 [00:08<00:00,  8.91it/s]


Test loss: 0.19287


Epoch 9: 100%|██████████| 292/292 [01:36<00:00,  3.02it/s]


Train loss: 0.19032


Validate: 100%|██████████| 73/73 [00:08<00:00,  8.89it/s]


Test loss: 0.18926


Epoch 10: 100%|██████████| 292/292 [01:36<00:00,  3.02it/s]


Train loss: 0.18877


Validate: 100%|██████████| 73/73 [00:08<00:00,  8.87it/s]


Test loss: 0.19267


In [10]:
model = UNet(1,1).to(device)
model.load_state_dict(torch.load('best_model.pt'))
model.eval()

rmses = []
for xray, mask in tqdm(test_loader, desc=f'Validate'):
    xray = xray.to(device)
    with torch.no_grad():
        pred = model(xray)

    pred = pred.detach().cpu()
    pred = (torch.sigmoid(pred) > 0.5).int()

    mask = mask.detach().cpu().bool().int()
    RMSE = torch.sqrt((pred - mask)**2).mean(dim=[2,3])
    rmses += RMSE
        
print(f'Metric: {sum(rmses)/len(rmses)}')

Validate: 100%|██████████| 73/73 [00:09<00:00,  7.58it/s]

Metric: tensor([0.0695])



