In [None]:
import torch
import sys
sys.path.append('/home/alexabades/recsys')
from src.models.CNCF.cncf import ContextualNeuralCollavorativeFiltering
from src.models.NCF.nfc import NeuralCollaborativeFiltering

def print_model_memory_usage(model, input_example):
    total_params = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            num_params = param.numel()
            total_params += num_params
            print(f"{name} has {num_params} parameters")
    
    model_size_bytes = total_params * 4  # each parameter is a 32-bit float
    model_size_mb = model_size_bytes / (1024 ** 2)
    print(f'Total parameters count: {total_params}')
    print(f"Total model size: {model_size_mb:.2f} MB")

    # Forward pass to deermine peak memory usage
    with torch.no_grad():
        model(*input_example)  # Assuming input_example is a tuple of inputs (user_input, item_input)
        peak_memory = torch.cuda.max_memory_allocated() / (1024 ** 2)  # Convert bytes to MB
        print(f"Peak memory usage during forward pass: {peak_memory:.2f} MB")

# Example usage
# model = NeuralCollaborativeFiltering(num_users=215366, num_items=95895, mf_dim=8, layers=[32, 16, 8]).to('cuda')
model = ContextualNeuralCollavorativeFiltering(num_users=215366, num_items=95895, num_context=22, mf_dim=8, layers=[86, 32, 16, 8]).to('cuda')
user_input = torch.randint(0, 215366, (1,), device='cuda')
item_input = torch.randint(0, 95895, (1,), device='cuda')
context_input = torch.randint(0, 95895, (1, 22), device='cuda')
print_model_memory_usage(model, (user_input, item_input, context_input))