In [1]:
import os

import PIL
import numpy as np
import torch
import torchvision
from torchvision.transforms.v2 import Compose, GaussianBlur, RandomEqualize, RandomSolarize, RandomApply
import wandb
from torch.utils.data import DataLoader
from torch.utils.data import random_split

from Dataset.AerialDataset import AerialDataset
from tasks.SR3Trainer import SR3Trainer
from models.SR3Builder import SR3Builder
from utils.model_utils import load_model

In [13]:
#Data
lr_size = 64
hr_size = 256
batch_size = 128
dataset_dir = 'C:\\Users\\adrianperera\\Desktop\\dataset_tfg'

transforms = Compose(
    [RandomApply(transforms= [GaussianBlur(7)], p = 0.5),
    RandomEqualize()]
)

dataset = AerialDataset(dataset_dir, lr_size, hr_size, data_augmentation = None, aux_sat_prob= 0, sat_dataset_path= "C:\\Users\\adrianperera\\Desktop\\dataset_tfg\\satelite_dataset\\64_256")
train_dataset, val_dataset, test_dataset = random_split(dataset, [0.6, 0.2, 0.2], generator=torch.Generator().manual_seed(420))

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [14]:
model_builder = SR3Builder()
model_builder = model_builder.set_standart()
model = model_builder.build()
model.to(device)

GaussianDiffusion(
  (model): UNet(
    (emb): GammaEmbedding(
      (linear1): Linear(
        (linear): Linear(in_features=3, out_features=12, bias=True)
      )
      (silu): SiLU()
      (linear2): Linear(
        (linear): Linear(in_features=12, out_features=12, bias=True)
      )
    )
    (conv1): Conv2d(
      (conv): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (down): ModuleList(
      (0-2): 3 x WideResNetBlock(
        (gn1): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu1): SiLU()
        (conv1): Conv2d(
          (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (silu2): SiLU()
        (linear1): Linear(
          (linear): Linear(in_features=12, out_features=3, bias=True)
        )
        (gn2): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu3): SiLU()
        (dropout): Dropout(p=0, inplace=False)
 

In [4]:

model_name =  f"SR3 version 3"
lr_size = 64
hr_size = 256
batch_size = 64
dataset_dir = 'C:\\Users\\adrianperera\\Desktop\\dataset_tfg'

transforms = Compose(
    [RandomApply(transforms= [GaussianBlur(7)], p = 0.5),
    RandomEqualize()]
)

dataset = AerialDataset(dataset_dir, lr_size, hr_size, data_augmentation = None, aux_sat_prob= 0.6, sat_dataset_path= "C:\\Users\\adrianperera\\Desktop\\dataset_tfg\\satelite_dataset\\64_256")
train_dataset, val_dataset, test_dataset = random_split(dataset, [0.6, 0.2, 0.2], generator=torch.Generator().manual_seed(420))

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_builder = SR3Builder()
model_builder = model_builder.set_sr3plus()
model_builder = model_builder.set_steps(4000)
model_builder = model_builder.set_sample_steps(200)
model = model_builder.build()
model.to(device)

hyperparams = {
    "lr":0.002,
    "epochs":100,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SR3",
    "batch_size" : batch_size,
    "ddp": False,
    "grad_acum": 0
}
hyperparams.update(model_builder.get_hyperparameters())
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])
project_name = "SR model benchmarking DA"
print(model_name)
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=model_name)
torch.backends.cudnn.benchmark = True

trainer = SR3Trainer(metrics_used=("ssim", "psnr"), model_name=model_name)
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    train_loss = trainer.train(train_dataloader, step)
    torch.cuda.empty_cache()
    if step % 10 == 0 and step > 0:
        trainer.save_model("saved models\\SR3+")
    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
     
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()

SR3 version 3


100%|██████████| 13/13 [00:04<00:00,  3.04batch/s]
100%|██████████| 37/37 [00:14<00:00,  2.53batch/s]
100%|██████████| 13/13 [00:04<00:00,  3.10batch/s]
100%|██████████| 37/37 [00:14<00:00,  2.61batch/s]
100%|██████████| 13/13 [00:04<00:00,  3.18batch/s]
100%|██████████| 37/37 [00:14<00:00,  2.61batch/s]
100%|██████████| 13/13 [00:04<00:00,  3.22batch/s]
100%|██████████| 37/37 [00:14<00:00,  2.59batch/s]
100%|██████████| 13/13 [00:04<00:00,  3.19batch/s]
100%|██████████| 37/37 [00:14<00:00,  2.60batch/s]
100%|██████████| 13/13 [00:04<00:00,  3.18batch/s]
100%|██████████| 37/37 [00:14<00:00,  2.60batch/s]
100%|██████████| 13/13 [00:04<00:00,  3.17batch/s]
100%|██████████| 37/37 [00:14<00:00,  2.62batch/s]
100%|██████████| 13/13 [00:04<00:00,  3.21batch/s]
100%|██████████| 37/37 [00:14<00:00,  2.63batch/s]
100%|██████████| 13/13 [00:04<00:00,  3.19batch/s]
100%|██████████| 37/37 [00:14<00:00,  2.61batch/s]
100%|██████████| 13/13 [00:04<00:00,  3.19batch/s]
100%|██████████| 37/37 [00:14<0

VBox(children=(Label(value='0.001 MB of 0.040 MB uploaded\r'), FloatProgress(value=0.034810275748291306, max=1…

0,1
psnr,▁
ssim,▁
train_loss,█▅▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_loss,█▅▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
psnr,4.09738
ssim,0.00215
train_loss,0.0177
validation_loss,0.01864


In [5]:
project_name = "SR model benchmarking"
run_name = "SR3 standart 128 bach"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)
torch.backends.cudnn.benchmark = True

trainer = SR3Trainer(metrics_used=("ssim", "psnr"), model_name="SR3 Standart 128 bach")
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    train_loss = trainer.train(train_dataloader, step)
    torch.cuda.empty_cache()
    if step % 10 == 0:
        trainer.save_model("saved models\\SR3")
    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33madrianpereramoreno[0m ([33mladoscuro[0m). Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 7/7 [00:05<00:00,  1.27batch/s]
100%|██████████| 19/19 [00:18<00:00,  1.04batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.79batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.53batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.83batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.54batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.80batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.55batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.78batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.55batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.83batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.56batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.81batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.55batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.82batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.55batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.81batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.55batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.82batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.55batch/s]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
psnr,▁
ssim,▁
train_loss,██▇▇▆▆▆▅▅▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
validation_loss,██▇▇▆▆▆▅▅▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁

0,1
psnr,4.24513
ssim,0.00269
train_loss,0.16762
validation_loss,0.16789


In [8]:
hyperparams = {
    "lr":0.0002,
    "epochs":100,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SR3+",
    "batch_size" : batch_size,
    "ddp": False
}
hyperparams.update(model_builder.get_hyperparameters())
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

In [9]:
project_name = "SR model benchmarking"
run_name = "SR3+ standart 128 bach"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=("ssim", "psnr"), model_name="SR3+ standart 128 bach")
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    
    train_loss = trainer.train(train_dataloader,epoch=step)
    torch.cuda.empty_cache()
    if step % 10 == 0:
        trainer.save_model("saved models\\SR3+")

    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01128888888957186, max=1.0)…

100%|██████████| 7/7 [00:04<00:00,  1.74batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.53batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.82batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.51batch/s]
100%|██████████| 7/7 [00:04<00:00,  1.72batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.56batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.80batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.57batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.80batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.56batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.84batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.56batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.81batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.56batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.82batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.57batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.83batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.56batch/s]
100%|██████████| 7/7 [00:03<00:00,  1.83batch/s]
100%|██████████| 19/19 [00:12<00:00,  1.56batch/s]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
psnr,▁
ssim,▁
train_loss,██▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
validation_loss,██▇▇▇▇▆▆▆▅▅▅▅▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁

0,1
psnr,5.70733
ssim,0.00744
train_loss,0.04932
validation_loss,0.04944


In [10]:
#Data
lr_size = 64
hr_size = 256
batch_size = 128
dataset_dir = 'C:\\Users\\adrianperera\\Desktop\\dataset_tfg'

transforms = Compose(
    [RandomApply(transforms= [GaussianBlur(7)], p = 0.5),
    RandomEqualize()]
)

dataset = AerialDataset(dataset_dir, lr_size, hr_size, data_augmentation = transforms, aux_sat_prob= 0.6, sat_dataset_path= "C:\\Users\\adrianperera\\Desktop\\dataset_tfg\\satelite_dataset\\64_256")
train_dataset, val_dataset, test_dataset = random_split(dataset, [0.6, 0.2, 0.2], generator=torch.Generator().manual_seed(420))

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [11]:
#Con data augmentation

model_builder = SR3Builder()
model_builder = model_builder.set_standart()
model_builder = model_builder.set_sample_steps(1000)
model_builder = model_builder.set_losstype("l1")
model = model_builder.build()
model.to(device)

hyperparams = {
    "lr":0.0002,
    "epochs":100,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SR3",
    "grad_acum": 0,
    "ddp": False,
    "batch_size":batch_size
}
hyperparams.update(model_builder.get_hyperparameters())
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

project_name = "SR model benchmarking"
run_name = "SR3 standart l1, DA, 128Batch"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=("ssim", "psnr"), model_name="SR3 Standart l1 DA 128b", grad_acum=hyperparams["grad_acum"])
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for epoch in range(hyperparams["epochs"]):  
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    
    train_loss = trainer.train(train_dataloader, epoch)
    torch.cuda.empty_cache()
    if epoch % 10 == 0:
        trainer.save_model("saved models\\SR3")

    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   

100%|██████████| 7/7 [00:06<00:00,  1.09batch/s]
100%|██████████| 19/19 [00:19<00:00,  1.02s/batch]
100%|██████████| 7/7 [00:06<00:00,  1.15batch/s]
100%|██████████| 19/19 [00:19<00:00,  1.01s/batch]
100%|██████████| 7/7 [00:05<00:00,  1.17batch/s]
100%|██████████| 19/19 [00:18<00:00,  1.01batch/s]
100%|██████████| 7/7 [00:05<00:00,  1.18batch/s]
100%|██████████| 19/19 [00:18<00:00,  1.03batch/s]
100%|██████████| 7/7 [00:05<00:00,  1.19batch/s]
100%|██████████| 19/19 [00:18<00:00,  1.04batch/s]
100%|██████████| 7/7 [00:05<00:00,  1.25batch/s]
100%|██████████| 19/19 [00:17<00:00,  1.07batch/s]
100%|██████████| 7/7 [00:05<00:00,  1.26batch/s]
100%|██████████| 19/19 [00:17<00:00,  1.07batch/s]
100%|██████████| 7/7 [00:05<00:00,  1.28batch/s]
100%|██████████| 19/19 [00:17<00:00,  1.09batch/s]
100%|██████████| 7/7 [00:05<00:00,  1.28batch/s]
100%|██████████| 19/19 [00:17<00:00,  1.11batch/s]
100%|██████████| 7/7 [00:05<00:00,  1.28batch/s]
100%|██████████| 19/19 [00:16<00:00,  1.13batch/s]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
psnr,▁
ssim,▁
train_loss,██▇▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
validation_loss,███▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁

0,1
psnr,4.33392
ssim,0.0022
train_loss,0.26871
validation_loss,0.27155
