In [6]:
import torch
from torch.utils.data import DataLoader

from torchvision.models import alexnet, vgg16, googlenet, inception_v3, resnet18, densenet161
from torchvision.datasets import PCAM
import torchvision.transforms as transforms

import ipywidgets as widgets
from IPython.core.display import display

from pcam import train


  from IPython.core.display import display


In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)


cpu


In [8]:
## Dataset and data loaders

transform = transforms.Compose([
    transforms.PILToTensor()
])

train_dataset = PCAM(root='data', split='train', download=True, transform=transform)
val_dataset = PCAM(root='data', split='val', download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)


In [9]:
## Model Selection

model_widget = widgets.Select(
    options=[('AlexNet', alexnet), ('VGG-16', vgg16), ('GoogleNet', googlenet), ('Inception-v3', inception_v3), ('ResNet-18', resnet18), ('DenseNet-161', densenet161)],
    value=alexnet,
    description='Model:',
    disabled=False,
)
display(model_widget)

Select(description='Model:', options=(('AlexNet', <function alexnet at 0x7f38e678b9d0>), ('VGG-16', <function …

In [11]:
## Model Initialization
 
print(f'Selected Model: {model_widget.value.__name__}')
model = model_widget.value(pretrained=True)
model.to(device)

# Freeze all layers except last
for param in model.parameters():
    param.requires_grad = False

# Create classification layer    
num_classes = 2
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
# TODO: .fc works for GoogleNet, ResNet only
# TODO: Also fix inception because it cant run

## Optimizer
optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.01, momentum=0.9)

## Loss Function
loss_fun = torch.nn.CrossEntropyLoss()


Selected Model: resnet18




In [12]:
train(model, train_loader, val_loader, loss_fun, optimizer, num_epochs=5, num_classes = 2, device=device)

Epoch 1/1, Training: 100%|██████████| 1024/1024 [02:26<00:00,  7.00it/s]
Epoch 1/1, Validation: 100%|██████████| 1024/1024 [02:16<00:00,  7.52it/s]

Train Loss: 0.9495, Train Acc: 0.7553, Train AUC: 0.8230, 
 Val Loss: 1.0507, Val Acc: 0.7192, Val AUC: 0.8100

Saved checkpoint at: models/ResNet_lr01_epoch0.pt



