# Plotting results

In [None]:
import os
import csv
from datetime import datetime, timedelta
from collections import OrderedDict

import numpy as np
import plotly.graph_objects as go
os.chdir('..')
import breaching

In [None]:
fig = go.Figure()


x = [32, 64, 128, 256, 512, 1024]
y = [46.88, 48.44, 68.75, 79.69, 89.06, 95.31]
fig.add_trace(go.Scatter(y=y,
                         x=x,
                         name='Batch size 64',
                         mode='lines+markers+text',
                         line=dict(width=5, dash='solid', color='royalblue'),
                         marker=dict(size=20, symbol='circle'),
                         showlegend=True,
                         textposition= "top center",
                         text=''
                        ))

x = [32, 64, 128, 256, 512, 1024]
y = [26.56,29.69,50.78,70.31,80.47,89.06]
fig.add_trace(go.Scatter(y=y,
                         x=x,
                         name='Batch size 128',
                         mode='lines+markers+text',
                         line=dict(width=5, dash='solid', color='firebrick'),
                         marker=dict(size=20, symbol='circle'),
                         showlegend=True,
                         textposition= "top center",
                         text=''
                        ))

x = [32, 64, 128, 256, 512, 1024]
y = [14.45,12.89,23.83,47.27,72.27,83.59]
fig.add_trace(go.Scatter(y=y,
                         x=x,
                         name='Batch size 256',
                         mode='lines+markers+text',
                         line=dict(width=5, dash='solid', color='rgba(34,139,34,1.0)'),
                         marker=dict(size=20, symbol='circle'),
                         showlegend=True,
                         textposition= "top center",
                         text=''
                        ))

fig.update_traces(cliponaxis=False, textfont=dict(color='black'))
# fig.update_layout(title=f'Angle between average data gradient and target gradient')
fig.update_layout(xaxis_type="log", yaxis_type="linear")
fig.update_layout(legend=dict(x=1, y=0,         
                              font=dict(
                                    size=28,
                                    color="black"
                                )), )
fig.update_layout(
    xaxis_title="<b>Number of bins</b>",
    yaxis_title="<b>IIP (pixel-based)</b>",
    font=dict(
        family="Computer Modern Bold",
        size=24,      
        ),
    legend=dict(
        yanchor="top",
        y=0.35,
        xanchor="right",
        x=0.95,
        bgcolor="white",
        bordercolor="Black",
        borderwidth=1
    )
)
fig.update_layout(uniformtext_minsize=24, uniformtext_mode='hide')


fig.update_xaxes(showline=True, showgrid=True, gridwidth=0.1, 
                 gridcolor='rgba(1,1,1,0.25)', 
                 linecolor='black',
                zerolinecolor='rgba(1,1,1,0.25)')
fig.update_yaxes(showline=True, showgrid=True, gridwidth=0.1, 
                 gridcolor='rgba(1,1,1,0.25)', 
                 linecolor='black',
                 zerolinecolor='rgba(1,1,1,0.25)')

#fig.update_xaxes(range=[0.05, 0.43])
#fig.update_yaxes(range=[-0.1, 1.1])
fig.update_layout(
    margin=dict(l=20, r=20, t=20, b=0),
    paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(0,0,0,0)',
    width=1000,
    height=500,
)
# fig.update_yaxes(automargin=True)

    
fig.show()
fig.write_image("bins_vs_iip_standard.pdf", scale=1)

In [None]:
fig = go.Figure()


x = [32, 64, 128, 256, 512, 1024]
y = [46.88, 48.44, 68.75, 79.69, 89.06, 95.31]
fig.add_trace(go.Scatter(y=y,
                         x=x,
                         name='Batch size 64, Start of Model',
                         mode='lines+markers+text',
                         line=dict(width=5, dash='solid', color='royalblue'),
                         marker=dict(size=20, symbol='circle'),
                         showlegend=True,
                         textposition= "top center",
                         text=''
                        ))

x = [32, 64, 128, 256, 512, 1024]
y = [35.94,34.38,56.25,84.38,84.38,90.62]
fig.add_trace(go.Scatter(y=y,
                         x=x,
                         name='Batch size 64, Middle of Model',
                         mode='lines+markers+text',
                         line=dict(width=5, dash='solid', color='royalblue'),
                         marker=dict(size=20, symbol='square'),
                         showlegend=True,
                         textposition= "top center",
                         text=''
                        ))

x = [32, 64, 128, 256, 512, 1024]
y = [39.06,31.25,60.94,84.38,85.94,82.81]
fig.add_trace(go.Scatter(y=y,
                         x=x,
                         name='Batch size 64, Late in the Model',
                         mode='lines+markers+text',
                         line=dict(width=5, dash='solid', color='royalblue'),
                         marker=dict(size=20, symbol='pentagon'),
                         showlegend=True,
                         textposition= "top center",
                         text=''
                        ))


x = [32, 64, 128, 256, 512, 1024]
y = [14.45,12.89,23.83,47.27,72.27,83.59]
fig.add_trace(go.Scatter(y=y,
                         x=x,
                         name='Batch size 256, Start of Model',
                         mode='lines+markers+text',
                         line=dict(width=5, dash='solid', color='firebrick'),
                         marker=dict(size=20, symbol='circle'),
                         showlegend=True,
                         textposition= "top center",
                         text=''
                        ))

x = [32, 64, 256, 512, 1024]
y = [10.55,7.81,42.19,59.38]
fig.add_trace(go.Scatter(y=y,
                         x=x,
                         name='Batch size 256, Middle of Model',
                         mode='lines+markers+text',
                         line=dict(width=5, dash='solid', color='firebrick'),
                         marker=dict(size=20, symbol='square'),
                         showlegend=True,
                         textposition= "top center",
                         text=''
                        ))

x = [32, 64, 128, 256, 512, 1024]
y = [9.38,11.33,32.42,53.12,72.66]
fig.add_trace(go.Scatter(y=y,
                         x=x,
                         name='Batch size 256, Late in the Model',
                         mode='lines+markers+text',
                         line=dict(width=5, dash='solid', color='firebrick'),
                         marker=dict(size=20, symbol='pentagon'),
                         showlegend=True,
                         textposition= "top center",
                         text=''
                        ))



fig.update_traces(cliponaxis=False, textfont=dict(color='black'))
# fig.update_layout(title=f'Angle between average data gradient and target gradient')
fig.update_layout(xaxis_type="log", yaxis_type="linear")
fig.update_layout(legend=dict(x=1, y=0,         
                              font=dict(
                                    size=22,
                                    color="black"
                                )), )
fig.update_layout(
    xaxis_title="<b>Number of bins</b>",
    yaxis_title="<b>IIP (pixel-based)</b>",
    font=dict(
        family="Computer Modern Bold",
        size=24,      
        ),
    legend=dict(
        yanchor="top",
        y=0.51,
        xanchor="right",
        x=1.0,
        bgcolor="white",
        bordercolor="Black",
        borderwidth=1
    )
)
fig.update_layout(uniformtext_minsize=24, uniformtext_mode='hide')


fig.update_xaxes(showline=True, showgrid=True, gridwidth=0.1, 
                 gridcolor='rgba(1,1,1,0.25)', 
                 linecolor='black',
                zerolinecolor='rgba(1,1,1,0.25)')
fig.update_yaxes(showline=True, showgrid=True, gridwidth=0.1, 
                 gridcolor='rgba(1,1,1,0.25)', 
                 linecolor='black',
                 zerolinecolor='rgba(1,1,1,0.25)')

#fig.update_xaxes(range=[0.05, 0.43])
#fig.update_yaxes(range=[-0.1, 1.1])
fig.update_layout(
    margin=dict(l=20, r=20, t=20, b=0),
    paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(0,0,0,0)',
    width=1000,
    height=500,
)
# fig.update_yaxes(automargin=True)

    
fig.show()
fig.write_image("iip_vs_bs_vs_position.pdf", scale=1)

In [None]:
fig = go.Figure()


x = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
y = [46.88,53.12,50.39,49.61,16.41,1.56]
fig.add_trace(go.Scatter(y=y,
                         x=x,
                         name='Batch size 256, 256 bins, Start of Model',
                         mode='lines+markers+text',
                         line=dict(width=5, dash='solid', color='royalblue'),
                         marker=dict(size=20, symbol='circle'),
                         showlegend=True,
                         textposition= "top center",
                         text=''
                        ))

x = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
y = [44.53,38.67,39.06,41.41,17.58,2.34]
fig.add_trace(go.Scatter(y=y,
                         x=x,
                         name='Batch size 256, 256 bins, Middle of Model',
                         mode='lines+markers+text',
                         line=dict(width=5, dash='solid', color='firebrick'),
                         marker=dict(size=20, symbol='circle'),
                         showlegend=True,
                         textposition= "top center",
                         text=''
                        ))

# x = [32, 64, 128, 256, 512, 1024]
# y = [39.06,31.25,60.94,84.38,85.94,82.81]
# fig.add_trace(go.Scatter(y=y,
#                          x=x,
#                          name='Batch size 64, Late in the Model',
#                          mode='lines+markers+text',
#                          line=dict(width=5, dash='solid'),
#                          marker=dict(size=20, symbol='circle'),
#                          showlegend=True,
#                          textposition= "top center",
#                          text=''
#                         ))

fig.update_traces(cliponaxis=False, textfont=dict(color='black'))
# fig.update_layout(title=f'Angle between average data gradient and target gradient')
fig.update_layout(xaxis_type="log", yaxis_type="linear")
fig.update_layout(legend=dict(x=1, y=0,         
                              font=dict(
                                    size=28,
                                    color="black"
                                )), )
fig.update_layout(
    xaxis_title="<b>Gradient Noise</b>",
    yaxis_title="<b>IIP (pixel-based)</b>",
    font=dict(
        family="Computer Modern Bold",
        size=24,      
        ),
    legend=dict(
        yanchor="top",
        y=0.25,
        xanchor="right",
        x=0.6,
        bgcolor="white",
        bordercolor="Black",
        borderwidth=1
    )
)
fig.update_layout(uniformtext_minsize=32, uniformtext_mode='hide')


fig.update_xaxes(showline=True, showgrid=True, gridwidth=0.1, 
                 gridcolor='rgba(1,1,1,0.25)', 
                 linecolor='black',
                zerolinecolor='rgba(1,1,1,0.25)')
fig.update_yaxes(showline=True, showgrid=True, gridwidth=0.1, 
                 gridcolor='rgba(1,1,1,0.25)', 
                 linecolor='black',
                 zerolinecolor='rgba(1,1,1,0.25)')

#fig.update_xaxes(range=[0.05, 0.43])
#fig.update_yaxes(range=[-0.1, 1.1])
fig.update_layout(
    margin=dict(l=20, r=20, t=20, b=0),
    paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(0,0,0,0)',
    width=1000,
    height=500,
)
# fig.update_yaxes(automargin=True)

    
fig.show()
fig.write_image("gradient_noise_vs_iip.pdf", scale=1)

In [None]:
from math import comb as nCr
import numpy as np
import matplotlib.pyplot as plt

def expected_amount(k, n):
    """
    k number of bins, n batch size
    """
    total_num = nCr(k + n - 1, k - 1)  # Total number of configs
    weight = 0
    for i in range(1, n - 1):
        temp = i * nCr(k, i)
        temp2 = 0
        for j in range(1, (n - i) // 2 + 1):
            temp2 += nCr(k - i, j) * nCr(n - i - j - 1, j - 1)
        weight += temp * temp2
    adjustment1 = n * nCr(k, n)  # First term in r(n,k)
    weight += adjustment1
    return weight / total_num - n / k  # Second adjustment term in r(n,k)

def one_shot_guarantee(k, n):
    """
    k number of bins, n batch size
    """
    total_num = nCr(k + n - 1, k - 1)  # Total number of configs
    weight = 0
    weight += nCr(n+k-3, k-2)
    return weight / total_num  


# Here we produce Figure 2a
params = 28 * 28 * 3 
# total = 11689512 # resnet18
total = 25_557_032 # resnet50
param_count = []
proportion = []
for i in range(256, 5001, 256):
    param_count.append(i * params / total)
    proportion.append(breaching.analysis.expected_amount(i, 256))


param_count_64 = []
proportion_64 = []
for i in range(64, 5001, 64):
    param_count_64.append(i * params / total)
    proportion_64.append(breaching.analysis.expected_amount(i, 64))
    
    
param_count_512 = []
proportion_512 = []
for i in range(512, 5001, 512):
    param_count_512.append(i * params / total)
    proportion_512.append(breaching.analysis.expected_amount(i, 512))
    
    
param_count_1024 = []
proportion_1024 = []
for i in range(1024, 5001, 1024):
    param_count_1024.append(i * params / total)
    param_count_1024.append(breaching.analysis.expected_amount(i, 1024))

In [None]:
fig = go.Figure()

fig.add_trace(go.Scatter(y=proportion,
                         x=param_count,
                         name='Batch Size 256',
                         mode='lines+markers',
                         line=dict(width=5, dash='solid'),
                         marker=dict(size=20, symbol='circle'),
                         showlegend=True,
                         textposition= "top center",
                         text=''
                        ))


fig.add_trace(go.Scatter(y=proportion_512,
                         x=param_count_512,
                         name='Batch Size 512',
                         mode='lines+markers',
                         line=dict(width=5, dash='solid'),
                         marker=dict(size=20, symbol='circle'),
                         showlegend=True,
                         textposition= "top center",
                         text=''
                        ))

fig.add_trace(go.Scatter(y=proportion_64,
                         x=param_count_64,
                         name='Batch Size 64',
                         mode='lines+markers',
                         line=dict(width=5, dash='solid'),
                         marker=dict(size=20, symbol='circle'),
                         showlegend=True,
                         textposition= "top center",
                         text=''
                        ))

fig.add_trace(go.Scatter(y=proportion_1024,
                         x=param_count_1024,
                         name='Batch Size 1024',
                         mode='lines+markers',
                         line=dict(width=5, dash='solid'),
                         marker=dict(size=20, symbol='circle'),
                         showlegend=True,
                         textposition= "top center",
                         text=''
                        ))

fig.update_traces(cliponaxis=False, textfont=dict(color='black'))
# fig.update_layout(title=f'Angle between average data gradient and target gradient')
fig.update_layout(xaxis_type="log", yaxis_type="linear")

fig.update_layout(
    xaxis_title="<b>Proportion of added Parameters</b>",
    yaxis_title="<b>Expected Data Points Reconstructed</b>",
    font=dict(
        family="Computer Modern Bold",
        size=24,      
        ),
    legend=dict(
        yanchor="top",
        y=0.25,
        xanchor="right",
        x=0.95,
        bgcolor="white",
        bordercolor="Black",
        borderwidth=1
    )
)
fig.update_layout(uniformtext_minsize=24, uniformtext_mode='hide')


fig.update_xaxes(showline=True, showgrid=True, gridwidth=0.1, 
                 gridcolor='rgba(1,1,1,0.25)', 
                 linecolor='black',
                zerolinecolor='rgba(1,1,1,0.25)')
fig.update_yaxes(showline=True, showgrid=True, gridwidth=0.1, 
                 gridcolor='rgba(1,1,1,0.25)', 
                 linecolor='black',
                 zerolinecolor='rgba(1,1,1,0.25)')

#fig.update_xaxes(range=[0.05, 0.43])
#fig.update_yaxes(range=[-0.1, 1.1])
fig.update_layout(
    margin=dict(l=20, r=20, t=20, b=0),
    paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(0,0,0,0)',
    width=1000,
    height=500,
)
# fig.update_yaxes(automargin=True)
fig.update_layout(template='seaborn')
fig.update_layout(legend=dict(x=0.25, y=1,         
                              font=dict(
                                    size=28,
                                    color="black"
                                )), )

fig.show()

In [None]:
fig.write_image("expected_reconstruction_detailed.pdf", scale=1)

In [None]:
fig = go.Figure()

# Here we produce Figure 2a
params = 224 * 224 * 3 
# total = 11689512 # resnet18
total = 25_557_032 # resnet50
param_count = []
proportion = []
for i in range(256, 2001, 48):
    param_count.append(i * params / total)
    proportion.append(breaching.analysis.expected_amount(i, 256) / 256)
    
    
param_count = []
proportion = []
for i in range(65, 500, 1):
    param_count.append(i)
    proportion.append(breaching.analysis.expected_amount(i, 64) / 64)

fig.add_trace(go.Scatter(y=proportion,
                         x=param_count,
                         name='Batch Size 128',
                         mode='lines',
                         line=dict(width=5, dash='solid'),
                         marker=dict(size=20, symbol='circle'),
                         showlegend=False,
                         textposition= "top center",
                         text=''
                        ))

fig.update_traces(cliponaxis=False, textfont=dict(color='black'))
# fig.update_layout(title=f'Angle between average data gradient and target gradient')
fig.update_layout(xaxis_type="linear", yaxis_type="linear")

fig.update_layout(
    xaxis_title="<b>Number of Bins</b>",
    yaxis_title="<b>Proportion of Data Reconstructed</b>",
    font=dict(
        family="Computer Modern Bold",
        size=24,      
        ),
    legend=dict(
        yanchor="top",
        y=0.25,
        xanchor="right",
        x=0.95,
        bgcolor="white",
        bordercolor="Black",
        borderwidth=1
    )
)
fig.update_layout(uniformtext_minsize=24, uniformtext_mode='hide')


fig.update_xaxes(showline=True, showgrid=True, gridwidth=0.1, 
                 gridcolor='rgba(1,1,1,0.25)', 
                 linecolor='black',
                zerolinecolor='rgba(1,1,1,0.25)')
fig.update_yaxes(showline=True, showgrid=True, gridwidth=0.1, 
                 gridcolor='rgba(1,1,1,0.25)', 
                 linecolor='black',
                 zerolinecolor='rgba(1,1,1,0.25)')

#fig.update_xaxes(range=[0.05, 0.43])
#fig.update_yaxes(range=[-0.1, 1.1])
fig.update_layout(
    margin=dict(l=20, r=20, t=20, b=0),
    paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(0,0,0,0)',
    width=1000,
    height=500,
)
# fig.update_yaxes(automargin=True)
fig.update_layout(template='seaborn')
fig.update_layout(legend=dict(x=0.25, y=1,         
                              font=dict(
                                    size=28,
                                    color="black"
                                )), )

fig.show()
fig.write_image("expected_reconstruction.pdf", scale=1)

In [None]:


# Here we produce Figure 2b
param_count = []
proportion = []
n=4096 #number of datapoints
for i in range(600, 24000, 100): # Some range including n. Might need smaller step than 100 for smaller n
    param_count.append(1/i)
    proportion.append(one_shot_guarantee(i, n))

print(proportion[np.argmax(np.array(proportion))])
print(param_count[np.argmax(np.array(proportion))])

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(y=proportion,
                         x=param_count,
                         name='Batch Size 128',
                         mode='lines+markers',
                         line=dict(width=5, dash='solid'),
                         marker=dict(size=20, symbol='circle'),
                         showlegend=False,
                         textposition= "top center",
                         text=''
                        ))

fig.update_traces(cliponaxis=False, textfont=dict(color='black'))
# fig.update_layout(title=f'Angle between average data gradient and target gradient')
fig.update_layout(xaxis_type="linear", yaxis_type="linear")

fig.update_layout(
    xaxis_title="<b>Mass captured in One-Shot Bin</b>",
    yaxis_title="<b>Probability of One-Shot Success</b>",
    font=dict(
        family="Computer Modern Bold",
        size=24,      
        ),
    legend=dict(
        yanchor="top",
        y=0.25,
        xanchor="right",
        x=0.95,
        bgcolor="white",
        bordercolor="Black",
        borderwidth=1
    )
)
fig.update_layout(uniformtext_minsize=24, uniformtext_mode='hide')


fig.update_xaxes(showline=True, showgrid=True, gridwidth=0.1, 
                 gridcolor='rgba(1,1,1,0.25)', 
                 linecolor='black',
                zerolinecolor='rgba(1,1,1,0.25)')
fig.update_yaxes(showline=True, showgrid=True, gridwidth=0.1, 
                 gridcolor='rgba(1,1,1,0.25)', 
                 linecolor='black',
                 zerolinecolor='rgba(1,1,1,0.25)')

#fig.update_xaxes(range=[0.05, 0.43])
#fig.update_yaxes(range=[-0.1, 1.1])
fig.update_layout(
    margin=dict(l=20, r=20, t=20, b=0),
    paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(0,0,0,0)',
    width=1000,
    height=500,
)
# fig.update_yaxes(automargin=True)
fig.update_layout(template='seaborn')
fig.update_layout(legend=dict(x=0.25, y=1,         
                              font=dict(
                                    size=28,
                                    color="black"
                                )), )

fig.show()
fig.write_image("oneshot_success.pdf", scale=1)