# Exp 2k: Bias data, unbias trained receiver, bg_unbiased sender network
bg_unbiased means digit and bg_colour are unrelated, but the model is trained and tested on the background.

Now we will 
only use unbias-trained network as receiver
only use bias data to train the stitch
Try the sender networks at all different stitch levels
## Rank
Also perform rank analysis on the stitched networks based on exp1e
## 4 Epochs
Only do 4 epochs of training (keep 10 epochs of stitch training) so that the initial models are weaker

In [1]:
# Packages
%matplotlib inline

import argparse
import gc
import os.path

import pandas as pd
from torch.linalg import LinAlgError

import matplotlib.pyplot as plt
import torchvision
import torch
from torch import optim

from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import datetime

import random
import numpy as np

import sys
import os
# add the path to find colour_mnist
sys.path.append(os.path.abspath('../ReferenceCode'))
import colour_mnist
from stitch_utils import train_model, RcvResNet18, StitchedResNet18, get_layer_output_shape
from stitch_utils import generate_activations, SyntheticDataset
import stitch_utils

# add the path to find the rank analysis code
# https://github.com/DHLSmith/jons-tunnel-effect/tree/NeurIPSPaper
sys.path.append(os.path.abspath('../../jons-tunnel-effect/'))
from utils.modelfitting import evaluate_model, set_seed
from extract_weight_rank import install_hooks, perform_analysis

import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST

# To track memory usage
import psutil
process = psutil.Process()
            

def logtofile(log_text, verbose=True):
    if verbose:
        print(log_text)
    with open(save_log_as, "a") as f:    
        print(log_text, file=f)

In [2]:
# Set Parameters

# fix random seed for reproducibility
seed = 57
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
random.seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

results_root = "results_2k"
train_all = False  # Just use pretrained model

# BG_UNBIASED is digits with randomly selected colour background. Targets represent the colour
train_bg_unbiased_colour_mnist_model = train_all  # when False, automatically loads a trained model
bg_unbiased_colour_mnist_model_to_load = "./results_4_epochs/2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_bg_unbiased_colour_mnist.weights"

# UNBIASED is digits with randoly selected colour background. Targets are digit values
train_unbiased_colour_mnist_model = train_all  # when False, automatically loads a trained model
unbiased_colour_mnist_model_to_load = "./results_4_epochs/2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_unbiased_colour_mnist.weights"

original_train_epochs = 4
bg_noise = 0.1

stitch_train_epochs = 10

batch_size = 128

In [4]:
# Generate filenames and log the setup details
formatted_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
filename_prefix = f"./{results_root}/{formatted_time}_SEED{seed}_EPOCHS{original_train_epochs}_BGN{bg_noise}_exp2e_ResNet18"
#save_mix_mnist_model_as = f"{filename_prefix}_mix_mnist.weights"
#save_bw_mnist_model_as = f"{filename_prefix}_bw_mnist.weights"
#save_bg_only_colour_mnist_model_as = f"{filename_prefix}_bg_only_colour_mnist.weights"
save_bg_unbiased_colour_mnist_model_as = f"{filename_prefix}_bg_unbiased_colour_mnist.weights"
#save_biased_colour_mnist_model_as = f"{filename_prefix}_biased_colour_mnist.weights"
save_unbiased_colour_mnist_model_as = f"{filename_prefix}_unbiased_colour_mnist.weights"
save_log_as = f"{filename_prefix}_log.txt"

colour_mnist_shape = (3,28,28)


logtofile(f"Executed at {formatted_time}")
logtofile(f"logging to {save_log_as}")
logtofile(f"{seed=}")
logtofile(f"{bg_noise=}")

logtofile(f"{train_bg_unbiased_colour_mnist_model=}")
if train_bg_unbiased_colour_mnist_model:
    logtofile(f"{save_bg_unbiased_colour_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{bg_unbiased_colour_mnist_model_to_load=}")

logtofile(f"{train_unbiased_colour_mnist_model=}")
if train_unbiased_colour_mnist_model:
    logtofile(f"{save_unbiased_colour_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{unbiased_colour_mnist_model_to_load=}")

logtofile(f"{stitch_train_epochs=}")
logtofile(f"================================================")

Executed at 2025-03-26_10-38-34
logging to ./results_2k/2025-03-26_10-38-34_SEED57_EPOCHS4_BGN0.1_exp2e_ResNet18_log.txt
seed=57
bg_noise=0.1
train_bg_unbiased_colour_mnist_model=False
bg_unbiased_colour_mnist_model_to_load='./results_4_epochs/2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_bg_unbiased_colour_mnist.weights'
train_unbiased_colour_mnist_model=False
unbiased_colour_mnist_model_to_load='./results_4_epochs/2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_unbiased_colour_mnist.weights'
stitch_train_epochs=10


mnist and cifar-10 both use 10-classes, with 60_000 train samples and 10_000 test samples. 

In [8]:

# unbiased means each digit has correct label and random colour - but bg means we will use colour as label (i.e. the bias_target will be the target)
bg_unbiased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=0.1, train=True, bg_noise_level=bg_noise, bias_targets_as_targets=True)
bg_unbiased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=0.1, train=False, bg_noise_level=bg_noise, bias_targets_as_targets=True)

# biased means each digit has correct label and consistent colour - Expect network to learn the colours only
biased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=1.0, train=True, bg_noise_level=bg_noise, standard_getitem=True)
biased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=1.0, train=False, bg_noise_level=bg_noise, standard_getitem=True)

# unbiased means each digit has correct label and random colour - Expect network to disregard colours?
unbiased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=0.1, train=True, bg_noise_level=bg_noise, standard_getitem=True)
unbiased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=0.1, train=False, bg_noise_level=bg_noise, standard_getitem=True)

## Set up resnet18 models and train it on versions of MNIST

In [6]:

process_structure = dict()
device = 'cuda:0'


process_structure["bg"] = dict()
process_structure["unbias"]    = dict()

# "bg_unbiased_colour"
process_structure["bg"]["model"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model
process_structure["bg"]["train"] = train_bg_unbiased_colour_mnist_model 
process_structure["bg"]["train_loader"] = bg_unbiased_train_dataloader
process_structure["bg"]["test_loader"] = bg_unbiased_test_dataloader
process_structure["bg"]["saveas"] = save_bg_unbiased_colour_mnist_model_as
process_structure["bg"]["loadfrom"] = bg_unbiased_colour_mnist_model_to_load

# "unbiased_colour_mnist"
process_structure["unbias"]["model"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model
process_structure["unbias"]["train"] = train_unbiased_colour_mnist_model
process_structure["unbias"]["train_loader"] = unbiased_train_dataloader
process_structure["unbias"]["test_loader"] = unbiased_test_dataloader
process_structure["unbias"]["saveas"] = save_unbiased_colour_mnist_model_as
process_structure["unbias"]["loadfrom"] =  unbiased_colour_mnist_model_to_load

for key, val in process_structure.items():
    print(f"Processing for {key=}")
    if val["train"]:
        train_model(model=val["model"], train_loader=val["train_loader"], 
                    epochs=original_train_epochs, saveas=val["saveas"], 
                    description=key, device=device, logtofile=logtofile)
    else:
        logtofile(f"{val['loadfrom']=}")
        val["model"].load_state_dict(torch.load(val["loadfrom"], map_location=torch.device(device)))
    val["model"].eval()


Processing for key='bg'
val['loadfrom']='./results_4_epochs/2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_bg_unbiased_colour_mnist.weights'
Processing for key='unbias'
val['loadfrom']='./results_4_epochs/2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_unbiased_colour_mnist.weights'


## Measure the Accuracy, Record the Confusion Matrix


In [9]:
logtofile("Entering Confusion")
# logtofile(process.memory_info().rss)  # in bytes 

original_accuracy = dict()
for key, val in process_structure.items():
    logtofile(f"Accuracy Calculation for ResNet18 with {key=}")
    model = val["model"]
    model.eval() # ALWAYS DO THIS BEFORE YOU EVALUATE MODELS
    
    # Compute the model accuracy on the test set
    correct = 0
    total = 0
    
    # assuming 10 classes
    # rows represent actual class, columns are predicted
    confusion_matrix = torch.zeros(10,10, dtype=torch.int)
    
    TDL = biased_test_dataloader  # In test 2D - ALWAYS use biased dataset to measure/train stitch
    for data in TDL:
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        predictions = torch.argmax(model(inputs),1)
        
        matches = predictions == labels
        correct += matches.sum().item()
        total += len(labels)
        for idx, l in enumerate(labels):
            confusion_matrix[l, predictions[idx]] = 1 + confusion_matrix[l, predictions[idx]] 
    
    logtofile("Test the Trained Resnet18 against BIASED TEST DATALOADER")
    acc = ((100.0 * correct) / total)
    logtofile('Test Accuracy: %2.2f %%' % acc)
    original_accuracy[key] = acc
    logtofile('Confusion Matrix')
    logtofile(confusion_matrix)
    logtofile(confusion_matrix.sum())
    # logtofile(process.memory_info().rss)  # in bytes 


logtofile(f"{original_accuracy=}")

Entering Confusion
Accuracy Calculation for ResNet18 with key='bg'
Test the Trained Resnet18 against BIASED TEST DATALOADER
Test Accuracy: 100.00 %
Confusion Matrix
tensor([[ 980,    0,    0,    0,    0,    0,    0,    0,    0,    0],
        [   0, 1135,    0,    0,    0,    0,    0,    0,    0,    0],
        [   0,    0, 1032,    0,    0,    0,    0,    0,    0,    0],
        [   0,    0,    0, 1010,    0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,  982,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,  892,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0,  958,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0,    0, 1028,    0,    0],
        [   0,    0,    0,    0,    0,    0,    0,    0,  974,    0],
        [   0,    0,    0,    0,    0,    0,    0,    0,    0, 1009]],
       dtype=torch.int32)
tensor(10000)
Accuracy Calculation for ResNet18 with key='unbias'
Test the Trained Resnet18 against BIASED T

## Measure Rank with __biased__ dataloader (test) before cutting and stitching

In [10]:
logtofile("Entering whole model check")
# logtofile(process.memory_info().rss)  # in bytes 

# For the Whole Model - but we will pass it through the RcvResNet18 function to get matching feature names
for key, val in process_structure.items():
    
    TDL = biased_test_dataloader  # ALWAYS use biased dataloader for this test
    
    if val["train"]:
        filename = val["saveas"] 
    else:    
        filename = val["loadfrom"] 
    assert os.path.exists(filename)
    mdl = torchvision.models.resnet18(num_classes=10) # Untrained model
    state = torch.load(filename, map_location=torch.device("cpu"))
    mdl.load_state_dict(state, assign=True)
    mdl=mdl.to(device)
    mdl = RcvResNet18(mdl, -1, colour_mnist_shape, device).to(device)

    out_filename = filename.split('/')[-1].replace('.weights', '-test.csv')
    
    outpath = f"./{results_root}_rank/{key}-bias-{seed}_{out_filename}"  # denote output name as <model_training_type>-dataset-<name>
    
    if os.path.exists(f"{outpath}"):
        logtofile(f"Already evaluated for {outpath}")
        continue
    logtofile(f"Measure Rank for {key=}")
    print(f"output to {outpath}")
            
    params = {}
    params["model"] = key
    params["dataset"] = "bias"
    params["seed"] = seed
    if val["train"]: # as only one network used, record its filename as both send and receive files
        params["send_file"] = val["saveas"] 
        params["rcv_file"] = val["saveas"] 
    else:    
        params["send_file"] = val["loadfrom"] 
        params["rcv_file"] = val["loadfrom"]     
    
    with torch.no_grad():
        layers, features, handles = install_hooks(mdl)
        
        metrics = evaluate_model(mdl, TDL, 'acc', verbose=2)
        params.update(metrics)
        classes = None
        df = perform_analysis(features, classes, layers, params, n=-1)
        df.to_csv(f"{outpath}")
    for h in handles:
        h.remove()
    del mdl, layers, features, metrics, params, df, handles
    gc.collect()
    # logtofile(process.memory_info().rss)  # in bytes 



Entering whole model check
Measure Rank for key='bg'
output to ./results_2k_rank/bg-bias-57_2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_bg_unbiased_colour_mnist-test.csv


0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████| 21/21 [00:45<00:00,  2.18s/it]


Measure Rank for key='unbias'
output to ./results_2k_rank/unbias-bias-57_2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_unbiased_colour_mnist-test.csv


0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████| 21/21 [00:40<00:00,  1.93s/it]


# Stitch at a given layer


## Train the stitch layer and check rank

In [12]:
logtofile("Entering Stitch/Rank")
# logtofile(process.memory_info().rss)  # in bytes 

logtofile(f"{device=}")
stitching_accuracies = dict()
stitching_penalties = dict()
# NOTE this is only valid as all models are the same architecture
num_layers_in_model = len(list(process_structure["bg"]["model"].children()))  
for send_key, send_val in process_structure.items():
    if (send_key != "bg"):
        logtofile(f"NOTE: Only running stitch with bg send model")
        continue
    stitching_accuracies[send_key] = dict()
    stitching_penalties[send_key] = dict()
    
    for rcv_key, rcv_val in process_structure.items():        
        if (rcv_key != "unbias"):
            logtofile(f"NOTE: Only running stitch with unbias receive model")
            continue       
            
        stitching_accuracies[send_key][rcv_key] = dict()
        stitching_penalties[send_key][rcv_key] = dict()
        for layer_to_cut_after in range(3,num_layers_in_model - 1):
            # for consistency, use the rcv network for the filename stem.
            if rcv_val["train"]:
                filename = rcv_val["saveas"] 
            else:    
                filename = rcv_val["loadfrom"] 
            
            rank_filename = filename.split('/')[-1].replace('.weights', '-test.csv')        
            # denote output name as <model_training_type>-dataset-<name>
            # where <model_training_type> is [sender_model or X][layer_to_cut_after][Receiver_model]
            model_training_type = f"{send_key}{layer_to_cut_after}{rcv_key}"
            dataset_type = "bias"  # ALWAYS use bias dataset in this test
            outpath = f"./{results_root}_rank/{model_training_type}-{dataset_type}-{seed}_{rank_filename}"  
                            
            if os.path.exists(f"{outpath}"):
                logtofile(f"Already evaluated for {outpath}")
                continue
            ####################################################################################
            logtofile(f"Evaluate ranks and output to {outpath}")
            # logtofile(process.memory_info().rss)  # in bytes 

            logtofile(f"Train the stitch to a model stitched after layer {layer_to_cut_after} from {send_key} to {rcv_key}")    
            logtofile(f"Use the biased data loader (train and test) regardless of what {rcv_key} was trained on")
            
            # train a stitch on the unbiased_colour dataset to compare receiver network performance with stitched
            model_stitched = StitchedResNet18(send_model=send_val["model"], 
                                              after_layer_index=layer_to_cut_after, 
                                              rcv_model=rcv_val["model"],
                                              input_image_shape=colour_mnist_shape, device=device  ).to(device)
                        
            #############################################################
            # store the initial stitch state
            initial_stitch_weight = model_stitched.stitch.s_conv1.weight.clone()
            initial_stitch_bias   = model_stitched.stitch.s_conv1.bias.clone()
            stitch_initial_weight_outpath    = f"./{results_root}/STITCH_initial_weight_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
            stitch_initial_bias_outpath      = f"./{results_root}/STITCH_initial_bias_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
            torch.save(initial_stitch_weight, stitch_initial_weight_outpath)
            torch.save(initial_stitch_bias, stitch_initial_bias_outpath)
            ############################################################
                    
            # define the loss function and the optimiser
            loss_function = nn.CrossEntropyLoss()
            # Hernandez said : momentum 0.9, batch size 256, weight decay 0.01, learning rate 0.01, and a post-warmup cosine learning rate scheduler.
            # optimiser = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
            optimiser = optim.SGD(model_stitched.parameters(), lr=1e-4, momentum=0.9, weight_decay=0.01)
            
            # Put top model into train mode so that bn and dropout perform in training mode
            model_stitched.train()
            # Freeze the whole model
            model_stitched.requires_grad_(False)
            # Un-Freeze the stitch layer
            for name, param in model_stitched.stitch.named_parameters():
                param.requires_grad_(True)
            # the epoch loop: note that we're training the whole network
            for epoch in range(stitch_train_epochs):
                running_loss = 0.0
                for data in biased_train_dataloader:
                    # data is (representations, labels) tuple
                    # get the inputs and put them on the GPU
                    inputs, labels = data
                    inputs = inputs.to(device)
                    labels = labels.to(device)
            
                    # zero the parameter gradients
                    optimiser.zero_grad()
            
                    # forward + loss + backward + optimise (update weights)
                    outputs = model_stitched(inputs)
                    loss = loss_function(outputs, labels)
                    loss.backward()
                    optimiser.step()
            
                    # keep track of the loss this epoch
                    running_loss += loss.item()
                logtofile("Epoch %d, loss %4.2f" % (epoch, running_loss))
                # logtofile(process.memory_info().rss)  # in bytes 

            logtofile('**** Finished Training ****')
            
            model_stitched.eval() # ALWAYS DO THIS BEFORE YOU EVALUATE MODELS

            ############################################################
            # store the trained stitch
            trained_stitch_weight = model_stitched.stitch.s_conv1.weight.clone()
            trained_stitch_bias   = model_stitched.stitch.s_conv1.bias.clone()
            stitch_trained_weight_outpath    = f"./{results_root}/STITCH_trained_weight_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
            stitch_trained_bias_outpath      = f"./{results_root}/STITCH_trained_bias_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
            torch.save(trained_stitch_weight, stitch_trained_weight_outpath)
            torch.save(trained_stitch_bias, stitch_trained_bias_outpath)
                       
            stitch_weight_diff = trained_stitch_weight - initial_stitch_weight
            stitch_weight_delta = torch.linalg.norm(stitch_weight_diff).item()
            logtofile(f"Change in stitch weights: {stitch_weight_delta}")
            maxabsweight =  torch.max(stitch_weight_diff.abs()).item()
            logtofile(f"Largest abs weight change: {maxabsweight}")
            stitch_weight_number = torch.sum(torch.where(stitch_weight_diff.abs() > 0.1*maxabsweight, True, False)).item()
            logtofile(f"Number of weights changing > 0.1 of that: {stitch_weight_number}")

            
            print(f"Number of weight / bias in stitch layer is {len(initial_stitch_weight)}")
            stitch_bias_diff = trained_stitch_bias - initial_stitch_bias
            stitch_bias_delta = torch.linalg.norm(stitch_bias_diff).item()
            logtofile(f"Change in stitch bias: {stitch_bias_delta}")
            maxabsbias =  torch.max(stitch_bias_diff.abs()).item()
            logtofile(f"Largest abs bias change: {maxabsbias}")
            stitch_bias_number = torch.sum(torch.where(stitch_bias_diff.abs() > 0.1*maxabsbias, True, False)).item()
            logtofile(f"Number of bias changing > 0.1 of that: {stitch_bias_number}")
            ##############################################################

            
            # Compute the model accuracy on the test set
            correct = 0
            total = 0
            
            # assuming 10 classes
            # rows represent actual class, columns are predicted
            confusion_matrix = torch.zeros(10,10, dtype=torch.int)
            
            for data in biased_test_dataloader:  # Only use biased test data
                inputs, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                predictions = torch.argmax(model_stitched(inputs),1)
                matches = predictions == labels.to(device)
                correct += matches.sum().item()
                total += len(labels)
            
                for idx, l in enumerate(labels):
                    confusion_matrix[l, predictions[idx]] = 1 + confusion_matrix[l, predictions[idx]] 
            logtofile("Test the trained stitch against biased data")    
            acc =  ((100.0 * correct) / total)
            logtofile('Test Accuracy: %2.2f %%' % acc)
            logtofile('Confusion Matrix')
            logtofile(confusion_matrix)
            logtofile("===================================================================")
            # logtofile(process.memory_info().rss)  # in bytes 

            # Stitching penalty should be negative if there is an improvement, and is relative to the original receiver network
            stitching_accuracies[send_key][rcv_key][layer_to_cut_after] = acc
            stitching_penalties[send_key][rcv_key][layer_to_cut_after] = original_accuracy[rcv_key] - acc

            TDL = biased_test_dataloader
            params = {}
            params["model"] = model_training_type # a mnemonic
            params["dataset"] = dataset_type
            params["seed"] = seed
            if send_val["train"]:
                params["send_file"] = send_val["saveas"] 
            else:    
                params["send_file"] = send_val["loadfrom"] 
            if rcv_val["train"]:
                params["rcv_file"] = rcv_val["saveas"] 
            else:    
                params["rcv_file"] = rcv_val["loadfrom"] 
            params["stitch_weight_delta"] = stitch_weight_delta
            params["stitch_bias_delta"] = stitch_bias_delta        
            params["stitch_weight_number"] = stitch_weight_number
            params["stitch_bias_number"] = stitch_bias_number
            # logtofile(process.memory_info().rss)  # in bytes 
            with torch.no_grad():
                layers, features, handles = install_hooks(model_stitched)                
                metrics = evaluate_model(model_stitched, TDL, 'acc', verbose=2)
                params.update(metrics)
                classes = None
                df = perform_analysis(features, classes, layers, params, n=-1)
                df.to_csv(f"{outpath}")
                
            for h in handles:
                h.remove()
            del model_stitched, layers, features, metrics, params, df, handles
            gc.collect()
            # logtofile(process.memory_info().rss)  # in bytes 

            

Entering Stitch/Rank
device='cuda:0'
NOTE: Only running stitch with unbias receive model
Evaluate ranks and output to ./results_2k_rank/bg3unbias-bias-57_2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 3 from bg to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 3 of send_model is: torch.Size([1, 64, 7, 7])
Epoch 0, loss 348.47
Epoch 1, loss 24.81
Epoch 2, loss 15.78
Epoch 3, loss 11.58
Epoch 4, loss 9.20
Epoch 5, loss 8.04
Epoch 6, loss 7.07
Epoch 7, loss 5.81
Epoch 8, loss 5.22
Epoch 9, loss 5.50
**** Finished Training ****
Change in stitch weights: 1.0196986198425293
Largest abs weight change: 0.11292095482349396
Number of weights changing > 0.1 of that: 1697
Number of weight / bias in stitch layer is 64
Change in stitch bias: 0.026380406692624092
Largest abs bias change: 0.0

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:46<00:00,  2.10s/it]


Evaluate ranks and output to ./results_2k_rank/bg4unbias-bias-57_2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 4 from bg to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 4 of send_model is: torch.Size([1, 64, 7, 7])
Epoch 0, loss 452.58
Epoch 1, loss 29.05
Epoch 2, loss 16.21
Epoch 3, loss 11.48
Epoch 4, loss 10.02
Epoch 5, loss 7.24
Epoch 6, loss 6.27
Epoch 7, loss 5.40
Epoch 8, loss 5.07
Epoch 9, loss 3.99
**** Finished Training ****
Change in stitch weights: 1.1344859600067139
Largest abs weight change: 0.08438502997159958
Number of weights changing > 0.1 of that: 2393
Number of weight / bias in stitch layer is 64
Change in stitch bias: 0.025598231703042984
Largest abs bias change: 0.0056830719113349915
Number of bias changing > 0.1 of that: 60
Test the trained stitch aga

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [00:47<00:00,  2.17s/it]


Evaluate ranks and output to ./results_2k_rank/bg5unbias-bias-57_2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 5 from bg to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 5 of send_model is: torch.Size([1, 128, 4, 4])
Epoch 0, loss 156.19
Epoch 1, loss 12.15
Epoch 2, loss 6.75
Epoch 3, loss 5.40
Epoch 4, loss 3.57
Epoch 5, loss 3.44
Epoch 6, loss 2.86
Epoch 7, loss 3.48
Epoch 8, loss 2.13
Epoch 9, loss 2.21
**** Finished Training ****
Change in stitch weights: 0.8525775671005249
Largest abs weight change: 0.042431335896253586
Number of weights changing > 0.1 of that: 7563
Number of weight / bias in stitch layer is 128
Change in stitch bias: 0.027369199320673943
Largest abs bias change: 0.004016287624835968
Number of bias changing > 0.1 of that: 117
Test the trained stitch aga

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [01:10<00:00,  3.21s/it]


Evaluate ranks and output to ./results_2k_rank/bg6unbias-bias-57_2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 6 from bg to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 6 of send_model is: torch.Size([1, 256, 2, 2])
Epoch 0, loss 73.44
Epoch 1, loss 5.03
Epoch 2, loss 3.22
Epoch 3, loss 2.76
Epoch 4, loss 2.07
Epoch 5, loss 2.07
Epoch 6, loss 1.79
Epoch 7, loss 1.56
Epoch 8, loss 1.55
Epoch 1, loss 1.59
Epoch 2, loss 2.07
Epoch 3, loss 1.40
Epoch 4, loss 0.98
Epoch 5, loss 1.52
Epoch 6, loss 0.89
Epoch 7, loss 0.80
Epoch 8, loss 1.08
Epoch 9, loss 0.94
**** Finished Training ****
Change in stitch weights: 0.7299830317497253
Largest abs weight change: 0.008423078805208206
Number of weights changing > 0.1 of that: 157864
Number of weight / bias in stitch layer is 512
Change i

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [01:24<00:00,  3.85s/it]


Evaluate ranks and output to ./results_2k_rank/bg8unbias-bias-57_2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_unbiased_colour_mnist-test.csv
Train the stitch to a model stitched after layer 8 from bg to unbias
Use the biased data loader (train and test) regardless of what unbias was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 8 of send_model is: torch.Size([1, 512, 1, 1])
Epoch 0, loss 35.09
Epoch 1, loss 1.95
Epoch 2, loss 1.58
Epoch 3, loss 1.57
Epoch 4, loss 0.96
Epoch 5, loss 1.00
Epoch 6, loss 0.95
Epoch 7, loss 1.33
Epoch 8, loss 1.04
Epoch 9, loss 0.97
**** Finished Training ****
Change in stitch weights: 0.7378382682800293
Largest abs weight change: 0.009735164232552052
Number of weights changing > 0.1 of that: 142574
Number of weight / bias in stitch layer is 512
Change in stitch bias: 0.026206182315945625
Largest abs bias change: 0.0020092539489269257
Number of bias changing > 0.1 of that: 464
Test the trained stitch ag

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████| 22/22 [01:34<00:00,  4.30s/it]


NOTE: Only running stitch with bg send model


In [13]:
logtofile(f"{stitching_accuracies=}")
logtofile(f"{stitching_penalties=}")

stitching_accuracies={'bg': {'unbias': {3: 99.89, 4: 99.91, 5: 100.0, 6: 100.0, 7: 100.0, 8: 100.0}}}
stitching_penalties={'bg': {'unbias': {3: -2.1200000000000045, 4: -2.1400000000000006, 5: -2.230000000000004, 6: -2.230000000000004, 7: -2.230000000000004, 8: -2.230000000000004}}}


In [20]:
for s_key in stitching_accuracies:    
    for r_key in stitching_accuracies[s_key]:
        logtofile(f"{s_key}-{r_key}")
        logtofile(f"{original_accuracy[r_key]=}")
        logtofile("Stitch Accuracy")
        for layer in stitching_accuracies[s_key][r_key]:
            logtofile(f"L{layer}: {stitching_accuracies[s_key][r_key][layer]}")
        logtofile("--------------------------")

bg-unbias
original_accuracy[r_key]=97.77
Stitch Accuracy
L3: 99.89
L4: 99.91
L5: 100.0
L6: 100.0
L7: 100.0
L8: 100.0
--------------------------
