In [2]:
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, random_split
from torchmetrics.classification import BinaryAccuracy

import numpy as np
import torch
import h5py
import time
import sys
sys.path.append('..')
sys.path.append('../stylegan3')

from models.LatentFeatureExtractor import LatentFeatureExtractor

torch.set_printoptions(sci_mode=False)

In [3]:
torch.manual_seed(0)
data_path = '/home/robert/data/diploma-thesis/datasets/stylegan3/tpsi_1/latents/sample_z.h5'
n_classes = 10

data = None
with h5py.File(data_path, 'r') as f:
    data = f['z'][:]

# dataset = TensorDataset(torch.Tensor(data),torch.randint(0,2,(len(data), n_classes)).to(torch.float32))
dataset = TensorDataset(torch.Tensor(data),torch.Tensor([0.,1.,0.,0.,0.,0.,0.,0.,0.,0.,]).repeat(len(data),1))
train_data, valid_data, test_data = random_split(dataset, [0.8, 0.1, 0.1])

In [4]:
data.shape

(256000, 512)

In [5]:
batch_size = 8
num_epochs = 10
learning_rate = 1e-7
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

accuracy = BinaryAccuracy()

model = LatentFeatureExtractor(n_classes=n_classes).cuda()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate)

save_filename = 'latent_feature_extractor_a.pt'

loss_hist = {}
accuracy_hist = {}

loss_hist['train'] = []
loss_hist['valid'] = []
loss_hist['test'] = []
accuracy_hist['train'] = []
accuracy_hist['valid'] = []
accuracy_hist['test'] = []
best_valid_loss = np.inf

In [6]:
train_loss, train_acc = 0, 0
valid_loss, valid_acc = 0, 0
test_loss, test_acc = 0, 0

for epoch in range(num_epochs):
    num_batches = len(train_dataloader)
    
    print(f'Epoch {epoch}:')
    model.train()
    start = time.time()
    for batch, data in enumerate(train_dataloader):
        x, y = data
        x, y = x.cuda(), y.cuda()
               
        preds = model(x, y)
        
        loss = criterion(preds, y)
        acc = accuracy(preds.cpu(), y.cpu())
        
        train_loss += loss.item()
        train_acc += acc
       
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if batch % 10 == 0:
            end = time.time()
            loss, current = loss.item(), batch * len(x)
            print(f'loss: {loss:>7f} acc: {acc} [{current:>5d}/{batch_size*num_batches:>5d}] time: {end-start}')
            start = time.time()

    train_loss /= num_batches
    train_acc /= num_batches
    loss_hist['train'].append(train_loss)
    accuracy_hist['train'].append(train_acc)
    
    num_batches = len(valid_dataloader)
    model.eval()
    with torch.no_grad():
        for data in valid_dataloader:
            x, y = data
            x, y = x.cuda(), y.cuda()
            
            preds = model(x, y)
            
            valid_loss += criterion(preds, y).item()
            valid_acc += accuracy(preds.cpu(), y.cpu())
    
    valid_loss /= num_batches
    valid_acc /= num_batches
    loss_hist['valid'].append(valid_loss)
    accuracy_hist['valid'].append(valid_acc)
    
    if loss_hist['valid'][-1] < best_valid_loss:
        torch.save(model.state_dict(), save_filename)
        best_valid_loss = loss_hist['valid'][-1]
   
    print(f'epoch [{epoch + 1}/{num_epochs}], loss:{loss_hist["train"][-1]}, valid_loss:{loss_hist["valid"][-1]}, acc:{accuracy_hist["train"][-1]}, valid_acc:{accuracy_hist["valid"][-1]}')

Epoch 0:
Setting up PyTorch plugin "bias_act_plugin"... Done.
Setting up PyTorch plugin "filtered_lrelu_plugin"... Done.
loss: 0.373710 acc: 0.6000000238418579 [    0/204800] time: 1.4310145378112793
loss: 0.362383 acc: 0.625 [   80/204800] time: 6.316206693649292
loss: 0.385343 acc: 0.6000000238418579 [  160/204800] time: 6.20276665687561
loss: 0.393765 acc: 0.5874999761581421 [  240/204800] time: 6.262559413909912
loss: 0.364748 acc: 0.612500011920929 [  320/204800] time: 6.2930192947387695
loss: 0.348779 acc: 0.637499988079071 [  400/204800] time: 6.307431936264038
loss: 0.334367 acc: 0.612500011920929 [  480/204800] time: 6.281689882278442
loss: 0.397678 acc: 0.574999988079071 [  560/204800] time: 6.272778034210205
loss: 0.334579 acc: 0.637499988079071 [  640/204800] time: 6.329864025115967
loss: 0.365992 acc: 0.637499988079071 [  720/204800] time: 6.3133704662323
loss: 0.364078 acc: 0.6000000238418579 [  800/204800] time: 6.3413615226745605
loss: 0.388613 acc: 0.6000000238418579 [

loss: 0.380325 acc: 0.5874999761581421 [ 8800/204800] time: 6.452470302581787
loss: 0.359598 acc: 0.625 [ 8880/204800] time: 6.474533557891846
loss: 0.366207 acc: 0.625 [ 8960/204800] time: 6.433427333831787
loss: 0.361845 acc: 0.625 [ 9040/204800] time: 6.4745988845825195
loss: 0.382591 acc: 0.6000000238418579 [ 9120/204800] time: 6.545551538467407
loss: 0.376134 acc: 0.612500011920929 [ 9200/204800] time: 6.512231111526489
loss: 0.390469 acc: 0.574999988079071 [ 9280/204800] time: 6.532720565795898
loss: 0.367139 acc: 0.612500011920929 [ 9360/204800] time: 6.528401613235474
loss: 0.363170 acc: 0.612500011920929 [ 9440/204800] time: 6.496868371963501
loss: 0.366393 acc: 0.6000000238418579 [ 9520/204800] time: 6.463238716125488
loss: 0.364799 acc: 0.612500011920929 [ 9600/204800] time: 6.504942417144775
loss: 0.358267 acc: 0.625 [ 9680/204800] time: 6.424196243286133
loss: 0.381752 acc: 0.6000000238418579 [ 9760/204800] time: 6.4021055698394775
loss: 0.395698 acc: 0.574999988079071 [ 9

loss: 0.388109 acc: 0.612500011920929 [17600/204800] time: 6.419998645782471
loss: 0.367845 acc: 0.612500011920929 [17680/204800] time: 6.369235277175903
loss: 0.354056 acc: 0.637499988079071 [17760/204800] time: 6.427839517593384
loss: 0.377225 acc: 0.5625 [17840/204800] time: 6.404646635055542
loss: 0.388173 acc: 0.574999988079071 [17920/204800] time: 6.4379777908325195
loss: 0.389613 acc: 0.6000000238418579 [18000/204800] time: 6.411391496658325
loss: 0.395140 acc: 0.5874999761581421 [18080/204800] time: 6.370588779449463
loss: 0.445952 acc: 0.5249999761581421 [18160/204800] time: 6.384347200393677
loss: 0.321827 acc: 0.6625000238418579 [18240/204800] time: 6.445890188217163
loss: 0.407723 acc: 0.5625 [18320/204800] time: 6.4656572341918945
loss: 0.369491 acc: 0.5874999761581421 [18400/204800] time: 6.380059719085693
loss: 0.385759 acc: 0.5625 [18480/204800] time: 6.439954042434692
loss: 0.384557 acc: 0.6000000238418579 [18560/204800] time: 6.525731325149536
loss: 0.405140 acc: 0.57

loss: 0.377968 acc: 0.612500011920929 [26240/204800] time: 6.436474561691284
loss: 0.374135 acc: 0.6000000238418579 [26320/204800] time: 6.4141364097595215
loss: 0.359842 acc: 0.612500011920929 [26400/204800] time: 6.382351398468018
loss: 0.325520 acc: 0.6499999761581421 [26480/204800] time: 7.095299243927002
loss: 0.393953 acc: 0.5874999761581421 [26560/204800] time: 7.003962755203247
loss: 0.367993 acc: 0.612500011920929 [26640/204800] time: 6.453750133514404
loss: 0.380941 acc: 0.6000000238418579 [26720/204800] time: 6.4413087368011475
loss: 0.375107 acc: 0.5874999761581421 [26800/204800] time: 6.46342134475708
loss: 0.368962 acc: 0.612500011920929 [26880/204800] time: 6.474572420120239
loss: 0.383902 acc: 0.5874999761581421 [26960/204800] time: 6.420130252838135
loss: 0.405285 acc: 0.5874999761581421 [27040/204800] time: 6.392674684524536
loss: 0.357861 acc: 0.6000000238418579 [27120/204800] time: 6.395679235458374
loss: 0.346248 acc: 0.625 [27200/204800] time: 6.390385389328003
lo

loss: 0.373603 acc: 0.6000000238418579 [34960/204800] time: 6.426630973815918
loss: 0.389397 acc: 0.5625 [35040/204800] time: 6.45022988319397
loss: 0.385817 acc: 0.6000000238418579 [35120/204800] time: 6.407751798629761
loss: 0.393480 acc: 0.5625 [35200/204800] time: 6.383366823196411
loss: 0.379115 acc: 0.612500011920929 [35280/204800] time: 6.81754469871521
loss: 0.348228 acc: 0.625 [35360/204800] time: 6.493281126022339
loss: 0.377146 acc: 0.5874999761581421 [35440/204800] time: 6.410213470458984
loss: 0.369965 acc: 0.612500011920929 [35520/204800] time: 6.4350714683532715
loss: 0.400963 acc: 0.574999988079071 [35600/204800] time: 6.4189231395721436
loss: 0.365239 acc: 0.612500011920929 [35680/204800] time: 6.400099515914917
loss: 0.401655 acc: 0.5874999761581421 [35760/204800] time: 6.457127094268799
loss: 0.315174 acc: 0.675000011920929 [35840/204800] time: 6.418212413787842
loss: 0.354641 acc: 0.625 [35920/204800] time: 6.393752098083496
loss: 0.397143 acc: 0.574999988079071 [36

loss: 0.395568 acc: 0.6000000238418579 [43680/204800] time: 6.396007537841797
loss: 0.412084 acc: 0.574999988079071 [43760/204800] time: 6.362169981002808
loss: 0.406732 acc: 0.550000011920929 [43840/204800] time: 6.428506135940552
loss: 0.374915 acc: 0.612500011920929 [43920/204800] time: 6.60883903503418
loss: 0.350791 acc: 0.637499988079071 [44000/204800] time: 6.442534446716309
loss: 0.346409 acc: 0.625 [44080/204800] time: 6.394577741622925
loss: 0.387921 acc: 0.6000000238418579 [44160/204800] time: 6.442714691162109
loss: 0.384256 acc: 0.574999988079071 [44240/204800] time: 6.397690057754517
loss: 0.392847 acc: 0.574999988079071 [44320/204800] time: 6.409817457199097
loss: 0.384289 acc: 0.5874999761581421 [44400/204800] time: 6.414618492126465
loss: 0.382803 acc: 0.6000000238418579 [44480/204800] time: 6.417996644973755
loss: 0.376664 acc: 0.6000000238418579 [44560/204800] time: 6.464407205581665
loss: 0.428954 acc: 0.5375000238418579 [44640/204800] time: 6.70520544052124
loss: 0

loss: 0.345688 acc: 0.625 [52400/204800] time: 6.401997089385986
loss: 0.360406 acc: 0.612500011920929 [52480/204800] time: 6.430078744888306
loss: 0.356724 acc: 0.625 [52560/204800] time: 6.509523391723633
loss: 0.386305 acc: 0.5625 [52640/204800] time: 6.470641851425171
loss: 0.415485 acc: 0.5625 [52720/204800] time: 6.409593105316162
loss: 0.391613 acc: 0.6000000238418579 [52800/204800] time: 6.389594316482544
loss: 0.327883 acc: 0.6499999761581421 [52880/204800] time: 6.370729923248291
loss: 0.343443 acc: 0.625 [52960/204800] time: 6.3907952308654785
loss: 0.364593 acc: 0.612500011920929 [53040/204800] time: 6.4390175342559814
loss: 0.357446 acc: 0.6000000238418579 [53120/204800] time: 6.479965925216675
loss: 0.366297 acc: 0.612500011920929 [53200/204800] time: 6.381576776504517
loss: 0.353267 acc: 0.625 [53280/204800] time: 6.473893404006958
loss: 0.369441 acc: 0.6000000238418579 [53360/204800] time: 6.416494607925415
loss: 0.377102 acc: 0.6000000238418579 [53440/204800] time: 6.4

loss: 0.389106 acc: 0.5874999761581421 [61200/204800] time: 6.590834379196167
loss: 0.425428 acc: 0.550000011920929 [61280/204800] time: 6.477105617523193
loss: 0.440298 acc: 0.512499988079071 [61360/204800] time: 6.766354084014893
loss: 0.357776 acc: 0.625 [61440/204800] time: 6.732240915298462
loss: 0.369133 acc: 0.612500011920929 [61520/204800] time: 6.618251085281372
loss: 0.347333 acc: 0.625 [61600/204800] time: 6.621847867965698
loss: 0.397961 acc: 0.5874999761581421 [61680/204800] time: 7.204516887664795
loss: 0.364256 acc: 0.625 [61760/204800] time: 6.735236167907715
loss: 0.350227 acc: 0.625 [61840/204800] time: 6.51774525642395
loss: 0.363343 acc: 0.612500011920929 [61920/204800] time: 6.584428787231445
loss: 0.378671 acc: 0.5874999761581421 [62000/204800] time: 6.440995931625366
loss: 0.354253 acc: 0.625 [62080/204800] time: 6.41595721244812
loss: 0.376148 acc: 0.6000000238418579 [62160/204800] time: 6.511575222015381
loss: 0.339864 acc: 0.637499988079071 [62240/204800] time

loss: 0.361838 acc: 0.612500011920929 [70000/204800] time: 6.5042054653167725
loss: 0.368026 acc: 0.6000000238418579 [70080/204800] time: 6.737061977386475
loss: 0.418165 acc: 0.5625 [70160/204800] time: 6.400595664978027
loss: 0.341954 acc: 0.6499999761581421 [70240/204800] time: 6.418655633926392
loss: 0.316669 acc: 0.6875 [70320/204800] time: 6.38616943359375
loss: 0.368537 acc: 0.5874999761581421 [70400/204800] time: 6.409510135650635
loss: 0.323232 acc: 0.6625000238418579 [70480/204800] time: 6.4643330574035645
loss: 0.362318 acc: 0.625 [70560/204800] time: 6.415278673171997
loss: 0.395850 acc: 0.5874999761581421 [70640/204800] time: 6.4275267124176025
loss: 0.366122 acc: 0.6000000238418579 [70720/204800] time: 6.435744285583496
loss: 0.393473 acc: 0.5625 [70800/204800] time: 6.460028171539307
loss: 0.418699 acc: 0.5625 [70880/204800] time: 6.556744575500488
loss: 0.369612 acc: 0.6000000238418579 [70960/204800] time: 6.400927305221558
loss: 0.391192 acc: 0.5625 [71040/204800] time

loss: 0.362763 acc: 0.625 [78800/204800] time: 6.55649995803833
loss: 0.371837 acc: 0.6000000238418579 [78880/204800] time: 6.573601722717285
loss: 0.335959 acc: 0.625 [78960/204800] time: 6.5149452686309814
loss: 0.358293 acc: 0.637499988079071 [79040/204800] time: 7.064599514007568
loss: 0.363927 acc: 0.637499988079071 [79120/204800] time: 7.6399760246276855
loss: 0.369534 acc: 0.612500011920929 [79200/204800] time: 7.815024137496948
loss: 0.347881 acc: 0.637499988079071 [79280/204800] time: 7.56390118598938
loss: 0.394091 acc: 0.5874999761581421 [79360/204800] time: 6.713524103164673
loss: 0.369454 acc: 0.5874999761581421 [79440/204800] time: 6.52350926399231
loss: 0.368370 acc: 0.612500011920929 [79520/204800] time: 7.885639429092407
loss: 0.388032 acc: 0.5874999761581421 [79600/204800] time: 6.852508783340454
loss: 0.389381 acc: 0.6000000238418579 [79680/204800] time: 6.400658369064331
loss: 0.400615 acc: 0.5625 [79760/204800] time: 6.4244372844696045
loss: 0.365947 acc: 0.6125000

loss: 0.376499 acc: 0.6000000238418579 [87520/204800] time: 6.436397552490234
loss: 0.341503 acc: 0.637499988079071 [87600/204800] time: 6.395236492156982
loss: 0.410533 acc: 0.5625 [87680/204800] time: 6.398730278015137
loss: 0.332996 acc: 0.637499988079071 [87760/204800] time: 6.420785903930664
loss: 0.420976 acc: 0.5375000238418579 [87840/204800] time: 6.444481611251831
loss: 0.376205 acc: 0.5874999761581421 [87920/204800] time: 6.404249668121338
loss: 0.400826 acc: 0.5874999761581421 [88000/204800] time: 6.434116840362549
loss: 0.362686 acc: 0.612500011920929 [88080/204800] time: 6.4402172565460205
loss: 0.378445 acc: 0.612500011920929 [88160/204800] time: 6.445719480514526
loss: 0.415460 acc: 0.550000011920929 [88240/204800] time: 6.4043354988098145
loss: 0.354841 acc: 0.625 [88320/204800] time: 6.515537261962891
loss: 0.343937 acc: 0.625 [88400/204800] time: 6.410535573959351
loss: 0.395173 acc: 0.5874999761581421 [88480/204800] time: 6.494720697402954
loss: 0.355759 acc: 0.625 [

KeyboardInterrupt: 