In [None]:
%matplotlib inline
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from torch.cuda.amp import GradScaler, autocast

from model import (
    Discriminator1,
    HAT,
    weights_init_normal,
    SSIM,
    TVLoss,
    PerceptualLoss,
    LPIPSNet
)
from datasets import CustomDataset, load_data_with_augmentation

class ModelTrainer:
    def __init__(
        self,
        epochs,
        batch_size,
        smoothing_method=None,
        attention_type=None,
        senet_module=None,
        random_seed=42,
        lpips_weights_path=None,
        single_channel=True,
        config=None
    ):
        # Default configuration
        default_config = {
            'lr_D': 0.0004,
            'lr_U': 0.0002,
            'adversarial_weight_start': 0.0,
            'adversarial_weight_end': 1.0,
            'lpips_weight': 0.5,
            'patience': 30,
            'delta': 0.0001,
            'lpips_adapt_epochs': 10,
            'lpips_margin': 0.5,
            'grad_clip': 1.0,
            'use_lpips': True,
            'use_mixed_precision': True
        }

        if config is not None:
            default_config.update(config)
        self.cfg = default_config

        self.epochs = epochs
        self.batch_size = batch_size
        self.smoothing_method = smoothing_method
        self.attention_type = attention_type
        self.senet_module = senet_module
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.random_seed = random_seed
        self.single_channel = single_channel
        self.lpips_weights_path = lpips_weights_path

        # Load and prepare data
        (
            [self.lr_grace_05, self.trend05],
            [self.lr_grace_025, self.trend25],
            self.hr_aux,
            self.grace_scaler_05,
            self.grace_scaler_025,
            self.aux_scalers,
        ) = load_data_with_augmentation()

        # Apply smoothing if specified
        if self.smoothing_method:
            self.hr_aux = self.smoothing_method(self.hr_aux)

        dataset = CustomDataset(self.lr_grace_05, self.lr_grace_025, self.hr_aux)
        print(np.shape(self.hr_aux))

        # Split dataset into train/val/test (70/15/15)
        train_size = int(0.7 * len(dataset))
        val_size = int(0.15 * len(dataset))
        test_size = len(dataset) - train_size - val_size
        self.train_dataset, self.val_dataset, self.test_dataset = random_split(
            dataset,
            [train_size, val_size, test_size],
            generator=torch.Generator().manual_seed(self.random_seed),
        )

        self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)
        self.val_loader = DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4)
        self.test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4)

        # Initialize models
        self.discriminator = Discriminator1().to(self.device)
        self.HAT = HAT(
            in_channels=self.hr_aux.shape[-1] + 1,
            out_channels=1,
            channels=64,
            num_groups=4,
            num_habs=6,
            window_size=8,
            num_heads=8,
            upscale_factor=4,
            device=self.device
        ).to(self.device)

        if self.senet_module:
            self.senet_module = self.senet_module.to(self.device)
        else:
            self.senet_module = None

        # Initialize weights
        self.discriminator.apply(weights_init_normal)
        self.HAT.apply(weights_init_normal)
        if self.senet_module:
            self.senet_module.apply(weights_init_normal)

        # LPIPS model
        self.lpips = LPIPSNet(single_channel=self.single_channel, normalize_inputs=True).to(self.device)
        if self.lpips_weights_path:
            self.lpips.load_lpips_weights(self.lpips_weights_path)

        # Optimizers
        hat_parameters = list(self.HAT.parameters())
        if self.senet_module:
            hat_parameters += list(self.senet_module.parameters())

        self.optimizer_D = torch.optim.AdamW(
            self.discriminator.parameters(),
            lr=self.cfg['lr_D'],
            betas=(0.5, 0.999),
            weight_decay=1e-4,
        )
        self.optimizer_U = torch.optim.AdamW(
            hat_parameters,
            lr=self.cfg['lr_U'],
            betas=(0.5, 0.999),
            weight_decay=1e-4,
        )

        # Schedulers
        self.scheduler_D = CosineAnnealingWarmRestarts(self.optimizer_D, T_0=10, T_mult=2, eta_min=1e-6)
        self.scheduler_U = CosineAnnealingWarmRestarts(self.optimizer_U, T_0=10, T_mult=2, eta_min=1e-6)

        # Loss functions
        self.adversarial_loss = nn.BCEWithLogitsLoss()
        self.pixelwise_loss = nn.MSELoss()
        self.ssim_loss = SSIM(window_size=11, size_average=True).to(self.device)
        self.tv_loss = TVLoss(weight=1e-5).to(self.device)
        self.perceptual_loss = PerceptualLoss(use_gpu=torch.cuda.is_available())

        # Gradient scaler for mixed-precision
        self.scaler = GradScaler(enabled=self.cfg['use_mixed_precision'])

    def adapt_lpips(self):
        """
        Adapt LPIPS using lr_grace_025 data:
        Positive pairs: (x, x)
        Negative pairs: (x, degraded_x) where degraded_x is down-up sampled to original size.
        """
        self.lpips.train()
        # Freeze backbone, only train lin layers
        for p in self.lpips.parameters():
            p.requires_grad = False
        for lin in self.lpips.lins:
            for p in lin.parameters():
                p.requires_grad = True

        lpips_optimizer = torch.optim.Adam([p for lin in self.lpips.lins for p in lin.parameters()], lr=1e-4)
        adapt_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)

        margin = self.cfg['lpips_margin']
        for ep in range(self.cfg['lpips_adapt_epochs']):
            total_loss = 0.0
            count = 0
            for lr_grace_05, lr_grace_025, hr_aux in adapt_loader:
                # lr_grace_025 is your domain data
                x = lr_grace_025.to(self.device)

                # Positive pair: (x, x)
                pos_dist = self.lpips(x, x)

                # Negative pair: create degraded_x by downscaling and then upscaling back to original shape
                B, C, H, W = x.shape
                degraded_x_down = F.interpolate(x, scale_factor=0.5, mode='bicubic', align_corners=False)
                degraded_x = F.interpolate(degraded_x_down, size=(H, W), mode='bicubic', align_corners=False)

                neg_dist = self.lpips(x, degraded_x)

                # Loss: minimize pos_dist and ensure neg_dist > margin
                # L = pos_dist + max(0, margin - neg_dist)
                loss = pos_dist + F.relu(margin - neg_dist)

                lpips_optimizer.zero_grad()
                loss.backward()
                lpips_optimizer.step()

                total_loss += loss.item()
                count += 1

            avg_loss = total_loss / max(count, 1)
            print(f"LPIPS Adaptation Epoch [{ep+1}/{self.cfg['lpips_adapt_epochs']}], Loss: {avg_loss:.4f}")

        # Freeze LPIPS after adaptation
        self.lpips.eval()
        for p in self.lpips.parameters():
            p.requires_grad = False
        print("LPIPS adaptation complete. LPIPS is now domain-adjusted.")

    def train_main_model(self):
        train_losses_G = []
        train_losses_D = []
        val_losses_G = []

        patience = self.cfg['patience']
        delta = self.cfg['delta']
        best_loss = float('inf')
        trigger_times = 0

        hat_parameters = list(self.HAT.parameters())
        if self.senet_module:
            hat_parameters += list(self.senet_module.parameters())

        for epoch in range(self.epochs):
            epoch_loss_G = 0.0
            epoch_loss_D = 0.0

            self.HAT.train()
            self.discriminator.train()
            if self.senet_module:
                self.senet_module.train()

            adv_start = self.cfg['adversarial_weight_start']
            adv_end = self.cfg['adversarial_weight_end']
            adv_weight = adv_start + (adv_end - adv_start) * (epoch / self.epochs)

            for lr_grace_05, lr_grace_025, hr_aux in self.train_loader:
                lr_grace = F.interpolate(lr_grace_05, scale_factor=0.5, mode='bicubic', align_corners=False).to(self.device)
                lr_grace_025 = lr_grace_025.to(self.device)
                hr_aux = hr_aux.to(self.device)

                downsampled_aux = F.interpolate(hr_aux, scale_factor=0.25, mode='bicubic', align_corners=False)
                combined_input = torch.cat([lr_grace, downsampled_aux], dim=1)
                if self.senet_module:
                    combined_input = self.senet_module(combined_input)

                with autocast(enabled=self.cfg['use_mixed_precision']):
                    hr_generated = self.HAT(combined_input)

                    # Discriminator training
                    self.optimizer_D.zero_grad()
                    real_output = self.discriminator(lr_grace_025)
                    fake_output = self.discriminator(hr_generated.detach())
                    real_labels = torch.ones_like(real_output, device=self.device)
                    fake_labels = torch.zeros_like(fake_output, device=self.device)

                    loss_D_real = self.adversarial_loss(real_output, real_labels)
                    loss_D_fake = self.adversarial_loss(fake_output, fake_labels)
                    loss_D = (loss_D_real + loss_D_fake) / 2

                self.scaler.scale(loss_D).backward()
                self.scaler.unscale_(self.optimizer_D)
                if self.cfg['grad_clip'] > 0:
                    torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.cfg['grad_clip'])
                self.scaler.step(self.optimizer_D)

                with autocast(enabled=self.cfg['use_mixed_precision']):
                    # Generator training
                    self.optimizer_U.zero_grad()
                    fake_output = self.discriminator(hr_generated)
                    loss_G_adv = self.adversarial_loss(fake_output, real_labels) * adv_weight
                    loss_G_pixel = self.pixelwise_loss(hr_generated, lr_grace_025)
                    loss_G_tv = self.tv_loss(hr_generated)
                    loss_G_perceptual = self.perceptual_loss(hr_generated, lr_grace_025)
                    loss_LPIPS = self.lpips(hr_generated, lr_grace_025) if self.cfg['use_lpips'] else torch.tensor(0.0, device=self.device)

                    loss_G = loss_G_pixel + loss_G_perceptual + loss_G_tv + loss_G_adv + self.cfg['lpips_weight'] * loss_LPIPS

                self.scaler.scale(loss_G).backward()
                self.scaler.unscale_(self.optimizer_U)
                if self.cfg['grad_clip'] > 0:
                    torch.nn.utils.clip_grad_norm_(hat_parameters, self.cfg['grad_clip'])
                self.scaler.step(self.optimizer_U)
                self.scaler.update()

                epoch_loss_G += loss_G.item()
                epoch_loss_D += loss_D.item()

            avg_epoch_loss_G = epoch_loss_G / len(self.train_loader)
            avg_epoch_loss_D = epoch_loss_D / len(self.train_loader)
            train_losses_G.append(avg_epoch_loss_G)
            train_losses_D.append(avg_epoch_loss_D)

            # Validation phase
            val_loss_G = 0.0
            self.HAT.eval()
            if self.senet_module:
                self.senet_module.eval()
            with torch.no_grad():
                for lr_grace_05, lr_grace_025, hr_aux in self.val_loader:
                    lr_grace = F.interpolate(
                        lr_grace_05, scale_factor=0.5, mode='bicubic', align_corners=False
                    ).to(self.device)
                    lr_grace_025 = lr_grace_025.to(self.device)
                    hr_aux = hr_aux.to(self.device)

                    downsampled_aux = F.interpolate(hr_aux, scale_factor=0.25, mode='bicubic', align_corners=False)
                    combined_input = torch.cat([lr_grace, downsampled_aux], dim=1)

                    if self.senet_module:
                        combined_input = self.senet_module(combined_input)

                    hr_generated = self.HAT(combined_input)

                    loss_G_pixel = self.pixelwise_loss(hr_generated, lr_grace_025)
                    loss_G_tv = self.tv_loss(hr_generated)
                    loss_G_perceptual = self.perceptual_loss(hr_generated, lr_grace_025)
                    val_loss = loss_G_pixel + loss_G_tv + loss_G_perceptual
                    if self.cfg['use_lpips']:
                        val_loss += self.cfg['lpips_weight'] * self.lpips(hr_generated, lr_grace_025)
                    val_loss_G += val_loss.item()

            val_loss_G /= len(self.val_loader)

            # Early Stopping Check
            if val_loss_G < best_loss - delta:
                best_loss = val_loss_G
                trigger_times = 0
                torch.save(self.HAT.state_dict(), 'best_model.pth')
            else:
                trigger_times += 1
                print(f'EarlyStopping: {trigger_times}/{patience} epochs with no improvement.')
                if trigger_times >= patience:
                    print('Early stopping triggered.')
                    self.HAT.load_state_dict(torch.load('best_model.pth'))
                    break

            # Update schedulers
            self.scheduler_D.step()
            self.scheduler_U.step()

            print(
                f"Epoch [{epoch+1}/{self.epochs}] | "
                f"Loss D: {avg_epoch_loss_D:.4f} | "
                f"Loss G: {avg_epoch_loss_G:.4f} | "
                f"Val Loss G: {val_loss_G:.4f} | "
                f"LR_D: {self.optimizer_D.param_groups[0]['lr']:.6f}, "
                f"LR_U: {self.optimizer_U.param_groups[0]['lr']:.6f}"
            )

        if trigger_times < patience:
            self.HAT.load_state_dict(torch.load('best_model.pth'))

        return train_losses_G, train_losses_D, val_losses_G

    def train(self):
        # Adapt LPIPS first if single-channel and using LPIPS
        if self.single_channel and self.cfg['use_lpips']:
            print("Adapting LPIPS to domain-specific single-channel lr_grace_025 data...")
            self.adapt_lpips()
        return self.train_main_model()

    def evaluate(self):
        self.HAT.eval()
        if self.senet_module:
            self.senet_module.eval()
        with torch.no_grad():
            preds = []
            trues = []
            for lr_grace_05, lr_grace_025, hr_aux in self.test_loader:
                lr_grace = F.interpolate(
                    lr_grace_05, scale_factor=0.5, mode='bicubic', align_corners=False
                ).to(self.device)
                lr_grace_025 = lr_grace_025.to(self.device)
                hr_aux = hr_aux.to(self.device)

                downsampled_aux = F.interpolate(
                    hr_aux, scale_factor=0.25, mode='bicubic', align_corners=False
                )
                combined_input = torch.cat([lr_grace, downsampled_aux], dim=1)

                if self.senet_module:
                    combined_input = self.senet_module(combined_input)

                hr_generated = self.HAT(combined_input)

                preds.append(hr_generated.cpu().numpy())
                trues.append(lr_grace_025.cpu().numpy())

            preds = np.concatenate(preds, axis=0).reshape(-1)
            trues = np.concatenate(trues, axis=0).reshape(-1)

            cc = np.corrcoef(trues, preds)[0, 1]
            mse = mean_squared_error(trues, preds)
            mae = mean_absolute_error(trues, preds)
            r2 = r2_score(trues, preds)

            print(
                f"Test MSE: {mse:.6f}, Test MAE: {mae:.6f}, Test R²: {r2:.6f}, Test CC: {cc:.6f}"
            )

        return preds, trues, r2
    

In [2]:
# Set parameters
epochs = 1
batch_size = 16
# Instantiate the modu


# Define smoothing method
#smoothing_method = ModelTrainer(epochs, batch_size, 1024).smooth_data_gaussian
smoothing_method = None
# Define modules
#attention_module = AttentionModule(input_channels=40, output_channels=40)
#senet_module = SqueezeExcitation(input_channels=40, reduction_ratio=8)

# Train the baseline model without any additional module
# Release GPU memory

# Train the model with Attention
model1 = ModelTrainer(epochs=epochs, batch_size=batch_size, smoothing_method=smoothing_method)
train_losses_G1, train_losses_D1, val_losses_G= model1.train()
preds1, trues1, r2_1 = model1.evaluate()
torch.cuda.empty_cache()

(181, 90, 44)
(181, 180, 88, 1)
[509.70157107]
[-32767.]
[-32767.]
[-32767.]
Combined HR Aux Data Shape: (181, 180, 88, 45)
0.0
65.5
Sliced HR Aux Data Shape: (181, 180, 88, 45)
-5.350948318234112
(180, 88, 7)
最大误差: 8.881784197001252e-16
最大误差: 8.881784197001252e-16
(543, 180, 88, 45)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Adapting LPIPS to domain-specific single-channel lr_grace_025 data...
LPIPS Adaptation Epoch [1/10], Loss: 0.4899
LPIPS Adaptation Epoch [2/10], Loss: 0.4852
LPIPS Adaptation Epoch [3/10], Loss: 0.4806
LPIPS Adaptation Epoch [4/10], Loss: 0.4759
LPIPS Adaptation Epoch [5/10], Loss: 0.4712
LPIPS Adaptation Epoch [6/10], Loss: 0.4665
LPIPS Adaptation Epoch [7/10], Loss: 0.4619
LPIPS Adaptation Epoch [8/10], Loss: 0.4572
LPIPS Adaptation Epoch [9/10], Loss: 0.4525
LPIPS Adaptation Epoch [10/10], Loss: 0.4478
LPIPS adaptation complete. LPIPS is now domain-adjusted.
Epoch [1/1] | Loss D: 0.6720 | Loss G: 2.9749 | Val Loss G: 2.7114 | LR_D: 0.000390, LR_U: 0.000195
(16, 1, 180, 88)


TypeError: Invalid shape (16, 1, 180, 88) for image data

In [None]:
torch.save(model1.HAT.state_dict(), 'model11_HAT.pth')