In [1]:
import torch
import pickle
import os
from PIL import Image
from torch.nn.functional import softmax
import torch.nn as nn
import numpy as np
from tqdm.notebook import tqdm
import torchvision.transforms as transforms

## General Setup

In [2]:
import platform
if platform.system() == 'Darwin':
    DATA_PATH = "/Users/maltegenschow/Documents/Uni/Thesis/Data.nosync"
    ROOT_PATH = "/Users/maltegenschow/Documents/Uni/Thesis/Thesis"
elif platform.system() == 'Linux':
    DATA_PATH = "/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Data.nosync"
    ROOT_PATH = "/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Thesis"

current_wd = os.getcwd()

In [3]:
os.chdir(f"{ROOT_PATH}/4_Assessor/Category_Assessor/DinoV2")
from helpers_pipeline import *
id2label = pickle.load(open("id2label_dicts/category_id2label.pkl", "rb"))
label2id = {v: k for k, v in id2label.items()}
os.chdir(current_wd)

Using device: mps


In [4]:
set_seed(42)
device = set_device()

Using device: mps


## Fixed Models Setup

In [5]:
# Latent Codes for all images
target_feature = 'category'
df, latents = load_latents(target_feature)
# Generator
G = setup_generator()
#G = G.to(device)
# DinoV2 Backbone
dino_processor_old, dino_model = setup_dinov2()
# Trained Classifier model 
classifier = load_classifier() 
#classifier = classifier.to(device)

In [6]:
# Freeeze all non-relevant model weights and set to eval mode
for param in G.parameters():
    param.requires_grad = False
G.eval()
for param in dino_model.parameters():
    param.requires_grad = False
dino_model.eval()
for param in classifier.parameters():
    param.requires_grad = False
classifier.eval();

In [7]:
# Print devoce for each model: 
print(f"Generator: {next(G.parameters()).device} | Requires Grad: {next(G.parameters()).requires_grad}")
print(f"DinoV2: {next(dino_model.parameters()).device} | Requires Grad: {next(dino_model.parameters()).requires_grad}")
print(f"Classifier: {next(classifier.parameters()).device} | Requires Grad: {next(classifier.parameters()).requires_grad}")

Generator: mps:0 | Requires Grad: False
DinoV2: cpu | Requires Grad: False
Classifier: mps:0 | Requires Grad: False


### Change Classifier Function to be only Tensors and differentiable functions
-> Most important function to change: processor function of DinoV2 model. 

In [8]:
transform_pipeline = transforms.Compose([
    #transforms.Resize(256),  # Resize so the shortest side is 256
    #transforms.CenterCrop((224, 224)),  # Center crop to 224x224
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
])

In [9]:
gen_img = G.synthesis(latents[0], noise_mode='const')
gen_img = gen_img.to('cpu')
gen_processed = transform_pipeline(gen_img)
gen_img.shape

torch.Size([1, 3, 512, 512])

In [10]:
img = Image.open('/Users/maltegenschow/Documents/Uni/Thesis/Data.nosync/Zalando_Germany_Dataset/dresses/images/square_images/0FB21C03E-I11.jpg').convert('RGB')
img = transforms.ToTensor()(img.resize([512,512]))
img = img.unsqueeze(0)
img.shape

torch.Size([1, 3, 512, 512])

In [11]:
def dino_processor(input):
    if isinstance(input, str):
        img = Image.open(input).convert('RGB')
        img = transforms.ToTensor()(img.resize([512,512]))
        img = img.unsqueeze(0)

        processed_img = transform_pipeline(img)
    elif isinstance(input, torch.Tensor):
        processed_img = transform_pipeline(input)
    else:
        raise ValueError("Input must be either a string or a torch.Tensor")
    return processed_img

### Test the gradients flow: 

In [31]:
latent = latents[0]
directions = torch.randn([8, 16,512], device=device, requires_grad=True)

target = torch.tensor(1, device=device, dtype=torch.float32)

In [33]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam([directions], lr=0.01)

for i in range(20):
    transformed = latent + 0.1 * directions[0]
    gen_img = G.synthesis(transformed, noise_mode='const')
    gen_img = gen_img.to('cpu')
    gen_processed = dino_processor(gen_img)
    embedding = dino_model(gen_processed)['pooler_output']
    embedding = embedding.to(device)
    scores = classifier(embedding)
    scores = softmax(scores, dim=1)
    scores = scores[0][0]

    

    loss = criterion(scores, target)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f"Loss: {loss.item()} | Directions Sum: {directions.sum()}")

Loss: 0.21093101799488068 | Directions Sum: 181.53134155273438
Loss: 0.12418025732040405 | Directions Sum: 181.4152069091797
Loss: 0.11543690413236618 | Directions Sum: 183.40005493164062
Loss: 0.10273753106594086 | Directions Sum: 185.85565185546875
Loss: 0.0880921259522438 | Directions Sum: 188.21725463867188
Loss: 0.0784483253955841 | Directions Sum: 190.4754638671875
Loss: 0.07277145981788635 | Directions Sum: 192.72296142578125
Loss: 0.06608559936285019 | Directions Sum: 194.8636932373047
Loss: 0.06189819797873497 | Directions Sum: 196.83029174804688
Loss: 0.06023789942264557 | Directions Sum: 198.9141082763672
Loss: 0.05609115585684776 | Directions Sum: 200.9251708984375
Loss: 0.053432293236255646 | Directions Sum: 202.85855102539062
Loss: 0.049008943140506744 | Directions Sum: 204.70999145507812
Loss: 0.04497193545103073 | Directions Sum: 206.62374877929688
Loss: 0.042290929704904556 | Directions Sum: 208.54299926757812
Loss: 0.03941873461008072 | Directions Sum: 210.34414672851

### Manipulation Model

In [None]:
class Editor(torch.nn.Module):

    def __init__(self, generator, dino_model, dino_processor, classifier, get_attribute_scores, id2label, label2id, ):
        super(Editor, self).__init__()

        self.generator = generator
        self.dino_model = dino_model
        self.dino_processor = dino_processor
        self.classifier = classifier
        self.get_attribute_score = get_attribute_scores

        self.id2label = id2label
        self.label2id = label2id   
        self.num_classes = len(self.id2label)
        self.directions_dimension = [generator.mapping.num_ws, generator.mapping.w_dim]

        self.directions = nn.Parameter(torch.randn(self.num_classes,self.directions_dimension[0], self.directions_dimension[1]), requires_grad=True)
        self.alphas = np.arange(0,1,0.1)
    
    def gan_output_to_image(self, output):
        img_perm = (output.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
        img = Image.fromarray(img_perm[0].cpu().numpy(), 'RGB')
        return img


    def forward(self, latents, class_idx=None, alpha=None):

        if class_idx == None:
            class_idx = torch.randint(0, self.num_classes, (1,))
        
        if alpha == None:
            alpha = torch.tensor(np.round(np.random.choice(self.alphas),2))

        # Get scores for original image
        real_img = self.generator.synthesis(latents, noise_mode='const')
        real_img = self.gan_output_to_image(real_img)
        real_scores = self.get_attribute_score(self.dino_model, self.dino_processor, self.classifier, real_img)
        real_probs = softmax(real_scores, dim=0)
        real_class_prob  = real_probs[class_idx]

        # Get scores for transformed image
        transformed_latent = latents.clone() + alpha * self.directions[class_idx].to(latents.device)
        transformed_img = self.generator.synthesis(transformed_latent, noise_mode='const')
        transformed_img = self.gan_output_to_image(transformed_img)
        transformed_scores = self.get_attribute_score(self.dino_model, self.dino_processor, self.classifier, transformed_img)
        transformed_probs = softmax(transformed_scores, dim=0)
        transformed_class_prob = transformed_probs[class_idx]

        real_class_prob.requires_grad = True
        transformed_class_prob.requires_grad = True

        return real_class_prob, transformed_class_prob, class_idx, alpha
        

In [None]:
model = Editor(G, dino_model, dino_processor, classifier, get_attribute_scores, id2label, label2id)

### Test on one example only

- First latent code in the dataset
- Real Label: Day Dress: class_idx = 0
- Target Label: Denim Dress: class_idx = 7

In [None]:
fixed_alpha = 0.1
fixed_class_idx = 7

model = Editor(G, dino_model, dino_processor, classifier, get_attribute_scores, id2label, label2id) 
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5)

for i in range(5):
    optimizer.zero_grad()
    # Forward:
    real_class_prob, transformed_class_prob, class_idx, alpha = model(latents[0], 7, 0.1)
    # Loss:
    loss = criterion(transformed_class_prob, (real_class_prob + fixed_alpha))
    # Backward:
    loss.backward()
    optimizer.step()

    # Print updates: 
    print(f"Step: {i} | Loss: {np.round(loss.item(),4)} | Real Class Prob: {real_class_prob.item()} | Transformed Class Prob: {transformed_class_prob.item()}")
    print(f"Sum of walking direction: {torch.sum(model.directions[class_idx])} | Alpha: {alpha}")


In [None]:
latents.requires_grad = True

In [None]:
# Assuming `latents` is a tensor with requires_grad=True
latent = latents[0]
generated_image = model.generator.synthesis(latent, noise_mode='const')
# Simulate the operation
out = model.gan_output_to_image(generated_image)

# Perform a dummy operation and check gradients
output = pil_to_tensor.sum()
output.backward()

print("Gradient to latents after image conversion:", latents.grad)  # This will likely show None or zero
