In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import copy

# ... (keep the existing imports and ConfigurableMLP class)

def compute_gram_matrix(activations):
    G = activations @ activations.t()
    return G / activations.size(1)
    # b, n = activations.size()
    # return torch.bmm(activations.view(b, n, 1), activations.view(b, 1, n))

def get_layer_activations(model, x):
    activations = []
    for layer in model.layers:
        x = layer(x)
        if isinstance(layer, nn.Linear):
            activations.append(x)
    return activations

def compute_gram_loss(M1, M2, data_loader, device, layer_idx):
    M1.eval()
    M2.eval()
    total_loss = 0
    with torch.no_grad():
        for batch_idx, (data, _) in enumerate(data_loader):
            data = data.to(device)
            data = data.view(data.size(0), -1)
            
            M1_activations = get_layer_activations(M1, data)[:layer_idx+1]
            M2_activations = get_layer_activations(M2, data)[:layer_idx+1]
            
            G1 = compute_gram_matrix(M1_activations[-1])
            G2 = compute_gram_matrix(M2_activations[-1])
            
            loss = torch.mean((G1 - G2) ** 2)
            total_loss += loss.item()
    
    return total_loss / len(data_loader)

def freeze_all_layers(model):
    for param in model.parameters():
        param.requires_grad = False

def unfreeze_layer(model, layer_idx):
    for param in model.layers[layer_idx].parameters():
        param.requires_grad = True

def train_gram_matrix(M1, M2, train_loader, test_loader, optimizer, device, num_epochs=1, tolerance=1e-4, freeze_others=False):
    M1.eval()  # Freeze M1
    M2.train()
    linear_layers = [li for li,l in enumerate(M1.layers) if isinstance(l, nn.Linear)]
    for layer_idx in linear_layers:
        print(f"Training layer {layer_idx + 1}")

        if freeze_others:
            # Freeze all layers
            freeze_all_layers(M2)
            # Unfreeze the current layer
            unfreeze_layer(M2, layer_idx)
            # Create a new optimizer for the unfrozen layer
            optimizer = optim.Adam(filter(lambda p: p.requires_grad, M2.parameters()), lr=optimizer.param_groups[0]['lr'])
 
        
        for epoch in range(num_epochs):
            total_loss = 0
            for batch_idx, (data, _) in enumerate(train_loader):
                data = data.to(device)
                data = data.view(data.size(0), -1)
                
                optimizer.zero_grad()
                
                # Get activations up to the current layer
                with torch.no_grad():
                    M1_activations = get_layer_activations(M1, data)[:layer_idx+1]
                M2_activations = get_layer_activations(M2, data)[:layer_idx+1]
                
                # Compute Gram matrices
                if layer_idx==linear_layers[-1]:
                    G1 = M1_activations[-1]
                    G2 = M2_activations[-1]
                else:
                    G1 = compute_gram_matrix(M1_activations[-1])
                    G2 = compute_gram_matrix(M2_activations[-1])
                
                # Compute loss
                loss = torch.mean((G1 - G2) ** 2)
                total_loss += loss.item()
                
                # Backpropagate and optimize
                loss.backward()
                optimizer.step()
            
            avg_train_loss = total_loss / len(train_loader)
            
            # Compute test Gram loss
            test_gram_loss = compute_gram_loss(M1, M2, test_loader, device, layer_idx)
            
            print(f"Epoch {epoch+1}, Layer {layer_idx+1}, Train Loss: {avg_train_loss:.6f}, Test Gram Loss: {test_gram_loss:.6f}")
            
            if avg_train_loss < tolerance:
                print(f"Converged at epoch {epoch+1} for layer {layer_idx+1}")
                break
    
    return M2

def compute_total_gram_loss(M1, M2, data_loader, device):
    M1.eval()
    M2.eval()
    total_loss = 0
    num_batches = 0
    with torch.no_grad():
        for batch_idx, (data, _) in enumerate(data_loader):
            data = data.to(device)
            data = data.view(data.size(0), -1)
            
            M1_activations = get_layer_activations(M1, data)
            M2_activations = get_layer_activations(M2, data)
            
            batch_loss = 0
            for i in range(len(M1_activations)):
                G1 = compute_gram_matrix(M1_activations[i])
                G2 = compute_gram_matrix(M2_activations[i])
                batch_loss += torch.mean((G1 - G2) ** 2)
            
            total_loss += batch_loss.item()
            num_batches += 1
    
    return total_loss / num_batches

def train_gram_matrix_holistic(M1, M2, train_loader, test_loader, optimizer, device, num_epochs=1, tolerance=1e-4):
    M1.eval()  # Freeze M1
    M2.train()

    for epoch in range(num_epochs):
        total_loss = 0
        num_batches = 0
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(device)
            data = data.view(data.size(0), -1)
            
            optimizer.zero_grad()
            
            # Get activations for all layers
            with torch.no_grad():
                M1_activations = get_layer_activations(M1, data)
            M2_activations = get_layer_activations(M2, data)
            
            # Compute Gram matrices and loss for all layers
            batch_loss = 0
            for i in range(len(M1_activations)):
                # for output, consider the logits themselves rather than Gram 
                if False and i==len(M1_activations)-1:
                    G1 = M1_activations[i]
                    G2 = M2_activations[i]
                else:
                    G1 = compute_gram_matrix(M1_activations[i])
                    G2 = compute_gram_matrix(M2_activations[i])
                batch_loss += torch.mean((G1 - G2) ** 2)
            
            # Backpropagate and optimize
            batch_loss.backward()
            optimizer.step()
            
            total_loss += batch_loss.item()
            num_batches += 1
        
        avg_train_loss = total_loss / num_batches
        
        # Compute test Gram loss
        test_gram_loss = compute_total_gram_loss(M1, M2, test_loader, device)
        
        print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.6f}, Test Gram Loss: {test_gram_loss:.6f}")
        
        if avg_train_loss < tolerance:
            print(f"Converged at epoch {epoch+1}")
            break
    
    return M2

def main(config):
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    
    train_dataset, input_dim, output_dim = get_dataset(config['dataset'], train=True)
    test_dataset, _, _ = get_dataset(config['dataset'], train=False)
    
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)
    
    # Create M1 (original model)
    M1 = ConfigurableMLP(input_dim, config['hidden_dims'], output_dim, config['activation'], config['norm_layer']).to(device)
    
    # Train M1 using the original training method
    optimizer_M1 = optim.Adam(M1.parameters(), lr=config['lr'])
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(config['epochs']):
        train(M1, train_loader, optimizer_M1, criterion, device)
        test_loss, accuracy = test(M1, test_loader, criterion, device)
        print(f'M1 - Epoch: {epoch+1}, Test loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')
    
    # Create M2 (model to be trained based on M1)
    config['hidden_dims'] = [2*h for h in config['hidden_dims']]
    M2 = ConfigurableMLP(input_dim, config['hidden_dims'], output_dim, config['activation'], config['norm_layer']).to(device)
    
    # Train M2 using the new Gram matrix method
    optimizer_M2 = optim.Adam(M2.parameters(), lr=config['gram_lr'])
    # M2 = train_gram_matrix(M1, M2, train_loader, test_loader, optimizer_M2, device, 
    #                        num_epochs=config['gram_epochs'], freeze_others=config['freeze_others'], tolerance=config['tolerance'])
    M2 = train_gram_matrix_holistic(M1, M2, train_loader, test_loader, optimizer_M2, device, 
                                    num_epochs=config['gram_epochs'], tolerance=config['tolerance'])
    
    # Test M2
    test_loss, accuracy = test(M2, test_loader, criterion, device)
    print(f'M2 - Final Test loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')
    M3 = ConfigurableMLP(input_dim, config['hidden_dims'], output_dim, config['activation'], config['norm_layer']).to(device)

    print('Training after kernel transfer (M3 = M2 from scratch)')
    for model,name in [(M2, 'M2'), (M3,'M3'), ]:
        # Train M1 using the original training method
        optimizer = optim.Adam(model.parameters(), lr=config['lr'])
        criterion = nn.CrossEntropyLoss()
        
        for epoch in range(config['epochs']):
            train(model, train_loader, optimizer, criterion, device)
            test_loss, accuracy = test(model, test_loader, criterion, device)
            print(f'{name} - Epoch: {epoch+1}, Test loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')
    
    return M1, M2, train_loader, test_loader

# if __name__ == '__main__':

config = {
    'dataset': 'CIFAR100',
    'hidden_dims': [256]*5,
    'activation': 'selu',
    'norm_layer': 'rmsnorm',
    'lr': 0.001,
    'gram_lr': 0.001,
    'epochs': 7,
    'gram_epochs': 1,
    'tolerance': 0.1,
    'freeze_others': False,
    'optimizer': 'adam',
    'batch_size': 512
}

M1, M2, train_loader, test_loader = main(config)