In [None]:
import os
import math
import torch
from torch.optim import Adam
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader, random_split, Dataset
from PIL import Image
from training_loop import train_model
from inceptionresnetv2 import InceptionResNetV2

In [None]:
CSV_FILE = os.path.join('sample_data', 'sample_labels.csv')
BATCH_SIZE = 10

In [None]:
import csv

class XRayDataset(Dataset):
    def __init__(self, transform):
        self.transform = transform
        self.files = []
        if not os.path.exists(CSV_FILE):
            raise Exception('missing csv data file {}, please download data as described in README.md'.format(CSV_FILE))

        self.classes = set()
            
        with open(CSV_FILE) as csvfile:
            reader = csv.reader(csvfile, delimiter=',')
            next(reader) # skip header
            for row in reader:
                filename, labels, *_ = row
                labels = labels.split('|')
                self.files.append((filename, labels))
                self.classes.update(labels)

        # convert set to list to have a guaranteed iteration order
        # this should also be the case with a set, but it is not explictly defined
        self.classes = list(self.classes)

    def __getitem__(self, index):
        filename, labels = self.files[index]
        image = Image.open(os.path.join('data', 'processed_images', filename))
        if self.transform:
            image = self.transform(image)

        label_tensor = []
        for label in self.classes:
            if label in labels:
                label_tensor.append(1.0)
            else:
                label_tensor.append(0.0)
        return image, torch.tensor(label_tensor, dtype=torch.float)

    def __len__(self):
        return len(self.files)

In [None]:
def load_dataset():
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    
    d = XRayDataset(transform=transform)
    train, validate = random_split(d, [4485, 1121])
    
    loader = DataLoader(train, batch_size=BATCH_SIZE)
    validation_loader = DataLoader(validate, batch_size=BATCH_SIZE)

    return len(d.classes), loader, validation_loader

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

num_classes, loader, validation_loader = load_dataset()

model = InceptionResNetV2(num_classes=num_classes)
model = model.to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = Adam(params=model.parameters(), lr=0.0001)

train_model(
    'basic',
    model,
    {'train': loader, 'val': validation_loader},
    criterion,
    optimizer,
    device,
    num_epochs=1000
)