In [1]:
from torch_uncertainty import cli_main, init_args
from torch_uncertainty.baselines.classification import ResNet
from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18
from torch_uncertainty.routines.classification import ClassificationSingle
from torch_uncertainty.datamodules import CIFAR10DataModule
import torch_uncertainty.losses
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.nn import CrossEntropyLoss
from pytorch_lightning import Trainer
import torch

import os
import sys
from pathlib import Path
from cli_test_helpers import ArgvContext

In [2]:
torch.cuda.is_available()

True

### 1 - Models

In [4]:
baseline = ResNet(num_classes=10,
                loss=CrossEntropyLoss,
                optimization_procedure=optim_cifar10_resnet18,
                version="std",
                in_channels=3, 
                arch=18)

mixup = ResNet(num_classes=10,
                loss=CrossEntropyLoss,
                optimization_procedure=optim_cifar10_resnet18,
                version="std",
                in_channels=3, 
                arch=18, 
                mixup=True,
                mixup_alpha=0.2)

regmixup = ResNet(num_classes=10,
                loss=CrossEntropyLoss,
                optimization_procedure=optim_cifar10_resnet18,
                version="std",
                in_channels=3,
                arch=18,
                reg_mixup=True,
                mixup_alpha=15)

### 2 - Data

In [5]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size = 4

trainset = CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


### 3. Training

In [6]:
root = Path(os.path.abspath(""))

# We mock the arguments for the trainer
with ArgvContext(
    "file.py",
    "--max_epochs",
    "1",
    "--enable_progress_bar",
    "False",
    "--num_estimators",
    "8",
    "--max_epochs",
    "2"
):
    args = init_args(network=ResNet, datamodule=CIFAR10DataModule)

net_name = "logs/reset18-cifar10"

# datamodule
args.root = str(root / "data")
dm = CIFAR10DataModule(**vars(args))


In [None]:
#Trainer(accelerator="gpu", devices=1)
results = cli_main(baseline, dm, root, net_name, args=args)