In [None]:
import os
import math
import torch
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import DataLoader, random_split, Dataset
from PIL import Image
from training_loop_single import train_model
from torchvision import models

In [None]:
CSV_FILE = os.path.join('data', 'Data_Entry_2017.csv')
#CSV_FILE = os.path.join('sample_data', 'sample_labels.csv')
BATCH_SIZE = 10

In [None]:
import csv

class XRayDataset(Dataset):
    def __init__(self, transform):
        self.transform = transform
        self.files = []
        if not os.path.exists(CSV_FILE):
            raise Exception('missing csv data file {}, please download data as described in README.md'.format(CSV_FILE))

        self.classes = set()
            
        with open(CSV_FILE) as csvfile:
            reader = csv.reader(csvfile, delimiter=',')
            next(reader) # skip header
            for row in reader:
                filename, labels, *_ = row
                labels = labels.split('|')
                if len(labels) != 1 or labels[0] == 'No Finding':
                    continue
                self.files.append((filename, labels[0]))
                self.classes.update(labels)

        # convert set to list to have a guaranteed iteration order
        # this should also be the case with a set, but it is not explictly defined
        self.classes = list(self.classes)

    def __getitem__(self, index):
        filename, label = self.files[index]
        image = Image.open(os.path.join('data', 'processed_images', filename))
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(self.classes.index(label), dtype=torch.long)

    def __len__(self):
        return len(self.files)

In [None]:
def load_dataset():
    transform = transforms.Compose([
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.0),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    
    d = XRayDataset(transform=transform)
    size = len(d)
    train_size = int(size * 0.8)
    train, validate = random_split(d, [train_size, size - train_size])
    
    loader = DataLoader(train, batch_size=BATCH_SIZE)
    validation_loader = DataLoader(validate, batch_size=BATCH_SIZE)

    return len(d.classes), loader, validation_loader

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

num_classes, loader, validation_loader = load_dataset()

model = models.resnet18(pretrained=False)
#model = models.resnet50(pretrained=False)
#for param in model.parameters():
#    param.requires_grad = False
#num_ftrs = model.fc.in_features
#model.fc = nn.Linear(num_ftrs, num_classes)

#state_dict = torch.load('resnet18_single_full_0.65617.pth')
#model.load_state_dict(state_dict)

model = model.to(device)

criterion = nn.CrossEntropyLoss()
#optimizer = optim.Adam(params=model.parameters(), lr=0.0001)
optimizer = optim.SGD(params=model.parameters(), lr=0.001, momentum=0.9)

train_model(
    'resnet18_single_full_findingsonly_augmentation',
    model,
    {'train': loader, 'val': validation_loader},
    criterion,
    optimizer,
    device,
    num_epochs=1000
)