In [2]:
import PIL
import numpy as np
import torch
import torchvision
from torchvision.transforms.v2 import Compose, GaussianBlur, RandomEqualize, RandomSolarize, RandomApply
import wandb
from torch.utils.data import DataLoader
from torch.utils.data import random_split

from Dataset.AerialDataset import AerialDataset
from tasks.SRDiffTrainer import SRDiffTrainer
from models.SRDIFFBuilder import SRDiffBuilder
from utils.model_utils import load_model

#Data
lr_size = 64
hr_size = 256
batch_size = 20
dataset_dir = 'C:\\Users\\adrianperera\\Desktop\\dataset_tfg'

transforms = Compose(
    [RandomApply(transforms= [GaussianBlur(7)], p = 0.5),
    RandomEqualize()]
)

dataset = AerialDataset(dataset_dir, lr_size, hr_size, data_augmentation = None, aux_sat_prob= 0.4, sat_dataset_path= "C:\\Users\\adrianperera\\Desktop\\dataset_tfg\\satelite_dataset\\64_256")
train_dataset, val_dataset, test_dataset = random_split(dataset, [0.6, 0.2, 0.2], generator=torch.Generator().manual_seed(420))

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
device = torch.device(0)

In [3]:
model_builder = SRDiffBuilder()
model = model_builder.set_standart().build()
model.to(device)

GaussianDiffusion(
  (denoise_fn): Unet(
    (cond_proj): ConvTranspose2d(96, 64, kernel_size=(8, 8), stride=(4, 4), padding=(2, 2))
    (time_pos_emb): SinusoidalPosEmb()
    (mlp): Sequential(
      (0): Linear(in_features=64, out_features=256, bias=True)
      (1): Mish()
      (2): Linear(in_features=256, out_features=64, bias=True)
    )
    (downs): ModuleList(
      (0): ModuleList(
        (0): ResnetBlock(
          (mlp): Sequential(
            (0): Mish()
            (1): Linear(in_features=64, out_features=64, bias=True)
          )
          (block1): Block(
            (block): Sequential(
              (0): ReflectionPad2d((1, 1, 1, 1))
              (1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
              (2): Mish()
            )
          )
          (block2): Block(
            (block): Sequential(
              (0): ReflectionPad2d((1, 1, 1, 1))
              (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
              (2): Mish()
            )
 

In [4]:
model_name = "SRDIFF standart aux SSIM"
hyperparams = {
    "lr":0.002,
    "epochs":100,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SRDiff",
    "batch_size" : batch_size,
    "ddp": False,
    "grad_acum": 1,
    "use_rrdb":True,
    "fix_rrdb":False,
    "aux_l1_loss":True,
    "aux_perceptual_loss":False,
    "aux_ssim_loss":False
}
hyperparams.update(model_builder.get_hyperparameters())
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

project_name = "SR model benchmarking"
run_name = model_name
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)
torch.backends.cudnn.benchmark = True

trainer = SRDiffTrainer(metrics_used=("ssim", "psnr"), model_name=model_name, device=device, 
                        use_rrdb=hyperparams["use_rrdb"], fix_rrdb=hyperparams["fix_rrdb"], aux_ssim_loss=hyperparams["aux_ssim_loss"],
                        aux_l1_loss=hyperparams["aux_l1_loss"], aux_perceptual_loss=hyperparams["aux_perceptual_loss"],
                        grad_acum=hyperparams["grad_acum"])

trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    torch.cuda.empty_cache()
    train_loss = trainer.train(train_dataloader, step)
    torch.cuda.empty_cache()
    if step % 10 == 0:
        trainer.save_model("saved models\\SRDiff\\large")
    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33madrianpereramoreno[0m ([33mladoscuro[0m). Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 40/40 [00:14<00:00,  2.74batch/s, aux_l1=1.19, aux_ssim=0.936, q=0.801]
100%|██████████| 118/118 [01:20<00:00,  1.46batch/s, aux_l1=1.16, aux_ssim=0.932, q=0.8]  
100%|██████████| 40/40 [00:10<00:00,  3.92batch/s, aux_l1=1.15, aux_ssim=0.931, q=0.801]
100%|██████████| 118/118 [01:10<00:00,  1.67batch/s, aux_l1=1.11, aux_ssim=0.931, q=0.801]
100%|██████████| 40/40 [00:10<00:00,  3.91batch/s, aux_l1=1.09, aux_ssim=0.94, q=0.8]   
100%|██████████| 118/118 [01:10<00:00,  1.66batch/s, aux_l1=1.11, aux_ssim=0.932, q=0.8]  
100%|██████████| 40/40 [00:10<00:00,  3.89batch/s, aux_l1=1.05, aux_ssim=0.936, q=0.802]
100%|██████████| 118/118 [01:10<00:00,  1.66batch/s, aux_l1=1.09, aux_ssim=0.929, q=0.8]  
100%|██████████| 40/40 [00:10<00:00,  3.87batch/s, aux_l1=1.14, aux_ssim=0.931, q=0.802]
100%|██████████| 118/118 [01:11<00:00,  1.66batch/s, aux_l1=1.13, aux_ssim=0.932, q=0.801]
100%|██████████| 40/40 [00:10<00:00,  3.87batch/s, aux_l1=1.18, aux_ssim=0.925, q=0.802]
100%|██████

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
psnr,▁
ssim,▁
train_loss,▆▂▇▂▆▄▆▆▅▄▃▆▆▄▃▄▇▃▂▄▇▃▂█▅▁▃▁▆▂▄▃▃▅▄▅▅▂▅▄
validation_loss,█▃▇▂▁▇▅█▆█▆▂▄▂▄▅▂▅▅▆▄▂▂▄▃▆▂▅▇▁█▄▅▅▂▂▂▂▃▅

0,1
psnr,7.43433
ssim,0.0186
train_loss,2.85192
validation_loss,2.85112


In [9]:
from models.SR3Builder import SR3Builder
from tasks.SR3Trainer import SR3Trainer


lr_size = 64
hr_size = 256
batch_size = 128
dataset_dir = 'C:\\Users\\adrianperera\\Desktop\\dataset_tfg'

dataset = AerialDataset(dataset_dir, lr_size, hr_size, data_augmentation = None, aux_sat_prob= 0.5, sat_dataset_path= "C:\\Users\\adrianperera\\Desktop\\dataset_tfg\\satelite_dataset\\64_256")
train_dataset, val_dataset, test_dataset = random_split(dataset, [0.6, 0.2, 0.2], generator=torch.Generator().manual_seed(420))

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
device = torch.device(0)

model_builder = SR3Builder()
model_builder = model_builder.set_standart()
model_builder = model_builder.set_sample_steps(1000)
model_builder = model_builder.set_losstype("l1")
model = model_builder.build()
model.to(device)

hyperparams = {
    "lr":0.0002,
    "epochs":100,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SR3",
    "grad_acum": 0,
    "ddp": False,
    "batch_size":batch_size
}
hyperparams.update(model_builder.get_hyperparameters())
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

project_name = "SR model benchmarking"
run_name = "SR3 standart l1, 128Batch, 100k sampling"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=("ssim", "psnr"), model_name="SR3 Standart l1 DA 128b 100k sampling", grad_acum=hyperparams["grad_acum"])
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for epoch in range(hyperparams["epochs"]):  
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    
    train_loss = trainer.train(train_dataloader, epoch)
    torch.cuda.empty_cache()
    if epoch % 10 == 0:
        trainer.save_model("saved models\\SR3")

    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   

100%|██████████| 7/7 [00:05<00:00,  1.29batch/s]
100%|██████████| 19/19 [00:18<00:00,  1.06batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.80batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.54batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.82batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.55batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.84batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.55batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.82batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.54batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.80batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.54batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.79batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.54batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.79batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.54batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.80batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.54batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.80batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.55batch/s]


VBox(children=(Label(value='0.007 MB of 0.042 MB uploaded (0.002 MB deduped)\r'), FloatProgress(value=0.169956…

0,1
psnr,▁
ssim,▁
train_loss,██▇▇▇▇▆▆▆▆▅▅▅▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
validation_loss,██▇▇▇▇▆▆▆▆▅▅▅▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁

0,1
psnr,3.679
ssim,0.00189
train_loss,0.27829
validation_loss,0.28036
