In [1]:
%%writefile Imagetoimagetranlator.py
import streamlit as st
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as transforms

def transform_image(image):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    return transform(image).unsqueeze(0)

@st.cache(allow_output_mutation=True)
def load_models():
    generator_G = torch.load('generator_G.pth', map_location=torch.device('cpu'))
    generator_F = torch.load('generator_F.pth', map_location=torch.device('cpu'))
    discriminator_G = torch.load('discriminator_G.pth', map_location=torch.device('cpu'))
    discriminator_F = torch.load('discriminator_F.pth', map_location=torch.device('cpu'))
    return generator_G, generator_F, discriminator_G, discriminator_F

generator_G, generator_F, discriminator_G, discriminator_F = load_models()
st.title('CycleGAN Image to Image Translation')
st.write('Upload an ultrasound image and a breast piece image of chicken for processing with CycleGAN.')
ultrasound_image = st.file_uploader('Upload Ultrasound Image', type=['png', 'jpg', 'jpeg'])
chicken_image = st.file_uploader('Upload Chicken Image', type=['png', 'jpg', 'jpeg'])

if ultrasound_image and chicken_image:
    ultrasound_pil = Image.open(ultrasound_image).convert('RGB')
    chicken_pil = Image.open(chicken_image).convert('RGB')
    st.image(ultrasound_pil, caption='Uploaded Ultrasound Image', use_column_width=True)
    st.image(chicken_pil, caption='Uploaded Chicken Breast Image', use_column_width=True)
    ultrasound_tensor = transform_image(ultrasound_pil)
    chicken_tensor = transform_image(chicken_pil)
    with torch.no_grad():
        fake_chicken = generator_G(ultrasound_tensor).squeeze().numpy().transpose(1, 2, 0)
        fake_ultrasound = generator_F(chicken_tensor).squeeze().numpy().transpose(1, 2, 0)
    
    def postprocess(tensor):
        tensor = tensor * 0.5 + 0.5
        tensor = np.clip(tensor * 255, 0, 255).astype(np.uint8)
        return Image.fromarray(tensor)

    fake_chicken_image = postprocess(fake_chicken)
    fake_ultrasound_image = postprocess(fake_ultrasound)
        col1, col2, col3 = st.columns(3)
        with col1:
            st.image(ultrasound_image, caption='Uploaded Ultrasound Image', use_column_width=True)
        with col3:
            st.image(result_image_display, caption='Processed Result Image', use_column_width=True)

Overwriting Imagetoimagetranlator.py


In [None]:
!streamlit run Imagetoimagetranslator.py