In [1]:
import torch
from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget, decode_latent_mesh

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
assert(device != 'cuda')

In [2]:
SAVE_LOCATION = 'models/'

In [3]:
xm = load_model('transmitter', device=device)
model = load_model('text300M', device=device)
diffusion = diffusion_from_config(load_config('diffusion'))

In [4]:
def generateLatentFromPrompt(prompt, guidance_scale=15.0):

    latents = sample_latents(
        batch_size=1,
        model=model,
        diffusion=diffusion,
        guidance_scale=guidance_scale,
        model_kwargs=dict(texts=[prompt] * 1),
        progress=True,
        clip_denoised=True,
        use_fp16=True,
        use_karras=True,
        karras_steps=64,
        sigma_min=1e-3,
        sigma_max=160,
        s_churn=0,
    )
    
    return latents[0]

def exportLatentToObj(latent, name):
    t = decode_latent_mesh(xm, latent).tri_mesh()
    with open(f'{SAVE_LOCATION}{name}.obj', 'w') as f:
        t.write_obj(f)

In [6]:
prompts = ['a shark', 'a goldfish']
latents = []
for prompt in prompts:
    latents.append(generateLatentFromPrompt(prompt))

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

In [7]:
linear_interpolation_increments = 10

z_interpolations = []
for i in range(linear_interpolation_increments):
    alpha = i * (1 / linear_interpolation_increments);
    z_interpolations.append((1 - alpha) * latents[0] + alpha * latents[1])

for i, z in enumerate(z_interpolations):
    exportLatentToObj(z, f'{i}_shark_goldfish')



In [7]:
base_prompt = 'a ``` car'
categories = ['red', 'blue']
latents_per_category = 5

samples = {category: [] for category in categories}
for category in categories:
    prompt = base_prompt.replace('```', category)
    print(f'Starting prompt: {prompt}...')
    for i in range(latents_per_category):
        samples[category].append(generateLatentFromPrompt(prompt))

Starting prompt: a red car...


  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

Starting prompt: a blue car...


  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

In [8]:
import torch
from torch import nn, optim
import random

color_epochs = 1000
color_batch_size = 30

class ColorClassifer(nn.Module):
    def __init__(self, latentSpaceDimensions):
        super(ColorClassifer, self).__init__()
        self.fc = nn.Linear(latentSpaceDimensions, 2).to(device)
    
    def forward(self, x):
        return self.fc(x)
    
latentSpaceDimensions = len(next(iter(samples.values()))[0])
classiferModel = ColorClassifer(latentSpaceDimensions).to(device)
criterionClassiferModel = nn.CrossEntropyLoss().to(device)
optimizerClassiferModel = optim.Adam(classiferModel.parameters(), lr=0.001)

def randomSample(batch_size):
    latent_vectors = random.sample(samples['red'], batch_size) + random.sample(samples['blue'], batch_size)
    latent_vectors = torch.stack(latent_vectors).to(device)
    labels = torch.tensor([0 for _ in range(batch_size)] + [1 for _ in range(batch_size)], dtype=torch.long).to(device)

    assert(len(latent_vectors) == batch_size * 2)
    assert(len(labels) == batch_size * 2)

    return latent_vectors, labels

for epoch in range(color_epochs):
    latent_vectors, labels = randomSample(color_batch_size)
    optimizerClassiferModel.zero_grad()
    output = classiferModel(latent_vectors)
    loss = criterionClassiferModel(output, labels)
    if (epoch % 100 == 0):
        print(f'Epoch {epoch}, Loss: {loss.item()}')
    loss.backward()
    optimizerClassiferModel.step()

red_weight = classiferModel.state_dict()['fc.weight'][0].to(device)
print(red_weight)


ValueError: Sample larger than population or is negative

In [24]:
weight = -60
sample_latent = samples['red'][0]
altered_latent = sample_latent + red_weight * weight
exportLatentToObj(altered_latent, 'color_test_output')

In [25]:
blue_shoe = generateLatentFromPrompt('a blue shoe')
altered_blue_shoe = blue_shoe + red_weight * weight
exportLatentToObj(altered_blue_shoe, 'color_blue_shoe_test_output')

  0%|          | 0/64 [00:00<?, ?it/s]

In [27]:
weight = 60
altered_blue_shoe = blue_shoe + red_weight * weight
exportLatentToObj(altered_blue_shoe, 'color_blue_shoe_test_output')
exportLatentToObj(blue_shoe, 'color_original_blue_shoe_test_output')

In [29]:
latentGuidance5 = generateLatentFromPrompt('a airplane', 5.0)
latentGuidance50 = generateLatentFromPrompt('a airplane', 50.0)
exportLatentToObj(latentGuidance5, 'latent_guidance_5')
exportLatentToObj(latentGuidance50, 'latent_guidance_50')

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

In [30]:
latentGuidance5 = generateLatentFromPrompt('a airplane', 1.0)
latentGuidance50 = generateLatentFromPrompt('a airplane', 20.0)
exportLatentToObj(latentGuidance5, 'latent_guidance_1')
exportLatentToObj(latentGuidance50, 'latent_guidance_20')

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

In [5]:
latent = generateLatentFromPrompt('a blue shoe')

  0%|          | 0/64 [00:00<?, ?it/s]

In [6]:
print(latent.shape)

torch.Size([1048576])


In [5]:
latent = generateLatentFromPrompt('a blue backpack')

  0%|          | 0/64 [00:00<?, ?it/s]

In [6]:
exportLatentToObj(latent, 'a blue backpack')

