In [None]:
import torch

# Define the paths to your checkpoint files
path_A = 'model_checkpoint_A.pt'
path_B = 'model_checkpoint_B.pt'

# Load the state dictionaries
state_dict_A = torch.load(path_A, map_location='cpu')
state_dict_B = torch.load(path_B, map_location='cpu')

print(f"Keys in A: {state_dict_A.keys()}")
print(f"Keys in B: {state_dict_B.keys()}")

In [None]:
# Choose an interpolation factor (e.g., 0.5 for simple average)
alpha = 0.5
new_state_dict = {}

# Ensure the models have the exact same structure (same keys)
for key in state_dict_A.keys():
    # Perform element-wise interpolation of the tensors
    tensor_A = state_dict_A[key]
    tensor_B = state_dict_B[key]
    
    # Check if the weight shape/size matches (critical for merging)
    if tensor_A.shape != tensor_B.shape:
        raise ValueError(f"Shape mismatch for key '{key}': {tensor_A.shape} vs {tensor_B.shape}")
        
    # Manual interpolation calculation
    new_tensor = (1 - alpha) * tensor_A + alpha * tensor_B
    
    new_state_dict[key] = new_tensor
    
print("New state dictionary created with interpolated weights.")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print("Interpolated weights successfully loaded into the new model instance.")