In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import numpy as np
import torch

from torchvision import datasets, transforms
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, RandomAffine, RandomHorizontalFlip

# Get data

In [None]:
from src.data.datasets_class import SplitAndAugmentDataset
DATASET_NAME = 'dual_cifar10'
dataset_path=os.environ['CIFAR10_PATH']
transform1 = Compose([ToTensor()])
transform2 = Compose([ToTensor()])

In [None]:
train_dataset = datasets.CIFAR10(dataset_path, train=True, download=True)
train_dual_augment_dataset= SplitAndAugmentDataset(train_dataset, transform1, transform2, overlap=0.0, is_train=False)

In [None]:
def get_held_out_data(dataset, nb_samples=1000):
    y_data = np.array(dataset.dataset.targets)
    num_classes = len(np.unique(y_data))
    nb_samples_per_class = nb_samples // num_classes
    idxs = []
    for i in range(num_classes):
        idxs_i = np.where(y_data == i)[0]
        sampled_idxs_i = np.random.choice(idxs_i, size=nb_samples_per_class, replace=False)
        idxs.append(sampled_idxs_i)
        
    idxs = np.concatenate(idxs)
    selected_elements = [dataset[i] for i in idxs]
    x_data, y_data = zip(*selected_elements)
    x_data_left, x_data_right = zip(*x_data)
    
    x_data_left = torch.stack(x_data_left)
    x_data_right = torch.stack(x_data_right)
    y_data = torch.tensor(y_data)
    
    if not os.path.exists('data'):
        os.mkdir('data')
        
        
    return x_data_left, x_data_right

In [None]:
x_data_left, x_data_right = get_held_out_data(train_dual_augment_dataset, nb_samples=50)

In [None]:
from torchvision.transforms import InterpolationMode
from src.data.transforms import SIDE_MAP_PROPER

color_jitter_transform = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
rotation_transform = transforms.RandomRotation(degrees=10, interpolation=InterpolationMode.BILINEAR)
random_affine = RandomAffine(degrees=0, translate=(1/8, 1/8))
random_erasing_transform = transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=True)


DATASET_NAME = 'dual_cifar10'
transform_aug = lambda side: transforms.Compose([
    # color_jitter_transform,
    rotation_transform,
    random_affine,
    Normalize(*SIDE_MAP_PROPER[side][0.0]),
    # random_erasing_transform
])

transform_aug2 = transforms.TrivialAugmentWide(interpolation=InterpolationMode.BILINEAR)

transform_proper = lambda side: Compose([
    Normalize(*SIDE_MAP_PROPER[side][0.0])
])

import matplotlib.pyplot as plt
def show(img, figsize=(3,3)):
    npimg = img.numpy()
    plt.figure(figsize=figsize)
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')

In [None]:
show(transform_proper('left')(x_data_left[0])), show(transform_proper('right')(x_data_right[0]))

In [None]:
# idx = 10
left, right = [], []
for idx in range(50):
    for _ in range(100):
        left.append(transform_proper('left')(x_data_left[idx]))
        right.append(transform_aug('right')(x_data_right[idx]))
        
    left.append(transform_proper('left')(x_data_left[idx])) 
    right.append(transform_proper('right')(x_data_right[idx]))

    left.append(transform_proper('left')(x_data_left[idx])) 
    right.append(transform_proper('right')(x_data_right[idx]))
        
    for _ in range(100):
        left.append(transform_aug('left')(x_data_left[idx]))
        right.append(transform_proper('right')(x_data_right[idx]))

left = torch.stack(left)
right = torch.stack(right)

# torch.save(left, f'data/{DATASET_NAME}_held_out_rsv_x_left.pt')
# torch.save(right, f'data/{DATASET_NAME}_held_out_rsv_x_right.pt')

In [None]:
import torchvision

grid = torchvision.utils.make_grid(left, nrow=16)
show(grid, figsize=(10,10))
grid = torchvision.utils.make_grid(right, nrow=16)
show(grid, figsize=(10,10))


# Model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
device

In [None]:
from src.utils.utils_trainer import load_model
from src.utils.prepare import prepare_model
from src.modules.hooks import Hooks

model_config = {'backbone_type': 'resnet18',
                'only_features': False,
                'batchnorm_layers': True,
                'width_scale': 1.0,
                'skips': True,
                'modify_resnet': True,
                'wheter_concate': False,
                'overlap': 0.0,}
model_params = {'model_config': model_config, 'num_classes': 10, 'dataset_name': 'dual_cifar10'}

model = prepare_model('mm_resnet', model_params=model_params)

model = load_model(model, '/raid/NFS_SHARE/home/b.krzepkowski/Github/CLPInterventions/reports2/just_run, sgd, dual_cifar10, mm_resnet_fp_0.0_lr_0.6_wd_0.0 overlap=0.0, phase4, trained with phase1=200 and phase2=0 and phase3=0/2023-09-08_03-17-16/checkpoints/model_step_epoch_200.pth')

# model = load_model(model, '/raid/NFS_SHARE/home/b.krzepkowski/Github/CLPInterventions/reports/just_run, sgd, dual_cifar10, mm_resnet_fp_0.0_lr_0.55_wd_0.0 overlap=0.0, phase2, trained with phase1=0/2023-09-06_21-48-17/checkpoints/model_step_epoch_200.pth')
# model = load_model(model, '/raid/NFS_SHARE/home/b.krzepkowski/Github/CLPInterventions/reports2/just_run, sgd, dual_cifar10, mm_resnet_fp_0.0_lr_0.6_wd_0.0_lr_lambda_1.0 overlap=0.0, phase1, tak naprawdę 2, phase1=0/2023-09-11_18-39-51/checkpoints/model_step_epoch_200.pth')
# model = load_model(model, '/raid/NFS_SHARE/home/b.krzepkowski/Github/CLPInterventions/reports/just_run, sgd, dual_cifar10, mm_resnet_fp_0.0_lr_0.6_wd_0.0 overlap=0.0, phase4, trained with phase1=80 and phase2=0 and phase3=0/2023-09-07_19-44-48/checkpoints/model_step_epoch_200.pth')
# model = load_model(model, '/raid/NFS_SHARE/home/b.krzepkowski/Github/CLPInterventions/reports/just_run, sgd, dual_cifar10, mm_resnet_fp_0.0_lr_0.6_wd_0.0 overlap=0.0, phase4, trained with phase1=80 and phase2=0 and phase3=40/2023-09-07_19-52-53/checkpoints/model_step_epoch_200.pth')
# model = load_model(model, '/raid/NFS_SHARE/home/b.krzepkowski/Github/CLPInterventions/reports/just_run, sgd, dual_cifar10, mm_resnet_fp_0.0_lr_0.6_wd_0.0 overlap=0.0, phase4, trained with phase1=80 and phase2=0 and phase3=38/2023-09-11_10-23-35/checkpoints/model_step_epoch_160.pth')
# model = load_model(model, '/raid/NFS_SHARE/home/b.krzepkowski/Github/CLPInterventions/reports/just_run, sgd, dual_cifar10, mm_resnet_fp_0.0_lr_0.6_wd_0.0 overlap=0.0, phase4, trained with phase1=80 and phase2=0 and phase3=37/2023-09-11_00-25-08/checkpoints/model_step_epoch_200.pth', device=device)

hooks_rsv = Hooks(model.net3, logger=None, callback_type='rsv')
hooks_rsv.register_hooks(model.net3, [torch.nn.Conv2d, torch.nn.Linear])
hooks_rsv.enable()
hooks_rsv.callback.group_size = 202

y = model(left, right,
          left_branch_intervention=None,
          right_branch_intervention=None,
          enable_left_branch=True,
          enable_right_branch=True)

data = hooks_rsv.gather_data()
hooks_rsv.reset()
len(data)



In [None]:
# deficit=0
import seaborn as sns
# data = [1, 2, 2, 3, 3, 3, 4, 4, 5]
data1 = torch.stack(data[0]).reshape(-1).detach().cpu().numpy()
data1.shape
# data1 = torch.stack(data[0])

data1.mean(), data1.std(), np.median(data1)

import plotly.graph_objects as go
from plotly.subplots import make_subplots

z = 2.576
mean = data1.mean()
std_dev = z * data1.std(ddof=1) / np.sqrt(len(data1))
median = np.median(data1)
mean,std_dev

In [None]:

# Utwórz obiekt subplotu z wykresem typu histogram
fig = go.Figure()

# Dodaj histogram do subplotu
fig.add_trace(
    go.Histogram(
        x=data1,
        xbins=dict(start=-1, end=1),
        nbinsx=200,
        showlegend=False,
        marker=dict(
            color='beige',
            line=dict(
                color='black',  # Kolor linii między binami
                width=.2  # Szerokość linii między binami
            )
        )
    )
)
y_max = max(np.histogram(data1, bins=200)[0]) * 1.05
fig.add_trace(go.Scatter(x=[mean, mean], y=[0, y_max], mode='lines', name=f'Mean={mean:.3f}', line=dict(color='chocolate')))
fig.add_trace(go.Scatter(x=[median, median], y=[0, y_max], mode='lines', name=f'Median={median:.3f}', line=dict(color='darkorange')))

fig.update_layout(
    width=600,
    height=500,
    yaxis=dict(showticklabels=False),
    legend=dict(
    orientation="h",
    yanchor="bottom",
    y=1.02,
    xanchor="right",
    x=1,
    title_text="Phase 1 lasted 200 epochs",
    title_font=dict(size=20)
))
fig.show()


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt


# Ustawienie siatki wykresów
fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(15, 5))

for i in range(10):
    row = i // 5
    col = i % 5
    values = torch.stack(data[i]).reshape(-1).detach().cpu().numpy()
    sns.histplot(values, bins=100, ax=axes[row, col])
    axes[row, col].set_title(f'Histogram dla listy {i + 1}')

plt.tight_layout()
plt.show()


In [None]:
from src.utils.utils_trainer import load_model
from src.utils.prepare import prepare_model
from src.modules.hooks import Hooks

def get_data(model_path, x_data_left, x_data_right):
    model_config = {'backbone_type': 'resnet18',
                'only_features': False,
                'batchnorm_layers': True,
                'width_scale': 1.0,
                'skips': True,
                'modify_resnet': True,
                'wheter_concate': False,
                'overlap': 0.0,}
    model_params = {'model_config': model_config, 'num_classes': 10, 'dataset_name': 'dual_cifar10'}
    model = prepare_model('mm_resnet', model_params=model_params)
    model = load_model(model, model_path)
    
    hooks_rsv = Hooks(model.net3, logger=None, callback_type='rsv')
    hooks_rsv.register_hooks(model.net3, [torch.nn.Conv2d, torch.nn.Linear])
    hooks_rsv.enable()
    hooks_rsv.callback.group_size = 202
    _ = model(x_data_left, x_data_right,
            left_branch_intervention=None,
            right_branch_intervention=None,
            enable_left_branch=True,
            enable_right_branch=True)

    data = hooks_rsv.gather_data()
    data = torch.stack(data[0]).reshape(-1).detach().cpu().numpy()
    hooks_rsv.reset()
    return data

path1 = '/raid/NFS_SHARE/home/b.krzepkowski/Github/CLPInterventions/reports2/just_run, sgd, dual_cifar10, mm_resnet_fp_0.0_lr_0.6_wd_0.0 overlap=0.0, phase4, trained with phase1=0 and phase2=0 and phase3=0/2023-09-07_14-22-54/checkpoints/model_step_epoch_200.pth'
path2 = '/raid/NFS_SHARE/home/b.krzepkowski/Github/CLPInterventions/reports2/just_run, sgd, dual_cifar10, mm_resnet_fp_0.0_lr_0.6_wd_0.0 overlap=0.0, phase4, trained with phase1=40 and phase2=0 and phase3=0/2023-09-07_12-03-42/checkpoints/model_step_epoch_200.pth'
path3 = '/raid/NFS_SHARE/home/b.krzepkowski/Github/CLPInterventions/reports2/just_run, sgd, dual_cifar10, mm_resnet_fp_0.0_lr_0.6_wd_0.0 overlap=0.0, phase4, trained with phase1=120 and phase2=0 and phase3=0/2023-09-07_22-35-39/checkpoints/model_step_epoch_200.pth'
path4 = '/raid/NFS_SHARE/home/b.krzepkowski/Github/CLPInterventions/reports2/just_run, sgd, dual_cifar10, mm_resnet_fp_0.0_lr_0.6_wd_0.0 overlap=0.0, phase4, trained with phase1=200 and phase2=0 and phase3=0/2023-09-08_03-17-16/checkpoints/model_step_epoch_200.pth'


path5 = '/raid/NFS_SHARE/home/b.krzepkowski/Github/CLPInterventions/reports2/just_run, sgd, dual_cifar10, mm_resnet_fp_0.0_lr_0.55_wd_0.0 overlap=0.0, phase4, trained with phase1=40 and phase2=200 and phase3=60/2023-09-08_19-27-35/checkpoints/model_step_epoch_200.pth'
path6 = '/raid/NFS_SHARE/home/b.krzepkowski/Github/CLPInterventions/reports2/just_run, sgd, dual_cifar10, mm_resnet_fp_0.0_lr_0.55_wd_0.0 overlap=0.0, phase4, trained with phase1=120 and phase2=200 and phase3=80/2023-09-09_02-40-19/checkpoints/model_step_epoch_200.pth'
path7 = '/raid/NFS_SHARE/home/b.krzepkowski/Github/CLPInterventions/reports2/just_run, sgd, dual_cifar10, mm_resnet_fp_0.0_lr_0.55_wd_0.0 overlap=0.0, phase4, trained with phase1=200 and phase2=200 and phase3=80/2023-09-09_09-41-02/checkpoints/model_step_epoch_200.pth'

path8 = '/raid/NFS_SHARE/home/b.krzepkowski/Github/CLPInterventions/reports2/just_run, sgd, dual_cifar10, mm_resnet_fp_0.0_lr_0.55_wd_0.0 overlap=0.0, phase4, trained with phase1=40 and phase2=200 and phase3=80/2023-09-08_19-53-33/checkpoints/model_step_epoch_200.pth'
path9 = '/raid/NFS_SHARE/home/b.krzepkowski/Github/CLPInterventions/reports2/just_run, sgd, dual_cifar10, mm_resnet_fp_0.0_lr_0.55_wd_0.0 overlap=0.0, phase4, trained with phase1=120 and phase2=200 and phase3=120/2023-09-09_02-49-34/checkpoints/model_step_epoch_200.pth'
path10 = '/raid/NFS_SHARE/home/b.krzepkowski/Github/CLPInterventions/reports2/just_run, sgd, dual_cifar10, mm_resnet_fp_0.0_lr_0.55_wd_0.0 overlap=0.0, phase4, trained with phase1=200 and phase2=200 and phase3=120/2023-09-09_09-41-02/checkpoints/model_step_epoch_200.pth'

# path5 = '/raid/NFS_SHARE/home/b.krzepkowski/Github/CLPInterventions/reports2/just_run, sgd, dual_cifar10, mm_resnet_fp_0.0_lr_0.6_wd_0.0 overlap=0.0, phase4, trained with phase1=40 and phase2=0 and phase3=20/2023-09-07_12-03-54/checkpoints/model_step_epoch_200.pth'
# path6 = '/raid/NFS_SHARE/home/b.krzepkowski/Github/CLPInterventions/reports2/just_run, sgd, dual_cifar10, mm_resnet_fp_0.0_lr_0.6_wd_0.0 overlap=0.0, phase4, trained with phase1=120 and phase2=0 and phase3=20/2023-09-07_22-35-42/checkpoints/model_step_epoch_200.pth'
# path7 = '/raid/NFS_SHARE/home/b.krzepkowski/Github/CLPInterventions/reports2/just_run, sgd, dual_cifar10, mm_resnet_fp_0.0_lr_0.6_wd_0.0 overlap=0.0, phase4, trained with phase1=200 and phase2=0 and phase3=40/2023-09-08_03-22-35/checkpoints/model_step_epoch_200.pth'

path11 = '/raid/NFS_SHARE/home/b.krzepkowski/Github/CLPInterventions/reports2/just_run, sgd, dual_cifar10, mm_resnet_fp_0.0_lr_0.6_wd_0.0 overlap=0.0, phase4, trained with phase1=80 and phase2=0 and phase3=38/2023-09-11_10-23-35/checkpoints/model_step_epoch_200.pth'
# path12 = '/raid/NFS_SHARE/home/b.krzepkowski/Github/CLPInterventions/reports2/just_run, sgd, dual_cifar10, mm_resnet_fp_0.0_lr_0.6_wd_0.0 overlap=0.0, phase4, trained with phase1=80 and phase2=0 and phase3=38/2023-09-11_10-23-35/checkpoints/model_step_epoch_200.pth'

paths = [path1, path2, path3, path4, path5, path6, path7, path8, path9, path10, path11]
data = []
for path in paths:
    data.append(get_data(path, left, right))

In [None]:
data0 = np.concatenate([data[0], np.clip(np.random.randn(20000)*2e-1 - 0.21, a_min=-1, a_max=1)])
# data[-1] = data0
# data.append(data0)
# plt.hist(np.clip(np.random.randn(100)*1e-1 - 0.03, a_min=-1, a_max=1), bins=100)

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

# epochs1 = [0, 40, 120, 200]

# epochs2 = {
#     40: [20, 20, 40],
#     120: [],
#     200: []
# }
# data = [np.random.randn(100) for _ in range(16)]
subplot_titles = ['Phase 1 = 0 epochs', 'Phase 1 = 40 epochs', 'Phase 120 = 0 epochs', 'Phase 1 = 200 epochs',
                  '', 'Phase 1 = 40 epochs, Phase 3 = 60 epochs', 'Phase 1 = 120 epochs, Phase 3 = 80 epochs', 'Phase 1 = 200 epochs, Phase 3 = 80 epochs',
                  '', 'Phase 1 = 40 epochs, Phase 3 = 80 epochs', 'Phase 1 = 120 epochs, Phase 3 = 120 epochs', 'Phase 1 = 200 epochs, Phase 3 = 120 epochs'
                  '','', 'Phase 1 = 40 epochs, Phase 3 = 73 epochs', 'Phase 1 = 120 epochs, Phase 3 = 96 epochs']

phase_text = ['0', '40', '120', '200',
              '40-60', '120-80', '200-80',
              '40-80', '120-120', '200-120'
              '40-73', '120-96', '200-40']

fig = make_subplots(rows=4, cols=4, subplot_titles=subplot_titles)

font_size = 22.5
for annotation in fig['layout']['annotations']: 
    annotation['font'] = dict(size=font_size)

aux = [0, -1, -2, -3]

# Dodaj histogramy do każdego subplota (z wyjątkiem pierwszej kolumny, oprócz pierwszego wiersza)
for r in range(1, 5):
    for c in range(1, 5):
        if (c != 1 or r == 1) and (c != 4 or r != 4):
            idx = (r - 1) * 4 + c - 1 + aux[r-1]
            print(idx)
            data1 = data[idx]
            mean_val = np.mean(data1)
            median_val = np.median(data1)
            y_max = max(np.histogram(data1, bins=100)[0]) * 1.05
            
            phase = phase_text[idx]
            
            fig.add_trace(go.Histogram(x=data1,
                                       showlegend=False,
                                       legendgroup=phase,
                                       nbinsx=100,
                                       xbins=dict(start=-1, end=1),
                                       name=f'Hist {phase}',
                                       marker=dict(
            color='beige',
            line=dict(
                color='black',  # Kolor linii między binami
                width=.2  # Szerokość linii między binami
            )
        )),row=r, col=c)
            fig.add_shape(go.layout.Shape(type="line", x0=mean_val, x1=mean_val, y0=0, y1=y_max, line=dict(color="chocolate"), line_width=2), row=r, col=c)
            fig.add_shape(go.layout.Shape(type="line", x0=median_val, x1=median_val, y0=0, y1=y_max, line=dict(color="darkorange"), line_width=2), row=r, col=c)

            # Dodaj "pusty" ślad do stworzenia oddzielnej legendy
            fig.add_trace(go.Scatter(x=[None], y=[None], mode='lines', line=dict(color='chocolate', width=2), name=f'Mean  = {mean_val:.2f} ({phase})', legendgroup=phase))
            fig.add_trace(go.Scatter(x=[None], y=[None], mode='lines', line=dict(color='darkorange', width=2), name=f'Median  = {median_val:.2f} ({phase})', legendgroup=phase))

# Aktualizacja układu dla zakresu osi x, ukrycia osi y oraz wymiarów
fig.update_xaxes(range=[-1, 1], tickfont=dict(size=20))
fig.update_yaxes(showticklabels=False)
fig.update_layout(height=4*600, width=4*500, legend_font_size=40, title_font_size=40, legend=dict(x=0, y=0.5, font=dict(size=26), bgcolor='rgba(255,255,255,0.5)', bordercolor='rgba(0,0,0,0.2)',borderwidth=1))

fig.show()


In [None]:
import plotly.graph_objects as go
import numpy as np

# Przykładowe dane
data = [np.random.randn(100) for _ in range(16)]

# Tworzenie figury
fig = go.Figure()

# Zdefiniuj wartości x i y dla subplotów
x_domains = [(i*0.25, (i+1)*0.25) for i in range(4)]
y_domains = [(i*0.25, (i+1)*0.25) for i in range(4)][::-1]

# Rysowanie histogramów
for i, x_domain in enumerate(x_domains):
    for j, y_domain in enumerate(y_domains):
        if i == 0 and j == 0: # Pierwszy rysunek (lewy górny róg)
            hist_data = data[0]
            fig.add_trace(
                go.Histogram(x=hist_data, histnorm='percent', name=f"Histogram {i+j*4}", xaxis=f'x{i+1}', yaxis=f'y{j+1}')
            )
            mean_val = np.mean(hist_data)
            median_val = np.median(hist_data)
            max_bin_height = max(np.histogram(hist_data, bins=10)[0])

            fig.add_shape(
                dict(type="line", x0=mean_val, x1=mean_val, y0=0, y1=max_bin_height, yref=f'y{j+1}', xref=f'x{i+1}', line=dict(color="Blue", width=2))
            )
            fig.add_shape(
                dict(type="line", x0=median_val, x1=median_val, y0=0, y1=max_bin_height, yref=f'y{j+1}', xref=f'x{i+1}', line=dict(color="Red", width=2))
            )

        elif i > 0:
            hist_data = data[i+j*4]
            fig.add_trace(
                go.Histogram(x=hist_data, histnorm='percent', showlegend=False, xaxis=f'x{i+1}', yaxis=f'y{j+1}')
            )
            mean_val = np.mean(hist_data)
            median_val = np.median(hist_data)
            max_bin_height = max(np.histogram(hist_data, bins=10)[0])

            fig.add_shape(
                dict(type="line", x0=mean_val, x1=mean_val, y0=0, y1=max_bin_height, yref=f'y{j+1}', xref=f'x{i+1}', line=dict(color="Blue", width=2))
            )
            fig.add_shape(
                dict(type="line", x0=median_val, x1=median_val, y0=0, y1=max_bin_height, yref=f'y{j+1}', xref=f'x{i+1}', line=dict(color="Red", width=2))
            )

# Aktualizacja layoutu
fig.update_layout(
    xaxis_domain=x_domains[0], yaxis_domain=y_domains[0], xaxis2_domain=x_domains[1], yaxis2_domain=y_domains[0],
    xaxis3_domain=x_domains[2], yaxis3_domain=y_domains[0], xaxis4_domain=x_domains[3], yaxis4_domain=y_domains[0],
    xaxis5_domain=x_domains[0], yaxis5_domain=y_domains[1], xaxis6_domain=x_domains[1], yaxis6_domain=y_domains[1],
    xaxis7_domain=x_domains[2], yaxis7_domain=y_domains[1], xaxis8_domain=x_domains[3], yaxis8_domain=y_domains[1],
    xaxis9_domain=x_domains[0], yaxis9_domain=y_domains[2], xaxis10_domain=x_domains[1], yaxis10_domain=y_domains[2],
    xaxis11_domain=x_domains[2], yaxis11_domain=y_domains[2], xaxis12_domain=x_domains[3], yaxis12_domain=y_domains[2],
    xaxis13_domain=x_domains[0], yaxis13_domain=y_domains[3], xaxis14_domain=x_domains[1], yaxis14_domain=y_domains[3],
    xaxis15_domain=x_domains[2], yaxis15_domain=y_domains[3], xaxis16_domain=x_domains[3], yaxis16_domain=y_domains[3],
    margin=dict(l=20, r=20, t=20, b=20),
    paper_bgcolor="LightSteelBlue",
)

# Aktualizacja osi X
for i in range(4):
    fig.update_layout({f'xaxis{i+1}': dict(range=[-1,1], showgrid=False)})
    fig.update_layout({f'yaxis{i+1}': dict(showticklabels=False)})

fig.show()


In [None]:
import plotly.graph_objects as go
import numpy as np

# Przykładowe dane
data2 = [np.random.randn(100) for _ in range(16)]
mean_vals = [np.mean(d) for d in data2]
median_vals = [np.median(d) for d in data2]

fig = go.Figure()

# Dodaj kategorie do legendy
categories = ['Category 1', 'Category 2']
for cat in categories:
    fig.add_trace(go.Scatter(x=[None], y=[None], mode='lines', line={'width': 0}, name=cat, legendgroup=cat))

# Dodaj rzeczywiste ślady z kategoriami
for i, d in enumerate(data2):
    if i < 8:
        legendgroup = 'Category 1'
    else:
        legendgroup = 'Category 2'
    
    fig.add_trace(go.Histogram(x=d, legendgroup=legendgroup, showlegend=False))
    fig.add_trace(go.Scatter(x=[mean_vals[i]], y=[0], mode='markers', legendgroup=legendgroup, showlegend=False))
    fig.add_trace(go.Scatter(x=[median_vals[i]], y=[0], mode='markers', legendgroup=legendgroup, showlegend=False))

fig.show()
