In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR


import numpy as np
import matplotlib.pyplot as plt




Images in MNIST have pixel values ranging from 0 to 1. After normalization using `transforms.Normalize(mean=(0.5,), std=(0.5,))`, the transformed images have pixel values ranging from -1 to 1.

Due to this transformation, the generator's last layer typically uses the hyperbolic tangent (tanh) activation function to get pixel values ranging from -1 to 1.



In [None]:
BATCH_SIZE = 32

# Image processing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

# MNIST dataset
train_dataset = MNIST(root='./data', 
                      train=True, 
                      transform=transform,
                      download=True)

test_dataset = MNIST(root='./data', 
                     train=False, 
                     transform=transform,
                     download=True)


train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=BATCH_SIZE, 
                          shuffle=True)

test_loader = DataLoader(dataset=test_dataset, 
                         batch_size=BATCH_SIZE, 
                         shuffle=False)

#### CustomMNIST

To restrict the dataset to certain digits, such as creating a dataset with fewer instances of the digit '1', you can use this class.

This can be very useful when you want to make it more challenging for a generator to produce specific digits.

In [None]:
import torch
from torchvision.datasets import MNIST
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class CustomMNIST(Dataset):
    def __init__(self, root, train=True, transform=None, download=False, digit_limits=None):
        self.mnist = MNIST(root=root, train=train, transform=transform, download=download)
        self.transform = transform
        self.digit_limits = digit_limits
        self.indices = self._get_indices()

    def _get_indices(self):
        digit_count = {i: 0 for i in range(10)}
        indices = []

        for i, (image, label) in enumerate(self.mnist):
            limit = self.digit_limits.get(label, None)
            if limit is None or digit_count[label] < limit:
                indices.append(i)
                digit_count[label] += 1

        return indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        idx = self.indices[idx]
        return self.mnist[idx]

    
digit_limits = {2:1000, 8: 1000, 7: 1000}

train_dataset = CustomMNIST(
    root='./data',
    train=True,
    transform=transform,
    download=True,
    digit_limits=digit_limits
)
train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=BATCH_SIZE, 
                          shuffle=True)

In [None]:
test_dataset

##### Mnist samples
Here an image and some pixels from it

In [None]:
img, label = train_dataset[0]
print('Label: ', label)
print(img[:,10:15,10:15])
torch.min(img), torch.max(img)

In [None]:
def denorm(x):
    out = (x + 1) / 2
    # print(out)
    return out.clamp(0, 1)

In [None]:
inputs, classes = next(iter(train_loader))
inputs = inputs[:6]
inputs = [el[0] for el in inputs]
classes = classes[:6]

fig = plt.figure()
for i in range(6):
    plt.subplot(2,3,i+1)
    plt.imshow(denorm(inputs[i]), cmap='gray', interpolation='none');
    plt.title("Ground Truth: {}".format(classes[i]))
    # plt.xticks([])
    # plt.yticks([])
    
plt.tight_layout()
plt.show()

In [None]:
IMAGE_SIZE = 784
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

### Load pretrained classifier

use it to calculate FID and accuracy

In [None]:
from modules.mnist_classifier import eval_model
from modules.mnist_models import CNN, CNN2, CNN3
CLASSIFIER = CNN3().to(device=device)
PATH = './mnist_tests_classifier/classifier__CNN3__1_9_32/classifier_model.pt'
CLASSIFIER.load_state_dict(torch.load(PATH))
CLASSIFIER.eval()

In [None]:
# CLASSIFIER
loss_test, accuracy_test = eval_model(test_loader,
                                      CLASSIFIER,
                                      criterion = nn.CrossEntropyLoss(),
                                      device = device)
print(loss_test, accuracy_test)

### Train


In [None]:
from modules.mnist import Generator, Discriminator, train
from modules.weighted_bce import WeightedVarianceBCE

In [None]:
save_path = '/Users/serafim/Desktop/Job/projects/science/hse/GAN-Estimation-Uncertainty/uncertainty/mnist_tests'
# save_path = '/Users/serafim/Desktop/Job/projects/science/hse/GAN-Estimation-Uncertainty/uncertainty/custom_mnist_tests'
# save_path = '/Users/serafim/Desktop/Job/projects/science/hse/GAN-Estimation-Uncertainty/uncertainty/custom_mnist_tests_v'
# save_path = '/Users/serafim/Desktop/Job/projects/science/hse/GAN-Estimation-Uncertainty/uncertainty/custom_mnist_tests_metric1'
save_path = '/Users/serafim/Desktop/Job/projects/science/hse/GAN-Estimation-Uncertainty/uncertainty/custom_mnist_tests_D'
save_path = '/Users/serafim/Desktop/Job/projects/science/hse/GAN-Estimation-Uncertainty/uncertainty/mnist_tests_D'

In [None]:
G = Generator().to(device=device)
D = Discriminator().to(device=device)
# criterion = nn.BCELoss()
# loss_function = nn.BCELoss()
loss_function = WeightedVarianceBCE()
learning_rate_G = 0.0001
learning_rate_D = 0.0001
# lr = 0.0001
NUM_EPOCHS = 50

import math 
NUM_BATCHES = math.ceil(len(train_loader.dataset)/train_loader.batch_size)
print(f'NUM_BATCHES: {NUM_BATCHES}')
D_optimizer = torch.optim.Adam(D.parameters(), lr=learning_rate_G)
G_optimizer = torch.optim.Adam(G.parameters(), lr=learning_rate_D)

max_lr = 0.00015

scheduler_D = OneCycleLR(D_optimizer, max_lr=max_lr, total_steps=NUM_EPOCHS * NUM_BATCHES)
scheduler_G = OneCycleLR(G_optimizer, max_lr=max_lr, total_steps=NUM_EPOCHS * NUM_BATCHES)

# save_path = '/Users/serafim/Desktop/Job/projects/science/hse/GAN-Estimation-Uncertainty/uncertainty/custom_mnist_tests'
NAME = 'base_wgan__lr_0,0001__1_3'

D_losses_final, G_losses_final = train(
    num_epochs = NUM_EPOCHS,
    data_loader = train_loader,
    D = D,
    G = G,
    D_optimizer = D_optimizer,
    G_optimizer = G_optimizer, 
    criterion = loss_function,
    device = device,
    name = NAME, 
    save_path = save_path, 
    progress_generator = True,
    plot_process = True,
    info_n = 1,
    classifier = CLASSIFIER,
    fid = True,
    fid_dataset = train_loader,
    weights_bce = True,
    test_fid = True,
    scheduler_D = scheduler_D,
    scheduler_G = scheduler_G
)
# plot_sine(G, save_path = save_path, name = NAME)


In [None]:
Discriminator(
  (label_embedding): Embedding(10, 10)
  (model): Sequential(
    (0): Linear(in_features=794, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Dropout(p=0.3, inplace=False)
    (6): Linear(in_features=512, out_features=512, bias=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Dropout(p=0.3, inplace=False)
    (9): Linear(in_features=512, out_features=1, bias=True)
    (10): Sigmoid()
  )
)

Save model, Generator and Discriminator

In [None]:
import os
def save_model(model, save_path, name, name2):
    filepath = os.path.join(save_path, name, name2)
    model.eval()
    torch.save(model.state_dict(), filepath)
    print(f"Model saved at: {filepath}")
    


In [None]:
save_model(G, save_path, NAME, name2 = 'Generator')
save_model(D, save_path, NAME, name2 = 'Discriminator')

In [None]:
G = Generator().to(device=device)
D = Discriminator().to(device=device)
# criterion = nn.BCELoss()
# loss_function = nn.BCELoss()
loss_function = WeightedVarianceBCE()
learning_rate_G = 0.0001
learning_rate_D = 0.0001
# lr = 0.0001
NUM_EPOCHS = 50

import math 
NUM_BATCHES = math.ceil(len(train_loader.dataset)/train_loader.batch_size)
print(f'NUM_BATCHES: {NUM_BATCHES}')
D_optimizer = torch.optim.Adam(D.parameters(), lr=learning_rate_G)
G_optimizer = torch.optim.Adam(G.parameters(), lr=learning_rate_D)

max_lr = 0.00015

scheduler_D = OneCycleLR(D_optimizer, max_lr=max_lr, total_steps=NUM_EPOCHS * NUM_BATCHES)
scheduler_G = OneCycleLR(G_optimizer, max_lr=max_lr, total_steps=NUM_EPOCHS * NUM_BATCHES)

# save_path = '/Users/serafim/Desktop/Job/projects/science/hse/GAN-Estimation-Uncertainty/uncertainty/custom_mnist_tests'
NAME = 'base_wgan__lr_0,0001__1_3'

D_losses_final, G_losses_final = train(
    num_epochs = NUM_EPOCHS,
    data_loader = train_loader,
    D = D,
    G = G,
    D_optimizer = D_optimizer,
    G_optimizer = G_optimizer, 
    criterion = loss_function,
    device = device,
    name = NAME, 
    save_path = save_path, 
    progress_generator = True,
    plot_process = True,
    info_n = 1,
    classifier = CLASSIFIER,
    fid = True,
    fid_dataset = train_loader,
    weights_bce = True,
    test_fid = True,
    scheduler_D = scheduler_D,
    scheduler_G = scheduler_G
)
# plot_sine(G, save_path = save_path, name = NAME)


In [None]:
save_model(G, save_path, NAME, name2 = 'Generator')
save_model(D, save_path, NAME, name2 = 'Discriminator')

### Training without weights

In [None]:
G = Generator().to(device=device)
D = Discriminator().to(device=device)
# criterion = nn.BCELoss()
# loss_function = nn.BCELoss()
loss_function = WeightedVarianceBCE()
learning_rate_G = 0.0001
learning_rate_D = 0.0001
# lr = 0.0001
NUM_EPOCHS = 50

import math 
NUM_BATCHES = math.ceil(len(train_loader.dataset)/train_loader.batch_size)
print(f'NUM_BATCHES: {NUM_BATCHES}')
D_optimizer = torch.optim.Adam(D.parameters(), lr=learning_rate_G)
G_optimizer = torch.optim.Adam(G.parameters(), lr=learning_rate_D)

max_lr = 0.00015

scheduler_D = OneCycleLR(D_optimizer, max_lr=max_lr, total_steps=NUM_EPOCHS * NUM_BATCHES)
scheduler_G = OneCycleLR(G_optimizer, max_lr=max_lr, total_steps=NUM_EPOCHS * NUM_BATCHES)

NAME = 'base_gan__lr_0,0001__1_6'

D_losses_final, G_losses_final = train(
    num_epochs = NUM_EPOCHS,
    data_loader = train_loader,
    D = D,
    G = G,
    D_optimizer = D_optimizer,
    G_optimizer = G_optimizer, 
    criterion = loss_function,
    device = device,
    name = NAME, 
    save_path = save_path, 
    progress_generator = True,
    plot_process = True,
    info_n = 1,
    classifier = CLASSIFIER,
    fid = True,
    fid_dataset = train_loader,
    weights_bce = False,
    test_fid = True,
    scheduler_D = scheduler_D,
    scheduler_G = scheduler_G
)


Save model

In [None]:
save_model(G, save_path, NAME, name2 = 'Generator')
save_model(D, save_path, NAME, name2 = 'Discriminator')

In [None]:
G = Generator().to(device=device)
D = Discriminator().to(device=device)
# criterion = nn.BCELoss()
# loss_function = nn.BCELoss()
loss_function = WeightedVarianceBCE()
learning_rate_G = 0.0001
learning_rate_D = 0.0001
# lr = 0.0001
NUM_EPOCHS = 50

import math 
NUM_BATCHES = math.ceil(len(train_loader.dataset)/train_loader.batch_size)
print(f'NUM_BATCHES: {NUM_BATCHES}')
D_optimizer = torch.optim.Adam(D.parameters(), lr=learning_rate_G)
G_optimizer = torch.optim.Adam(G.parameters(), lr=learning_rate_D)

max_lr = 0.00015

scheduler_D = OneCycleLR(D_optimizer, max_lr=max_lr, total_steps=NUM_EPOCHS * NUM_BATCHES)
scheduler_G = OneCycleLR(G_optimizer, max_lr=max_lr, total_steps=NUM_EPOCHS * NUM_BATCHES)

NAME = 'base_gan__lr_0,0001__1_8'

D_losses_final, G_losses_final = train(
    num_epochs = NUM_EPOCHS,
    data_loader = train_loader,
    D = D,
    G = G,
    D_optimizer = D_optimizer,
    G_optimizer = G_optimizer, 
    criterion = loss_function,
    device = device,
    name = NAME, 
    save_path = save_path, 
    progress_generator = True,
    plot_process = True,
    info_n = 1,
    classifier = CLASSIFIER,
    fid = True,
    fid_dataset = train_loader,
    weights_bce = False,
    test_fid = True,
    scheduler_D = scheduler_D,
    scheduler_G = scheduler_G
)


In [None]:
save_model(G, save_path, NAME, name2 = 'Generator')
save_model(D, save_path, NAME, name2 = 'Discriminator')

In [None]:
G = Generator().to(device=device)
D = Discriminator().to(device=device)
# criterion = nn.BCELoss()
# loss_function = nn.BCELoss()
loss_function = WeightedVarianceBCE()
learning_rate_G = 0.0001
learning_rate_D = 0.0001
# lr = 0.0001
NUM_EPOCHS = 50

import math 
NUM_BATCHES = math.ceil(len(train_loader.dataset)/train_loader.batch_size)
print(f'NUM_BATCHES: {NUM_BATCHES}')
D_optimizer = torch.optim.Adam(D.parameters(), lr=learning_rate_G)
G_optimizer = torch.optim.Adam(G.parameters(), lr=learning_rate_D)

max_lr = 0.0005

scheduler_D = OneCycleLR(D_optimizer, max_lr=max_lr, total_steps=NUM_EPOCHS * NUM_BATCHES)
scheduler_G = OneCycleLR(G_optimizer, max_lr=max_lr, total_steps=NUM_EPOCHS * NUM_BATCHES)

NAME = 'base_gan__lr_0,0001__1_6'

D_losses_final, G_losses_final = train(
    num_epochs = NUM_EPOCHS,
    data_loader = train_loader,
    D = D,
    G = G,
    D_optimizer = D_optimizer,
    G_optimizer = G_optimizer, 
    criterion = loss_function,
    device = device,
    name = NAME, 
    save_path = save_path, 
    progress_generator = True,
    plot_process = True,
    info_n = 1,
    classifier = CLASSIFIER,
    fid = True,
    fid_dataset = train_loader,
    weights_bce = False,
    test_fid = True,
    scheduler_D = scheduler_D,
    scheduler_G = scheduler_G
)


In [None]:
save_model(G, save_path, NAME, name2 = 'Generator')
save_model(D, save_path, NAME, name2 = 'Discriminator')

### Compare models(with weights and without it)

> load lists with fid and vfid for every class for model with weights and without it and compare thier values


In [None]:
import pickle
import os

In [None]:


m1_path_base = 'mnist_tests/base_gan__lr_0,0001__1_8'
m2_path_base = 'mnist_tests/base_wgan__lr_0,0001__1_3'

# m1_path = f'{m1_path_base}/fid_test.pickle' 'mnist_tests/base_gan__lr_0,0001__1_12/fid_test.pickle' 
# m2_path = 'mnist_tests/base_wgan__lr_0,0001__1_11/fid_test.pickle'

m1_fid_path = os.path.join(m1_path_base, f'fid_test.pickle')
m2_fid_path = os.path.join(m2_path_base, f'fid_test.pickle')

m1_vfid_path = os.path.join(m1_path_base, f'vfid_test.pickle')
m2_vfid_path = os.path.join(m2_path_base, f'vfid_test.pickle')


with open(m1_fid_path, 'rb') as f:
    m1_fid = pickle.load(f)
    
with open(m2_fid_path, 'rb') as f:
    m2_fid = pickle.load(f)
    
with open(m1_vfid_path, 'rb') as f:
    m1_vfid = pickle.load(f)
    
with open(m2_vfid_path, 'rb') as f:
    m2_vfid = pickle.load(f)

In [None]:
def plot_grouped_bar(m1, m2, title, y_label, x_label = 'digits', bar_width=0.35, legend_labels=('Standard', 'Weighted'),
                    save_path = None):
    """
    Plots a grouped bar chart with two sets of data.
    
    Args:
    - m1 (list): List of tuples containing x, y values for the first set of bars.
    - m2 (list): List of tuples containing x, y values for the second set of bars.
    - title (str): Title of the plot.
    - x_label (str): Label for the x-axis.
    - y_label (str): Label for the y-axis.
    - bar_width (float): Width of the bars. Default is 0.35.
    - legend_labels (tuple): Labels for the legend. Default is ('Standard', 'Weighted').
    """
    
    # Extract the x values (indices) and y values
    x = [item[0] for item in m1]
    y1 = [item[1] for item in m1]
    y2 = [item[1] for item in m2]

    # Define the positions of the bars for m1 and m2
    x_pos_m1 = np.array(x) - bar_width / 2
    x_pos_m2 = np.array(x) + bar_width / 2
    
    # Create the bar plot
    plt.bar(x_pos_m1, y1, width=bar_width, label=legend_labels[0], color='r')
    plt.bar(x_pos_m2, y2, width=bar_width, label=legend_labels[1], color='black')
    
    # Labeling and title
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)
    plt.xticks(x)  # Set the ticks at the center of the groups
    plt.legend()
    if save_path:
        save_filename = os.path.join(save_path,  'fid4digits')
        plt.savefig(save_filename, dpi=300)
        plt.close()
        return

    # Show the plot
    plt.show()

In [None]:
plot_grouped_bar(m1_fid, m2_fid, title = 'FID for digits', y_label = 'FID')

In [None]:
plot_grouped_bar(m1_vfid, m2_vfid, title = 'vFID for digits', y_label = 'vFID')

In [None]:
os.path

Load standard and weighted Generator and Discriminator

In [None]:
standard_model_path = 'base_gan__lr_0,0001__1_4'
# standard_model_path = 'base_wgan__lr_0,0001__1_1'
w_model_path = 'base_wgan__lr_0,0001__1_3'
save_path = '/Users/serafim/Desktop/Job/projects/science/hse/GAN-Estimation-Uncertainty/uncertainty/custom_mnist_tests_metric1'
save_path = '/Users/serafim/Desktop/Job/projects/science/hse/GAN-Estimation-Uncertainty/uncertainty/custom_mnist_tests'
save_path = '/Users/serafim/Desktop/Job/projects/science/hse/GAN-Estimation-Uncertainty/uncertainty/mnist_tests'
# save_path = '/Users/serafim/Desktop/Job/projects/science/hse/GAN-Estimation-Uncertainty/uncertainty/custom_mnist_tests_v'
# save_path = '/Users/serafim/Desktop/Job/projects/science/hse/GAN-Estimation-Uncertainty/uncertainty/custom_mnist_tests_D'
# save_path = '/Users/serafim/Desktop/Job/projects/science/hse/GAN-Estimation-Uncertainty/uncertainty/mnist_tests_D'

In [None]:
standard_G = Generator().to(device=device)
standard_D = Discriminator().to(device=device)

path = os.path.join(save_path, standard_model_path, 'Generator')
standard_G.load_state_dict(torch.load(path))
standard_G.eval();

path = os.path.join(save_path, standard_model_path, 'Discriminator')
standard_D.load_state_dict(torch.load(path))
standard_D.eval();

In [None]:
standard_D.eval()

In [None]:
Weighted_G = Generator().to(device=device)
Weighted_D = Discriminator().to(device=device)

path = os.path.join(save_path, w_model_path, 'Generator')
Weighted_G.load_state_dict(torch.load(path))
Weighted_G.eval();

path = os.path.join(save_path, w_model_path, 'Discriminator')
Weighted_D.load_state_dict(torch.load(path))
Weighted_D.eval();

Load classifier to get feature representation model

In [None]:
from modules.mnist_classifier import eval_model
from modules.mnist_models import CNN, CNN2, CNN3
CLASSIFIER = CNN3().to(device=device)
PATH = './mnist_tests_classifier/classifier__CNN3__1_9_32/classifier_model.pt'
CLASSIFIER.load_state_dict(torch.load(PATH))
CLASSIFIER.eval()

In [None]:
from modules.mnist_models import CNNClassifierWrapper

FEATURE_EXTRACTOR = CNNClassifierWrapper(CLASSIFIER, layer_index = -6, use_global_pooling = False)
FEATURE_EXTRACTOR.register_hook()

Check that FEATURE_EXTRACTOR works correctly

In [None]:
# FEATURE_EXTRACTOR = CNNClassifierWrapper(CLASSIFIER, layer_index = -6, use_global_pooling = False)

img_example = torch.rand(32, 1, 28, 28)  #image 28 on 28 with 1 chanel like in mnist
# FEATURE_EXTRACTOR.register_hook()  # Register hook before making predictions
output = FEATURE_EXTRACTOR(img_example)
# FEATURE_EXTRACTOR.remove_hook()
print(output.size())

In [None]:
from modules.fid import split_mnist_loader_cats, calculate_multiple_fid
from modules.mnist_classifier import calculate_confusion_matrix
from modules.mnist import get_fake_dataloader, FakeDataset

In [None]:
fake_loader_eval = get_fake_dataloader(Weighted_G,
                                              device,
                                              batch_size=32,
                                              num_examples_per_class=10000,
                                              noise_dim=100,
                                              shuffle=True)

In [None]:
# fake_loader_eval[0]

In [None]:
# ls
# ls /Users/serafim/Desktop

In [None]:
a = calculate_confusion_matrix(CLASSIFIER, fake_loader_eval, device, epoch = 50, save_path = '/Users/serafim/Desktop', name = 'cf', metrics = True)


In [None]:
b = calculate_confusion_matrix(CLASSIFIER, fake_loader_eval, device, epoch = 50, save_path = '/Users/serafim/Desktop', name = 'cf', metrics = True)


In [None]:
b

In [None]:
b

In [None]:
a

In [None]:
'accuracy': 0.98388,
 'macro avg': {'precision': 0.9839122704969807,
  'recall': 0.9838800000000001,
  'f1-score': 0.9838827441740126,
  'support': 100000},
 'weighted avg': {'precision': 0.9839122704969808,
  'recall': 0.98388,
  'f1-score': 0.9838827441740123,
  'support': 100000}}




In [None]:

# category_data_real = split_mnist_loader_cats(train_loader, max_images = None)
# fid_cats, vfid_cats = calculate_multiple_fid(standard_G, FEATURE_EXTRACTOR, category_data_real, device)


In [None]:
category_data_real = split_mnist_loader_cats(test_loader, max_images = None)

In [None]:
# fid_cats, vfid_cats = calculate_multiple_fid(standard_G, FEATURE_EXTRACTOR, category_data_real, device)
w_fid_cats, w_vfid_cats = calculate_multiple_fid(Weighted_G, FEATURE_EXTRACTOR, category_data_real, device)


In [None]:
# w_fid_cats, w_vfid_cats

custom dataset with `{2:1000, 8: 1000, 7: 1000}`(there are 1000 images of 2, and etc...)

In [None]:
plot_grouped_bar(list(fid_cats.items()), list(w_fid_cats.items()), title = 'FID for digits', y_label = 'FID', save_path = '/Users/serafim/Desktop/cf')

In [None]:
plot_grouped_bar(list(vfid_cats.items()), list(w_vfid_cats.items()), title = 'vFID for digits', y_label = 'vFID')

In [None]:
plot_grouped_bar(list(fid_cats.items()), list(w_fid_cats.items()), title = 'FID for digits', y_label = 'FID')

In [None]:
sum(fid_cats.values())/len(fid_cats), sum(w_fid_cats.values())/len(w_fid_cats)

In [None]:
(34.62795008633528, 36.0239711077637)

standard dataset

In [None]:
plot_grouped_bar(list(fid_cats.items()), list(w_fid_cats.items()), title = 'FID for digits', y_label = 'FID')


In [None]:
plot_grouped_bar(list(vfid_cats.items()), list(w_vfid_cats.items()), title = 'vFID for digits', y_label = 'vFID')

In [None]:
sum(fid_cats.values())/len(fid_cats)

In [None]:
sum(w_fid_cats.values())/len(w_fid_cats)

In [None]:
# plot_grouped_bar(list(fid_cats.items()), list(w_fid_cats.items()), title = 'FID for digits', y_label = 'FID', save_path = '/Users/serafim/Desktop/cf')

In [None]:
a, b = calculate_multiple_fid(Weighted_G, FEATURE_EXTRACTOR, category_data_real, device)

In [None]:
a

In [None]:
# b

In [None]:
plt.plot(a)

In [None]:
# Extracting keys and values
keys = list(a.keys())
values = list(a.values())

# Creating the bar plot
plt.figure(figsize=(8, 5))
plt.bar(keys, values, color='red')

# Adding titles and labels
plt.title('Trace of Covariance Matrix for Real Objects in MNIST')
plt.xlabel('Digits')
plt.ylabel('Trace Value')

save_filename = os.path.join(save_path,  'var4digits')
plt.savefig(save_filename, dpi=300)
plt.close()
# Display the plot
# plt.show()

In [None]:
save_filename