In [9]:
%matplotlib inline
from torchviz import make_dot
import torch
from torch import optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from model import OriginalRelationshipLearner, Discriminator1, FlexibleUpsamplingModule, weights_init_normal, SSIM, TVLoss, PerceptualLoss
from datasets import CustomDataset, load_data_with_augmentation
import torch.nn.functional as F
from utils import plot_results
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.model_selection import train_test_split
from scipy.ndimage import gaussian_filter, median_filter
from scipy.signal import savgol_filter
import matplotlib.pyplot as plt
import pandas as pd
from taylorDiagram import TaylorDiagram
from torchvision import models
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import copy
def visualize_model(model, filename,x):
    #x = torch.randn(input_size)
    y = model(x)
    dot = make_dot(y, params=dict(model.named_parameters()), show_attrs=True, show_saved=True)
    dot.format = 'png'
    dot.render(filename, cleanup=True)
    print(f"Model architecture saved as '{filename}.png'")
class ModelTrainer:
    def __init__(self, epochs, batch_size, relationship_learner, relationship_output_channels, smoothing_method=None, attention=None, senet=None, rand=42):
        self.epochs = epochs
        self.batch_size = batch_size
        #self.relationship_learner = relationship_learner
        self.relationship_output_channels = relationship_output_channels
        self.smoothing_method = smoothing_method
        self.attention = attention
        self.senet = senet
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.rand=rand
        # Load and prepare data
        [self.lr_grace_05,self.trend05], [self.lr_grace_025,self.trend25], self.hr_aux, self.grace_scaler_05, self.grace_scaler_025, self.aux_scalers = load_data_with_augmentation()
        
        # Apply data smoothing to hr_aux if smoothing_method is specified
        if self.smoothing_method:
            self.hr_aux = self.smoothing_method(self.hr_aux)
        else:
            self.hr_aux = self.hr_aux
        
        # Split data into training and testing sets
        # Ensure data is sorted by time before splitting
        split_index = int(len(self.lr_grace_05) * 0.8)  # 80% training, 20% testing
        self.train_lr_grace_05, self.test_lr_grace_05 = self.lr_grace_05[:split_index], self.lr_grace_05[split_index:]
        self.train_lr_grace_025, self.test_lr_grace_025 = self.lr_grace_025[:split_index], self.lr_grace_025[split_index:]
        self.train_hr_aux, self.test_hr_aux = self.hr_aux[:split_index], self.hr_aux[split_index:]
        self.train_lr_grace_05, self.test_lr_grace_05, self.train_lr_grace_025, self.test_lr_grace_025, self.train_hr_aux, self.test_hr_aux = train_test_split(
            self.lr_grace_05, self.lr_grace_025, self.hr_aux, test_size=0.2, random_state=self.rand)
        
        # Create datasets and dataloaders
        self.train_dataset = CustomDataset(self.train_lr_grace_05, self.train_lr_grace_025, self.train_hr_aux)
        self.test_dataset = CustomDataset(self.test_lr_grace_05, self.test_lr_grace_025, self.test_hr_aux)
        
        self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size)
        self.test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size)
        
        # Initialize models
        #self.relationship_learner = self.relationship_learner.to(self.device)
        self.discriminator = Discriminator1().to(self.device)
        self.upsampling_module = FlexibleUpsamplingModule(input_channels=self.hr_aux.shape[-1]+1,attention_type=self.attention).to(self.device)
        
        self.flag=self.attention
        self.attention=None
        # Initialize optional modules
        if self.attention:
            self.attention_module = self.attention.to(self.device)
        else:
            self.attention_module = None
            
        if self.senet:
            self.senet_module = self.senet.to(self.device)
        else:
            self.senet_module = None
        
        # Initialize weights
        #self.relationship_learner.apply(weights_init_normal)
        self.discriminator.apply(weights_init_normal)
        self.upsampling_module.apply(weights_init_normal)
        if self.attention_module:
            self.attention_module.apply(weights_init_normal)
        if self.senet_module:
            self.senet_module.apply(weights_init_normal)
        
        # Optimizers
        #self.optimizer_RL = optim.Adam(self.relationship_learner.parameters(), lr=0.0002)
        hat_parameters = list(self.upsampling_module.parameters())

        if self.attention_module:
            hat_parameters += list(self.attention_module.parameters())

        if self.senet_module:
            hat_parameters += list(self.senet_module.parameters())

        # Optimizers
        self.optimizer_D = optim.AdamW(self.discriminator.parameters(), lr=0.0004, betas=(0.5, 0.999), weight_decay=1e-4)
        self.optimizer_U = optim.AdamW(hat_parameters, lr=0.0002, betas=(0.5, 0.999), weight_decay=1e-4)

        # Learning Rate Schedulers
        self.scheduler_D = CosineAnnealingWarmRestarts(self.optimizer_D, T_0=10, T_mult=2, eta_min=1e-6)
        self.scheduler_U = CosineAnnealingWarmRestarts(self.optimizer_U, T_0=10, T_mult=2, eta_min=1e-6)
        
        # Loss functions
        self.adversarial_loss = torch.nn.BCEWithLogitsLoss()
        self.pixelwise_loss = torch.nn.MSELoss()
        self.ssim_loss = SSIM(window_size=11, size_average=True).to(self.device)
        self.tv_loss = TVLoss(weight=1e-5).to(self.device)
        self.perceptual_loss = PerceptualLoss(use_gpu=torch.cuda.is_available())
        #self.perceptual_loss = PerceptualLoss([1, 6, 11, 20], use_gpu=torch.cuda.is_available())
    def smooth_data_gaussian(self, data, sigma=2):
        return gaussian_filter(data, sigma=sigma)

    def smooth_data_median(self, data, size=3):
        return median_filter(data, size=size)

    def smooth_data_savitzky_golay(self, data, window_length=5, polyorder=2):
        return savgol_filter(data, window_length, polyorder)

    def train(self):
        train_losses_G = []
        train_losses_D = []
        patience = 20  # Early stopping patience
        min_delta = 0  # Minimum change in monitored value to qualify as improvement
        trigger_times = 0  # Counter for early stopping
        best_loss = float('inf')

        for epoch in range(self.epochs):
            epoch_loss_G = 0
            epoch_loss_D = 0

            # Training phase
            self.upsampling_module.train()
            self.discriminator.train()
            if self.attention_module:
                self.attention_module.train()
            if self.senet_module:
                self.senet_module.train()

            for lr_grace_05, lr_grace_025, hr_aux in self.train_loader:
                lr_grace = F.interpolate(lr_grace_05, scale_factor=0.5, mode='bicubic', align_corners=False)
                lr_grace, hr_aux = lr_grace.to(self.device), hr_aux.to(self.device)
                lr_grace_025 = lr_grace_025.to(self.device)

                # Combine lr_grace and downsampled hr_aux
                downsampled_aux = F.interpolate(hr_aux, scale_factor=0.25, mode='bicubic', align_corners=False)
                combined_input = torch.cat([lr_grace, downsampled_aux], dim=1)

                # Learn relationship features
                relationship_features = combined_input
                # Apply attention or SENet if exists
                if self.attention_module:
                    relationship_features = self.attention_module(relationship_features)
                elif self.senet_module:
                    relationship_features = self.senet_module(relationship_features)

                # Generate HR result using HAT module
                hr_generated = self.upsampling_module(relationship_features)

                # Discriminator training
                self.optimizer_D.zero_grad()
                real_output = self.discriminator(lr_grace_025)
                fake_output = self.discriminator(hr_generated.detach())
                real_labels = torch.ones_like(real_output, device=self.device)
                fake_labels = torch.zeros_like(fake_output, device=self.device)

                loss_D_real = self.adversarial_loss(real_output, real_labels)
                loss_D_fake = self.adversarial_loss(fake_output, fake_labels)
                loss_D = (loss_D_real + loss_D_fake) / 2
                loss_D.backward()
                self.optimizer_D.step()

                # Generator training
                self.optimizer_U.zero_grad()
                fake_output = self.discriminator(hr_generated)
                loss_G_adv = self.adversarial_loss(fake_output, real_labels)
                loss_G_pixel = self.pixelwise_loss(hr_generated, lr_grace_025)
                loss_G_ssim = 1 - self.ssim_loss(hr_generated, lr_grace_025)
                loss_G_tv = self.tv_loss(hr_generated)
                loss_G_perceptual = self.perceptual_loss(hr_generated, lr_grace_025)
                loss_weight = epoch / self.epochs  # Linearly increase adversarial weight  HAT
                loss_G = (1 - loss_weight) * loss_G_pixel + loss_weight * loss_G_adv + loss_G_tv + loss_G_perceptual
                loss_G.backward()
                self.optimizer_U.step()

                epoch_loss_G += loss_G.item()
                epoch_loss_D += loss_D.item()

            # Average losses over the epoch
            avg_epoch_loss_G = epoch_loss_G / len(self.train_loader)
            avg_epoch_loss_D = epoch_loss_D / len(self.train_loader)

            # Early Stopping Check
            if avg_epoch_loss_G < best_loss - min_delta:
                best_loss = avg_epoch_loss_G
                trigger_times = 0
                # Optionally save the best model state
                torch.save(self.upsampling_module.state_dict(), 'best_model.pth')
            else:
                trigger_times += 1
                print(f'EarlyStopping: {trigger_times}/{patience} epochs with no improvement.')
                if trigger_times >= patience:
                    print('Early stopping triggered.')
                    # Load the best model state before stopping
                    self.upsampling_module.load_state_dict(torch.load('best_model.pth'))
                    return train_losses_G, train_losses_D

            # Update the schedulers at the end of the epoch
            self.scheduler_D.step()
            self.scheduler_U.step()
            if self.attention_module:
                self.scheduler_A.step()
            if self.senet_module:
                self.scheduler_SE.step()

            train_losses_G.append(avg_epoch_loss_G)
            train_losses_D.append(avg_epoch_loss_D)

            print(f'Epoch [{epoch+1}/{self.epochs}], Loss D: {avg_epoch_loss_D:.4f}, Loss G: {avg_epoch_loss_G:.4f}')

        # Load the best model at the end of training
        self.upsampling_module.load_state_dict(torch.load('best_model.pth'))
        return train_losses_G, train_losses_D

    def evaluate(self):
       # self.relationship_learner.eval()
        self.upsampling_module.eval()
        if self.attention_module:
            self.attention_module.eval()
        if self.senet_module:
            self.senet_module.eval()
        with torch.no_grad():
            preds = []
            trues = []
            bs=0
            for lr_grace_05, lr_grace_025, hr_aux in self.test_loader:
                
                bs=bs+1
                if bs==-1 :
                    lr_grace_05, lr_grace_025, hr_aux = lr_grace_05.to(self.device), lr_grace_025.to(self.device), hr_aux.to(self.device)
                    
                    # Combine lr_grace and downsampled hr_aux
                    combined_input = torch.cat([lr_grace_025, hr_aux], dim=1)

                    # Learn relationship features
                    #relationship_features = self.relationship_learner(combined_input)
                    relationship_features = combined_input
                    # Apply attention or SENet if exists
                    if self.attention_module:
                        relationship_features = self.attention_module(relationship_features)
                    elif self.senet_module:
                        relationship_features = self.senet_module(relationship_features)

                    # Generate HR result using improved upsampling module
                    hr_generated = self.upsampling_module(relationship_features)

                    plot_results(lr_grace_05[0,0].cpu(), hr_generated[0,0].cpu(), lr_grace_025[0,0].cpu(), True)
                # Save predictions and true values for metrics calculation
                lr_grace_05, lr_grace_025, hr_aux = lr_grace_05.to(self.device), lr_grace_025.to(self.device), hr_aux.to(self.device)
                lr_grace = F.interpolate(lr_grace_05, scale_factor=0.5, mode='bicubic', align_corners=False)
                
                # Combine lr_grace and downsampled hr_aux
                downsampled_aux = F.interpolate(hr_aux, scale_factor=0.25, mode='bicubic', align_corners=False)
                combined_input = torch.cat([lr_grace, downsampled_aux], dim=1)

                # Learn relationship features
                #relationship_features = self.relationship_learner(combined_input)
                relationship_features = combined_input
                # Apply attention or SENet if exists
                if self.attention_module:
                    relationship_features = self.attention_module(relationship_features)
                elif self.senet_module:
                    relationship_features = self.senet_module(relationship_features)

                # Generate HR result using improved upsampling module
                hr_generated = self.upsampling_module(relationship_features)

                # Upsample lr_grace to create the ground truth for hr_generated
                hr_grace_upsampled = lr_grace_025
                preds.append(hr_generated.cpu().numpy())
                trues.append(hr_grace_upsampled.cpu().numpy())

            # Compute evaluation metrics
            preds = np.concatenate(preds, axis=0).reshape(-1)
            trues = np.concatenate(trues, axis=0).reshape(-1)

            cc=np.corrcoef(trues, preds)
            mse = mean_squared_error(trues, preds)
            mae = mean_absolute_error(trues, preds)
            r2 = r2_score(trues, preds)

            print(f"Test MSE: {mse}, Test MAE: {mae}, Test RÂ²: {r2}, Test cc: {cc}")

        return preds, trues, r2


In [10]:
# Set parameters
epochs = 150
batch_size = 12
# Instantiate the modu


# Define smoothing method
smoothing_method = ModelTrainer(epochs, batch_size, OriginalRelationshipLearner(40), 1024).smooth_data_gaussian
smoothing_method = None
# Define modules
#attention_module = AttentionModule(input_channels=40, output_channels=40)
#senet_module = SqueezeExcitation(input_channels=40, reduction_ratio=8)

# Train the baseline model without any additional module
# Release GPU memory

# Train the model with Attention
model1 = ModelTrainer(epochs=epochs, batch_size=batch_size, relationship_learner=OriginalRelationshipLearner(40), relationship_output_channels=1024, smoothing_method=smoothing_method, attention='senet')
train_losses_G1, train_losses_D1 = model1.train()
preds1, trues1, r2_1 = model1.evaluate()
torch.cuda.empty_cache()
model2 = ModelTrainer(epochs=epochs, batch_size=batch_size, relationship_learner=OriginalRelationshipLearner(40), relationship_output_channels=1024, smoothing_method=smoothing_method, attention='senet', rand=26)
train_losses_G2, train_losses_D2 = model2.train()
preds2, trues2, r2_2 = model2.evaluate()
torch.cuda.empty_cache()

(181, 90, 44)


FileNotFoundError: [Errno 2] No such file or directory: '/media/xy/data_op/ERA5/11/'

In [None]:
torch.save(model1.upsampling_module.state_dict(), 'model11_upsampling_module.pth')
torch.save(model2.upsampling_module.state_dict(), 'model12_upsampling_module.pth')