In [1]:
from torch.utils.data import DataLoader, TensorDataset, random_split

import numpy as np
import torch
import h5py
import sys
sys.path.append('..')
sys.path.append('../stylegan3')

from utils.L2FPipeline import L2FPipeline
from stylegan_generator import StyleGANGenerator
from models.MultilabelResnetClassifier import MultilabelResnetClassifier
from models.LatentFeatureExtractor import LatentFeatureExtractor

In [3]:
torch.manual_seed(0)

<torch._C.Generator at 0x7f1860548650>

In [10]:
data_path = '/home/robert/data/diploma-thesis/datasets/stylegan3/tpsi_1/latents/sample_z.h5'
n_classes = 10

data = None
with h5py.File(data_path, 'r') as f:
    data = f['z'][:]

dataset = TensorDataset(torch.Tensor(data),torch.randint(0,2,(len(data), n_classes)))
train_data, valid_data, test_data = random_split(dataset, [0.8, 0.1, 0.1])
print(dataset[0])

(tensor([ 0.0863, -0.2087, -0.0529,  1.9157, -0.7995, -1.2684, -0.7082, -0.4558,
        -0.4521,  0.9718,  0.8288,  0.5547, -0.8654, -0.0424, -1.6860, -0.4464,
        -0.1741,  1.5128,  0.0182, -1.3632, -0.5289,  1.5866,  2.2250, -0.5918,
         0.6909, -0.7142,  0.4625, -1.2063,  0.6286,  1.0684,  0.0952, -0.8414,
        -0.0596,  0.2289,  0.2742, -0.4522,  2.3054,  1.0486, -1.0632,  1.8720,
        -0.2277,  0.7040,  0.9490,  0.5992,  0.6675, -0.8961,  2.5029, -0.9136,
         0.6681, -2.6288, -0.3084,  0.6764,  1.1316, -0.5700, -0.6827,  0.6213,
        -1.0184,  0.0625,  0.3167,  0.5707, -0.6916, -2.1364,  2.3244,  1.0641,
         0.2128, -0.0191,  0.2481, -1.8346, -0.1900, -1.5825, -0.5059,  2.1864,
        -0.7260,  1.1937,  0.6262, -2.6140,  0.6539, -1.5208,  0.0103, -1.4978,
        -0.5254, -2.6559, -1.9320,  1.3012,  0.1731, -1.4431, -2.0771,  0.3184,
         0.6851, -1.2064,  0.1635,  0.6987,  0.4779, -0.8020,  1.0947, -0.2922,
         1.3414, -1.4261, -1.7614, -0.4

In [9]:
data.shape

(256000, 512)

In [None]:
network_pkl = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-256x256.pkl'
classifier_weights = '/home/robert/data/diploma-thesis/weights/classfier/resnet34_celeba10attr_10e.pt'

generator = StyleGANGenerator(network_pkl)
classifier = MultilabelResnetClassifier(n_classes=10)
classifier.load_state_dict(torch.load(classifier_weights))

pipeline = L2FPipeline(generator = generator, classifier = classifier, tpsi=0.7)

In [None]:
batch_size = 64
num_epochs = 10
learning_rate = 1e-4
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

model = LatentFeatureExtractor(n_classes=n_classes).cuda()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate)

In [None]:
for epoch in range(num_epochs):
    model.train()
    for data in train_dataloader:
        x, y = data
        x = x.cuda()
               
        z = model(x, y)
        preds, _ = pipeline.transform(z)
        
        train_loss = criterion(preds, y)
       
        train_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
    loss['train'].append(train_loss.data.item())
    
    model.eval()
    with torch.no_grad():
        for data in valid_dataloader:
            x, y = data
            x = x.cuda()
            y = y.cuda()
            
            output,latent = model(x, y)            
            val_loss = criterion(output, x)
            
    loss['valid'].append(val_loss.data.item())
    
    if loss['valid'][-1] < best_valid_loss:
        torch.save(model.state_dict(), save_filename)
        best_valid_loss = loss['valid'][-1]
   
    print(f'epoch [{epoch + 1}/{num_epochs}], loss:{loss["train"][-1]}, valid_loss:{loss["valid"][-1]}')