In [23]:
# data.py

from os.path import join

import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
import pytorch_lightning as pl

import pandas as pd


def lymph_collate_fn(x):
    image, tabular = x
    

class lymph_dataset(Dataset):
    """Wrapper around an imagefolder dataset to provide with tabular data"""
    def __init__(self, path, csv_path, transforms):
        super().__init__()
        self.path = path
        self.csv_path = csv_path
        self.transforms = transforms
        self.images = torchvision.datasets.ImageFolder(root=self.path, transform=self.transforms)
        self.tabular = pd.read_csv(self.csv_path, index_col=0)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image, lab_idx = self.images[idx]
        label = self.images.classes[lab_idx]
        tab_data = self.tabular.loc[label]
        return image, tab_data.LABEL


class lymph_datamodule(pl.LightningDataModule):
    def __init__(self, path, batch_size=32, train_prop=0.8):
        super().__init__()
        self.train_path = join(path, "trainset")
        self.train_csv_path = join(self.train_path, "trainset_true.csv")
        self.test_path = join(path, "testset")
        self.test_csv_path = join(self.test_path, "testset_data.csv")
        self.batch_size = batch_size
        self.train_prop = train_prop

        # TODO data augment
        # TODO normalize?
        self.train_transforms = transforms.Compose([
            torchvision.transforms.Resize(224),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.ToTensor()
        ])
        self.test_transforms = transforms.Compose([
            torchvision.transforms.Resize(224),
            torchvision.transforms.ToTensor()
        ])

    def prepare_data(self):
        pass

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            self.full_dataset = lymph_dataset(self.train_path, self.train_csv_path, self.train_transforms)
            train_size = int(len(self.full_dataset) * self.train_prop)
            val_size = len(self.full_dataset) - train_size
            self.train_dataset, self.val_dataset = random_split(self.full_dataset, [train_size, val_size])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.test_dataset = lymph_dataset(self.test_path, self.test_csv_path, self.test_transform)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset,
                                           batch_size=self.batch_size,
                                           shuffle=True,
                                           num_workers=8,
                                           pin_memory=True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset,
                                           batch_size=self.batch_size,
                                           shuffle=True,
                                           num_workers=8,
                                           pin_memory=True)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.test_dataset,
                                           batch_size=self.batch_size,
                                           shuffle=True,
                                           num_workers=8,
                                           pin_memory=True)

In [38]:
# model.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.optim import SGD, Adam
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import pytorch_lightning as pl


# Some code reused from https://github.com/Stevellen/ResNet-Lightning/blob/master/resnet_classifier.py
class TransferResNet(pl.LightningModule):
    def __init__(self,
                 resnet_version=18,
                 optimizer='adam',
                 lr=1e-3):
        super().__init__()
        self.__dict__.update(locals())

        resnets = {
            18: models.resnet18, 34: models.resnet34,
            50: models.resnet50, 101: models.resnet101,
            152: models.resnet152
        }
        optimizers = {'adam': Adam, 'sgd': SGD}

        self.optimizer = optimizers[optimizer]
        
        self.criterion = nn.BCEWithLogitsLoss()

        self.resnet = resnets[resnet_version](pretrained=True)
        linear_size = list(self.resnet.children())[-1].in_features
        self.resnet.fc = nn.Linear(linear_size, 1)

        for child in list(self.resnet.children())[:-1]:
            for param in child.parameters():
                param.requires_grad = False
        self.resnet.eval()

    def forward(self, im):
        return self.resnet(im).squeeze()

    def configure_optimizers(self):
        return self.optimizer(self.parameters(), lr=self.lr)
    
    def training_step(self, batch, batch_idx):
        image, label = batch
        score = self(image)
        loss = self.criterion(score, label.float())
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        image, label = batch
        score = self(image)
        loss = self.criterion(score, label.float())
        pred = (score > 0.5).int()
        acc = (pred == label).sum() / pred.shape[0]
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        return loss


In [39]:
# train.py

from argparse import ArgumentParser
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint


pl.seed_everything(42)
    
# Handle the data
dm = lymph_datamodule("data", batch_size=16)

# Define model
model = TransferResNet()

# Exp logger
logger = TensorBoardLogger('logs/tensorboard_logs')

# Define training
checkpointer = ModelCheckpoint(monitor='val_loss',
                               save_top_k=3,
                               mode='min',
                               save_last=True,
                               filename='{epoch}-{val_loss:.2f}-{train_loss:.2f}')
trainer = pl.Trainer(gpus=1,
                     max_epochs=10,
                     #callbacks=[checkpointer],
                     logger=logger,
                     val_check_interval=0.5)

# Train
trainer.fit(model, dm)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type              | Params
------------------------------------------------
0 | criterion | BCEWithLogitsLoss | 0     
1 | resnet    | ResNet            | 11.2 M
------------------------------------------------
513       Trainable params
11.2 M    Non-trainable params
11.2 M    Total params


Epoch 0:  33%|███▎      | 336/1011 [01:16<02:33,  4.39it/s, loss=0.352, v_num=20]
Validating: 0it [00:00, ?it/s][A
Epoch 0:  33%|███▎      | 338/1011 [01:18<02:36,  4.31it/s, loss=0.352, v_num=20]
Epoch 0:  34%|███▍      | 346/1011 [01:19<02:31,  4.38it/s, loss=0.352, v_num=20]
Validating:   8%|▊         | 13/169 [00:02<02:28,  1.05it/s][A
Validating:   9%|▉         | 15/169 [00:02<01:52,  1.37it/s][A
Epoch 0:  35%|███▌      | 354/1011 [01:19<02:28,  4.43it/s, loss=0.352, v_num=20]
Validating:  11%|█         | 18/169 [00:03<01:11,  2.11it/s][A
Validating:  11%|█         | 19/169 [00:03<00:58,  2.56it/s][A
Validating:  12%|█▏        | 20/169 [00:04<00:52,  2.81it/s][A
Validating:  12%|█▏        | 21/169 [00:04<00:45,  3.25it/s][A
Validating:  13%|█▎        | 22/169 [00:04<00:40,  3.64it/s][A
Validating:  14%|█▎        | 23/169 [00:04<00:36,  3.99it/s][A
Validating:  14%|█▍        | 24/169 [00:04<00:37,  3.88it/s][A
Epoch 0:  36%|███▌      | 362/1011 [01:21<02:26,  4.43it/s, lo

Validating:  79%|███████▊  | 133/169 [00:31<00:10,  3.32it/s][A
Validating:  79%|███████▉  | 134/169 [00:31<00:10,  3.22it/s][A
Validating:  80%|███████▉  | 135/169 [00:31<00:10,  3.32it/s][A
Validating:  80%|████████  | 136/169 [00:32<00:09,  3.53it/s][A
Epoch 0:  47%|████▋     | 474/1011 [01:49<02:03,  4.35it/s, loss=0.352, v_num=20]
Validating:  82%|████████▏ | 138/169 [00:32<00:09,  3.19it/s][A
Validating:  82%|████████▏ | 139/169 [00:33<00:09,  3.11it/s][A
Validating:  83%|████████▎ | 140/169 [00:33<00:09,  3.00it/s][A
Validating:  83%|████████▎ | 141/169 [00:33<00:09,  2.98it/s][A
Validating:  84%|████████▍ | 142/169 [00:34<00:09,  2.97it/s][A
Validating:  85%|████████▍ | 143/169 [00:34<00:08,  2.95it/s][A
Validating:  85%|████████▌ | 144/169 [00:34<00:08,  2.94it/s][A
Epoch 0:  48%|████▊     | 482/1011 [01:51<02:02,  4.31it/s, loss=0.352, v_num=20]
Validating:  86%|████████▋ | 146/169 [00:35<00:07,  3.24it/s][A
Validating:  87%|████████▋ | 147/169 [00:35<00:06,  3.64

Epoch 0:  93%|█████████▎| 938/1011 [04:09<00:19,  3.76it/s, loss=0.374, v_num=20]
Validating:  57%|█████▋    | 96/169 [00:23<00:19,  3.70it/s][A
Validating:  57%|█████▋    | 97/169 [00:24<00:19,  3.67it/s][A
Validating:  58%|█████▊    | 98/169 [00:24<00:19,  3.63it/s][A
Validating:  59%|█████▊    | 99/169 [00:24<00:19,  3.65it/s][A
Validating:  59%|█████▉    | 100/169 [00:24<00:18,  3.67it/s][A
Validating:  60%|█████▉    | 101/169 [00:25<00:18,  3.68it/s][A
Validating:  60%|██████    | 102/169 [00:25<00:18,  3.70it/s][A
Epoch 0:  94%|█████████▎| 946/1011 [04:11<00:17,  3.76it/s, loss=0.374, v_num=20]
Validating:  62%|██████▏   | 104/169 [00:25<00:17,  3.72it/s][A
Validating:  62%|██████▏   | 105/169 [00:26<00:17,  3.74it/s][A
Validating:  63%|██████▎   | 106/169 [00:26<00:17,  3.68it/s][A
Validating:  63%|██████▎   | 107/169 [00:26<00:16,  3.71it/s][A
Validating:  64%|██████▍   | 108/169 [00:27<00:16,  3.71it/s][A
Validating:  64%|██████▍   | 109/169 [00:27<00:16,  3.72it/s

Validating:  34%|███▎      | 57/169 [00:11<00:27,  4.02it/s][A
Validating:  34%|███▍      | 58/169 [00:11<00:28,  3.90it/s][A
Validating:  35%|███▍      | 59/169 [00:11<00:27,  4.06it/s][A
Validating:  36%|███▌      | 60/169 [00:12<00:27,  3.98it/s][A
Validating:  36%|███▌      | 61/169 [00:12<00:27,  3.86it/s][A
Validating:  37%|███▋      | 62/169 [00:12<00:28,  3.81it/s][A
Epoch 1:  40%|███▉      | 400/1011 [01:55<02:56,  3.47it/s, loss=0.419, v_num=20]
Validating:  38%|███▊      | 64/169 [00:13<00:28,  3.74it/s][A
Validating:  38%|███▊      | 65/169 [00:13<00:27,  3.75it/s][A
Validating:  39%|███▉      | 66/169 [00:13<00:27,  3.73it/s][A
Validating:  40%|███▉      | 67/169 [00:13<00:27,  3.70it/s][A
Validating:  40%|████      | 68/169 [00:14<00:27,  3.69it/s][A
Validating:  41%|████      | 69/169 [00:14<00:27,  3.67it/s][A
Validating:  41%|████▏     | 70/169 [00:14<00:27,  3.64it/s][A
Epoch 1:  40%|████      | 408/1011 [01:57<02:53,  3.47it/s, loss=0.419, v_num=20]
Vali

Validating:  11%|█         | 19/169 [00:02<00:34,  4.29it/s][A
Validating:  12%|█▏        | 20/169 [00:02<00:33,  4.50it/s][A
Epoch 1:  85%|████████▌ | 864/1011 [04:08<00:42,  3.48it/s, loss=0.352, v_num=20]
Validating:  13%|█▎        | 22/169 [00:03<00:36,  4.05it/s][A
Validating:  14%|█▎        | 23/169 [00:03<00:34,  4.17it/s][A
Validating:  14%|█▍        | 24/169 [00:03<00:35,  4.10it/s][A
Validating:  15%|█▍        | 25/169 [00:04<00:33,  4.31it/s][A
Validating:  15%|█▌        | 26/169 [00:04<00:36,  3.91it/s][A
Validating:  16%|█▌        | 27/169 [00:04<00:36,  3.88it/s][A
Validating:  17%|█▋        | 28/169 [00:04<00:37,  3.75it/s][A
Epoch 1:  86%|████████▋ | 872/1011 [04:10<00:39,  3.48it/s, loss=0.352, v_num=20]
Validating:  18%|█▊        | 30/169 [00:05<00:36,  3.82it/s][A
Validating:  18%|█▊        | 31/169 [00:05<00:36,  3.78it/s][A
Validating:  19%|█▉        | 32/169 [00:06<00:36,  3.76it/s][A
Validating:  20%|█▉        | 33/169 [00:06<00:35,  3.80it/s][A
Vali

1

In [None]:
dm = lymph_datamodule("data", batch_size=16)
dm.setup("fit")
print(type(dm.train_dataset.dataset.tabular.loc["P0"].values))

In [None]:
torchvision.datasets.ImageFolder("data/trainset", transform=torchvision.transforms.ToTensor())