In [2]:
import PIL
from torchmetrics.image import StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio
from torchvision import transforms

def get_test_img(file):
    test_img = PIL.Image.open(f'test pictures\\{file}')
    return transforms.ToTensor()(test_img)

In [4]:
import PIL
import numpy as np
import torch
import torchvision
from torchvision.transforms.v2 import Compose, GaussianBlur, RandomEqualize, RandomSolarize, RandomApply
import wandb
from torch.utils.data import DataLoader
from torch.utils.data import random_split

from Dataset.AerialDataset import AerialDataset
from tasks.SRDiffTrainer import SRDiffTrainer
from models.SRDIFFBuilder import SRDiffBuilder
from utils.model_utils import load_model

#Data
lr_size = 64
hr_size = 256
batch_size = 20
dataset_dir = 'E:\\TFG\\dataset_tfg'

transforms = Compose(
    [RandomApply(transforms= [GaussianBlur(7)], p = 0.5),
    RandomEqualize()]
)

dataset = AerialDataset(dataset_dir, lr_size, hr_size, data_augmentation = None, aux_sat_prob= 0.4, sat_dataset_path= "E:\\TFG\\dataset_tfg\\satelite_dataset\\64_256")
train_dataset, val_dataset, test_dataset = random_split(dataset, [0.6, 0.2, 0.2], generator=torch.Generator().manual_seed(420))

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
device = torch.device(0)

In [6]:
from tqdm import tqdm
import requests
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionUpscalePipeline
import torch
device = torch.device(0)

# load model and scheduler
model_id = "stabilityai/stable-diffusion-x4-upscaler"
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
    model_id, revision="fp16", torch_dtype=torch.float16
)
wandb.login()
hyperparams = {
    "pretrained": True,
    "fine tunning": False,
    "batch_size": batch_size
}
wandb.init(project="Stable diffusion", config=hyperparams, name=model_id + "no tuning")
pipeline = pipeline.to(device)
metrics = {"psnr":0, "ssim":0}
for batch in tqdm(test_dataloader):
    bicubic = batch["bicubic"]
    hr = batch["hr"]
    prompt = ["Satelite imagery"]*batch_size
    sr = pipeline(prompt=prompt, image=bicubic).images[0]
    ssim = StructuralSimilarityIndexMeasure().to(device=device)
    psnr = PeakSignalNoiseRatio().to(device=device)
    metrics['psnr'] += psnr(sr, hr)
    metrics['ssim'] += ssim(sr, hr)

metrics = {metric: value / len(test_dataloader) for metric, value in metrics.keys()}
wandb.log(metrics)
wandb.finish()