In [None]:
import os
import math
import torch
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import DataLoader, random_split, Dataset
from PIL import Image
from training_loop_single import train_model
from torchvision import models
from inceptionresnetv2 import InceptionResNetV2

In [None]:
CSV_FILE = os.path.join('data', 'Data_Entry_2017.csv')
#CSV_FILE = os.path.join('sample_data', 'sample_labels.csv')
BATCH_SIZE = 15

In [None]:
import csv
from collections import defaultdict

class XRayDataset(Dataset):
    def __init__(self, transform, validation=False):
        self.transform = transform
        self.files = []
        if not os.path.exists(CSV_FILE):
            raise Exception('missing csv data file {}, please download data as described in README.md'.format(CSV_FILE))

        self.classes = set()
        self.class_counts = defaultdict(lambda: 0)

        with open('data/test_list.txt' if validation else 'data/train_val_list.txt') as f:
            filenames = set([s.strip() for s in f.readlines()])

        with open(CSV_FILE) as csvfile:
            reader = csv.reader(csvfile, delimiter=',')
            next(reader) # skip header
            for row in reader:
                filename, labels, *_ = row
                if filename not in filenames:
                    continue

                labels = labels.split('|')
                # only use images with a single class
                if len(labels) != 1 or labels[0] == 'No Finding':
                    continue
                self.files.append((filename, labels[0]))
                self.classes.update(labels)
                self.class_counts[labels[0]] += 1

        # convert set to list to have a guaranteed iteration order
        # this should also be the case with a set, but it is not explictly defined
        self.classes = sorted(list(self.classes))
        class_weights = []
        for class_ in self.classes:
            class_weights.append(1 / self.class_counts[class_])
        self.class_weights = torch.tensor(class_weights, dtype=torch.float)

    def __getitem__(self, index):
        filename, label = self.files[index]
        image = Image.open(os.path.join('data299', 'processed_images', filename))
        image = self.transform(image)
        return image, torch.tensor(self.classes.index(label), dtype=torch.long)

    def __len__(self):
        return len(self.files)

In [None]:
def load_dataset():
    transform = transforms.Compose([
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.0),
        #transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    train = XRayDataset(transform=transform)
    validation = XRayDataset(transform=transform, validation=True)
    loader = DataLoader(train, batch_size=BATCH_SIZE)
    validation_loader = DataLoader(validation, batch_size=BATCH_SIZE)
    return len(train.classes), loader, validation_loader, train.class_weights

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

num_classes, loader, validation_loader, weights = load_dataset()

model = models.densenet121(pretrained=True)
num_ftrs = model.classifier.in_features
model.classifier = nn.Sequential(nn.Linear(num_ftrs, num_classes), nn.Sigmoid())

#model = InceptionResNetV2(num_classes=num_classes)

model = model.to(device)

criterion = nn.CrossEntropyLoss(weight=weights.to(device))
optimizer = optim.Adam(params=model.parameters(), lr=0.0001)
#optimizer = optim.SGD(params=model.parameters(), lr=0.001, momentum=0.9)

train_model(
    'inceptionresnetv2_single_full_nonofindings_weighted_augmented',
    model,
    {'train': loader, 'val': validation_loader},
    criterion,
    optimizer,
    device,
    num_epochs=1000
)