# Debugging GPU Memroy Usage in PyTorch

In [1]:
import sys
from pathlib import Path

# Make project root importable
ROOT = Path().resolve().parents[1]
sys.path.append(str(ROOT))

In [2]:
%load_ext autoreload

In [3]:
%autoreload 2
from src.utils.variables.var_names import *
from src.utils.variables.coord_names import *
from src.data_processing.conversions.scalar_conversions import *
from src.config.env_loader import get_env_var
import src.learning.model_diagnostics as model_diagnostics
from src.learning.model_training import batch_data_by_num_stations, compute_val_loss

from src.data_processing.station_processor import ProcessStations
from src.data_processing.topography_processor import ProcessTopography
from src.data_processing.era5_processor import ProcessERA5

In [4]:
%autoreload 2
import deepsensor.torch
from deepsensor.train.train import train_epoch, set_gpu_default_device
from deepsensor.data.loader import TaskLoader
from deepsensor.data.processor import DataProcessor
from deepsensor.model.convnp import ConvNP
from deepsensor.data.utils import construct_x1x2_ds

In [5]:
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
from mpl_toolkits.basemap import Basemap
import torch
from torch import optim
import os
import lab as B
from tqdm import tqdm
import cartopy.crs as ccrs
import cartopy.feature as cf

## Set up a Demo Dataset

In [6]:
# setup variables for experiment
var = TEMPERATURE
years = [2010, 2011, 2012, 2013, 2014]

train_years = [2010] #[2010, 2011, 2012, 2013]
validation_years = [2014]

# GPU settings
use_gpu = True
if use_gpu:
    set_gpu_default_device()

# visualisations of data
DEBUG_PLOTS = True

In [7]:
station_processor = ProcessStations()
topography_processor = ProcessTopography()
era5_processor = ProcessERA5()

In [8]:
topography_ds = topography_processor.load_ds(standardise_var_names=True, standardise_coord_names=True)
era5_ds = era5_processor.load_ds(mode="surface", years=years, standardise_var_names=True, standardise_coord_names=True)

In [9]:
era5_var = era5_processor.get_variable(era5_ds, var) # set variable to process - e.g. "temperature"
era5_var = kelvin_to_celsius(era5_var)
era5_ds[var] = era5_var

In [10]:
crop = False

crop_left = 166
crop_right = 176
crop_top = -38
crop_bottom = -48

In [11]:
if crop:
    era5_ds = era5_ds.sel(lat=slice(crop_top, crop_bottom), lon=slice(crop_left, crop_right))
    topography_ds = topography_ds.sel(lat=slice(crop_bottom, crop_top), lon=slice(crop_left, crop_right))

era5_ds_coarsen = era5_ds.coarsen(lat=5, lon=5, boundary='trim').mean()

In [12]:
ds_aux = topography_processor.compute_tpi(topography_ds, window_sizes=[0.1])

# coarsen the elevation data
ds_aux_coarse  = ds_aux.coarsen(lat=200, lon=200, boundary='trim').mean()

ds_aux = ds_aux.fillna(0)
ds_aux_coarse = ds_aux_coarse.fillna(0)

In [13]:
stations_df = station_processor.load_df(vars=[var], year_start=2010, year_end=2014)
stations_df.head()
stations_reset = stations_df.reset_index()
stations_reset.drop(columns=['station'], inplace=True)
stations_resample = stations_reset.groupby(['lat', 'lon']).resample("6h", on='time').mean()[['temperature']]
stations_resample = stations_resample.reset_index().set_index(['time', 'lat', 'lon']).sort_index()

if crop:
    stations_resample = stations_resample[(stations_resample.index.get_level_values('lat') > crop_bottom) & (stations_resample.index.get_level_values('lat') < crop_top) &
                                      (stations_resample.index.get_level_values('lon') > crop_left) & (stations_resample.index.get_level_values('lon') < crop_right)]

  ds_comb = xr.concat([first, *station_iter], dim="station")
  stations_resample = stations_reset.groupby(['lat', 'lon']).resample("6h", on='time').mean()[['temperature']]


In [14]:
era5_da = era5_ds.sel(lat=slice(ds_aux_coarse[LATITUDE].max(), ds_aux_coarse[LATITUDE].min()), lon=slice(ds_aux_coarse[LONGITUDE].min(), ds_aux_coarse[LONGITUDE].max()))

In [15]:
era5_ds_coarsen = era5_ds_coarsen[[var]]

In [16]:
data_processor = DataProcessor(x1_name=LATITUDE, x1_map=(era5_ds[LATITUDE].min(), era5_ds[LATITUDE].max()), x2_name=LONGITUDE, x2_map=(era5_ds[LONGITUDE].min(), era5_ds[LONGITUDE].max()))
era5_processed, station_processed = data_processor([era5_ds_coarsen, stations_resample])
ds_aux_processed, ds_aux_coarse_processed = data_processor([ds_aux, ds_aux_coarse], method='min_max')

x1x2_ds = construct_x1x2_ds(ds_aux_coarse_processed)
ds_aux_coarse_processed['x1_arr'] = x1x2_ds['x1_arr']
ds_aux_coarse_processed['x2_arr'] = x1x2_ds['x2_arr']

  f"x1_map={x1_map} and x2_map={x2_map} have different ranges ({float(np.diff(x1_map))} "
  f"and {float(np.diff(x2_map))}, respectively). "


## Model Training

In [17]:
task_loader = TaskLoader(
        context = [station_processed, era5_processed, ds_aux_coarse_processed], 
        target = station_processed, 
        aux_at_targets = ds_aux_processed, 
        links = [(0, 0)])

In [18]:
model = ConvNP(data_processor, task_loader, unet_channels=(64,)*5, likelihood="gnp", internal_density=50)

dim_yc inferred from TaskLoader: (1, 1, 4)
dim_yt inferred from TaskLoader: 1
dim_aux_t inferred from TaskLoader: 2
Setting aux_t_mlp_layers: (64, 64, 64)
encoder_scales inferred from TaskLoader: [0.01, 0.0347222238779068, 0.007111109793186188]
decoder_scale inferred from TaskLoader: 0.02


## Training Loop

In [19]:
task_loader.load_dask()

In [20]:
train_dates = era5_ds.sel(time=slice("2010-01-01", "2011-12-31")).time.values
val_dates = era5_ds.sel(time=slice("2012-01-01", "2012-06-30")).time.values

In [21]:
train_tasks = []
for date in tqdm(train_dates):
    task = task_loader(date, context_sampling=["split", "all", "all"], target_sampling=["split"], split_frac=0.5)
    train_tasks.append(task)


val_tasks = []
for date in tqdm(val_dates):
    task = task_loader(date, context_sampling=["split", "all", "all"], target_sampling=["split"], split_frac=0.5)
    val_tasks.append(task)


100%|██████████| 2920/2920 [00:42<00:00, 69.04it/s]
100%|██████████| 728/728 [00:10<00:00, 71.04it/s]


In [22]:
print(train_tasks[1])

time: 2010-01-01 06:00:00
ops: []
X_c: [(2, 5), ((1, 14), (1, 12)), ((1, 54), (1, 54))]
Y_c: [(1, 5), (1, 14, 12), (4, 54, 54)]
X_t: [(2, 6)]
Y_t: [(1, 6)]
Y_t_aux: (2, 6)



- three context sets
- X_c is the coordinates for the context sets
- Y_c is the values for the context sets
- 3, 20x20, 78x78 observations (Y_c)
- X_t is the target sensor coordinates
- Y_t is the target sensor values

In [23]:
task_batched = batch_data_by_num_stations(train_tasks, batch_size=16)

In [None]:
len(task_batched)

186

: 

In [None]:
n_epochs = 3
train_losses = []
val_losses = []
lr=5e-5

output_model = False

val_loss_best = np.inf

opt = optim.Adam(model.model.parameters(), lr=lr)

with torch.profiler.profile(profile_memory=True, record_shapes=True, with_stack=True) as prof:
    for epoch in tqdm(range(n_epochs)):
        
        

        batch_losses = [train_epoch(model, task_batched[f'{num_stations}'], 
                                                batch_size=len(task_batched[f'{num_stations}']), 
                                                lr=lr, opt=opt) for num_stations in task_batched.keys()]
        
        train_loss = np.mean(batch_losses)
        train_losses.append(train_loss)

        with torch.no_grad():
            val_loss = compute_val_loss(model, val_tasks)
        val_losses.append(val_loss)

        if val_loss < val_loss_best:
            val_loss_best = val_loss
            if output_model:
                folder = os.path.join(get_env_var("OUTPUT_HOME"), "models", "downscaling", "temperature", "convcnp")
                if not os.path.exists(folder): os.makedirs(folder)
                torch.save(model.model.state_dict(), folder + f"model.pt")

        torch.cuda.empty_cache()

        print(f"Epoch {epoch} train_loss: {train_loss:.2f}, val_loss: {val_loss:.2f}")

  for name in np.core.numerictypes.__all__ + ["bool"]:
 33%|███▎      | 1/3 [03:35<07:11, 215.96s/it]

Epoch 0 train_loss: 1.73, val_loss: 1.51


 67%|██████▋   | 2/3 [07:10<03:35, 215.03s/it]

Epoch 1 train_loss: 1.64, val_loss: 1.47


 67%|██████▋   | 2/3 [17:43<08:51, 531.55s/it]

In [None]:
# print the memory profile
print(prof.key_averages(group_by_input_shape=True).table(
    sort_by="self_cuda_memory_usage")
)

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls                                                                      Input Shapes  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  -------------------------------------------------------------------------

In [None]:
torch.cuda.empty_cache()

tensor multiplication: [[16, 16000, 6], [16, 6, 8128]]

In [None]:
deepsensor.model.nps.compute_encoding_tensor(model, task).shape

(1, 4, 32, 32)