In [1]:
import torch
from torch import Tensor
from torch.utils import data as torch_data
from torch.amp import autocast, GradScaler

from models import UNet
from datasets import PreprocessedOpenFWI

In [None]:
# model
model = UNet(
    in_channels=5,
    out_channels=1,
    start_features=32,
    depth=4
).cuda()
# data
train_dataset = PreprocessedOpenFWI(train=True, norm_output=True, nb_files_to_load=100)
train_loader = torch_data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_dataset = PreprocessedOpenFWI(train=False, norm_output=True)
test_loader = torch_data.DataLoader(train_dataset, batch_size=128, shuffle=True)
# Training
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
scaler = GradScaler(device="cuda")

Output()

Output()

Output()

Output()

In [None]:
from rich.progress import Task

step = 0
for epoch in range(100):
    total_loss = 0
    total_accuracy = 0
    # Count nb samples instead of accessing len(data_loader.dataset)
    # in case the data lauder augments the number of samples
    nb_batches = 0
    nb_samples = 0
    # Task()
    for batch_x, batch_y in train_loader:
        batch_x = batch_x.cuda()
        batch_y = batch_y.cuda()
        nb_batches += 1
        nb_samples += len(batch_x)
        model.train()
        optimizer.zero_grad()
        # print(batch_x.dtype)
        with autocast(device_type="cuda"):
            batch_y_pred = model(batch_x)

        loss_value = (batch_y_pred - batch_y).abs().mean()
        scaler.scale(loss_value).backward()
        # scaler.unscale_(optimizer)
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss_value.item()
        step += 1
    print("loss:", total_loss / len(train_loader))


loss: 0.25165528566491024
loss: 0.18944680750469534
loss: 0.1700423829211881
loss: 0.15913278179561421
loss: 0.1517887614339872


In [None]:
import numpy as np
import plotly.express as px

OFFSET = 50
IMGS_TO_SHOW = 10

y_true_to_display = (
    batch_y
    .cpu()
    .numpy()
    [OFFSET:OFFSET+IMGS_TO_SHOW, 0, ...]
)

y_pred_to_display = (
    batch_y_pred
    .cpu()
    .detach()
    .numpy()
    [OFFSET:OFFSET+IMGS_TO_SHOW, 0, ...]
)

print(y_true_to_display.shape)


px.imshow(
    np.concatenate(
        (y_true_to_display,
        y_pred_to_display)
    ),
    facet_col=0,
    facet_col_wrap=IMGS_TO_SHOW,
)

(10, 70, 70)
