In [None]:
# @title 环境
import os

!apt -y update -qq

# !pip install -q torch==2.2.2+cu121 torchvision==0.17.2+cu121 torchaudio==2.2.2+cu121 torchsde --extra-index-url https://download.pytorch.org/whl/cu121 -U
# !pip install -q einops transformers>=4.25.1 safetensors>=0.3.0
# !pip install -q aiohttp accelerate pyyaml Pillow scipy tqdm psutil kornia>=0.7.1 websocket-client==1.6.3 diffusers>=0.25.0 albumentations==1.4.3
# !pip install -q cog

# !curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget

In [None]:
import yaml

torch_versions = {
    "torch": "2.2.2",
    "torchvision": "0.17.2",
    "torchaudio":"2.2.2",
}

with open('cog.yaml', 'r') as file:
    config = yaml.safe_load(file).get('build')

cuda_version=config['cuda'].replace('.', '')

cuda_pkgs = []
pkgs=[]
for pkg in config['python_packages']:
  if pkg.startswith('torch'):
    if pkg in torch_versions:
      cuda_pkgs.append(f"{pkg}=={torch_versions[pkg]}+cu{cuda_version}")
    else:
      cuda_pkgs.append(pkg)
  else:
    pkgs.append(pkg)
cuda_pkgs.append(f'--extra-index-url https://download.pytorch.org/whl/cu{cuda_version} -U')
pkgs.append('cog')

print(f"pip install -q {' '.join(cuda_pkgs)}")
!pip install -q {' '.join(cuda_pkgs)}

print(f"pip install -q {' '.join(pkgs)}")
!pip install -q {' '.join(pkgs)}

for r in config['run']:
  print(r)
  !{r}

In [None]:
# @title 本体
!git clone https://github.com/3LOCats/cog-stickers.git /content/cog-stickers
%cd /content/cog-stickers
!./scripts/reset.sh

### 下载模型

In [None]:
import json

from scripts.weights_from_workflow import handle_weights
with open('sticker_maker_api.json', 'r') as f:
    workflow = json.load(f)
    weights = handle_weights(workflow)

print('Installing weights:')
print('\n'.join(weights))
!python ./scripts/get_weights.py {' '.join(weights)}


### 异步启动cog.server.http服务

In [None]:
import asyncio
import subprocess

async def create_command(cmd: str):
    proc = await asyncio.create_subprocess_shell(
        cmd,
        stdout=asyncio.subprocess.PIPE,
        stderr=asyncio.subprocess.PIPE
    )
    return proc

async def wait_command(proc):
    stdout, stderr = await proc.communicate()
    print(f"stdout: {stdout.decode()}")
    print(f"stderr: {stderr.decode()}")

    return_code = await proc.wait()
    print(f"Return code: {return_code}")

http_server = await create_command("python -m cog.server.http")
print(f'http server pid: {http_server.pid}')
wait_command(http_server)

### 生成图片

In [None]:
from predict import Predictor

p = Predictor()
p.setup()

files=p.predict(
    prompt='rabbit',
    negative_prompt='',
    width=1024,
    height=1024,
    steps=20,
    number_of_images=1,
    output_format='png',
    output_quality=100,
    sticker_type='Stickersheet', # 'Sticker' or 'Stickersheet'
    seed= None
    )


### 显示图片

In [None]:
from IPython.display import Image
file=files[0]
if file.suffix == '.webp':
  file = file.with_suffix('.png')
Image(data=file)

### 停止HTTP服务

In [None]:
!ss -antp | grep LISTEN | grep python

In [None]:
print(f'killing {http_server.pid}')
!kill {http_server.pid}