# To investigate Hernandez 2023 model-stitching analysis on VGG19
Use VGG19 models trained on different datasets and test stitching at various layers. 
Based on exp1c and exp1f

In [1]:
# Packages
%matplotlib inline

import argparse
import gc
import os.path

import pandas as pd
from torch.linalg import LinAlgError

import matplotlib.pyplot as plt
import torchvision
import torch
from torch import optim

from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import datetime

import random
import numpy as np

import sys
import os
# add the path to find colour_mnist
sys.path.append(os.path.abspath('../ReferenceCode'))
import colour_mnist
from stitch_utils import train_model, RcvVGG19, get_layer_output_shape
from stitch_utils import generate_activations, SyntheticDataset
import stitch_utils

# add the path to find the rank analysis code
# https://github.com/DHLSmith/jons-tunnel-effect/tree/NeurIPSPaper

sys.path.append(os.path.abspath('../../jons-tunnel-effect/'))
from utils.modelfitting import evaluate_model, set_seed
from extract_weight_rank import install_hooks, perform_analysis

import torchvision
from torchvision.datasets import MNIST

def logtofile(log_text, verbose=True):
    if verbose:
        print(log_text)
    with open(save_log_as, "a") as f:    
        print(log_text, file=f)

In [2]:
# Set Parameters

# fix random seed for reproducibility
seed = 104
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
random.seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

train_all = True

# MIX is 1/3 bgonly, 1/3 mnist only, 1/3 biased data
train_mix_mnist_model = train_all
mix_mnist_model_to_load = './results_1g/2024-08-11_11-32-44_SEED104_EPOCHS50_BGN0.1_exp1g_VGG19_mix_mnist.weights'

# BW is greyscale mnist with no colour added (i.e. original mnist)
train_bw_mnist_model = train_all  # when False, automatically loads a trained model
bw_mnist_model_to_load = './results_1g/2024-08-10_22-57-13_SEED104_EPOCHS50_BGN0.1_exp1g_VGG19_bw_mnist.weights'

# BG_ONLY contains no mnist data, just a coloured background
train_bg_only_colour_mnist_model = train_all # when False, automatically loads a trained model
bg_only_colour_mnist_model_to_load =  './results_1g/2024-08-10_22-57-13_SEED104_EPOCHS50_BGN0.1_exp1g_VGG19_bg_only_colour_mnist.weights'

# BG_UNBIASED is digits with randomly selected colour background. Targets represent the colour
train_bg_unbiased_colour_mnist_model = train_all  # when False, automatically loads a trained model
bg_unbiased_colour_mnist_model_to_load = './results_1g/2024-08-11_11-32-44_SEED104_EPOCHS50_BGN0.1_exp1g_VGG19_bg_unbiased_colour_mnist.weights'

# BIASED is digits with consistent per-class colour background. 
train_biased_colour_mnist_model = train_all  # when False, automatically loads a trained model
biased_colour_mnist_model_to_load = './results_1g/2024-08-11_11-32-44_SEED104_EPOCHS50_BGN0.1_exp1g_VGG19_biased_colour_mnist.weights'

# UNBIASED is digits with randoly selected colour background. Targets are digit values
train_unbiased_colour_mnist_model = train_all  # when False, automatically loads a trained model
unbiased_colour_mnist_model_to_load = './results_1g/2024-08-10_22-57-13_SEED104_EPOCHS50_BGN0.1_exp1g_VGG19_unbiased_colour_mnist.weights'

original_train_epochs = 50
bg_noise = 0.1
synthetic_dataset_noise = 0.1

stitch_train_epochs = 50
#stitch_train_loss_tol = 40

device = 'cuda:0'


In [3]:
# Generate filenames and log the setup details
formatted_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
filename_prefix = f"./results_1g/{formatted_time}_SEED{seed}_EPOCHS{original_train_epochs}_BGN{bg_noise}_exp1g_VGG19"
save_mix_mnist_model_as = f"{filename_prefix}_mix_mnist.weights"
save_bw_mnist_model_as = f"{filename_prefix}_bw_mnist.weights"
save_bg_only_colour_mnist_model_as = f"{filename_prefix}_bg_only_colour_mnist.weights"
save_bg_unbiased_colour_mnist_model_as = f"{filename_prefix}_bg_unbiased_colour_mnist.weights"
save_biased_colour_mnist_model_as = f"{filename_prefix}_biased_colour_mnist.weights"
save_unbiased_colour_mnist_model_as = f"{filename_prefix}_unbiased_colour_mnist.weights"
save_log_as = f"{filename_prefix}_log.txt"

colour_mnist_shape = (3,28,28)
vgg19_input_shape = (3,32,32) # monochrome

logtofile(f"Executed at {formatted_time}")
logtofile(f"logging to {save_log_as}")
logtofile(f"{seed=}")
logtofile(f"{bg_noise=}")
logtofile(f"{synthetic_dataset_noise=}")

logtofile(f"{train_mix_mnist_model=}")
if train_mix_mnist_model:
    logtofile(f"{save_mix_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{mix_mnist_model_to_load=}")

logtofile(f"{train_bw_mnist_model=}")
if train_bw_mnist_model:
    logtofile(f"{save_bw_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{bw_mnist_model_to_load=}")
    
logtofile(f"{train_bg_only_colour_mnist_model=}")
if train_bg_only_colour_mnist_model:
    logtofile(f"{save_bg_only_colour_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{bg_only_colour_mnist_model_to_load=}")
    
logtofile(f"{train_bg_unbiased_colour_mnist_model=}")
if train_bg_unbiased_colour_mnist_model:
    logtofile(f"{save_bg_unbiased_colour_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{bg_unbiased_colour_mnist_model_to_load=}")

logtofile(f"{train_biased_colour_mnist_model=}")
if train_biased_colour_mnist_model:
    logtofile(f"{save_biased_colour_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{biased_colour_mnist_model_to_load=}")

logtofile(f"{train_unbiased_colour_mnist_model=}")
if train_unbiased_colour_mnist_model:
    logtofile(f"{save_unbiased_colour_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{unbiased_colour_mnist_model_to_load=}")

logtofile(f"{stitch_train_epochs=}")
logtofile(f"================================================")

Executed at 2025-01-13_16-05-59
logging to ./results_1g/2025-01-13_16-05-59_SEED104_EPOCHS50_BGN0.1_exp1g_VGG19_log.txt
seed=104
bg_noise=0.1
synthetic_dataset_noise=0.1
train_mix_mnist_model=True
save_mix_mnist_model_as='./results_1g/2025-01-13_16-05-59_SEED104_EPOCHS50_BGN0.1_exp1g_VGG19_mix_mnist.weights'
original_train_epochs=50
train_bw_mnist_model=True
save_bw_mnist_model_as='./results_1g/2025-01-13_16-05-59_SEED104_EPOCHS50_BGN0.1_exp1g_VGG19_bw_mnist.weights'
original_train_epochs=50
train_bg_only_colour_mnist_model=True
save_bg_only_colour_mnist_model_as='./results_1g/2025-01-13_16-05-59_SEED104_EPOCHS50_BGN0.1_exp1g_VGG19_bg_only_colour_mnist.weights'
original_train_epochs=50
train_bg_unbiased_colour_mnist_model=True
save_bg_unbiased_colour_mnist_model_as='./results_1g/2025-01-13_16-05-59_SEED104_EPOCHS50_BGN0.1_exp1g_VGG19_bg_unbiased_colour_mnist.weights'
original_train_epochs=50
train_biased_colour_mnist_model=True
save_biased_colour_mnist_model_as='./results_1g/2025-01-13

mnist and cifar-10 both use 10-classes, with 60_000 train samples and 10_000 test samples. 

In [4]:
# Set up dataloaders
transform_resize=transforms.Resize(vgg19_input_shape[2])
transform_bw = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert to 3 channels 
    transform_resize,
    transforms.ToTensor(),  # convert to tensor. We always do this one    
    transforms.Normalize((0.1307,) * 3, (0.3081,) * 3)     
])


mnist_train = MNIST("./MNIST", train=True, download=True, transform=transform_bw)
mnist_test = MNIST("./MNIST", train=False, download=True, transform=transform_bw)

bw_train_dataloader = DataLoader(mnist_train, batch_size=64, shuffle=True, drop_last=True)
bw_test_dataloader  = DataLoader(mnist_test,  batch_size=64, shuffle=True, drop_last=False)

# mix dataloader
mix_train_dataloader = colour_mnist.get_mixed_mnist_dataloader(root="./MNIST", batch_size=64, train=True, bg_noise_level=bg_noise, standard_getitem=True, dl_transform=transform_resize)
mix_test_dataloader = colour_mnist.get_mixed_mnist_dataloader(root="./MNIST", batch_size=64,  train=False, bg_noise_level=bg_noise, standard_getitem=True, dl_transform=transform_resize)

# bg_only means no digits - we will use colour as label
bg_only_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=64, data_label_correlation=1.0, train=True, bg_noise_level=bg_noise, bg_only=True, standard_getitem=True, dl_transform=transform_resize)
bg_only_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=64, data_label_correlation=1.0, train=False, bg_noise_level=bg_noise, bg_only=True, standard_getitem=True, dl_transform=transform_resize)

# unbiased means each digit has correct label and random colour - but bg means we will use colour as label (i.e. the bias_target will be the target)
bg_unbiased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=64, data_label_correlation=0.1, train=True, bg_noise_level=bg_noise, bias_targets_as_targets=True, dl_transform=transform_resize)
bg_unbiased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=64, data_label_correlation=0.1, train=False, bg_noise_level=bg_noise, bias_targets_as_targets=True, dl_transform=transform_resize)

# biased means each digit has correct label and consistent colour - Expect network to learn the colours only
biased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=64, data_label_correlation=1.0, train=True, bg_noise_level=bg_noise, standard_getitem=True, dl_transform=transform_resize)
biased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=64, data_label_correlation=1.0, train=False, bg_noise_level=bg_noise, standard_getitem=True, dl_transform=transform_resize)

# unbiased means each digit has correct label and random colour - Expect network to disregard colours?
unbiased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=64, data_label_correlation=0.1, train=True, bg_noise_level=bg_noise, standard_getitem=True, dl_transform=transform_resize)
unbiased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=64, data_label_correlation=0.1, train=False, bg_noise_level=bg_noise, standard_getitem=True, dl_transform=transform_resize)

## Set up VGG19 models and train it on versions of MNIST

In [None]:

process_structure = dict()

process_structure["bw"] = dict()
process_structure["mix"] = dict()
process_structure["bgonly"] = dict()
process_structure["bg"] = dict()
process_structure["bias"]      = dict()
process_structure["unbias"]    = dict()

# "bw"
process_structure["bw"]["model"] = torchvision.models.vgg19(weights=None, num_classes=10).to(device) # Untrained model
process_structure["bw"]["train"] = train_bw_mnist_model 
process_structure["bw"]["train_loader"] = bw_train_dataloader
process_structure["bw"]["test_loader"] = bw_test_dataloader
process_structure["bw"]["saveas"] = save_bw_mnist_model_as
process_structure["bw"]["loadfrom"] = bw_mnist_model_to_load

# "mix"
process_structure["mix"]["model"] = torchvision.models.vgg19(weights=None, num_classes=10).to(device) # Untrained model
process_structure["mix"]["train"] = train_mix_mnist_model 
process_structure["mix"]["train_loader"] = mix_train_dataloader
process_structure["mix"]["test_loader"] = mix_test_dataloader
process_structure["mix"]["saveas"] = save_mix_mnist_model_as
process_structure["mix"]["loadfrom"] = mix_mnist_model_to_load

# "bg_only_colour"
process_structure["bgonly"]["model"] = torchvision.models.vgg19(weights=None, num_classes=10).to(device) # Untrained model
process_structure["bgonly"]["train"] = train_bg_only_colour_mnist_model 
process_structure["bgonly"]["train_loader"] = bg_only_train_dataloader
process_structure["bgonly"]["test_loader"] = bg_only_test_dataloader
process_structure["bgonly"]["saveas"] = save_bg_only_colour_mnist_model_as
process_structure["bgonly"]["loadfrom"] = bg_only_colour_mnist_model_to_load

# "bg_unbiased_colour"
process_structure["bg"]["model"] = torchvision.models.vgg19(weights=None, num_classes=10).to(device) # Untrained model
process_structure["bg"]["train"] = train_bg_unbiased_colour_mnist_model 
process_structure["bg"]["train_loader"] = bg_unbiased_train_dataloader
process_structure["bg"]["test_loader"] = bg_unbiased_test_dataloader
process_structure["bg"]["saveas"] = save_bg_unbiased_colour_mnist_model_as
process_structure["bg"]["loadfrom"] = bg_unbiased_colour_mnist_model_to_load

# "biased_colour_mnist"
process_structure["bias"]["model"] = torchvision.models.vgg19(weights=None, num_classes=10).to(device) # Untrained model
process_structure["bias"]["train"] = train_biased_colour_mnist_model
process_structure["bias"]["train_loader"] = biased_train_dataloader
process_structure["bias"]["test_loader"] = biased_test_dataloader
process_structure["bias"]["saveas"] = save_biased_colour_mnist_model_as
process_structure["bias"]["loadfrom"] =  biased_colour_mnist_model_to_load

# "unbiased_colour_mnist"
process_structure["unbias"]["model"] = torchvision.models.vgg19(weights=None, num_classes=10).to(device) # Untrained model
process_structure["unbias"]["train"] = train_unbiased_colour_mnist_model
process_structure["unbias"]["train_loader"] = unbiased_train_dataloader
process_structure["unbias"]["test_loader"] = unbiased_test_dataloader
process_structure["unbias"]["saveas"] = save_unbiased_colour_mnist_model_as
process_structure["unbias"]["loadfrom"] =  unbiased_colour_mnist_model_to_load

for key, val in process_structure.items():
    print(f"Processing for {key=}")
    if val["train"]:
        train_model(model=val["model"], train_loader=val["train_loader"], 
                    epochs=original_train_epochs, saveas=val["saveas"], 
                    description=key, device=device, logtofile=logtofile,                     
                    milestones=[100,150],
                    lr=0.01, #0.1 didn't work for some datasets                    
                   )
    else:
        logtofile(f"{val['loadfrom']=}")
        val["model"].load_state_dict(torch.load(val["loadfrom"], map_location=torch.device(device)))
    val["model"].eval()


Processing for key='bw'
Train Model on bw
Epoch 1, loss 2156.62
Epoch 2, loss 2156.53
Epoch 3, loss 2156.48
Epoch 4, loss 2156.42
Epoch 5, loss 2156.38
Epoch 6, loss 2156.34
Epoch 7, loss 2156.32
Epoch 8, loss 2156.30
Epoch 9, loss 2156.29
Epoch 10, loss 2156.26
Epoch 11, loss 2156.25
Epoch 12, loss 2156.24
Epoch 13, loss 2156.23
Epoch 14, loss 2156.23
Epoch 15, loss 2156.23
Epoch 16, loss 2156.22
Epoch 17, loss 2156.21
Epoch 18, loss 2156.21
Epoch 19, loss 2156.21
Epoch 20, loss 2156.20
Epoch 21, loss 2156.20
Epoch 22, loss 2156.19
Epoch 23, loss 2156.19
Epoch 24, loss 2156.19
Epoch 25, loss 2156.20
Epoch 26, loss 2156.19
Epoch 27, loss 2156.19
Epoch 28, loss 2156.19
Epoch 29, loss 2156.18
Epoch 30, loss 2156.18
Epoch 31, loss 2156.18
Epoch 32, loss 2156.20
Epoch 33, loss 2156.19
Epoch 34, loss 2156.19
Epoch 35, loss 2156.19
Epoch 36, loss 2156.19
Epoch 37, loss 2156.19
Epoch 38, loss 2156.19
Epoch 39, loss 2156.19
Epoch 40, loss 2156.19
Epoch 41, loss 2156.18
Epoch 42, loss 2156.19
E

## Measure the Accuracy, Record the Confusion Matrix


In [None]:


original_accuracy = dict()
for key, val in process_structure.items():
    logtofile(f"Accuracy Calculation for VGG19 with {key=}")
    model = val["model"]
    model.eval() # ALWAYS DO THIS BEFORE YOU EVALUATE MODELS
    
    # Compute the model accuracy on the test set
    correct = 0
    total = 0
    
    # assuming 10 classes
    # rows represent actual class, columns are predicted
    confusion_matrix = torch.zeros(10,10, dtype=torch.int)
    
    for data in val["test_loader"]:
        inputs, labels = data
        
        inputs = inputs.to(device)
        labels = labels.to(device)
        predictions = torch.argmax(model(inputs),1)
        
        matches = predictions == labels
        correct += matches.sum().item()
        total += len(labels)
        for idx, l in enumerate(labels):
            confusion_matrix[l, predictions[idx]] = 1 + confusion_matrix[l, predictions[idx]] 
    
    logtofile("Test the Trained VGG19")
    acc = ((100.0 * correct) / total)
    logtofile('Test Accuracy: %2.2f %%' % acc)
    original_accuracy[key] = acc
    logtofile('Confusion Matrix')
    logtofile(confusion_matrix)
    logtofile(confusion_matrix.sum())

logtofile(f"{original_accuracy=}")

## Measure Rank with original dataloader (test) before cutting and stitching

In [None]:
# For the Whole Model - but we will pass it through the RcvResNet18 function to get matching feature names
for key, val in process_structure.items():
    
    dl = val["test_loader"]
    
    if val["train"]:
        filename = val["saveas"] 
    else:    
        filename = val["loadfrom"] 
    assert os.path.exists(filename)

    model = torchvision.models.vgg19(weights=None, num_classes=10)  # Untrained model

    state = torch.load(filename, map_location=torch.device("cpu"))
    model.load_state_dict(state, assign=True)
    model = model.to(device)
                  
    model = RcvVGG19(model, -1, vgg19_input_shape, device).to(device)
    out_filename = filename.split('/')[-1].replace('.weights', '-test.csv')
    
    outpath = f"./results_1g_rank/{key}-{key}-{seed}_{out_filename}"  # denote output name as <model_training_type>-dataset-<name>
    
    if os.path.exists(f"{outpath}"):
        logtofile(f"Already evaluated for {outpath}")
        continue
    logtofile(f"Measure Rank for {key=}")
    print(f"output to {outpath}")
            
    params = {}
    params["model"] = key
    params["dataset"] = key
    params["seed"] = seed
    if val["train"]: # as only one network used, record its filename as both send and receive files
        params["send_file"] = val["saveas"] 
        params["rcv_file"] = val["saveas"] 
    else:    
        params["send_file"] = val["loadfrom"] 
        params["rcv_file"] = val["loadfrom"]
    with torch.no_grad():
        layers, features, handles = install_hooks(model)
        
        metrics = evaluate_model(model, dl, 'acc', verbose=2)
        params.update(metrics)
        classes = None
        df = perform_analysis(features, classes, layers, params, n=8000)
        df.to_csv(f"{outpath}")
                
    for h in handles:
        h.remove()
    del model, layers, features, metrics, params, df, handles
    gc.collect()

# Stitch at a given layer


## Train the stitch layer

In [None]:
layer_to_cut_after = 1
device = "cpu"
model = torchvision.models.vgg19(weights=None, num_classes=10).to(device)  # Untrained model
#model.load_state_dict(torch.load(filename, map_location=torch.device(device)))  # uses either the load/save name depending whether it'
cut_layer_output_size = get_layer_output_shape(model, layer_to_cut_after, vgg19_input_shape, device, type="VGG19")
model_cut = RcvVGG19(model, layer_to_cut_after, vgg19_input_shape, device).to(device)
print(model_cut)

In [None]:
stitching_accuracies = dict()
stitching_penalties = dict()

for key, val in process_structure.items():        
    stitching_accuracies[key] = dict()
    stitching_penalties[key] = dict()
    for layer_to_cut_after in [1,8,22,29,35]:  #  [1,3,6,8,11,13,15,17,20,22,24,26,29,31,33,35]
        ###################### Don't bother to stitch and train if we've already analysed it
        
        if val["train"]:
            filename = val["saveas"] 
        else:    
            filename = val["loadfrom"] 
        rank_filename = filename.split('/')[-1].replace('.weights', '-test.csv')        
        # denote output name as <model_training_type>-dataset-<name>
        # where <model_training_type> is [sender_model or X][layer_to_cut_after][Receiver_model]
        model_training_type = f"X{layer_to_cut_after}{key}"
        dataset_type = "synth"
        outpath = f"./results_1g_rank/{model_training_type}-{dataset_type}-{seed}_{rank_filename}"  
                        
        if os.path.exists(f"{outpath}"):
            logtofile(f"Already evaluated for {outpath}")
            continue
        ####################################################################################
        logtofile(f"Evaluate ranks and output to {outpath}")
        logtofile(f"stitch into model {key}")
                
        model = torchvision.models.vgg19(weights=None, num_classes=10).to(device)  # Untrained model
        model.load_state_dict(torch.load(filename, map_location=torch.device(device)))  # uses either the load/save name depending whether it'
        cut_layer_output_size = get_layer_output_shape(model, layer_to_cut_after, vgg19_input_shape, device, type="VGG19")
        model_cut = RcvVGG19(model, layer_to_cut_after, vgg19_input_shape, device).to(device)

        #############################################################
        # store the initial stitch state
        initial_stitch_weight = model_cut.stitch.s_conv1.weight.clone()
        initial_stitch_bias   = model_cut.stitch.s_conv1.bias.clone()
        stitch_initial_weight_outpath    = f"./results_1g/STITCH_initial_weight_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
        stitch_initial_bias_outpath      = f"./results_1g/STITCH_initial_bias_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
        #torch.save(initial_stitch_weight, stitch_initial_weight_outpath)
        #torch.save(initial_stitch_bias, stitch_initial_bias_outpath)
        ############################################################
        
        representation_shape=(cut_layer_output_size[1:])
        
        syn_activations = generate_activations(num_classes=10, representation_shape=representation_shape)
        logtofile(f"after layer {layer_to_cut_after}, activations shape is {syn_activations.shape}")
        
        syn_train_set  = SyntheticDataset(train=True, activations=syn_activations, noise=synthetic_dataset_noise)
        syn_trainloader  = DataLoader(syn_train_set, batch_size=64, shuffle=True, drop_last=True)
        syn_test_set  = SyntheticDataset(train=False, activations=syn_activations, noise=synthetic_dataset_noise)
        syn_testloader  = DataLoader(syn_test_set, batch_size=64, shuffle=False, drop_last=False)        
                
        # define the loss function and the optimiser
        loss_function = nn.CrossEntropyLoss()
        optimiser = optim.SGD(model_cut.parameters(), lr=1e-3, momentum=0.9, weight_decay=0.01)
        
        # Put top model into train mode so that bn and dropout perform in training mode
        model_cut.train()
        # Freeze the top model
        model_cut.requires_grad_(False)
        # Un-Freeze the stitch layer
        for name, param in model_cut.stitch.named_parameters():
            param.requires_grad_(True)
        logtofile(f"Train the stitch to a top model cut after layer {layer_to_cut_after}")    
        # the epoch loop: note that we're training the whole network
        for epoch in range(stitch_train_epochs):
            running_loss = 0.0
            for data in syn_trainloader:  # Use synthetic training data
                # data is (representations, labels) tuple
                # get the inputs and put them on the GPU
                inputs, labels = data
                                
                inputs = inputs.to(device)
                labels = labels.to(device)
        
                # zero the parameter gradients
                optimiser.zero_grad()
        
                # forward + loss + backward + optimise (update weights)
                outputs = model_cut(inputs)
                loss = loss_function(outputs, labels)
                loss.backward()
                optimiser.step()
        
                # keep track of the loss this epoch
                running_loss += loss.item()
            logtofile("Epoch %d, loss %4.2f" % (epoch, running_loss))
        logtofile('**** Finished Training ****')
        
        model_cut.eval() # ALWAYS DO THIS BEFORE YOU EVALUATE MODELS

        ############################################################
        # store the trained stitch
        trained_stitch_weight = model_cut.stitch.s_conv1.weight.clone()
        trained_stitch_bias   = model_cut.stitch.s_conv1.bias.clone()
        stitch_trained_weight_outpath    = f"./results_1g/STITCH_trained_weight_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
        stitch_trained_bias_outpath      = f"./results_1g/STITCH_trained_bias_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
        #torch.save(trained_stitch_weight, stitch_trained_weight_outpath)
        #torch.save(trained_stitch_bias, stitch_trained_bias_outpath)
        
        print(f"Number of weight / bias in stitch layer is {len(initial_stitch_weight)}")
        stitch_weight_diff = trained_stitch_weight - initial_stitch_weight
        stitch_weight_delta = torch.linalg.norm(stitch_weight_diff).item()
        logtofile(f"Change in stitch weights: {stitch_weight_delta}")
        maxabsweight =  torch.max(stitch_weight_diff.abs()).item()
        logtofile(f"Largest abs weight change: {maxabsweight}")
        stitch_weight_number = torch.sum(torch.where(stitch_weight_diff.abs() > 0.1*maxabsweight, True, False)).item()
        logtofile(f"Number of weights changing > 0.1 of that: {stitch_weight_number}")
        
        stitch_bias_diff = trained_stitch_bias - initial_stitch_bias
        stitch_bias_delta = torch.linalg.norm(stitch_bias_diff).item()
        logtofile(f"Change in stitch bias: {stitch_bias_delta}")
        maxabsbias =  torch.max(stitch_bias_diff.abs()).item()
        logtofile(f"Largest abs bias change: {maxabsbias}")
        stitch_bias_number = torch.sum(torch.where(stitch_bias_diff.abs() > 0.1*maxabsbias, True, False)).item()
        logtofile(f"Number of bias changing > 0.1 of that: {stitch_bias_number}")

        #new_tensor = torch.load("foo_1_state.pt")
        ##############################################################
        
        logtofile("Test the trained stitch")        
        # Compute the model accuracy on the test set
        correct = 0
        total = 0
        
        # assuming 10 classes
        # rows represent actual class, columns are predicted
        confusion_matrix = torch.zeros(10,10, dtype=torch.int)
        
        for data in syn_testloader:
            inputs, labels = data
            #print(inputs)   
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            predictions = torch.argmax(model_cut(inputs),1)
            #print(model_cut(inputs))
            matches = predictions == labels.to(device)
            correct += matches.sum().item()
            total += len(labels)
        
            for idx, l in enumerate(labels):
                confusion_matrix[l, predictions[idx]] = 1 + confusion_matrix[l, predictions[idx]] 
        acc = ((100.0 * correct) / total)
        logtofile('Test Accuracy: %2.2f %%' % acc)
        logtofile('Confusion Matrix')
        logtofile(confusion_matrix)
        logtofile("===================================================================")
        stitching_accuracies[key][layer_to_cut_after] = acc
        stitching_penalties[key][layer_to_cut_after] = original_accuracy[key] - acc        
        
        dl = syn_testloader        
        params = {}
        params["model"] = model_training_type
        params["dataset"] = dataset_type
        params["seed"] = seed
        if val["train"]: # as only one network used, record its filename as both send and receive files
            params["rcv_file"] = val["saveas"] 
        else:    
            params["rcv_file"] = val["loadfrom"]
        params["stitch_weight_delta"] = stitch_weight_delta
        params["stitch_bias_delta"] = stitch_bias_delta        
        params["stitch_weight_number"] = stitch_weight_number
        params["stitch_bias_number"] = stitch_bias_number
        
        with torch.no_grad():
            layers, features, handles = install_hooks(model_cut)
            
            metrics = evaluate_model(model_cut, dl, 'acc', verbose=2)
            params.update(metrics)
            classes = None
            df = perform_analysis(features, classes, layers, params, n=8000)
            df.to_csv(f"{outpath}")
            
        for h in handles:
            h.remove()
        del model_cut, layers, features, metrics, params, df, handles
        gc.collect()

In [None]:

stitch_weight_diff = trained_stitch_weight - initial_stitch_weight
stitch_weight_delta = torch.linalg.norm(stitch_weight_diff).item()
print(f"Change in stitch weights: {stitch_weight_delta}")
maxabsweight =  torch.max(stitch_weight_diff.abs()).item()
print(f"Largest abs weight change: {maxabsweight}")
stitch_weight_number = torch.sum(torch.where(stitch_weight_diff.abs() > 0.1*maxabsweight, True, False)).item()
print(f"Number of weights changing > 0.1 of that: {stitch_weight_number}")

print(stitch_weight_delta)
print(stitch_weight_number)

In [None]:
logtofile(f"{stitching_accuracies=}")
logtofile(f"{stitching_penalties=}")

In [None]:
for r_key in stitching_accuracies:
    logtofile(f"synth-{r_key}")
    logtofile(original_accuracy[r_key])
    logtofile("Stitch Accuracy")
    for layer in stitching_accuracies[r_key]:
        logtofile(stitching_accuracies[r_key][layer])
    logtofile("--------------------------")