In [None]:
import pickle
import glob
import numpy as np
import matplotlib.pyplot as plt
import tikzplotlib

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Computer Modern Roman"],
    })


colors = [(0.2298057,0.298717966,0.753683153),
          (0.3634607953411765,0.4847836818509804,0.9010188868941177),
          (0.5108243242509803,0.6493966148235294,0.9850787763764707),
          (0.6672529243333334,0.7791764569999999,0.992959213),
          (0.8049647588235295,0.8516661605568627,0.9261650744313725),
          (0.9193759889058823,0.8312727235294118,0.7828736304470588),
          (0.968203399,0.7208441,0.6122929913333334),
          (0.9440545734235294,0.5531534787490197,0.4355484903137255),
          (0.8523781350078431,0.34649194649411763,0.2803464686980392),
          (0.705673158,0.01555616,0.150232812)]
colors = ['b','r','k','g','m','c','tab:brown','tab:orange','tab:pink','tab:gray','tab:olive','tab:purple']

markers = ["v","o","^","1","*",">","d","<","s","P","X"]
fontsize = 22

### Gradients for different noise levels

In [None]:
def plot_hists_together(ss_grad_diffs,sup_grad_diffs,bins,log=True):
    
    min_all = np.min(ss_grad_diffs+sup_grad_diffs)
    max_all = np.max(ss_grad_diffs+sup_grad_diffs)
    
    plt.hist(sup_grad_diffs,bins=bins, alpha=0.5, label='Sup', color='red')
    plt.hist(ss_grad_diffs,bins=bins, alpha=0.5, label='Self-sup 3.0', color='green')
    
    plt.legend(loc='upper right')
    if log:
        plt.xscale('log')
        #plt.yscale('log')

    # Add labels and a title
    plt.xlabel('Normalized Difference')
    plt.ylabel('Frequency')
    plt.title('Histogram of Stochastic Gradients')
    
    #tikzplotlib.save('../../txt_files/histogram_Den_SIDD_log_logbins.tex')
    plt.show()

In [None]:
epochs = [1]

experiment_name = "E004_t10000_l2c56_bs1_lr00032_sup/"# "E001_t10000_l2c56_bs1_lr00032_sup/" 

for epoch in epochs:
    with open(f"./{experiment_name}/ss_diff_tracks_ep{epoch}.pkl", 'rb') as stream:          
        ss_diff_30_tracks_FixCenter = pickle.load(stream)
    with open(f"./{experiment_name}/sup_diff_tracks_ep{epoch}.pkl", 'rb') as stream:          
        sup_diff_tracks_FixCenter = pickle.load(stream)

    ss_grad_diffs_30 = ss_diff_30_tracks_FixCenter['divide_by_norm_of_risk_grad'].val
    sup_grad_diffs = sup_diff_tracks_FixCenter['divide_by_norm_of_risk_grad'].val
    ss_grad_diffs_30.sort()
    sup_grad_diffs.sort()
    
    print(f"\nEpoch {epoch}:")
    print(f"Stats ss 3.0 grad diffs: mean {np.mean(ss_grad_diffs_30):.5}, std {np.std(ss_grad_diffs_30):.5}, max {np.max(ss_grad_diffs_30):.5}, min {np.min(ss_grad_diffs_30):.5}")
    print(f"Stats sup grad diffs: mean {np.mean(sup_grad_diffs):.5}, std {np.std(sup_grad_diffs):.5}, max {np.max(sup_grad_diffs):.5}, min {np.min(sup_grad_diffs):.5}")

    
    cut_off = 1
    max_diff_considered = 200 #ss_grad_diffs[-50]
    print(f"\nGrad diffs larger than {max_diff_considered} are not considered in histograms")
    ss_grad_diffs_30_cutoff = list(np.array(ss_grad_diffs_30)[np.array(ss_grad_diffs_30)<=max_diff_considered])
    sup_grad_diffs_cutoff = list(np.array(sup_grad_diffs)[np.array(sup_grad_diffs)<=max_diff_considered])
    
    print(f"Fraction of grad diffs that remain after cutoff:")
    print(f"ss 3.0: {len(ss_grad_diffs_30_cutoff)/len(ss_grad_diffs_30)}")
    print(f"sup: {len(sup_grad_diffs_cutoff)/len(sup_grad_diffs)}")
    ss_grad_diffs_30 = ss_grad_diffs_30_cutoff
    sup_grad_diffs = sup_grad_diffs_cutoff
    
    print(f"\nStats after cutoff grad diffs larger than {max_diff_considered}")
    print(f"Stats ss 3.0 grad diffs: mean {np.mean(ss_grad_diffs_30):.5}, std {np.std(ss_grad_diffs_30):.5}, max {np.max(ss_grad_diffs_30):.5}, min {np.min(ss_grad_diffs_30):.5}")
    print(f"Stats sup grad diffs: mean {np.mean(sup_grad_diffs):.5}, std {np.std(sup_grad_diffs):.5}, max {np.max(sup_grad_diffs):.5}, min {np.min(sup_grad_diffs):.5}")

    
    
    min_all = np.min(ss_grad_diffs_30+sup_grad_diffs)
    max_all = np.max(ss_grad_diffs_30+sup_grad_diffs)
    
    
    # logarithmically sized bins
    num_bins=100
    bins=np.logspace(np.log10(min_all), np.log10(max_all), num=num_bins, endpoint=True, base=10.0)
    print(num_bins)
    
    print(len(bins))
    print(max_all)
    print(bins[-1])
    
    log=True
    plot_hists_together(ss_grad_diffs_30,sup_grad_diffs,bins=bins,log=log)
