In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
class SoftmaxWeight(nn.Module):
    def __init__(self, K, p, hidden_dimensions =[]):
        super().__init__()
        self.K = K
        self.p = p
        self.network_dimensions = [self.p] + hidden_dimensions + [self.K]
        network = []
        for h0, h1 in zip(self.network_dimensions, self.network_dimensions[1:]):
            network.extend([nn.Linear(h0, h1),nn.Tanh(),])
        network.pop()
        self.f = nn.Sequential(*network)
        self.f = nn.Sequential(*network)
        self.f[-1].bias = nn.Parameter(torch.ones(self.K))
        self.f[-1].weight = nn.Parameter(torch.zeros_like(self.f[-1].weight))

    def log_prob(self, z):
        unormalized_log_w = self.f.forward(z)
        return unormalized_log_w - torch.logsumexp(unormalized_log_w, dim=-1, keepdim=True)
    
class ConvNetWeight(nn.Module):
    def __init__(self, K):
        super(ConvNetWeight, self).__init__()
        self.K = K
        self.conv1 = nn.Conv2d(3, 8, 3, 1)
        #premier argument = couleur vs gris 3 vs 1
        self.conv2 = nn.Conv2d(8, 8, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(14*14*8, 128)
        self.fc2 = nn.Linear(128, K)

        # x represents our data
    def forward(self, x):
        # Pass data through conv1
        x = self.conv1(x)
        # Use the rectified-linear activation function over x
        x = F.relu(x)

        x = self.conv2(x)
        x = F.relu(x)

        # Run max pooling over x
        x = F.max_pool2d(x, 2)
        # Pass data through dropout1
        x = self.dropout1(x)
        # Flatten x with start_dim=1
        x = torch.flatten(x, 1)
        # Pass data through fc1
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

    def log_prob(self, z):
        unormalized_log_w = self.forward(z)
        return unormalized_log_w - torch.logsumexp(unormalized_log_w, dim=-1, keepdim=True)

In [2]:
from IPython.display import clear_output

In [9]:
###MNIST###
import torch
import torchvision
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
fmnist_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=None)
fmnist_testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=None)
randperm_train = torch.randperm(torch.tensor(fmnist_trainset.targets).shape[0])
randperm_test_val = torch.randperm(torch.tensor(fmnist_testset.targets).shape[0])

num_train = 5000

train_labels = torch.tensor(fmnist_trainset.targets)[randperm_train][:num_train]
test_labels = torch.tensor(fmnist_testset.targets)[randperm_test_val][:5000]
val_labels = torch.tensor(fmnist_testset.targets)[randperm_test_val][5000:]

extracted_train = torch.tensor(fmnist_trainset.data).float()[randperm_train][:num_train].reshape(num_train,3,32,32)
train_samples = (extracted_train + torch.rand(extracted_train.shape))/256
extracted_test = torch.tensor(fmnist_testset.data).float()[randperm_test_val][:5000].reshape(5000,3,32,32)
test_samples = (extracted_test + torch.rand(extracted_test.shape))/256
extracted_val = torch.tensor(fmnist_testset.data).float()[randperm_test_val][5000:].reshape(5000,3,32,32)
val_samples = (extracted_val + torch.rand(extracted_val.shape))/256
torch.cuda.empty_cache()

Files already downloaded and verified
Files already downloaded and verified


In [10]:
from tqdm import tqdm

In [12]:
w = SoftmaxWeight(10,train_samples.shape[-1],[128,128,128])
w = ConvNetWeight(10)
optim = torch.optim.Adam(w.parameters(), lr = 5e-4)
list_accuracy = []
device = torch.device('cpu' if torch.cuda.is_available() else 'cpu')

train_samples = train_samples.to(device)
train_labels = train_labels.to(device)

test_samples = test_samples.to(device)
test_labels = test_labels.to(device)

val_samples = val_samples.to(device)
val_labels = val_labels.to(device)

counts = torch.unique(train_labels, return_counts = True)[1]/train_labels.shape[0]
weights = torch.distributions.Dirichlet(torch.ones(train_labels.shape[0])).sample().to(device)
list_accuracy_train = []
list_accuracy_test = []
list_accuracy_val = []

w.to(device)
pbar = tqdm(range(1000))
for t in pbar:
    optim.zero_grad()
    loss_train = -torch.sum(weights*(w.log_prob(train_samples)*counts)[range(train_samples.shape[0]),train_labels])
    with torch.no_grad():
        accuracy_train = torch.mean((torch.max(w.log_prob(train_samples), dim = 1)[1] == train_labels).float())
        list_accuracy_train.append(accuracy_train.cpu().item())
        accuracy_test = torch.mean((torch.max(w.log_prob(test_samples), dim = 1)[1] == test_labels).float())
        list_accuracy_test.append(accuracy_test.cpu().item())
        accuracy_val= torch.mean((torch.max(w.log_prob(val_samples), dim = 1)[1] == val_labels).float())
        list_accuracy_val.append(accuracy_val.cpu().item())
    loss_train.backward()
    optim.step()
    pbar.set_postfix_str('loss_train = ' + str(round(loss_train.item(),4)) +'; acc_train =' + str(round(accuracy_train.item(),4)) + '; acc_test =' + str(round(accuracy_test.item(),4)) + '; acc_validation =' + str(round(accuracy_val.item(),4)))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [27:08<00:00,  1.63s/it, loss_train = 0.1407; acc_train =0.3994; acc_test =0.2048; acc_validation =0.2132]


In [None]:
plt.figure(figsize = (15,8))
plt.plot(list_accuracy_train, label = 'train')
plt.plot(list_accuracy_test, label = 'test')
plt.plot(list_accuracy_val, label = 'validation')
plt.legend()
plt.show()

In [None]:
list_accuracy_current_gibbs = []
list_accuracy_train_gibbs = []
list_accuracy_test_gibbs = []
list_accuracy_val_gibbs = []
pbar = tqdm(range(100))
log_prob_train_gibbs = []
log_prob_test_gibbs = []
log_prob_val_gibbs = []

w = SoftmaxWeight(10,train_samples.shape[-1],[128,128,128]).to(device)
current_samples = torch.cat([train_samples.to(device), test_samples.to(device)], dim = 0)
for i in pbar:
    fake_labels= torch.distributions.Categorical(torch.exp(w.to(device).log_prob(test_samples.to(device)))).sample()
    current_labels = torch.cat([train_labels.to(device), fake_labels], dim =0)
    counts = torch.unique(current_labels, return_counts = True)[1]/current_labels.shape[0]
    weights = torch.distributions.Dirichlet(torch.ones(current_labels.shape[0])).sample().to(device)
    w = SoftmaxWeight(10,train_samples.shape[-1],[128,128,128]).to(device)
    optim = torch.optim.Adam(w.parameters(), lr = 5e-4)
    for t in range(500):
        optim.zero_grad()
        loss_train = -torch.sum(weights*(w.log_prob(current_samples))[torch.tensor(range(current_labels.shape[0])).to(device),current_labels])
        with torch.no_grad():
            accuracy_current = torch.mean((torch.max(w.log_prob(current_samples), dim = 1)[1] == current_labels).float()).cpu()
            accuracy_val = torch.mean((torch.max(w.log_prob(val_samples), dim = 1)[1] == val_labels).float()).cpu()
            accuracy_train = torch.mean((torch.max(w.log_prob(train_samples), dim = 1)[1] == train_labels).float()).cpu()
            accuracy_test = torch.mean((torch.max(w.log_prob(test_samples), dim = 1)[1] == test_labels).float()).cpu()
            list_accuracy_current_gibbs.append(accuracy_current.item())
            list_accuracy_train_gibbs.append(accuracy_train.item())
            list_accuracy_test_gibbs.append(accuracy_test.item())
            list_accuracy_val_gibbs.append(accuracy_val.item())
        loss_train.backward()
        optim.step()
        pbar.set_postfix_str('loss_train = ' + str(round(loss_train.item(),4))+'; acc_current =' + str(round(accuracy_current.item(),4)) +'; acc_train =' + str(round(accuracy_train.item(),4)) + '; acc_test =' + str(round(accuracy_test.item(),4)) + '; acc_val =' + str(round(accuracy_val.item(),4)))

    clear_output(wait = True)
    plt.hist(torch.prod(torch.exp(w.log_prob(test_samples)), dim = -1).cpu().detach().numpy(), bins = 50)
    plt.show()
    clear_output(wait = True)
    plt.figure(figsize = (15,8))
    plt.plot(list_accuracy_train_gibbs, label = 'Train')
    plt.plot(list_accuracy_current_gibbs, label ='Current')
    plt.plot(list_accuracy_val_gibbs, label = 'Validation')
    plt.plot(list_accuracy_test_gibbs, label = 'Test')
    plt.legend()
    plt.show()
    log_prob_train_gibbs.append(w.log_prob(train_samples))
    log_prob_test_gibbs.append(w.log_prob(test_samples))
    log_prob_val_gibbs.append(w.log_prob(val_samples))
bagging_log_prob_train_gibbs = torch.mean(torch.stack(log_prob_train_gibbs), dim =0)
bagging_log_prob_test_gibbs = torch.mean(torch.stack(log_prob_test_gibbs), dim =0)
bagging_log_prob_val_gibbs = torch.mean(torch.stack(log_prob_val_gibbs), dim =0)
bagging_accuracy_train= torch.mean((torch.max(bagging_log_prob_train_gibbs, dim = 1)[1] == train_labels).float())
print(bagging_accuracy_train)
bagging_accuracy_test = torch.mean((torch.max(bagging_log_prob_test_gibbs, dim = 1)[1] == test_labels).float())
print(bagging_accuracy_test)
bagging_accuracy_val= torch.mean((torch.max(bagging_log_prob_val_gibbs, dim = 1)[1] == val_labels).float())
print(bagging_accuracy_val)