In [1]:
# basic imports
import os, sys, time
import argparse
import torch
import numpy as np
from utils.YParams import YParams
from ruamel.yaml import YAML
# torch optimizers and lr schedulers
import torch.optim as optim
from torch.optim import lr_scheduler

In [2]:
# some dummy model
import torch.nn as nn

def my_conv(in_channels, out_channels, kernel_size):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, padding='same'),
        nn.LeakyReLU(inplace=True)
        )
           
class CNN(nn.Module):
    def __init__(self, in_channels=2, out_channels=2, depth=5, hidden_dim=64, kernel_size=5, dropout=0.):
        super(CNN, self).__init__()
        self.depth = depth
        self.dropout = dropout
        self.conv_in = my_conv(in_channels, hidden_dim, kernel_size)
        self.conv_hidden = nn.ModuleList([my_conv(hidden_dim, hidden_dim, kernel_size) for _ in range(self.depth-2)]) 
        self.conv_out = my_conv(hidden_dim, out_channels, kernel_size)

    def forward(self, x):
        x = self.conv_in(x)
        for layer in self.conv_hidden:
            x = layer(x)
        x = self.conv_out(x)
        return x
       
def simple_cnn(params, **kwargs):
    model = CNN(in_channels=params.in_chan, out_channels=params.out_chan, depth=params.depth, hidden_dim=64,
                kernel_size=3, **kwargs)
    return model


In [3]:
# some dummy dataloader
from torch.utils.data import DataLoader, Dataset

# see: https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

def get_data_loader(params, location, train=True):
    dataset = TestDataSet(params, location)
    if train:
        batch_size = params.local_batch_size
    else:
        batch_size = params.local_valid_batch_size
    dataloader = DataLoader(dataset,
                            batch_size=int(batch_size),
                            num_workers=params.num_data_workers,
                            shuffle=True,
                            sampler=None,
                            drop_last=True,
                            pin_memory=torch.cuda.is_available())
    return dataloader


class TestDataSet(Dataset):
    def __init__(self, params, location):
        self.params = params
        self.location = location # not used, but input data loc goes here
        self.n_samples = 128

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        ''' just return random tensors '''
        X = torch.rand((1,128,128))
        y = torch.rand((1,128,128))
        return X, y


In [4]:
# get hyperparams and config details
config_name = 'default'
params = YParams(os.path.abspath('./configs/default.yaml'), config_name)
params.log()

------------------ Configuration ------------------
Configuration file: /global/u2/s/shas1693/codes/nersc-dl-multigpu/configs/default.yaml
Configuration name: default
num_data_workers 1
in_chan 1
out_chan 1
depth 5
lr 0.001
max_epochs 25
max_cosine_lr_epochs 25
batch_size 32
valid_batch_size 32
log_to_screen True
save_checkpoint True
train_path 
val_path 
---------------------------------------------------


In [5]:
# where to run
if torch.cuda.is_available():
    torch.cuda.set_device(0)
    torch.backends.cudnn.benchmark = True
    device = torch.cuda.current_device()
else:
    device = torch.device('cpu')
print(device)

0


In [6]:
# get the data
# batch sizes for training and validation (local batch sizes are the same on single GPU)
params['global_batch_size'] = params.batch_size
params['local_batch_size'] = params.batch_size
params['global_valid_batch_size'] = params.valid_batch_size
params['local_valid_batch_size'] = params.valid_batch_size

# get the dataloaders
train_data_loader = get_data_loader(params, params.train_path, train=True)
val_data_loader = get_data_loader(params, params.val_path, train=False)

In [7]:
# get the model and training details
model = simple_cnn(params).to(device) # send model wts to GPU

# set an optimizer and learning rate scheduler
optimizer = optim.Adam(model.parameters(), lr=params.lr)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=params.max_cosine_lr_epochs)

# set loss functions
loss_func = torch.nn.MSELoss()

In [8]:
# train and validate
for epoch in range(0, params.max_epochs):
    start = time.time()
    
    # training
    model.train()
    train_loss = 0.
    tr_start = time.time()
    for i, (inputs, targets) in enumerate(train_data_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        # zero grads
        model.zero_grad()
        # fwd
        u = model(inputs)
        loss = loss_func(u, targets)
        train_loss += loss.detach()
        # bwd
        loss.backward()
        # update
        optimizer.step()
    tr_time = time.time() - tr_start
    train_loss /= len(train_data_loader) # avg train loss
    
    # validation
    model.eval()
    val_loss = 0.
    val_start = time.time()
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(val_data_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            u = model(inputs)
            loss = loss_func(u, targets)
            val_loss += loss.detach()
    val_time = time.time() - val_start
    val_loss /= len(val_data_loader)

    # learning rate scheduler
    scheduler.step()

    print('Time taken for epoch {} is {} sec; with {}/{} in tr/val'.format(epoch+1, time.time()-start, tr_time, val_time))
    print('Loss = {}, Val loss = {}'.format(train_loss, val_loss))

Time taken for epoch 1 is 7.895143508911133 sec; with 7.710684061050415/0.1817774772644043 in tr/val
Loss = 0.2027476280927658, Val loss = 0.14298033714294434
Time taken for epoch 2 is 0.41819047927856445 sec; with 0.23258662223815918/0.18485164642333984 in tr/val
Loss = 0.10877930372953415, Val loss = 0.10980162024497986
Time taken for epoch 3 is 0.39698266983032227 sec; with 0.2169637680053711/0.17949652671813965 in tr/val
Loss = 0.09730541706085205, Val loss = 0.0945519506931305
Time taken for epoch 4 is 0.41175198554992676 sec; with 0.23145174980163574/0.17969727516174316 in tr/val
Loss = 0.09379316866397858, Val loss = 0.0875803604722023
Time taken for epoch 5 is 0.41367506980895996 sec; with 0.23390460014343262/0.17927885055541992 in tr/val
Loss = 0.09104160219430923, Val loss = 0.08907153457403183
Time taken for epoch 6 is 0.4241926670074463 sec; with 0.23750066757202148/0.18609142303466797 in tr/val
Loss = 0.08819828927516937, Val loss = 0.08918419480323792
Time taken for epoch