In [1]:
import PIL
import numpy as np
import torch
import torchvision
from torchvision.transforms.v2 import Compose, GaussianBlur, RandomEqualize, RandomSolarize, RandomApply

import wandb
from torch.utils.data import DataLoader
from torch.utils.data import random_split

from Dataset.AerialDataset import AerialDataset
from tasks.SRDiffTrainer import SRDiffTrainer
from tasks.SR3Trainer import SR3Trainer
from utils.model_utils import load_model

- Generar imagenes bicubicas
- Construir Dataset
- Construir Dataloader
- SR3
- SRdiff
- SR3+

## 64 -> 256

## Entrenamiento

In [2]:
lr_size = 64
hr_size = 256
batch_size = 16
dataset_dir = 'E:\\TFG\\air_dataset'

In [3]:
transforms = Compose(
    [RandomApply(transforms= [GaussianBlur(7)], p = 0.5),
    RandomEqualize()]
)

dataset = AerialDataset(dataset_dir, lr_size, hr_size, data_augmentation = transforms, aux_sat_prob= 0.5, sat_dataset_path= "E:\\TFG\\satelite_dataset\\64_256")
train_dataset, val_dataset, test_dataset = random_split(dataset, [0.6, 0.2, 0.2], generator=torch.Generator().manual_seed(420))

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### SRDiff

In [None]:
from models.SRDiff.diffusion import GaussianDiffusion
from models.SRDiff.diffsr_modules import Unet, RRDBNet
losstype = "l2"
model_name = f"SRDiff{losstype}auxsatdata"
hidden_size = 64
dim_mults = [1,2,2,4]
rrdb_num_features = 32
rrdb_num_blocks = 8
timesteps = 100

denoise_fn = Unet(
    hidden_size, out_dim=3, cond_dim=rrdb_num_features, dim_mults=dim_mults, rrdb_num_block=rrdb_num_blocks, sr_scale=4)

rrdb = RRDBNet(3, 3, rrdb_num_features, rrdb_num_blocks, rrdb_num_features// 2)

model = GaussianDiffusion(
    denoise_fn=denoise_fn,
    rrdb_net=rrdb,
    timesteps= timesteps,
    loss_type=losstype,
    aux_l1_loss=True,
    aux_perceptual_loss=True
)
#model = load_model(model, f"SRDiff{losstype}.pt", "models_state_dic")
model.to(device)  

lr= 0.0002
decay_steps= 100000
gamma = 0.5

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=decay_steps, gamma=gamma)
max_steps = 5
hyperparams = {
    "max_steps": 100,
    "model": "SRDifftrainedRRDB",
    "learning_rate": lr,
    "decay_steps": decay_steps,
    "gamma": gamma,
    "batch_size": batch_size,
    "hidden_size": hidden_size,
    "dim_mults": dim_mults,
    "rrdb_num_features": rrdb_num_features,
    "rrdb_num_blocks": rrdb_num_blocks,
    "loss_type": losstype,
    "epochs": max_steps,
    "use_rrdb": True,
    "fix_rrdb": False
}

project_name = "SR model benchmarking"
run_name = model_name
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SRDiffTrainer(metrics_used=["ssim", "psnr"], model_name=model_name)
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(max_steps):
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    train_loss = trainer.train(train_dataloader, True, True)
    torch.cuda.empty_cache()
    trainer.save_model("models_state_dic")
    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()     

   ### SR3

In [4]:
from models.SR3.diffusion import GaussianDiffusion
from models.SR3.model import UNet
hyperparams = {
    "steps" : 2000,
    "sample_steps" : 100,
    "lr":0.0002,
    "epochs":10,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SR3",
    "losstype": "l2"
}
model = UNet(3, hyperparams["steps"]) #Valores por defecto ya que la tarea base es la misma upsample por 4
SR3_model = GaussianDiffusion(model, hyperparams["steps"], hyperparams["sample_steps"], losstype="l1")
SR3_model.to(device)

GaussianDiffusion(
  (model): UNet(
    (emb): GammaEmbedding(
      (linear1): Linear(
        (linear): Linear(in_features=3, out_features=12, bias=True)
      )
      (silu): SiLU()
      (linear2): Linear(
        (linear): Linear(in_features=12, out_features=12, bias=True)
      )
    )
    (conv1): Conv2d(
      (conv): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (down): ModuleList(
      (0-2): 3 x WideResNetBlock(
        (gn1): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu1): SiLU()
        (conv1): Conv2d(
          (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (silu2): SiLU()
        (linear1): Linear(
          (linear): Linear(in_features=12, out_features=3, bias=True)
        )
        (gn2): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu3): SiLU()
        (dropout): Dropout(p=0.0, inplace=False)

In [5]:
optimizer = torch.optim.Adam(SR3_model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

In [6]:
project_name = "SR model benchmarking"
run_name = "SR3 with aux satelite data"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=["ssim", "psnr"], model_name="SR3")
trainer.set_model(SR3_model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    
    train_loss = trainer.train(train_dataloader)
    torch.cuda.empty_cache()
    trainer.save_model("models_state_dic")

    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   

[34m[1mwandb[0m: Currently logged in as: [33madrianpereramoreno[0m ([33mladoscuro[0m). Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/50 [00:02<?, ?batch/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 256.00 GiB. GPU 0 has a total capacity of 11.99 GiB of which 10.60 GiB is free. Of the allocated memory 190.52 MiB is allocated by PyTorch, and 5.48 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

#### SR3+

In [7]:
from models.SR3plus.diffusion import GaussianDiffusion
from models.SR3plus.model import UNet
hyperparams = {
    "steps" : 8000,
    "sample_steps" : 100,
    "lr":0.0002,
    "epochs":10,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,
    "model" : "SR3+",
    "losstype": "l2"
}
model = UNet(3, hyperparams["steps"], channel_expansions= [1, 2, 4, 4, 4, 8, 8, 8])
SR3plus_model = GaussianDiffusion(model, hyperparams["steps"], hyperparams["sample_steps"])
SR3plus_model.to(device)

GaussianDiffusion(
  (model): UNet(
    (emb): GammaEmbedding(
      (linear1): Linear(
        (linear): Linear(in_features=3, out_features=12, bias=True)
      )
      (silu): SiLU()
      (linear2): Linear(
        (linear): Linear(in_features=12, out_features=12, bias=True)
      )
    )
    (conv1): Conv2d(
      (conv): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (down): ModuleList(
      (0-2): 3 x WideResNetBlock(
        (gn1): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu1): SiLU()
        (conv1): Conv2d(
          (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (silu2): SiLU()
        (linear1): Linear(
          (linear): Linear(in_features=12, out_features=3, bias=True)
        )
        (gn2): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu3): SiLU()
        (dropout): Dropout(p=0.0, inplace=False)

In [8]:
optimizer = torch.optim.Adam(SR3plus_model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

In [9]:
project_name = "SR model benchmarking"
run_name = "SR3+  with aux satelite data"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=["ssim", "psnr"], model_name="SR3+")
trainer.set_model(SR3plus_model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    train_loss = trainer.train(train_dataloader)
    torch.cuda.empty_cache()
    trainer.save_model("models_state_dic") 

    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

100%|██████████| 50/50 [00:07<00:00,  6.35batch/s]
100%|██████████| 148/148 [00:33<00:00,  4.47batch/s]
100%|██████████| 50/50 [00:07<00:00,  6.82batch/s]
100%|██████████| 148/148 [00:32<00:00,  4.50batch/s]
100%|██████████| 50/50 [00:07<00:00,  6.70batch/s]
100%|██████████| 148/148 [00:32<00:00,  4.62batch/s]
100%|██████████| 50/50 [00:07<00:00,  7.09batch/s]
100%|██████████| 148/148 [00:32<00:00,  4.60batch/s]
100%|██████████| 50/50 [00:07<00:00,  7.13batch/s]
100%|██████████| 148/148 [00:32<00:00,  4.61batch/s]
100%|██████████| 50/50 [00:06<00:00,  7.45batch/s]
100%|██████████| 148/148 [00:30<00:00,  4.86batch/s]
100%|██████████| 50/50 [00:06<00:00,  7.62batch/s]
100%|██████████| 148/148 [00:29<00:00,  4.99batch/s]
100%|██████████| 50/50 [00:06<00:00,  7.96batch/s]
100%|██████████| 148/148 [00:29<00:00,  5.02batch/s]
100%|██████████| 50/50 [00:06<00:00,  7.96batch/s]
100%|██████████| 148/148 [00:30<00:00,  4.79batch/s]
100%|██████████| 50/50 [00:07<00:00,  6.63batch/s]
100%|████████

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
psnr,▁
ssim,▁
train_loss,█▇▆▅▄▃▃▂▁▁
validation_loss,█▇▆▅▄▃▃▂▁▁

0,1
psnr,4.40912
ssim,0.00209
train_loss,0.34076
validation_loss,0.35272
