In [1]:
import PIL
import torch
import wandb
from torch.utils.data import DataLoader
from torch.utils.data import random_split

from Dataset.AerialDataset import AerialDataset
from tasks.SRDiffTrainer import SRDiffTrainer
from tasks.SR3Trainer import SR3Trainer

- Generar imagenes bicubicas
- Construir Dataset
- Construir Dataloader
- SR3
- SRdiff
- SR3+

## 64 -> 256

## Entrenamiento

In [2]:
lr_size = 64
hr_size = 256
batch_size = 16
dataset_dir = 'E:\\TFG\\air_dataset'

In [3]:
dataset = AerialDataset(dataset_dir, lr_size, hr_size)
train_dataset, val_dataset, test_dataset = random_split(dataset, [0.6, 0.2, 0.2])

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### SRDiff

#### Modelo

In [None]:
from models.SRDiff.diffusion import GaussianDiffusion
from models.SRDiff.diffsr_modules import Unet, RRDBNet

hidden_size = 64
dim_mults = [1,2,2,4]
rrdb_num_features = 32
rrdb_num_blocks = 8
timesteps = 100
losstype = 'l1'

denoise_fn = Unet(
    hidden_size, out_dim=3, cond_dim=rrdb_num_features, dim_mults=dim_mults, rrdb_num_block=rrdb_num_blocks, sr_scale=4)

rrdb = RRDBNet(3, 3, rrdb_num_features, rrdb_num_blocks, rrdb_num_features// 2)

model = GaussianDiffusion(
    denoise_fn=denoise_fn,
    rrdb_net=rrdb,
    timesteps= timesteps,
    loss_type=losstype
)

model.to(device)

#### Optimizador y scheduler

In [None]:
lr= 0.0002
decay_steps= 100000
gamma = 0.5

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=decay_steps, gamma=gamma)

In [None]:
max_steps = 5
hyperparams = {
    "max_steps": 100,
    "model": "SRDiff",
    "learning_rate": lr,
    "decay_steps": decay_steps,
    "gamma": gamma,
    "batch_size": batch_size,
    "hidden_size": hidden_size,
    "dim_mults": dim_mults,
    "rrdb_num_features": rrdb_num_features,
    "rrdb_num_blocks": rrdb_num_blocks,
    "loss_type": losstype
}
project_name = "SR model benchmarking"
run_name = "SRDiff Standart Params"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SRDiffTrainer(metrics_used=["ssim", "psnr"])
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(max_steps):
    train_loss = trainer.train(train_dataloader)
    torch.cuda.empty_cache()
    
    trainer.save_model("checkpoints\\SRDiff", step)
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish() 

### SR3

In [16]:
from models.SR3.diffusion import GaussianDiffusion
from models.SR3.model import UNet
hyperparams = {
    "steps" : 2000,
    "sample_steps" : 100,
    "lr":0.0002,
    "epochs":20,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5
}
model = UNet(3, hyperparams["steps"]) #Valores por defecto ya que la tarea base es la misma upsample por 4
SR3_model = GaussianDiffusion(model, hyperparams["steps"], hyperparams["sample_steps"])
SR3_model.to(device)

GaussianDiffusion(
  (model): UNet(
    (emb): GammaEmbedding(
      (linear1): Linear(
        (linear): Linear(in_features=3, out_features=12, bias=True)
      )
      (silu): SiLU()
      (linear2): Linear(
        (linear): Linear(in_features=12, out_features=12, bias=True)
      )
    )
    (conv1): Conv2d(
      (conv): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (down): ModuleList(
      (0-2): 3 x WideResNetBlock(
        (gn1): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu1): SiLU()
        (conv1): Conv2d(
          (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (silu2): SiLU()
        (linear1): Linear(
          (linear): Linear(in_features=12, out_features=3, bias=True)
        )
        (gn2): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu3): SiLU()
        (dropout): Dropout(p=0.0, inplace=False)

In [17]:
optimizer = torch.optim.Adam(SR3_model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

In [18]:
project_name = "SR model benchmarking"
run_name = "SR3 moddiffied lr and scheduler, long training"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=["ssim", "psnr"])
trainer.set_model(SR3_model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    train_loss = trainer.train(train_dataloader)
    torch.cuda.empty_cache()
    print(train_loss)
    trainer.save_model("checkpoints\\SR3", step)
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()  

100%|██████████| 148/148 [00:20<00:00,  7.28batch/s]


tensor(0.7684, device='cuda:0', grad_fn=<DivBackward0>)
Step@0: saving model to checkpoints\SR3\model_ckpt_steps_0.ckpt


100%|██████████| 50/50 [00:04<00:00, 12.07batch/s]
100%|██████████| 148/148 [00:20<00:00,  7.34batch/s]


tensor(0.7025, device='cuda:0', grad_fn=<DivBackward0>)
Step@1: saving model to checkpoints\SR3\model_ckpt_steps_1.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.67batch/s]
100%|██████████| 148/148 [00:20<00:00,  7.20batch/s]


tensor(0.6391, device='cuda:0', grad_fn=<DivBackward0>)
Step@2: saving model to checkpoints\SR3\model_ckpt_steps_2.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.93batch/s]
100%|██████████| 148/148 [00:20<00:00,  7.36batch/s]


tensor(0.5776, device='cuda:0', grad_fn=<DivBackward0>)
Step@3: saving model to checkpoints\SR3\model_ckpt_steps_3.ckpt


100%|██████████| 50/50 [00:04<00:00, 12.01batch/s]
100%|██████████| 148/148 [00:20<00:00,  7.36batch/s]


tensor(0.5210, device='cuda:0', grad_fn=<DivBackward0>)
Step@4: saving model to checkpoints\SR3\model_ckpt_steps_4.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.90batch/s]
100%|██████████| 148/148 [00:20<00:00,  7.39batch/s]


tensor(0.4692, device='cuda:0', grad_fn=<DivBackward0>)
Step@5: saving model to checkpoints\SR3\model_ckpt_steps_5.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.91batch/s]
100%|██████████| 148/148 [00:20<00:00,  7.25batch/s]


tensor(0.4217, device='cuda:0', grad_fn=<DivBackward0>)
Step@6: saving model to checkpoints\SR3\model_ckpt_steps_6.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.94batch/s]
100%|██████████| 148/148 [00:20<00:00,  7.33batch/s]


tensor(0.3847, device='cuda:0', grad_fn=<DivBackward0>)
Step@7: saving model to checkpoints\SR3\model_ckpt_steps_7.ckpt


100%|██████████| 50/50 [00:04<00:00, 12.05batch/s]
100%|██████████| 148/148 [00:20<00:00,  7.37batch/s]


tensor(0.3552, device='cuda:0', grad_fn=<DivBackward0>)
Step@8: saving model to checkpoints\SR3\model_ckpt_steps_8.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.96batch/s]
100%|██████████| 148/148 [00:20<00:00,  7.34batch/s]


tensor(0.3322, device='cuda:0', grad_fn=<DivBackward0>)
Step@9: saving model to checkpoints\SR3\model_ckpt_steps_9.ckpt


100%|██████████| 50/50 [00:04<00:00, 12.05batch/s]
100%|██████████| 148/148 [00:20<00:00,  7.36batch/s]


tensor(0.3148, device='cuda:0', grad_fn=<DivBackward0>)
Step@10: saving model to checkpoints\SR3\model_ckpt_steps_10.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.98batch/s]
100%|██████████| 148/148 [00:20<00:00,  7.36batch/s]


tensor(0.3001, device='cuda:0', grad_fn=<DivBackward0>)
Step@11: saving model to checkpoints\SR3\model_ckpt_steps_11.ckpt


100%|██████████| 50/50 [00:04<00:00, 12.04batch/s]
100%|██████████| 148/148 [00:20<00:00,  7.39batch/s]


tensor(0.2844, device='cuda:0', grad_fn=<DivBackward0>)
Step@12: saving model to checkpoints\SR3\model_ckpt_steps_12.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.99batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.40batch/s]


tensor(0.2737, device='cuda:0', grad_fn=<DivBackward0>)
Step@13: saving model to checkpoints\SR3\model_ckpt_steps_13.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.92batch/s]
100%|██████████| 148/148 [00:20<00:00,  7.31batch/s]


tensor(0.2567, device='cuda:0', grad_fn=<DivBackward0>)
Step@14: saving model to checkpoints\SR3\model_ckpt_steps_14.ckpt


100%|██████████| 50/50 [00:04<00:00, 12.05batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.41batch/s]


tensor(0.2431, device='cuda:0', grad_fn=<DivBackward0>)
Step@15: saving model to checkpoints\SR3\model_ckpt_steps_15.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.92batch/s]
100%|██████████| 148/148 [00:20<00:00,  7.36batch/s]


tensor(0.2276, device='cuda:0', grad_fn=<DivBackward0>)
Step@16: saving model to checkpoints\SR3\model_ckpt_steps_16.ckpt


100%|██████████| 50/50 [00:04<00:00, 12.11batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.41batch/s]


tensor(0.2148, device='cuda:0', grad_fn=<DivBackward0>)
Step@17: saving model to checkpoints\SR3\model_ckpt_steps_17.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.98batch/s]
100%|██████████| 148/148 [00:20<00:00,  7.38batch/s]


tensor(0.2018, device='cuda:0', grad_fn=<DivBackward0>)
Step@18: saving model to checkpoints\SR3\model_ckpt_steps_18.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.96batch/s]
100%|██████████| 148/148 [00:20<00:00,  7.32batch/s]


tensor(0.1945, device='cuda:0', grad_fn=<DivBackward0>)
Step@19: saving model to checkpoints\SR3\model_ckpt_steps_19.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.83batch/s]
100%|██████████| 50/50 [02:59<00:00,  3.59s/batch]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
psnr,▁
ssim,▁
train_loss,█▇▆▆▅▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁
validation_loss,█▇▆▆▅▄▄▃▃▃▃▂▂▂▂▂▁▁▁▁

0,1
psnr,12.37624
ssim,0.02595
train_loss,0.1945
validation_loss,0.19105


#### SR3+

In [19]:
from models.SR3plus.diffusion import GaussianDiffusion
from models.SR3plus.model import UNet
hyperparams = {
    "steps" : 2000,
    "sample_steps" : 100,
    "lr":0.0002,
    "epochs":20,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5
}
model = UNet(3, hyperparams["steps"], channel_expansions= [1, 2, 4, 4, 4, 8, 8, 8])
SR3plus_model = GaussianDiffusion(model, hyperparams["steps"], hyperparams["sample_steps"])
SR3plus_model.to(device)

GaussianDiffusion(
  (model): UNet(
    (emb): GammaEmbedding(
      (linear1): Linear(
        (linear): Linear(in_features=3, out_features=12, bias=True)
      )
      (silu): SiLU()
      (linear2): Linear(
        (linear): Linear(in_features=12, out_features=12, bias=True)
      )
    )
    (conv1): Conv2d(
      (conv): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (down): ModuleList(
      (0-2): 3 x WideResNetBlock(
        (gn1): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu1): SiLU()
        (conv1): Conv2d(
          (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (silu2): SiLU()
        (linear1): Linear(
          (linear): Linear(in_features=12, out_features=3, bias=True)
        )
        (gn2): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu3): SiLU()
        (dropout): Dropout(p=0.0, inplace=False)

In [20]:
optimizer = torch.optim.Adam(SR3plus_model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

In [21]:
project_name = "SR model benchmarking"
run_name = "SR3+ moddiffied lr and scheduler, long training"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=["ssim", "psnr"])
trainer.set_model(SR3plus_model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    train_loss = trainer.train(train_dataloader)
    torch.cuda.empty_cache()
    print(train_loss)
    trainer.save_model("checkpoints\\SR3+", step) 
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

100%|██████████| 148/148 [00:22<00:00,  6.59batch/s]


tensor(0.7682, device='cuda:0', grad_fn=<DivBackward0>)
Step@0: saving model to checkpoints\SR3+\model_ckpt_steps_0.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.26batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.65batch/s]


tensor(0.7000, device='cuda:0', grad_fn=<DivBackward0>)
Step@1: saving model to checkpoints\SR3+\model_ckpt_steps_1.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.38batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.66batch/s]


tensor(0.6353, device='cuda:0', grad_fn=<DivBackward0>)
Step@2: saving model to checkpoints\SR3+\model_ckpt_steps_2.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.23batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.56batch/s]


tensor(0.5796, device='cuda:0', grad_fn=<DivBackward0>)
Step@3: saving model to checkpoints\SR3+\model_ckpt_steps_3.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.26batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.50batch/s]


tensor(0.5304, device='cuda:0', grad_fn=<DivBackward0>)
Step@4: saving model to checkpoints\SR3+\model_ckpt_steps_4.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.42batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.73batch/s]


tensor(0.4853, device='cuda:0', grad_fn=<DivBackward0>)
Step@5: saving model to checkpoints\SR3+\model_ckpt_steps_5.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.56batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.71batch/s]


tensor(0.4444, device='cuda:0', grad_fn=<DivBackward0>)
Step@6: saving model to checkpoints\SR3+\model_ckpt_steps_6.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.27batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.71batch/s]


tensor(0.4096, device='cuda:0', grad_fn=<DivBackward0>)
Step@7: saving model to checkpoints\SR3+\model_ckpt_steps_7.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.45batch/s]
100%|██████████| 148/148 [00:21<00:00,  6.77batch/s]


tensor(0.3795, device='cuda:0', grad_fn=<DivBackward0>)
Step@8: saving model to checkpoints\SR3+\model_ckpt_steps_8.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.38batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.67batch/s]


tensor(0.3504, device='cuda:0', grad_fn=<DivBackward0>)
Step@9: saving model to checkpoints\SR3+\model_ckpt_steps_9.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.36batch/s]
100%|██████████| 148/148 [00:21<00:00,  6.76batch/s]


tensor(0.3261, device='cuda:0', grad_fn=<DivBackward0>)
Step@10: saving model to checkpoints\SR3+\model_ckpt_steps_10.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.37batch/s]
100%|██████████| 148/148 [00:21<00:00,  6.78batch/s]


tensor(0.3037, device='cuda:0', grad_fn=<DivBackward0>)
Step@11: saving model to checkpoints\SR3+\model_ckpt_steps_11.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.42batch/s]
100%|██████████| 148/148 [00:21<00:00,  6.80batch/s]


tensor(0.2857, device='cuda:0', grad_fn=<DivBackward0>)
Step@12: saving model to checkpoints\SR3+\model_ckpt_steps_12.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.51batch/s]
100%|██████████| 148/148 [00:21<00:00,  6.78batch/s]


tensor(0.2680, device='cuda:0', grad_fn=<DivBackward0>)
Step@13: saving model to checkpoints\SR3+\model_ckpt_steps_13.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.48batch/s]
100%|██████████| 148/148 [00:21<00:00,  6.73batch/s]


tensor(0.2534, device='cuda:0', grad_fn=<DivBackward0>)
Step@14: saving model to checkpoints\SR3+\model_ckpt_steps_14.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.40batch/s]
100%|██████████| 148/148 [00:21<00:00,  6.74batch/s]


tensor(0.2411, device='cuda:0', grad_fn=<DivBackward0>)
Step@15: saving model to checkpoints\SR3+\model_ckpt_steps_15.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.35batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.62batch/s]


tensor(0.2242, device='cuda:0', grad_fn=<DivBackward0>)
Step@16: saving model to checkpoints\SR3+\model_ckpt_steps_16.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.57batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.62batch/s]


tensor(0.2029, device='cuda:0', grad_fn=<DivBackward0>)
Step@17: saving model to checkpoints\SR3+\model_ckpt_steps_17.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.53batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.68batch/s]


tensor(0.1896, device='cuda:0', grad_fn=<DivBackward0>)
Step@18: saving model to checkpoints\SR3+\model_ckpt_steps_18.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.46batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.68batch/s]


tensor(0.1730, device='cuda:0', grad_fn=<DivBackward0>)
Step@19: saving model to checkpoints\SR3+\model_ckpt_steps_19.ckpt


100%|██████████| 50/50 [00:04<00:00, 11.12batch/s]
100%|██████████| 50/50 [03:23<00:00,  4.07s/batch]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
psnr,▁
ssim,▁
train_loss,█▇▆▆▅▅▄▄▃▃▃▃▂▂▂▂▂▁▁▁
validation_loss,█▇▆▆▅▅▄▄▃▃▃▃▂▂▂▂▂▁▁▁

0,1
psnr,11.96721
ssim,0.02908
train_loss,0.17303
validation_loss,0.17087
