In [1]:
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, random_split
from torchmetrics.classification import BinaryAccuracy

import numpy as np
import torch
import h5py
import time
import sys
sys.path.append('..')
sys.path.append('../stylegan3')

from models.LatentFeatureExtractor import LatentFeatureExtractor

torch.set_printoptions(sci_mode=False)

In [2]:
torch.manual_seed(0)
data_path = '/home/robert/data/diploma-thesis/datasets/stylegan3/tpsi_1/latents/sample_z.h5'
n_classes = 10

data = None
with h5py.File(data_path, 'r') as f:
    data = f['z'][:]

dataset = TensorDataset(torch.Tensor(data),torch.randint(0,2,(len(data), n_classes)).to(torch.float32))
train_data, valid_data, test_data = random_split(dataset, [0.8, 0.1, 0.1])

In [3]:
data.shape

(256000, 512)

In [4]:
batch_size = 8
num_epochs = 10
learning_rate = 1e-4
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

accuracy = BinaryAccuracy()

model = LatentFeatureExtractor(n_classes=n_classes).cuda()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate)

save_filename = 'latent_feature_extractor_a.pt'

loss_hist = {}
accuracy_hist = {}

loss_hist['train'] = []
loss_hist['valid'] = []
loss_hist['test'] = []
accuracy_hist['train'] = []
accuracy_hist['valid'] = []
accuracy_hist['test'] = []
best_valid_loss = np.inf

In [5]:
train_loss, train_acc = 0, 0
valid_loss, valid_acc = 0, 0
test_loss, test_acc = 0, 0

for epoch in range(num_epochs):
    num_batches = len(train_dataloader)
    
    print(f'Epoch {epoch}:')
    model.train()
    start = time.time()
    for batch, data in enumerate(train_dataloader):
        x, y = data
        x, y = x.cuda(), y.cuda()
               
        preds = model(x, y)
        
        loss = criterion(preds, y)
        acc = accuracy(preds.cpu(), y.cpu())
        
        train_loss += loss.item()
        train_acc += acc
       
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if batch % 10 == 0:
            end = time.time()
            loss, current = loss.item(), batch * len(x)
            print(f'loss: {loss:>7f} acc: {acc} [{current:>5d}/{batch_size*num_batches:>5d}] time: {end-start}')
            start = time.time()

    train_loss /= num_batches
    train_acc /= num_batches
    loss_hist['train'].append(train_loss)
    accuracy_hist['train'].append(train_acc)
    
    num_batches = len(valid_dataloader)
    model.eval()
    with torch.no_grad():
        for data in valid_dataloader:
            x, y = data
            x, y = x.cuda(), y.cuda()
            
            preds = model(x, y)
            
            valid_loss += criterion(preds, y).item()
            valid_acc += accuracy(preds.cpu(), y.cpu())
    
    valid_loss /= num_batches
    valid_acc /= num_batches
    loss_hist['valid'].append(valid_loss)
    accuracy_hist['valid'].append(valid_acc)
    
    if loss_hist['valid'][-1] < best_valid_loss:
        torch.save(model.state_dict(), save_filename)
        best_valid_loss = loss_hist['valid'][-1]
   
    print(f'epoch [{epoch + 1}/{num_epochs}], loss:{loss_hist["train"][-1]}, valid_loss:{loss_hist["valid"][-1]}, acc:{accuracy_hist["train"][-1]}, valid_acc:{accuracy_hist["valid"][-1]}')

Epoch 0:
Setting up PyTorch plugin "bias_act_plugin"... Done.
Setting up PyTorch plugin "filtered_lrelu_plugin"... Done.
loss: 0.416806 acc: 0.5874999761581421 [    0/204800] time: 1.2996187210083008
loss: 0.536213 acc: 0.42500001192092896 [   80/204800] time: 7.3612377643585205
loss: 0.445213 acc: 0.5249999761581421 [  160/204800] time: 7.323360204696655
loss: 0.407667 acc: 0.5874999761581421 [  240/204800] time: 7.518978118896484
loss: 0.472019 acc: 0.512499988079071 [  320/204800] time: 7.614176273345947
loss: 0.485736 acc: 0.48750001192092896 [  400/204800] time: 7.525393486022949
loss: 0.469791 acc: 0.512499988079071 [  480/204800] time: 7.479475975036621
loss: 0.505205 acc: 0.4375 [  560/204800] time: 7.390590667724609
loss: 0.434583 acc: 0.5249999761581421 [  640/204800] time: 7.3381571769714355
loss: 0.443591 acc: 0.48750001192092896 [  720/204800] time: 7.45112681388855
loss: 0.512557 acc: 0.4625000059604645 [  800/204800] time: 7.633774757385254
loss: 0.582507 acc: 0.38749998

KeyboardInterrupt: 