In [None]:
import os
import numpy as np
import math

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn.functional as F

import optuna
# optuna.logging.set_verbosity(optuna.logging.WARNING)
from sklearn.metrics import r2_score

from typing import Optional, Tuple, List, Dict
from dataclasses import dataclass, field
from skimage.util import view_as_windows
from collections import defaultdict

import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 200
plt.rcParams['text.usetex'] = True
os.environ['PATH'] = '/Library/TeX/texbin:' + os.environ['PATH']
plt.style.use('dark_background')

# old

#### read data

In [None]:
@dataclass
class SatelliteDataset(Dataset):

    images: np.ndarray
    masks: np.ndarray
    transform: Optional[callable] = None

    mean = torch.tensor([0.485, 0.456, 0.406, 0.5], dtype=torch.float32)[:, None, None]
    std = torch.tensor([0.229, 0.224, 0.225, 0.25], dtype=torch.float32)[:, None, None]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):

        img = torch.from_numpy(self.images[idx]).float()
        img = (img - self.mean) / self.std
        msk = torch.from_numpy(self.masks[idx]).float()

        while msk.ndim > 2:
            msk = msk.squeeze(0)
        msk = msk.unsqueeze(0)
        return img, msk

#### layers

In [None]:
@dataclass
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def __hash__(self): #make instance hashable by its id
        return id(self)

    def forward(self, x):
        return self.double_conv(x)

#### simple unet

In [None]:
@dataclass
class UNet(nn.Module):
    def __init__(self, in_channels=5, out_channels=1, features=[64, 128, 256, 512]):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()

        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature
            self.downs.append(nn.MaxPool2d(kernel_size=2, stride=2))

        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def __hash__(self): #make instance hashable by its id
        return id(self)

    def forward(self, x):
        skip = []
        for idx in range(0, len(self.downs), 2):
            conv = self.downs[idx](x)
            skip.append(conv)
            x = self.downs[idx+1](conv)

        x = self.bottleneck(x)
        skip = skip[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            sc = skip[idx//2]
            if x.shape[2:] != sc.shape[2:]:
                x = F.interpolate(x, size=sc.shape[2:], mode='nearest')
            x = torch.cat([sc, x], dim=1)
            x = self.ups[idx+1](x)

        return self.final_conv(x)

#### train

In [None]:
def train_model(model: nn.Module, train_loader: DataLoader,
                val_loader: DataLoader, device: torch.device,
                epochs: int = 50, lr: float = 1e-4):

    criterion = nn.MSELoss() #! check this------
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.to(device)

    tl, tt = [], []
    for epoch in range(1, epochs+1):
        model.train()
        train_loss = 0.0
        for imgs, masks in train_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            preds = model(imgs)
            loss = criterion(preds, masks)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * imgs.size(0)
        train_loss /= len(train_loader.dataset)
        tl.append(train_loss)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for imgs, masks in val_loader:
                imgs, masks = imgs.to(device), masks.to(device)
                preds = model(imgs)
                loss = criterion(preds, masks)
                val_loss += loss.item() * imgs.size(0)
            val_loss /= len(val_loader.dataset)
            tt.append(val_loss)

        # print(f'epoch {epoch}/{epochs}, train:{train_loss:.4f}, val: {val_loss:.4f}')
    return tl, tt

data

In [None]:
dataset = np.load('chesapeake_dataset.npz')

X = np.array([dataset['B2'], dataset['B3'], dataset['B4'], dataset['B8']])
y = dataset['in_situ']

In [None]:
dataset['B2']

In [None]:
X.shape, y.shape

----

In [None]:
images = np.stack([dataset['B2'], dataset['B3'], dataset['B4'], dataset['B8']], axis=-1)
print("images.shape =", images.shape)

In [None]:
window_shape = (patch_size, patch_size, images.shape[2])  # (43,43,4)
step = (step_spatial, step_spatial, images.shape[2])

patches = view_as_windows(
    images,
    window_shape=window_shape,
    step=step
)

In [None]:
images.shape, mask.shape

In [None]:
patch_size = 43
step = 43

img_patches = view_as_windows(images,
                              window_shape=(patch_size, patch_size, 4),
                              step=step
                              )
img_patches = img_patches.reshape(-1, patch_size, patch_size, 4)

In [None]:
mask_patches = view_as_windows(mask,
                               window_shape=(patch_size, patch_size),
                               step=step
                               )

mask_patches = mask_patches.reshape(-1, patch_size, patch_size)

In [None]:
X = np.transpose(img_patches, (0, 3, 1, 2))
y = mask_patches[:, None, ...]

In [None]:
N = X.shape[0]
split_idx = int(0.8 * N)
train_imgs, val_imgs = X[:split_idx], X[split_idx:]
train_masks, val_masks = y[:split_idx], y[split_idx:]

In [None]:
train_imgs.shape, val_imgs.shape, train_masks.shape, val_masks.shape

split

In [None]:
split_idx = int(0.8 * len(X))
train_imgs, val_imgs = X[:split_idx], X[split_idx:]
train_masks, val_masks = y[:split_idx], y[split_idx:]

In [None]:
N = X.shape[0]
train_frac = 0.8
split_idx = math.ceil(train_frac * N)  # ceil(0.8 * 1) == 1

train_imgs, val_imgs = X[:split_idx], X[split_idx:]
train_masks, val_masks = y[:split_idx], y[split_idx:]

In [None]:
train_imgs.shape, val_imgs.shape, train_masks.shape, val_masks.shape

In [None]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406, 0.5, 0.5],
                                                     std=[0.229, 0.224, 0.225, 0.25, 0.25])
                                ])

In [None]:
train_dataset = SatelliteDataset(train_imgs, train_masks, transform=transform)
val_dataset = SatelliteDataset(val_imgs, val_masks, transform=transform)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=8)

use mps for metal acc

In [None]:
torch.backends.mps.is_available()

In [None]:
device = torch.device("mps")
device

#### try

In [None]:
model = UNet(in_channels=4, out_channels=1)
model.to(device, dtype=torch.float32)

In [None]:
tl,tt = train_model(model, train_loader, val_loader, device, epochs=50, lr=1e-4)

In [None]:
epochs = np.arange(1, 51)
plt.plot(epochs, tl, label='train')
plt.plot(epochs, tt, label='val')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

----

In [None]:
@dataclass
class ModelOptimizer:
    x_train: np.ndarray
    y_train: np.ndarray
    x_val: np.ndarray
    y_val: np.ndarray
    best_params: Dict = field(default_factory=dict)

    def objective(self, trial):
        lr = trial.suggest_float('lr', 1e-5, 1e-1, log=True)
        epochs = trial.suggest_int('epochs', 10, 500)
        # batch_size = trial.suggest_int('batch_size', 4, 32)

        model = UNet(in_channels=4, out_channels=1)
        model.to(device, dtype=torch.float32)

        train_dataset = SatelliteDataset(self.x_train, self.y_train)
        val_dataset = SatelliteDataset(self.x_val, self.y_val)

        train_loader = DataLoader(train_dataset, batch_size=8, shuffle=False)
        val_loader = DataLoader(val_dataset, batch_size=8)

        tl, tt = train_model(model, train_loader, val_loader, device, epochs=epochs, lr=lr)
        val_loss = tt[-1]
        # r2 = r2_score(self.y_val.flatten(), model(self.x_val).flatten().cpu().numpy())
        # print(f'val_loss: {val_loss:.4f}, r2: {r2:.4f}')
        return val_loss

    def optimize(self, n_trials):
        study = optuna.create_study(direction='minimize')
        study.optimize(self.objective, n_trials=n_trials)
        self.best_params = study.best_params
        print(f'Best params: {self.best_params}')
        return self.best_params

In [None]:
study = ModelOptimizer(x_train=train_imgs, y_train=train_masks,
                       x_val=val_imgs, y_val=val_masks)

In [None]:
best = study.optimize(n_trials=100)

In [None]:
best = {'lr': 0.002099322607285872, 'epochs': 396}

In [None]:
model = UNet(in_channels=4, out_channels=1)
model.to(device, dtype=torch.float32)
tl,tt = train_model(model, train_loader, val_loader, device, epochs=best['epochs'], lr=best['lr'])

In [None]:
epochs = np.arange(1, best['epochs']+1)

plt.plot(epochs, tl, label='train')
plt.plot(epochs, tt, label='val')
plt.title(f"Best params: {best}")
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.xlim(1, 10)
plt.grid(linewidth=0.1)
plt.legend()
plt.show()

# new

In [None]:
data = np.load('chesapeake_patches.npz')
images = data['images']        # (N, C, H, W)
masks = data['masks']          # (N, H, W)

masks = masks[:, None, :, :]   # (N, 1, H, W)

N = images.shape[0]
split = int(0.8 * N)
train_imgs, val_imgs   = images[:split], images[split:]
train_masks, val_masks = masks[:split], masks[split:]

In [None]:
augment = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation((0, 270)),  # rotaciones de 0,90,180,270°
])

class SatelliteDataset(Dataset):
    def __init__(self, imgs, msks):
        self.images  = torch.from_numpy(imgs).float()
        self.masks   = torch.from_numpy(msks).float()
        self.augment = augment

        self.mean = torch.tensor([0.485, 0.456, 0.406, 0.5])[..., None, None]
        self.std  = torch.tensor([0.229, 0.224, 0.225, 0.25])[..., None, None]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        x = self.images[idx]   # (4, H, W)
        y = self.masks[idx]    # (1, H, W)
        stacked = torch.cat([x, y], dim=0)     # (5, H, W)
        stacked = self.augment(stacked)
        x_aug, y_aug = stacked[:-1], stacked[-1:]
        # x_aug = (x_aug - self.mean) / self.std
        return x_aug, y_aug

In [None]:
train_ds = SatelliteDataset(train_imgs, train_masks)
val_ds   = SatelliteDataset(val_imgs, val_masks)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=4)

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.net(x)

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels=4, out_channels=1, features=[64,128,256,512]):
        super().__init__()
        self.downs, self.ups = nn.ModuleList(), nn.ModuleList()
        c = in_channels
        for f in features:
            self.downs.append(DoubleConv(c, f))
            self.downs.append(nn.MaxPool2d(2))
            c = f
        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        for f in reversed(features):
            self.ups.append(nn.ConvTranspose2d(f*2, f, 2, stride=2))
            self.ups.append(DoubleConv(f*2, f))
        self.final = nn.Conv2d(features[0], out_channels, 1)

    def forward(self, x):
        skips = []

        for i in range(0, len(self.downs), 2):
            x = self.downs[i](x)
            skips.append(x)
            x = self.downs[i+1](x)
        x = self.bottleneck(x)

        for i in range(0, len(self.ups), 2):
            x = self.ups[i](x)
            skip = skips[-(i//2)-1]
            if x.shape[2:] != skip.shape[2:]:
                x = F.interpolate(x, size=skip.shape[2:], mode='nearest')
            x = torch.cat([skip, x], dim=1)
            x = self.ups[i+1](x)
        return self.final(x)

In [None]:
def focal_loss(logits, targets, alpha=0.25, gamma=2.0, eps=1e-6):
    prob = torch.sigmoid(logits)
    ce   = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
    p_t  = prob*targets + (1-prob)*(1-targets)
    loss = alpha * (1 - p_t).pow(gamma) * ce
    return loss.mean()

In [None]:
def train_model(model, train_loader, val_loader, device, epochs=30, lr=1e-4, wd=1e-5):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    criterion = nn.BCEWithLogitsLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
             mode='min', factor=0.5, patience=3)
    best_loss = float('inf')
    patience, wait = 10, 0

    tl, tt = [], []
    for epoch in range(1, epochs+1):
        model.train()
        train_loss = 0.0
        for imgs, msks in train_loader:
            imgs, msks = imgs.to(device), msks.to(device)
            preds = model(imgs)
            loss = criterion(preds, msks)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * imgs.size(0)
        train_loss /= len(train_loader.dataset)
        tl.append(train_loss)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for imgs, msks in val_loader:
                imgs, msks = imgs.to(device), msks.to(device)
                preds = model(imgs)
                val_loss += criterion(preds, msks).item() * imgs.size(0)
        val_loss /= len(val_loader.dataset)
        tt.append(val_loss)
        
        scheduler.step(val_loss)
        if val_loss < best_loss:
            best_loss, wait = val_loss, 0
            torch.save(model.state_dict(), 'best_model.pt')
        else:
            wait += 1
            if wait >= patience:
                print(f"No hubo mejora en {patience} épocas, deteniendo.")
                model.load_state_dict(torch.load('best_model.pt'))
                break


        print(f"Epoch {epoch:02d} — train {train_loss:.4f}, val {val_loss:.4f}")
    return model, tl, tt

In [None]:
device = torch.device("mps")
device

In [None]:
# model = UNet(in_channels=4, out_channels=1)
# en lugar de [64,128,256,512], prueba [32,64,128,256]
model = UNet(in_channels=4, out_channels=1, features=[32,64,128,256])

trained_model, tl, tt = train_model(model, train_loader, val_loader,
                                    device, epochs=200, lr=1e-4,
                                    wd=1e-6)

epochs = np.arange(1, len(tl)+1)
plt.plot(epochs, tl, label='train')
plt.plot(epochs, tt, label='val')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
model = UNet(in_channels=4, out_channels=1, features=[32,64,128,256])

trained_model, tl, tt = train_model(model, train_loader, val_loader,
                                    device, epochs=200, lr=1e-5,
                                    wd=1e-6)

epochs = np.arange(1, len(tl)+1)
plt.plot(epochs, tl, label='train')
plt.plot(epochs, tt, label='val')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
model = UNet(in_channels=4, out_channels=1, features=[32,64,128,256])

trained_model, tl, tt = train_model(model, train_loader, val_loader,
                                    device, epochs=200, lr=1e-5,
                                    wd=0)

epochs = np.arange(1, len(tl)+1)
plt.plot(epochs, tl, label='train')
plt.plot(epochs, tt, label='val')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
model = UNet(in_channels=4, out_channels=1, features=[32,64,128,256])

trained_model, tl, tt = train_model(model, train_loader, val_loader,
                                    device, epochs=200, lr=5e-4,
                                    wd=5e-6)

epochs = np.arange(1, len(tl)+1)
plt.plot(epochs, tl, label='train')
plt.plot(epochs, tt, label='val')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
model = UNet(in_channels=4, out_channels=1, features=[32,64,128,256])

trained_model, tl, tt = train_model(model, train_loader, val_loader,
                                    device, epochs=200, lr=5e-4,
                                    wd=5e-6)

epochs = np.arange(1, len(tl)+1)
plt.plot(epochs, tl, label='train')
plt.plot(epochs, tt, label='val')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
model = UNet(in_channels=4, out_channels=1, features=[32,64,128,256])

trained_model, tl, tt = train_model(model, train_loader, val_loader,
                                    device, epochs=200, lr=1e-4,
                                    wd=5e-3)

epochs = np.arange(1, len(tl)+1)
plt.plot(epochs, tl, label='train')
plt.plot(epochs, tt, label='val')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
def iou_score(logits, mask, eps=1e-6):
    """Intersection over Union para un batch (promediada)."""
    preds = (torch.sigmoid(logits) > 0.5).float()
    mask  = mask.float()
    inter = (preds * mask).sum(dim=(2,3))
    union = (preds + mask - preds*mask).sum(dim=(2,3))
    return ((inter + eps)/(union + eps)).mean().item()

def dice_score(logits, mask, eps=1e-6):
    """Dice Coefficient (promediada)."""
    prob  = torch.sigmoid(logits)
    inter = (prob * mask).sum(dim=(2,3))
    union = prob.sum((2,3)) + mask.sum((2,3))
    return ((2*inter + eps)/(union + eps)).mean().item()

def train_and_evaluate(model, train_loader, val_loader, device,
                       epochs=30, lr=1e-4, weight_decay=1e-5):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr,
                                 weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3
    )
    criterion = torch.nn.BCEWithLogitsLoss()

    train_losses, val_losses, val_ious = [], [], []

    for epoch in range(1, epochs+1):
        model.train()
        running_train = 0.0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss   = criterion(logits, y) + (1 - dice_score(logits, y))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_train += loss.item() * x.size(0)
        train_loss = running_train / len(train_loader.dataset)
        train_losses.append(train_loss)

        model.eval()
        running_val, running_iou = 0.0, 0.0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                logits = model(x)
                loss   = criterion(logits, y) + (1 - dice_score(logits, y))
                running_val += loss.item() * x.size(0)
                running_iou += iou_score(logits, y) * x.size(0)
        val_loss = running_val / len(val_loader.dataset)
        val_iou  = running_iou  / len(val_loader.dataset)
        val_losses.append(val_loss)
        val_ious.append(val_iou)

        scheduler.step(val_loss)

        print(f"Epoch {epoch:02d} — "
              f"train: {train_loss:.4f}  "
              f"val: {val_loss:.4f}  "
              f"val IoU: {val_iou:.3f}")

    return train_losses, val_losses, val_ious

device = torch.device('mps')
model = UNet(in_channels=4, out_channels=1, features=[32,64,128,256])

model.to(device)
tl, vl, vi = train_and_evaluate(model, train_loader, val_loader, device,
                                epochs=30, lr=1e-4, weight_decay=1e-5)

plt.figure()
plt.plot(range(1, len(tl)+1), tl, label='train loss')
plt.plot(range(1, len(vl)+1), vl, label='val loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Train vs Val Loss')
plt.show()

plt.figure()
plt.plot(range(1, len(vi)+1), vi, label='val IoU')
plt.xlabel('Epochs')
plt.ylabel('IoU')
plt.legend()
plt.title('Validation IoU')
plt.show()

In [None]:
def evaluate_classification_metrics(model, loader, device):
    """
    Evalúa accuracy, precision, recall y F1-score
    sobre un DataLoader de validación.
    """
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for imgs, masks in loader:
            imgs = imgs.to(device)
            logits = model(imgs)

            probs = torch.sigmoid(logits).cpu().numpy().flatten()
            preds = (probs > 0.5).astype(int)
            labels = masks.cpu().numpy().flatten().astype(int)

            all_preds.append(preds)
            all_labels.append(labels)

    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    acc  = accuracy_score(all_labels, all_preds)

    prec = precision_score(all_labels, all_preds, zero_division=0)
    rec  = recall_score(all_labels, all_preds, zero_division=0)
    f1   = f1_score(all_labels, all_preds, zero_division=0)

    return acc, prec, rec, f1

device = torch.device('mps')
acc, prec, rec, f1 = evaluate_classification_metrics(
    model, val_loader, device
)
print(f"Accuracy:  {acc:.3f}")
print(f"Precision: {prec:.3f}")
print(f"Recall:    {rec:.3f}")
print(f"F1-score:  {f1:.3f}")

In [None]:
def train_and_evaluate(model, train_loader, val_loader, device,
                       train_losses, val_losses, val_ious,
                       epochs=30, lr=1e-4, weight_decay=1e-5):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr,
                                 weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3
    )
    criterion = torch.nn.BCEWithLogitsLoss()

    for epoch in range(1, epochs+1):
        model.train()
        running_train = 0.0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss   = criterion(logits, y) + (1 - dice_score(logits, y))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_train += loss.item() * x.size(0)
        train_loss = running_train / len(train_loader.dataset)
        train_losses.append(train_loss)

        model.eval()
        running_val, running_iou = 0.0, 0.0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                logits = model(x)
                loss   = criterion(logits, y) + (1 - dice_score(logits, y))
                running_val += loss.item() * x.size(0)
                running_iou += iou_score(logits, y) * x.size(0)
        val_loss = running_val / len(val_loader.dataset)
        val_iou  = running_iou  / len(val_loader.dataset)
        val_losses.append(val_loss)
        val_ious.append(val_iou)

        scheduler.step(val_loss)

        print(f"Epoch {epoch:02d} — "
              f"train: {train_loss:.4f}  "
              f"val: {val_loss:.4f}  "
              f"val IoU: {val_iou:.3f}")

    return train_losses, val_losses, val_ious

In [None]:
model.to(device)
tl_, vl_, vi_ = train_and_evaluate(model, train_loader, val_loader, device,
                                tl, vl, vi,
                                epochs=30, lr=1e-4, weight_decay=1e-5)

plt.figure()
plt.plot(range(1, len(tl_)+1), tl_, label='train loss')
plt.plot(range(1, len(vl_)+1), vl_, label='val loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Train vs Val Loss')
plt.show()

plt.figure()
plt.plot(range(1, len(vi_)+1), vi_, label='val IoU')
plt.xlabel('Epochs')
plt.ylabel('IoU')
plt.legend()
plt.title('Validation IoU')
plt.show()