In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from Glocal_IB import Glocal_IB

In [2]:
# 1. Define hyperparameters
BATCH_SIZE = 4
SEQ_LEN = 24
FEATURES = 10
EMBEDDING_DIM = 64

In [3]:
class DummyImputationModel(nn.Module):
    """A simple dummy imputation model"""
    def __init__(self, seq_len, features, embedding_dim):
        super().__init__()
        self.seq_len = seq_len
        self.features = features
        self.embedding_dim = embedding_dim
        self.encoder = nn.Linear(features, embedding_dim)
        self.decoder = nn.Linear(embedding_dim, features)
        print(f"Base model initialized, input features: {features}, embedding dimension: {embedding_dim}")

    def forward(self, x, **kwargs):
        # Input x dimension: (batch, seq_len, features)
        embedding = self.encoder(x)
        
        # Decoder expands embedding vector back to sequence length
        # (batch, seq_len, embedding_dim)
        reconstructed = embedding
        output = self.decoder(reconstructed) # -> (batch, seq_len, features)

        # Following convention, return imputation result and intermediate embedding
        return output, embedding
    
    def get_embedding_dim(self): # Just for testing __getattr__ functionality
        return self.embedding_dim


In [4]:
# 2. Instantiate base model and Glocal_IB wrapper
base_model = DummyImputationModel(SEQ_LEN, FEATURES, EMBEDDING_DIM)

# Use "cos_align" as alignment loss with weight 0.5
glocal_model = Glocal_IB(
    base_model=base_model, 
    embedding_dim=EMBEDDING_DIM,
    align_loss_type="cos_align",
    align_model_type="self",
    align_weight=0.5,
    foundation_embedding=None,
)

Base model initialized, input features: 10, embedding dimension: 64


In [5]:
# 3. Prepare simulated data
x_complete = torch.randn(BATCH_SIZE, SEQ_LEN, FEATURES) # Complete original data
x_masked = x_complete.clone()
# Randomly generate a mask, covering approximately 20% of data points
mask = torch.rand(x_masked.shape) > 0.8
x_masked[mask] = 0 # Set masked data points to 0

In [6]:
# 4. Simulate training process
print("🚀 Mode: Training")
glocal_model.train() # Set model to training mode

# During training, need to pass both masked and complete data
training_results = glocal_model(x_masked, x_complete)

print(f"Return result type: {type(training_results)}")
print(f"Dictionary keys: {training_results.keys()}")

# Get imputation results and alignment loss from dictionary
imputation = training_results['output']
alignment_loss = training_results['alignment_loss']

print(f"Imputation result shape: {imputation.shape}")
print(f"Alignment loss value: {alignment_loss.item():.4f}")

# In actual training, you would calculate a standard reconstruction loss
reconstruction_loss = F.mse_loss(imputation[~mask], x_complete[~mask])

# Total loss is weighted sum of alignment loss and reconstruction loss
total_loss = reconstruction_loss + alignment_loss
print(f"Reconstruction loss: {reconstruction_loss.item():.4f}")
print(f"Total loss (for backpropagation): {total_loss.item():.4f}")

🚀 Mode: Training
Return result type: <class 'dict'>
Dictionary keys: dict_keys(['output', 'alignment_loss'])
Imputation result shape: torch.Size([4, 24, 10])
Alignment loss value: 0.5036
Reconstruction loss: 1.0334
Total loss (for backpropagation): 1.5371


In [7]:
# 5. Simulate evaluation/inference process
print("🔬 Mode: Evaluation")
glocal_model.eval() # Set model to evaluation mode

# During evaluation, only pass masked data
with torch.no_grad():
    eval_output = glocal_model(x_masked)

print(f"Return result type: {type(eval_output)}")
print(f"Imputation result shape: {eval_output.shape}")

🔬 Mode: Evaluation
Return result type: <class 'torch.Tensor'>
Imputation result shape: torch.Size([4, 24, 10])


In [8]:
# 6. Demonstrate __getattr__ functionality
print("🎁 Demonstrate __getattr__ functionality")
# Even though we're operating on glocal_model, we can call methods as if calling base_model directly
# This is because __getattr__ automatically forwards calls to the internal self.base_model
emb_dim = glocal_model.get_embedding_dim()
print(f"Call base model method directly through Glocal_IB wrapper: glocal_model.get_embedding_dim() -> {emb_dim}")
print(f"Consistent with direct call to base model: base_model.get_embedding_dim() -> {base_model.get_embedding_dim()}")

🎁 Demonstrate __getattr__ functionality
Call base model method directly through Glocal_IB wrapper: glocal_model.get_embedding_dim() -> 64
Consistent with direct call to base model: base_model.get_embedding_dim() -> 64
