In [None]:
# # Install dependencies

# !git clone https://github.com/openai/CLIP
# !git clone https://github.com/crowsonkb/guided-diffusion
# !pip install -e ./CLIP
# !pip install -e ./guided-diffusion
# !pip install lpips
# !pip install blobfile
# !mkdir ckpt_model
# ##RESTART KERNEL###
# Imports

import gc
import io
import math
import sys

from IPython import display
import lpips
from PIL import Image
import requests
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm
import json

sys.path.append('./CLIP')
sys.path.append('./guided-diffusion')
sys.path.append('./improved-diffusion')

import clip
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults

In [None]:
#open test_summaries.json
import json
with open('test_summaries.json') as json_file:
    test_summaries = json.load(json_file)

In [16]:
# -*- coding: utf-8 -*-
import re
alphabets= "([A-Za-z])"
prefixes = "(Mr|St|Mrs|Ms|Dr)[.]"
suffixes = "(Inc|Ltd|Jr|Sr|Co)"
starters = "(Mr|Mrs|Ms|Dr|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)"
acronyms = "([A-Z][.][A-Z][.](?:[A-Z][.])?)"
websites = "[.](com|net|org|io|gov)"

def split_into_sentences(text):
    text = " " + text + "  "
    text = text.replace("\n"," ")
    text = re.sub(prefixes,"\\1<prd>",text)
    text = re.sub(websites,"<prd>\\1",text)
    if "Ph.D" in text: text = text.replace("Ph.D.","Ph<prd>D<prd>")
    text = re.sub("\s" + alphabets + "[.] "," \\1<prd> ",text)
    text = re.sub(acronyms+" "+starters,"\\1<stop> \\2",text)
    text = re.sub(alphabets + "[.]" + alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>\\3<prd>",text)
    text = re.sub(alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>",text)
    text = re.sub(" "+suffixes+"[.] "+starters," \\1<stop> \\2",text)
    text = re.sub(" "+suffixes+"[.]"," \\1<prd>",text)
    text = re.sub(" " + alphabets + "[.]"," \\1<prd>",text)
    if "”" in text: text = text.replace(".”","”.")
    if "\"" in text: text = text.replace(".\"","\".")
    if "!" in text: text = text.replace("!\"","\"!")
    if "?" in text: text = text.replace("?\"","\"?")
    text = text.replace(".",".<stop>")
    text = text.replace("?","?<stop>")
    text = text.replace("!","!<stop>")
    text = text.replace("<prd>",".")
    sentences = text.split("<stop>")
    sentences = sentences[:-1]
    sentences = [s.strip() for s in sentences]
    return sentences


# Define necessary functions

def fetch(url_or_path):
    if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
        r = requests.get(url_or_path)
        r.raise_for_status()
        fd = io.BytesIO()
        fd.write(r.content)
        fd.seek(0)
        return fd
    return open(url_or_path, 'rb')


def parse_prompt(prompt):
    if prompt.startswith('http://') or prompt.startswith('https://'):
        vals = prompt.rsplit(':', 2)
        vals = [vals[0] + ':' + vals[1], *vals[2:]]
    else:
        vals = prompt.rsplit(':', 1)
    vals = vals + ['', '1'][len(vals):]
    return vals[0], float(vals[1])


class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
        return torch.cat(cutouts)


def spherical_dist_loss(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)


def tv_loss(input):
    """L2 total variation loss, as in Mahendran et al."""
    input = F.pad(input, (0, 1, 0, 1), 'replicate')
    x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
    y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
    return (x_diff**2 + y_diff**2).mean([1, 2, 3])


def range_loss(input):
    return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])

NameError: name 'nn' is not defined

In [None]:
def gen_image(save_path, prompt, model_dir):
    # Model settings

    model_config = model_and_diffusion_defaults()
    model_config.update({
        #'attention_resolutions': '32, 16, 8',
        'class_cond': False,
        'diffusion_steps': 4000,
        'rescale_timesteps': True,
        'timestep_respacing': '250',  # Modify this value to decrease the number of
                                    # timesteps.
        'image_size': 64,
        'learn_sigma': True,
        'noise_schedule': 'linear',
        'num_channels': 128,
        #'num_head_channels': 64,
        'num_res_blocks': 3,
        #'resblock_updown': True,
        'use_checkpoint': False,
        'use_fp16': True,
        'use_scale_shift_norm': True,
    })


    # Load models

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print('Using device:', device)

    model, diffusion = create_model_and_diffusion(**model_config)
    model.load_state_dict(torch.load(model_dir, map_location='cpu'))
    model.requires_grad_(False).eval().to(device)
    if model_config['use_fp16']:
        model.convert_to_fp16()

    clip_model = clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device)
    clip_size = clip_model.visual.input_resolution
    normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                    std=[0.26862954, 0.26130258, 0.27577711])
    lpips_model = lpips.LPIPS(net='vgg').to(device)





    prompts = [prompt]
    image_prompts = []
    batch_size = 16
    clip_guidance_scale = 1000  # Controls how much the image should look like the prompt.
    tv_scale = 150              # Controls the smoothness of the final output.
    range_scale = 50            # Controls how far out of range RGB values are allowed to be.
    cutn = 4
    n_batches = 1
    init_image = None   # This can be an URL or Colab local path and must be in quotes.
    skip_timesteps = 0  # This needs to be between approx. 200 and 500 when using an init image.
                        # Higher values make the output look more like the init.
    init_scale = 0      # This enhances the effect of the init image, a good value is 1000.
    seed = 0




    def do_run():
        if seed is not None:
            torch.manual_seed(seed)

        make_cutouts = MakeCutouts(clip_size, cutn)
        side_x = side_y = model_config['image_size']

        target_embeds, weights = [], []

        for prompt in prompts:
            txt, weight = parse_prompt(prompt)
            target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
            weights.append(weight)

        for prompt in image_prompts:
            path, weight = parse_prompt(prompt)
            img = Image.open(fetch(path)).convert('RGB')
            img = TF.resize(img, min(side_x, side_y, *img.size), transforms.InterpolationMode.LANCZOS)
            batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
            embed = clip_model.encode_image(normalize(batch)).float()
            target_embeds.append(embed)
            weights.extend([weight / cutn] * cutn)

        target_embeds = torch.cat(target_embeds)
        weights = torch.tensor(weights, device=device)
        if weights.sum().abs() < 1e-3:
            raise RuntimeError('The weights must not sum to 0.')
        weights /= weights.sum().abs()

        init = None
        if init_image is not None:
            init = Image.open(fetch(init_image)).convert('RGB')
            init = init.resize((side_x, side_y), Image.LANCZOS)
            init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)

        cur_t = None

        def cond_fn(x, t, out, y=None):
            n = x.shape[0]
            fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
            x_in = out['pred_xstart'] * fac + x * (1 - fac)
            clip_in = normalize(make_cutouts(x_in.add(1).div(2)))
            image_embeds = clip_model.encode_image(clip_in).float()
            dists = spherical_dist_loss(image_embeds.unsqueeze(1), target_embeds.unsqueeze(0))
            dists = dists.view([cutn, n, -1])
            losses = dists.mul(weights).sum(2).mean(0)
            tv_losses = tv_loss(x_in)
            range_losses = range_loss(out['pred_xstart'])
            loss = losses.sum() * clip_guidance_scale + tv_losses.sum() * tv_scale + range_losses.sum() * range_scale
            if init is not None and init_scale:
                init_losses = lpips_model(x_in, init)
                loss = loss + init_losses.sum() * init_scale
            return -torch.autograd.grad(loss, x)[0]

        if model_config['timestep_respacing'].startswith('ddim'):
            sample_fn = diffusion.ddim_sample_loop_progressive
        else:
            sample_fn = diffusion.p_sample_loop_progressive

        for i in range(n_batches):
            cur_t = diffusion.num_timesteps - skip_timesteps - 1

            samples = sample_fn(
                model,
                (batch_size, 3, side_y, side_x),
                clip_denoised=False,
                model_kwargs={},
                cond_fn=cond_fn,
                progress=True,
                skip_timesteps=skip_timesteps,
                init_image=init,
                randomize_class=True,
                cond_fn_with_grad=True,
            )

            for j, sample in enumerate(samples):
                cur_t -= 1
                if j % 100 == 0 or cur_t == -1:
                    print()
                    for k, image in enumerate(sample['pred_xstart']):
                        filename = f'progress_{i * batch_size + k:05}.png'
                        TF.to_pil_image(image.add(1).div(2).clamp(0, 1)).save(save_path + '_' + str(i) + '_' + str(j) + '_' + str(k) + '.png')
                        tqdm.write(f'Batch {i}, step {j}, output {k}:')
                        display.display(display.Image(filename))
                #save 
                        

    gc.collect()
    do_run()




        

        

In [None]:
#get all keys in test_summaries
#loop through test_summaries and print values

for key in test_summaries:
    caption = test_summaries[key]['caption']
    print('caption: ',caption)
    finetuned= 'final_model/final/final.pt'
    base = './imagenet64_uncond_100M_1500K.pt'
    gen_image(caption, finetuned)
    gen_image(caption, base)
    summary = test_summaries[key]['summary']
    #split into sentences
    sentences = split_into_sentences(summary)
    if sentences == []:
        sentences = summary
    else:
        sentences = sentences[0]
    print('summ: ', sentences)
    gen_image(sentences, finetuned)
    gen_image(sentences, base)


    print('\n')
    


In [None]:
#DOWNLOAD PRETRAINED WEIGHTS
#!curl -OL "https://openaipublic.blob.core.windows.net/diffusion/march-2021/imagenet64_uncond_100M_1500K.pt"

In [None]:
#get all keys in test_summaries
#loop through test_summaries and print values

for key in test_summaries:
    
    caption = test_summaries[key]['caption']

    #reduce caption to 50 words
    caption = caption.split(' ')
    caption = caption[:50]
    caption = ' '.join(caption)
    print('caption: ',caption)
    finetuned= 'final_model/final/final.pt'
    base = './imagenet64_uncond_100M_1500K.pt'
    save_path = f'./fake/caption/base/cb{key}'

    #check if any path beginning with save_path exists
    if os.path.exists(save_path + 

    #get current directory
    curr_dir = os.getcwd()
    save_path = os.path.join(curr_dir, save_path[1:])
    if not os.path.exists(save_path):
        print('already exists')
        continue
    gen_image(save_path, caption, base, 'ddim500')
    save_path = f'./fake/caption/ft/cf{key}'
    gen_image(save_path, caption, finetuned, '500')
    
    summary = test_summaries[key]['summary']
    #split into words
    summary = summary.split()

    summary = summary[:75]
    summary = ' '.join(summary)
    #split into sentences
    sentences = split_into_sentences(summary)
    if sentences == []:
        sentences = summary
    else:
        sentences = sentences[0]

    #check if sentences is of type list
    if type(sentences) == list:
        sentences = sentences[0]
    

    #strip sentence of all non alphanumeric characters
    sentences = re.sub(r'[^\w\s]', '', sentences)
    print('summ: ', sentences)
    save_path = f'./fake/summary/ft/sf{key}'

    gen_image(save_path, sentences, finetuned, '500')
    save_path = f'./fake/summary/base/sb{key}'
    gen_image(save_path, sentences, base, 'ddim500')


    print('\n')
    


In [None]:
MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3 --learn_sigma True"
DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule cosine"
TRAIN_FLAGS="--lr 1e-4 --batch_size 32"
%env OPENAI_LOGDIR = ./final_model

In [None]:
#Train 
#SET DATA DIR TO FOLDER OF ALL OF YOUR DATA WHICH IS *.JPG FILES
!mpiexec -n 1 python guided-diffusion/scripts/image_train.py --resume_checkpoint ./final_model/final/final.pt --data_dir train/data --image_size 64 --num_channels 128 --num_res_blocks 3 --learn_sigma True --lr 1e-4 --batch_size 32

In [None]:
#Train MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3 --learn_sigma True"
DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule cosine"
TRAIN_FLAGS="--lr 1e-4 --batch_size 32"
%env OPENAI_LOGDIR = ./ckpt_model2

#SET DATA DIR TO FOLDER OF ALL OF YOUR DATA WHICH IS *.JPG FILES
!mpiexec -n 1 python guided-diffusion/scripts/image_train.py --resume_checkpoint ./ckpt_model/final0.pt --data_dir data2 $MODEL_FLAGS $TRAIN_FLAGS

In [None]:
#SAMPLE FINE-TUNED MODEL 
#CHANGE MODEL PATH TO TRAINED MODEL PATH (PROLLY IN CKPT_MODEL)
SAMPLE_FLAGS="--batch_size 4 --num_samples 1024 --timestep_respacing 250"
!python guided-diffusion/scripts/image_sample.py --model_path ./ckpt_model/final0.pt $SAMPLE_FLAGS $MODEL_FLAGS

In [None]:
#GET BASE MODEL AND CLASSIFER
!curl -OL "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/64x64_classifier.pt"
!curl -OL "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/64x64_diffusion.pt"

In [None]:

#SAMPLE BASE MODEL
SAMPLE_FLAGS="--batch_size 4 --num_samples 128 --timestep_respacing 250"


MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --dropout 0.1 --image_size 64 --learn_sigma True --noise_schedule cosine --num_channels 192 --num_head_channels 64 --num_res_blocks 3 --resblock_updown True --use_new_attention_order True --use_fp16 True --use_scale_shift_norm True"
!python guided-diffusion/scripts/classifier_sample.py $MODEL_FLAGS --classifier_scale 0.0 --classifier_path ./64x64_classifier.pt --classifier_depth 4 --model_path ./64x64_diffusion.pt $SAMPLE_FLAGS

# CLIP GUIDED DIFFUSION (WORK IN PROGRESS... )

In [None]:
# Define necessary functions

def fetch(url_or_path):
    if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
        r = requests.get(url_or_path)
        r.raise_for_status()
        fd = io.BytesIO()
        fd.write(r.content)
        fd.seek(0)
        return fd
    return open(url_or_path, 'rb')


def parse_prompt(prompt):
    if prompt.startswith('http://') or prompt.startswith('https://'):
        vals = prompt.rsplit(':', 2)
        vals = [vals[0] + ':' + vals[1], *vals[2:]]
    else:
        vals = prompt.rsplit(':', 1)
    vals = vals + ['', '1'][len(vals):]
    return vals[0], float(vals[1])


class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
        return torch.cat(cutouts)


def spherical_dist_loss(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)


def tv_loss(input):
    """L2 total variation loss, as in Mahendran et al."""
    input = F.pad(input, (0, 1, 0, 1), 'replicate')
    x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
    y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
    return (x_diff**2 + y_diff**2).mean([1, 2, 3])


def range_loss(input):
    return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])

In [None]:
# Model settings

model_config = model_and_diffusion_defaults()
model_config.update({
    #'attention_resolutions': '32, 16, 8',
    'class_cond': False,
    'diffusion_steps': 4000,
    'rescale_timesteps': True,
    'timestep_respacing': '250',  # Modify this value to decrease the number of
                                   # timesteps.
    'image_size': 64,
    'learn_sigma': True,
    'noise_schedule': 'linear',
    'num_channels': 128,
    #'num_head_channels': 64,
    'num_res_blocks': 3,
    #'resblock_updown': True,
    'use_checkpoint': False,
    'use_fp16': True,
    'use_scale_shift_norm': True,
})

In [None]:
# Load models

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

model, diffusion = create_model_and_diffusion(**model_config)
model.load_state_dict(torch.load('final_model/final/final.pt', map_location='cpu'))
model.requires_grad_(False).eval().to(device)
if model_config['use_fp16']:
    model.convert_to_fp16()

clip_model = clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device)
clip_size = clip_model.visual.input_resolution
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                 std=[0.26862954, 0.26130258, 0.27577711])
lpips_model = lpips.LPIPS(net='vgg').to(device)


## Settings for this run:

In [None]:
prompts = ['CIA black site detainee served as training prop to teach interrogators torture techniques']
image_prompts = []
batch_size = 16
clip_guidance_scale = 1000  # Controls how much the image should look like the prompt.
tv_scale = 150              # Controls the smoothness of the final output.
range_scale = 50            # Controls how far out of range RGB values are allowed to be.
cutn = 4
n_batches = 1
init_image = None   # This can be an URL or Colab local path and must be in quotes.
skip_timesteps = 0  # This needs to be between approx. 200 and 500 when using an init image.
                    # Higher values make the output look more like the init.
init_scale = 0      # This enhances the effect of the init image, a good value is 1000.
seed = 0

### Actually do the run...

In [None]:
def do_run():
    if seed is not None:
        torch.manual_seed(seed)

    make_cutouts = MakeCutouts(clip_size, cutn)
    side_x = side_y = model_config['image_size']

    target_embeds, weights = [], []

    for prompt in prompts:
        txt, weight = parse_prompt(prompt)
        target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
        weights.append(weight)

    for prompt in image_prompts:
        path, weight = parse_prompt(prompt)
        img = Image.open(fetch(path)).convert('RGB')
        img = TF.resize(img, min(side_x, side_y, *img.size), transforms.InterpolationMode.LANCZOS)
        batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
        embed = clip_model.encode_image(normalize(batch)).float()
        target_embeds.append(embed)
        weights.extend([weight / cutn] * cutn)

    target_embeds = torch.cat(target_embeds)
    weights = torch.tensor(weights, device=device)
    if weights.sum().abs() < 1e-3:
        raise RuntimeError('The weights must not sum to 0.')
    weights /= weights.sum().abs()

    init = None
    if init_image is not None:
        init = Image.open(fetch(init_image)).convert('RGB')
        init = init.resize((side_x, side_y), Image.LANCZOS)
        init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)

    cur_t = None

    def cond_fn(x, t, out, y=None):
        n = x.shape[0]
        fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
        x_in = out['pred_xstart'] * fac + x * (1 - fac)
        clip_in = normalize(make_cutouts(x_in.add(1).div(2)))
        image_embeds = clip_model.encode_image(clip_in).float()
        dists = spherical_dist_loss(image_embeds.unsqueeze(1), target_embeds.unsqueeze(0))
        dists = dists.view([cutn, n, -1])
        losses = dists.mul(weights).sum(2).mean(0)
        tv_losses = tv_loss(x_in)
        range_losses = range_loss(out['pred_xstart'])
        loss = losses.sum() * clip_guidance_scale + tv_losses.sum() * tv_scale + range_losses.sum() * range_scale
        if init is not None and init_scale:
            init_losses = lpips_model(x_in, init)
            loss = loss + init_losses.sum() * init_scale
        return -torch.autograd.grad(loss, x)[0]

    if model_config['timestep_respacing'].startswith('ddim'):
        sample_fn = diffusion.ddim_sample_loop_progressive
    else:
        sample_fn = diffusion.p_sample_loop_progressive

    for i in range(n_batches):
        cur_t = diffusion.num_timesteps - skip_timesteps - 1

        samples = sample_fn(
            model,
            (batch_size, 3, side_y, side_x),
            clip_denoised=False,
            model_kwargs={},
            cond_fn=cond_fn,
            progress=True,
            skip_timesteps=skip_timesteps,
            init_image=init,
            randomize_class=True,
            cond_fn_with_grad=True,
        )

        for j, sample in enumerate(samples):
            cur_t -= 1
            if j % 100 == 0 or cur_t == -1:
                print()
                for k, image in enumerate(sample['pred_xstart']):
                    filename = f'progress_{i * batch_size + k:05}.png'
                    TF.to_pil_image(image.add(1).div(2).clamp(0, 1)).save(filename)
                    tqdm.write(f'Batch {i}, step {j}, output {k}:')
                    display.display(display.Image(filename))

gc.collect()
do_run()

In [None]:
#zip every file in fake folder recursively
!zip -r fake.zip fake/