In [194]:
import torch
import torch.nn as nn
from lightning.pytorch.utilities.model_summary import LayerSummary, ModelSummary
from torch.nn import HuberLoss
import sys
sys.path.append("/home/griffingoodwin/2025-HL-Flaring-MEGS-AI/") # go to parent dir
print(sys.path)
from flaring.forecasting.models.base_model import BaseModel
from torchvision.models import resnet18

class LinearIrradianceModel(BaseModel):
    def __init__(self, d_input, d_output, loss_func=HuberLoss(), lr=1e-4):
        self.n_channels = d_input
        self.outSize = d_output
        model = nn.Linear(2 * self.n_channels, self.outSize)
        super().__init__(model=model, loss_func=loss_func, lr=lr)

    def forward(self, x, **kwargs):


        # Debug: Print input shape
        #print(f"Input shape to LinearIrradianceModel.forward: {x.shape}")

        # Expect x shape: (batch_size, H, W, C)
        if len(x.shape) != 4:
            raise ValueError(f"Expected 4D input tensor (batch_size, H, W, C), got shape {x.shape}")
        if x.shape[-1] != self.n_channels:
            raise ValueError(f"AIA image has {x.shape[-1]} channels, expected {self.n_channels}")

        # Calculate mean and std across spatial dimensions (H,W)
        # First permute to (batch_size, C, H, W)
        x = x.permute(0, 3, 1, 2)

        # Now calculate mean/std across dimensions 2 and 3 (H,W)
        mean_irradiance = torch.mean(x, dim=(2, 3))  # Shape: (batch_size, n_channels)
        std_irradiance = torch.std(x, dim=(2, 3))    # Shape: (batch_size, n_channels)

        # Debug: Print shapes after mean and std
        #print(f"mean_irradiance shape: {mean_irradiance.shape}, std_irradiance shape: {std_irradiance.shape}")

        input_features = torch.cat((mean_irradiance, std_irradiance), dim=1)  # Shape: (batch_size, 2 * n_channels)
        #print(f"Input features shape to linear layer: {input_features.shape}")

        if input_features.shape[1] != 2 * self.n_channels:
            raise ValueError(f"Expected {2 * self.n_channels} features, got {input_features.shape[1]}")

        return self.model(input_features)

class HybridIrradianceModel(BaseModel):
    def __init__(self, d_input, d_output, cnn_model='resnet', ln_model=True, ln_params=None, lr=1e-4, cnn_dp=0.75, loss_func=HuberLoss()):
        super().__init__(model=None, loss_func=loss_func, lr=lr)
        self.n_channels = d_input
        self.outSize = d_output
        self.ln_params = ln_params
        self.ln_model = None
        if ln_model:
            self.ln_model = LinearIrradianceModel(d_input, d_output, loss_func=loss_func, lr=lr)
        if self.ln_params is not None and self.ln_model is not None:
            self.ln_model.model.weight = nn.Parameter(self.ln_params['weight'])
            self.ln_model.model.bias = nn.Parameter(self.ln_params['bias'])
        self.cnn_model = None
        self.cnn_lambda = 1.
        if cnn_model == 'resnet':
            #deeper model
            self.cnn_model = nn.Sequential(

                nn.Conv2d(d_input, 64, kernel_size=7, stride=1, padding=1),
                nn.BatchNorm2d(64),  # Add batch normalization
                nn.ReLU(),
                nn.Conv2d(64, 64, kernel_size=7, stride=1, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),

                nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                nn.Conv2d(128, 128, kernel_size=5, stride=1, padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),

                nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),

                nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(512),
                nn.ReLU(),
                nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(512),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),

                nn.AdaptiveAvgPool2d((2, 2)),
                nn.Linear(2048, 2048),
                nn.ReLU(),
                nn.Dropout(cnn_dp),
                nn.Linear(2048, 1024),
                nn.ReLU(),
                nn.Dropout(cnn_dp),
                nn.Linear(1024, 512),
                nn.ReLU(),
                nn.Dropout(cnn_dp),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Dropout(cnn_dp),
                nn.Linear(256, 128),
                nn.ReLU(),
                nn.Dropout(cnn_dp),
                nn.Linear(128, d_output),
            )

        elif cnn_model.startswith('efficientnet'):
            raise NotImplementedError("EfficientNet requires timm; replace with custom CNN or install timm")
        if self.ln_model is None and self.cnn_model is None:
            raise ValueError('Please pass at least one model.')

    def forward(self, x, sxr=None, **kwargs):
        # If x is a tuple (aia_img, sxr_val), extract the AIA image tensor
        if isinstance(x, (list, tuple)):
            x = x[0]

        # Expect x shape: (batch_size, H, W, C)
        if len(x.shape) != 4:
            raise ValueError(f"Expected 4D input tensor (batch_size, H, W, C), got shape {x.shape}")
        if x.shape[-1] != self.n_channels:
            raise ValueError(f"AIA image has {x.shape[-1]} channels, expected {self.n_channels}")

        # Convert to (batch_size, C, H, W) for CNN
        x_cnn = x.permute(0, 3, 1, 2)

        if self.ln_model is not None and self.cnn_model is not None:
            # For linear model, keep original (B,H,W,C) format
            return self.ln_model(x) + self.cnn_lambda * self.cnn_model(x_cnn)
        elif self.ln_model is not None:
            return self.ln_model(x)
        elif self.cnn_model is not None:
            return self.cnn_model(x_cnn)

    # def configure_optimizers(self):
    #     return torch.optim.Adam(self.parameters(), lr=self.lr)

    def set_train_mode(self, mode):
        if mode == 'linear':
            self.cnn_lambda = 0
            if self.cnn_model: self.cnn_model.eval()
            if self.ln_model: self.ln_model.train()
        elif mode == 'cnn':
            self.cnn_lambda = 0.01
            if self.cnn_model: self.cnn_model.train()
            if self.ln_model: self.ln_model.eval()
        elif mode == 'both':
            self.cnn_lambda = 0.01
            if self.cnn_model: self.cnn_model.train()
            if self.ln_model: self.ln_model.train()
        else:
            raise NotImplementedError(f'Mode not supported: {mode}')

['/Applications/PyCharm.app/Contents/plugins/python-ce/helpers/jupyter_debug', '/Applications/PyCharm.app/Contents/plugins/python-ce/helpers/pydev', '/tmp/w4scFYbshp', '/home/griffingoodwin/.pycharm_helpers/pydev', '/home/griffingoodwin/.pycharm_helpers/jupyter_debug', '/opt/conda/envs/Flare_detection/lib/python310.zip', '/opt/conda/envs/Flare_detection/lib/python3.10', '/opt/conda/envs/Flare_detection/lib/python3.10/lib-dynload', '', '/home/griffingoodwin/.local/lib/python3.10/site-packages', '/opt/conda/envs/Flare_detection/lib/python3.10/site-packages', '/tmp/tmpotxgyzav', '/opt/conda/envs/Flare_detection/lib/python3.10/site-packages/setuptools/_vendor', '/home/griffingoodwin/2025-HL-Flaring-MEGS-AI/', '/home/griffingoodwin/2025-HL-Flaring-MEGS-AI/', '/home/griffingoodwin/2025-HL-Flaring-MEGS-AI/', '/home/griffingoodwin/2025-HL-Flaring-MEGS-AI/', '/home/griffingoodwin/2025-HL-Flaring-MEGS-AI/', '/home/griffingoodwin/2025-HL-Flaring-MEGS-AI/', '/home/griffingoodwin/2025-HL-Flaring-ME

In [195]:
from pytorch_lightning.utilities.model_summary import ModelSummary


In [196]:
Model = HybridIrradianceModel(6,1,cnn_dp=.2)



In [197]:
Model

HybridIrradianceModel(
  (loss_func): HuberLoss()
  (ln_model): LinearIrradianceModel(
    (model): Linear(in_features=12, out_features=1, bias=True)
    (loss_func): HuberLoss()
  )
  (cnn_model): Sequential(
    (0): Conv2d(6, 64, kernel_size=(7, 7), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU()
    (10): Conv2d(128, 128, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tr

In [198]:
ModelSummary(Model)

  | Name      | Type                  | Params | Mode 
------------------------------------------------------------
0 | loss_func | HuberLoss             | 0      | train
1 | ln_model  | LinearIrradianceModel | 13     | train
2 | cnn_model | Sequential            | 12.2 M | train
------------------------------------------------------------
12.2 M    Trainable params
0         Non-trainable params
12.2 M    Total params
48.988    Total estimated model params size (MB)
49        Modules in train mode
0         Modules in eval mode

In [108]:
import flaring.forecasting.models.vision_transformer_custom as vit_custom

In [8]:
kwarg = {
    'embed_dim': 512,
    'num_channels': 6,
    'num_classes': 1,
    'patch_size': 16,
    'num_patches': 1024,
    'hidden_dim': 1024,
    'num_heads': 8,
    'num_layers': 6,
    'dropout': 0.2,
    'lr': 0.0001}

In [11]:
v = vit_custom.ViT(kwarg, sxr_norm=None)

In [12]:
ModelSummary(v)

  | Name  | Type              | Params | Mode 
----------------------------------------------------
0 | model | VisionTransformer | 13.9 M | train
----------------------------------------------------
13.9 M    Trainable params
0         Non-trainable params
13.9 M    Total params
55.722    Total estimated model params size (MB)
73        Modules in train mode
0         Modules in eval mode