In [1]:
import logging
from pathlib import Path

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import wandb
from torch import optim
from tqdm import tqdm

from torchsummary import summary

from dice_score import dice_loss
from OSSE_DataLoader import get_data_loaders, get_xarray, normalize_osse
from unet import UNet

In [2]:

batch_size = 16


INPUT_IMAGE_HEIGHT = 357
INPUT_IMAGE_WIDTH = 717

augmentation = transforms.Compose([
    transforms.RandomRotation(180),
    transforms.RandomCrop((INPUT_IMAGE_HEIGHT//6, INPUT_IMAGE_WIDTH//6)),
])
    
OSSE_train, eddies_train, OSSE_test =  get_xarray()

OSSE_train, OSSE_test, *_ = normalize_osse(OSSE_train, OSSE_test)


train_dataloader, val_dataloader = get_data_loaders(batch_size, OSSE_train, eddies_train, 0, 0, augmentations=augmentation)

data_iter = iter(train_dataloader)

features, labels = next(data_iter)


In [3]:
import torch
import torch.nn.functional as F
from tqdm import tqdm

from dice_score import multiclass_dice_coeff, dice_coeff


@torch.inference_mode()
def evaluate(net, dataloader, device, amp):
    net.eval()
    num_val_batches = len(dataloader)
    dice_score = 0

    # iterate over the validation set
    # with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
    for (image, mask_true) in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
        # image, mask_true = batch['image'], batch['mask']

        # move images and labels to correct device and type
        # image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
        image = image.to(device=device, dtype=torch.float32)
        mask_true = mask_true.to(device=device, dtype=torch.long)

        # predict the mask
        mask_pred = net(image)

        if net.n_classes == 1:
            assert mask_true.min() >= 0 and mask_true.max() <= 1, 'True mask indices should be in [0, 1]'
            mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
            # compute the Dice score
            dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
        else:
            assert mask_true.min() >= 0 and mask_true.max() < net.n_classes, 'True mask indices should be in [0, n_classes['
            # convert to one-hot format
            # mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float()
            # mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float()
            mask_true = F.one_hot(mask_true.squeeze(1), net.n_classes).float()
            mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
            # compute the Dice score, ignoring background
            dice_score += multiclass_dice_coeff(mask_pred[:, 1:], mask_true[:, 1:], reduce_batch_first=False)

    net.train()
    return dice_score / max(num_val_batches, 1)

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')

model = UNet(n_channels=4, n_classes=3, bilinear=False)
model =model.to(device)
# model.to("mps")

print(device)

# summary(model, input_size=(4, 357//4, 717//4))

mps


In [5]:
model.train()
loss = 0

criterion = nn.CrossEntropyLoss()

for features, labels in train_dataloader:
    print(features.shape)
    print(labels.shape)

    features = features.to(device)
    true_masks = labels.to(device)

    masks_pred = model(features)

    y_true = torch.squeeze(true_masks, 1).long()

    loss += criterion(masks_pred, y_true)
    loss += dice_loss(
        F.softmax(masks_pred, dim=1).float(),
        F.one_hot(y_true, model.n_classes).permute(0, 3, 1, 2).float(),
        multiclass=True
    )

    print(loss)


torch.Size([4, 4, 44, 89])
torch.Size([4, 1, 44, 89])


  nonzero_finite_vals = torch.masked_select(


tensor(1.8130, device='mps:0', grad_fn=<AddBackward0>)
torch.Size([4, 4, 44, 89])
torch.Size([4, 1, 44, 89])
tensor(3.5855, device='mps:0', grad_fn=<AddBackward0>)
torch.Size([4, 4, 44, 89])
torch.Size([4, 1, 44, 89])
tensor(5.3351, device='mps:0', grad_fn=<AddBackward0>)
torch.Size([4, 4, 44, 89])
torch.Size([4, 1, 44, 89])
tensor(7.1229, device='mps:0', grad_fn=<AddBackward0>)
torch.Size([4, 4, 44, 89])
torch.Size([4, 1, 44, 89])
tensor(8.9121, device='mps:0', grad_fn=<AddBackward0>)
torch.Size([4, 4, 44, 89])
torch.Size([4, 1, 44, 89])
tensor(10.7066, device='mps:0', grad_fn=<AddBackward0>)
torch.Size([4, 4, 44, 89])
torch.Size([4, 1, 44, 89])
tensor(12.5044, device='mps:0', grad_fn=<AddBackward0>)
torch.Size([4, 4, 44, 89])
torch.Size([4, 1, 44, 89])
tensor(14.3565, device='mps:0', grad_fn=<AddBackward0>)
torch.Size([4, 4, 44, 89])
torch.Size([4, 1, 44, 89])
tensor(16.1693, device='mps:0', grad_fn=<AddBackward0>)
torch.Size([4, 4, 44, 89])
torch.Size([4, 1, 44, 89])
tensor(17.9422,

In [6]:
features, labels = next(data_iter)

features = features.to(device, dtype=torch.float32)
labels= labels.to(device)

model.train()


UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
 

In [7]:
features.shape

torch.Size([4, 4, 89, 179])

In [8]:
model(features)

tensor([[[[ 1.1164e-01, -1.8179e-02, -3.9206e-03,  ..., -8.2241e-01,
           -1.8600e-01, -3.6633e-01],
          [ 2.0391e-01, -8.0261e-02,  1.2172e-01,  ..., -9.3781e-01,
           -5.8848e-01, -1.5908e-01],
          [ 1.6411e-01,  6.3528e-02,  2.6253e-02,  ..., -3.0476e-01,
            1.6268e-01,  6.7472e-04],
          ...,
          [ 2.9629e-01, -3.8887e-02,  1.3487e-01,  ..., -2.1646e-01,
            5.2838e-01,  5.0821e-01],
          [ 1.9403e-01, -4.2702e-01, -1.8601e-01,  ...,  3.6195e-01,
            2.9550e-01, -7.9762e-02],
          [ 2.2907e-01, -1.2006e-01, -3.0962e-01,  ...,  2.1625e-01,
            2.1107e-03,  2.0906e-01]],

         [[-1.9044e-01,  2.4912e-01,  5.2261e-02,  ..., -2.1875e-02,
            1.5649e-01,  4.6888e-02],
          [ 3.9567e-02,  1.0505e-01,  1.7893e-01,  ..., -2.2845e-01,
           -1.0090e-01, -4.7845e-02],
          [-5.8220e-02,  3.5523e-01,  2.9870e-01,  ..., -1.0247e-02,
            2.1321e-03,  7.1031e-02],
          ...,
     

In [6]:
dir_img = Path('./data/imgs/')
dir_mask = Path('./data/masks/')
dir_checkpoint = Path('./checkpoints/')

epochs = 2000
learning_rate = 1e-3
val_percent = 0.1
save_checkpoint = True
img_scale = 1
amp = False
weight_decay = 1e-8
momentum = 0.999
gradient_clipping = 1

train_loader, val_loader = train_dataloader, val_dataloader


n_train = int(len(OSSE_train) * 0.8)
n_val = int(len(OSSE_test) * 0.2)

# (Initialize logging)
experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
experiment.config.update(
    dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
            val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale, amp=amp)
)

logging.info(f'''Starting training:
    Epochs:          {epochs}
    Batch size:      {batch_size}
    Learning rate:   {learning_rate}
    Training size:   {n_train}
    Validation size: {n_val}
    Checkpoints:     {save_checkpoint}
    Device:          {device.type}
    Images scaling:  {img_scale}
    Mixed Precision: {amp}
''')

# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
optimizer = optim.RMSprop(model.parameters(),
                            # lr=learning_rate, weight_decay=weight_decay, momentum=momentum, foreach=True)
                            lr=learning_rate, weight_decay=weight_decay, momentum=momentum)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5)  # goal: maximize Dice score
# grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss()
global_step = 0

# print("là")

# 5. Begin training
for epoch in range(1, epochs + 1):
    model.train()
    epoch_loss = 0
    # print("ici")
    # with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:

    for (images, true_masks) in train_loader:
        # images, true_masks = batch['image'], batch['mask']

        # print(images.shape)
        # print(true_masks.shape)

        # assert images.shape[1] == model.n_channels, \
        #     f'Network has been defined with {model.n_channels} input channels, ' \
        #     f'but loaded images have {images.shape[1]} channels. Please check that ' \
        #     'the images are loaded correctly.'

        # images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
        images = images.to(device)
        # images = images.to(device=device, dtype=torch.float32)
        # true_masks = true_masks.to(device=device, dtype=torch.long)
        true_masks = true_masks.to(device)

        # print("jsuis là fréro")

        # with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
        masks_pred = model(images)
        # print("juste ici")
        
        # if model.n_classes == 1:
        #     loss = criterion(masks_pred.squeeze(1), true_masks.float())
        #     loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)
        # else:
        #     # loss = criterion(masks_pred, true_masks)
        #     # print(masks_pred.shape)
        #     # print(true_masks.shape)

        #     y_true = torch.squeeze(true_masks, 1)
        #     # print(y_true.shape)
        #     print("là bas")
        #     # loss = criterion(masks_pred, F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float())
        #     loss = criterion(masks_pred, y_true.float())
        #     loss += dice_loss(
        #         F.softmax(masks_pred, dim=1).float(),
        #         # F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),
        #         F.one_hot(y_true, model.n_classes).float(),
        #         multiclass=True
        #     )


        y_true = torch.squeeze(true_masks, 1).long()

        loss = criterion(masks_pred, y_true)
        loss += dice_loss(
            F.softmax(masks_pred, dim=1).float(),
            F.one_hot(y_true, model.n_classes).permute(0, 3, 1, 2).float(),
            multiclass=True
        )


        # print("here")

        optimizer.zero_grad(set_to_none=True)
        # grad_scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
        # grad_scaler.step(optimizer)
        # loss.step(optimizer)
        # grad_scaler.update()
        loss.backward()
        optimizer.step()
        
        # print("there")

        # pbar.update(images.shape[0])
        global_step += 1
        epoch_loss += loss.item()
        # experiment.log({
        #     'train loss': loss.item(),
        #     'step': global_step,
        #     'epoch': epoch
        # })
        experiment.log({
            'train loss': loss.item(),
            'step': global_step,
            'epoch': epoch
        })

        # pbar.set_postfix(**{'loss (batch)': loss.item()})

        # Evaluation round
        division_step = (n_train // (5 * batch_size))
        if division_step > 0:
            if global_step % division_step == 0:
                histograms = {}
                for tag, value in model.named_parameters():
                    tag = tag.replace('/', '.')
                    if not torch.isinf(value).any():
                        histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
                    if not torch.isinf(value.grad).any():
                        histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())

                val_score = evaluate(model, val_loader, device, amp)
                scheduler.step(val_score)

                logging.info('Validation Dice score: {}'.format(val_score))
                try:
                    experiment.log({
                    # logging.info({
                        'learning rate': optimizer.param_groups[0]['lr'],
                        'validation Dice': val_score,
                        'images': wandb.Image(images[0].cpu()),
                        'masks': {
                            'true': wandb.Image(true_masks[0].float().cpu()),
                            'pred': wandb.Image(masks_pred.argmax(dim=1)[0].float().cpu()),
                        },
                        'step': global_step,
                        'epoch': epoch,
                        **histograms
                    })
                except:
                    pass

    if save_checkpoint:
        Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
        state_dict = model.state_dict()
        # state_dict['mask_values'] = dataset.mask_values
        torch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))
        logging.info(f'Checkpoint {epoch} saved!')



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016689087499999998, max=1.0…

KeyboardInterrupt: 