# To investigate Hernandez 2023 model-stitching analysis
Hernandez found that stitching from later in a resnet to earlier was still stitching-connected.
I would like to make a synthetic dataset in which each image of a class is replaced by the same, randomly generated activations.
Load in ResNet18 models that I've trained on different types of dataset. Cut them at different layers and stitch in the random dataloader.
Add hooks and measure rank/activations etc, saving the csv files of details
## Only train for 4 epochs
original models only train 4 epochs so they learn less accuracy. This is to allow our experiments better to determine change and improvement

In [1]:
# Packages
%matplotlib inline

import argparse
import gc
import os.path

import pandas as pd
from torch.linalg import LinAlgError

import matplotlib.pyplot as plt
import torchvision
import torch
from torch import optim

from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import datetime

import random
import numpy as np

import sys
import os
# add the path to find colour_mnist
sys.path.append(os.path.abspath('../ReferenceCode'))
import colour_mnist
from stitch_utils import train_model, RcvResNet18, get_layer_output_shape
from stitch_utils import generate_activations, SyntheticDataset
import stitch_utils

# add the path to find the rank analysis code
# https://github.com/DHLSmith/jons-tunnel-effect/tree/NeurIPSPaper

sys.path.append(os.path.abspath('../../jons-tunnel-effect/'))
from utils.modelfitting import evaluate_model, set_seed
from extract_weight_rank import install_hooks, perform_analysis

import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST

def logtofile(log_text, verbose=True):
    if verbose:
        print(log_text)
    with open(save_log_as, "a") as f:    
        print(log_text, file=f)

In [2]:
# Set Parameters

# fix random seed for reproducibility
seed = 102
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
random.seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

train_all = False

# MIX is 1/3 bgonly, 1/3 mnist only, 1/3 biased data
train_mix_mnist_model = train_all
mix_mnist_model_to_load = './results_4_epochs/2024-08-01_07-29-36_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_mix_mnist.weights'

# BW is greyscale mnist with no colour added
train_bw_mnist_model = train_all  # when False, automatically loads a trained model
bw_mnist_model_to_load ='./results_4_epochs/2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_bw_mnist.weights'

# BG_ONLY contains no mnist data, just a coloured background
train_bg_only_colour_mnist_model = train_all # when False, automatically loads a trained model
bg_only_colour_mnist_model_to_load =  './results_4_epochs/2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_bg_only_colour_mnist.weights'

# BG_UNBIASED is digits with randomly selected colour background. Targets represent the colour
train_bg_unbiased_colour_mnist_model = train_all  # when False, automatically loads a trained model
bg_unbiased_colour_mnist_model_to_load = './results_4_epochs/2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_bg_unbiased_colour_mnist.weights'

# BIASED is digits with consistent per-class colour background. 
train_biased_colour_mnist_model = train_all  # when False, automatically loads a trained model
biased_colour_mnist_model_to_load = './results_4_epochs/2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_biased_colour_mnist.weights'

# UNBIASED is digits with randoly selected colour background. Targets are digit values
train_unbiased_colour_mnist_model = train_all  # when False, automatically loads a trained model
unbiased_colour_mnist_model_to_load = './results_4_epochs/2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_unbiased_colour_mnist.weights'

original_train_epochs = 4
bg_noise = 0.1
synthetic_dataset_noise = 0.1

stitch_train_epochs = 4
device = 'cuda:2'


In [3]:
# Generate filenames and log the setup details
formatted_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
filename_prefix = f"./results_4_epochs/{formatted_time}_SEED{seed}_EPOCHS{original_train_epochs}_BGN{bg_noise}_exp1f_ResNet18"
save_mix_mnist_model_as = f"{filename_prefix}_mix_mnist.weights"
save_bw_mnist_model_as = f"{filename_prefix}_bw_mnist.weights"
save_bg_only_colour_mnist_model_as = f"{filename_prefix}_bg_only_colour_mnist.weights"
save_bg_unbiased_colour_mnist_model_as = f"{filename_prefix}_bg_unbiased_colour_mnist.weights"
save_biased_colour_mnist_model_as = f"{filename_prefix}_biased_colour_mnist.weights"
save_unbiased_colour_mnist_model_as = f"{filename_prefix}_unbiased_colour_mnist.weights"
save_log_as = f"{filename_prefix}_log.txt"

colour_mnist_shape = (3,28,28)

logtofile(f"Executed at {formatted_time}")
logtofile(f"logging to {save_log_as}")
logtofile(f"{seed=}")
logtofile(f"{bg_noise=}")
logtofile(f"{synthetic_dataset_noise=}")

logtofile(f"{train_mix_mnist_model=}")
if train_mix_mnist_model:
    logtofile(f"{save_mix_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{mix_mnist_model_to_load=}")

logtofile(f"{train_bw_mnist_model=}")
if train_bw_mnist_model:
    logtofile(f"{save_bw_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{bw_mnist_model_to_load=}")
    
logtofile(f"{train_bg_only_colour_mnist_model=}")
if train_bg_only_colour_mnist_model:
    logtofile(f"{save_bg_only_colour_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{bg_only_colour_mnist_model_to_load=}")
    
logtofile(f"{train_bg_unbiased_colour_mnist_model=}")
if train_bg_unbiased_colour_mnist_model:
    logtofile(f"{save_bg_unbiased_colour_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{bg_unbiased_colour_mnist_model_to_load=}")

logtofile(f"{train_biased_colour_mnist_model=}")
if train_biased_colour_mnist_model:
    logtofile(f"{save_biased_colour_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{biased_colour_mnist_model_to_load=}")

logtofile(f"{train_unbiased_colour_mnist_model=}")
if train_unbiased_colour_mnist_model:
    logtofile(f"{save_unbiased_colour_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{unbiased_colour_mnist_model_to_load=}")

logtofile(f"{stitch_train_epochs=}")
logtofile(f"================================================")

Executed at 2025-01-17_19-21-52
logging to ./results_4_epochs/2025-01-17_19-21-52_SEED102_EPOCHS4_BGN0.1_exp1f_ResNet18_log.txt
seed=102
bg_noise=0.1
synthetic_dataset_noise=0.1
train_mix_mnist_model=False
mix_mnist_model_to_load='./results_4_epochs/2024-08-01_07-29-36_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_mix_mnist.weights'
train_bw_mnist_model=False
bw_mnist_model_to_load='./results_4_epochs/2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_bw_mnist.weights'
train_bg_only_colour_mnist_model=False
bg_only_colour_mnist_model_to_load='./results_4_epochs/2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_bg_only_colour_mnist.weights'
train_bg_unbiased_colour_mnist_model=False
bg_unbiased_colour_mnist_model_to_load='./results_4_epochs/2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_bg_unbiased_colour_mnist.weights'
train_biased_colour_mnist_model=False
biased_colour_mnist_model_to_load='./results_4_epochs/2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_b

mnist and cifar-10 both use 10-classes, with 60_000 train samples and 10_000 test samples. 

In [4]:
# Set up dataloaders
transform_bw = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert to 3 channels    
    transforms.ToTensor(),  # convert to tensor. We always do this one    
    transforms.Normalize((0.1307,) * 3, (0.3081,) * 3)     
])

mnist_train = MNIST("./MNIST", train=True, download=True, transform=transform_bw)
mnist_test = MNIST("./MNIST", train=False, download=True, transform=transform_bw)

bw_train_dataloader = DataLoader(mnist_train, batch_size=128, shuffle=True, drop_last=True)
bw_test_dataloader  = DataLoader(mnist_test,  batch_size=128, shuffle=True, drop_last=False)

# mix dataloader
mix_train_dataloader = colour_mnist.get_mixed_mnist_dataloader(root="./MNIST", batch_size=128, train=True, bg_noise_level=bg_noise, standard_getitem=True)
mix_test_dataloader = colour_mnist.get_mixed_mnist_dataloader(root="./MNIST", batch_size=128,  train=False, bg_noise_level=bg_noise, standard_getitem=True)

# bg_only means no digits - we will use colour as label
bg_only_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=128, data_label_correlation=1.0, train=True, bg_noise_level=bg_noise, bg_only=True, standard_getitem=True)
bg_only_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=128, data_label_correlation=1.0, train=False, bg_noise_level=bg_noise, bg_only=True, standard_getitem=True)

# unbiased means each digit has correct label and random colour - but bg means we will use colour as label (i.e. the bias_target will be the target)
bg_unbiased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=128, data_label_correlation=0.1, train=True, bg_noise_level=bg_noise, bias_targets_as_targets=True)
bg_unbiased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=128, data_label_correlation=0.1, train=False, bg_noise_level=bg_noise, bias_targets_as_targets=True)

# biased means each digit has correct label and consistent colour - Expect network to learn the colours only
biased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=128, data_label_correlation=1.0, train=True, bg_noise_level=bg_noise, standard_getitem=True)
biased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=128, data_label_correlation=1.0, train=False, bg_noise_level=bg_noise, standard_getitem=True)

# unbiased means each digit has correct label and random colour - Expect network to disregard colours?
unbiased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=128, data_label_correlation=0.1, train=True, bg_noise_level=bg_noise, standard_getitem=True)
unbiased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=128, data_label_correlation=0.1, train=False, bg_noise_level=bg_noise, standard_getitem=True)

## Set up resnet18 models and train it on versions of MNIST

In [5]:

process_structure = dict()

process_structure["mix"] = dict()
process_structure["bw"] = dict()
process_structure["bgonly"] = dict()
process_structure["bg"] = dict()
process_structure["bias"]      = dict()
process_structure["unbias"]    = dict()

# "mix"
process_structure["mix"]["model"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model
process_structure["mix"]["train"] = train_mix_mnist_model 
process_structure["mix"]["train_loader"] = mix_train_dataloader
process_structure["mix"]["test_loader"] = mix_test_dataloader
process_structure["mix"]["saveas"] = save_mix_mnist_model_as
process_structure["mix"]["loadfrom"] = mix_mnist_model_to_load

# "bw"
process_structure["bw"]["model"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model
process_structure["bw"]["train"] = train_bw_mnist_model 
process_structure["bw"]["train_loader"] = bw_train_dataloader
process_structure["bw"]["test_loader"] = bw_test_dataloader
process_structure["bw"]["saveas"] = save_bw_mnist_model_as
process_structure["bw"]["loadfrom"] = bw_mnist_model_to_load

# "bg_only_colour"
process_structure["bgonly"]["model"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model
process_structure["bgonly"]["train"] = train_bg_only_colour_mnist_model 
process_structure["bgonly"]["train_loader"] = bg_only_train_dataloader
process_structure["bgonly"]["test_loader"] = bg_only_test_dataloader
process_structure["bgonly"]["saveas"] = save_bg_only_colour_mnist_model_as
process_structure["bgonly"]["loadfrom"] = bg_only_colour_mnist_model_to_load

# "bg_unbiased_colour"
process_structure["bg"]["model"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model
process_structure["bg"]["train"] = train_bg_unbiased_colour_mnist_model 
process_structure["bg"]["train_loader"] = bg_unbiased_train_dataloader
process_structure["bg"]["test_loader"] = bg_unbiased_test_dataloader
process_structure["bg"]["saveas"] = save_bg_unbiased_colour_mnist_model_as
process_structure["bg"]["loadfrom"] = bg_unbiased_colour_mnist_model_to_load

# "biased_colour_mnist"
process_structure["bias"]["model"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model
process_structure["bias"]["train"] = train_biased_colour_mnist_model
process_structure["bias"]["train_loader"] = biased_train_dataloader
process_structure["bias"]["test_loader"] = biased_test_dataloader
process_structure["bias"]["saveas"] = save_biased_colour_mnist_model_as
process_structure["bias"]["loadfrom"] =  biased_colour_mnist_model_to_load

# "unbiased_colour_mnist"
process_structure["unbias"]["model"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model
process_structure["unbias"]["train"] = train_unbiased_colour_mnist_model
process_structure["unbias"]["train_loader"] = unbiased_train_dataloader
process_structure["unbias"]["test_loader"] = unbiased_test_dataloader
process_structure["unbias"]["saveas"] = save_unbiased_colour_mnist_model_as
process_structure["unbias"]["loadfrom"] =  unbiased_colour_mnist_model_to_load



for key, val in process_structure.items():
    print(f"Processing for {key=}")
    if val["train"]:
        train_model(model=val["model"], train_loader=val["train_loader"], 
                    epochs=original_train_epochs, saveas=val["saveas"], 
                    description=key, device=device, logtofile=logtofile)
    else:
        logtofile(f"{val['loadfrom']=}")
        val["model"].load_state_dict(torch.load(val["loadfrom"], map_location=torch.device(device)))
    val["model"].eval()


Processing for key='mix'
val['loadfrom']='./results_4_epochs/2024-08-01_07-29-36_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_mix_mnist.weights'
Processing for key='bw'
val['loadfrom']='./results_4_epochs/2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_bw_mnist.weights'
Processing for key='bgonly'
val['loadfrom']='./results_4_epochs/2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_bg_only_colour_mnist.weights'
Processing for key='bg'
val['loadfrom']='./results_4_epochs/2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_bg_unbiased_colour_mnist.weights'
Processing for key='bias'
val['loadfrom']='./results_4_epochs/2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_biased_colour_mnist.weights'
Processing for key='unbias'
val['loadfrom']='./results_4_epochs/2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_unbiased_colour_mnist.weights'


## Measure the Accuracy, Record the Confusion Matrix


In [6]:
original_accuracy = dict()
for key, val in process_structure.items():
    logtofile(f"Accuracy Calculation for ResNet18 with {key=}")
    model = val["model"]
    model.eval() # ALWAYS DO THIS BEFORE YOU EVALUATE MODELS
    
    # Compute the model accuracy on the test set
    correct = 0
    total = 0
    
    # assuming 10 classes
    # rows represent actual class, columns are predicted
    confusion_matrix = torch.zeros(10,10, dtype=torch.int)
    
    # YOUR CODE HERE
    for data in val["test_loader"]:
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        predictions = torch.argmax(model(inputs),1)
        
        matches = predictions == labels
        correct += matches.sum().item()
        total += len(labels)
        for idx, l in enumerate(labels):
            confusion_matrix[l, predictions[idx]] = 1 + confusion_matrix[l, predictions[idx]] 
    
    logtofile("Test the Trained Resnet18")
    acc = ((100.0 * correct) / total)
    logtofile('Test Accuracy: %2.2f %%' % acc)
    original_accuracy[key] = acc
    logtofile('Confusion Matrix')
    logtofile(confusion_matrix)
    logtofile(confusion_matrix.sum())

logtofile(f"{original_accuracy=}")

Accuracy Calculation for ResNet18 with key='mix'
Test the Trained Resnet18
Test Accuracy: 99.26 %
Confusion Matrix
tensor([[ 946,    0,    2,    0,    0,    0,    0,    0,    1,    0],
        [   0, 1117,    2,    0,    0,    0,    0,    0,    0,    0],
        [   0,    0, 1061,    1,    0,    0,    0,    0,    0,    0],
        [   0,    0,    1, 1003,    0,    4,    0,    2,    2,    0],
        [   0,    1,    1,    0,  998,    0,    0,    1,    0,    3],
        [   0,    0,    0,    2,    0,  856,    0,    0,    3,    1],
        [   4,    1,    1,    0,    1,    0,  977,    0,    5,    0],
        [   0,    0,    5,    0,    0,    0,    0, 1028,    0,    0],
        [   2,    0,    5,    0,    0,    1,    0,    2,  944,    1],
        [   1,    2,    2,    0,    5,    2,    0,    2,    5,  996]],
       dtype=torch.int32)
tensor(10000)
Accuracy Calculation for ResNet18 with key='bw'
Test the Trained Resnet18
Test Accuracy: 98.26 %
Confusion Matrix
tensor([[ 976,    0,    2,    

## Measure Rank with original dataloader (test) before cutting and stitching

In [7]:



num_layers_in_model = len(list(model.children()))

# Specify the layer name you're interested in

#print(model.training)
#layer_index = 3
#output_shape = get_layer_output_shape(model, layer_index, input_image_shape)
#print(f"The shape of the output from layer {layer_index} is: {output_shape}")
#print(f"Output Size is : {output_shape.numel()}")




In [8]:
# For the Whole Model - but we will pass it through the RcvResNet18 function to get matching feature names
for key, val in process_structure.items():
    
    dl = val["test_loader"]
    
    if val["train"]:
        filename = val["saveas"] 
    else:    
        filename = val["loadfrom"] 
    assert os.path.exists(filename)
    mdl = torchvision.models.resnet18(num_classes=10) # Untrained model
    state = torch.load(filename, map_location=torch.device("cpu"))
    mdl.load_state_dict(state, assign=True)
    mdl=mdl.to(device)
    mdl = RcvResNet18(mdl, -1, colour_mnist_shape, device).to(device)

    out_filename = filename.split('/')[-1].replace('.weights', '-test.csv')
    
    outpath = f"./results_4_epochs_rank/{key}-{key}-{seed}_{out_filename}"  # denote output name as <model_training_type>-dataset-<name>
    
    if os.path.exists(f"{outpath}"):
        logtofile(f"Already evaluated for {outpath}")
        continue
    logtofile(f"Measure Rank for {key=}")
    print(f"output to {outpath}")
            
    params = {}
    params["model"] = key
    params["dataset"] = key
    params["seed"] = seed
    if val["train"]: # as only one network used, record its filename as both send and receive files
        params["send_file"] = val["saveas"] 
        params["rcv_file"] = val["saveas"] 
    else:    
        params["send_file"] = val["loadfrom"] 
        params["rcv_file"] = val["loadfrom"]
    with torch.no_grad():
        layers, features, handles = install_hooks(mdl)
        
        metrics = evaluate_model(mdl, dl, 'acc', verbose=2)
        params.update(metrics)
        classes = None
        df = perform_analysis(features, classes, layers, params, n=-1)
        df.to_csv(f"{outpath}")
                
    for h in handles:
        h.remove()
    del mdl, layers, features, metrics, params, df, handles
    gc.collect()

Already evaluated for ./results_4_epochs_rank/mix-mix-102_2024-08-01_07-29-36_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_mix_mnist-test.csv
Already evaluated for ./results_4_epochs_rank/bw-bw-102_2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_bw_mnist-test.csv
Already evaluated for ./results_4_epochs_rank/bgonly-bgonly-102_2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_bg_only_colour_mnist-test.csv
Already evaluated for ./results_4_epochs_rank/bg-bg-102_2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_bg_unbiased_colour_mnist-test.csv
Already evaluated for ./results_4_epochs_rank/bias-bias-102_2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_biased_colour_mnist-test.csv
Already evaluated for ./results_4_epochs_rank/unbias-unbias-102_2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_unbiased_colour_mnist-test.csv


## Train the stitch layer

In [None]:
stitching_accuracies = dict()
stitching_penalties = dict()

for key, val in process_structure.items():        
    stitching_accuracies[key] = dict()
    stitching_penalties[key] = dict()
    for layer_to_cut_after in range(3,num_layers_in_model - 1):
        ###################### Don't bother to stitch and train if we've already analysed it
        
        if val["train"]:
            filename = val["saveas"] 
        else:    
            filename = val["loadfrom"] 
        rank_filename = filename.split('/')[-1].replace('.weights', '-test.csv')        
        # denote output name as <model_training_type>-dataset-<name>
        # where <model_training_type> is [sender_model or X][layer_to_cut_after][Receiver_model]
        model_training_type = f"X{layer_to_cut_after}{key}"
        dataset_type = "synth"
        outpath = f"./results_4_epochs_rank/{model_training_type}-{dataset_type}-{seed}_{rank_filename}"  
                        
        if os.path.exists(f"{outpath}"):
            logtofile(f"Already evaluated for {outpath}")
            continue
        ####################################################################################
        logtofile(f"Evaluate ranks and output to {outpath}")
        logtofile(f"stitch into model {key}")
        
        
        model = torchvision.models.resnet18(num_classes=10).to(device)
        model.load_state_dict(torch.load(filename, map_location=torch.device(device)))  # uses either the load/save name depending whether it'
        cut_layer_output_size = get_layer_output_shape(model, layer_to_cut_after, colour_mnist_shape, device)
        model_cut = RcvResNet18(model, layer_to_cut_after, colour_mnist_shape, device).to(device)

        #############################################################
        # store the initial stitch state
        initial_stitch_weight = model_cut.stitch.s_conv1.weight.clone()
        initial_stitch_bias   = model_cut.stitch.s_conv1.bias.clone()
        stitch_initial_weight_outpath    = f"./results_4_epochs/STITCH_initial_weight_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
        stitch_initial_bias_outpath      = f"./results_4_epochs/STITCH_initial_bias_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
        torch.save(initial_stitch_weight, stitch_initial_weight_outpath)
        torch.save(initial_stitch_bias, stitch_initial_bias_outpath)
        ############################################################
        
        representation_shape=(cut_layer_output_size[1:])
        
        syn_activations = generate_activations(num_classes=10, representation_shape=representation_shape)
        logtofile(f"after layer {layer_to_cut_after}, activations shape is {syn_activations.shape}")
        
        syn_train_set  = SyntheticDataset(train=True, activations=syn_activations, noise=synthetic_dataset_noise)
        syn_trainloader  = DataLoader(syn_train_set, batch_size=64, shuffle=True, drop_last=True)
        syn_test_set  = SyntheticDataset(train=False, activations=syn_activations, noise=synthetic_dataset_noise)
        syn_testloader  = DataLoader(syn_test_set, batch_size=64, shuffle=False, drop_last=False)        
                
        # define the loss function and the optimiser
        loss_function = nn.CrossEntropyLoss()
        optimiser = optim.SGD(model_cut.parameters(), lr=1e-4, momentum=0.9, weight_decay=0.01)
        
        # Put top model into train mode so that bn and dropout perform in training mode
        model_cut.train()
        # Freeze the top model
        model_cut.requires_grad_(False)
        # Un-Freeze the stitch layer
        for name, param in model_cut.stitch.named_parameters():
            param.requires_grad_(True)
        logtofile(f"Train the stitch to a top model cut after layer {layer_to_cut_after}")    
        # the epoch loop: note that we're training the whole network
        for epoch in range(stitch_train_epochs):
            running_loss = 0.0
            for data in syn_trainloader:  # Use synthetic training data
                # data is (representations, labels) tuple
                # get the inputs and put them on the GPU
                inputs, labels = data                
                inputs = inputs.to(device)
                labels = labels.to(device)
        
                # zero the parameter gradients
                optimiser.zero_grad()
        
                # forward + loss + backward + optimise (update weights)
                outputs = model_cut(inputs)
                loss = loss_function(outputs, labels)
                loss.backward()
                optimiser.step()
        
                # keep track of the loss this epoch
                running_loss += loss.item()
            logtofile("Epoch %d, loss %4.2f" % (epoch, running_loss))
        logtofile('**** Finished Training ****')
        
        model_cut.eval() # ALWAYS DO THIS BEFORE YOU EVALUATE MODELS

        ############################################################
        # store the trained stitch
        trained_stitch_weight = model_cut.stitch.s_conv1.weight.clone()
        trained_stitch_bias   = model_cut.stitch.s_conv1.bias.clone()
        stitch_trained_weight_outpath    = f"./results_4_epochs/STITCH_trained_weight_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
        stitch_trained_bias_outpath      = f"./results_4_epochs/STITCH_trained_bias_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
        torch.save(trained_stitch_weight, stitch_trained_weight_outpath)
        torch.save(trained_stitch_bias, stitch_trained_bias_outpath)
        
        print(f"Number of weight / bias in stitch layer is {len(initial_stitch_weight)}")
        stitch_weight_diff = trained_stitch_weight - initial_stitch_weight
        stitch_weight_delta = torch.linalg.norm(stitch_weight_diff).item()
        logtofile(f"Change in stitch weights: {stitch_weight_delta}")
        maxabsweight =  torch.max(stitch_weight_diff.abs()).item()
        logtofile(f"Largest abs weight change: {maxabsweight}")
        stitch_weight_number = torch.sum(torch.where(stitch_weight_diff.abs() > 0.1*maxabsweight, True, False)).item()
        logtofile(f"Number of weights changing > 0.1 of that: {stitch_weight_number}")
        
        stitch_bias_diff = trained_stitch_bias - initial_stitch_bias
        stitch_bias_delta = torch.linalg.norm(stitch_bias_diff).item()
        logtofile(f"Change in stitch bias: {stitch_bias_delta}")
        maxabsbias =  torch.max(stitch_bias_diff.abs()).item()
        logtofile(f"Largest abs bias change: {maxabsbias}")
        stitch_bias_number = torch.sum(torch.where(stitch_bias_diff.abs() > 0.1*maxabsbias, True, False)).item()
        logtofile(f"Number of bias changing > 0.1 of that: {stitch_bias_number}")

        #new_tensor = torch.load("foo_1_state.pt")
        ##############################################################
        
        logtofile("Test the trained stitch")        
        # Compute the model accuracy on the test set
        correct = 0
        total = 0
        
        # assuming 10 classes
        # rows represent actual class, columns are predicted
        confusion_matrix = torch.zeros(10,10, dtype=torch.int)
        
        for data in syn_testloader:
            inputs, labels = data
            #print(inputs)   
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            predictions = torch.argmax(model_cut(inputs),1)
            #print(model_cut(inputs))
            matches = predictions == labels.to(device)
            correct += matches.sum().item()
            total += len(labels)
        
            for idx, l in enumerate(labels):
                confusion_matrix[l, predictions[idx]] = 1 + confusion_matrix[l, predictions[idx]] 
        acc = ((100.0 * correct) / total)
        logtofile('Test Accuracy: %2.2f %%' % acc)
        logtofile('Confusion Matrix')
        logtofile(confusion_matrix)
        logtofile("===================================================================")
        stitching_accuracies[key][layer_to_cut_after] = acc
        stitching_penalties[key][layer_to_cut_after] = original_accuracy[key] - acc        
        
        dl = syn_testloader        
        params = {}
        params["model"] = model_training_type
        params["dataset"] = dataset_type
        params["seed"] = seed
        if val["train"]: # as only one network used, record its filename as both send and receive files
            params["rcv_file"] = val["saveas"] 
        else:    
            params["rcv_file"] = val["loadfrom"]
        params["stitch_weight_delta"] = stitch_weight_delta
        params["stitch_bias_delta"] = stitch_bias_delta        
        params["stitch_weight_number"] = stitch_weight_number
        params["stitch_bias_number"] = stitch_bias_number
        
        with torch.no_grad():
            layers, features, handles = install_hooks(model_cut)
            
            metrics = evaluate_model(model_cut, dl, 'acc', verbose=2)
            params.update(metrics)
            classes = None
            df = perform_analysis(features, classes, layers, params, n=-1)
            df.to_csv(f"{outpath}")
            
        for h in handles:
            h.remove()
        del model_cut, layers, features, metrics, params, df, handles
        gc.collect()

Already evaluated for ./results_4_epochs_rank/X3mix-synth-102_2024-08-01_07-29-36_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_mix_mnist-test.csv
Evaluate ranks and output to ./results_4_epochs_rank/X3mix-synth-102_2024-08-01_07-29-36_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_mix_mnist-test.csv
stitch into model mix
get_layer_output_shape for type='ResNet18'
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 3 is: torch.Size([1, 64, 7, 7]), with 3136 elements
after layer 3, activations shape is torch.Size([10, 64, 7, 7])
Train the stitch to a top model cut after layer 3
Epoch 0, loss 2278.09
Epoch 1, loss 38.18
Epoch 2, loss 8.82
Epoch 3, loss 43.67
**** Finished Training ****
Number of weight / bias in stitch layer is 64
Change in stitch weights: 1.457006573677063
Largest abs weight change: 0.1278412938117981
Number of weights changing > 0.1 of that: 2084
Change in stitch bias: 0.020712366327643394
Largest abs bias change: 0.00439857691526413
Number of bias changing > 

0/1(e):   0%|          | 0/157 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 21/21 [00:13<00:00,  1.53it/s]


Already evaluated for ./results_4_epochs_rank/X4mix-synth-102_2024-08-01_07-29-36_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_mix_mnist-test.csv
Evaluate ranks and output to ./results_4_epochs_rank/X4mix-synth-102_2024-08-01_07-29-36_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_mix_mnist-test.csv
stitch into model mix
get_layer_output_shape for type='ResNet18'
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 4 is: torch.Size([1, 64, 7, 7]), with 3136 elements
after layer 4, activations shape is torch.Size([10, 64, 7, 7])
Train the stitch to a top model cut after layer 4
Epoch 0, loss 470.76
Epoch 1, loss 101.95
Epoch 2, loss 7.95
Epoch 3, loss 4.37
**** Finished Training ****
Number of weight / bias in stitch layer is 64
Change in stitch weights: 1.2019702196121216
Largest abs weight change: 0.0830741822719574
Number of weights changing > 0.1 of that: 2526
Change in stitch bias: 0.017709776759147644
Largest abs bias change: 0.004533551633358002
Number of bias changing >

0/1(e):   0%|          | 0/157 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 17/17 [00:05<00:00,  2.84it/s]


Already evaluated for ./results_4_epochs_rank/X5mix-synth-102_2024-08-01_07-29-36_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_mix_mnist-test.csv
Evaluate ranks and output to ./results_4_epochs_rank/X5mix-synth-102_2024-08-01_07-29-36_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_mix_mnist-test.csv
stitch into model mix
get_layer_output_shape for type='ResNet18'
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 5 is: torch.Size([1, 128, 4, 4]), with 2048 elements
after layer 5, activations shape is torch.Size([10, 128, 4, 4])
Train the stitch to a top model cut after layer 5
Epoch 0, loss 1171.08
Epoch 1, loss 3.09
Epoch 2, loss 32.30
Epoch 3, loss 1.17
**** Finished Training ****
Number of weight / bias in stitch layer is 128
Change in stitch weights: 1.2826297283172607
Largest abs weight change: 0.05589667707681656
Number of weights changing > 0.1 of that: 8576
Change in stitch bias: 0.020607024431228638
Largest abs bias change: 0.0032383203506469727
Number of bias chang

0/1(e):   0%|          | 0/157 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 12/12 [00:01<00:00,  6.77it/s]


Already evaluated for ./results_4_epochs_rank/X6mix-synth-102_2024-08-01_07-29-36_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_mix_mnist-test.csv
Evaluate ranks and output to ./results_4_epochs_rank/X6mix-synth-102_2024-08-01_07-29-36_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_mix_mnist-test.csv
stitch into model mix
get_layer_output_shape for type='ResNet18'
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 6 is: torch.Size([1, 256, 2, 2]), with 1024 elements
after layer 6, activations shape is torch.Size([10, 256, 2, 2])
Train the stitch to a top model cut after layer 6
Epoch 0, loss 276.84
Epoch 1, loss 42.72
Epoch 2, loss 2.29
Epoch 3, loss 1.10
**** Finished Training ****
Number of weight / bias in stitch layer is 256
Change in stitch weights: 1.2064861059188843
Largest abs weight change: 0.027446232736110687
Number of weights changing > 0.1 of that: 34003
Change in stitch bias: 0.021246284246444702
Largest abs bias change: 0.0022921524941921234
Number of bias chan

0/1(e):   0%|          | 0/157 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 12.43it/s]


Already evaluated for ./results_4_epochs_rank/X7mix-synth-102_2024-08-01_07-29-36_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_mix_mnist-test.csv
Evaluate ranks and output to ./results_4_epochs_rank/X7mix-synth-102_2024-08-01_07-29-36_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_mix_mnist-test.csv
stitch into model mix
get_layer_output_shape for type='ResNet18'
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 7 is: torch.Size([1, 512, 1, 1]), with 512 elements
after layer 7, activations shape is torch.Size([10, 512, 1, 1])
Train the stitch to a top model cut after layer 7
Epoch 0, loss 60.27
Epoch 1, loss 0.38
Epoch 2, loss 0.25
Epoch 3, loss 0.19
**** Finished Training ****
Number of weight / bias in stitch layer is 512
Change in stitch weights: 0.7143695950508118
Largest abs weight change: 0.00882827490568161
Number of weights changing > 0.1 of that: 139718
Change in stitch bias: 0.021571150049567223
Largest abs bias change: 0.0016194134950637817
Number of bias changin

0/1(e):   0%|          | 0/157 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 21.92it/s]


Already evaluated for ./results_4_epochs_rank/X8mix-synth-102_2024-08-01_07-29-36_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_mix_mnist-test.csv
Evaluate ranks and output to ./results_4_epochs_rank/X8mix-synth-102_2024-08-01_07-29-36_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_mix_mnist-test.csv
stitch into model mix
get_layer_output_shape for type='ResNet18'
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 8 is: torch.Size([1, 512, 1, 1]), with 512 elements
after layer 8, activations shape is torch.Size([10, 512, 1, 1])
Train the stitch to a top model cut after layer 8
Epoch 0, loss 66.12
Epoch 1, loss 0.23
Epoch 2, loss 0.18
Epoch 3, loss 0.14
**** Finished Training ****
Number of weight / bias in stitch layer is 512
Change in stitch weights: 0.6895490288734436
Largest abs weight change: 0.007533423602581024
Number of weights changing > 0.1 of that: 154878
Change in stitch bias: 0.021245047450065613
Largest abs bias change: 0.0016216598451137543
Number of bias changi

0/1(e):   0%|          | 0/157 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 22.74it/s]


Already evaluated for ./results_4_epochs_rank/X3bw-synth-102_2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_bw_mnist-test.csv
Evaluate ranks and output to ./results_4_epochs_rank/X3bw-synth-102_2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_bw_mnist-test.csv
stitch into model bw
get_layer_output_shape for type='ResNet18'
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 3 is: torch.Size([1, 64, 7, 7]), with 3136 elements
after layer 3, activations shape is torch.Size([10, 64, 7, 7])
Train the stitch to a top model cut after layer 3
Epoch 0, loss 572.08
Epoch 1, loss 4.02
Epoch 2, loss 5.04
Epoch 3, loss 4.71
**** Finished Training ****
Number of weight / bias in stitch layer is 64
Change in stitch weights: 0.8689684271812439
Largest abs weight change: 0.06449682265520096
Number of weights changing > 0.1 of that: 2308
Change in stitch bias: 0.020167553797364235
Largest abs bias change: 0.004552669823169708
Number of bias changing > 0.1 o

0/1(e):   0%|          | 0/157 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 21/21 [00:14<00:00,  1.46it/s]


Already evaluated for ./results_4_epochs_rank/X4bw-synth-102_2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_bw_mnist-test.csv
Evaluate ranks and output to ./results_4_epochs_rank/X4bw-synth-102_2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_bw_mnist-test.csv
stitch into model bw
get_layer_output_shape for type='ResNet18'
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 4 is: torch.Size([1, 64, 7, 7]), with 3136 elements
after layer 4, activations shape is torch.Size([10, 64, 7, 7])
Train the stitch to a top model cut after layer 4
Epoch 0, loss 829.42
Epoch 1, loss 6.63
Epoch 2, loss 5.45
Epoch 3, loss 3.87
**** Finished Training ****
Number of weight / bias in stitch layer is 64
Change in stitch weights: 1.0920593738555908
Largest abs weight change: 0.09709145128726959
Number of weights changing > 0.1 of that: 2155
Change in stitch bias: 0.020634783431887627
Largest abs bias change: 0.004514455795288086
Number of bias changing > 0.1 o

0/1(e):   0%|          | 0/157 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 17/17 [00:06<00:00,  2.74it/s]


Already evaluated for ./results_4_epochs_rank/X5bw-synth-102_2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_bw_mnist-test.csv
Evaluate ranks and output to ./results_4_epochs_rank/X5bw-synth-102_2024-08-01_11-00-22_SEED101_EPOCHS4_BGN0.1_exp1e_ResNet18_bw_mnist-test.csv
stitch into model bw
get_layer_output_shape for type='ResNet18'
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 5 is: torch.Size([1, 128, 4, 4]), with 2048 elements
after layer 5, activations shape is torch.Size([10, 128, 4, 4])
Train the stitch to a top model cut after layer 5


In [None]:

stitch_weight_diff = trained_stitch_weight - initial_stitch_weight
stitch_weight_delta = torch.linalg.norm(stitch_weight_diff).item()
print(f"Change in stitch weights: {stitch_weight_delta}")
maxabsweight =  torch.max(stitch_weight_diff.abs()).item()
print(f"Largest abs weight change: {maxabsweight}")
stitch_weight_number = torch.sum(torch.where(stitch_weight_diff.abs() > 0.1*maxabsweight, True, False)).item()
print(f"Number of weights changing > 0.1 of that: {stitch_weight_number}")

print(stitch_weight_delta)
print(stitch_weight_number)

In [None]:
logtofile(f"{stitching_accuracies=}")
logtofile(f"{stitching_penalties=}")

In [None]:
for r_key in stitching_accuracies:
    logtofile(f"synth-{r_key}")
    logtofile(original_accuracy[r_key])
    logtofile("Stitch Accuracy")
    for layer in stitching_accuracies[r_key]:
        logtofile(stitching_accuracies[r_key][layer])
    logtofile("--------------------------")