In [4]:
import PIL
import numpy as np
import torch
import wandb
from torch.utils.data import DataLoader
from torch.utils.data import random_split

from Dataset.AerialDataset import AerialDataset
from tasks.SRDiffTrainer import SRDiffTrainer
from tasks.SR3Trainer import SR3Trainer
from utils.model_utils import load_model

- Generar imagenes bicubicas
- Construir Dataset
- Construir Dataloader
- SR3
- SRdiff
- SR3+

## 64 -> 256

## Entrenamiento

In [5]:
lr_size = 64
hr_size = 256
batch_size = 16
dataset_dir = 'E:\\TFG\\air_dataset'

In [6]:
dataset = AerialDataset(dataset_dir, lr_size, hr_size)
train_dataset, val_dataset, test_dataset = random_split(dataset, [0.6, 0.2, 0.2], generator=torch.Generator().manual_seed(420))

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### SRDiff

In [8]:
from models.SRDiff.diffusion import GaussianDiffusion
from models.SRDiff.diffsr_modules import Unet, RRDBNet
losstypes = ["l2","ssim"]

for losstype in losstypes:
    model_name = f"SRDiff{losstype}"
    hidden_size = 64
    dim_mults = [1,2,2,4]
    rrdb_num_features = 32
    rrdb_num_blocks = 8
    timesteps = 100
    
    denoise_fn = Unet(
        hidden_size, out_dim=3, cond_dim=rrdb_num_features, dim_mults=dim_mults, rrdb_num_block=rrdb_num_blocks, sr_scale=4)
    
    rrdb = RRDBNet(3, 3, rrdb_num_features, rrdb_num_blocks, rrdb_num_features// 2)
    
    model = GaussianDiffusion(
        denoise_fn=denoise_fn,
        rrdb_net=rrdb,
        timesteps= timesteps,
        loss_type=losstype
    )
    #model = load_model(model, f"SRDiff{losstype}.pt", "models_state_dic")
    model.to(device)  
    
    lr= 0.0002
    decay_steps= 100000
    gamma = 0.5
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=decay_steps, gamma=gamma)
    max_steps = 20
    hyperparams = {
        "max_steps": 100,
        "model": "SRDiff",
        "learning_rate": lr,
        "decay_steps": decay_steps,
        "gamma": gamma,
        "batch_size": batch_size,
        "hidden_size": hidden_size,
        "dim_mults": dim_mults,
        "rrdb_num_features": rrdb_num_features,
        "rrdb_num_blocks": rrdb_num_blocks,
        "loss_type": losstype,
        "epochs": max_steps
    }
    project_name = "SR model benchmarking"
    run_name = "SRDiff"
    wandb.login()
    wandb.init(project=project_name, config=hyperparams, name=run_name)
    
    trainer = SRDiffTrainer(metrics_used=["ssim", "psnr"], model_name=model_name)
    trainer.set_model(model)
    trainer.set_optimizer(optimizer)
    trainer.set_scheduler(scheduler)
    for step in range(max_steps):
        train_loss = trainer.train(train_dataloader)
        torch.cuda.empty_cache()
        
        trainer.save_model("models_state_dic")
        with torch.no_grad():
            val_loss = trainer.validate(val_dataloader)
        torch.cuda.empty_cache()
        wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
        
    test_metrics = trainer.test(test_dataloader)
    wandb.log(test_metrics)
    wandb.finish()   

100%|██████████| 50/50 [12:48<00:00, 15.37s/batch, n_samples=1, psnr=26.6, ssim=0.752]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
psnr,▁
ssim,▁
train_loss,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_loss,█▅▃▃▄▂▂▂▃▃▂▂▂▂▂▁▂▁▁▁

0,1
psnr,22.75909
ssim,0.62807
train_loss,0.09014
validation_loss,0.08974


 ### SR3

In [9]:
from models.SR3.diffusion import GaussianDiffusion
from models.SR3.model import UNet
hyperparams = {
    "steps" : 2000,
    "sample_steps" : 100,
    "lr":0.0002,
    "epochs":100,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SR3"
}
model = UNet(3, hyperparams["steps"]) #Valores por defecto ya que la tarea base es la misma upsample por 4
SR3_model = GaussianDiffusion(model, hyperparams["steps"], hyperparams["sample_steps"])
SR3_model.to(device)

GaussianDiffusion(
  (model): UNet(
    (emb): GammaEmbedding(
      (linear1): Linear(
        (linear): Linear(in_features=3, out_features=12, bias=True)
      )
      (silu): SiLU()
      (linear2): Linear(
        (linear): Linear(in_features=12, out_features=12, bias=True)
      )
    )
    (conv1): Conv2d(
      (conv): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (down): ModuleList(
      (0-2): 3 x WideResNetBlock(
        (gn1): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu1): SiLU()
        (conv1): Conv2d(
          (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (silu2): SiLU()
        (linear1): Linear(
          (linear): Linear(in_features=12, out_features=3, bias=True)
        )
        (gn2): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu3): SiLU()
        (dropout): Dropout(p=0.0, inplace=False)

In [10]:
optimizer = torch.optim.Adam(SR3_model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

In [11]:
project_name = "SR model benchmarking"
run_name = "SR3"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=["ssim", "psnr"], model_name="SR3")
trainer.set_model(SR3_model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    train_loss = trainer.train(train_dataloader)
    torch.cuda.empty_cache()
    print(train_loss)
    trainer.save_model("models_state_dic")
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()  

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011277777777932999, max=1.0…

100%|██████████| 148/148 [00:20<00:00,  7.30batch/s]


tensor(0.7701, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 11.98batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.41batch/s]


tensor(0.7091, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.23batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.55batch/s]


tensor(0.6470, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.31batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.56batch/s]


tensor(0.5869, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.31batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.52batch/s]


tensor(0.5341, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.18batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.55batch/s]


tensor(0.4859, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.30batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.54batch/s]


tensor(0.4423, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.26batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.54batch/s]


tensor(0.4069, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.21batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.53batch/s]


tensor(0.3779, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.21batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.46batch/s]


tensor(0.3560, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.28batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.59batch/s]


tensor(0.3391, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.23batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.58batch/s]


tensor(0.3247, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.31batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.62batch/s]


tensor(0.3117, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.20batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.61batch/s]


tensor(0.2974, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.29batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.63batch/s]


tensor(0.2830, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.21batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.59batch/s]


tensor(0.2717, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.21batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.57batch/s]


tensor(0.2566, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.27batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.65batch/s]


tensor(0.2433, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.32batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.49batch/s]


tensor(0.2341, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.22batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.56batch/s]


tensor(0.2251, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.29batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.56batch/s]


tensor(0.2160, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.18batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.53batch/s]


tensor(0.2106, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.24batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.56batch/s]


tensor(0.1995, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.26batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.55batch/s]


tensor(0.1941, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.19batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.52batch/s]


tensor(0.1885, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.14batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.56batch/s]


tensor(0.1815, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.17batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.54batch/s]


tensor(0.1777, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.22batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.49batch/s]


tensor(0.1744, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.23batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.50batch/s]


tensor(0.1706, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.23batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.50batch/s]


tensor(0.1664, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.33batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.52batch/s]


tensor(0.1615, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.05batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.60batch/s]


tensor(0.1597, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.33batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.58batch/s]


tensor(0.1547, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.27batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.59batch/s]


tensor(0.1537, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.35batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.58batch/s]


tensor(0.1509, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.29batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.64batch/s]


tensor(0.1481, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.29batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.62batch/s]


tensor(0.1478, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.28batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.60batch/s]


tensor(0.1439, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.31batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.61batch/s]


tensor(0.1447, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.30batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.56batch/s]


tensor(0.1422, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.27batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.55batch/s]


tensor(0.1416, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.26batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.54batch/s]


tensor(0.1440, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.35batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.65batch/s]


tensor(0.1381, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.12batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.61batch/s]


tensor(0.1336, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.08batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.60batch/s]


tensor(0.1375, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.19batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.61batch/s]


tensor(0.1330, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.28batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.63batch/s]


tensor(0.1331, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.28batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.57batch/s]


tensor(0.1300, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.30batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.62batch/s]


tensor(0.1318, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.38batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.56batch/s]


tensor(0.1312, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.34batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.62batch/s]


tensor(0.1306, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.25batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.61batch/s]


tensor(0.1260, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.30batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.62batch/s]


tensor(0.1271, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.41batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.60batch/s]


tensor(0.1275, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.29batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.62batch/s]


tensor(0.1278, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.34batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.53batch/s]


tensor(0.1240, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.32batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.62batch/s]


tensor(0.1248, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.43batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.56batch/s]


tensor(0.1242, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.33batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.58batch/s]


tensor(0.1239, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.32batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.60batch/s]


tensor(0.1215, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.36batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.62batch/s]


tensor(0.1237, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.39batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.60batch/s]


tensor(0.1226, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.18batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.58batch/s]


tensor(0.1201, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.24batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.59batch/s]


tensor(0.1239, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.30batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.57batch/s]


tensor(0.1194, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.26batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.59batch/s]


tensor(0.1191, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.23batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.60batch/s]


tensor(0.1218, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.28batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.63batch/s]


tensor(0.1184, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.41batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.58batch/s]


tensor(0.1193, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.32batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.61batch/s]


tensor(0.1172, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.30batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.56batch/s]


tensor(0.1166, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.35batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.58batch/s]


tensor(0.1197, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.34batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.57batch/s]


tensor(0.1159, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.37batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.65batch/s]


tensor(0.1186, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.21batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.60batch/s]


tensor(0.1196, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.38batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.59batch/s]


tensor(0.1147, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.24batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.64batch/s]


tensor(0.1144, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.37batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.63batch/s]


tensor(0.1172, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.41batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.57batch/s]


tensor(0.1123, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.37batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.63batch/s]


tensor(0.1131, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.28batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.63batch/s]


tensor(0.1118, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.40batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.69batch/s]


tensor(0.1116, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.37batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.65batch/s]


tensor(0.1111, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.34batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.65batch/s]


tensor(0.1135, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.28batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.66batch/s]


tensor(0.1116, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.32batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.58batch/s]


tensor(0.1132, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.37batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.68batch/s]


tensor(0.1096, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.31batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.64batch/s]


tensor(0.1101, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.31batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.61batch/s]


tensor(0.1096, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.25batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.65batch/s]


tensor(0.1105, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.29batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.63batch/s]


tensor(0.1082, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.46batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.66batch/s]


tensor(0.1123, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.20batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.64batch/s]


tensor(0.1090, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.35batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.57batch/s]


tensor(0.1112, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.26batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.64batch/s]


tensor(0.1086, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.27batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.62batch/s]


tensor(0.1074, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.38batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.63batch/s]


tensor(0.1076, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.30batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.63batch/s]


tensor(0.1082, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.36batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.65batch/s]


tensor(0.1074, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.15batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.63batch/s]


tensor(0.1070, device='cuda:0', grad_fn=<DivBackward0>)


100%|██████████| 50/50 [00:04<00:00, 12.36batch/s]
100%|██████████| 50/50 [02:39<00:00,  3.20s/batch]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
psnr,▁
ssim,▁
train_loss,█▇▅▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_loss,█▇▅▄▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
psnr,7.38325
ssim,0.02394
train_loss,0.10705
validation_loss,0.10943


#### SR3+

In [12]:
from models.SR3plus.diffusion import GaussianDiffusion
from models.SR3plus.model import UNet
hyperparams = {
    "steps" : 8000,
    "sample_steps" : 100,
    "lr":0.0002,
    "epochs":100,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,
    "model" : "SR3+"
}
model = UNet(3, hyperparams["steps"], channel_expansions= [1, 2, 4, 4, 4, 8, 8, 8])
SR3plus_model = GaussianDiffusion(model, hyperparams["steps"], hyperparams["sample_steps"])
SR3plus_model.to(device)

GaussianDiffusion(
  (model): UNet(
    (emb): GammaEmbedding(
      (linear1): Linear(
        (linear): Linear(in_features=3, out_features=12, bias=True)
      )
      (silu): SiLU()
      (linear2): Linear(
        (linear): Linear(in_features=12, out_features=12, bias=True)
      )
    )
    (conv1): Conv2d(
      (conv): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (down): ModuleList(
      (0-2): 3 x WideResNetBlock(
        (gn1): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu1): SiLU()
        (conv1): Conv2d(
          (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (silu2): SiLU()
        (linear1): Linear(
          (linear): Linear(in_features=12, out_features=3, bias=True)
        )
        (gn2): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu3): SiLU()
        (dropout): Dropout(p=0.0, inplace=False)

In [13]:
optimizer = torch.optim.Adam(SR3plus_model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

In [14]:
project_name = "SR model benchmarking"
run_name = "SR3+"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=["ssim", "psnr"], model_name="SR3+")
trainer.set_model(SR3plus_model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    train_loss = trainer.train(train_dataloader)
    torch.cuda.empty_cache()
    trainer.save_model("models_state_dic") 
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

100%|██████████| 148/148 [00:21<00:00,  6.83batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.69batch/s]
100%|██████████| 148/148 [00:21<00:00,  6.97batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.92batch/s]
100%|██████████| 148/148 [00:21<00:00,  6.93batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.87batch/s]
100%|██████████| 148/148 [00:21<00:00,  6.96batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.84batch/s]
100%|██████████| 148/148 [00:21<00:00,  6.95batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.79batch/s]
100%|██████████| 148/148 [00:21<00:00,  6.98batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.84batch/s]
100%|██████████| 148/148 [00:21<00:00,  6.95batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.81batch/s]
100%|██████████| 148/148 [00:21<00:00,  6.98batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.97batch/s]
100%|██████████| 148/148 [00:21<00:00,  6.92batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.91batch/s]
100%|██████████| 148/148 [00:21<00:00,  6.92batch/s]
100%|██████

VBox(children=(Label(value='0.001 MB of 0.054 MB uploaded\r'), FloatProgress(value=0.023535003256526254, max=1…

0,1
psnr,▁
ssim,▁
train_loss,█▆▄▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_loss,█▆▅▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
psnr,6.91966
ssim,0.05656
train_loss,0.28559
validation_loss,0.2854
