In [1]:
from continual_learning import * 


if __name__ == "__main__":
    # Centralized configuration dictionary
    config = {
        "seed": 42,
        "sample_classes": [0, 1, 2,],
        "batch_size": 128,
        "learning_rate": 0.001,
        "num_epochs": 50,
        "metrics_frequency": 100,
        "dead_threshold": 0.95, 
        "corr_threshold": 0.99, 
        "saturation_threshold": 1e-4, 
        "saturation_percentage": 0.99,
        "model_type": "MLP",  # Options: "MLP", "CNN", "VisionTransformer",
        "model": {
            "input_size": 3 * 32 * 32,
            "hidden_sizes": [1024] * 10,
            "activation": "relu",
            "normalization": "layer",
            "norm_after_activation": False,
            "normalization_affine": False,
            "dropout_p": 0,
            # output_size is set dynamically based on number of sample classes
        }
    }
    
    # Initialize wandb run (specify project and optionally entity)
    wandb.init(project="CL-plasticity", config=config)
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Set random seed for reproducibility
    def set_seed(seed):
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        np.random.seed(seed)
        random.seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    set_seed(config["seed"])
    
    num_classes = len(config["sample_classes"])
    
    print("Loading CIFAR10 dataset...")
    trainloader, testloader, fixed_trainloader, fixed_valloader = get_cifar10_data_with_class_selection(
        batch_size=config["batch_size"],
        sample_classes=config["sample_classes"]
    )
    
    print("Creating model...")
    if config["model_type"] == "MLP":
        config["model"]["output_size"] = num_classes
        model = MLP(**config["model"])
    model = VisionTransformer()
    # Options for CNN or VisionTransformer can be added here similarly.
    
    model = model.to(device)
    
    def module_filter(name):
        return name[-4:] == '.mlp' or 'linear' in name
    train_monitor = NetworkMonitor(model, module_filter)
    val_monitor = NetworkMonitor(model, module_filter)
    
    print("\nModel Architecture:")
    for name, module in model.named_modules():
        if len(name) > 0:
            print(f"{name}: {module.__class__.__name__}")
    
    print("\nStarting training with separate monitors...")
    history = train_with_separate_monitors(
        model, trainloader, testloader, fixed_trainloader, fixed_valloader,
        train_monitor, val_monitor,config,
        device=device
    )
    
    # Log final metrics to wandb
    wandb.log({
        "final_train_loss": history["train_losses"][-1],
        "final_test_acc": history["test_accs"][-1]
    })
    
    results_dir = './results'
    os.makedirs(results_dir, exist_ok=True)
    
    print("\nPlotting results...")
    plot_training_curves(history, save_path=results_dir)
    plot_all_metrics(history, save_path=results_dir)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mamirjoudaki[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using device: cuda
Loading CIFAR10 dataset...
Files already downloaded and verified
Files already downloaded and verified
Creating model...

Model Architecture:
layers: ModuleDict
layers.patch_embed: PatchEmbedding
layers.patch_embed.layers: ModuleDict
layers.patch_embed.layers.proj: Conv2d
layers.pos_drop: Dropout
layers.block_0: TransformerBlock
layers.block_0.layers: ModuleDict
layers.block_0.layers.norm1: LayerNorm
layers.block_0.layers.attn: Attention
layers.block_0.layers.attn.layers: ModuleDict
layers.block_0.layers.attn.layers.qkv: Linear
layers.block_0.layers.attn.layers.attn_drop: Dropout
layers.block_0.layers.attn.layers.proj: Linear
layers.block_0.layers.attn.layers.proj_drop: Dropout
layers.block_0.layers.norm2: LayerNorm
layers.block_0.layers.mlp: TransformerMLP
layers.block_0.layers.mlp.layers: ModuleDict
layers.block_0.layers.mlp.layers.fc1: Linear
layers.block_0.layers.mlp.layers.act: GELU
layers.block_0.layers.mlp.layers.drop1: Dropout
layers.block_0.layers.mlp.layers

KeyboardInterrupt: 

In [None]:
for inputs, targets in fixed_trainloader:
    break
    
# Get the indices that would sort the targets
_, sorted_indices = torch.sort(targets)

# Use these indices to sort both inputs and targets
inputs = inputs[sorted_indices]
targets = targets[sorted_indices]

def module_filter(name):
    return True
monitor = NetworkMonitor(model, module_filter)
inputs.shape, targets.shape
monitor.register_hooks()
output = model(inputs.to('cuda'));
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
optimizer.zero_grad()
loss = criterion(output, targets.to('cuda'))
loss.backward()

monitor.activations.keys(), monitor.gradients.keys()

In [None]:
k = 8
X = monitor.activations['layers.linear_8'][0]
# X = X.flatten(0,1).t()
print(X.shape)
# plt.figure(figsize=(12,6))
# plt.imshow(X[:,:]); 
# X = X[::k]
Y = (X/X.norm(dim=1,keepdim=True))
C = Y @ Y.t()
# # C[C<0.95] = 0
plt.figure()
plt.imshow(C)
plt.colorbar()
X.shape, targets[::k], C.shape

In [None]:
layers = list(monitor.activations.keys())
for layer in layers[1::3]:
    A = monitor.activations[layer][0]
    B = monitor.gradients[layer][0]
    M = B.abs()/A.abs().mean(dim=0,keepdim=True)
    M = (M < 1e-4).float()
    M = M.mean(dim=0)
    print(f'{layer}  saturated = {len(M[M>.99]) / len(M):.2f}')
    # A.shape, G.shape

In [None]:
li = 8
l = model.layers[f'linear_{li}']
l2 = model.layers[f'linear_{li+1}']
in_norms = l.weight.norm(dim=1).detach().cpu()
in_grad_norms = l.weight.grad.norm(dim=1).detach().cpu()
in_bw = monitor.gradients[f'layers.linear_{li}'][0].norm(dim=0).detach().cpu()
out_norms = l2.weight.norm(dim=0).detach().cpu()
bias = l.bias.detach().cpu()

print(in_bw.shape)
plt.scatter(in_norms,out_norms);
plt.figure()
plt.scatter(in_norms,in_grad_norms);
plt.figure()
for li in range(1,10,2):
    in_bw = monitor.gradients[f'layers.linear_{li}'][0].norm(dim=0).detach().cpu()
    print((in_bw.log() - in_bw.log().mean() < -1).sum())
    sns.kdeplot(in_bw.log(), label=f'layers.linear_{li}')
plt.legend()

In [None]:
W = l.weight.detach().cpu()
G = l.weight.grad.detach().cpu()
M = G.abs()/W.abs().mean()
M = (M < 1e-3).float()
sns.kdeplot(M.mean(dim=1))
(M>0.95).float().mean()

In [None]:
A = monitor.activations[layer][0]
B = monitor.gradients[layer][0]
M = B.abs()/A.abs().mean(dim=0,keepdim=True)
M = (M < 1e-6).float()
M.mean(dim=0).shape, layer, M.mean(dim=0)[:5]

In [None]:
import seaborn as sns

L = list(monitor.gradients.keys())[::-1][2:30:3]

for l in L:
    G = monitor.gradients[l][0]
    feature_grad_norms = G.norm(dim=0)
    batch_grad_norms = G.norm(dim=1)
    print(l, G.shape, feature_grad_norms.shape, batch_grad_norms.shape)
    sns.kdeplot(feature_grad_norms.log(),label=l,)
plt.legend()

In [None]:
X = train_monitor.gradients['layers.norm_9'][0]
# plt.hist(X.flatten())
for f in range(10):
    plt.hist(X[:,f],label=f'feature {f}',alpha=0.5)
plt.legend()

plt.figure()
for b in [0,1]:
    plt.hist(X[b,:],label=f'batch {b}',alpha=0.5)
plt.legend()
X.shape