In [1]:
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, random_split

import numpy as np
import torch
import h5py
import time
import sys
sys.path.append('..')
sys.path.append('../stylegan3')

from utils.L2FPipeline import L2FPipeline
from stylegan_generator import StyleGANGenerator
from models.MultilabelResnetClassifier import MultilabelResnetClassifier
from models.LatentFeatureExtractor import LatentFeatureExtractor

In [2]:
torch.manual_seed(0)
data_path = '/home/robert/data/diploma-thesis/datasets/stylegan3/tpsi_1/latents/sample_z.h5'
n_classes = 10

data = None
with h5py.File(data_path, 'r') as f:
    data = f['z'][:]

dataset = TensorDataset(torch.Tensor(data),torch.randint(0,2,(len(data), n_classes)).to(torch.float32))
train_data, valid_data, test_data = random_split(dataset, [0.8, 0.1, 0.1])

In [3]:
data.shape

(256000, 512)

In [4]:
network_pkl = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-256x256.pkl'
classifier_weights = '/home/robert/data/diploma-thesis/weights/classfier/resnet34_celeba10attr_10e.pt'

generator = StyleGANGenerator(network_pkl)
classifier = MultilabelResnetClassifier(n_classes=10)
classifier.load_state_dict(torch.load(classifier_weights))

pipeline = L2FPipeline(generator = generator, classifier = classifier, tpsi=0.7)

In [5]:
batch_size = 16
num_epochs = 10
learning_rate = 1e-4
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

model = LatentFeatureExtractor(n_classes=n_classes).cuda()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate)

save_filename = 'latent_feature_extractor_a.pt'

loss = {}

loss['train'] = []
loss['valid'] = []
loss['test'] = []
best_valid_loss = np.inf

In [6]:
for epoch in range(num_epochs):
    model.train()
    start = time.time()
    for batch, data in enumerate(train_dataloader):
        x, y = data
        x, y = x.cuda(), y.cuda()
               
        z = model(x, y)
        preds, _ = pipeline.transform(z.cpu())
        
        train_loss = criterion(preds.cuda(), y)
       
        train_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if batch % 100 == 0:
            end = time.time()
            loss, current = train_loss.item(), batch * len(x)
            print(f'loss: {loss:>7f}  [{current:>5d}/{batch_size*len(x):>5d}] time: {end-start}')
            start = time.time()

    loss['train'].append(train_loss.data.item())
    
    model.eval()
    with torch.no_grad():
        for data in valid_dataloader:
            x, y = data
            x, y = x.cuda(), y.cuda()
            
            z = model(x, y)
            preds, _ = pipeline.transform(z)
            
            val_loss = criterion(preds.cuda(), y)
            
    loss['valid'].append(val_loss.data.item())
    
    if loss['valid'][-1] < best_valid_loss:
        torch.save(model.state_dict(), save_filename)
        best_valid_loss = loss['valid'][-1]
   
    print(f'epoch [{epoch + 1}/{num_epochs}], loss:{loss["train"][-1]}, valid_loss:{loss["valid"][-1]}')

tensor([[ 0.0000,  1.0000,  0.0000,  ..., -1.0005, -0.2181,  1.5756],
        [ 1.0000,  1.0000,  0.0000,  ..., -0.5479, -0.9351, -0.2589],
        [ 1.0000,  0.0000,  1.0000,  ...,  0.5089,  1.1540, -0.5031],
        ...,
        [ 1.0000,  1.0000,  1.0000,  ..., -0.5668,  1.1148,  0.9745],
        [ 1.0000,  1.0000,  0.0000,  ...,  0.0659,  0.0696,  0.7243],
        [ 0.0000,  1.0000,  0.0000,  ...,  1.8869, -0.4247,  0.8510]],
       device='cuda:0')
Setting up PyTorch plugin "bias_act_plugin"... Done.
Setting up PyTorch plugin "filtered_lrelu_plugin"... Done.
tensor([[9.9999e-01, 5.0539e-10, 9.9902e-01, 6.2888e-08, 4.3842e-08, 1.7732e-07,
         8.2307e-09, 1.0000e+00, 9.6137e-01, 1.5383e-13],
        [7.2168e-01, 1.5548e-11, 1.0000e+00, 1.5212e-07, 4.2468e-02, 1.1659e-06,
         2.4406e-13, 9.9953e-01, 9.6359e-01, 3.6735e-10],
        [4.8042e-02, 1.9135e-02, 9.9973e-01, 9.3433e-02, 3.6703e-04, 6.6802e-05,
         2.3147e-04, 5.7330e-01, 1.3113e-05, 2.7768e-07],
        [4.38

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn