In [1]:
import torch
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
import xarray as xr
import seaborn as sb
import torch.nn.functional as F
import pytorch_lightning as pl
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import DiceLoss
from torchmetrics import F1Score
import wandb
from pytorch_lightning.loggers import WandbLogger
import OSSE_DataLoader

import warnings
warnings.filterwarnings('ignore')

In [4]:
# files = []
# import os
# for dirname, _, filenames in os.walk('./data'): #/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))
#         files.append(os.path.join(dirname, filename))

./data/OSSE_U_V_SLA_SST_train.nc
./data/eddies_train.nc
./data/OSSE_U_V_SLA_SST_test.nc


In [2]:
#3992370031028f22c45bf7c639d150a61cd79a84
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mdouxthibault[0m ([33mdoux[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

# Load data

In [3]:
from OSSE_DataLoader import get_xarray, get_data_loaders

OSSE_train, eddies_train, OSSE_test =  get_xarray() # parent_dir = "./data"
batch_size = 8
train_dataloader, val_dataloader = get_data_loaders(
    batch_size, 
    OSSE_train, 
    eddies_train, 
    osse_nan_value=3,
    eddies_nan_value=3,
    )

In [7]:
# WARNING flatten le résultat en IdxTemps_IdLatitude_IdLongitude
# Keep only needed variable 

## Kaggle path
# eddies_train = xr.open_dataset("/kaggle/input/ocean-eddy-detection/eddies_train.nc")
# OSSE_test = xr.open_dataset("/kaggle/input/ocean-eddy-detection/OSSE_U_V_SLA_SST_test.nc")
# OSSE_train = xr.open_dataset("/kaggle/input/ocean-eddy-detection/OSSE_U_V_SLA_SST_train.nc")

## Local path
# eddies_train = xr.open_dataset("./data/eddies_train.nc")
# OSSE_test = xr.open_dataset("./data/OSSE_U_V_SLA_SST_test.nc")
# OSSE_train = xr.open_dataset("./data/OSSE_U_V_SLA_SST_train.nc")
# OSSE_train = OSSE_train.rename({"time_counter":"time"})

# selected_var = ['vomecrtyT', 'vozocrtxT','sossheig','votemper']
# OSSE_train = OSSE_train.fillna(0.0)
# X_full = torch.tensor(OSSE_train.get(selected_var).to_array().to_numpy()) # Features x Nb x Latitude x Longitude
# X_full = X_full.permute(1, 0, 2, 3) 

# na_value = 3.0
# eddies_train = eddies_train.fillna(na_value) # fill coast with 999.
# y_full = torch.tensor(eddies_train.to_array().to_numpy()) # Target x Times x Latitude x Longitude 
# y_full = y_full.permute(1, 0, 2, 3).long()

In [8]:
# nb_val = X_full.shape[0]
# idx_split = int(0.8*nb_val)
# X_train = torch.tensor(X_full[:idx_split,: , :, :]).clone().detach()
# y_train = torch.tensor(y_full[:idx_split,: , :, :]).clone().detach()
# X_val = torch.tensor(X_full[idx_split:, : , :, :]).clone().detach()
# y_val = torch.tensor(y_full[idx_split:,: , :, :]).clone().detach()

In [9]:
# class OSSE_Dataset(Dataset):
#     def __init__(self, OSSE_tensor,eddie_tensor):
#         self.OSSE_tensor = OSSE_tensor
#         self.eddie_tensor = eddie_tensor
        
#     def __len__(self):
#         return self.OSSE_tensor.shape[0]

#     def __getitem__(self, idx):
#         features = self.OSSE_tensor[idx, :, :, :]
#         label = self.eddie_tensor[idx, :, :, :]
        
#         return features, label

In [10]:
# ds_train = OSSE_Dataset(X_train,y_train)
# train_dataloader = DataLoader(ds_train, batch_size = 8)

# ds_val = OSSE_Dataset(X_val,y_val)
# val_dataloader = DataLoader(ds_val, batch_size = 8)

In [4]:
def conv_size_out(hin, win, padding, dilation, kernel_size, stride):
    hout = np.floor(((hin + 2*padding - (dilation * (kernel_size -1 )) -1) / stride ) + 1)
    wout = np.floor(((hin + 2*padding - (dilation * (kernel_size -1 )) -1) / stride ) + 1)
    return hout, wout

def conv_transpose_out(hin, win, padding, dilation, kernel_size, stride):
    hout = np.floor(((hin + 2*padding - (dilation * (kernel_size -1 )) -1) / stride ) + 1)
    wout = np.floor(((hin + 2*padding - (dilation * (kernel_size -1 )) -1) / stride ) + 1)
    return hout, wout

# Model

In [5]:
class classic_CNN(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # Input 4 x 357 x 717
        self.classes = 4
        self.features = 4 
        self.kernel_size = 3
        self.accuracy = F1Score(task ="multiclass",  num_classes=4)
        # Encoder
        self.encoder = nn.Sequential( 
            nn.Conv2d(self.features, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 16, 6))
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(16, 32, 6),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 64, 3, stride=2, padding=1, output_padding=0),
            nn.ReLU(),
            nn.ConvTranspose2d(64, self.classes, 3, stride=2, padding=1, output_padding=0))

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    def loss_fn(self,ypred,ytrue):
        return DiceLoss(mode = 'multiclass', ignore_index = 999)(ypred, ytrue)   # Ypred = NxCxHxW Ytrue NxHxW
    
    def configure_optimizers(self):
        LR = 1e-3
        optimizer = torch.optim.Adam(self.parameters(),lr=LR)
        return optimizer

    def training_step(self,batch,batch_idx):
        x,y = batch[0],batch[1]
        out = self(x)
        out = nn.Softmax(-1)(out)
        loss = self.loss_fn(out,y)
        self.log('train_loss', loss)
        return loss      

    def validation_step(self,batch,batch_idx):
        x,y = batch[0],batch[1]
        label = y
        out = self(x)
        out = nn.Softmax(-1)(out)
        loss = self.loss_fn(out,label)
        logits = torch.argmax(out,dim=1, keepdim=True)
        accu = self.accuracy(logits, label)        
        self.log('valid_loss', loss)
        self.log('train_acc_step', accu)
        return loss, accu

In [6]:
wandb_logger = WandbLogger(project='cnn-lightning', 
                           group='CNN', 
                           job_type='train')

# Training 

In [7]:
model = classic_CNN()
trainer = pl.Trainer(max_epochs = 10, logger=wandb_logger, enable_progress_bar=True)
#trainer = pl.Trainer(accelerator='mps',devices=1, max_epochs = 10, logger=wandb_logger, enable_progress_bar=True)

trainer.fit(model, train_dataloader, val_dataloader)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name     | Type              | Params
-----------------------------------------------
0 | accuracy | MulticlassF1Score | 0     
1 | encoder  | Sequential        | 39.3 K
2 | decoder  | Sequential        | 39.3 K
-----------------------------------------------
78.5 K    Trainable params
0         Non-trainable params
78.5 K    Total params
0.314     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


In [44]:
wandb.finish()