In [1]:
import torch
import pickle
import os
from PIL import Image
from torch.nn.functional import softmax
import torch.nn as nn
import numpy as np
from tqdm.notebook import tqdm
import torchvision.transforms as transforms
from IPython.utils import io

## General Setup

In [2]:
import platform
if platform.system() == 'Darwin':
    DATA_PATH = "/Users/maltegenschow/Documents/Uni/Thesis/Data.nosync"
    ROOT_PATH = "/Users/maltegenschow/Documents/Uni/Thesis/Thesis"
elif platform.system() == 'Linux':
    DATA_PATH = "/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Data.nosync"
    ROOT_PATH = "/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Thesis"

current_wd = os.getcwd()

In [3]:
os.chdir(f"{ROOT_PATH}/4_Assessor/Category_Assessor/DinoV2")
from helpers_pipeline import *
from helper_DinoV2_Embeddings import *
id2label = pickle.load(open("id2label_dicts/category_id2label.pkl", "rb"))
label2id = {v: k for k, v in id2label.items()}
os.chdir(current_wd)

Using devices: DinoV2 device: cuda | SG2 device: cuda | General device: cuda


In [4]:
set_seed(42)
dino_device, sg2_device, device = set_device()

Using devices: DinoV2 device: cuda | SG2 device: cuda | General device: cuda


## Fixed Models Setup

In [5]:
# Products data and latents
target_feature = 'category'
df, latents = load_latents(target_feature)
latents = latents.to(sg2_device)

# SG2-Ada Generator
G = setup_generator()
G = G.to(sg2_device)

# DinoV2 Model 
dino_processor, dino_model = setup_dinov2()
dino_model = dino_model.to(dino_device)

# Attribute Classifier Model
classifier = load_classifier()
classifier = classifier.to(device)

In [6]:
# Freeeze all non-relevant model weights and set to eval mode
for param in G.parameters():
    param.requires_grad = False
G.eval()
for param in dino_model.parameters():
    param.requires_grad = False
dino_model.eval()
for param in classifier.parameters():
    param.requires_grad = False
classifier.eval();

In [7]:
# Print devoce for each model: 
print(f"Generator: {next(G.parameters()).device} | Requires Grad: {next(G.parameters()).requires_grad}")
print(f"DinoV2: {next(dino_model.parameters()).device} | Requires Grad: {next(dino_model.parameters()).requires_grad}")
print(f"Classifier: {next(classifier.parameters()).device} | Requires Grad: {next(classifier.parameters()).requires_grad}")

Generator: cuda:0 | Requires Grad: False
DinoV2: cuda:0 | Requires Grad: False
Classifier: cuda:0 | Requires Grad: False


### Check gradient flow outside of Model Class

In [None]:
from IPython.utils import io
torch.set_printoptions(sci_mode=False)

fixed_alpha = 0.9
fixed_class_idx = 7
latent = latents[0]

directions = torch.randn([8,16,512], device=latent.device, requires_grad=True)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam([directions], lr=0.1)


In [None]:
for i in range(5):
    optimizer.zero_grad()
    with io.capture_output() as captured:
        real_img = G.synthesis(latent, noise_mode='const')
    real_dino_input = dino_processor(real_img)
    real_dino_embedding = dino_model(real_dino_input)['pooler_output']
    real_scores = classifier(real_dino_embedding)
    real_probs = softmax(real_scores, dim = 1).squeeze(0)
    real_class_prob  = real_probs[fixed_class_idx]

    transformed_latent = latent + fixed_alpha * directions[fixed_class_idx]
    with io.capture_output() as captured:
        trans_img = G.synthesis(transformed_latent, noise_mode='const')
    trans_dino_input = dino_processor(trans_img)
    trans_dino_embedding = dino_model(trans_dino_input)['pooler_output']
    trans_scores = classifier(trans_dino_embedding)
    trans_probs = softmax(trans_scores, dim = 1).squeeze(0)
    trans_class_prob  = trans_probs[fixed_class_idx]

    loss = criterion((real_class_prob + fixed_alpha), trans_class_prob)
    loss.backward()
    optimizer.step()
    print(f"Step {i}: Loss: {loss.item()} | Directions sum: {directions.sum()}")


### Manipulation Model

In [12]:
class Editor(torch.nn.Module):

    def __init__(self, generator, dino_model, dino_processor, classifier, id2label, label2id, ):
        super(Editor, self).__init__()

        self.generator = generator
        self.dino_model = dino_model
        self.dino_processor = dino_processor
        self.classifier = classifier

        self.id2label = id2label
        self.label2id = label2id   
        self.num_classes = len(self.id2label)
        self.directions_dimension = [generator.mapping.num_ws, generator.mapping.w_dim]

        self.directions = nn.Parameter(torch.randn(self.num_classes,self.directions_dimension[0], self.directions_dimension[1]), requires_grad=True)
        self.alphas = np.arange(0,1,0.1)
    
    def forward(self, latent, class_idx=None, alpha=None):

        if class_idx == None:
            class_idx = torch.randint(0, self.num_classes, (1,))
        
        if alpha == None:
            alpha = torch.tensor(np.round(np.random.choice(self.alphas),2))

        # Get scores for original image
        with io.capture_output() as captured:
            real_img = self.generator.synthesis(latent, noise_mode='const')
        real_dino_input = self.dino_processor(real_img)
        real_dino_embedding = self.dino_model(real_dino_input)['pooler_output']
        real_scores = self.classifier(real_dino_embedding)
        real_probs = softmax(real_scores, dim=0).squeeze(0)
        real_class_prob  = real_probs[class_idx]

        # Get scores for transformed image
        transformed_latent = latent + alpha * self.directions[class_idx].to(latent.device)
        with io.capture_output() as captured:
            transformed_img = self.generator.synthesis(transformed_latent, noise_mode='const')
        transformed_dino_input = self.dino_processor(transformed_img)
        transformed_dino_embedding = self.dino_model(transformed_dino_input)['pooler_output']
        transformed_scores = self.classifier(transformed_dino_embedding)
        transformed_probs = softmax(transformed_scores, dim=0).squeeze(0)
        transformed_class_prob = transformed_probs[class_idx]

        return real_class_prob, transformed_class_prob, class_idx, alpha
        

In [13]:
model = Editor(G, dino_model, dino_processor, classifier, id2label, label2id)

### Test on one example only

- First latent code in the dataset
- Real Label: Day Dress: class_idx = 0
- Target Label: Denim Dress: class_idx = 7

In [14]:
fixed_alpha = 0.1
fixed_class_idx = 7

model = Editor(G, dino_model, dino_processor, classifier, id2label, label2id) 
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5)

for i in range(5):
    optimizer.zero_grad()
    # Forward:
    real_class_prob, transformed_class_prob, class_idx, alpha = model(latents[0], 7, 0.1)
    # Loss:
    loss = criterion(transformed_class_prob, (real_class_prob + fixed_alpha))
    # Backward:
    loss.backward()
    optimizer.step()

    # Print updates: 
    print(f"Step: {i} | Loss: {np.round(loss.item(),4)} | Real Class Prob: {real_class_prob.item()} | Transformed Class Prob: {transformed_class_prob.item()}")
    print(f"Sum of walking direction: {torch.sum(model.directions[class_idx])} | Alpha: {alpha}")


Step: 0 | Loss: 0.01 | Real Class Prob: 1.0 | Transformed Class Prob: 1.0
Sum of walking direction: -39.65654754638672 | Alpha: 0.1
Step: 1 | Loss: 0.01 | Real Class Prob: 1.0 | Transformed Class Prob: 1.0
Sum of walking direction: -39.65654754638672 | Alpha: 0.1
Step: 2 | Loss: 0.01 | Real Class Prob: 1.0 | Transformed Class Prob: 1.0
Sum of walking direction: -39.65654754638672 | Alpha: 0.1
Step: 3 | Loss: 0.01 | Real Class Prob: 1.0 | Transformed Class Prob: 1.0
Sum of walking direction: -39.65654754638672 | Alpha: 0.1
Step: 4 | Loss: 0.01 | Real Class Prob: 1.0 | Transformed Class Prob: 1.0
Sum of walking direction: -39.65654754638672 | Alpha: 0.1


In [None]:
latents.requires_grad = True

In [None]:
# Assuming `latents` is a tensor with requires_grad=True
latent = latents[0]
generated_image = model.generator.synthesis(latent, noise_mode='const')
# Simulate the operation
out = model.gan_output_to_image(generated_image)

# Perform a dummy operation and check gradients
output = pil_to_tensor.sum()
output.backward()

print("Gradient to latents after image conversion:", latents.grad)  # This will likely show None or zero
