# Stable Diffusion Counterfactual Image generation
Using libraries from huggingface and the open source model Stable Diffusion 3.5-large or medium

In [None]:
from huggingface_hub import interpreter_login
interpreter_login()


## Loading the model
loading the model using huggingface's diffuser library

In [None]:
import torch
from diffusers import StableDiffusion3Img2ImgPipeline

pipe = StableDiffusion3Img2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-3-large", torch_dtype=torch.bfloat16).to("cuda")

## Loading the dataset
Load the generated dataset containing only neutral relationships

In [1]:
# Load the dataset
import json

with open("./dataset/snli_1.0_train_neutral.jsonl", "r") as f:
    data = [json.loads(line) for line in f]


data[0]

{'annotator_labels': ['neutral'],
 'captionID': '3416050480.jpg#4',
 'gold_label': 'neutral',
 'pairID': '3416050480.jpg#4r1n',
 'sentence1': 'A person on a horse jumps over a broken down airplane.',
 'sentence1_binary_parse': '( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )',
 'sentence1_parse': '(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))',
 'sentence2': 'A person is training his horse for a competition.',
 'sentence2_binary_parse': '( ( A person ) ( ( is ( ( training ( his horse ) ) ( for ( a competition ) ) ) ) . ) )',
 'sentence2_parse': '(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (VP (VBG training) (NP (PRP$ his) (NN horse)) (PP (IN for) (NP (DT a) (NN competition))))) (. .)))'}

#### Test drive first entry

In [2]:
from dataset import get_url
import requests
import torch
from PIL import Image
from io import BytesIO

caption_id: str = data[0]['captionID'].split("#")[0]

url: str = get_url(caption_id, local=False)

response = requests.get(url)

image = Image.open(BytesIO(response.content)).convert("RGB")

prompt = data[0]['sentence2']

In [None]:
genImage = pipe(
    prompt=prompt, 
    image=image, 
    num_inference_steps=30, 
    guidance_scale=7.5).images[0]

genImage.save(f"output/{data[0]['captionID']}.png")

#### Evaluating test image

##### Clip score

In [None]:
from torchmetrics.multimodal import CLIPScore

clip = CLIPScore(model_name_or_path="openai/clip-vit-base-patch32")

score = clip(genImage, prompt)

##### Fréchet inception distance

In [4]:
from fid import compute_fid_between_images
import torchvision.transforms as transforms

caption_id2: str = data[1]['captionID'].split("#")[0]

url: str = get_url(caption_id2, local=False)

response = requests.get(url)

image2 = Image.open(BytesIO(response.content)).convert("RGB")

# Example usage:
fid_score = compute_fid_between_images(image, image2)
print(f'FID score: {fid_score}')




FID score: 610.7946166992188
