# Setup Environment and Paths

This cell imports necessary libraries and sets the device (GPU or CPU).  
It also defines the file paths to the three pre-trained model checkpoints that will be loaded later.

In [1]:
import torch
from torchvision import transforms
from PIL import Image
import os

# --- Paths to model checkpoints ---
identity_encoder_path = r"DeepExPo/DeepExPo_Weights/DSID_Checkpoints/dsid-source-identity-encoder-cpk.pth"
semantic_encoder_path = r"DeepExPo/DeepExPo_Weights/DSID_Checkpoints/dsid-target-sementic-encoder-cpk.pth"
MEAF_model_path = r"DeepExPo/DeepExPo_Weights/DeepExPo_MEAF_weights.pth"  # add extension if needed

# --- Device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


# Define or Import Model Architectures

Here we define placeholder PyTorch model classes for the Identity Encoder, Semantic Encoder, and the MEAF Model.  
You should replace these with your actual model definitions or imports.


In [None]:
class DSID(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.identity_encoder = ...
        self.semantic_encoder = ...

    def forward(self, x_id, x_sem):
        id_emb = self.identity_encoder(x_id)
        sem_emb = self.semantic_encoder(x_sem)
        return id_emb, sem_emb


# Load Pre-trained Model Weights

This cell loads the saved checkpoint weights into the model instances and sets the models to evaluation mode.


In [None]:
identity_encoder.load_state_dict(torch.load(identity_encoder_path, map_location=device))
semantic_encoder.load_state_dict(torch.load(semantic_encoder_path, map_location=device))
third_model.load_state_dict(torch.load(third_model_path, map_location=device))

identity_encoder.eval()
semantic_encoder.eval()
third_model.eval()

print("Models loaded and set to eval mode.")


# Define Image Preprocessing Pipeline

Defines the transformations applied to input images before feeding them into the models, including resizing, converting to tensors, and normalizing pixel values.


preprocess = transforms.Compose([
    transforms.Resize((256, 256)),  # adjust if needed
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  std=[0.229, 0.224, 0.225])
])


# Load and Preprocess Input Images

Prompts for the identity and reference semantic image paths, loads the images, converts them to RGB, and applies the preprocessing pipeline to prepare them for model input.


In [None]:
identity_img_path = input("Enter path to identity image: ").strip()
semantic_img_path = input("Enter path to reference semantic image: ").strip()

assert os.path.exists(identity_img_path), f"Identity image not found: {identity_img_path}"
assert os.path.exists(semantic_img_path), f"Semantic image not found: {semantic_img_path}"

identity_img = Image.open(identity_img_path).convert("RGB")
semantic_img = Image.open(semantic_img_path).convert("RGB")

identity_tensor = preprocess(identity_img).unsqueeze(0).to(device)
semantic_tensor = preprocess(semantic_img).unsqueeze(0).to(device)

print(f"Images loaded and preprocessed.")


# Extract Identity and Semantic Embeddings

Runs the identity encoder and semantic encoder on the preprocessed images to obtain feature embeddings.


In [None]:
with torch.no_grad():
    identity_emb = identity_encoder(identity_tensor)
    semantic_emb = semantic_encoder(semantic_tensor)

print(f"Identity embedding shape: {identity_emb.shape}")
print(f"Semantic embedding shape: {semantic_emb.shape}")


In [None]:
from PIL import Image
from torchvision import transforms

to_tensor = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),            # converts to [0,1] float tensor, shape [C,H,W]
])

img_id  = Image.open("/home/afgan_server/Important_Code/New_ID_Images/fifth_Id/Screenshot 2025-02-14 155404.jpg").convert("RGB")
img_sem = Image.open("/home/afgan_server/Important_Code/New_ID_Images/fifth_Id/Screenshot 2025-02-14 155404.jpg").convert("RGB")

x_id  = to_tensor(img_id ).unsqueeze(0).to(device)   # add batch dim -> [1,3,256,256]
x_sem = to_tensor(img_sem).unsqueeze(0).to(device)


# Combine or Process Embeddings

Optionally combines the extracted embeddings (for example, concatenation) in preparation for inference by the third model.


In [None]:
with torch.no_grad():
    output = third_model(identity_emb, semantic_emb)  # or pass combined_emb if needed

print(f"Output shape: {output.shape}")


# Run Inference with Third Model

Feeds the embeddings into the third model to generate the output (e.g., synthesized image tensor).

In [None]:
def tensor_to_pil(tensor):
    tensor = tensor.squeeze(0).cpu()
    tensor = tensor.clamp(0, 1)  # clip values to [0,1]
    tensor = tensor.permute(1, 2, 0).numpy() * 255
    tensor = tensor.astype('uint8')
    return Image.fromarray(tensor)

# Assuming output tensor is normalized [0,1]. Adjust if necessary.
output = (output - output.min()) / (output.max() - output.min())  # normalize to [0,1]

generated_img = tensor_to_pil(output)
output_path = "generated_output.png"
generated_img.save(output_path)

print(f"Generated image saved to {output_path}")

# Display inline in notebook
generated_img.show()


# DeepExPo Output


In [None]:
import torch
from diffusers import DiffusionPipeline, AutoencoderKL
lora_weights_path ="/home/afgan_server/Important_Code/Old_ID_Fine_Tune/ID_2_fine_tune"
vae = AutoencoderKL.from_pretrained("/home/afgan_server/Important_Code/stable_diffusion_base_weights/madebyollinsdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = DiffusionPipeline.from_pretrained(
    "/home/afgan_server/Important_Code/stable_diffusion_base_weights/stabilityaistable-diffusion-xl-base-1.0",
    vae=vae,
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True
)
pipe.load_lora_weights(lora_weights_path)
_ = pipe.to("cuda")