In [1]:
# !pip install wilds

In [2]:
from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader
from wilds.common.data_loaders import get_eval_loader
import torchvision.transforms as transforms
from torchvision.models import DenseNet
import torch
import pytorch_lightning as pl
import numpy as np
import pandas as pd

In [3]:
lr = 0.0001
batch_size = 256
k = 75
n_cls = 2
label_noise = 0.0
max_epoch = 1

In [4]:
# Adapted from https://github.com/YBZh/Bridging_UDA_SSL

from PIL import Image, ImageOps, ImageEnhance, ImageDraw


def AutoContrast(img, _):
    return ImageOps.autocontrast(img)


def Brightness(img, v):
    assert v >= 0.0
    return ImageEnhance.Brightness(img).enhance(v)


def Color(img, v):
    assert v >= 0.0
    return ImageEnhance.Color(img).enhance(v)


def Contrast(img, v):
    assert v >= 0.0
    return ImageEnhance.Contrast(img).enhance(v)


def Equalize(img, _):
    return ImageOps.equalize(img)


def Invert(img, _):
    return ImageOps.invert(img)


def Identity(img, v):
    return img


def Posterize(img, v):  # [4, 8]
    v = int(v)
    v = max(1, v)
    return ImageOps.posterize(img, v)


def Rotate(img, v):  # [-30, 30]
    return img.rotate(v)


def Sharpness(img, v):  # [0.1,1.9]
    assert v >= 0.0
    return ImageEnhance.Sharpness(img).enhance(v)


def ShearX(img, v):  # [-0.3, 0.3]
    return img.transform(img.size, Image.AFFINE, (1, v, 0, 0, 1, 0))


def ShearY(img, v):  # [-0.3, 0.3]
    return img.transform(img.size, Image.AFFINE, (1, 0, 0, v, 1, 0))


def TranslateX(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
    v = v * img.size[0]
    return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0))


def TranslateXabs(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
    return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0))


def TranslateY(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
    v = v * img.size[1]
    return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v))


def TranslateYabs(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
    return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v))


def Solarize(img, v):  # [0, 256]
    assert 0 <= v <= 256
    return ImageOps.solarize(img, v)


def Cutout(img, v):  # [0, 60] => percentage: [0, 0.2] => change to [0, 0.5]
    assert 0.0 <= v <= 0.5

    v = v * img.size[0]
    return CutoutAbs(img, v)


def CutoutAbs(img, v):  # [0, 60] => percentage: [0, 0.2]
    if v < 0:
        return img
    w, h = img.size
    x_center = _sample_uniform(0, w)
    y_center = _sample_uniform(0, h)

    x0 = int(max(0, x_center - v / 2.0))
    y0 = int(max(0, y_center - v / 2.0))
    x1 = min(w, x0 + v) 
    y1 = min(h, y0 + v)

    xy = (x0, y0, x1, y1)
    color = (125, 123, 114)
    img = img.copy()
    ImageDraw.Draw(img).rectangle(xy, color)
    return img


FIX_MATCH_AUGMENTATION_POOL = [
    (AutoContrast, 0, 1),
    (Brightness, 0.05, 0.95),
    (Color, 0.05, 0.95),
    (Contrast, 0.05, 0.95),
    (Equalize, 0, 1),
    (Identity, 0, 1),
    (Posterize, 4, 8),
    (Rotate, -30, 30),
    (Sharpness, 0.05, 0.95),
    (ShearX, -0.3, 0.3),
    (ShearY, -0.3, 0.3),
    (Solarize, 0, 256),
    (TranslateX, -0.3, 0.3),
    (TranslateY, -0.3, 0.3),
]


def _sample_uniform(a, b):
    return torch.empty(1).uniform_(a, b).item()


class RandAugment:
    def __init__(self, n, augmentation_pool):
        assert n >= 1, "RandAugment N has to be a value greater than or equal to 1."
        self.n = n
        self.augmentation_pool = augmentation_pool

    def __call__(self, img):
        ops = [
            self.augmentation_pool[torch.randint(len(self.augmentation_pool), (1,))]
            for _ in range(self.n)
        ]
        for op, min_val, max_val in ops:
            val = min_val + float(max_val - min_val) * _sample_uniform(0, 1)
            img = op(img, val)
        cutout_val = _sample_uniform(0, 1) * 0.5
        img = Cutout(img, cutout_val)
        return img

In [5]:
# # add label noise to train data
# metadata = pd.read_csv('noisy_data/camelyon17_v1.0/metadata.csv', 
#                         index_col=0,
#                         dtype={'patient': 'str'})
# y_arr = metadata[metadata['split'] == 0]['tumor'].to_numpy() # training split
# p = 0.15
# flip_loc = np.random.choice([True,False], len(y_arr), p=[p, 1-p])
# noisy_y = np.logical_not(y_arr, where = flip_loc, out=y_arr.copy())
# metadata.loc[metadata['split']==0, 'tumor'] = noisy_y
# metadata.to_csv('noisy_data/camelyon17_v1.0/metadata.csv')

In [6]:
# # Load the full dataset, and download it if necessary
# dataset = get_dataset(dataset="camelyon17", download=True)

# Load custom dataset
data_dir='noisy_data_0.15'
dataset = get_dataset(dataset="camelyon17", download=False, root_dir=data_dir)

In [7]:
train_trans = transforms.Compose(
        [RandAugment(2, FIX_MATCH_AUGMENTATION_POOL),
        transforms.ToTensor()]
    )

trans = transforms.Compose(
        [transforms.ToTensor()]
    )

train_data = dataset.get_subset(
    "train",
    frac=1,
    transform=train_trans,
)

# Get the test set
test_data = dataset.get_subset(
    "test",
    transform=trans
)

val_data = dataset.get_subset(
    'val',
    transform=trans
)

id_val_data = dataset.get_subset(
    'id_val',
    transform=trans
)

test_datasets = [test_data, val_data, id_val_data]
test_split_names = ['test', 'val', 'idval']

In [8]:
class LitDensenet(pl.LightningModule):
    ''' Returns a Densenet121 with growth parameter k. '''
    def __init__(self, trainset, testsets, testset_names, 
                k=32, num_classes=10, lr=1e-4, train_batch_size=256, test_batch_size=256):
        super().__init__()
        self.trainset = trainset
        self.testsets = testsets
        self.testset_names = testset_names
        self.train_batch_size = train_batch_size
        self.test_batch_size = test_batch_size
        self.model = DenseNet(growth_rate=k, num_classes=num_classes)
        self.lr = lr

    def train_dataloader(self):
        train_loader=  get_train_loader("standard", self.trainset, 
                                batch_size=self.train_batch_size,
                                num_workers=8, pin_memory=True)
        return train_loader
    
    def val_dataloader(self):
        val_loaders = []
        for dataset in self.testsets:
            loader = get_eval_loader("standard", dataset, 
                                batch_size=self.test_batch_size,
                                num_workers=8, pin_memory=True)
            val_loaders.append(loader)
        return val_loaders

    def training_step(self, batch, batch_idx):
        inputs, labels, metadata = batch
        inputs = inputs.cuda()
        labels = labels.cuda()
        outputs = self.model(inputs)
        loss = torch.nn.functional.cross_entropy(outputs, labels)

        _, y_pred = torch.max(outputs.data, 1)
        return {'loss':loss,
                'y_pred':y_pred.cpu(),
                'labels':labels.cpu(),
                'metadata':metadata.cpu()}
    
    def training_epoch_end(self, train_step_outputs):
        preds = [x["y_pred"] for x in train_step_outputs]
        labels = [x["labels"] for x in train_step_outputs]
        metadata = [x["metadata"] for x in train_step_outputs]
        eval = self.trainset.eval(torch.cat(preds), 
                            torch.cat(labels), 
                            torch.cat(metadata))
        for key, value in eval[0].items():
            if 'avg' in key or 'wg' in key:
                self.log(key+'_'+'train', value, on_epoch=True)
            else:
                self.log(key,value, on_epoch=True)


    def validation_step(self, batch, batch_idx, dataloader_idx):
        inputs, labels, metadata = batch
        inputs = inputs.cuda()
        labels = labels.cuda()
        outputs = self.model(inputs)
        loss = torch.nn.functional.cross_entropy(outputs, labels)

        _, y_pred = torch.max(outputs.data, 1)
        return {'loss':loss,
                'test_idx':dataloader_idx,
                'y_pred':y_pred.cpu(),
                'labels':labels.cpu(),
                'metadata':metadata.cpu()}

    def validation_epoch_end(self, val_step_outputs) -> None:
        for k in range(len(self.testsets)):
            preds = []
            labels = []
            metadata = []
            for x in val_step_outputs[k]:
                preds.append(x['y_pred'])
                labels.append(x['labels'])
                metadata.append(x['metadata'])
            eval = self.testsets[k].eval(torch.cat(preds), 
                                        torch.cat(labels), 
                                        torch.cat(metadata))
            # Logging
            testset_name = self.testset_names[k]
            for key, value in eval[0].items():
                if 'avg' in key or 'wg' in key:
                    self.log(key+'_'+testset_name, value)
                else:
                    self.log(key,value)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [9]:
model = LitDensenet(train_data, test_datasets, test_split_names, 
                    k=k, num_classes=n_cls, lr=lr)

In [10]:
logger = pl.loggers.CSVLogger('logs', 
                            name=f"densenet_width{k}_noise{label_noise}")

In [11]:
trainer = pl.Trainer(
    max_epochs=max_epoch,
    accelerator="gpu",
    precision=16,
    logger=logger,
    devices=1,)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [12]:
trainer.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type     | Params
-----------------------------------
0 | model | DenseNet | 37.2 M
-----------------------------------
37.2 M    Trainable params
0         Non-trainable params
37.2 M    Total params
74.443    Total estimated model params size (MB)


Validation sanity check: 100%|██████████| 6/6 [00:04<00:00,  2.09it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
import numpy as np
(np.array([15,26,27,16,24,25,13,18,19,22,17,21]) *2)

array([30, 52, 54, 32, 48, 50, 26, 36, 38, 44, 34, 42])

In [None]:
30,52,58,32,54,50,42,46,48,40,64