In [1]:
# !pip install wilds

In [2]:
from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader
from wilds.common.data_loaders import get_eval_loader
import torchvision.transforms as transforms
from torchvision.models import DenseNet
import torch
import pytorch_lightning as pl

In [3]:
# Load the full dataset, and download it if necessary
dataset = get_dataset(dataset="camelyon17", download=True)

In [4]:
trans = transforms.Compose(
        [transforms.ToTensor()]
    )

train_data = dataset.get_subset(
    "train",
    frac=1,
    transform=trans,
)

# Get the test set
test_data = dataset.get_subset(
    "test",
    transform=trans
)

val_data = dataset.get_subset(
    'val',
    transform=trans
)

id_val_data = dataset.get_subset(
    'id_val',
    transform=trans
)

test_datasets = [test_data, val_data, id_val_data]
test_split_names = ['test', 'val', 'idval']

In [5]:
class LitDensenet(pl.LightningModule):
    ''' Returns a Densenet121 with growth parameter k. '''
    def __init__(self, trainset, testsets, testset_names, 
                k=32, num_classes=10, lr=1e-4, train_batch_size=256, test_batch_size=256):
        super().__init__()
        self.trainset = trainset
        self.testsets = testsets
        self.testset_names = testset_names
        self.train_batch_size = train_batch_size
        self.test_batch_size = test_batch_size
        self.model = DenseNet(growth_rate=k, num_classes=num_classes)
        self.lr = lr

    def train_dataloader(self):
        train_loader=  get_train_loader("standard", self.trainset, 
                                batch_size=self.train_batch_size,
                                num_workers=8, pin_memory=True)
        return train_loader
    
    def val_dataloader(self):
        val_loaders = []
        for dataset in self.testsets:
            loader = get_eval_loader("standard", dataset, 
                                batch_size=self.test_batch_size,
                                num_workers=8, pin_memory=True)
            val_loaders.append(loader)
        return val_loaders

    def training_step(self, batch, batch_idx):
        inputs, labels, metadata = batch
        inputs = inputs.cuda()
        labels = labels.cuda()
        outputs = self.model(inputs)
        loss = torch.nn.functional.cross_entropy(outputs, labels)

        _, y_pred = torch.max(outputs.data, 1)
        return {'loss':loss,
                'y_pred':y_pred.cpu(),
                'labels':labels.cpu(),
                'metadata':metadata.cpu()}
    
    def training_epoch_end(self, train_step_outputs):
        preds = [x["y_pred"] for x in train_step_outputs]
        labels = [x["labels"] for x in train_step_outputs]
        metadata = [x["metadata"] for x in train_step_outputs]
        eval = self.trainset.eval(torch.cat(preds), 
                            torch.cat(labels), 
                            torch.cat(metadata))
        for key, value in eval[0].items():
            if 'avg' in key or 'wg' in key:
                self.log(key+'_'+'train', value, on_epoch=True)
            else:
                self.log(key,value, on_epoch=True)


    def validation_step(self, batch, batch_idx, dataloader_idx):
        inputs, labels, metadata = batch
        inputs = inputs.cuda()
        labels = labels.cuda()
        outputs = self.model(inputs)
        loss = torch.nn.functional.cross_entropy(outputs, labels)

        _, y_pred = torch.max(outputs.data, 1)
        return {'loss':loss,
                'test_idx':dataloader_idx,
                'y_pred':y_pred.cpu(),
                'labels':labels.cpu(),
                'metadata':metadata.cpu()}

    def validation_epoch_end(self, val_step_outputs) -> None:
        for k in range(len(self.testsets)):
            preds = []
            labels = []
            metadata = []
            for x in val_step_outputs[k]:
                preds.append(x['y_pred'])
                labels.append(x['labels'])
                metadata.append(x['metadata'])
            eval = self.testsets[k].eval(torch.cat(preds), 
                                        torch.cat(labels), 
                                        torch.cat(metadata))
            # Logging
            testset_name = self.testset_names[k]
            for key, value in eval[0].items():
                if 'avg' in key or 'wg' in key:
                    self.log(key+'_'+testset_name, value)
                else:
                    self.log(key,value)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [6]:
lr = 0.0001
batch_size = 256
k = 16
n_cls = 2
label_noise = 0
max_epoch = 1

In [7]:
model = LitDensenet(train_data, test_datasets, test_split_names, 
                    k=k, num_classes=n_cls, lr=lr)

In [8]:
logger = pl.loggers.CSVLogger('logs', 
                            name=f"densenet_width{k}_noise{label_noise}")

In [9]:
trainer = pl.Trainer(
    max_epochs=max_epoch,
    accelerator="gpu",
    precision=16,
    logger=logger,
    devices=1,)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [10]:
trainer.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type     | Params
-----------------------------------
0 | model | DenseNet | 1.8 M 
-----------------------------------
1.8 M     Trainable params
0         Non-trainable params
1.8 M     Total params
3.646     Total estimated model params size (MB)


Epoch 0: 100%|██████████| 1784/1784 [02:04<00:00, 14.36it/s, loss=0.0876, v_num=1]


: 