# Exp2m: Bias data, unbias trained receiver, randomly initialised sender network
Now we will 
only use unbias-trained network as receiver
only use bias data to train the stitch
Try the sender networks at all different stitch levels
## Rank
Also perform rank analysis on the stitched networks based on exp1e
## 4 Epochs
Only do 4 epochs of training (keep 10 epochs of stitch training) so that the initial models are weaker

In [1]:
# Packages
%matplotlib inline

import argparse
import gc
import os.path

import pandas as pd
from torch.linalg import LinAlgError

import matplotlib.pyplot as plt
import torchvision
import torch
from torch import optim

from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import datetime

import random
import numpy as np

import sys
import os
# add the path to find colour_mnist
sys.path.append(os.path.abspath('../ReferenceCode'))
import colour_mnist
from stitch_utils import train_model, RcvResNet18, StitchedResNet18, get_layer_output_shape
from stitch_utils import generate_activations, SyntheticDataset
import stitch_utils

# add the path to find the rank analysis code
# https://github.com/DHLSmith/jons-tunnel-effect/tree/NeurIPSPaper
sys.path.append(os.path.abspath('../../jons-tunnel-effect/'))
from utils.modelfitting import evaluate_model, set_seed
from extract_weight_rank import install_hooks, perform_analysis

import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST

# To track memory usage
import psutil
process = psutil.Process()
            

def logtofile(log_text, verbose=True):
    if verbose:
        print(log_text)
    with open(save_log_as, "a") as f:    
        print(log_text, file=f)

In [2]:
# Set Parameters

# fix random seed for reproducibility
seed = 10
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
random.seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

results_root = "results_2m"

# randinit model
gen_randinit_model = True
randinit_model_to_load = f"./results_2m/2025-03-26_12-06-31_SEED57_EPOCHS4_BGN0.1_exp2e_ResNet18_randinit.weights"

# UNBIASED is digits with randoly selected colour background. Targets are digit values
train_unbiased_colour_mnist_model = False  # when False, automatically loads a trained model
unbiased_colour_mnist_model_to_load = "./results_4_epochs/2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_unbiased_colour_mnist.weights"

original_train_epochs = 4
bg_noise = 0.1

stitch_train_epochs = 10

batch_size = 128

In [3]:
# Generate filenames and log the setup details
formatted_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
filename_prefix = f"./{results_root}/{formatted_time}_SEED{seed}_EPOCHS{original_train_epochs}_BGN{bg_noise}_exp2e_ResNet18"
#save_mix_mnist_model_as = f"{filename_prefix}_mix_mnist.weights"
#save_bw_mnist_model_as = f"{filename_prefix}_bw_mnist.weights"
#save_bg_only_colour_mnist_model_as = f"{filename_prefix}_bg_only_colour_mnist.weights"
#save_bg_unbiased_colour_mnist_model_as = f"{filename_prefix}_bg_unbiased_colour_mnist.weights"
#save_biased_colour_mnist_model_as = f"{filename_prefix}_biased_colour_mnist.weights"
save_randinit_model_as = f"{filename_prefix}_randinit.weights"
save_unbiased_colour_mnist_model_as = f"{filename_prefix}_unbiased_colour_mnist.weights"
save_log_as = f"{filename_prefix}_log.txt"

colour_mnist_shape = (3,28,28)


logtofile(f"Executed at {formatted_time}")
logtofile(f"logging to {save_log_as}")
logtofile(f"{seed=}")
logtofile(f"{bg_noise=}")

logtofile(f"{gen_randinit_model=}")
if gen_randinit_model:
    logtofile(f"{save_randinit_model_as=}")    
else:
    logtofile(f"{randinit_model_to_load=}")


logtofile(f"{train_unbiased_colour_mnist_model=}")
if train_unbiased_colour_mnist_model:
    logtofile(f"{save_unbiased_colour_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{unbiased_colour_mnist_model_to_load=}")

logtofile(f"{stitch_train_epochs=}")
logtofile(f"================================================")

Executed at 2025-03-26_14-24-26
logging to ./results_2m/2025-03-26_14-24-26_SEED10_EPOCHS4_BGN0.1_exp2e_ResNet18_log.txt
seed=10
bg_noise=0.1
gen_randinit_model=True
save_randinit_model_as='./results_2m/2025-03-26_14-24-26_SEED10_EPOCHS4_BGN0.1_exp2e_ResNet18_randinit.weights'
train_unbiased_colour_mnist_model=False
unbiased_colour_mnist_model_to_load='./results_4_epochs/2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_unbiased_colour_mnist.weights'
stitch_train_epochs=10


mnist and cifar-10 both use 10-classes, with 60_000 train samples and 10_000 test samples. 

In [4]:

# biased means each digit has correct label and consistent colour - Expect network to learn the colours only
biased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=1.0, train=True, bg_noise_level=bg_noise, standard_getitem=True)
biased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=1.0, train=False, bg_noise_level=bg_noise, standard_getitem=True)

# unbiased means each digit has correct label and random colour - Expect network to disregard colours?
unbiased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=0.1, train=True, bg_noise_level=bg_noise, standard_getitem=True)
unbiased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=0.1, train=False, bg_noise_level=bg_noise, standard_getitem=True)

## Set up resnet18 models and train it on versions of MNIST

In [5]:

process_structure = dict()
device = 'cuda:0'


process_structure["randinit"] = dict()
process_structure["unbias"]    = dict()

# "randinit"
process_structure["randinit"]["model"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model
process_structure["randinit"]["train"] = gen_randinit_model
process_structure["randinit"]["train_loader"] = None
process_structure["randinit"]["test_loader"] = None
process_structure["randinit"]["saveas"] = save_randinit_model_as
process_structure["randinit"]["loadfrom"] = randinit_model_to_load

# "unbiased_colour_mnist"
process_structure["unbias"]["model"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model
process_structure["unbias"]["train"] = train_unbiased_colour_mnist_model
process_structure["unbias"]["train_loader"] = unbiased_train_dataloader
process_structure["unbias"]["test_loader"] = unbiased_test_dataloader
process_structure["unbias"]["saveas"] = save_unbiased_colour_mnist_model_as
process_structure["unbias"]["loadfrom"] =  unbiased_colour_mnist_model_to_load

for key, val in process_structure.items():
    print(f"Processing for {key=}")
    if key == "randinit":
        if gen_randinit_model:  # create new model but don't train it
            logtofile(f"model has already been initialised: save it as {val['saveas']}")
            torch.save(val["model"].state_dict(), val["saveas"])
        else:
            logtofile(f"{val['loadfrom']=}")
            val["model"].load_state_dict(torch.load(val["loadfrom"], map_location=torch.device(device)))
    else:
        if val["train"]:
            train_model(model=val["model"], train_loader=val["train_loader"], 
                        epochs=original_train_epochs, saveas=val["saveas"], 
                        description=key, device=device, logtofile=logtofile)
        else:
            logtofile(f"{val['loadfrom']=}")
            val["model"].load_state_dict(torch.load(val["loadfrom"], map_location=torch.device(device)))
    val["model"].eval()


Processing for key='randinit'
model has already been initialised: save it as ./results_2m/2025-03-26_14-24-26_SEED10_EPOCHS4_BGN0.1_exp2e_ResNet18_randinit.weights
Processing for key='unbias'
val['loadfrom']='./results_4_epochs/2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_unbiased_colour_mnist.weights'


## Measure the Accuracy, Record the Confusion Matrix


In [6]:
logtofile("Entering Confusion")
# logtofile(process.memory_info().rss)  # in bytes 

original_accuracy = dict()
for key, val in process_structure.items():
    logtofile(f"Accuracy Calculation for ResNet18 with {key=}")
    model = val["model"]
    model.eval() # ALWAYS DO THIS BEFORE YOU EVALUATE MODELS
    
    # Compute the model accuracy on the test set
    correct = 0
    total = 0
    
    # assuming 10 classes
    # rows represent actual class, columns are predicted
    confusion_matrix = torch.zeros(10,10, dtype=torch.int)
    
    TDL = biased_test_dataloader  # In test 2M - ALWAYS use biased dataset to measure/train stitch
    for data in TDL:
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        predictions = torch.argmax(model(inputs),1)
        
        matches = predictions == labels
        correct += matches.sum().item()
        total += len(labels)
        for idx, l in enumerate(labels):
            confusion_matrix[l, predictions[idx]] = 1 + confusion_matrix[l, predictions[idx]] 
    
    logtofile("Test the Trained Resnet18 against BIASED TEST DATALOADER")
    acc = ((100.0 * correct) / total)
    logtofile('Test Accuracy: %2.2f %%' % acc)
    original_accuracy[key] = acc
    logtofile('Confusion Matrix')
    logtofile(confusion_matrix)
    logtofile(confusion_matrix.sum())
    # logtofile(process.memory_info().rss)  # in bytes 


logtofile(f"{original_accuracy=}")

Entering Confusion
Accuracy Calculation for ResNet18 with key='randinit'
Test the Trained Resnet18 against BIASED TEST DATALOADER
Test Accuracy: 9.82 %
Confusion Matrix
tensor([[   0,    0,    0,    0,  980,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0, 1135,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0, 1032,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0, 1010,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,  982,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,  892,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,  958,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0, 1028,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,  974,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0, 1009,    0,    0,    0,    0,    0]],
       dtype=torch.int32)
tensor(10000)
Accuracy Calculation for ResNet18 with key='unbias'
Test the Trained Resnet18 against BIAS

## Measure Rank with __biased__ dataloader (test) before cutting and stitching

In [7]:
logtofile("Entering whole model check")
# logtofile(process.memory_info().rss)  # in bytes 

# For the Whole Model - but we will pass it through the RcvResNet18 function to get matching feature names
for key, val in process_structure.items():
    
    TDL = biased_test_dataloader  # ALWAYS use biased dataloader for this test
    logtofile(f"about to measure rank for {key}")
    if val["train"]:
        filename = val["saveas"] 
    else:    
        filename = val["loadfrom"] 
    assert os.path.exists(filename)
    mdl = torchvision.models.resnet18(num_classes=10) # Untrained model
    state = torch.load(filename, map_location=torch.device("cpu"))
    mdl.load_state_dict(state, assign=True)
    mdl=mdl.to(device)
    mdl = RcvResNet18(mdl, -1, colour_mnist_shape, device).to(device)

    out_filename = filename.split('/')[-1].replace('.weights', '-test.csv')
    
    outpath = f"./{results_root}_rank/{key}-bias-{seed}_{out_filename}"  # denote output name as <model_training_type>-dataset-<name>
    
    if os.path.exists(f"{outpath}"):
        logtofile(f"Already evaluated for {outpath}")
        continue
    logtofile(f"Measure Rank for {key=}")
    print(f"output to {outpath}")
            
    params = {}
    params["model"] = key
    params["dataset"] = "bias"
    params["seed"] = seed
    if val["train"]: # as only one network used, record its filename as both send and receive files
        params["send_file"] = val["saveas"] 
        params["rcv_file"] = val["saveas"] 
    else:    
        params["send_file"] = val["loadfrom"] 
        params["rcv_file"] = val["loadfrom"]     
    
    with torch.no_grad():
        layers, features, handles = install_hooks(mdl)
        
        metrics = evaluate_model(mdl, TDL, 'acc', verbose=2)
        params.update(metrics)
        classes = None
        df = perform_analysis(features, classes, layers, params, n=-1)
        df.to_csv(f"{outpath}")
    for h in handles:
        h.remove()
    del mdl, layers, features, metrics, params, df, handles
    gc.collect()
    # logtofile(process.memory_info().rss)  # in bytes 



Entering whole model check
about to measure rank for randinit
Measure Rank for key='randinit'
output to ./results_2m_rank/randinit-bias-10_2025-03-26_14-24-26_SEED10_EPOCHS4_BGN0.1_exp2e_ResNet18_randinit-test.csv


0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████| 21/21 [00:34<00:00,  1.63s/it]


about to measure rank for unbias
Already evaluated for ./results_2m_rank/unbias-bias-10_2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_unbiased_colour_mnist-test.csv


# Stitch at a given layer


## Train the stitch layer and check rank

In [8]:
logtofile("Entering Stitch/Rank")
# logtofile(process.memory_info().rss)  # in bytes 

logtofile(f"{device=}")
stitching_accuracies = dict()
stitching_penalties = dict()
# NOTE this is only valid as all models are the same architecture
num_layers_in_model = len(list(process_structure["randinit"]["model"].children()))  
for send_key, send_val in process_structure.items():
    if (send_key != "randinit"):
        logtofile(f"NOTE: Only running stitch with randinit send model")
        continue
    stitching_accuracies[send_key] = dict()
    stitching_penalties[send_key] = dict()
    
    for rcv_key, rcv_val in process_structure.items():        
        if (rcv_key != "unbias"):
            logtofile(f"NOTE: Only running stitch with unbias receive model")
            continue       
            
        stitching_accuracies[send_key][rcv_key] = dict()
        stitching_penalties[send_key][rcv_key] = dict()
        for layer_to_cut_after in range(3,num_layers_in_model - 1):
            # for consistency, use the rcv network for the filename stem.
            if rcv_val["train"]:
                filename = rcv_val["saveas"] 
            else:    
                filename = rcv_val["loadfrom"] 
            
            rank_filename = filename.split('/')[-1].replace('.weights', '-test.csv')        
            # denote output name as <model_training_type>-dataset-<name>
            # where <model_training_type> is [sender_model or X][layer_to_cut_after][Receiver_model]
            model_training_type = f"{send_key}{layer_to_cut_after}{rcv_key}"
            dataset_type = "bias"  # ALWAYS use bias dataset in this test
            outpath = f"./{results_root}_rank/{model_training_type}-{dataset_type}-{seed}_{rank_filename}"  
                            
            if os.path.exists(f"{outpath}"):
                logtofile(f"Already evaluated for {outpath}")
                continue
            ####################################################################################
            logtofile(f"Evaluate ranks and output to {outpath}")
            # logtofile(process.memory_info().rss)  # in bytes 

            logtofile(f"Train the stitch to a model stitched after layer {layer_to_cut_after} from {send_key} to {rcv_key}")    
            logtofile(f"Use the biased data loader (train and test) regardless of what {rcv_key} was trained on")

            initial_send_state = send_val["model"].state_dict().copy()
            initial_rcv_state = rcv_val["model"].state_dict().copy()
            # train a stitch on the unbiased_colour dataset to compare receiver network performance with stitched
            model_stitched = StitchedResNet18(send_model=send_val["model"], 
                                              after_layer_index=layer_to_cut_after, 
                                              rcv_model=rcv_val["model"],
                                              input_image_shape=colour_mnist_shape, device=device  ).to(device)
                        
            #############################################################
            # store the initial stitch state
            initial_stitch_weight = model_stitched.stitch.s_conv1.weight.clone()
            initial_stitch_bias   = model_stitched.stitch.s_conv1.bias.clone()
            stitch_initial_weight_outpath    = f"./{results_root}/STITCH_initial_weight_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
            stitch_initial_bias_outpath      = f"./{results_root}/STITCH_initial_bias_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
            torch.save(initial_stitch_weight, stitch_initial_weight_outpath)
            torch.save(initial_stitch_bias, stitch_initial_bias_outpath)
            ############################################################
                    
            # define the loss function and the optimiser
            loss_function = nn.CrossEntropyLoss()
            # Hernandez said : momentum 0.9, batch size 256, weight decay 0.01, learning rate 0.01, and a post-warmup cosine learning rate scheduler.
            # optimiser = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
            optimiser = optim.SGD(model_stitched.parameters(), lr=1e-4, momentum=0.9, weight_decay=0.01)
            
            # Put top model into train mode so that bn and dropout perform in training mode
            model_stitched.train()
            # Freeze the whole model
            model_stitched.requires_grad_(False)
            # Un-Freeze the stitch layer
            for name, param in model_stitched.stitch.named_parameters():
                param.requires_grad_(True)
            # the epoch loop: note that we're training the whole network
            for epoch in range(stitch_train_epochs):
                running_loss = 0.0
                for data in biased_train_dataloader:
                    # data is (representations, labels) tuple
                    # get the inputs and put them on the GPU
                    inputs, labels = data
                    inputs = inputs.to(device)
                    labels = labels.to(device)
            
                    # zero the parameter gradients
                    optimiser.zero_grad()
            
                    # forward + loss + backward + optimise (update weights)
                    outputs = model_stitched(inputs)
                    loss = loss_function(outputs, labels)
                    loss.backward()
                    optimiser.step()
            
                    # keep track of the loss this epoch
                    running_loss += loss.item()
                logtofile("Epoch %d, loss %4.2f" % (epoch, running_loss))
                # logtofile(process.memory_info().rss)  # in bytes 

            logtofile('**** Finished Training ****')
            
            model_stitched.eval() # ALWAYS DO THIS BEFORE YOU EVALUATE MODELS
            final_send_state = send_val["model"].state_dict().copy()
            final_rcv_state = rcv_val["model"].state_dict().copy()
            
            ############################################################
            # store the trained stitch
            trained_stitch_weight = model_stitched.stitch.s_conv1.weight.clone()
            trained_stitch_bias   = model_stitched.stitch.s_conv1.bias.clone()
            stitch_trained_weight_outpath    = f"./{results_root}/STITCH_trained_weight_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
            stitch_trained_bias_outpath      = f"./{results_root}/STITCH_trained_bias_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
            torch.save(trained_stitch_weight, stitch_trained_weight_outpath)
            torch.save(trained_stitch_bias, stitch_trained_bias_outpath)
                       
            stitch_weight_diff = trained_stitch_weight - initial_stitch_weight
            stitch_weight_delta = torch.linalg.norm(stitch_weight_diff).item()
            logtofile(f"Change in stitch weights: {stitch_weight_delta}")
            maxabsweight =  torch.max(stitch_weight_diff.abs()).item()
            logtofile(f"Largest abs weight change: {maxabsweight}")
            stitch_weight_number = torch.sum(torch.where(stitch_weight_diff.abs() > 0.1*maxabsweight, True, False)).item()
            logtofile(f"Number of weights changing > 0.1 of that: {stitch_weight_number}")

            
            print(f"Number of weight / bias in stitch layer is {len(initial_stitch_weight)}")
            stitch_bias_diff = trained_stitch_bias - initial_stitch_bias
            stitch_bias_delta = torch.linalg.norm(stitch_bias_diff).item()
            logtofile(f"Change in stitch bias: {stitch_bias_delta}")
            maxabsbias =  torch.max(stitch_bias_diff.abs()).item()
            logtofile(f"Largest abs bias change: {maxabsbias}")
            stitch_bias_number = torch.sum(torch.where(stitch_bias_diff.abs() > 0.1*maxabsbias, True, False)).item()
            logtofile(f"Number of bias changing > 0.1 of that: {stitch_bias_number}")
            ##############################################################

            
            # Compute the model accuracy on the test set
            correct = 0
            total = 0
            
            # assuming 10 classes
            # rows represent actual class, columns are predicted
            confusion_matrix = torch.zeros(10,10, dtype=torch.int)
            
            for data in biased_test_dataloader:  # Only use biased test data
                inputs, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                predictions = torch.argmax(model_stitched(inputs),1)
                matches = predictions == labels.to(device)
                correct += matches.sum().item()
                total += len(labels)
            
                for idx, l in enumerate(labels):
                    confusion_matrix[l, predictions[idx]] = 1 + confusion_matrix[l, predictions[idx]] 
            logtofile("Test the trained stitch against biased data")    
            acc =  ((100.0 * correct) / total)
            logtofile('Test Accuracy: %2.2f %%' % acc)
            logtofile('Confusion Matrix')
            logtofile(confusion_matrix)
            logtofile("===================================================================")
            # logtofile(process.memory_info().rss)  # in bytes 

            # Stitching penalty should be negative if there is an improvement, and is relative to the original receiver network
            stitching_accuracies[send_key][rcv_key][layer_to_cut_after] = acc
            stitching_penalties[send_key][rcv_key][layer_to_cut_after] = original_accuracy[rcv_key] - acc

            TDL = biased_test_dataloader
            params = {}
            params["model"] = model_training_type # a mnemonic
            params["dataset"] = dataset_type
            params["seed"] = seed
            if send_val["train"]:
                params["send_file"] = send_val["saveas"] 
            else:    
                params["send_file"] = send_val["loadfrom"] 
            if rcv_val["train"]:
                params["rcv_file"] = rcv_val["saveas"] 
            else:    
                params["rcv_file"] = rcv_val["loadfrom"] 
            params["stitch_weight_delta"] = stitch_weight_delta
            params["stitch_bias_delta"] = stitch_bias_delta        
            params["stitch_weight_number"] = stitch_weight_number
            params["stitch_bias_number"] = stitch_bias_number
            # logtofile(process.memory_info().rss)  # in bytes 
            with torch.no_grad():
                layers, features, handles = install_hooks(model_stitched)                
                metrics = evaluate_model(model_stitched, TDL, 'acc', verbose=2)
                params.update(metrics)
                classes = None
                df = perform_analysis(features, classes, layers, params, n=-1)
                df.to_csv(f"{outpath}")
                
            for h in handles:
                h.remove()
            del model_stitched, layers, features, metrics, params, df, handles
            gc.collect()
            # logtofile(process.memory_info().rss)  # in bytes 

            

Entering Stitch/Rank
device='cuda:0'
NOTE: Only running stitch with unbias receive model
Evaluate ranks and output to ./results_2m_rank/randinit3unbias-bias-10_2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 3 from randinit to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 3 of send_model is: torch.Size([1, 64, 7, 7])
Epoch 0, loss 341.21
Epoch 1, loss 62.31
Epoch 2, loss 45.08
Epoch 3, loss 34.28
Epoch 4, loss 29.02
Epoch 5, loss 24.80
Epoch 6, loss 21.88
Epoch 7, loss 19.52
Epoch 8, loss 17.44
Epoch 9, loss 16.26
**** Finished Training ****
Change in stitch weights: 1.2918787002563477
Largest abs weight change: 0.11091408878564835
Number of weights changing > 0.1 of that: 2036
Number of weight / bias in stitch layer is 64
Change in stitch bias: 0.02698933705687523
Largest abs

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:39<00:00,  1.80s/it]


Evaluate ranks and output to ./results_2m_rank/randinit4unbias-bias-10_2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 4 from randinit to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 4 of send_model is: torch.Size([1, 64, 7, 7])
Epoch 0, loss 386.39
Epoch 1, loss 66.42
Epoch 2, loss 45.29
Epoch 3, loss 35.00
Epoch 4, loss 28.89
Epoch 5, loss 24.70
Epoch 6, loss 21.58
Epoch 7, loss 19.56
Epoch 8, loss 17.18
Epoch 9, loss 15.85
**** Finished Training ****
Change in stitch weights: 1.436583399772644
Largest abs weight change: 0.14218488335609436
Number of weights changing > 0.1 of that: 1935
Number of weight / bias in stitch layer is 64
Change in stitch bias: 0.024001406505703926
Largest abs bias change: 0.005690738558769226
Number of bias changing > 0.1 of that: 60
Test the tra

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:37<00:00,  1.72s/it]


Evaluate ranks and output to ./results_2m_rank/randinit5unbias-bias-10_2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 5 from randinit to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 5 of send_model is: torch.Size([1, 128, 4, 4])
Epoch 0, loss 527.25
Epoch 1, loss 58.13
Epoch 2, loss 36.97
Epoch 3, loss 28.82
Epoch 4, loss 23.50
Epoch 5, loss 19.36
Epoch 6, loss 17.48
Epoch 7, loss 15.32
Epoch 8, loss 13.67
Epoch 9, loss 12.83
**** Finished Training ****
Change in stitch weights: 1.5029387474060059
Largest abs weight change: 0.06785206496715546
Number of weights changing > 0.1 of that: 8920
Number of weight / bias in stitch layer is 128
Change in stitch bias: 0.024144934490323067
Largest abs bias change: 0.003987401723861694
Number of bias changing > 0.1 of that: 113
Test the

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:36<00:00,  1.67s/it]


Evaluate ranks and output to ./results_2m_rank/randinit6unbias-bias-10_2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 6 from randinit to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 6 of send_model is: torch.Size([1, 256, 2, 2])
Epoch 0, loss 296.21
Epoch 1, loss 55.22
Epoch 2, loss 37.36
Epoch 3, loss 28.99
Epoch 4, loss 23.56
Epoch 5, loss 20.43
Epoch 6, loss 18.16
Epoch 7, loss 16.70
Epoch 8, loss 14.95
Epoch 9, loss 14.03
**** Finished Training ****
Change in stitch weights: 1.465800166130066
Largest abs weight change: 0.04203174635767937
Number of weights changing > 0.1 of that: 26901
Number of weight / bias in stitch layer is 256
Change in stitch bias: 0.025760801509022713
Largest abs bias change: 0.002852499485015869
Number of bias changing > 0.1 of that: 233
Test the

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:38<00:00,  1.76s/it]


Evaluate ranks and output to ./results_2m_rank/randinit7unbias-bias-10_2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 7 from randinit to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 7 of send_model is: torch.Size([1, 512, 1, 1])
Epoch 0, loss 262.68
Epoch 1, loss 51.03
Epoch 2, loss 36.29
Epoch 3, loss 29.35
Epoch 4, loss 25.41
Epoch 5, loss 22.24
Epoch 6, loss 20.43
Epoch 7, loss 19.82
Epoch 8, loss 17.96
Epoch 9, loss 16.70
**** Finished Training ****
Change in stitch weights: 1.471322774887085
Largest abs weight change: 0.018993109464645386
Number of weights changing > 0.1 of that: 125123
Number of weight / bias in stitch layer is 512
Change in stitch bias: 0.02649996243417263
Largest abs bias change: 0.0020172595977783203
Number of bias changing > 0.1 of that: 456
Test t

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:47<00:00,  2.15s/it]


Evaluate ranks and output to ./results_2m_rank/randinit8unbias-bias-10_2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 8 from randinit to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 8 of send_model is: torch.Size([1, 512, 1, 1])
Epoch 0, loss 279.23
Epoch 1, loss 52.09
Epoch 2, loss 36.84
Epoch 3, loss 30.50
Epoch 4, loss 25.81
Epoch 5, loss 22.40
Epoch 6, loss 20.47
Epoch 7, loss 18.88
Epoch 8, loss 18.39
Epoch 9, loss 16.73
**** Finished Training ****
Change in stitch weights: 1.4870556592941284
Largest abs weight change: 0.023864634335041046
Number of weights changing > 0.1 of that: 98921
Number of weight / bias in stitch layer is 512
Change in stitch bias: 0.02722635120153427
Largest abs bias change: 0.0020109154284000397
Number of bias changing > 0.1 of that: 463
Test t

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:33<00:00,  1.54s/it]


NOTE: Only running stitch with randinit send model


In [9]:
#print(initial_rcv_state)


In [10]:
#print(final_rcv_state)

In [11]:
#rcv_val["model"].state_dict()

In [12]:
#initial_send_state = send_val["model"].state_dict().copy()
#initial_send_state

In [13]:
#final_send_state

In [14]:
logtofile(f"{stitching_accuracies=}")
logtofile(f"{stitching_penalties=}")

stitching_accuracies={'randinit': {'unbias': {3: 98.97, 4: 99.23, 5: 99.54, 6: 99.41, 7: 99.21, 8: 99.28}}}
stitching_penalties={'randinit': {'unbias': {3: -1.1700000000000017, 4: -1.4300000000000068, 5: -1.740000000000009, 6: -1.6099999999999994, 7: -1.4099999999999966, 8: -1.480000000000004}}}


In [15]:
for s_key in stitching_accuracies:    
    for r_key in stitching_accuracies[s_key]:
        logtofile(f"{s_key}-{r_key}")
        logtofile(f"{original_accuracy[r_key]=}")
        logtofile("Stitch Accuracy")
        for layer in stitching_accuracies[s_key][r_key]:
            logtofile(f"L{layer}: {stitching_accuracies[s_key][r_key][layer]}")
        logtofile("--------------------------")

randinit-unbias
original_accuracy[r_key]=97.8
Stitch Accuracy
L3: 98.97
L4: 99.23
L5: 99.54
L6: 99.41
L7: 99.21
L8: 99.28
--------------------------
