In [1]:
import PIL
import numpy as np
import torch
import wandb
from torch.utils.data import DataLoader
from torch.utils.data import random_split

from Dataset.AerialDataset import AerialDataset
from tasks.SRDiffTrainer import SRDiffTrainer
from tasks.SR3Trainer import SR3Trainer

- Generar imagenes bicubicas
- Construir Dataset
- Construir Dataloader
- SR3
- SRdiff
- SR3+

## 64 -> 256

## Entrenamiento

In [2]:
lr_size = 64
hr_size = 256
batch_size = 16
dataset_dir = 'E:\\TFG\\air_dataset'

In [3]:
dataset = AerialDataset(dataset_dir, lr_size, hr_size)
train_dataset, val_dataset, test_dataset = random_split(dataset, [0.6, 0.2, 0.2])

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### SRDiff

#### Modelo

In [4]:
from models.SRDiff.diffusion import GaussianDiffusion
from models.SRDiff.diffsr_modules import Unet, RRDBNet

hidden_size = 64
dim_mults = [1,2,2,4]
rrdb_num_features = 32
rrdb_num_blocks = 8
timesteps = 100
losstype = 'l1'

denoise_fn = Unet(
    hidden_size, out_dim=3, cond_dim=rrdb_num_features, dim_mults=dim_mults, rrdb_num_block=rrdb_num_blocks, sr_scale=4)

rrdb = RRDBNet(3, 3, rrdb_num_features, rrdb_num_blocks, rrdb_num_features// 2)

model = GaussianDiffusion(
    denoise_fn=denoise_fn,
    rrdb_net=rrdb,
    timesteps= timesteps,
    loss_type=losstype
)

model.to(device)

GaussianDiffusion(
  (denoise_fn): Unet(
    (cond_proj): ConvTranspose2d(96, 64, kernel_size=(8, 8), stride=(4, 4), padding=(2, 2))
    (time_pos_emb): SinusoidalPosEmb()
    (mlp): Sequential(
      (0): Linear(in_features=64, out_features=256, bias=True)
      (1): Mish()
      (2): Linear(in_features=256, out_features=64, bias=True)
    )
    (downs): ModuleList(
      (0): ModuleList(
        (0): ResnetBlock(
          (mlp): Sequential(
            (0): Mish()
            (1): Linear(in_features=64, out_features=64, bias=True)
          )
          (block1): Block(
            (block): Sequential(
              (0): ReflectionPad2d((1, 1, 1, 1))
              (1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
              (2): Mish()
            )
          )
          (block2): Block(
            (block): Sequential(
              (0): ReflectionPad2d((1, 1, 1, 1))
              (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
              (2): Mish()
            )
 

#### Optimizador y scheduler

In [5]:
lr= 0.0002
decay_steps= 100000
gamma = 0.5

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=decay_steps, gamma=gamma)

In [6]:
max_steps = 5
hyperparams = {
    "max_steps": 100,
    "model": "SRDiff",
    "learning_rate": lr,
    "decay_steps": decay_steps,
    "gamma": gamma,
    "batch_size": batch_size,
    "hidden_size": hidden_size,
    "dim_mults": dim_mults,
    "rrdb_num_features": rrdb_num_features,
    "rrdb_num_blocks": rrdb_num_blocks,
    "loss_type": losstype
}
project_name = "SR model benchmarking"
run_name = "SRDiff Standart Params"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SRDiffTrainer(metrics_used=["ssim", "psnr"], model_name="SRDiff")
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(max_steps):
    train_loss = trainer.train(train_dataloader)
    torch.cuda.empty_cache()
    
    trainer.save_model("models_state_dic")
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   

100%|██████████| 148/148 [07:13<00:00,  2.93s/batch, q=0.209]
100%|██████████| 50/50 [00:26<00:00,  1.87batch/s, q=0.17] 
100%|██████████| 148/148 [05:39<00:00,  2.30s/batch, q=0.0979]
100%|██████████| 50/50 [00:11<00:00,  4.31batch/s, q=0.102] 
100%|██████████| 148/148 [08:48<00:00,  3.57s/batch, q=0.416] 
100%|██████████| 50/50 [00:18<00:00,  2.73batch/s, q=0.0817]
100%|██████████| 148/148 [06:06<00:00,  2.47s/batch, q=0.162] 
100%|██████████| 50/50 [00:13<00:00,  3.62batch/s, q=0.0603]
100%|██████████| 148/148 [06:30<00:00,  2.64s/batch, q=0.0648]
100%|██████████| 50/50 [00:11<00:00,  4.27batch/s, q=0.0819]
100%|██████████| 50/50 [13:25<00:00, 16.12s/batch, n_samples=1, psnr=32.5, ssim=0.775]


VBox(children=(Label(value='0.006 MB of 0.012 MB uploaded (0.003 MB deduped)\r'), FloatProgress(value=0.539696…

0,1
psnr,▁
ssim,▁
train_loss,█▂▂▁▁
validation_loss,█▃▁▁▁

0,1
psnr,32.62196
ssim,0.77044
train_loss,0.09964
validation_loss,0.11607


### SR3

In [7]:
from models.SR3.diffusion import GaussianDiffusion
from models.SR3.model import UNet
hyperparams = {
    "steps" : 2000,
    "sample_steps" : 100,
    "lr":0.0002,
    "epochs":20,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5
}
model = UNet(3, hyperparams["steps"]) #Valores por defecto ya que la tarea base es la misma upsample por 4
SR3_model = GaussianDiffusion(model, hyperparams["steps"], hyperparams["sample_steps"])
SR3_model.to(device)

GaussianDiffusion(
  (model): UNet(
    (emb): GammaEmbedding(
      (linear1): Linear(
        (linear): Linear(in_features=3, out_features=12, bias=True)
      )
      (silu): SiLU()
      (linear2): Linear(
        (linear): Linear(in_features=12, out_features=12, bias=True)
      )
    )
    (conv1): Conv2d(
      (conv): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (down): ModuleList(
      (0-2): 3 x WideResNetBlock(
        (gn1): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu1): SiLU()
        (conv1): Conv2d(
          (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (silu2): SiLU()
        (linear1): Linear(
          (linear): Linear(in_features=12, out_features=3, bias=True)
        )
        (gn2): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu3): SiLU()
        (dropout): Dropout(p=0.0, inplace=False)

In [8]:
optimizer = torch.optim.Adam(SR3_model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

In [11]:
project_name = "SR model benchmarking"
run_name = "SR3 moddiffied lr and scheduler, long training"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=["ssim", "psnr"], model_name="SR3")
trainer.set_model(SR3_model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    train_loss = trainer.train(train_dataloader)
    torch.cuda.empty_cache()
    print(train_loss)
    trainer.save_model("models_state_dic")
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()  

100%|██████████| 148/148 [00:20<00:00,  7.36batch/s]


tensor(0.1955, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.43batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.48batch/s]


tensor(0.1839, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.41batch/s]
100%|██████████| 50/50 [03:02<00:00,  3.65s/batch]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
psnr,▁
ssim,▁
train_loss,█▇▇▆▅▅▄▄▄▃▃▃▃▂▂▂▂▁▁▁
validation_loss,█▇▇▆▅▅▄▄▄▃▃▃▃▂▂▂▂▁▁▁

0,1
psnr,11.57389
ssim,0.02501
train_loss,0.1839
validation_loss,0.18056


#### SR3+

In [4]:
from models.SR3plus.diffusion import GaussianDiffusion
from models.SR3plus.model import UNet
hyperparams = {
    "steps" : 8000,
    "sample_steps" : 100,
    "lr":0.0002,
    "epochs":20,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5
}
model = UNet(3, hyperparams["steps"], channel_expansions= [1, 2, 4, 4, 4, 8, 8, 8])
SR3plus_model = GaussianDiffusion(model, hyperparams["steps"], hyperparams["sample_steps"])
SR3plus_model.to(device)

GaussianDiffusion(
  (model): UNet(
    (emb): GammaEmbedding(
      (linear1): Linear(
        (linear): Linear(in_features=3, out_features=12, bias=True)
      )
      (silu): SiLU()
      (linear2): Linear(
        (linear): Linear(in_features=12, out_features=12, bias=True)
      )
    )
    (conv1): Conv2d(
      (conv): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (down): ModuleList(
      (0-2): 3 x WideResNetBlock(
        (gn1): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu1): SiLU()
        (conv1): Conv2d(
          (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (silu2): SiLU()
        (linear1): Linear(
          (linear): Linear(in_features=12, out_features=3, bias=True)
        )
        (gn2): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu3): SiLU()
        (dropout): Dropout(p=0.0, inplace=False)

In [5]:
optimizer = torch.optim.Adam(SR3plus_model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

In [6]:
project_name = "SR model benchmarking"
run_name = "SR3+ moddiffied lr and scheduler, long training"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=["ssim", "psnr"], model_name="SR3+")
trainer.set_model(SR3plus_model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    train_loss = trainer.train(train_dataloader)
    torch.cuda.empty_cache()
    trainer.save_model("models_state_dic") 
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   

[34m[1mwandb[0m: Currently logged in as: [33madrianpereramoreno[0m ([33mladoscuro[0m). Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 148/148 [00:50<00:00,  2.95batch/s]
100%|██████████| 50/50 [00:11<00:00,  4.22batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.48batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.11batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.72batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.23batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.70batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.33batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.69batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.32batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.66batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.19batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.53batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.04batch/s]
100%|██████████| 148/148 [00:23<00:00,  6.41batch/s]
100%|██████████| 50/50 [00:04<00:00, 10.89batch/s]
100%|██████████| 148/148 [00:23<00:00,  6.40batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.01batch/s]
100%|██████████| 148/148 [00:23<00:00,  6.40batch/s]
100%|██████

VBox(children=(Label(value='0.001 MB of 0.008 MB uploaded\r'), FloatProgress(value=0.16419010192803635, max=1.…

0,1
psnr,▁
ssim,▁
train_loss,█▇▆▆▅▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁
validation_loss,█▇▆▆▅▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁

0,1
psnr,10.68461
ssim,0.01753
train_loss,0.23969
validation_loss,0.2371


Let's upscale some images and see the results.

In [4]:
#sr3
from models.SRDiff.diffusion import GaussianDiffusion
from models.SRDiff.diffsr_modules import Unet, RRDBNet
from utils.model_utils import load_model
from utils.tensor_utils import *

dir_checkpoint = ("checkpoints/SR3")


hidden_size = 64
dim_mults = [1,2,2,4]
rrdb_num_features = 32
rrdb_num_blocks = 8
timesteps = 100
losstype = 'l1'

denoise_fn = Unet(
    hidden_size, out_dim=3, cond_dim=rrdb_num_features, dim_mults=dim_mults, rrdb_num_block=rrdb_num_blocks, sr_scale=4)

rrdb = RRDBNet(3, 3, rrdb_num_features, rrdb_num_blocks, rrdb_num_features// 2)

model = GaussianDiffusion(
    denoise_fn=denoise_fn,
    rrdb_net=rrdb,
    timesteps= timesteps,
    loss_type=losstype
)
model = load_model(model, "SrDiff.pt", "models_state_dic")
model.to(device)

imgs = dataset.get_image_from_name("Amnesty POI-9-3-3")

hr = imgs['hr']
lr = imgs['lr']
bicubic = imgs['bicubic']
sr, _ = model.sample(lr.to(device).unsqueeze(0), bicubic.to(device).unsqueeze(0), hr.to(device).unsqueeze(0).shape, True)

The model weights have been loaded from 'models_state_dic\SrDiff.pt'


In [6]:
img_sr = tensor2img(sr.squeeze(), normalize=False)
img_bicubic = tensor2img(bicubic.squeeze(), normalize=False)
img_lr = tensor2img(lr.squeeze(), normalize=False)
img_hr = tensor2img(hr.squeeze(), normalize=False)