In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.utils.data as t_data

In [5]:
from utils import WeatherDatasetSimple
from models import InterpolationModel, SRCNN
from model_eval import eval_model_single_frame, eval_results

In [3]:
device = 'cuda'

In [5]:
train_dataset = WeatherDatasetSimple('data/train_T2_V10_U10_d02_new_lr_npy', 'data/train_T2_V10_U10_d02_new_hr_npy', normalization=False,)
valid_dataset = WeatherDatasetSimple('data/true_val_T2_V10_U10_d02_new_lr_npy', 'data/true_val_T2_V10_U10_d02_new_hr_npy', normalization=False,)
test_dataset = WeatherDatasetSimple('data/val_T2_V10_U10_d02_new_lr', 'data/val_T2_V10_U10_d02_new_hr', normalization=False,)

In [6]:
batch_size = 72
train_dataloader = t_data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataloader = t_data.DataLoader(valid_dataset, batch_size=1, shuffle=True)
test_dataloader = t_data.DataLoader(test_dataset, batch_size=1, shuffle=True)

In [6]:
mns_lr, mns_hr = [], []
for x in train_dataset:
    mns_lr.append( np.mean(x[0], axis=(1,2)) )
    mns_hr.append( np.mean(x[1], axis=(1,2)) )

print("High Res data : mean ", np.mean(mns_hr, 0), ' std: ', np.std(mns_hr, 0) )
print(" Low Res data : mean ", np.mean(mns_lr, 0), ' std: ', np.std(mns_lr, 0) )

High Res data : mean  [0.6636226  0.5109665  0.50087802]  std:  [0.11283001 0.06067654 0.06711711]
 Low Res data : mean  [0.66322656 0.51177405 0.50098086]  std:  [0.11245459 0.06100276 0.06742526]


Data at different channels have pretty similar distributions and is contained within [0; 1], so we decided not to standartidize it.

## Bicubic Interpolation

In [29]:
i_model = InterpolationModel().to(device)

In [30]:
result = eval_model_single_frame(i_model, test_dataloader, device=device)
print("Interpolation model results: ")
print("MSE:  ", result['mse'])
print("MAE:  ", result['mae'])
print("PSNR: ", result['psnr'])

Interpolation model results: 
MSE:   0.0002186244701581322
MAE:   0.009801893723107144
PSNR:  29.82986338738031


## SRCNN

In [6]:
import os
import copy
from tqdm import tqdm

In [7]:
srcnn_model = SRCNN(num_channels=3).to(device)

In [15]:
lr = 1e-4
num_epochs = 200
criterion = nn.MSELoss()
optimizer = torch.optim.Adam([
    {'params': srcnn_model.conv1.parameters()},
    {'params': srcnn_model.conv2.parameters()},
    {'params': srcnn_model.conv3.parameters(), 'lr': 0.1 * lr}
], lr=lr)
outputs_dir = 'Models/Baselines/TRAIN'

In [9]:
best_weights = copy.deepcopy(srcnn_model.state_dict())
best_epoch = 0
best_psnr = 0.0

In [None]:
for epoch in range(30, num_epochs):
    srcnn_model.train()

    with tqdm(total=(len(train_dataset) - len(train_dataset) % batch_size)) as t:
        t.set_description('epoch: {}/{}'.format(epoch, num_epochs - 1))
        cur_loss = 0.0
        for lr, hr in train_dataloader:
            lr = lr.to(device)
            hr = hr.to(device)

            hr_pred = srcnn_model(lr)

            loss = criterion(hr_pred, hr)
            optimizer.zero_grad()
            cur_loss += loss.item()  / (len(train_dataset) // batch_size)
            
            loss.backward()
            optimizer.step()

            t.set_postfix(loss='{:.6f}'.format(cur_loss))
            t.update(len(lr))

        #torch.save(model.state_dict(), os.path.join(outputs_dir, 'epoch_{}.pth'.format(epoch)))
        srcnn_model.eval()
        epoch_result = eval_model_single_frame(srcnn_model, valid_dataloader,  device=device)
        print('Epoch ', epoch, ' train MSE loss ', cur_loss, 'eval psnr: {:.2f}'.format(epoch_result['psnr']))

        if epoch_result['psnr'] > best_psnr:
            best_epoch = epoch
            best_psnr = epoch_result['psnr']
            best_weights = copy.deepcopy(srcnn_model.state_dict())

print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
torch.save(best_weights, os.path.join(outputs_dir, 'SRCNN_best.pth'))

In [26]:
srcnn_model = SRCNN(num_channels=3).to(device)
srcnn_model.load_state_dict(torch.load('Models/Baselines/TRAIN/SRCNN_best.pth'))
srcnn_model.eval()

SRCNN(
  (conv1): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (conv2): Conv2d(64, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv3): Conv2d(32, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (relu): ReLU(inplace=True)
)

In [27]:
result = eval_model_single_frame(srcnn_model, test_dataloader, device=device)
print("SRCNN model best results: ")
print("MSE:  ", result['mse'])
print("MAE:  ", result['mae'])
print("PSNR: ", result['psnr'])

SRCNN model best results: 
MSE:   0.00019476998141337054
MAE:   0.009489991709880274
PSNR:  30.034115645288075


## SRFBN

In [8]:
test_dataset = WeatherDatasetSimple('data/val_T2_V10_U10_d02_new_hr', 'Baseline_predicts/SRFBN')
check_dataloader = t_data.DataLoader(test_dataset, batch_size=1, shuffle=False)

result = eval_results(check_dataloader)
print("Interpolation model results: ")
print("MSE:  ", result['mse'])
print("MAE:  ", result['mae'])
print("PSNR: ", result['psnr'])

Interpolation model results: 
MSE:   0.00022177140316468976
MAE:   0.009790236856362954
PSNR:  29.71257377434382


## ESRGAN 

In [10]:
test_dataset = WeatherDatasetSimple('data/val_T2_V10_U10_d02_new_hr', 'Baseline_predicts/ESRGAN')
check_dataloader = t_data.DataLoader(test_dataset, batch_size=1, shuffle=False)

result = eval_results(check_dataloader)
print("Interpolation model results: ")
print("MSE:  ", result['mse'])
print("MAE:  ", result['mae'])
print("PSNR: ", result['psnr'])

Interpolation model results: 
MSE:   0.0013496031888309233
MAE:   0.02381261536051019
PSNR:  23.814576119438854
