In [1]:
import sys
import torch

import numpy as np

from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms

sys.path.append('..')

from comet_ml import Experiment
from comet_ml.integration.pytorch import log_model

In [2]:
batch_size = 64
learning_rate = 0.001
best_valid_loss = np.inf

transform = transforms.Compose([transforms.ToTensor()])

train_data = datasets.CelebA(root='data', split='train', download=True, transform=transform)
valid_data = datasets.CelebA(root='data', split='valid', download=True, transform=transform)
test_data = datasets.CelebA(root='data', split='test', download=True, transform=transform)

features = ['Attractive', 'Eyeglasses', 'No_Beard', 'Male', 'Black_Hair', 'Blond_Hair', 'Mustache', 'Young', 'Smiling', 'Bald']
names_data = train_data.attr_names
idx = [names_data.index(x) for x in features]

train_data.attr = train_data.attr[:,idx]
valid_data.attr = valid_data.attr[:, idx]
test_data.attr = test_data.attr[:, idx]

train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [3]:
from models.MultilabelResnetClassifier import MultilabelResnetClassifier
from pytorch_nn import NNUtil

model = MultilabelResnetClassifier(n_classes=len(features))
loss_fn = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
save_filename = 'resnet34_celeba10attr_10e_2.pt'

experiment = Experiment(
api_key = 'MqskAFE0NPXT89J9t3bXaz6ac',
project_name = 'pytorch-resnet34-classifier',
workspace='wicwik',
log_code=True
)

hyper_params = {
'learning_rate': learning_rate,
'batch_size': batch_size,
'steps': len(train_dataloader)//batch_size,
'loss': 'BCELoss',
'optimizer': "Adam",
'save_filename': save_filename
}

experiment.set_name('resnet_drop02_fc10_2')
experiment.log_parameters(hyper_params)

log_model(experiment, model, model_name='MultilabelResnet34-CelebA-10attributes')
experiment.set_model_graph(model, overwrite=False)

dataloaders={'train': train_dataloader, 'valid': valid_dataloader, 'test': test_dataloader}
trainer = NNUtil(model=model, dataloaders=dataloaders, loss_fn=loss_fn, optimizer=optimizer, save_filename=save_filename, experiment=experiment)

trainer.run_classifier_training()

COMET INFO: Experiment is live on comet.com https://www.comet.com/wicwik/pytorch-resnet34-classifier/af5d78196ae74640b10b8b6e0e9df240



Epoch 1
-------------------------------
loss: 0.844471  [    0/162770] time: 1.3008537292480469 acc: 0.37187498807907104 precision: 0.34300729632377625 recall: 0.6830494403839111, f1_macro: 0.3440707325935364
loss: 0.206048  [ 6400/162770] time: 18.91797137260437 acc: 0.9203125238418579 precision: 0.8298611640930176 recall: 0.7749924063682556, f1_macro: 0.7920374274253845
loss: 0.203863  [12800/162770] time: 18.531570434570312 acc: 0.9046874642372131 precision: 0.7533761262893677 recall: 0.7091480493545532, f1_macro: 0.7005952000617981
loss: 0.206746  [19200/162770] time: 18.687729835510254 acc: 0.9140625596046448 precision: 0.850766658782959 recall: 0.8040753602981567, f1_macro: 0.821570634841919
loss: 0.196587  [25600/162770] time: 18.514316082000732 acc: 0.9234374165534973 precision: 0.6608806252479553 recall: 0.7095963954925537, f1_macro: 0.6732354760169983
loss: 0.177726  [32000/162770] time: 18.762848377227783 acc: 0.9218749403953552 precision: 0.7724224925041199 recall: 0.807484

In [4]:
experiment.end()

COMET INFO: ---------------------------------------------------------------------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------------------------------------------------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.com/wicwik/pytorch-resnet34-classifier/af5d78196ae74640b10b8b6e0e9df240
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     avg_test_accuracy_per_epoch        : 0.9375636577606201
COMET INFO:     avg_test_f1_macro_per_epoch        : 0.8089115619659424
COMET INFO:     avg_test_loss_per_epoch            : 0.14630392611695406
COMET INFO:     avg_test_precision_per_epoch       : 0.8298677206039429
COMET INFO:     avg_test_recall_per_epoch          : 0.8086774349212646
COMET INFO:     avg_train_accuracy_per_epoch [10]  : (0.9281567931175232, 0.9859073758125305)
COMET INFO:     avg_train_f1_macro_per_epoch [10]  : (0.785646140575408