In [None]:
import torch
import torchvision.models as models
import torch.distributed as dist
import torch.optim as optim
import scipy.stats as stats

%matplotlib inline
import matplotlib.pyplot as plt

# oialr_checkpoint_folder = "/hkfs/work/workspace/scratch/qv2382-madonna/madonna/models/svd-final/12/weights/epoch_299.pt"
oialr_checkpoint_folder = "/hkfs/work/workspace/scratch/qv2382-madonna-ddp/madonna/models/svd-nodroppath/2/weights/epoch_299.pt"
# baseline_checkpoint_folder = "/hkfs/work/workspace/scratch/qv2382-madonna/madonna/models/baseline-final/8/weights/epoch_299.pt"
baseline_checkpoint_folder = "/hkfs/work/workspace/scratch/qv2382-madonna-ddp/madonna/models/baseline-nodroppath/1/weights/epoch_299.pt"

svd_model_dict = torch.load(oialr_checkpoint_folder, map_location="cpu")
baseline_model_dict = torch.load(baseline_checkpoint_folder, map_location="cpu")
print(list(baseline_model_dict["model"].keys()))

In [None]:
# get full rank rep of the svd models

layer_names = []

for n in svd_model_dict['model']:
    p = svd_model_dict['model'][n]
    # print(n, p.shape)
    if not n.startswith('module'):
        n = f"module.{n}"
    if n.endswith(".u"):
        # get layer name
        name = n.split('.')
        name = '.'.join(name[:-1])
        # print(name)
        layer_names.append(name)
        u = p
        s = svd_model_dict['model'][f"{name[7:]}.s"]
        vh = svd_model_dict['model'][f"{name[7:]}.vh"]
        w = u @ s @ vh
        svd_model_dict[name] = w
    elif n.endswith((".s", ".vh")):
        pass
    elif p.ndim > 1:
        # name = n.split(".")
        if p.ndim > 2:
            w = p.reshape((p.shape[0], -1))
            svd_model_dict[n] = w if w.shape[0] > w.shape[1] else w.T
        else:
            svd_model_dict[n] = p if p.shape[0] > p.shape[1] else p.T
        layer_names.append(n)

for n in baseline_model_dict['model']:
    p = baseline_model_dict['model'][n]
    print(n, p.ndim)
    if not n.startswith('module'):
        n = f"module.{n}"
    if p.ndim >= 2:
        # get layer name
        if n not in layer_names:
            name = n.split('.')
            name = '.'.join(name[:-1])
        else: 
            name = n
        print(name)
        if p.ndim > 2:
            w = p.reshape((p.shape[0], -1))
            baseline_model_dict[name] = w if w.shape[0] > w.shape[1] else w.T
        else:
            # layer_names.append(name)
            w = p if p.shape[0] > p.shape[1] else p.T
            baseline_model_dict[name] = w
print(layer_names)

In [None]:
cutoff = 1e-5
target_layer = layer_names[8]
# print(layer_names[8])
# print(svd_model_dict.keys())

for n in layer_names:
    print(n)
    p = baseline_model_dict[n]
#     print("baseline")
    print(f"{n} {p.numel()}\t{tuple(p.shape)}\tvals: {(p.abs() < cutoff).sum()} dim0: {(p.mean(0).abs() < cutoff).sum()} "
          f"dim1: {(p.mean(1).abs() < cutoff).sum()}")
#     print("lr")
    p = svd_model_dict[n]
    print(f"{n} {p.numel()}\t{tuple(p.shape)}\tvals: {(p.abs() < cutoff).sum()} dim0: {(p.mean(0).abs() < cutoff).sum()} "
          f"dim1: {(p.mean(1).abs() < cutoff).sum()}")
    print()

In [None]:
figs, axs = {}, {}
# fig.clf()
for i, n in enumerate(layer_names[10:12]):
    print(n)
    pb = baseline_model_dict[n]
    # print(f"{n} {pb.numel()}\t{tuple(pb.shape)}\tvals: {(pb.abs() < cutoff).sum()} dim0: {(pb.mean(0).abs() < cutoff).sum()} "
    #       f"dim1: {(pb.mean(1).abs() < cutoff).sum()}")
    figs[i]=plt.figure()
    axs[i]=figs[i].add_subplot(111)
    pb = torch.rand(pb.shape)
    ret = axs[i].hist(pb.flatten().detach(), bins=100, label="baseline")
    # Show/save figure as desired.
    ps = svd_model_dict[n]
    ret = axs[i].hist(ps.flatten().detach(), bins=ret[1], alpha=0.3, label="svd")
    axs[i].set_title(n)
    # print(f"{n} {ps.numel()}\t{tuple(ps.shape)}\tvals: {(ps.abs() < cutoff).sum()} dim0: {(ps.mean(0).abs() < cutoff).sum()} "
    #       f"dim1: {(ps.mean(1).abs() < cutoff).sum()}")
    # ret = axs[i].hist([pb.flatten().detach(), ps.flatten().detach()], bins=100, label=['b', 'svd'])#, alpha=[1.0, 0.3])
    axs[i].legend(loc="upper right")
    # print()
plt.show()
    # plt.clf()


In [None]:
import seaborn as sns
import numpy as np
figs, axs = {}, {}
# fig.clf()
for i, n in enumerate(layer_names[:]):
    print(n)
    pb = baseline_model_dict[n]
    ps = svd_model_dict[n]
    # print(f"{n} {pb.numel()}\t{tuple(pb.shape)}\tvals: {(pb.abs() < cutoff).sum()} dim0: {(pb.mean(0).abs() < cutoff).sum()} "
    #       f"dim1: {(pb.mean(1).abs() < cutoff).sum()}")

    # newcolors = cmap(np.linspace(0, 1, 100))
    # newcolors[:10] = np.array([1,1,1,1])
    # newcolors[90:] = np.array([0,0,0,1])
    # newcolors[20] = mpl.colors.to_rgb('tab:orange') + (1,)

    figs[i]=plt.figure()
    axs[i]=figs[i].add_subplot(111)
    # print(pb.min(), pb.max())
    # mask = pb.abs() < 1e-5
    # g = sns.heatmap(pb.numpy(), ax=axs[i], mask=mask.numpy())
    mask = ps.abs() < 1e-5
    g = sns.heatmap(ps.numpy(), ax=axs[i], mask=mask.numpy())
    g.set_facecolor("black")
    # ret = axs[i].imshow(pb, cmap=plt.cm.seismic, vmin=pb.min()*0.5, vmax=pb.max()*0.5)
    # Show/save figure as desired.
    # ps = svd_model_dict[n]
    # ret = axs[i].hist(ps.flatten().detach(), bins=ret[1], alpha=0.3, label="svd")
    axs[i].set_title(n)
    # print(f"{n} {ps.numel()}\t{tuple(ps.shape)}\tvals: {(ps.abs() < cutoff).sum()} dim0: {(ps.mean(0).abs() < cutoff).sum()} "
    #       f"dim1: {(ps.mean(1).abs() < cutoff).sum()}")
#     ret = axs[i].hist([pb.flatten().detach(), ps.flatten().detach()], bins=100, label=['b', 'svd'])#, alpha=[1.0, 0.3])
    # axs[i].legend(loc="upper right")
    # print()
plt.show()
    # plt.clf()

In [None]:
print(layer_names[7])
sample = baseline_model_dict[layer_names[7]]
print(sample.shape)
print((sample.mean(0).abs() < 0.0001).sum())
sns.histplot(sample.mean(0))
plt.show()