In [1]:
import os
import ipywidgets as widgets
from IPython.core.display import display

import torch
from torchvision.models import alexnet, vgg16, googlenet, inception_v3, resnet18, densenet161

from pcam import get_dataloaders, train, test


  from IPython.core.display import display


In [2]:
## Check if GPU is used

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)


cuda


In [3]:
## Data

train_loader, val_loader, test_loader = get_dataloaders('data', batch_size=32)


In [4]:
## Model Selection

model_widget = widgets.Select(
    options=[('AlexNet', alexnet), ('VGG-16', vgg16), ('GoogleNet', googlenet), ('Inception-v3', inception_v3), ('ResNet-18', resnet18), ('DenseNet-161', densenet161)],
    value=resnet18,
    description='Model:',
    disabled=False,
)
display(model_widget)

Select(description='Model:', index=4, options=(('AlexNet', <function alexnet at 0x7f5e204e7b80>), ('VGG-16', <…

In [5]:
## Model Initialization
 
print(f'Selected Model: {model_widget.value.__name__}')
model = model_widget.value(pretrained=True)
model.to(device)

# Freeze all layers except last
for param in model.parameters():
    param.requires_grad = False

# Create classification layer    
num_classes = 2
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
# TODO: .fc works for GoogleNet, ResNet only
# TODO: Also fix inception because it cant run

## Optimizer
optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.01, momentum=0.9)

## Loss Function
loss_fun = torch.nn.CrossEntropyLoss()


Selected Model: resnet18




In [6]:
train(model, train_loader, val_loader, loss_fun, optimizer, num_epochs=5, num_classes = 2, device=device)

Epoch 1/5, Training:   0%|          | 1/8192 [00:01<2:38:18,  1.16s/it]


KeyboardInterrupt: 

In [6]:
test(model, test_loader, loss_fun, num_classes, device, load_ckpt_path=os.path.join('models','ResNet_lr01_epoch5.pt'))

Testing: 100%|██████████| 1024/1024 [00:28<00:00, 35.58it/s]

GFLOPS: 0.6692, Test Loss: 0.8939, Test Acc: 0.7530, Test AUC: 0.8197
Saved results at: models/ResNet_lr01_epoch5_1.csv



