In [42]:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from matplotlib import pyplot as plt
from dreamer_init_weights import *

initial_model_path = "imagenet_pretrained.ckpt"
early_model = "autoencoder_transfer_early.ckpt"
later_model = "autoencoder_transfer_late.ckpt"


initial_model_state = torch.load(initial_model_path, map_location="cpu", weights_only=False)
early_model_state = torch.load(early_model, map_location="cpu", weights_only=False)
late_model_state = torch.load(later_model, map_location="cpu", weights_only=False)
world_model_enc_keys = [key for key in initial_model_state["world_model"].keys() if "encoder" in key]
world_model_dec_keys = [key for key in initial_model_state["world_model"].keys() if "decoder" in key]

In [43]:
# print("Encoder Difference")
# for i in range(len(world_model_enc_keys)):
#     if initial_model_state["world_model"][world_model_enc_keys[i]].dim() == 1:
#         print(f'Layernorm {"weight" if "weight" in world_model_enc_keys[i] else "bias"}:', end=' ')
#         print(F.mse_loss(initial_model_state["world_model"][world_model_enc_keys[i]], early_model_state["world_model"][world_model_enc_keys[i]]), end= ' ')
#         print(F.mse_loss(initial_model_state["world_model"][world_model_enc_keys[i]], late_model_state["world_model"][world_model_enc_keys[i]]), end=' ')
#     else:
#         print("Conv:", end=' ')
#         print(F.cosine_similarity(initial_model_state["world_model"][world_model_enc_keys[i]], early_model_state["world_model"][world_model_enc_keys[i]]).mean(), end= ' ')
#         print(F.cosine_similarity(initial_model_state["world_model"][world_model_enc_keys[i]], late_model_state["world_model"][world_model_enc_keys[i]]).mean(), end=' ')
#     print()
# print("Decoder Difference")
# for i in range(len(world_model_dec_keys)):
#     if initial_model_state["world_model"][world_model_dec_keys[i]].dim() == 1:
#         print(f'Layernorm {"weight" if "weight" in world_model_dec_keys[i] else "bias"}:', end=' ')
#         print(F.mse_loss(initial_model_state["world_model"][world_model_dec_keys[i]], early_model_state["world_model"][world_model_dec_keys[i]]), end= ' ')
#         print(F.mse_loss(initial_model_state["world_model"][world_model_dec_keys[i]], late_model_state["world_model"][world_model_dec_keys[i]]), end=' ')
#     else:
#         print("Conv:", end=' ')
#         print(F.cosine_similarity(initial_model_state["world_model"][world_model_dec_keys[i]], early_model_state["world_model"][world_model_dec_keys[i]]).mean(), end= ' ')
#         print(F.cosine_similarity(initial_model_state["world_model"][world_model_dec_keys[i]], late_model_state["world_model"][world_model_dec_keys[i]]).mean(), end=' ')
#     print()

In [44]:
import torch.nn.functional as F

def print_layer_diffs_table(keys, initial_state, early_state, late_state, name="Model"):
    # generated by chatgpt just for formatting purposes
    header = ["Layer", "Type", "Early Diff", "Late Diff"]
    print(f"\n{name} Differences")
    print("-" * 60)
    print(f"{header[0]:<30} {header[1]:<8} {header[2]:>12} {header[3]:>12}")
    print("-" * 60)
    
    for key in keys:
        layer = key.split('.')[-1][:30]  # truncate to max 30 chars
        if initial_state[key].dim() == 1:
            diff_type = "LN"
            early_diff = F.mse_loss(initial_state[key], early_state[key]).item()
            late_diff = F.mse_loss(initial_state[key], late_state[key]).item()
        else:
            diff_type = "Conv"
            early_diff = F.cosine_similarity(initial_state[key], early_state[key]).mean().item()
            late_diff = F.cosine_similarity(initial_state[key], late_state[key]).mean().item()
        
        print(f"{layer:<30} {diff_type:<8} {early_diff:12.6f} {late_diff:12.6f}")

    print("-" * 60)

print_layer_diffs_table(world_model_enc_keys, initial_model_state["world_model"], early_model_state["world_model"], late_model_state["world_model"], name="Encoder")
print_layer_diffs_table(world_model_dec_keys, initial_model_state["world_model"], early_model_state["world_model"], late_model_state["world_model"], name="Decoder")



Encoder Differences
------------------------------------------------------------
Layer                          Type       Early Diff    Late Diff
------------------------------------------------------------
weight                         Conv         0.996398     0.862410
weight                         LN           0.000015     0.012105
bias                           LN           0.000005     0.014555
weight                         Conv         0.993883     0.667888
weight                         LN           0.000018     0.015164
bias                           LN           0.000003     0.011438
weight                         Conv         0.988533     0.605790
weight                         LN           0.000031     0.013224
bias                           LN           0.000004     0.001908
weight                         Conv         0.989789     0.646106
weight                         LN           0.000036     0.029807
bias                           LN           0.000002     0.000063

In [45]:
initial_model_path = "imagenet_enc_only.ckpt"
early_model = "encoder_transfer_early.ckpt"
later_model = "encoder_transfer_late.ckpt"


initial_model_state = torch.load(initial_model_path, map_location="cpu", weights_only=False)
early_model_state = torch.load(early_model, map_location="cpu", weights_only=False)
late_model_state = torch.load(later_model, map_location="cpu", weights_only=False)
world_model_enc_keys = [key for key in initial_model_state["world_model"].keys() if "encoder" in key]
world_model_dec_keys = [key for key in initial_model_state["world_model"].keys() if "decoder" in key]

In [46]:
# print("encoder MSE")
# for i in range(len(world_model_enc_keys)):
#     if initial_model_state["world_model"][world_model_enc_keys[i]].dim() == 1:
#         continue
#     print(F.cosine_similarity(initial_model_state["world_model"][world_model_enc_keys[i]], early_model_state["world_model"][world_model_enc_keys[i]]).mean(), end= ' ')
#     print(F.cosine_similarity(initial_model_state["world_model"][world_model_enc_keys[i]], late_model_state["world_model"][world_model_enc_keys[i]]).mean(), end=' ')
#     scale1 = torch.mean(torch.abs(initial_model_state["world_model"][world_model_enc_keys[i]]))
#     scale2 = torch.mean(torch.abs(early_model_state["world_model"][world_model_enc_keys[i]]))
#     scale3 = torch.mean(torch.abs(late_model_state["world_model"][world_model_enc_keys[i]]))
#     print(f"scale: {scale1:.4f}, {scale2:.4f}, {scale3:.4f}")
# print("decoder MSE")
# for i in range(len(world_model_dec_keys)):
#     if initial_model_state["world_model"][world_model_dec_keys[i]].dim() == 1:
#         continue
#     print(F.cosine_similarity(initial_model_state["world_model"][world_model_dec_keys[i]], early_model_state["world_model"][world_model_dec_keys[i]]).mean(), end= ' ')
#     print(F.cosine_similarity(initial_model_state["world_model"][world_model_dec_keys[i]], late_model_state["world_model"][world_model_dec_keys[i]]).mean(), end=' ')
#     scale1 = torch.mean(torch.abs(initial_model_state["world_model"][world_model_dec_keys[i]]))
#     scale2 = torch.mean(torch.abs(early_model_state["world_model"][world_model_dec_keys[i]]))
#     scale3 = torch.mean(torch.abs(late_model_state["world_model"][world_model_dec_keys[i]]))
#     print(f"scale: {scale1:.4f}, {scale2:.4f}, {scale3:.4f}")


print_layer_diffs_table(world_model_enc_keys, initial_model_state["world_model"], early_model_state["world_model"], late_model_state["world_model"], name="Encoder")
print_layer_diffs_table(world_model_dec_keys, initial_model_state["world_model"], early_model_state["world_model"], late_model_state["world_model"], name="Decoder")


Encoder Differences
------------------------------------------------------------
Layer                          Type       Early Diff    Late Diff
------------------------------------------------------------
weight                         Conv         0.997452     0.867836
weight                         LN           0.000010     0.010277
bias                           LN           0.000009     0.011791
weight                         Conv         0.992937     0.658117
weight                         LN           0.000013     0.015325
bias                           LN           0.000003     0.020321
weight                         Conv         0.988854     0.606595
weight                         LN           0.000024     0.010294
bias                           LN           0.000003     0.002030
weight                         Conv         0.987292     0.599645
weight                         LN           0.000023     0.017456
bias                           LN           0.000011     0.007431