In [2]:
import PIL
import numpy as np
import torch
import torchvision
from torchvision.transforms.v2 import Compose, GaussianBlur, RandomEqualize, RandomSolarize, RandomApply
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from torch.nn.parallel import DistributedDataParallel as DDP

import wandb
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data import random_split

from Dataset.AerialDataset import AerialDataset
from tasks.SR3Trainer import SR3Trainer
from models.SR3Builder import SR3Builder
from utils.model_utils import load_model
from utils.DDP_utils import ddp_setup

In [11]:
#Inicialization
ddp_setup()
lr_size = 64
hr_size = 256
batch_size = 64
dataset_dir = 'C:\\Users\\adrianperera\\Desktop\\dataset_tfg'

transforms = Compose(
    [RandomApply(transforms= [GaussianBlur(7)], p = 0.5),
    RandomEqualize()]
)

dataset = AerialDataset(dataset_dir, lr_size, hr_size, data_augmentation = None, aux_sat_prob= 0, sat_dataset_path= "C:\\Users\\adrianperera\\Desktop\\dataset_tfg\\satelite_dataset\\64_256")
train_dataset, val_dataset, test_dataset = random_split(dataset, [0.6, 0.2, 0.2], generator=torch.Generator().manual_seed(420))

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler= DistributedSampler(train_dataset))
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, sampler= DistributedSampler(val_dataset))
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, sampler= DistributedSampler(test_dataset))
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [10]:
model_builder = SR3Builder()
model_builder = model_builder.set_standart()
model = DDP(model_builder.build(), device_ids=[rank])#Definir q id tengo q poner
model.to(device)

In [11]:
hyperparams = {
    "lr":0.0002,
    "epochs":100,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SR3",
    "DDP": True
}
hyperparams.update(model_builder.get_hyperparameters())
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

In [12]:
project_name = "SR model benchmarking"
run_name = "SR3 standart"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=("ssim", "psnr"), model_name="SR3 Standart")
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    
    train_loss = trainer.train(train_dataloader)
    torch.cuda.empty_cache()
    if step % 10 == 0:
        trainer.save_model("saved models\\SR3")

    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   

In [13]:
model_builder = SR3Builder()
model_builder = model_builder.set_standart()
model_builder = model_builder.set_losstype("l1")
model = model_builder.build()
model.to(device)

In [14]:
hyperparams = {
    "lr":0.0002,
    "epochs":100,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SR3"
}
hyperparams.update(model_builder.get_hyperparameters())
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

In [15]:
project_name = "SR model benchmarking"
run_name = "SR3 standart with l1 loss"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=("ssim", "psnr"), model_name="SR3 Standart l2 loss")
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    
    train_loss = trainer.train(train_dataloader)
    torch.cuda.empty_cache()
    if step % 10 == 0:
        trainer.save_model("saved models\\SR3")

    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   

In [16]:
model_builder = SR3Builder()
model_builder = model_builder.set_sr3plus()
model = model_builder.build()
model.to(device)

In [17]:
hyperparams = {
    "lr":0.0002,
    "epochs":100,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SR3+"
}
hyperparams.update(model_builder.get_hyperparameters())
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

In [18]:
project_name = "SR model benchmarking"
run_name = "SR3+ standart"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=("ssim", "psnr"), model_name="SR3+ standart")
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    
    train_loss = trainer.train(train_dataloader)
    torch.cuda.empty_cache()
    if step % 10 == 0:
        trainer.save_model("saved models\\SR3+")

    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   

In [19]:
model_builder = SR3Builder()
model_builder = model_builder.set_sr3plus()
model_builder = model_builder.set_losstype("l1")
model = model_builder.build()
model.to(device)

In [20]:
hyperparams = {
    "lr":0.0002,
    "epochs":100,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SR3+"
}
hyperparams.update(model_builder.get_hyperparameters())
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

In [21]:
project_name = "SR model benchmarking"
run_name = "SR3+ standart with l1 loss"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=("ssim", "psnr"), model_name="SR3+ standart l1 loss")
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    
    train_loss = trainer.train(train_dataloader)
    torch.cuda.empty_cache()
    if step % 10 == 0:
        trainer.save_model("saved models\\SR3+")

    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   

In [2]:
#Data
lr_size = 64
hr_size = 256
batch_size = 16
dataset_dir = 'E:\\TFG\\air_dataset'

transforms = Compose(
    [RandomApply(transforms= [GaussianBlur(7)], p = 0.5),
    RandomEqualize()]
)

dataset = AerialDataset(dataset_dir, lr_size, hr_size, data_augmentation = transforms, aux_sat_prob= 0.6, sat_dataset_path= "E:\\TFG\\satelite_dataset\\64_256")
train_dataset, val_dataset, test_dataset = random_split(dataset, [0.6, 0.2, 0.2], generator=torch.Generator().manual_seed(420))

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
#Con data augmentation

model_builder = SR3Builder()
model_builder = model_builder.set_standart()
model = model_builder.build()
model.to(device)

hyperparams = {
    "lr":0.0002,
    "epochs":200,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SR3"
}
hyperparams.update(model_builder.get_hyperparameters())
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

project_name = "SR model benchmarking"
run_name = "SR3 standart with data augmentation"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=("ssim", "psnr"), model_name="SR3 Standart DA")
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    
    train_loss = trainer.train(train_dataloader)
    torch.cuda.empty_cache()
    if step % 10 == 0:
        trainer.save_model("saved models\\SR3")

    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   

In [10]:
#Con data augmentation

model_builder = SR3Builder()
model_builder = model_builder.set_standart()
model_builder = model_builder.set_losstype("l1")
model = model_builder.build()
model.to(device)

hyperparams = {
    "lr":0.0002,
    "epochs":200,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SR3",
    "grad_acum": 2
}
hyperparams.update(model_builder.get_hyperparameters())
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

project_name = "SR model benchmarking"
run_name = "SR3 standart l1, DA, GA"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=("ssim", "psnr"), model_name="SR3 Standart l1 DA GA2", grad_acum=hyperparams["grad_acum"])
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for epoch in range(hyperparams["epochs"]):  
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    
    train_loss = trainer.train(train_dataloader, epoch)
    torch.cuda.empty_cache()
    if epoch % 10 == 0:
        trainer.save_model("saved models\\SR3")

    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   

In [12]:
#Con data augmentation

model_builder = SR3Builder()
model_builder = model_builder.set_sr3plus()
model = model_builder.build()
model.to(device)

hyperparams = {
    "lr":0.0002,
    "epochs":200,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SR3+",
    "grad_acum": 2
}
hyperparams.update(model_builder.get_hyperparameters())
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

project_name = "SR model benchmarking"
run_name = "SR3+ standart DA GA"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=("ssim", "psnr"), model_name="SR3+ Standart DA GA 2", grad_acum= hyperparams["grad_acum"])
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for epoch in range(hyperparams["epochs"]):
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    
    train_loss = trainer.train(train_dataloader, epoch )
    torch.cuda.empty_cache()
    if epoch % 10 == 0:
        trainer.save_model("saved models\\SR3")

    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   

In [27]:
#Con data augmentation

model_builder = SR3Builder()
model_builder = model_builder.set_sr3plus()
model_builder = model_builder.set_losstype("l1")
model = model_builder.build()
model.to(device)

hyperparams = {
    "lr":0.0002,
    "epochs":60,
    "eta_min":1e-7,
    "decay_steps": 100000,
    "gamma" : 0.5,  
    "model" : "SR3+"
}
hyperparams.update(model_builder.get_hyperparameters())
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams["decay_steps"], gamma=hyperparams["gamma"])

project_name = "SR model benchmarking"
run_name = "SR3+ standart l1 with data augmentation"
wandb.login()
wandb.init(project=project_name, config=hyperparams, name=run_name)

trainer = SR3Trainer(metrics_used=("ssim", "psnr"), model_name="SR3+ Standart l1 DA")
trainer.set_model(model)
trainer.set_optimizer(optimizer)
trainer.set_scheduler(scheduler)
for step in range(hyperparams["epochs"]):
    with torch.no_grad():
        val_loss = trainer.validate(val_dataloader)
    
    train_loss = trainer.train(train_dataloader)
    torch.cuda.empty_cache()
    if step % 10 == 0:
        trainer.save_model("saved models\\SR3")

    torch.cuda.empty_cache()
    wandb.log({"train_loss": train_loss, "validation_loss": val_loss})
    
test_metrics = trainer.test(test_dataloader)
wandb.log(test_metrics)
wandb.finish()   