# To investigate Hernandez 2023 model-stitching analysis
Hernandez found that stitching from later in a resnet to earlier was still stitching-connected.
I would like to make a synthetic dataset in which each image of a class is replaced by the same, randomly generated activations.
Use the Imagenet pretrained ResNet50 model 


In [1]:
# Packages
%matplotlib inline

import argparse
import gc
import os.path

import pandas as pd
from torch.linalg import LinAlgError

import matplotlib.pyplot as plt
import torchvision
import torch
from torch import optim

from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import datetime

import random
import numpy as np

import sys
import os
# add the path to find colour_mnist
sys.path.append(os.path.abspath('../ReferenceCode'))
import colour_mnist
from stitch_utils import train_model, RcvResNet50, get_layer_output_shape
from stitch_utils import generate_activations, DynamicSyntheticDataset
import stitch_utils

# add the path to find the rank analysis code
# https://github.com/DHLSmith/jons-tunnel-effect/tree/NeurIPSPaper

sys.path.append(os.path.abspath('../../jons-tunnel-effect/'))
from utils.modelfitting import evaluate_model, set_seed
from extract_weight_rank import install_hooks, perform_analysis

import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST

def logtofile(log_text, verbose=True):
    if verbose:
        print(log_text)
    with open(save_log_as, "a") as f:    
        print(log_text, file=f)

In [2]:
# Set Parameters

# fix random seed for reproducibility
seed = 102
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
random.seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
results_root = "results_1m"
synthetic_dataset_noise = 0.1
num_classes = 1000
num_train_samples = 1_200_000 
num_test_samples  = 120_000 
stitch_train_epochs = 3
device = 'cuda:3'

measure_rank = False
save_stitch_delta = False



In [3]:
# Generate filenames and log the setup details
formatted_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
filename_prefix = f"./{results_root}/{formatted_time}_SEED{seed}_exp1m_ResNet50"

save_log_as = f"{filename_prefix}_log.txt"

#colour_mnist_shape = (3,28,28)
imagenet_shape = (3,224,224)

logtofile(f"Executed at {formatted_time}")
logtofile(f"logging to {save_log_as}")
logtofile(f"{seed=}")
logtofile(f"{synthetic_dataset_noise=}")
logtofile(f"{num_classes=} {num_train_samples=} {num_test_samples=}")

logtofile(f"model will be model = torchvision.models.resnet50(pretrained=True)")

logtofile(f"{stitch_train_epochs=}")
logtofile(f"================================================")

Executed at 2025-01-21_09-20-17
logging to ./results_1m/2025-01-21_09-20-17_SEED102_exp1m_ResNet50_log.txt
seed=102
synthetic_dataset_noise=0.1
num_classes=1000 num_train_samples=1200000 num_test_samples=120000
model will be model = torchvision.models.resnet50(pretrained=True)
stitch_train_epochs=3


mnist and cifar-10 both use 10-classes, with 60_000 train samples and 10_000 test samples. 
ImageNet has 1000 classes

In [4]:
if False:
    # Set up dataloaders
    transform_bw = transforms.Compose([
        transforms.Grayscale(num_output_channels=3),  # Convert to 3 channels    
        transforms.ToTensor(),  # convert to tensor. We always do this one    
        transforms.Normalize((0.1307,) * 3, (0.3081,) * 3)     
    ])
    
    transform_list = [
        transforms.Resize(size=(512, 512)),
        transforms.RandomCrop(256),
        transforms.ToTensor()
    ]
    
    mnist_train = MNIST("./MNIST", train=True, download=True, transform=transform_bw)
    mnist_test = MNIST("./MNIST", train=False, download=True, transform=transform_bw)
    
    bw_train_dataloader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
    bw_test_dataloader  = DataLoader(mnist_test,  batch_size=batch_size, shuffle=True, drop_last=False)


## Set up resnet50 models and train it on versions of MNIST

In [5]:

process_structure = dict()

process_structure["imagenet"] = dict()

# "mix"
from torchvision.models import resnet50, ResNet50_Weights
process_structure["imagenet"]["model"] = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# print(process_structure["imagenet"]["model"])

#process_structure["imagenet"]["train"] = False 
#process_structure["imagenet"]["train_loader"] = mix_train_dataloader
#process_structure["imagenet"]["test_loader"] = mix_test_dataloader
#process_structure["imagenet"]["saveas"] = save_mix_mnist_model_as
#process_structure["imagenet"]["loadfrom"] = mix_mnist_model_to_load

process_structure["imagenet"]["model"].eval()
process_structure["imagenet"]["model"].to(device)


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

## Measure the Accuracy, Record the Confusion Matrix


In [6]:
logtofile("not measuring initial accuracy as I assume this is available")
if False:
    original_accuracy = dict()
    for key, val in process_structure.items():
        logtofile(f"Accuracy Calculation for ResNet50 with {key=}")
        model = val["model"]
        model.eval() # ALWAYS DO THIS BEFORE YOU EVALUATE MODELS
        
        # Compute the model accuracy on the test set
        correct = 0
        total = 0
        
        # assuming 10 classes
        # rows represent actual class, columns are predicted
        confusion_matrix = torch.zeros(num_classes,num_classes, dtype=torch.int)
        
        # YOUR CODE HERE
        for data in val["test_loader"]:
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            predictions = torch.argmax(model(inputs),1)
            
            matches = predictions == labels
            correct += matches.sum().item()
            total += len(labels)
            for idx, l in enumerate(labels):
                confusion_matrix[l, predictions[idx]] = 1 + confusion_matrix[l, predictions[idx]] 
        
        logtofile("Test the Trained Resnet50")
        acc = ((100.0 * correct) / total)
        logtofile('Test Accuracy: %2.2f %%' % acc)
        original_accuracy[key] = acc
        logtofile('Confusion Matrix')
        logtofile(confusion_matrix)
        logtofile(confusion_matrix.sum())
        if not measure_rank:
            raise ValueError("This needs to be implemented")
    
    logtofile(f"{original_accuracy=}")

not measuring initial accuracy as I assume this is available


## Measure Rank with original dataloader (test) before cutting and stitching

In [7]:



num_layers_in_model = len(list(process_structure["imagenet"]["model"].children()))

# Specify the layer name you're interested in
print(device)
#print(model.training)
layer_index = 4
output_shape = get_layer_output_shape(process_structure["imagenet"]["model"], layer_index, imagenet_shape, device=device, type="ResNet50")
print(f"The shape of the output from layer {layer_index} is: {output_shape}")
print(f"Output Size is : {output_shape.numel()}")




cuda:3
get_layer_output_shape for type='ResNet50'
The shape of the output from layer 4 is: torch.Size([1, 256, 56, 56])
Output Size is : 802816


## Train the stitch layer

In [8]:
if measure_rank:
    # For the Whole Model - but we will pass it through the RcvResNet150 function to get matching feature names
    for key, val in process_structure.items():
        
        dl = val["test_loader"]
        
        if val["train"]:
            filename = val["saveas"] 
        else:    
            filename = val["loadfrom"] 
        assert os.path.exists(filename)
        mdl = torchvision.models.resnet18(num_classes=10) # Untrained model
        state = torch.load(filename, map_location=torch.device("cpu"))
        mdl.load_state_dict(state, assign=True)
        mdl=mdl.to(device)
        mdl = RcvResNet18(mdl, -1, colour_mnist_shape, device).to(device)
    
        out_filename = filename.split('/')[-1].replace('.weights', '-test.csv')
        
        outpath = f"./results_4_epochs_rank/{key}-{key}-{seed}_{out_filename}"  # denote output name as <model_training_type>-dataset-<name>
        
        if os.path.exists(f"{outpath}"):
            logtofile(f"Already evaluated for {outpath}")
            continue
        logtofile(f"Measure Rank for {key=}")
        print(f"output to {outpath}")
                
        params = {}
        params["model"] = key
        params["dataset"] = key
        params["seed"] = seed
        if val["train"]: # as only one network used, record its filename as both send and receive files
            params["send_file"] = val["saveas"] 
            params["rcv_file"] = val["saveas"] 
        else:    
            params["send_file"] = val["loadfrom"] 
            params["rcv_file"] = val["loadfrom"]
        with torch.no_grad():
            layers, features, handles = install_hooks(mdl)
            
            metrics = evaluate_model(mdl, dl, 'acc', verbose=2)
            params.update(metrics)
            classes = None
            df = perform_analysis(features, classes, layers, params, n=-1)
            df.to_csv(f"{outpath}")
                    
        for h in handles:
            h.remove()
        del mdl, layers, features, metrics, params, df, handles
        gc.collect()

In [None]:
stitching_accuracies = dict()
#stitching_penalties = dict()

for key, val in process_structure.items():        
    stitching_accuracies[key] = dict()
    #stitching_penalties[key] = dict()
    for layer_to_cut_after in [num_layers_in_model - 6]:   #range(3,num_layers_in_model - 1):
        ###################### Don't bother to stitch and train if we've already analysed it
        
        filename = 'resnet50_imagenet1k-test.weights'
        rank_filename = 'resnet50_imagenet1k-test.csv'     
        # denote output name as <model_training_type>-dataset-<name>
        # where <model_training_type> is [sender_model or X][layer_to_cut_after][Receiver_model]
        model_training_type = f"X{layer_to_cut_after}{key}"
        dataset_type = "synth"
        outpath = f"./{results_root}_rank/{model_training_type}-{dataset_type}-{seed}_{rank_filename}"  
                        
        if os.path.exists(f"{outpath}"):
            logtofile(f"Already evaluated for {outpath}")
            continue
        ####################################################################################
        logtofile(f"Evaluate ranks and output to {outpath}")
        logtofile(f"stitch into model {key}")
        
        model = process_structure["imagenet"]["model"]
        #model.load_state_dict(torch.load(filename, map_location=torch.device(device)))  # uses either the load/save name depending whether it'
        cut_layer_output_size = get_layer_output_shape(model, layer_to_cut_after, imagenet_shape, device, type="ResNet50")
        model_cut = RcvResNet50(model, layer_to_cut_after, imagenet_shape, device).to(device)

        if save_stitch_delta:
            #############################################################
            # store the initial stitch state
            initial_stitch_weight = model_cut.stitch.s_conv1.weight.clone()
            initial_stitch_bias   = model_cut.stitch.s_conv1.bias.clone()
            stitch_initial_weight_outpath    = f"./{results_root}/STITCH_initial_weight_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
            stitch_initial_bias_outpath      = f"./{results_root}/STITCH_initial_bias_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
            torch.save(initial_stitch_weight, stitch_initial_weight_outpath)
            torch.save(initial_stitch_bias, stitch_initial_bias_outpath)
            ############################################################
        representation_shape=(cut_layer_output_size[1:])
        
        logtofile("About to generate activations")
        syn_activations = generate_activations(num_classes=num_classes, representation_shape=representation_shape)
        logtofile(f"after layer {layer_to_cut_after}, activations shape is {syn_activations.shape}")

        logtofile("About to create datasets")
        # syn_train_set  = SyntheticDataset(train=True, activations=syn_activations, noise=synthetic_dataset_noise, num_classes=num_classes)
        syn_train_set  = DynamicSyntheticDataset(num_samples=num_train_samples, activations=syn_activations, noise=synthetic_dataset_noise, local_seed=seed)
        syn_trainloader  = DataLoader(syn_train_set, batch_size=128, shuffle=True, drop_last=True)
        #syn_test_set  = SyntheticDataset(train=False, activations=syn_activations, noise=synthetic_dataset_noise, num_classes=num_classes)
        syn_test_set = DynamicSyntheticDataset(num_samples=num_test_samples, activations=syn_activations, noise=synthetic_dataset_noise, local_seed=(seed + 1))
        syn_testloader  = DataLoader(syn_test_set, batch_size=128, shuffle=False, drop_last=False)        
                
        # define the loss function and the optimiser
        loss_function = nn.CrossEntropyLoss()
        optimiser = optim.SGD(model_cut.parameters(), lr=1e-4, momentum=0.9, weight_decay=0.01)
        
        # Put top model into train mode so that bn and dropout perform in training mode
        model_cut.train()
        # Freeze the top model
        model_cut.requires_grad_(False)
        # Un-Freeze the stitch layer
        for name, param in model_cut.stitch.named_parameters():
            param.requires_grad_(True)
        logtofile(f"Train the stitch to a top model cut after layer {layer_to_cut_after}")    
        # the epoch loop: note that we're training the whole network
        for epoch in range(stitch_train_epochs):
            running_loss = 0.0
            for data in syn_trainloader:  # Use synthetic training data
                # data is (representations, labels) tuple
                # get the inputs and put them on the GPU
                inputs, labels = data
                # logtofile(f"{epoch=}: {labels=}")
                inputs = inputs.to(device)
                labels = labels.to(device)
                # raise ValueError("STOP HERE")
                # zero the parameter gradients
                optimiser.zero_grad()
        
                # forward + loss + backward + optimise (update weights)
                outputs = model_cut(inputs)
                loss = loss_function(outputs, labels)
                loss.backward()
                optimiser.step()
        
                # keep track of the loss this epoch
                running_loss += loss.item()
            logtofile("Epoch %d, loss %4.2f" % (epoch, running_loss))
        logtofile('**** Finished Training ****')
        
        model_cut.eval() # ALWAYS DO THIS BEFORE YOU EVALUATE MODELS

        if save_stitch_delta:
            ############################################################
            # store the trained stitch
            trained_stitch_weight = model_cut.stitch.s_conv1.weight.clone()
            trained_stitch_bias   = model_cut.stitch.s_conv1.bias.clone()
            stitch_trained_weight_outpath    = f"./{results_root}/STITCH_trained_weight_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
            stitch_trained_bias_outpath      = f"./{results_root}/STITCH_trained_bias_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
            torch.save(trained_stitch_weight, stitch_trained_weight_outpath)
            torch.save(trained_stitch_bias, stitch_trained_bias_outpath)
            
            print(f"Number of weight / bias in stitch layer is {len(initial_stitch_weight)}")
            stitch_weight_diff = trained_stitch_weight - initial_stitch_weight
            stitch_weight_delta = torch.linalg.norm(stitch_weight_diff).item()
            logtofile(f"Change in stitch weights: {stitch_weight_delta}")
            maxabsweight =  torch.max(stitch_weight_diff.abs()).item()
            logtofile(f"Largest abs weight change: {maxabsweight}")
            stitch_weight_number = torch.sum(torch.where(stitch_weight_diff.abs() > 0.1*maxabsweight, True, False)).item()
            logtofile(f"Number of weights changing > 0.1 of that: {stitch_weight_number}")
            
            stitch_bias_diff = trained_stitch_bias - initial_stitch_bias
            stitch_bias_delta = torch.linalg.norm(stitch_bias_diff).item()
            logtofile(f"Change in stitch bias: {stitch_bias_delta}")
            maxabsbias =  torch.max(stitch_bias_diff.abs()).item()
            logtofile(f"Largest abs bias change: {maxabsbias}")
            stitch_bias_number = torch.sum(torch.where(stitch_bias_diff.abs() > 0.1*maxabsbias, True, False)).item()
            logtofile(f"Number of bias changing > 0.1 of that: {stitch_bias_number}")

        #new_tensor = torch.load("foo_1_state.pt")
        ##############################################################
        
        logtofile("Test the trained stitch")        
        # Compute the model accuracy on the test set
        correct = 0
        total = 0
        
        # assuming 10 classes
        # rows represent actual class, columns are predicted
        confusion_matrix = torch.zeros(num_classes,num_classes, dtype=torch.int)
        
        for data in syn_testloader:
            inputs, labels = data
            #print(inputs)   
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            predictions = torch.argmax(model_cut(inputs),1)
            #print(model_cut(inputs))
            matches = predictions == labels.to(device)
            correct += matches.sum().item()
            total += len(labels)
        
            for idx, l in enumerate(labels):
                confusion_matrix[l, predictions[idx]] = 1 + confusion_matrix[l, predictions[idx]] 
        acc = ((100.0 * correct) / total)
        logtofile('Test Accuracy: %2.2f %%' % acc)
        logtofile('Confusion Matrix')
        logtofile(confusion_matrix)
        logtofile("===================================================================")
        stitching_accuracies[key][layer_to_cut_after] = acc
        #stitching_penalties[key][layer_to_cut_after] = original_accuracy[key] - acc        

        dl = syn_testloader        
        if measure_rank:
            params = {}
            params["model"] = model_training_type
            params["dataset"] = dataset_type
            params["seed"] = seed

            if save_stitch_delta:
                params["stitch_weight_delta"] = stitch_weight_delta
                params["stitch_bias_delta"] = stitch_bias_delta        
                params["stitch_weight_number"] = stitch_weight_number
                params["stitch_bias_number"] = stitch_bias_number
            
            with torch.no_grad():
                layers, features, handles = install_hooks(model_cut)
                
                metrics = evaluate_model(model_cut, dl, 'acc', verbose=2)
                params.update(metrics)
                classes = None
                df = perform_analysis(features, classes, layers, params, n=-1)
                df.to_csv(f"{outpath}")
                
            for h in handles:
                h.remove()
            del model_cut, layers, features, metrics, params, df, handles
            gc.collect()
        else:  # not measuring rank           
            logtofile(f"output to {outpath}")
                    
            params = {}
            # params["offset"] = target_offset
            params["model"] = model_training_type
            params["dataset"] = dataset_type
            params["seed"] = seed
            params["val_acc"] = acc / 100
            params["name"] = "only"
    
            results = []
            results.append(params)
            df = pd.DataFrame.from_records(results)
            df.to_csv(f"{outpath}")
                        
            del  params, df
            gc.collect()

Evaluate ranks and output to ./results_1m_rank/X4imagenet-synth-102_resnet50_imagenet1k-test.csv
stitch into model imagenet
get_layer_output_shape for type='ResNet50'
get_layer_output_shape for type='ResNet50'
The shape of the output from layer 4 is: torch.Size([1, 256, 56, 56]), with 802816 elements
About to generate activations
after layer 4, activations shape is torch.Size([1000, 256, 56, 56])
About to create datasets
Train the stitch to a top model cut after layer 4


In [None]:
if False:
    model_cut.eval() # ALWAYS DO THIS BEFORE YOU EVALUATE MODELS
    
    logtofile("Test the trained stitch")        
    # Compute the model accuracy on the test set
    correct = 0
    total = 0
    
    # assuming 10 classes
    # rows represent actual class, columns are predicted
    confusion_matrix = torch.zeros(num_classes,num_classes, dtype=torch.int)
    
    for data in syn_testloader:
        inputs, labels = data
        #print(inputs)   
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        predictions = torch.argmax(model_cut(inputs),1)
        #print(model_cut(inputs))
        matches = predictions == labels.to(device)
        correct += matches.sum().item()
        total += len(labels)
    
        for idx, l in enumerate(labels):
            confusion_matrix[l, predictions[idx]] = 1 + confusion_matrix[l, predictions[idx]] 
    acc = ((100.0 * correct) / total)
    logtofile('Test Accuracy: %2.2f %%' % acc)
    logtofile('Confusion Matrix')
    logtofile(confusion_matrix)
    logtofile("===================================================================")
    stitching_accuracies[key][layer_to_cut_after] = acc

In [None]:
if save_stitch_delta:
    stitch_weight_diff = trained_stitch_weight - initial_stitch_weight
    stitch_weight_delta = torch.linalg.norm(stitch_weight_diff).item()
    print(f"Change in stitch weights: {stitch_weight_delta}")
    maxabsweight =  torch.max(stitch_weight_diff.abs()).item()
    print(f"Largest abs weight change: {maxabsweight}")
    stitch_weight_number = torch.sum(torch.where(stitch_weight_diff.abs() > 0.1*maxabsweight, True, False)).item()
    print(f"Number of weights changing > 0.1 of that: {stitch_weight_number}")
    
    print(stitch_weight_delta)
    print(stitch_weight_number)

In [None]:
logtofile(f"{stitching_accuracies=}")
#logtofile(f"{stitching_penalties=}")

In [None]:
for r_key in stitching_accuracies:
    logtofile(f"synth-{r_key}")
    #logtofile(original_accuracy[r_key])
    logtofile("Stitch Accuracy")
    for layer in stitching_accuracies[r_key]:
        logtofile(stitching_accuracies[r_key][layer])
    logtofile("--------------------------")