# Bias data, unbias trained receiver, different sender networks
bg_unbiased means digit and bg_colour are unrelated, but the model is trained and tested on the background.
bg_colour will be varied by 10% in each channel
In 2b the stitch was trained against the dataset of the receive network.
Now we will 
only use unbias-trained network as receiver
only use bias data to train the stitch
Try the different sender networks at all different stitch levels
## Rank
Also perform rank analysis on the stitched networks based on exp1e
## 4 Epochs
Only do 4 epochs of training (keep 10 epochs of stitch training) so that the initial models are weaker

In [1]:
# Packages
%matplotlib inline

import argparse
import gc
import os.path

import pandas as pd
from torch.linalg import LinAlgError

import matplotlib.pyplot as plt
import torchvision
import torch
from torch import optim

from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import datetime

import random
import numpy as np

import sys
import os
# add the path to find colour_mnist
sys.path.append(os.path.abspath('../ReferenceCode'))
import colour_mnist
from stitch_utils import train_model, RcvResNet18, StitchedResNet18, get_layer_output_shape
from stitch_utils import generate_activations, SyntheticDataset
import stitch_utils

# add the path to find the rank analysis code
# https://github.com/DHLSmith/jons-tunnel-effect/tree/NeurIPSPaper
sys.path.append(os.path.abspath('../../jons-tunnel-effect/'))
from utils.modelfitting import evaluate_model, set_seed
from extract_weight_rank import install_hooks, perform_analysis

import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST

# To track memory usage
import psutil
process = psutil.Process()
            

def logtofile(log_text, verbose=True):
    if verbose:
        print(log_text)
    with open(save_log_as, "a") as f:    
        print(log_text, file=f)

In [2]:
# Set Parameters

# fix random seed for reproducibility
seed = 61
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
random.seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

train_all = True

# MIX is 1/3 bgonly, 1/3 mnist only, 1/3 biased data
train_mix_mnist_model = train_all  # when False, automatically loads a trained model
mix_mnist_model_to_load = "./results_4_epochs/2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_mix_mnist.weights"

# BW is greyscale mnist with no colour added
train_bw_mnist_model = train_all  # when False, automatically loads a trained model
bw_mnist_model_to_load = './results_4_epochs/2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_bw_mnist.weights'

# BG_ONLY contains no mnist data, just a coloured background
train_bg_only_colour_mnist_model = train_all  # when False, automatically loads a trained model
bg_only_colour_mnist_model_to_load =  "./results_4_epochs/2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_bg_only_colour_mnist.weights"

# BG_UNBIASED is digits with randomly selected colour background. Targets represent the colour
train_bg_unbiased_colour_mnist_model = train_all  # when False, automatically loads a trained model
bg_unbiased_colour_mnist_model_to_load = "./results_4_epochs/2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_bg_unbiased_colour_mnist.weights"

# BIASED is digits with consistent per-class colour background. 
train_biased_colour_mnist_model = train_all  # when False, automatically loads a trained model
biased_colour_mnist_model_to_load = "./results_4_epochs/2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_biased_colour_mnist.weights"

# UNBIASED is digits with randoly selected colour background. Targets are digit values
train_unbiased_colour_mnist_model = train_all  # when False, automatically loads a trained model
unbiased_colour_mnist_model_to_load = "./results_4_epochs/2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_unbiased_colour_mnist.weights"

original_train_epochs = 4
bg_noise = 0.1

stitch_train_epochs = 10

batch_size = 128

In [3]:
# Generate filenames and log the setup details
formatted_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
filename_prefix = f"./results_4_epochs/{formatted_time}_SEED{seed}_EPOCHS{original_train_epochs}_BGN{bg_noise}_exp2e_ResNet18"
save_mix_mnist_model_as = f"{filename_prefix}_mix_mnist.weights"
save_bw_mnist_model_as = f"{filename_prefix}_bw_mnist.weights"
save_bg_only_colour_mnist_model_as = f"{filename_prefix}_bg_only_colour_mnist.weights"
save_bg_unbiased_colour_mnist_model_as = f"{filename_prefix}_bg_unbiased_colour_mnist.weights"
save_biased_colour_mnist_model_as = f"{filename_prefix}_biased_colour_mnist.weights"
save_unbiased_colour_mnist_model_as = f"{filename_prefix}_unbiased_colour_mnist.weights"
save_log_as = f"{filename_prefix}_log.txt"

colour_mnist_shape = (3,28,28)


logtofile(f"Executed at {formatted_time}")
logtofile(f"logging to {save_log_as}")
logtofile(f"{seed=}")
logtofile(f"{bg_noise=}")

logtofile(f"{train_mix_mnist_model=}")
if train_mix_mnist_model:
    logtofile(f"{save_mix_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{mix_mnist_model_to_load=}")

logtofile(f"{train_bw_mnist_model=}")
if train_bw_mnist_model:
    logtofile(f"{save_bw_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{bw_mnist_model_to_load=}")
    
logtofile(f"{train_bg_only_colour_mnist_model=}")
if train_bg_only_colour_mnist_model:
    logtofile(f"{save_bg_only_colour_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{bg_only_colour_mnist_model_to_load=}")
    
logtofile(f"{train_bg_unbiased_colour_mnist_model=}")
if train_bg_unbiased_colour_mnist_model:
    logtofile(f"{save_bg_unbiased_colour_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{bg_unbiased_colour_mnist_model_to_load=}")

logtofile(f"{train_biased_colour_mnist_model=}")
if train_biased_colour_mnist_model:
    logtofile(f"{save_biased_colour_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{biased_colour_mnist_model_to_load=}")

logtofile(f"{train_unbiased_colour_mnist_model=}")
if train_unbiased_colour_mnist_model:
    logtofile(f"{save_unbiased_colour_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{unbiased_colour_mnist_model_to_load=}")

logtofile(f"{stitch_train_epochs=}")
logtofile(f"================================================")

Executed at 2025-01-29_20-01-26
logging to ./results_4_epochs/2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_log.txt
seed=61
bg_noise=0.1
train_mix_mnist_model=True
save_mix_mnist_model_as='./results_4_epochs/2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_mix_mnist.weights'
original_train_epochs=4
train_bw_mnist_model=True
save_bw_mnist_model_as='./results_4_epochs/2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_bw_mnist.weights'
original_train_epochs=4
train_bg_only_colour_mnist_model=True
save_bg_only_colour_mnist_model_as='./results_4_epochs/2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_bg_only_colour_mnist.weights'
original_train_epochs=4
train_bg_unbiased_colour_mnist_model=True
save_bg_unbiased_colour_mnist_model_as='./results_4_epochs/2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_bg_unbiased_colour_mnist.weights'
original_train_epochs=4
train_biased_colour_mnist_model=True
save_biased_colour_mnist_model_as='./results_4_epochs/20

mnist and cifar-10 both use 10-classes, with 60_000 train samples and 10_000 test samples. 

In [4]:
# Set up dataloaders
transform_bw = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert to 3 channels    
    transforms.ToTensor(),  # convert to tensor. We always do this one    
    transforms.Normalize((0.1307,) * 3, (0.3081,) * 3)     
])

mnist_train = MNIST("./MNIST", train=True, download=True, transform=transform_bw)
mnist_test = MNIST("./MNIST", train=False, download=True, transform=transform_bw)

bw_train_dataloader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
bw_test_dataloader  = DataLoader(mnist_test,  batch_size=batch_size, shuffle=True, drop_last=False)

# mix dataloader
mix_train_dataloader = colour_mnist.get_mixed_mnist_dataloader(root="./MNIST", batch_size=batch_size, train=True, bg_noise_level=bg_noise, standard_getitem=True)
mix_test_dataloader = colour_mnist.get_mixed_mnist_dataloader(root="./MNIST", batch_size=batch_size,  train=False, bg_noise_level=bg_noise, standard_getitem=True)

# bg_only means no digits - we will use colour as label
bg_only_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=1.0, train=True, bg_noise_level=bg_noise, bg_only=True, standard_getitem=True)
bg_only_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=1.0, train=False, bg_noise_level=bg_noise, bg_only=True, standard_getitem=True)

# unbiased means each digit has correct label and random colour - but bg means we will use colour as label (i.e. the bias_target will be the target)
bg_unbiased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=0.1, train=True, bg_noise_level=bg_noise, bias_targets_as_targets=True)
bg_unbiased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=0.1, train=False, bg_noise_level=bg_noise, bias_targets_as_targets=True)

# biased means each digit has correct label and consistent colour - Expect network to learn the colours only
biased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=1.0, train=True, bg_noise_level=bg_noise, standard_getitem=True)
biased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=1.0, train=False, bg_noise_level=bg_noise, standard_getitem=True)

# unbiased means each digit has correct label and random colour - Expect network to disregard colours?
unbiased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=0.1, train=True, bg_noise_level=bg_noise, standard_getitem=True)
unbiased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=0.1, train=False, bg_noise_level=bg_noise, standard_getitem=True)

## Set up resnet18 models and train it on versions of MNIST

In [5]:

process_structure = dict()
device = 'cuda:0'


process_structure["mix"] = dict()
process_structure["bw"] = dict()
process_structure["bgonly"] = dict()
process_structure["bg"] = dict()
process_structure["bias"]      = dict()
process_structure["unbias"]    = dict()

# "mix"
process_structure["mix"]["model"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model
process_structure["mix"]["train"] = train_mix_mnist_model 
process_structure["mix"]["train_loader"] = mix_train_dataloader
process_structure["mix"]["test_loader"] = mix_test_dataloader
process_structure["mix"]["saveas"] = save_mix_mnist_model_as
process_structure["mix"]["loadfrom"] = mix_mnist_model_to_load

# "bw"
process_structure["bw"]["model"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model
process_structure["bw"]["train"] = train_bw_mnist_model 
process_structure["bw"]["train_loader"] = bw_train_dataloader
process_structure["bw"]["test_loader"] = bw_test_dataloader
process_structure["bw"]["saveas"] = save_bw_mnist_model_as
process_structure["bw"]["loadfrom"] = bw_mnist_model_to_load

# "bg_only_colour"
process_structure["bgonly"]["model"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model
process_structure["bgonly"]["train"] = train_bg_only_colour_mnist_model 
process_structure["bgonly"]["train_loader"] = bg_only_train_dataloader
process_structure["bgonly"]["test_loader"] = bg_only_test_dataloader
process_structure["bgonly"]["saveas"] = save_bg_only_colour_mnist_model_as
process_structure["bgonly"]["loadfrom"] = bg_only_colour_mnist_model_to_load

# "bg_unbiased_colour"
process_structure["bg"]["model"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model
process_structure["bg"]["train"] = train_bg_unbiased_colour_mnist_model 
process_structure["bg"]["train_loader"] = bg_unbiased_train_dataloader
process_structure["bg"]["test_loader"] = bg_unbiased_test_dataloader
process_structure["bg"]["saveas"] = save_bg_unbiased_colour_mnist_model_as
process_structure["bg"]["loadfrom"] = bg_unbiased_colour_mnist_model_to_load

# "biased_colour_mnist"
process_structure["bias"]["model"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model
process_structure["bias"]["train"] = train_biased_colour_mnist_model
process_structure["bias"]["train_loader"] = biased_train_dataloader
process_structure["bias"]["test_loader"] = biased_test_dataloader
process_structure["bias"]["saveas"] = save_biased_colour_mnist_model_as
process_structure["bias"]["loadfrom"] =  biased_colour_mnist_model_to_load

# "unbiased_colour_mnist"
process_structure["unbias"]["model"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model
process_structure["unbias"]["train"] = train_unbiased_colour_mnist_model
process_structure["unbias"]["train_loader"] = unbiased_train_dataloader
process_structure["unbias"]["test_loader"] = unbiased_test_dataloader
process_structure["unbias"]["saveas"] = save_unbiased_colour_mnist_model_as
process_structure["unbias"]["loadfrom"] =  unbiased_colour_mnist_model_to_load

for key, val in process_structure.items():
    print(f"Processing for {key=}")
    if val["train"]:
        train_model(model=val["model"], train_loader=val["train_loader"], 
                    epochs=original_train_epochs, saveas=val["saveas"], 
                    description=key, device=device, logtofile=logtofile)
    else:
        logtofile(f"{val['loadfrom']=}")
        val["model"].load_state_dict(torch.load(val["loadfrom"], map_location=torch.device(device)))
    val["model"].eval()


Processing for key='mix'
Train Model on mix
Epoch 0, loss 291.71
Epoch 1, loss 75.00
Epoch 2, loss 37.22
Epoch 3, loss 15.44
**** Finished Training ****
saveas='./results_4_epochs/2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_mix_mnist.weights'
Processing for key='bw'
Train Model on bw
Epoch 0, loss 137.39
Epoch 1, loss 26.91
Epoch 2, loss 19.86
Epoch 0, loss 66.28
Epoch 1, loss 4.52
Epoch 2, loss 1.49
Epoch 3, loss 1.58
**** Finished Training ****
saveas='./results_4_epochs/2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_bg_unbiased_colour_mnist.weights'
Processing for key='bias'
Train Model on bias
Epoch 0, loss 63.10
Epoch 1, loss 3.65
Epoch 2, loss 1.01
Epoch 3, loss 1.73
**** Finished Training ****
saveas='./results_4_epochs/2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_biased_colour_mnist.weights'
Processing for key='unbias'
Train Model on unbias
Epoch 0, loss 267.71
Epoch 1, loss 70.37
Epoch 2, loss 47.43
Epoch 3, loss 31.44
**** Finished Training 

## Measure the Accuracy, Record the Confusion Matrix


In [6]:
logtofile("Entering Confusion")
# logtofile(process.memory_info().rss)  # in bytes 

original_accuracy = dict()
for key, val in process_structure.items():
    logtofile(f"Accuracy Calculation for ResNet18 with {key=}")
    model = val["model"]
    model.eval() # ALWAYS DO THIS BEFORE YOU EVALUATE MODELS
    
    # Compute the model accuracy on the test set
    correct = 0
    total = 0
    
    # assuming 10 classes
    # rows represent actual class, columns are predicted
    confusion_matrix = torch.zeros(10,10, dtype=torch.int)
    
    TDL = biased_test_dataloader  # In test 2D - ALWAYS use biased dataset to measure/train stitch
    for data in TDL:
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        predictions = torch.argmax(model(inputs),1)
        
        matches = predictions == labels
        correct += matches.sum().item()
        total += len(labels)
        for idx, l in enumerate(labels):
            confusion_matrix[l, predictions[idx]] = 1 + confusion_matrix[l, predictions[idx]] 
    
    logtofile("Test the Trained Resnet18 against BIASED TEST DATALOADER")
    acc = ((100.0 * correct) / total)
    logtofile('Test Accuracy: %2.2f %%' % acc)
    original_accuracy[key] = acc
    logtofile('Confusion Matrix')
    logtofile(confusion_matrix)
    logtofile(confusion_matrix.sum())
    # logtofile(process.memory_info().rss)  # in bytes 


logtofile(f"{original_accuracy=}")

Entering Confusion
Accuracy Calculation for ResNet18 with key='mix'
Test the Trained Resnet18 against BIASED TEST DATALOADER
Test Accuracy: 100.00 %
Confusion Matrix
tensor([[ 980,    0,    0,    0,    0,    0,    0,    0,    0,    0],
        [   0, 1135,    0,    0,    0,    0,    0,    0,    0,    0],
        [   0,    0, 1032,    0,    0,    0,    0,    0,    0,    0],
        [   0,    0,    0, 1010,    0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,  982,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,  892,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0,  958,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0,    0, 1028,    0,    0],
        [   0,    0,    0,    0,    0,    0,    0,    0,  974,    0],
        [   0,    0,    0,    0,    0,    0,    0,    0,    0, 1009]],
       dtype=torch.int32)
tensor(10000)
Accuracy Calculation for ResNet18 with key='bw'
Test the Trained Resnet18 against BIASED TEST

## Measure Rank with __biased__ dataloader (test) before cutting and stitching

In [7]:
logtofile("Entering whole model check")
# logtofile(process.memory_info().rss)  # in bytes 

# For the Whole Model - but we will pass it through the RcvResNet18 function to get matching feature names
for key, val in process_structure.items():
    
    TDL = biased_test_dataloader  # ALWAYS use biased dataloader for this test
    
    if val["train"]:
        filename = val["saveas"] 
    else:    
        filename = val["loadfrom"] 
    assert os.path.exists(filename)
    mdl = torchvision.models.resnet18(num_classes=10) # Untrained model
    state = torch.load(filename, map_location=torch.device("cpu"))
    mdl.load_state_dict(state, assign=True)
    mdl=mdl.to(device)
    mdl = RcvResNet18(mdl, -1, colour_mnist_shape, device).to(device)

    out_filename = filename.split('/')[-1].replace('.weights', '-test.csv')
    
    outpath = f"./results_4_epochs_rank/{key}-bias-{seed}_{out_filename}"  # denote output name as <model_training_type>-dataset-<name>
    
    if os.path.exists(f"{outpath}"):
        logtofile(f"Already evaluated for {outpath}")
        continue
    logtofile(f"Measure Rank for {key=}")
    print(f"output to {outpath}")
            
    params = {}
    params["model"] = key
    params["dataset"] = "bias"
    params["seed"] = seed
    if val["train"]: # as only one network used, record its filename as both send and receive files
        params["send_file"] = val["saveas"] 
        params["rcv_file"] = val["saveas"] 
    else:    
        params["send_file"] = val["loadfrom"] 
        params["rcv_file"] = val["loadfrom"]     
    
    with torch.no_grad():
        layers, features, handles = install_hooks(mdl)
        
        metrics = evaluate_model(mdl, TDL, 'acc', verbose=2)
        params.update(metrics)
        classes = None
        df = perform_analysis(features, classes, layers, params, n=-1)
        df.to_csv(f"{outpath}")
    for h in handles:
        h.remove()
    del mdl, layers, features, metrics, params, df, handles
    gc.collect()
    # logtofile(process.memory_info().rss)  # in bytes 



Entering whole model check
Measure Rank for key='mix'
output to ./results_4_epochs_rank/mix-bias-61_2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_mix_mnist-test.csv


0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 21/21 [00:35<00:00,  1.67s/it]


Measure Rank for key='bw'
output to ./results_4_epochs_rank/bw-bias-61_2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_bw_mnist-test.csv


0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 21/21 [00:35<00:00,  1.69s/it]


Measure Rank for key='bgonly'
output to ./results_4_epochs_rank/bgonly-bias-61_2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_bg_only_colour_mnist-test.csv


0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 21/21 [00:36<00:00,  1.72s/it]


Measure Rank for key='bg'
output to ./results_4_epochs_rank/bg-bias-61_2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_bg_unbiased_colour_mnist-test.csv


0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 21/21 [00:36<00:00,  1.73s/it]


Measure Rank for key='bias'
output to ./results_4_epochs_rank/bias-bias-61_2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_biased_colour_mnist-test.csv


0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 21/21 [00:36<00:00,  1.72s/it]


Measure Rank for key='unbias'
output to ./results_4_epochs_rank/unbias-bias-61_2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_unbiased_colour_mnist-test.csv


0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 21/21 [00:33<00:00,  1.57s/it]


# Stitch at a given layer


## Train the stitch layer and check rank

In [8]:
logtofile("Entering Stitch/Rank")
# logtofile(process.memory_info().rss)  # in bytes 

logtofile(f"{device=}")
stitching_accuracies = dict()
stitching_penalties = dict()
# NOTE this is only valid as all models are the same architecture
num_layers_in_model = len(list(process_structure["bias"]["model"].children()))  
for send_key, send_val in process_structure.items():
    stitching_accuracies[send_key] = dict()
    stitching_penalties[send_key] = dict()
    for rcv_key, rcv_val in process_structure.items():        
        if (rcv_key != "unbias"):
            logtofile(f"NOTE: Only running stitch with unbias receive model")
            continue
        stitching_accuracies[send_key][rcv_key] = dict()
        stitching_penalties[send_key][rcv_key] = dict()
        for layer_to_cut_after in range(3,num_layers_in_model - 1):
            # for consistency, use the rcv network for the filename stem.
            if rcv_val["train"]:
                filename = rcv_val["saveas"] 
            else:    
                filename = rcv_val["loadfrom"] 
            
            rank_filename = filename.split('/')[-1].replace('.weights', '-test.csv')        
            # denote output name as <model_training_type>-dataset-<name>
            # where <model_training_type> is [sender_model or X][layer_to_cut_after][Receiver_model]
            model_training_type = f"{send_key}{layer_to_cut_after}{rcv_key}"
            dataset_type = "bias"  # ALWAYS use bias dataset in this test
            outpath = f"./results_4_epochs_rank/{model_training_type}-{dataset_type}-{seed}_{rank_filename}"  
                            
            if os.path.exists(f"{outpath}"):
                logtofile(f"Already evaluated for {outpath}")
                continue
            ####################################################################################
            logtofile(f"Evaluate ranks and output to {outpath}")
            # logtofile(process.memory_info().rss)  # in bytes 

            logtofile(f"Train the stitch to a model stitched after layer {layer_to_cut_after} from {send_key} to {rcv_key}")    
            logtofile(f"Use the biased data loader (train and test) regardless of what {rcv_key} was trained on")
            
            # train a stitch on the unbiased_colour dataset to compare receiver network performance with stitched
            model_stitched = StitchedResNet18(send_model=send_val["model"], 
                                              after_layer_index=layer_to_cut_after, 
                                              rcv_model=rcv_val["model"],
                                              input_image_shape=colour_mnist_shape, device=device  ).to(device)
                        
            #############################################################
            # store the initial stitch state
            initial_stitch_weight = model_stitched.stitch.s_conv1.weight.clone()
            initial_stitch_bias   = model_stitched.stitch.s_conv1.bias.clone()
            stitch_initial_weight_outpath    = f"./results_4_epochs/STITCH_initial_weight_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
            stitch_initial_bias_outpath      = f"./results_4_epochs/STITCH_initial_bias_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
            torch.save(initial_stitch_weight, stitch_initial_weight_outpath)
            torch.save(initial_stitch_bias, stitch_initial_bias_outpath)
            ############################################################
                    
            # define the loss function and the optimiser
            loss_function = nn.CrossEntropyLoss()
            # Hernandez said : momentum 0.9, batch size 256, weight decay 0.01, learning rate 0.01, and a post-warmup cosine learning rate scheduler.
            # optimiser = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
            optimiser = optim.SGD(model_stitched.parameters(), lr=1e-4, momentum=0.9, weight_decay=0.01)
            
            # Put top model into train mode so that bn and dropout perform in training mode
            model_stitched.train()
            # Freeze the whole model
            model_stitched.requires_grad_(False)
            # Un-Freeze the stitch layer
            for name, param in model_stitched.stitch.named_parameters():
                param.requires_grad_(True)
            # the epoch loop: note that we're training the whole network
            for epoch in range(stitch_train_epochs):
                running_loss = 0.0
                for data in biased_train_dataloader:
                    # data is (representations, labels) tuple
                    # get the inputs and put them on the GPU
                    inputs, labels = data
                    inputs = inputs.to(device)
                    labels = labels.to(device)
            
                    # zero the parameter gradients
                    optimiser.zero_grad()
            
                    # forward + loss + backward + optimise (update weights)
                    outputs = model_stitched(inputs)
                    loss = loss_function(outputs, labels)
                    loss.backward()
                    optimiser.step()
            
                    # keep track of the loss this epoch
                    running_loss += loss.item()
                logtofile("Epoch %d, loss %4.2f" % (epoch, running_loss))
                # logtofile(process.memory_info().rss)  # in bytes 

            logtofile('**** Finished Training ****')
            
            model_stitched.eval() # ALWAYS DO THIS BEFORE YOU EVALUATE MODELS

            ############################################################
            # store the trained stitch
            trained_stitch_weight = model_stitched.stitch.s_conv1.weight.clone()
            trained_stitch_bias   = model_stitched.stitch.s_conv1.bias.clone()
            stitch_trained_weight_outpath    = f"./results_4_epochs/STITCH_trained_weight_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
            stitch_trained_bias_outpath      = f"./results_4_epochs/STITCH_trained_bias_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
            torch.save(trained_stitch_weight, stitch_trained_weight_outpath)
            torch.save(trained_stitch_bias, stitch_trained_bias_outpath)
                       
            stitch_weight_diff = trained_stitch_weight - initial_stitch_weight
            stitch_weight_delta = torch.linalg.norm(stitch_weight_diff).item()
            logtofile(f"Change in stitch weights: {stitch_weight_delta}")
            maxabsweight =  torch.max(stitch_weight_diff.abs()).item()
            logtofile(f"Largest abs weight change: {maxabsweight}")
            stitch_weight_number = torch.sum(torch.where(stitch_weight_diff.abs() > 0.1*maxabsweight, True, False)).item()
            logtofile(f"Number of weights changing > 0.1 of that: {stitch_weight_number}")

            
            print(f"Number of weight / bias in stitch layer is {len(initial_stitch_weight)}")
            stitch_bias_diff = trained_stitch_bias - initial_stitch_bias
            stitch_bias_delta = torch.linalg.norm(stitch_bias_diff).item()
            logtofile(f"Change in stitch bias: {stitch_bias_delta}")
            maxabsbias =  torch.max(stitch_bias_diff.abs()).item()
            logtofile(f"Largest abs bias change: {maxabsbias}")
            stitch_bias_number = torch.sum(torch.where(stitch_bias_diff.abs() > 0.1*maxabsbias, True, False)).item()
            logtofile(f"Number of bias changing > 0.1 of that: {stitch_bias_number}")
            ##############################################################

            
            # Compute the model accuracy on the test set
            correct = 0
            total = 0
            
            # assuming 10 classes
            # rows represent actual class, columns are predicted
            confusion_matrix = torch.zeros(10,10, dtype=torch.int)
            
            for data in biased_test_dataloader:  # Only use biased test data
                inputs, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                predictions = torch.argmax(model_stitched(inputs),1)
                matches = predictions == labels.to(device)
                correct += matches.sum().item()
                total += len(labels)
            
                for idx, l in enumerate(labels):
                    confusion_matrix[l, predictions[idx]] = 1 + confusion_matrix[l, predictions[idx]] 
            logtofile("Test the trained stitch against biased data")    
            acc =  ((100.0 * correct) / total)
            logtofile('Test Accuracy: %2.2f %%' % acc)
            logtofile('Confusion Matrix')
            logtofile(confusion_matrix)
            logtofile("===================================================================")
            # logtofile(process.memory_info().rss)  # in bytes 

            # Stitching penalty should be negative if there is an improvement, and is relative to the original receiver network
            stitching_accuracies[send_key][rcv_key][layer_to_cut_after] = acc
            stitching_penalties[send_key][rcv_key][layer_to_cut_after] = original_accuracy[rcv_key] - acc

            TDL = biased_test_dataloader
            params = {}
            params["model"] = model_training_type # a mnemonic
            params["dataset"] = dataset_type
            params["seed"] = seed
            if send_val["train"]:
                params["send_file"] = send_val["saveas"] 
            else:    
                params["send_file"] = send_val["loadfrom"] 
            if rcv_val["train"]:
                params["rcv_file"] = rcv_val["saveas"] 
            else:    
                params["rcv_file"] = rcv_val["loadfrom"] 
            params["stitch_weight_delta"] = stitch_weight_delta
            params["stitch_bias_delta"] = stitch_bias_delta        
            params["stitch_weight_number"] = stitch_weight_number
            params["stitch_bias_number"] = stitch_bias_number
            # logtofile(process.memory_info().rss)  # in bytes 
            with torch.no_grad():
                layers, features, handles = install_hooks(model_stitched)                
                metrics = evaluate_model(model_stitched, TDL, 'acc', verbose=2)
                params.update(metrics)
                classes = None
                df = perform_analysis(features, classes, layers, params, n=-1)
                df.to_csv(f"{outpath}")
                
            for h in handles:
                h.remove()
            del model_stitched, layers, features, metrics, params, df, handles
            gc.collect()
            # logtofile(process.memory_info().rss)  # in bytes 

            

Entering Stitch/Rank
device='cuda:0'
NOTE: Only running stitch with unbias receive model
NOTE: Only running stitch with unbias receive model
NOTE: Only running stitch with unbias receive model
NOTE: Only running stitch with unbias receive model
NOTE: Only running stitch with unbias receive model
Evaluate ranks and output to ./results_4_epochs_rank/mix3unbias-bias-61_2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 3 from mix to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 3 of send_model is: torch.Size([1, 64, 7, 7])
Epoch 0, loss 227.34
Epoch 1, loss 50.47
Epoch 2, loss 35.49
Epoch 3, loss 28.83
Epoch 4, loss 24.23
Epoch 5, loss 21.10
Epoch 6, loss 18.37
Epoch 7, loss 17.37
Epoch 8, loss 15.23
Epoch 9, loss 13.93
**** Finished Training ****
Change in stitch weights: 1.0903089

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 22/22 [00:36<00:00,  1.67s/it]


Evaluate ranks and output to ./results_4_epochs_rank/mix4unbias-bias-61_2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 4 from mix to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 4 of send_model is: torch.Size([1, 64, 7, 7])
Epoch 0, loss 367.69
Epoch 1, loss 32.62
Epoch 2, loss 19.90
Epoch 3, loss 15.64
Epoch 4, loss 12.33
Epoch 5, loss 9.97
Epoch 6, loss 9.15
Epoch 7, loss 8.69
Epoch 8, loss 6.91
Epoch 9, loss 7.11
**** Finished Training ****
Change in stitch weights: 1.114902138710022
Largest abs weight change: 0.12319251894950867
Number of weights changing > 0.1 of that: 1556
Number of weight / bias in stitch layer is 64
Change in stitch bias: 0.027516931295394897
Largest abs bias change: 0.005611389875411987
Number of bias changing > 0.1 of that: 60
Test the trained stit

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 22/22 [00:33<00:00,  1.53s/it]


Evaluate ranks and output to ./results_4_epochs_rank/mix5unbias-bias-61_2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 5 from mix to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 5 of send_model is: torch.Size([1, 128, 4, 4])
Epoch 0, loss 227.15
Epoch 1, loss 14.16
Epoch 2, loss 8.85
Epoch 3, loss 6.50
Epoch 4, loss 5.02
Epoch 5, loss 4.52
Epoch 6, loss 3.54
Epoch 7, loss 3.14
Epoch 8, loss 3.18
Epoch 9, loss 2.87
**** Finished Training ****
Change in stitch weights: 0.9085882306098938
Largest abs weight change: 0.06845062226057053
Number of weights changing > 0.1 of that: 4210
Number of weight / bias in stitch layer is 128
Change in stitch bias: 0.027867073193192482
Largest abs bias change: 0.003968223929405212
Number of bias changing > 0.1 of that: 111
Test the trained sti

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 22/22 [00:35<00:00,  1.62s/it]


Evaluate ranks and output to ./results_4_epochs_rank/mix6unbias-bias-61_2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 6 from mix to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 6 of send_model is: torch.Size([1, 256, 2, 2])
Epoch 0, loss 76.39
Epoch 1, loss 4.78
Epoch 2, loss 2.89
Epoch 3, loss 2.29
Epoch 4, loss 2.15
Epoch 5, loss 1.74
Epoch 6, loss 1.68
Epoch 7, loss 1.67
Epoch 8, loss 1.22
Epoch 9, loss 1.50
**** Finished Training ****
Change in stitch weights: 0.7377018332481384
Largest abs weight change: 0.027315039187669754
Number of weights changing > 0.1 of that: 18825
Number of weight / bias in stitch layer is 256
Change in stitch bias: 0.02506132423877716
Largest abs bias change: 0.002848513424396515
Number of bias changing > 0.1 of that: 228
Test the trained stit

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 22/22 [00:33<00:00,  1.52s/it]


Evaluate ranks and output to ./results_4_epochs_rank/mix7unbias-bias-61_2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 7 from mix to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 7 of send_model is: torch.Size([1, 512, 1, 1])
Epoch 0, loss 51.14
Epoch 1, loss 3.03
Epoch 2, loss 2.67
Epoch 3, loss 2.18
Epoch 4, loss 1.76
Epoch 5, loss 1.43
Epoch 6, loss 1.49
Epoch 7, loss 1.32
Epoch 8, loss 1.26
Epoch 9, loss 1.37
**** Finished Training ****
Change in stitch weights: 0.8016335964202881
Largest abs weight change: 0.013892661780118942
Number of weights changing > 0.1 of that: 100703
Number of weight / bias in stitch layer is 512
Change in stitch bias: 0.02583794668316841
Largest abs bias change: 0.002016685903072357
Number of bias changing > 0.1 of that: 453
Test the trained sti

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 22/22 [00:37<00:00,  1.70s/it]


Evaluate ranks and output to ./results_4_epochs_rank/mix8unbias-bias-61_2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 8 from mix to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 8 of send_model is: torch.Size([1, 512, 1, 1])
Epoch 0, loss 50.72
Epoch 1, loss 3.27
Epoch 2, loss 2.37
Epoch 3, loss 2.09
Epoch 4, loss 1.80
Epoch 5, loss 1.53
Epoch 6, loss 1.45
Epoch 7, loss 1.42
Epoch 8, loss 1.19
Epoch 9, loss 1.19
**** Finished Training ****
Change in stitch weights: 0.8050686717033386
Largest abs weight change: 0.013270174153149128
Number of weights changing > 0.1 of that: 107688
Number of weight / bias in stitch layer is 512
Change in stitch bias: 0.026144079864025116
Largest abs bias change: 0.002020888030529022
Number of bias changing > 0.1 of that: 463
Test the trained st

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 22/22 [00:36<00:00,  1.66s/it]


NOTE: Only running stitch with unbias receive model
NOTE: Only running stitch with unbias receive model
NOTE: Only running stitch with unbias receive model
NOTE: Only running stitch with unbias receive model
NOTE: Only running stitch with unbias receive model
Evaluate ranks and output to ./results_4_epochs_rank/bw3unbias-bias-61_2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 3 from bw to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 3 of send_model is: torch.Size([1, 64, 7, 7])
Epoch 0, loss 308.65
Epoch 1, loss 62.76
Epoch 2, loss 50.75
Epoch 3, loss 45.37
Epoch 4, loss 40.94
Epoch 5, loss 38.52
Epoch 6, loss 36.92
Epoch 7, loss 34.82
Epoch 8, loss 33.45
Epoch 9, loss 31.78
**** Finished Training ****
Change in stitch weights: 1.0683857202529907
Largest abs weight change: 0.

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 22/22 [00:37<00:00,  1.69s/it]


Evaluate ranks and output to ./results_4_epochs_rank/bw4unbias-bias-61_2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 4 from bw to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 4 of send_model is: torch.Size([1, 64, 7, 7])
Epoch 0, loss 226.61
Epoch 1, loss 60.60
Epoch 2, loss 49.44
Epoch 3, loss 43.64
Epoch 4, loss 40.14
Epoch 5, loss 37.35
Epoch 6, loss 35.39
Epoch 7, loss 33.93
Epoch 8, loss 32.28
Epoch 9, loss 31.54
**** Finished Training ****
Change in stitch weights: 1.0824449062347412
Largest abs weight change: 0.10259586572647095
Number of weights changing > 0.1 of that: 1686
Number of weight / bias in stitch layer is 64
Change in stitch bias: 0.02606078051030636
Largest abs bias change: 0.0057085007429122925
Number of bias changing > 0.1 of that: 58
Test the trained 

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 22/22 [00:39<00:00,  1.77s/it]


Evaluate ranks and output to ./results_4_epochs_rank/bw5unbias-bias-61_2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 5 from bw to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 5 of send_model is: torch.Size([1, 128, 4, 4])
Epoch 0, loss 210.63
Epoch 1, loss 50.79
Epoch 2, loss 39.37
Epoch 3, loss 34.15
Epoch 4, loss 30.64
Epoch 5, loss 28.09
Epoch 6, loss 25.77
Epoch 7, loss 24.05
Epoch 8, loss 22.87
Epoch 9, loss 21.40
**** Finished Training ****
Change in stitch weights: 1.0879610776901245
Largest abs weight change: 0.053218014538288116
Number of weights changing > 0.1 of that: 7822
Number of weight / bias in stitch layer is 128
Change in stitch bias: 0.026817502453923225
Largest abs bias change: 0.0040097832679748535
Number of bias changing > 0.1 of that: 117
Test the tra

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 22/22 [00:37<00:00,  1.71s/it]


Evaluate ranks and output to ./results_4_epochs_rank/bw6unbias-bias-61_2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 6 from bw to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 6 of send_model is: torch.Size([1, 256, 2, 2])
Epoch 0, loss 117.40
Epoch 1, loss 29.11
Epoch 2, loss 23.57
Epoch 3, loss 20.42
Epoch 4, loss 18.87
Epoch 5, loss 17.66
Epoch 6, loss 16.75
Epoch 7, loss 15.60
Epoch 8, loss 15.30
Epoch 9, loss 14.80
**** Finished Training ****
Change in stitch weights: 0.8934941291809082
Largest abs weight change: 0.04855835810303688
Number of weights changing > 0.1 of that: 8180
Number of weight / bias in stitch layer is 256
Change in stitch bias: 0.02542518638074398
Largest abs bias change: 0.002824343740940094
Number of bias changing > 0.1 of that: 229
Test the traine

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 22/22 [00:33<00:00,  1.51s/it]


Evaluate ranks and output to ./results_4_epochs_rank/bw7unbias-bias-61_2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 7 from bw to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 7 of send_model is: torch.Size([1, 512, 1, 1])
Epoch 0, loss 95.20
Epoch 1, loss 35.04
Epoch 2, loss 30.94
Epoch 3, loss 28.57
Epoch 4, loss 26.78
Epoch 5, loss 25.47
Epoch 6, loss 24.36
Epoch 7, loss 23.92
Epoch 8, loss 22.65
Epoch 9, loss 22.25
**** Finished Training ****
Change in stitch weights: 0.9160462617874146
Largest abs weight change: 0.02072807028889656
Number of weights changing > 0.1 of that: 53430
Number of weight / bias in stitch layer is 512
Change in stitch bias: 0.026117218658328056
Largest abs bias change: 0.0020145252346992493
Number of bias changing > 0.1 of that: 461
Test the trai

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch 5, loss 26.20
Epoch 6, loss 23.71
Epoch 7, loss 22.03
Epoch 8, loss 20.85
Epoch 9, loss 19.31
**** Finished Training ****
Change in stitch weights: 1.0073761940002441
Largest abs weight change: 0.13200369477272034
Number of weights changing > 0.1 of that: 1147
Number of weight / bias in stitch layer is 64
Change in stitch bias: 0.028249530121684074
Largest abs bias change: 0.005637906491756439
Number of bias changing > 0.1 of that: 59
Test the trained stitch against biased data
Test Accuracy: 98.74 %
Confusion Matrix
tensor([[ 976,    0,    1,    0,    0,    0,    0,    1,    2,    0],
        [   0, 1129,    0,    3,    0,    0,    2,    0,    1,    0],
        [   1,    0, 1025,    1,    2,    0,    0,    1,    2,    0],
        [   0,    0,    0, 1005,    0,    2,    0,    0,    2,    1],
        [   0,    1,    0,    0,  980,    0,    0,    0,    1,    0],
        [   1,    0,    1,    7,    0,  881,    0,    0,    2,    0],
        [   9,    0,    4,    1,    3,    5,  935, 

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 22/22 [00:39<00:00,  1.78s/it]


Evaluate ranks and output to ./results_4_epochs_rank/unbias5unbias-bias-61_2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 5 from unbias to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 5 of send_model is: torch.Size([1, 128, 4, 4])
Epoch 0, loss 192.52
Epoch 1, loss 47.86
Epoch 2, loss 36.55
Epoch 3, loss 30.12
Epoch 4, loss 27.30
Epoch 5, loss 25.80
Epoch 6, loss 24.26
Epoch 7, loss 22.75
Epoch 8, loss 21.29
Epoch 9, loss 21.05
**** Finished Training ****
Change in stitch weights: 1.0372215509414673
Largest abs weight change: 0.06327448040246964
Number of weights changing > 0.1 of that: 6012
Number of weight / bias in stitch layer is 128
Change in stitch bias: 0.025303145870566368
Largest abs bias change: 0.0040357038378715515
Number of bias changing > 0.1 of that: 114
Test 

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 22/22 [00:38<00:00,  1.76s/it]


Evaluate ranks and output to ./results_4_epochs_rank/unbias6unbias-bias-61_2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 6 from unbias to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 6 of send_model is: torch.Size([1, 256, 2, 2])
Epoch 0, loss 128.26
Epoch 1, loss 36.40
Epoch 2, loss 30.59
Epoch 3, loss 27.18
Epoch 4, loss 25.39
Epoch 5, loss 24.16
Epoch 6, loss 22.30
Epoch 7, loss 22.06
Epoch 8, loss 20.91
Epoch 9, loss 20.38
**** Finished Training ****
Change in stitch weights: 0.9123812913894653
Largest abs weight change: 0.032476507127285004
Number of weights changing > 0.1 of that: 19739
Number of weight / bias in stitch layer is 256
Change in stitch bias: 0.0259136613458395
Largest abs bias change: 0.00284421443939209
Number of bias changing > 0.1 of that: 221
Test th

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 22/22 [00:35<00:00,  1.59s/it]


Evaluate ranks and output to ./results_4_epochs_rank/unbias7unbias-bias-61_2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 7 from unbias to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 7 of send_model is: torch.Size([1, 512, 1, 1])
Epoch 0, loss 86.15
Epoch 1, loss 26.93
Epoch 2, loss 24.78
Epoch 3, loss 23.31
Epoch 4, loss 22.66
Epoch 5, loss 21.61
Epoch 6, loss 20.71
Epoch 7, loss 20.88
Epoch 8, loss 20.48
Epoch 9, loss 20.26
**** Finished Training ****
Change in stitch weights: 0.8723613619804382
Largest abs weight change: 0.011980276554822922
Number of weights changing > 0.1 of that: 127580
Number of weight / bias in stitch layer is 512
Change in stitch bias: 0.025867940858006477
Largest abs bias change: 0.0020201317965984344
Number of bias changing > 0.1 of that: 471
Tes

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 22/22 [00:37<00:00,  1.70s/it]


Evaluate ranks and output to ./results_4_epochs_rank/unbias8unbias-bias-61_2025-01-29_20-01-26_SEED61_EPOCHS4_BGN0.1_exp2e_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 8 from unbias to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 8 of send_model is: torch.Size([1, 512, 1, 1])
Epoch 0, loss 89.82
Epoch 1, loss 27.10
Epoch 2, loss 24.50
Epoch 3, loss 23.34
Epoch 4, loss 23.31
Epoch 5, loss 21.94
Epoch 6, loss 21.40
Epoch 7, loss 21.06
Epoch 8, loss 20.42
Epoch 9, loss 19.96
**** Finished Training ****
Change in stitch weights: 0.8706745505332947
Largest abs weight change: 0.011597005650401115
Number of weights changing > 0.1 of that: 131605
Number of weight / bias in stitch layer is 512
Change in stitch bias: 0.025922397151589394
Largest abs bias change: 0.0020193085074424744
Number of bias changing > 0.1 of that: 466
Tes

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|███████████████████████████████████████████████████████████████████████████████████| 22/22 [00:37<00:00,  1.70s/it]


In [9]:
logtofile(f"{stitching_accuracies=}")
logtofile(f"{stitching_penalties=}")

stitching_accuracies={'mix': {'unbias': {3: 99.36, 4: 99.71, 5: 99.96, 6: 100.0, 7: 100.0, 8: 100.0}}, 'bw': {'unbias': {3: 97.79, 4: 97.99, 5: 98.66, 6: 99.06, 7: 98.63, 8: 98.65}}, 'bgonly': {'unbias': {3: 99.87, 4: 99.99, 5: 100.0, 6: 100.0, 7: 100.0, 8: 100.0}}, 'bg': {'unbias': {3: 99.88, 4: 99.94, 5: 100.0, 6: 100.0, 7: 100.0, 8: 100.0}}, 'bias': {'unbias': {3: 99.82, 4: 99.84, 5: 99.99, 6: 100.0, 7: 100.0, 8: 100.0}}, 'unbias': {'unbias': {3: 99.38, 4: 98.74, 5: 98.45, 6: 98.65, 7: 98.42, 8: 98.48}}}
stitching_penalties={'mix': {'unbias': {3: -1.6700000000000017, 4: -2.019999999999996, 5: -2.269999999999996, 6: -2.3100000000000023, 7: -2.3100000000000023, 8: -2.3100000000000023}}, 'bw': {'unbias': {3: -0.10000000000000853, 4: -0.29999999999999716, 5: -0.9699999999999989, 6: -1.3700000000000045, 7: -0.9399999999999977, 8: -0.960000000000008}}, 'bgonly': {'unbias': {3: -2.180000000000007, 4: -2.299999999999997, 5: -2.3100000000000023, 6: -2.3100000000000023, 7: -2.3100000000000023

In [10]:
for r_key in stitching_accuracies:
    logtofile(f"synth-{r_key}")
    logtofile(original_accuracy[r_key])
    logtofile("Stitch Accuracy")
    for layer in stitching_accuracies[r_key]:
        logtofile(stitching_accuracies[r_key][layer])
    logtofile("--------------------------")

synth-mix
100.0
Stitch Accuracy
{3: 99.36, 4: 99.71, 5: 99.96, 6: 100.0, 7: 100.0, 8: 100.0}
--------------------------
synth-bw
62.73
Stitch Accuracy
{3: 97.79, 4: 97.99, 5: 98.66, 6: 99.06, 7: 98.63, 8: 98.65}
--------------------------
synth-bgonly
90.71
Stitch Accuracy
{3: 99.87, 4: 99.99, 5: 100.0, 6: 100.0, 7: 100.0, 8: 100.0}
--------------------------
synth-bg
100.0
Stitch Accuracy
{3: 99.88, 4: 99.94, 5: 100.0, 6: 100.0, 7: 100.0, 8: 100.0}
--------------------------
synth-bias
99.95
Stitch Accuracy
{3: 99.82, 4: 99.84, 5: 99.99, 6: 100.0, 7: 100.0, 8: 100.0}
--------------------------
synth-unbias
97.69
Stitch Accuracy
{3: 99.38, 4: 98.74, 5: 98.45, 6: 98.65, 7: 98.42, 8: 98.48}
--------------------------
