In [None]:
# If using Google Colab, mount Google Drive and set up the path
from google.colab import drive
drive.flush_and_unmount()
drive.mount('/content/drive', force_remount=True)
import sys
!ls /content/drive/MyDrive/'Colab Notebooks'/conformal-prediction-introduction
sys.path.append('/content/drive/MyDrive/Colab Notebooks/conformal-prediction-introduction')


In [2]:
import torch
from torchvision import datasets, transforms, models
from src.data import IndexedDataset
from src.train import train_model, evaluate_and_save
 
import torch.nn as nn
from functools import partial


print("Running with",  "cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_folder = "dataset"
results_folder = "results"
num_workers = 6
val_frac = 0.2
holdout_frac = 0.2
epochs = 10
lr = 1e-3

# Data transforms
transform_train = transforms.Compose(
    [
        transforms.Resize(224), # Resize image
        transforms.RandomHorizontalFlip(), # Randomly mirror images
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)
transform_test = transforms.Compose(
    [
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

# Datasets and loaders
full_trainset = IndexedDataset(
    datasets.CIFAR10(
        root=data_folder, train=True, download=True, transform=transform_train
    )
)
val_size = int(val_frac * len(full_trainset))
holdout_size = int(holdout_frac * len(full_trainset))
train_size = len(full_trainset) - val_size - holdout_size
trainset, valset, holdoutset = torch.utils.data.random_split(
    full_trainset, [train_size, val_size, holdout_size]
)
testset = IndexedDataset(
    datasets.CIFAR10(
        root=data_folder, train=False, download=True, transform=transform_test
    )
)

dataloader_settings = partial(
    torch.utils.data.DataLoader,
    batch_size=64,
    num_workers=num_workers,
    pin_memory=True if device.type == "cuda" else False,
    persistent_workers=True,
)

trainloader = dataloader_settings(trainset, shuffle=True, drop_last=True)
valloader = dataloader_settings(valset, shuffle=False)
testloader = dataloader_settings(testset, shuffle=False)
holdoutloader = dataloader_settings(holdoutset, shuffle=False)

# Model setup
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, 10)
model = model.to(device)



# Train the model
model = train_model(model, trainloader, valloader, device, epochs=epochs, lr=lr)

# Evaluation on validation set
evaluate_and_save(model, valloader, device, results_folder, "val_predictions.pth")

# Evaluation on test set
evaluate_and_save(model, testloader, device, results_folder, "test_predictions.pth")

# Save the model and predictions on the holdout set
torch.save(model.state_dict(), results_folder + "/cifar10_resnet18.pth")
evaluate_and_save(model, holdoutloader, device, results_folder, "holdout_predictions.pth")


Running with cuda
Files already downloaded and verified
Files already downloaded and verified


 10%|█         | 1/10 [03:13<29:04, 193.87s/it]

Epoch 1/10, Loss: 0.6383, Validation Accuracy: 84.00%


 20%|██        | 2/10 [04:26<16:20, 122.60s/it]

Epoch 2/10, Loss: 0.3998, Validation Accuracy: 85.01%


 30%|███       | 3/10 [05:39<11:39, 99.94s/it] 

Epoch 3/10, Loss: 0.2981, Validation Accuracy: 88.33%


 40%|████      | 4/10 [06:52<08:56, 89.44s/it]

Epoch 4/10, Loss: 0.2402, Validation Accuracy: 89.08%


 50%|█████     | 5/10 [08:06<06:59, 83.90s/it]

Epoch 5/10, Loss: 0.1987, Validation Accuracy: 88.36%


 60%|██████    | 6/10 [09:20<05:21, 80.38s/it]

Epoch 6/10, Loss: 0.1626, Validation Accuracy: 87.88%


 70%|███████   | 7/10 [10:33<03:53, 77.83s/it]

Epoch 7/10, Loss: 0.1524, Validation Accuracy: 88.69%


 80%|████████  | 8/10 [11:46<02:32, 76.29s/it]

Epoch 8/10, Loss: 0.1264, Validation Accuracy: 89.29%


 90%|█████████ | 9/10 [12:59<01:15, 75.23s/it]

Epoch 9/10, Loss: 0.1039, Validation Accuracy: 89.41%


100%|██████████| 10/10 [14:11<00:00, 85.19s/it]

Epoch 10/10, Loss: 0.0991, Validation Accuracy: 89.77%





Accuracy on val set: 89.22%
Accuracy on test set: 88.89%
Accuracy on holdout set: 89.21%
