In [None]:
# ! pip install scikit-image
# ! pip install torch torchmetrics torchvision
# ! nvidia-smi

# Import Packages

In [None]:
import torch
from tqdm import tqdm 
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure

import sys
sys.path.append("../../src")
from utils.NOAHminiDataset import *
from utils.NOAHModelUNetFiLM import *

# Set-up Environment

In [None]:
num_channels = 5 # static rs
conditioning_dim = (24 * 6) + 4 # 24hrs for 6 features + lat, long, elev, cloud cover

learning_rate = 0.00001
batch_size = 2
num_epochs = 100

# device = 'cpu'
# device = 'cuda:0'
device = 'cuda:1'

In [None]:
dataset_dict = {
    'B1': {
        'train_dataset': NOAHminiB1Dataset(device = device),
        'test_dataset': NOAHminiB1Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB1Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B2': {
        'train_dataset': NOAHminiB2Dataset(device = device),
        'test_dataset': NOAHminiB2Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB2Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B3': {
        'train_dataset': NOAHminiB3Dataset(device = device),
        'test_dataset': NOAHminiB3Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB3Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B4': {
        'train_dataset': NOAHminiB4Dataset(device = device),
        'test_dataset': NOAHminiB4Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB4Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B5': {
        'train_dataset': NOAHminiB5Dataset(device = device),
        'test_dataset': NOAHminiB5Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB5Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B6': {
        'train_dataset': NOAHminiB6Dataset(device = device),
        'test_dataset': NOAHminiB6Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB6Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B7': {
        'train_dataset': NOAHminiB7Dataset(device = device),
        'test_dataset': NOAHminiB7Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB7Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B8': {
        'train_dataset': NOAHminiB8Dataset(device = device),
        'test_dataset': NOAHminiB8Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB8Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B9': {
        'train_dataset': NOAHminiB9Dataset(device = device),
        'test_dataset': NOAHminiB9Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB9Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B10': {
        'train_dataset': NOAHminiB10Dataset(device = device),
        'test_dataset': NOAHminiB10Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB10Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
    'B11': {
        'train_dataset': NOAHminiB11Dataset(device = device),
        'test_dataset': NOAHminiB11Dataset(
            device = device,
            dataset_csv = 'test.csv'
        ),
        'val_dataset': NOAHminiB11Dataset(
            device = device,
            dataset_csv = 'val.csv'
        ),
    },
}

# Train Models

In [None]:
for band in dataset_dict:
    print(f"Working on {band}....")
    # build model
    model = UNetFiLM(
        in_channels = num_channels, 
        conditioning_dim = conditioning_dim
    ).to(
        device
    )

    # get dataset
    train_dataset = dataset_dict[band]['train_dataset']
    train_dataloader = DataLoader(
        train_dataset, 
        batch_size = batch_size,
        shuffle = True, 
        drop_last = True
    )
    test_dataset = dataset_dict[band]['test_dataset']
    test_dataloader = DataLoader(
        test_dataset, 
        batch_size = batch_size,
        shuffle = True, 
        drop_last = True
    )

    optimizer = optim.Adam(
        model.parameters(), 
        lr = learning_rate
    )

    # build metrics
    # criterion = nn.MSELoss() 
    criterion = SSIMLoss(data_range=1.0).to(
        device
    )
    ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
    psnr = PeakSignalNoiseRatio().to(device)
    mse = nn.MSELoss()

    for epoch in range(num_epochs):
        model.train()
        
        running_loss = 0.0
        running_ssim = 0.0
        running_psnr = 0.0
        data_count = 0.0
        
        train_batch_progress_bar = tqdm(
            train_dataloader,
            desc = "Batch"
        )

        model.train()
        model = model.to(device)
        for gbs, static_rs, targets in train_batch_progress_bar:

            # zero gradients
            optimizer.zero_grad()
            
            # forward pass
            outputs = model(
                x_img = static_rs, 
                x_weather = gbs,
            )

            # get loss
            loss = criterion(outputs, targets)

            # back pass
            loss.backward()
            optimizer.step()

            train_batch_progress_bar.set_postfix({
                'loss': loss.item(),
            })

        del outputs
        # clear cache
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

        # testing
        test_batch_progress_bar = tqdm(
            test_dataloader,
            desc = "Batch test"
        )

        model.eval()
        # model = model.to('cpu')
        for gbs, static_rs, targets in test_batch_progress_bar:
            outputs = model(
                x_img = static_rs, 
                x_weather = gbs,
            )
            
            # get ssim
            ssim_val = ssim(targets, outputs).item()

            # get psnr
            psnr_val = psnr(targets, outputs).item()

            # mse
            loss = mse(targets, outputs).item()
            
            running_loss += loss * gbs.size(0)
            running_ssim += ssim_val * gbs.size(0)
            running_psnr += psnr_val * gbs.size(0)
            data_count += gbs.size(0)
            
            # test_batch_progress_bar.set_postfix({
            train_batch_progress_bar.set_postfix({
                'mse': loss,
                'ssim': ssim_val,
                'psnr': psnr_val
            })

        print(f"Epoch: {epoch+1}/{num_epochs} (Test)-> Loss:{(running_loss/data_count):2f}  SSIM:{(running_ssim/data_count):2f}  PSNR:{(running_psnr/data_count):2f}")

    # save model
    model_path = f"../../cache/saved_models/UNet_FiLM_{band}_ssim_all.pth"
    print(f"Saveing Model... -> {model_path}")
    model = model.cpu()
    torch.save(model.state_dict(), model_path)

    # clear cache
    torch.cuda.empty_cache()
            