# SAM: Inference Playground

In [None]:
import os
os.chdir('/content')
CODE_DIR = 'SAM'

In [2]:
!git clone https://github.com/yuval-alaluf/SAM.git $CODE_DIR

Cloning into 'SAM'...
remote: Enumerating objects: 228, done.[K
remote: Counting objects: 100% (48/48), done.[K
remote: Compressing objects: 100% (26/26), done.[K
remote: Total 228 (delta 32), reused 22 (delta 22), pack-reused 180 (from 1)[K
Receiving objects: 100% (228/228), 24.63 MiB | 15.66 MiB/s, done.
Resolving deltas: 100% (78/78), done.


In [3]:
!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
!sudo unzip ninja-linux.zip -d /usr/local/bin/
!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force

--2025-05-24 06:34:59--  https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
Resolving github.com (github.com)... 140.82.116.3
Connecting to github.com (github.com)|140.82.116.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/1335132/d2f252e2-9801-11e7-9fbf-bc7b4e4b5c83?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20250524%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250524T063459Z&X-Amz-Expires=300&X-Amz-Signature=8b6ad38287c959bad86042e64ce995c63532057b969f7830d276590e70bd799b&X-Amz-SignedHeaders=host&response-content-disposition=attachment%3B%20filename%3Dninja-linux.zip&response-content-type=application%2Foctet-stream [following]
--2025-05-24 06:34:59--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/1335132/d2f252e2-9801-11e7-9fbf-bc7b4e4b5c83?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credenti

In [4]:
os.chdir(f'./{CODE_DIR}')

In [5]:
from argparse import Namespace
import os
import sys
import pprint
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms

sys.path.append(".")
sys.path.append("..")

from datasets.augmentations import AgeTransformer
from utils.common import tensor2im
from models.psp import pSp

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [6]:
EXPERIMENT_TYPE = 'ffhq_aging'

## Step 1: Download Pretrained Model
As part of this repository, we provide our pretrained aging model.
We'll download the model for the selected experiments as save it to the folder `../pretrained_models`.

In [7]:
def get_download_model_command(file_id, file_name):
    """ Get wget download command for downloading the desired model and save to directory ../pretrained_models. """
    current_directory = os.getcwd()
    save_path = os.path.join(os.path.dirname(current_directory), "pretrained_models")
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    url = r"""wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id={FILE_ID}" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt""".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path)
    return url

In [8]:
MODEL_PATHS = {
    "ffhq_aging": {"id": "1XyumF6_fdAxFmxpFcmPf-q84LU_22EMC", "name": "sam_ffhq_aging.pt"}
}

path = MODEL_PATHS[EXPERIMENT_TYPE]
download_command = get_download_model_command(file_id=path["id"], file_name=path["name"])

In [9]:
!wget {download_command}

--2025-05-24 06:36:57--  http://wget/
Resolving wget (wget)... failed: Name or service not known.
wget: unable to resolve host address ‘wget’
--2025-05-24 06:36:57--  https://docs.google.com/uc?export=download&confirm=&id=1XyumF6_fdAxFmxpFcmPf-q84LU_22EMC
Resolving docs.google.com (docs.google.com)... 173.194.203.113, 173.194.203.138, 173.194.203.102, ...
Connecting to docs.google.com (docs.google.com)|173.194.203.113|:443... connected.
HTTP request sent, awaiting response... 303 See Other
Location: https://drive.usercontent.google.com/download?id=1XyumF6_fdAxFmxpFcmPf-q84LU_22EMC&export=download [following]
--2025-05-24 06:36:57--  https://drive.usercontent.google.com/download?id=1XyumF6_fdAxFmxpFcmPf-q84LU_22EMC&export=download
Resolving drive.usercontent.google.com (drive.usercontent.google.com)... 142.250.99.132, 2607:f8b0:400e:c05::84
Connecting to drive.usercontent.google.com (drive.usercontent.google.com)|142.250.99.132|:443... connected.
HTTP request sent, awaiting response... 

## Step 2: Define Inference Parameters

Below we have a dictionary defining parameters such as the path to the pretrained model to use and the path to the
image to perform inference on.
While we provide default values to run this script, feel free to change as needed.

In [10]:
EXPERIMENT_DATA_ARGS = {
    "ffhq_aging": {
        "model_path": "../pretrained_models/sam_ffhq_aging.pt",
        "image_path": "notebooks/images/866.jpg",
        "transform": transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
    }
}

In [11]:
EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[EXPERIMENT_TYPE]

## Step 3: Load Pretrained Model
We assume that you have downloaded the pretrained aging model and placed it in the path defined above

In [16]:
model_path = EXPERIMENT_ARGS['model_path']
ckpt = torch.load(model_path, map_location='cpu')

UnpicklingError: Weights only load failed. In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
Please file an issue with the following so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler error: Unsupported operand 60

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

In [89]:
opts = ckpt['opts']
pprint.pprint(opts)

NameError: name 'ckpt' is not defined

In [None]:
# update the training options
opts['checkpoint_path'] = model_path

In [None]:
opts = Namespace(**opts)
net = pSp(opts)
net.eval()
net.cuda()
print('Model successfully loaded!')

## Step 4: Visualize Input

In [None]:
image_path = EXPERIMENT_DATA_ARGS[EXPERIMENT_TYPE]["image_path"]
original_image = Image.open(image_path).convert("RGB")

In [None]:
original_image.resize((256, 256))

## Step 5: Perform Inference

### Align Image

Before running inference we'll run alignment on the input image.

In [None]:
!wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
!bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2

In [None]:
def run_alignment(image_path):
    import dlib
    from scripts.align_all_parallel import align_face
    predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
    aligned_image = align_face(filepath=image_path, predictor=predictor)
    print("Aligned image has shape: {}".format(aligned_image.size))
    return aligned_image

In [None]:
aligned_image = run_alignment(image_path)

In [None]:
aligned_image.resize((256, 256))

### Run Inference

In [None]:
img_transforms = EXPERIMENT_ARGS['transform']
input_image = img_transforms(aligned_image)

In [None]:
# we'll run the image on multiple target ages
target_ages = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
age_transformers = [AgeTransformer(target_age=age) for age in target_ages]

In [None]:
def run_on_batch(inputs, net):
    result_batch = net(inputs.to("cuda").float(), randomize_noise=False, resize=False)
    return result_batch

In [None]:
# for each age transformed age, we'll concatenate the results to display them side-by-side
results = np.array(aligned_image.resize((1024, 1024)))
for age_transformer in age_transformers:
    print(f"Running on target age: {age_transformer.target_age}")
    with torch.no_grad():
        input_image_age = [age_transformer(input_image.cpu()).to('cuda')]
        input_image_age = torch.stack(input_image_age)
        result_tensor = run_on_batch(input_image_age, net)[0]
        result_image = tensor2im(result_tensor)
        results = np.concatenate([results, result_image], axis=1)

### Visualize Result

In [None]:
results = Image.fromarray(results)
results   # this is a very large image (11*1024 x 1024) so it may take some time to display!

In [None]:
# save image at full resolution
results.save("notebooks/images/age_transformed_image.jpg")