In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import alexnet

from lib.datasets import datasets
from lib.utils import train, alex_classifier, pac_label_to_string, save_model, load_model, list_models, compute_error_rate

cuda = torch.cuda.is_available()

%run HEX.ipynb

# PACS

In [None]:
ds = datasets()
ds.create_dataset('pacs', pacs='art_painting', pacs_heuristic=True)
batch_loader = ds.batch_loader(256)

epoch = 100
log_every = 10

In [None]:
alex = alexnet(pretrained=False)
alex.classifier = alex_classifier(8)
load_model(alex, 'pacs_art_painting')

hexnet = HEX(dim=224, num_classes=8, alex_pretrained=False, alex_params = alex.state_dict())

for param in hexnet.cnn.features.parameters():
    param.requires_grad = False

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(hexnet.parameters(), lr=0.001)

train(hexnet, batch_loader, optimizer, criterion, num_epochs = epoch, cuda = cuda, log_every = log_every)

Training the model!
Params to learn:
	 nglcm.0.a
	 nglcm.0.b
	 nglcm.1.weight
	 nglcm.1.bias
	 cnn.classifier.1.weight
	 cnn.classifier.1.bias
	 cnn.classifier.4.weight
	 cnn.classifier.4.bias
	 classifier.weight
	 classifier.bias
You can interrupt it at any time.
Minibatch     10  | loss  0.88 | err rate 28.00%
Minibatch     20  | loss  0.87 | err rate 25.00%
----------------------------------------------------------
After epoch  1 | valid err rate: 67.63% | doing 100 epochs
----------------------------------------------------------
Minibatch     30  | loss  0.77 | err rate 23.00%
Minibatch     40  | loss  0.84 | err rate 27.00%
Minibatch     50  | loss  1.33 | err rate 41.00%
----------------------------------------------------------
After epoch  2 | valid err rate: 67.63% | doing 100 epochs
----------------------------------------------------------
Minibatch     60  | loss  4.02 | err rate 55.00%
Minibatch     70  | loss  1.11 | err rate 39.00%
Minibatch     80  | loss  1.13 | err r

----------------------------------------------------------
After epoch 26 | valid err rate: 80.03% | doing 100 epochs
----------------------------------------------------------
Minibatch    730  | loss  1.59 | err rate 57.00%
Minibatch    740  | loss  1.65 | err rate 60.00%
Minibatch    750  | loss  1.62 | err rate 56.00%
----------------------------------------------------------
After epoch 27 | valid err rate: 76.32% | doing 100 epochs
----------------------------------------------------------
Minibatch    760  | loss  1.57 | err rate 50.00%
Minibatch    770  | loss  1.53 | err rate 50.00%
Minibatch    780  | loss  1.58 | err rate 55.00%
----------------------------------------------------------
After epoch 28 | valid err rate: 89.84% | doing 100 epochs
----------------------------------------------------------
Minibatch    790  | loss  1.77 | err rate 61.00%
Minibatch    800  | loss  1.71 | err rate 64.00%
Minibatch    810  | loss  1.83 | err rate 71.00%
----------------------------

In [None]:
save_model(net, 'HEX_art_painting')