## exp 2g_a For BW and BGONLY models and datasets, stitch a model to itself, but for remapped classes

Can it learn to think that a 3 is a 7? or that blue is green?

## Rank
maybe remove rank analysis on the stitched networks to save time
## 4 Epochs
Only do 4 epochs of training (do more epochs of stitch training) so that the initial models are weaker
## Consistency - set train_all to False and use the original trained networks

In [1]:
# Packages
%matplotlib inline

import argparse
import gc
import os.path

import pandas as pd
from torch.linalg import LinAlgError

import matplotlib.pyplot as plt
import torchvision
import torch
from torch import optim

from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import datetime

import random
import numpy as np

import sys
import os
# add the path to find colour_mnist
sys.path.append(os.path.abspath('../ReferenceCode'))
import colour_mnist
from stitch_utils import train_model, RcvResNet18, StitchedResNet18, get_layer_output_shape
from stitch_utils import generate_activations, SyntheticDataset
import stitch_utils

# add the path to find the rank analysis code
# https://github.com/DHLSmith/jons-tunnel-effect/tree/NeurIPSPaper

sys.path.append(os.path.abspath('../../jons-tunnel-effect/'))
from utils.modelfitting import evaluate_model, set_seed
from extract_weight_rank import install_hooks, perform_analysis

import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST

# To track memory usage
import psutil
process = psutil.Process()
            

def logtofile(log_text, verbose=True):
    if verbose:
        print(log_text)
    with open(save_log_as, "a") as f:    
        print(log_text, file=f)

In [2]:
# Set Parameters

# fix random seed for reproducibility
seed = 13
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
random.seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

results_root = "results_2g_a"
save_stitch_delta = False
train_all = False
measure_rank = False

# MIX is 1/3 bgonly, 1/3 mnist only, 1/3 biased data
train_mix_mnist_model = train_all  # when False, automatically loads a trained model
mix_mnist_model_to_load = "./results_4_epochs/2024-08-02_16-08-24_SEED58_EPOCHS4_BGN0.1_exp2e_ResNet18_mix_mnist.weights"

# BW is greyscale mnist with no colour added
train_bw_mnist_model = train_all  # when False, automatically loads a trained model
bw_mnist_model_to_load = './results_4_epochs/2024-08-06_12-57-58_SEED60_EPOCHS4_BGN0.1_exp2e_ResNet18_bw_mnist.weights'

# BG_ONLY contains no mnist data, just a coloured background
train_bg_only_colour_mnist_model = train_all  # when False, automatically loads a trained model
bg_only_colour_mnist_model_to_load =  "./results_4_epochs/2024-08-06_12-57-58_SEED60_EPOCHS4_BGN0.1_exp2e_ResNet18_bg_only_colour_mnist.weights"

# BG_UNBIASED is digits with randomly selected colour background. Targets represent the colour
train_bg_unbiased_colour_mnist_model = train_all  # when False, automatically loads a trained model
bg_unbiased_colour_mnist_model_to_load = "./results_4_epochs/2024-08-02_16-08-24_SEED58_EPOCHS4_BGN0.1_exp2e_ResNet18_bg_unbiased_colour_mnist.weights"

# BIASED is digits with consistent per-class colour background. 
train_biased_colour_mnist_model = train_all  # when False, automatically loads a trained model
biased_colour_mnist_model_to_load = "./results_4_epochs/2024-08-02_16-08-24_SEED58_EPOCHS4_BGN0.1_exp2e_ResNet18_biased_colour_mnist.weights"

# UNBIASED is digits with randoly selected colour background. Targets are digit values
train_unbiased_colour_mnist_model = train_all  # when False, automatically loads a trained model
unbiased_colour_mnist_model_to_load = "./results_4_epochs/2024-08-02_16-08-24_SEED58_EPOCHS4_BGN0.1_exp2e_ResNet18_unbiased_colour_mnist.weights"
original_train_epochs = 4
bg_noise = 0.1

stitch_train_epochs = 10

batch_size = 128

In [3]:
# Generate filenames and log the setup details
formatted_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
filename_prefix = f"./{results_root}/{formatted_time}_SEED{seed}_EPOCHS{original_train_epochs}_BGN{bg_noise}_exp2e_ResNet18"
save_mix_mnist_model_as = f"{filename_prefix}_mix_mnist.weights"
save_bw_mnist_model_as = f"{filename_prefix}_bw_mnist.weights"
save_bg_only_colour_mnist_model_as = f"{filename_prefix}_bg_only_colour_mnist.weights"
save_bg_unbiased_colour_mnist_model_as = f"{filename_prefix}_bg_unbiased_colour_mnist.weights"
save_biased_colour_mnist_model_as = f"{filename_prefix}_biased_colour_mnist.weights"
save_unbiased_colour_mnist_model_as = f"{filename_prefix}_unbiased_colour_mnist.weights"
save_log_as = f"{filename_prefix}_log.txt"

colour_mnist_shape = (3,28,28)


logtofile(f"Executed at {formatted_time}")
logtofile(f"logging to {save_log_as}")
logtofile(f"{seed=}")
logtofile(f"{bg_noise=}")

logtofile(f"{train_mix_mnist_model=}")
if train_mix_mnist_model:
    logtofile(f"{save_mix_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{mix_mnist_model_to_load=}")

logtofile(f"{train_bw_mnist_model=}")
if train_bw_mnist_model:
    logtofile(f"{save_bw_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{bw_mnist_model_to_load=}")
    
logtofile(f"{train_bg_only_colour_mnist_model=}")
if train_bg_only_colour_mnist_model:
    logtofile(f"{save_bg_only_colour_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{bg_only_colour_mnist_model_to_load=}")
    
logtofile(f"{train_bg_unbiased_colour_mnist_model=}")
if train_bg_unbiased_colour_mnist_model:
    logtofile(f"{save_bg_unbiased_colour_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{bg_unbiased_colour_mnist_model_to_load=}")

logtofile(f"{train_biased_colour_mnist_model=}")
if train_biased_colour_mnist_model:
    logtofile(f"{save_biased_colour_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{biased_colour_mnist_model_to_load=}")

logtofile(f"{train_unbiased_colour_mnist_model=}")
if train_unbiased_colour_mnist_model:
    logtofile(f"{save_unbiased_colour_mnist_model_as=}")
    logtofile(f"{original_train_epochs=}")
else:
    logtofile(f"{unbiased_colour_mnist_model_to_load=}")

logtofile(f"{stitch_train_epochs=}")
logtofile(f"================================================")

Executed at 2025-01-27_09-36-53
logging to ./results_2g_a/2025-01-27_09-36-53_SEED13_EPOCHS4_BGN0.1_exp2e_ResNet18_log.txt
seed=13
bg_noise=0.1
train_mix_mnist_model=False
mix_mnist_model_to_load='./results_4_epochs/2024-08-02_16-08-24_SEED58_EPOCHS4_BGN0.1_exp2e_ResNet18_mix_mnist.weights'
train_bw_mnist_model=False
bw_mnist_model_to_load='./results_4_epochs/2024-08-06_12-57-58_SEED60_EPOCHS4_BGN0.1_exp2e_ResNet18_bw_mnist.weights'
train_bg_only_colour_mnist_model=False
bg_only_colour_mnist_model_to_load='./results_4_epochs/2024-08-06_12-57-58_SEED60_EPOCHS4_BGN0.1_exp2e_ResNet18_bg_only_colour_mnist.weights'
train_bg_unbiased_colour_mnist_model=False
bg_unbiased_colour_mnist_model_to_load='./results_4_epochs/2024-08-02_16-08-24_SEED58_EPOCHS4_BGN0.1_exp2e_ResNet18_bg_unbiased_colour_mnist.weights'
train_biased_colour_mnist_model=False
biased_colour_mnist_model_to_load='./results_4_epochs/2024-08-02_16-08-24_SEED58_EPOCHS4_BGN0.1_exp2e_ResNet18_biased_colour_mnist.weights'
train_unbia

mnist and cifar-10 both use 10-classes, with 60_000 train samples and 10_000 test samples. 

In [4]:
# Set up dataloaders
transform_bw = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert to 3 channels    
    transforms.ToTensor(),  # convert to tensor. We always do this one    
    transforms.Normalize((0.1307,) * 3, (0.3081,) * 3)     
])

mnist_train = MNIST("./MNIST", train=True, download=True, transform=transform_bw)
mnist_test = MNIST("./MNIST", train=False, download=True, transform=transform_bw)

bw_train_dataloader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
bw_test_dataloader  = DataLoader(mnist_test,  batch_size=batch_size, shuffle=True, drop_last=False)

# mix dataloader
mix_train_dataloader = colour_mnist.get_mixed_mnist_dataloader(root="./MNIST", batch_size=batch_size, train=True, bg_noise_level=bg_noise, standard_getitem=True)
mix_test_dataloader = colour_mnist.get_mixed_mnist_dataloader(root="./MNIST", batch_size=batch_size,  train=False, bg_noise_level=bg_noise, standard_getitem=True)

# bg_only means no digits - we will use colour as label
bg_only_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=1.0, train=True, bg_noise_level=bg_noise, bg_only=True, standard_getitem=True)
bg_only_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=1.0, train=False, bg_noise_level=bg_noise, bg_only=True, standard_getitem=True)

# unbiased means each digit has correct label and random colour - but bg means we will use colour as label (i.e. the bias_target will be the target)
bg_unbiased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=0.1, train=True, bg_noise_level=bg_noise, bias_targets_as_targets=True)
bg_unbiased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=0.1, train=False, bg_noise_level=bg_noise, bias_targets_as_targets=True)

# biased means each digit has correct label and consistent colour - Expect network to learn the colours only
biased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=1.0, train=True, bg_noise_level=bg_noise, standard_getitem=True)
biased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=1.0, train=False, bg_noise_level=bg_noise, standard_getitem=True)

# unbiased means each digit has correct label and random colour - Expect network to disregard colours?
unbiased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=0.1, train=True, bg_noise_level=bg_noise, standard_getitem=True)
unbiased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root="./MNIST", batch_size=batch_size, data_label_correlation=0.1, train=False, bg_noise_level=bg_noise, standard_getitem=True)

## Set up resnet18 models and train it on versions of MNIST

In [5]:


process_structure = dict()
device = 'cuda:2'


# process_structure["mix"] = dict()
process_structure["bw"] = dict()
process_structure["bgonly"] = dict()
# process_structure["bg"] = dict()
process_structure["bias"]      = dict()
# process_structure["unbias"]    = dict()

# "mix"
# process_structure["mix"]["model"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model
# process_structure["mix"]["train"] = train_mix_mnist_model 
# process_structure["mix"]["train_loader"] = mix_train_dataloader
# process_structure["mix"]["test_loader"] = mix_test_dataloader
# process_structure["mix"]["saveas"] = save_mix_mnist_model_as
# process_structure["mix"]["loadfrom"] = mix_mnist_model_to_load

# "bw"
process_structure["bw"]["model"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model
process_structure["bw"]["train"] = train_bw_mnist_model 
process_structure["bw"]["train_loader"] = bw_train_dataloader
process_structure["bw"]["test_loader"] = bw_test_dataloader
process_structure["bw"]["saveas"] = save_bw_mnist_model_as
process_structure["bw"]["loadfrom"] = bw_mnist_model_to_load

# "bg_only_colour"
process_structure["bgonly"]["model"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model
process_structure["bgonly"]["train"] = train_bg_only_colour_mnist_model 
process_structure["bgonly"]["train_loader"] = bg_only_train_dataloader
process_structure["bgonly"]["test_loader"] = bg_only_test_dataloader
process_structure["bgonly"]["saveas"] = save_bg_only_colour_mnist_model_as
process_structure["bgonly"]["loadfrom"] = bg_only_colour_mnist_model_to_load

# "bg_unbiased_colour"
# process_structure["bg"]["model"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model
# process_structure["bg"]["train"] = train_bg_unbiased_colour_mnist_model 
# process_structure["bg"]["train_loader"] = bg_unbiased_train_dataloader
# process_structure["bg"]["test_loader"] = bg_unbiased_test_dataloader
# process_structure["bg"]["saveas"] = save_bg_unbiased_colour_mnist_model_as
# process_structure["bg"]["loadfrom"] = bg_unbiased_colour_mnist_model_to_load

# "biased_colour_mnist"
process_structure["bias"]["model"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model
process_structure["bias"]["train"] = train_biased_colour_mnist_model
process_structure["bias"]["train_loader"] = biased_train_dataloader
process_structure["bias"]["test_loader"] = biased_test_dataloader
process_structure["bias"]["saveas"] = save_biased_colour_mnist_model_as
process_structure["bias"]["loadfrom"] =  biased_colour_mnist_model_to_load
# 
# # "unbiased_colour_mnist"
# process_structure["unbias"]["model"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model
# process_structure["unbias"]["train"] = train_unbiased_colour_mnist_model
# process_structure["unbias"]["train_loader"] = unbiased_train_dataloader
# process_structure["unbias"]["test_loader"] = unbiased_test_dataloader
# process_structure["unbias"]["saveas"] = save_unbiased_colour_mnist_model_as
# process_structure["unbias"]["loadfrom"] =  unbiased_colour_mnist_model_to_load

for key, val in process_structure.items():
    logtofile(f"Processing for {key=}")
    if val["train"]:
        train_model(model=val["model"], train_loader=val["train_loader"], 
                    epochs=original_train_epochs, saveas=val["saveas"], 
                    description=key, device=device, logtofile=logtofile)
    else:
        logtofile(f"{val['loadfrom']=}")
        val["model"].load_state_dict(torch.load(val["loadfrom"], map_location=torch.device(device)))
    val["model"].eval()


Processing for key='bw'
val['loadfrom']='./results_4_epochs/2024-08-06_12-57-58_SEED60_EPOCHS4_BGN0.1_exp2e_ResNet18_bw_mnist.weights'
Processing for key='bgonly'
val['loadfrom']='./results_4_epochs/2024-08-06_12-57-58_SEED60_EPOCHS4_BGN0.1_exp2e_ResNet18_bg_only_colour_mnist.weights'
Processing for key='bias'
val['loadfrom']='./results_4_epochs/2024-08-02_16-08-24_SEED58_EPOCHS4_BGN0.1_exp2e_ResNet18_biased_colour_mnist.weights'


## Measure the Accuracy, Record the Confusion Matrix


In [6]:
logtofile("Entering Confusion")
# logtofile(process.memory_info().rss)  # in bytes 

original_accuracy = dict()
for key, val in process_structure.items():
    logtofile(f"Accuracy Calculation for ResNet18 with {key=}")
    model = val["model"]
    model.eval() # ALWAYS DO THIS BEFORE YOU EVALUATE MODELS
    
    # Compute the model accuracy on the test set
    correct = 0
    total = 0
    
    # assuming 10 classes
    # rows represent actual class, columns are predicted
    confusion_matrix = torch.zeros(10,10, dtype=torch.int)
    
    #TDL = biased_test_dataloader  # In test 2D - ALWAYS use biased dataset to measure/train stitch
    TDL = val["test_loader"] # use the test loader for the dataset the model was trained on 
    for data in TDL:
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        predictions = torch.argmax(model(inputs),1)
        
        matches = predictions == labels
        correct += matches.sum().item()
        total += len(labels)
        for idx, l in enumerate(labels):
            confusion_matrix[l, predictions[idx]] = 1 + confusion_matrix[l, predictions[idx]] 
    
    logtofile(f"Test the Trained Resnet18 against OWN TEST LOADER: {key=}")
    acc = ((100.0 * correct) / total)
    logtofile('Test Accuracy: %2.2f %%' % acc)
    original_accuracy[key] = acc
    logtofile('Confusion Matrix')
    logtofile(confusion_matrix)
    logtofile(confusion_matrix.sum())
    # logtofile(process.memory_info().rss)  # in bytes 

    if not measure_rank:
        
        if val["train"]:
            filename = val["saveas"] 
        else:    
            filename = val["loadfrom"] 
        assert os.path.exists(filename)
    
        out_filename = filename.split('/')[-1].replace('.weights', '-test.csv')
        
        outpath = f"./{results_root}_rank/{key}-{key}-{seed}_{out_filename}"  # denote output name as <model_training_type>-dataset-<name>
        
        if os.path.exists(f"{outpath}"):
            logtofile(f"Already evaluated for {outpath}")
            continue
        #logtofile(f"Measure Rank for {key=}")
        logtofile(f"output to {outpath}")
                
        params = {}
        params["model"] = key
        params["dataset"] = key
        params["seed"] = seed
        if val["train"]: # as only one network used, record its filename as both send and receive files
            params["send_file"] = val["saveas"] 
            params["rcv_file"] = val["saveas"] 
        else:    
            params["send_file"] = val["loadfrom"] 
            params["rcv_file"] = val["loadfrom"]
        params["val_acc"] = acc / 100
        params["name"] = "only"

        results = []
        results.append(params)
        df = pd.DataFrame.from_records(results)
        df.to_csv(f"{outpath}")
                    
        del  params, df
        gc.collect()

logtofile(f"{original_accuracy=}")

Entering Confusion
Accuracy Calculation for ResNet18 with key='bw'
Test the Trained Resnet18 against OWN TEST LOADER: key='bw'
Test Accuracy: 99.09 %
Confusion Matrix
tensor([[ 979,    0,    0,    0,    0,    1,    0,    0,    0,    0],
        [   0, 1132,    0,    0,    0,    2,    1,    0,    0,    0],
        [   1,    1, 1025,    3,    1,    0,    0,    1,    0,    0],
        [   1,    0,    0, 1000,    0,    5,    0,    0,    3,    1],
        [   0,    0,    0,    0,  973,    0,    0,    1,    0,    8],
        [   2,    0,    0,    1,    0,  881,    1,    0,    2,    5],
        [   3,    4,    0,    0,    0,    2,  948,    0,    1,    0],
        [   0,    5,    9,    2,    0,    0,    0, 1006,    1,    5],
        [   3,    0,    1,    0,    0,    1,    0,    2,  965,    2],
        [   1,    1,    0,    2,    3,    1,    0,    0,    1, 1000]],
       dtype=torch.int32)
tensor(10000)
Already evaluated for ./results_2g_a_rank/bw-bw-13_2024-08-06_12-57-58_SEED60_EPOCHS4_BGN0.1

## Measure Rank with __OWN__ dataloader (test) before cutting and stitching

In [7]:
if measure_rank:
    logtofile("Entering whole model check")
    # logtofile(process.memory_info().rss)  # in bytes 
    
    # For the Whole Model - but we will pass it through the RcvResNet18 function to get matching feature names
    for key, val in process_structure.items():
        
        #TDL = biased_test_dataloader  # ALWAYS use biased dataloader for this test
        TDL = val["test_loader"] # use the test loader for the dataset the model was trained on 
        if val["train"]:
            filename = val["saveas"] 
        else:    
            filename = val["loadfrom"] 
        assert os.path.exists(filename)
        mdl = torchvision.models.resnet18(num_classes=10) # Untrained model
        state = torch.load(filename, map_location=torch.device("cpu"))
        mdl.load_state_dict(state, assign=True)
        mdl=mdl.to(device)
        mdl = RcvResNet18(mdl, -1, colour_mnist_shape, device).to(device)
    
        out_filename = filename.split('/')[-1].replace('.weights', '-test.csv')
        
        outpath = f"./{results_root}_rank/{key}-{key}-{seed}_{out_filename}"  # denote output name as <model_training_type>-dataset-<name>
        
        if os.path.exists(f"{outpath}"):
            logtofile(f"Already evaluated for {outpath}")
            continue
        logtofile(f"Measure Rank for {key=}")
        logtofile(f"output to {outpath}")
                
        params = {}
        params["model"] = key
        params["dataset"] = key
        params["seed"] = seed
        if val["train"]: # as only one network used, record its filename as both send and receive files
            params["send_file"] = val["saveas"] 
            params["rcv_file"] = val["saveas"] 
        else:    
            params["send_file"] = val["loadfrom"] 
            params["rcv_file"] = val["loadfrom"]     
        
        with torch.no_grad():
            layers, features, handles = install_hooks(mdl)
                        
            metrics = evaluate_model(mdl, TDL, 'acc', verbose=2, device=device)
            params.update(metrics)
            
            classes = None
            df = perform_analysis(features, classes, layers, params, n=-1)
            df.to_csv(f"{outpath}")
        for h in handles:
            h.remove()
        del mdl, layers, features, metrics, params, df, handles
        gc.collect()
        # logtofile(process.memory_info().rss)  # in bytes 



# Stitch at a given layer


## Train the stitch layer and check rank

In [None]:
logtofile("Entering Stitch/Rank")
# logtofile(process.memory_info().rss)  # in bytes 

logtofile(f"{device=}")
#stitching_accuracies = dict()
#stitching_penalties = dict()
# NOTE this is only valid as all models are the same architecture
num_layers_in_model = len(list(process_structure["bw"]["model"].children()))  
for target_offset in [0,4,5,6,7,8,9]:  # This is the amount that will be added to labels when training and testing the stitch (modulo 10)

    logtofile(f"{target_offset=}")
    for send_key, send_val in process_structure.items():
        #stitching_accuracies[send_key] = dict()
        #stitching_penalties[send_key] = dict()
        for rcv_key, rcv_val in process_structure.items():        
            if (rcv_key != send_key):
                logtofile(f"NOTE: Only running stitch between same model: skipping")
                continue
            #stitching_accuracies[send_key][rcv_key] = dict()
            #stitching_penalties[send_key][rcv_key] = dict()
            for layer_to_cut_after in range(3,num_layers_in_model - 1):
                # for consistency, use the rcv network for the filename stem.
                if rcv_val["train"]:
                    filename = rcv_val["saveas"] 
                else:    
                    filename = rcv_val["loadfrom"] 
                
                rank_filename = filename.split('/')[-1].replace('.weights', '-test.csv')        
                # denote output name as <model_training_type>-dataset-<name>
                # where <model_training_type> is [sender_model or X][layer_to_cut_after][Receiver_model]
                model_training_type = f"{send_key}{layer_to_cut_after}{rcv_key}"
                #dataset_type = "bias"  # ALWAYS use bias dataset in this test
                dataset_type = send_key # ALWAYS use sender-matched dataset in this test
                outpath = f"./{results_root}_rank/{model_training_type}-offset{target_offset}-StitchEpochs{stitch_train_epochs}-{dataset_type}-{seed}_{rank_filename}"  
                                
                if os.path.exists(f"{outpath}"):
                    logtofile(f"Already evaluated for {outpath}")
                    continue
                ####################################################################################
                logtofile(f"Evaluate ranks and output to {outpath}")
                # logtofile(process.memory_info().rss)  # in bytes 
    
                logtofile(f"Train the stitch to a model stitched after layer {layer_to_cut_after} from {send_key} to {rcv_key}")    
                logtofile(f"Use the {send_key} data loader (train and test) regardless of what {rcv_key} was trained on")
                
                # train a stitch on the unbiased_colour dataset to compare receiver network performance with stitched
                model_stitched = StitchedResNet18(send_model=send_val["model"], 
                                                  after_layer_index=layer_to_cut_after, 
                                                  rcv_model=rcv_val["model"],
                                                  input_image_shape=colour_mnist_shape, device=device  ).to(device)
                            
                #############################################################
                # store the initial stitch state
                initial_stitch_weight = model_stitched.stitch.s_conv1.weight.clone()
                initial_stitch_bias   = model_stitched.stitch.s_conv1.bias.clone()
                stitch_initial_weight_outpath    = f"./{results_root}/STITCH_initial_weight_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
                stitch_initial_bias_outpath      = f"./{results_root}/STITCH_initial_bias_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
                if save_stitch_delta:
                    torch.save(initial_stitch_weight, stitch_initial_weight_outpath)
                    torch.save(initial_stitch_bias, stitch_initial_bias_outpath)
                ############################################################
                        
                # define the loss function and the optimiser
                loss_function = nn.CrossEntropyLoss()
                # Hernandez said : momentum 0.9, batch size 256, weight decay 0.01, learning rate 0.01, and a post-warmup cosine learning rate scheduler.
                # optimiser = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
                optimiser = optim.SGD(model_stitched.parameters(), lr=1e-4, momentum=0.9, weight_decay=0.01)
                
                # Put top model into train mode so that bn and dropout perform in training mode
                model_stitched.train()
                # Freeze the whole model
                model_stitched.requires_grad_(False)
                # Un-Freeze the stitch layer
                for name, param in model_stitched.stitch.named_parameters():
                    param.requires_grad_(True)
                # the epoch loop: note that we're training the whole network
                for epoch in range(stitch_train_epochs):
                    running_loss = 0.0
                    for data in send_val["train_loader"]:# biased_train_dataloader:
                        # data is (representations, labels) tuple
                        # get the inputs and put them on the GPU
                        inputs, labels = data
                        inputs = inputs.to(device)
                        labels = labels.to(device)
                        labels = (labels + target_offset) % 10
                        
                
                        # zero the parameter gradients
                        optimiser.zero_grad()
                
                        # forward + loss + backward + optimise (update weights)
                        outputs = model_stitched(inputs)
                        loss = loss_function(outputs, labels)
                        loss.backward()
                        optimiser.step()
                
                        # keep track of the loss this epoch
                        running_loss += loss.item()
                    logtofile("Epoch %d, loss %4.2f" % (epoch, running_loss))
                    # logtofile(process.memory_info().rss)  # in bytes 
    
                logtofile('**** Finished Training ****')
                
                model_stitched.eval() # ALWAYS DO THIS BEFORE YOU EVALUATE MODELS
    
                ############################################################
                # store the trained stitch
                trained_stitch_weight = model_stitched.stitch.s_conv1.weight.clone()
                trained_stitch_bias   = model_stitched.stitch.s_conv1.bias.clone()
                stitch_trained_weight_outpath    = f"./{results_root}/STITCH_trained_weight_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
                stitch_trained_bias_outpath      = f"./{results_root}/STITCH_trained_bias_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}"  
                
                if save_stitch_delta:
                    torch.save(trained_stitch_weight, stitch_trained_weight_outpath)
                    torch.save(trained_stitch_bias, stitch_trained_bias_outpath)
                           
                stitch_weight_diff = trained_stitch_weight - initial_stitch_weight
                stitch_weight_delta = torch.linalg.norm(stitch_weight_diff).item()
                logtofile(f"Change in stitch weights: {stitch_weight_delta}")
                maxabsweight =  torch.max(stitch_weight_diff.abs()).item()
                logtofile(f"Largest abs weight change: {maxabsweight}")
                stitch_weight_number = torch.sum(torch.where(stitch_weight_diff.abs() > 0.1*maxabsweight, True, False)).item()
                logtofile(f"Number of weights changing > 0.1 of that: {stitch_weight_number}")
    
                
                logtofile(f"Number of weight / bias in stitch layer is {len(initial_stitch_weight)}")
                stitch_bias_diff = trained_stitch_bias - initial_stitch_bias
                stitch_bias_delta = torch.linalg.norm(stitch_bias_diff).item()
                logtofile(f"Change in stitch bias: {stitch_bias_delta}")
                maxabsbias =  torch.max(stitch_bias_diff.abs()).item()
                logtofile(f"Largest abs bias change: {maxabsbias}")
                stitch_bias_number = torch.sum(torch.where(stitch_bias_diff.abs() > 0.1*maxabsbias, True, False)).item()
                logtofile(f"Number of bias changing > 0.1 of that: {stitch_bias_number}")
                ##############################################################
    
                
                # Compute the model accuracy on the test set
                correct = 0
                total = 0
                
                # assuming 10 classes
                # rows represent actual class, columns are predicted
                confusion_matrix = torch.zeros(10,10, dtype=torch.int)
                TDL = send_val["test_loader"]
                for data in TDL: # biased_test_dataloader:  # Only use biased test data
                    inputs, labels = data
                    inputs = inputs.to(device)
                    labels = labels.to(device)
                    labels = (labels + target_offset) % 10
    
                    
                    predictions = torch.argmax(model_stitched(inputs),1)
                    matches = predictions == labels.to(device)
                    correct += matches.sum().item()
                    total += len(labels)
                
                    for idx, l in enumerate(labels):
                        confusion_matrix[l, predictions[idx]] = 1 + confusion_matrix[l, predictions[idx]] 
                logtofile(f"Test the trained stitch against {send_key=} data")    
                acc =  ((100.0 * correct) / total)
                logtofile('Test Accuracy: %2.2f %%' % acc)
                logtofile('Confusion Matrix')
                logtofile(confusion_matrix)
                logtofile("===================================================================")
                # logtofile(process.memory_info().rss)  # in bytes 
    
                # Stitching penalty should be negative if there is an improvement, and is relative to the original receiver network
                #stitching_accuracies[send_key][rcv_key][layer_to_cut_after] = acc
                #stitching_penalties[send_key][rcv_key][layer_to_cut_after] = original_accuracy[rcv_key] - acc
    
                if measure_rank:
                    raise ValueError("FIXME: measure_rank won't work until dataloaders can offset labels")
                    # MEASURE RANK
                    #TDL = biased_test_dataloader
                    params = {}
                    params["offset"] = target_offset
                    params["model"] = model_training_type # a mnemonic
                    params["dataset"] = dataset_type
                    params["seed"] = seed
                    if send_val["train"]:
                        params["send_file"] = send_val["saveas"] 
                    else:    
                        params["send_file"] = send_val["loadfrom"] 
                    if rcv_val["train"]:
                        params["rcv_file"] = rcv_val["saveas"] 
                    else:    
                        params["rcv_file"] = rcv_val["loadfrom"] 
                    params["stitch_weight_delta"] = stitch_weight_delta
                    params["stitch_bias_delta"] = stitch_bias_delta        
                    params["stitch_weight_number"] = stitch_weight_number
                    params["stitch_bias_number"] = stitch_bias_number
                    # logtofile(process.memory_info().rss)  # in bytes 
                    with torch.no_grad():
                        layers, features, handles = install_hooks(model_stitched)                
                        metrics = evaluate_model(model_stitched, TDL, 'acc', verbose=2, device=device)
                        params.update(metrics)
                        
                        classes = None
                        df = perform_analysis(features, classes, layers, params, n=-1)
                        df.to_csv(f"{outpath}")
                        
                    for h in handles:
                        h.remove()
                    del model_stitched, layers, features, metrics, params, df, handles
                    gc.collect()
                    # logtofile(process.memory_info().rss)  # in bytes 
                else:
                    dl = val["test_loader"]
                    
                    logtofile(f"output to {outpath}")
                            
                    params = {}
                    params["offset"] = target_offset
                    params["model"] = model_training_type
                    params["dataset"] = dataset_type
                    params["seed"] = seed
                    if send_val["train"]:
                        params["send_file"] = send_val["saveas"] 
                    else:    
                        params["send_file"] = send_val["loadfrom"] 
                    if rcv_val["train"]:
                        params["rcv_file"] = rcv_val["saveas"] 
                    else:    
                        params["rcv_file"] = rcv_val["loadfrom"] 
                    params["val_acc"] = acc / 100
                    params["name"] = "only"
            
                    results = []
                    results.append(params)
                    df = pd.DataFrame.from_records(results)
                    df.to_csv(f"{outpath}")
                                
                    del  params, df
                    gc.collect()
            

Entering Stitch/Rank
device='cuda:2'
target_offset=0
Evaluate ranks and output to ./results_2g_a_rank/bw3bw-offset0-StitchEpochs10-bw-13_2024-08-06_12-57-58_SEED60_EPOCHS4_BGN0.1_exp2e_ResNet18_bw_mnist-test.csv
Train the stitch to a model stitched after layer 3 from bw to bw
Use the bw data loader (train and test) regardless of what bw was trained on
get_layer_output_shape for type='ResNet18'
The shape of the output from layer 3 of send_model is: torch.Size([1, 64, 7, 7])
Epoch 0, loss 142.88
Epoch 1, loss 49.64
Epoch 2, loss 39.46
Epoch 3, loss 33.95
Epoch 4, loss 31.23
Epoch 5, loss 28.93
Epoch 6, loss 27.29
Epoch 7, loss 26.08
Epoch 8, loss 24.89
Epoch 9, loss 23.84
**** Finished Training ****
Change in stitch weights: 0.9635379910469055
Largest abs weight change: 0.13469451665878296
Number of weights changing > 0.1 of that: 1280
Number of weight / bias in stitch layer is 64
Change in stitch bias: 0.027068743482232094
Largest abs bias change: 0.005700737237930298
Number of bias cha

In [None]:
#logtofile(f"{stitching_accuracies=}")
#logtofile(f"{stitching_penalties=}")

In [None]:
'''
for s_key in stitching_accuracies:
    logtofile(f"sender:-{s_key}")
    logtofile(original_accuracy[s_key])
    logtofile("Stitch Accuracy")
    for r_key in stitching_accuracies[s_key]:
        logtofile(f"receiver:-{r_key}")
        logtofile(stitching_accuracies[s_key][r_key])
    logtofile("--------------------------")
'''