In [None]:
import torch
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import DataLoader
from model import NeuralNetwork, Trainer
from model.datasets import UIEB, UIEBChallenging
from model.util import save_batch

In [None]:
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
])
dataset = UIEB(root_dir='data/UIEB', transform=transform)

In [None]:
batch_size = 4
train_dataloader = DataLoader(dataset[:-90], batch_size=batch_size, shuffle=True)
validate_dataloader = DataLoader(dataset[-90:], batch_size=batch_size, shuffle=False)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"

In [None]:
model = NeuralNetwork().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
trainer = Trainer(model, optimizer, device, checkpoint_dir="checkpoints/uieb")

# trainer.load_checkpoint(checkpoint_path)

trainer.train(
    train_loader=train_dataloader,
    val_loader=validate_dataloader,
    num_epochs=50,
    use_supervision=True,
    save_interval=10
)

In [None]:
with torch.no_grad():
    model.eval()
    for batch in tqdm(validate_dataloader, desc='Test-90'):
        inputs, references = batch
        inputs = inputs.to(device)
        references = references.to(device)

        outputs, _ = model(inputs)

        save_batch(outputs, "tests/mdnet-uieb/test-90/outputs")
        save_batch(references, "tests/mdnet-uieb/test-90/references")

In [None]:
!python evaluate.py "tests/mdnet-uieb/test-90/output" "tests/mdnet-uieb/test-90/references"

In [None]:
dataset = UIEBChallenging(root_dir='data/UIEB')
dataloader = DataLoader(dataset, batch_size=batch_size)
with torch.no_grad():
    model.eval()
    for batch in tqdm(dataloader, desc='UIEB Challenging Test'):
        inputs = batch.to(device)

        outputs, _ = model(inputs)
        
        save_batch(outputs, "tests/mdnet-uieb/challenging")

In [None]:
!python nevaluate.py "tests/mdnet-uieb/challenging"