In [1]:
import os
import numpy as np
from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torchvision import models
import torch
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import precision_score, recall_score, f1_score
from torch import nn
from torch.utils.data.dataloader import DataLoader
from matplotlib import pyplot as plt

import random

In [2]:
VALCSV = 'DATASET_MED/Only_labels/Only_labels_val.csv'
TRAINCSV  = 'DATASET_MED/Only_labels/Only_labels_train.csv'
LABELS = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Effusion',
          'Emphysema', 'Fibrosis', 'Hernia', 'Infiltration', 'Mass', 'Nodule',
          'Pleural_Thickening', 'Pneumonia', 'Pneumothorax']

In [3]:
class XDataset(Dataset):
    def __init__(self, anno_path, transforms):
        self.transforms = transforms
        
        self.imgs = []
        self.annos = []
        print('loading', anno_path)
        
        with open(anno_path) as fp:
            for i in fp:
                if i.split(',')[1].replace('\n','').split('|')[0] == 'No Finding':
                    self.imgs.append(anno_path.replace(anno_path.split('/')[-1],'') + 'Images/' + i.split(',')[0])
                    self.annos.append([])
                else:
                    self.imgs.append(anno_path.replace(anno_path.split('/')[-1],'') + 'Images/' + i.split(',')[0])
                    self.annos.append(i.split(',')[1].replace('\n','').split('|'))
                    
        self.classes = LABELS
        
        for item_id in range(len(self.annos)):
            item = self.annos[item_id]
            vector = [cls in item for cls in self.classes]
            self.annos[item_id] = np.array(vector, dtype=float)
        
    def __getitem__(self, item):
        anno = self.annos[item]
        img_path = os.path.join(self.imgs[item])
        img = Image.open(img_path).convert('RGB')
        if self.transforms is not None:
            img = self.transforms(img)
        return img, anno

    def __len__(self):
        return len(self.imgs)

In [None]:
dataset_val = XDataset(os.path.join(VALCSV), None)
dataset_train = XDataset(os.path.join(TRAINCSV), None)

def show_sample(img, binary_img_labels):
    img_labels = np.array(dataset_val.classes)[np.argwhere(binary_img_labels > 0)[:, 0]]
    plt.imshow(img)
    plt.title("{}".format(', '.join(img_labels)))
    plt.axis('off')
    plt.show()

for sample_id in range(5):
    show_sample(*dataset_val[sample_id])

In [5]:
class Resnet18(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        resnet = models.resnet18(pretrained=True)
        resnet.fc = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=resnet.fc.in_features, out_features=n_classes)
        )
        self.base_model = resnet
        self.sigm = nn.Sigmoid()

    def forward(self, x):
        return self.sigm(self.base_model(x))

In [6]:
def calculate_metrics(pred, target, threshold=0.5):
    pred = np.array(pred > threshold, dtype=float)
    return {'micro/precision': precision_score(y_true=target, y_pred=pred, average='micro', zero_division=True),
            'micro/recall': recall_score(y_true=target, y_pred=pred, average='micro', zero_division=True),
            'micro/f1': f1_score(y_true=target, y_pred=pred, average='micro', zero_division=True),
            'macro/precision': precision_score(y_true=target, y_pred=pred, average='macro', zero_division=True),
            'macro/recall': recall_score(y_true=target, y_pred=pred, average='macro', zero_division=True),
            'macro/f1': f1_score(y_true=target, y_pred=pred, average='macro', zero_division=True),
            'samples/precision': precision_score(y_true=target, y_pred=pred, average='samples', zero_division=True),
            'samples/recall': recall_score(y_true=target, y_pred=pred, average='samples', zero_division=True),
            'samples/f1': f1_score(y_true=target, y_pred=pred, average='samples', zero_division=True),
            }

In [7]:
num_workers = 0
lr = 1e-4 # 1e-3
batch_size = 8
save_freq = 1 
test_freq = 800 
max_epoch_number = 35 

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

device = torch.device("cuda:0")

save_path = 'chekpoints/'
logdir = 'logs/'

In [8]:
def checkpoint_save(model, save_path, epoch):
    f = os.path.join(save_path, 'checkpoint-{:06d}.pth'.format(epoch))
    if 'module' in dir(model):
        torch.save(model.module.state_dict(), f)
    else:
        torch.save(model.state_dict(), f)
    print('saved checkpoint:', f)

In [9]:
val_transform = transforms.Compose([
    #transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

train_transform = transforms.Compose([
    #transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

In [None]:
test_annotations = os.path.join(VALCSV)
train_annotations = os.path.join(TRAINCSV)
test_dataset = XDataset(test_annotations, val_transform)
train_dataset = XDataset(train_annotations, train_transform)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True,
                              drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)
num_train_batches = int(np.ceil(len(train_dataset) / batch_size))

# model
model = Resnet18(len(train_dataset.classes))
model.train()
model = model.to(device)

# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

os.makedirs(save_path, exist_ok=True)

# Loss
criterion = nn.BCELoss()

logger = SummaryWriter(logdir)

In [None]:
# Run training
epoch = 0
iteration = 0
while True:
    batch_losses = []
    for imgs, targets in train_dataloader:
        imgs, targets = imgs.to(device), targets.to(device)

        optimizer.zero_grad()

        model_result = model(imgs)
        loss = criterion(model_result, targets.type(torch.float))

        batch_loss_value = loss.item()
        loss.backward()
        optimizer.step()

        logger.add_scalar('train_loss', batch_loss_value, iteration)
        batch_losses.append(batch_loss_value)
        with torch.no_grad():
            result = calculate_metrics(model_result.cpu().numpy(), targets.cpu().numpy())
            for metric in result:
                logger.add_scalar('train/' + metric, result[metric], iteration)

        if iteration % test_freq == 0:
            model.eval()
            with torch.no_grad():
                model_result = []
                targets = []
                for imgs, batch_targets in test_dataloader:
                    imgs = imgs.to(device)
                    model_batch_result = model(imgs)
                    model_result.extend(model_batch_result.cpu().numpy())
                    targets.extend(batch_targets.cpu().numpy())

            result = calculate_metrics(np.array(model_result), np.array(targets))
            for metric in result:
                logger.add_scalar('test/' + metric, result[metric], iteration)
            print("epoch:{:2d} iter:{:3d} test: "
                  "micro f1: {:.3f} "
                  "macro f1: {:.3f} "
                  "samples f1: {:.3f}".format(epoch, iteration,
                                              result['micro/f1'],
                                              result['macro/f1'],
                                              result['samples/f1']))

            model.train()
        iteration += 1

    loss_value = np.mean(batch_losses)
    print("epoch:{:2d} iter:{:3d} train: loss:{:.3f}".format(epoch, iteration, loss_value))
    if epoch % save_freq == 0:
        checkpoint_save(model, save_path, epoch)
    epoch += 1
    if max_epoch_number < epoch:
        break