In [1]:
import PIL
import numpy as np
import torch
import torchvision
from torchvision.transforms.v2 import Compose, GaussianBlur, RandomEqualize, RandomSolarize, RandomApply
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from torch.nn.parallel import DistributedDataParallel as DDP

import wandb
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data import random_split

from Dataset.AerialDataset import AerialDataset
from tasks.SR3Trainer import SR3Trainer
from models.SR3Builder import SR3Builder
from utils.model_utils import load_model

  return torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count


ModuleNotFoundError: No module named 'torchmetrics'

In [29]:
#Data
lr_size = 64
hr_size = 256
batch_size = 64
dataset_dir = 'E:\\TFG\\air_dataset'

transforms = Compose(
    [RandomApply(transforms= [GaussianBlur(7)], p = 0.5),
    RandomEqualize()]
)

dataset = AerialDataset(dataset_dir, lr_size, hr_size, data_augmentation = None, aux_sat_prob= 0, sat_dataset_path= "E:\\TFG\\satelite_dataset\\64_256")
train_dataset, val_dataset, test_dataset = random_split(dataset, [0.6, 0.2, 0.2], generator=torch.Generator().manual_seed(420))

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler= DistributedSampler(train_dataset))
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, sampler= DistributedSampler(val_dataset))
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, sampler= DistributedSampler(test_dataset))
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [10]:
model_builder = SR3Builder()
model_builder = model_builder.set_standart()
model = DDP(model_builder.build(), device_ids=[rank])#Definir q id tengo q poner
model.to(device)

GaussianDiffusion(
  (model): UNet(
    (emb): GammaEmbedding(
      (linear1): Linear(
        (linear): Linear(in_features=3, out_features=12, bias=True)
      )
      (silu): SiLU()
      (linear2): Linear(
        (linear): Linear(in_features=12, out_features=12, bias=True)
      )
    )
    (conv1): Conv2d(
      (conv): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (down): ModuleList(
      (0-2): 3 x WideResNetBlock(
        (gn1): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu1): SiLU()
        (conv1): Conv2d(
          (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (silu2): SiLU()
        (linear1): Linear(
          (linear): Linear(in_features=12, out_features=3, bias=True)
        )
        (gn2): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu3): SiLU()
        (dropout): Dropout(p=0, inplace=False)
 

In [11]:
hyperparams = {
    "lr":0.0002,
    "epochs":100,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SR3",
    "DDP": True
}
hyperparams.update(model_builder.get_hyperparameters())
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

In [12]:
project_name = "SR model benchmarking"
run_name = "SR3 standart"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=("ssim", "psnr"), model_name="SR3 Standart")
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    
    train_loss = trainer.train(train_dataloader)
    torch.cuda.empty_cache()
    if step % 10 == 0:
        trainer.save_model("saved models\\SR3")

    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train_loss,▁
validation_loss,▁

0,1
train_loss,0.93045
validation_loss,1.00012


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888888925108, max=1.0…

100%|██████████| 50/50 [00:04<00:00, 12.01batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.70batch/s]
100%|██████████| 50/50 [00:03<00:00, 12.65batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.70batch/s]
100%|██████████| 50/50 [00:03<00:00, 12.59batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.66batch/s]
100%|██████████| 50/50 [00:03<00:00, 12.58batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.63batch/s]
100%|██████████| 50/50 [00:04<00:00, 12.48batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.61batch/s]
100%|██████████| 50/50 [00:04<00:00, 12.44batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.63batch/s]
100%|██████████| 50/50 [00:04<00:00, 12.49batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.58batch/s]
100%|██████████| 50/50 [00:04<00:00, 12.40batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.60batch/s]
100%|██████████| 50/50 [00:03<00:00, 12.53batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.63batch/s]
100%|██████████| 50/50 [00:04<00:00, 12.34batch/s]
100%|████████

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
psnr,▁
ssim,▁
train_loss,█▆▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_loss,█▆▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
psnr,7.86421
ssim,0.02554
train_loss,0.03124
validation_loss,0.02627


In [13]:
model_builder = SR3Builder()
model_builder = model_builder.set_standart()
model_builder = model_builder.set_losstype("l1")
model = model_builder.build()
model.to(device)

GaussianDiffusion(
  (model): UNet(
    (emb): GammaEmbedding(
      (linear1): Linear(
        (linear): Linear(in_features=3, out_features=12, bias=True)
      )
      (silu): SiLU()
      (linear2): Linear(
        (linear): Linear(in_features=12, out_features=12, bias=True)
      )
    )
    (conv1): Conv2d(
      (conv): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (down): ModuleList(
      (0-2): 3 x WideResNetBlock(
        (gn1): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu1): SiLU()
        (conv1): Conv2d(
          (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (silu2): SiLU()
        (linear1): Linear(
          (linear): Linear(in_features=12, out_features=3, bias=True)
        )
        (gn2): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu3): SiLU()
        (dropout): Dropout(p=0, inplace=False)
 

In [14]:
hyperparams = {
    "lr":0.0002,
    "epochs":100,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SR3"
}
hyperparams.update(model_builder.get_hyperparameters())
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

In [15]:
project_name = "SR model benchmarking"
run_name = "SR3 standart with l1 loss"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=("ssim", "psnr"), model_name="SR3 Standart l2 loss")
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    
    train_loss = trainer.train(train_dataloader)
    torch.cuda.empty_cache()
    if step % 10 == 0:
        trainer.save_model("saved models\\SR3")

    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

100%|██████████| 50/50 [00:04<00:00, 11.69batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.68batch/s]
100%|██████████| 50/50 [00:04<00:00, 12.46batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.68batch/s]
100%|██████████| 50/50 [00:03<00:00, 12.52batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.64batch/s]
100%|██████████| 50/50 [00:04<00:00, 12.34batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.58batch/s]
100%|██████████| 50/50 [00:03<00:00, 12.53batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.63batch/s]
100%|██████████| 50/50 [00:04<00:00, 12.43batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.66batch/s]
100%|██████████| 50/50 [00:04<00:00, 12.50batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.62batch/s]
100%|██████████| 50/50 [00:04<00:00, 12.48batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.64batch/s]
100%|██████████| 50/50 [00:04<00:00, 12.49batch/s]
100%|██████████| 148/148 [00:19<00:00,  7.63batch/s]
100%|██████████| 50/50 [00:04<00:00, 12.44batch/s]
100%|████████

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
psnr,▁
ssim,▁
train_loss,█▇▅▄▄▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_loss,█▇▅▄▄▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
psnr,7.98456
ssim,0.03851
train_loss,0.06455
validation_loss,0.05975


In [16]:
model_builder = SR3Builder()
model_builder = model_builder.set_sr3plus()
model = model_builder.build()
model.to(device)

GaussianDiffusion(
  (model): UNet(
    (emb): GammaEmbedding(
      (linear1): Linear(
        (linear): Linear(in_features=3, out_features=12, bias=True)
      )
      (silu): SiLU()
      (linear2): Linear(
        (linear): Linear(in_features=12, out_features=12, bias=True)
      )
    )
    (conv1): Conv2d(
      (conv): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (down): ModuleList(
      (0-4): 5 x WideResNetBlock(
        (gn1): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu1): SiLU()
        (conv1): Conv2d(
          (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (silu2): SiLU()
        (linear1): Linear(
          (linear): Linear(in_features=12, out_features=3, bias=True)
        )
        (gn2): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu3): SiLU()
        (dropout): Dropout(p=0, inplace=False)
 

In [17]:
hyperparams = {
    "lr":0.0002,
    "epochs":100,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SR3+"
}
hyperparams.update(model_builder.get_hyperparameters())
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

In [18]:
project_name = "SR model benchmarking"
run_name = "SR3+ standart"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=("ssim", "psnr"), model_name="SR3+ standart")
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    
    train_loss = trainer.train(train_dataloader)
    torch.cuda.empty_cache()
    if step % 10 == 0:
        trainer.save_model("saved models\\SR3+")

    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01127777777777131, max=1.0)…

100%|██████████| 50/50 [00:04<00:00, 10.49batch/s]
100%|██████████| 148/148 [00:24<00:00,  6.13batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.26batch/s]
100%|██████████| 148/148 [00:24<00:00,  6.08batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.37batch/s]
100%|██████████| 148/148 [00:24<00:00,  6.15batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.30batch/s]
100%|██████████| 148/148 [00:24<00:00,  6.15batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.34batch/s]
100%|██████████| 148/148 [00:24<00:00,  6.15batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.27batch/s]
100%|██████████| 148/148 [00:24<00:00,  6.15batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.17batch/s]
100%|██████████| 148/148 [00:24<00:00,  6.12batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.26batch/s]
100%|██████████| 148/148 [00:24<00:00,  6.14batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.24batch/s]
100%|██████████| 148/148 [00:24<00:00,  6.12batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.35batch/s]
100%|████████

VBox(children=(Label(value='0.001 MB of 0.060 MB uploaded\r'), FloatProgress(value=0.021472931334514743, max=1…

0,1
psnr,▁
ssim,▁
train_loss,█▆▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_loss,█▆▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
psnr,4.26679
ssim,0.00247
train_loss,0.00922
validation_loss,0.00852


In [19]:
model_builder = SR3Builder()
model_builder = model_builder.set_sr3plus()
model_builder = model_builder.set_losstype("l1")
model = model_builder.build()
model.to(device)

GaussianDiffusion(
  (model): UNet(
    (emb): GammaEmbedding(
      (linear1): Linear(
        (linear): Linear(in_features=3, out_features=12, bias=True)
      )
      (silu): SiLU()
      (linear2): Linear(
        (linear): Linear(in_features=12, out_features=12, bias=True)
      )
    )
    (conv1): Conv2d(
      (conv): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (down): ModuleList(
      (0-4): 5 x WideResNetBlock(
        (gn1): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu1): SiLU()
        (conv1): Conv2d(
          (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (silu2): SiLU()
        (linear1): Linear(
          (linear): Linear(in_features=12, out_features=3, bias=True)
        )
        (gn2): GroupNorm(
          (group_norm): GroupNorm(3, 3, eps=1e-05, affine=True)
        )
        (silu3): SiLU()
        (dropout): Dropout(p=0, inplace=False)
 

In [20]:
hyperparams = {
    "lr":0.0002,
    "epochs":100,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SR3+"
}
hyperparams.update(model_builder.get_hyperparameters())
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

In [21]:
project_name = "SR model benchmarking"
run_name = "SR3+ standart with l1 loss"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=("ssim", "psnr"), model_name="SR3+ standart l1 loss")
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    
    train_loss = trainer.train(train_dataloader)
    torch.cuda.empty_cache()
    if step % 10 == 0:
        trainer.save_model("saved models\\SR3+")

    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888888925108, max=1.0…

100%|██████████| 50/50 [00:04<00:00, 10.62batch/s]
100%|██████████| 148/148 [00:24<00:00,  6.05batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.04batch/s]
100%|██████████| 148/148 [00:24<00:00,  6.11batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.19batch/s]
100%|██████████| 148/148 [00:24<00:00,  6.09batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.22batch/s]
100%|██████████| 148/148 [00:24<00:00,  6.11batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.22batch/s]
100%|██████████| 148/148 [00:24<00:00,  6.07batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.08batch/s]
100%|██████████| 148/148 [00:24<00:00,  6.06batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.06batch/s]
100%|██████████| 148/148 [00:24<00:00,  6.12batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.17batch/s]
100%|██████████| 148/148 [00:24<00:00,  6.11batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.11batch/s]
100%|██████████| 148/148 [00:24<00:00,  6.11batch/s]
100%|██████████| 50/50 [00:04<00:00, 11.21batch/s]
100%|████████

VBox(children=(Label(value='0.001 MB of 0.060 MB uploaded\r'), FloatProgress(value=0.02142299667107338, max=1.…

0,1
psnr,▁
ssim,▁
train_loss,█▇▅▄▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_loss,█▇▆▅▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
psnr,4.68663
ssim,0.00405
train_loss,0.03587
validation_loss,0.03659


In [2]:
#Data
lr_size = 64
hr_size = 256
batch_size = 16
dataset_dir = 'E:\\TFG\\air_dataset'

transforms = Compose(
    [RandomApply(transforms= [GaussianBlur(7)], p = 0.5),
    RandomEqualize()]
)

dataset = AerialDataset(dataset_dir, lr_size, hr_size, data_augmentation = transforms, aux_sat_prob= 0.6, sat_dataset_path= "E:\\TFG\\satelite_dataset\\64_256")
train_dataset, val_dataset, test_dataset = random_split(dataset, [0.6, 0.2, 0.2], generator=torch.Generator().manual_seed(420))

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
#Con data augmentation

model_builder = SR3Builder()
model_builder = model_builder.set_standart()
model = model_builder.build()
model.to(device)

hyperparams = {
    "lr":0.0002,
    "epochs":200,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SR3"
}
hyperparams.update(model_builder.get_hyperparameters())
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

project_name = "SR model benchmarking"
run_name = "SR3 standart with data augmentation"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=("ssim", "psnr"), model_name="SR3 Standart DA")
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    
    train_loss = trainer.train(train_dataloader)
    torch.cuda.empty_cache()
    if step % 10 == 0:
        trainer.save_model("saved models\\SR3")

    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

100%|██████████| 50/50 [00:05<00:00,  9.50batch/s]


TypeError: SR3Trainer.train() missing 1 required positional argument: 'epoch'

In [10]:
#Con data augmentation

model_builder = SR3Builder()
model_builder = model_builder.set_standart()
model_builder = model_builder.set_losstype("l1")
model = model_builder.build()
model.to(device)

hyperparams = {
    "lr":0.0002,
    "epochs":200,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SR3",
    "grad_acum": 2
}
hyperparams.update(model_builder.get_hyperparameters())
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

project_name = "SR model benchmarking"
run_name = "SR3 standart l1, DA, GA"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=("ssim", "psnr"), model_name="SR3 Standart l1 DA GA2", grad_acum=hyperparams["grad_acum"])
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for epoch in range(hyperparams["epochs"]):  
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    
    train_loss = trainer.train(train_dataloader, epoch)
    torch.cuda.empty_cache()
    if epoch % 10 == 0:
        trainer.save_model("saved models\\SR3")

    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011277777777609622, max=1.0…

100%|██████████| 50/50 [00:04<00:00, 10.15batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.68batch/s]
100%|██████████| 50/50 [00:04<00:00, 10.96batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.69batch/s]
100%|██████████| 50/50 [00:04<00:00, 10.61batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.67batch/s]
100%|██████████| 50/50 [00:04<00:00, 10.16batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.68batch/s]
100%|██████████| 50/50 [00:04<00:00, 10.42batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.61batch/s]
100%|██████████| 50/50 [00:04<00:00, 10.63batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.64batch/s]
100%|██████████| 50/50 [00:04<00:00, 10.62batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.48batch/s]
100%|██████████| 50/50 [00:04<00:00, 10.35batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.69batch/s]
100%|██████████| 50/50 [00:04<00:00, 10.71batch/s]
100%|██████████| 148/148 [00:22<00:00,  6.71batch/s]
100%|██████████| 50/50 [00:04<00:00, 10.76batch/s]
100%|████████

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
psnr,▁
ssim,▁
train_loss,██▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁
validation_loss,██▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁

0,1
psnr,5.33682
ssim,0.00347
train_loss,0.10685
validation_loss,0.21282


In [12]:
#Con data augmentation

model_builder = SR3Builder()
model_builder = model_builder.set_sr3plus()
model = model_builder.build()
model.to(device)

hyperparams = {
    "lr":0.0002,
    "epochs":200,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SR3+",
    "grad_acum": 2
}
hyperparams.update(model_builder.get_hyperparameters())
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

project_name = "SR model benchmarking"
run_name = "SR3+ standart DA GA"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=("ssim", "psnr"), model_name="SR3+ Standart DA GA 2", grad_acum= hyperparams["grad_acum"])
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for epoch in range(hyperparams["epochs"]):
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    
    train_loss = trainer.train(train_dataloader, epoch )
    torch.cuda.empty_cache()
    if epoch % 10 == 0:
        trainer.save_model("saved models\\SR3")

    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

100%|██████████| 50/50 [00:05<00:00,  9.28batch/s]
100%|██████████| 148/148 [00:26<00:00,  5.61batch/s]
100%|██████████| 50/50 [00:04<00:00, 10.14batch/s]
100%|██████████| 148/148 [00:25<00:00,  5.73batch/s]
100%|██████████| 50/50 [00:04<00:00, 10.01batch/s]
100%|██████████| 148/148 [00:26<00:00,  5.66batch/s]
100%|██████████| 50/50 [00:04<00:00, 10.15batch/s]
100%|██████████| 148/148 [00:25<00:00,  5.77batch/s]
100%|██████████| 50/50 [00:04<00:00, 10.02batch/s]
100%|██████████| 148/148 [00:26<00:00,  5.59batch/s]
100%|██████████| 50/50 [00:04<00:00, 10.05batch/s]
100%|██████████| 148/148 [00:25<00:00,  5.72batch/s]
100%|██████████| 50/50 [00:04<00:00, 10.12batch/s]
100%|██████████| 148/148 [00:26<00:00,  5.65batch/s]
100%|██████████| 50/50 [00:04<00:00, 10.04batch/s]
100%|██████████| 148/148 [00:25<00:00,  5.75batch/s]
100%|██████████| 50/50 [00:04<00:00, 10.05batch/s]
100%|██████████| 148/148 [00:26<00:00,  5.59batch/s]
100%|██████████| 50/50 [00:04<00:00, 10.01batch/s]
100%|████████

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
psnr,▁
ssim,▁
train_loss,██▇▇▆▆▅▅▅▅▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
validation_loss,██▇▇▆▆▅▅▅▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁

0,1
psnr,4.34428
ssim,0.00203
train_loss,0.04324
validation_loss,0.08925


In [27]:
#Con data augmentation

model_builder = SR3Builder()
model_builder = model_builder.set_sr3plus()
model_builder = model_builder.set_losstype("l1")
model = model_builder.build()
model.to(device)

hyperparams = {
    "lr":0.0002,
    "epochs":60,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SR3+"
}
hyperparams.update(model_builder.get_hyperparameters())
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

project_name = "SR model benchmarking"
run_name = "SR3+ standart l1 with data augmentation"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=("ssim", "psnr"), model_name="SR3+ Standart l1 DA")
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    
    train_loss = trainer.train(train_dataloader)
    torch.cuda.empty_cache()
    if step % 10 == 0:
        trainer.save_model("saved models\\SR3")

    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   

100%|██████████| 50/50 [25:34<00:00, 30.70s/batch]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
psnr,▁
ssim,▁
train_loss,█▇▆▆▅▅▅▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_loss,█▇▆▆▅▅▅▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
psnr,3.9947
ssim,0.00195
train_loss,0.05619
validation_loss,0.056
