In [1]:
import sys
import torch

import numpy as np

from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms

sys.path.append('..')

batch_size = 64
best_valid_loss = np.inf

In [2]:
transform = transforms.Compose([transforms.ToTensor()])

train_data = datasets.CelebA(root='data', split='train', download=True, transform=transform)
valid_data = datasets.CelebA(root='data', split='valid', download=True, transform=transform)
test_data = datasets.CelebA(root='data', split='test', download=True, transform=transform)

features = ['Attractive', 'Eyeglasses', 'No_Beard', 'Male', 'Black_Hair', 'Blond_Hair', 'Mustache', 'Young', 'Smiling', 'Bald']
names_data = train_data.attr_names
idx = [names_data.index(x) for x in features]

train_data.attr = train_data.attr[:,idx]
valid_data.attr = valid_data.attr[:, idx]
test_data.attr = test_data.attr[:, idx]

train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [3]:
from models.MultilabelResnetClassifier import MultilabelResnetClassifier

In [4]:
from pytorch_nn import NNUtil

model = MultilabelResnetClassifier(n_classes=len(features))
loss_fn = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

dataloaders={'train': train_dataloader, 'valid': valid_dataloader, 'test': test_dataloader}
trainer = NNUtil(model=model, dataloaders=dataloaders, loss_fn=loss_fn, optimizer=optimizer)

In [5]:
trainer.run_training()

Epoch 1
-------------------------------
loss: 0.798202  [    0/162770] time: 1.0570011138916016 acc: 0.4765625 precision: 0.3478260934352875 recall: 0.6296296119689941
loss: 0.186934  [ 6400/162770] time: 16.13055729866028 acc: 0.9234374761581421 precision: 0.9122806787490845 recall: 0.8776371479034424
loss: 0.163336  [12800/162770] time: 16.286580801010132 acc: 0.934374988079071 precision: 0.931034505367279 recall: 0.8925619721412659
loss: 0.163681  [19200/162770] time: 16.218785047531128 acc: 0.934374988079071 precision: 0.9166666865348816 recall: 0.9090909361839294
loss: 0.178405  [25600/162770] time: 16.258424758911133 acc: 0.9234374761581421 precision: 0.8986175060272217 recall: 0.8783783912658691
loss: 0.177946  [32000/162770] time: 16.301422119140625 acc: 0.917187511920929 precision: 0.9035087823867798 recall: 0.8691983222961426
loss: 0.162866  [38400/162770] time: 16.35820770263672 acc: 0.9312499761581421 precision: 0.8771186470985413 recall: 0.9324324131011963
loss: 0.166863  

loss: 0.109885  [70400/162770] time: 16.285614728927612 acc: 0.9546874761581421 precision: 0.9636363387107849 recall: 0.9098712205886841
loss: 0.100971  [76800/162770] time: 16.304107904434204 acc: 0.949999988079071 precision: 0.9601770043373108 recall: 0.9041666388511658
loss: 0.138530  [83200/162770] time: 16.24916410446167 acc: 0.953125 precision: 0.9070796370506287 recall: 0.9579439163208008
loss: 0.133780  [89600/162770] time: 16.294527530670166 acc: 0.9390624761581421 precision: 0.8893616795539856 recall: 0.9414414167404175
loss: 0.131332  [96000/162770] time: 16.371323585510254 acc: 0.932812511920929 precision: 0.890350878238678 recall: 0.918552041053772
loss: 0.118314  [102400/162770] time: 16.292105436325073 acc: 0.9546874761581421 precision: 0.9603524208068848 recall: 0.9159663915634155
loss: 0.131064  [108800/162770] time: 16.250661373138428 acc: 0.940625011920929 precision: 0.9200000166893005 recall: 0.9118942618370056
loss: 0.119914  [115200/162770] time: 16.33506011962890

loss: 0.142113  [121600/162770] time: 16.22463369369507 acc: 0.9312499761581421 precision: 0.9066666960716248 recall: 0.8986784219741821
loss: 0.097096  [128000/162770] time: 16.296155214309692 acc: 0.957812488079071 precision: 0.9099099040031433 recall: 0.9665071964263916
loss: 0.126117  [134400/162770] time: 16.2146315574646 acc: 0.9390624761581421 precision: 0.9354838728904724 recall: 0.890350878238678
loss: 0.137261  [140800/162770] time: 16.25610065460205 acc: 0.9296875 precision: 0.90625 recall: 0.8942731022834778
loss: 0.116505  [147200/162770] time: 16.192627429962158 acc: 0.932812511920929 precision: 0.8918918967247009 recall: 0.9124423861503601
loss: 0.095605  [153600/162770] time: 16.205593824386597 acc: 0.9609375 precision: 0.9375 recall: 0.9502262473106384
loss: 0.156612  [160000/162770] time: 16.24087619781494 acc: 0.934374988079071 precision: 0.8878923654556274 recall: 0.9209302067756653
Valid | Error: 
 Accuracy: 0.932275, Precision: 0.920943, Recall: 0.882534, Avg loss

Valid | Error: 
 Accuracy: 0.934234, Precision: 0.910748, Recall: 0.900396, Avg loss: 0.178436 

Epoch 8
-------------------------------
loss: 0.063598  [    0/162770] time: 0.17516446113586426 acc: 0.971875011920929 precision: 0.9819819927215576 recall: 0.9396551847457886
loss: 0.042954  [ 6400/162770] time: 16.49529480934143 acc: 0.9781249761581421 precision: 0.9647576808929443 recall: 0.9733333587646484
loss: 0.039503  [12800/162770] time: 16.35405158996582 acc: 0.984375 precision: 0.9779735803604126 recall: 0.9779735803604126
loss: 0.052340  [19200/162770] time: 16.2877140045166 acc: 0.981249988079071 precision: 0.9700854420661926 recall: 0.9784482717514038
loss: 0.048510  [25600/162770] time: 16.280375242233276 acc: 0.981249988079071 precision: 0.9830508232116699 recall: 0.9666666388511658
loss: 0.047752  [32000/162770] time: 16.29220700263977 acc: 0.9828125238418579 precision: 0.9732142686843872 recall: 0.9775784611701965
loss: 0.060462  [38400/162770] time: 16.461535215377808 ac

loss: 0.028769  [44800/162770] time: 16.40094304084778 acc: 0.987500011920929 precision: 0.9872881174087524 recall: 0.9789915680885315
loss: 0.047316  [51200/162770] time: 16.318308115005493 acc: 0.984375 precision: 0.9736841917037964 recall: 0.982300877571106
loss: 0.047358  [57600/162770] time: 16.224337577819824 acc: 0.979687511920929 precision: 0.9656652212142944 recall: 0.97826087474823
loss: 0.027161  [64000/162770] time: 16.14490818977356 acc: 0.9937499761581421 precision: 0.9954751133918762 recall: 0.9865471124649048
loss: 0.035526  [70400/162770] time: 16.355412483215332 acc: 0.984375 precision: 0.9746835231781006 recall: 0.9829787015914917
loss: 0.048794  [76800/162770] time: 16.2732093334198 acc: 0.9828125238418579 precision: 0.960869550704956 recall: 0.9910314083099365
loss: 0.035678  [83200/162770] time: 16.362791299819946 acc: 0.9859374761581421 precision: 0.9831932783126831 recall: 0.9790794849395752
loss: 0.037051  [89600/162770] time: 16.373429775238037 acc: 0.98281252

In [11]:
torch.save(trainer._NNUtil__model.state_dict(), "resnet34_celeba10attr_10e.pt")

In [10]:
trainer._NNUtil__model

MultilabelResnetClassifier(
  (resnet): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affi