In [1]:
import torch
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights

def get_modified_resnet(in_channels=4, model_name='resnet18'):
    """Loads a pre-trained ResNet and modifies the first conv layer for 4 input channels."""
    
    # 1. Load the pre-trained model weights
    # We use ResNet18_Weights.IMAGENET1K_V1 which contains ImageNet weights
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    
    # Get the existing Conv1 layer (which takes 3 input channels)
    original_conv1 = model.conv1
    
    # 2. Create a new Conv1 layer for 4 input channels
    # Keep all other parameters (out_channels=64, kernel_size, stride, padding, bias) the same
    new_conv1 = nn.Conv2d(
        in_channels, 
        original_conv1.out_channels, 
        kernel_size=original_conv1.kernel_size, 
        stride=original_conv1.stride, 
        padding=original_conv1.padding, 
        bias=original_conv1.bias
    )
    
    # 3. Transfer the pre-trained weights to the new layer
    # Since the new layer has 4 channels and the old one has 3, we copy the weights
    # for the first 3 channels and initialize the 4th channel's weights randomly (or to zeros).
    # We copy the weights using the first 3 channels of the original weights.
    new_conv1.weight.data[:, :3, :, :] = original_conv1.weight.data
    
    # Initialize the new 4th channel's weights (e.g., by averaging the first 3 or setting to zero)
    # Using the mean of the first three channels is a common heuristic:
    new_conv1.weight.data[:, 3, :, :] = original_conv1.weight.data[:, :3, :, :].mean(dim=1)

    # 4. Replace the original layer with the new one
    model.conv1 = new_conv1
    
    # 5. Modify the final layer (Fully Connected) to output the desired embedding size
    # We cut off the classification head and just use the feature extractor backbone.
    # The output size of the backbone is typically 512 for ResNet18.
    model = nn.Sequential(*list(model.children())[:-1]) # Remove the classification layer
    
    return model

In [2]:
# Assuming you have a working get_dataloader function and device setup

@torch.no_grad()
def extract_foundation_features(model, dataloader, device):
    """
    Extracts features using the modified pre-trained ResNet.
    """
    model.eval()
    all_features = []
    
    pbar = tqdm(dataloader, desc="Extracting ResNet Embeddings", unit="batch")
    
    for x in pbar:
        if isinstance(x, (list, tuple)):
            x = x[0]
            
        x = x.to(torch.float32).to(device)
        
        # Forward pass through the ResNet backbone
        features = model(x)
        
        # Features will be (Batch, 512, 1, 1). We squeeze out the 1x1 dimensions.
        features = features.squeeze()
        
        all_features.append(features.cpu())
        
    return torch.cat(all_features, dim=0).numpy()

if __name__ == '__main__':
    # --- Setup ---
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    
    # Load your existing data_loader (64x64 patches, 4 channels)
    # data_loader = get_dataloader(...) 
    
    # --- 1. Load and Modify Foundation Model ---
    foundation_model = get_modified_resnet(in_channels=IMAGE_CHANNELS).to(torch.float32).to(device)
    print("ResNet-18 model modified and loaded successfully.")

    # --- 2. Extract Embeddings ---
    # features_Z_VAE = extract_features(vae_model, data_loader, device) # Your VAE embeddings
    features_Z_ResNet = extract_foundation_features(foundation_model, data_loader, device)
    
    # --- 3. Comparison ---
    print(f"\nVAE Latent Dim (Example): 128")
    print(f"ResNet Embedding Dim: {features_Z_ResNet.shape[1]}")
    
    # Now you can compare the two sets of embeddings:
    # 1. Compare Clustering: Apply UMAP/t-SNE to both VAE and ResNet embeddings.
    #    (You may need to reduce the ResNet features from 512D to 128D first, 
    #     or let UMAP handle the reduction.)
    # 2. Compare Performance: Use the embeddings as input for a simple linear classifier 
    #    to predict a known property (e.g., redshift, if you have labels).

NameError: name 'IMAGE_CHANNELS' is not defined