In [1]:
import torch
from torchvision import transforms

In [2]:
# settings
seed = 2453466
checkpt_dir = 'checkpoints'
dataset = 'MNIST'
num_experts = 4
input_size = 28 * 28
load_initialized_experts = False
model_for_initialized_experts = 'blockmodel'
optimizer_initialize = 'adam'
learning_rate_initialize = .1
weight_decay = .1
epochs_init = 10

In [3]:
# load dataset
from torch.utils.data import DataLoader
from dataset import MNISTDataset

train_dataset = MNISTDataset(train=True, transformer_names=["rotate_left", "rotate_left"])
test_dataset = MNISTDataset(train=False, transformer_names=["rotate_left", "rotate_left"])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [4]:
# Init seed and training device
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
torch.manual_seed(seed)
torch.manual_seed(seed)
if cuda:
    torch.cuda.manual_seed_all(seed)

In [5]:
import os

In [6]:
# Directory for checkpoints
if not os.path.exists(checkpt_dir):
    os.mkdir(checkpt_dir)

In [7]:
# initialize models
from model import Expert, Discriminator
experts = [Expert(dataset=dataset, input_size=input_size).to(device) for _ in range(num_experts)]
discriminator = Discriminator(dataset=dataset, input_size=input_size).to(device)

In [8]:
# Losses
loss_initial = torch.nn.MSELoss(reduction='mean')
criterion = torch.nn.BCELoss(reduction='mean')

In [9]:
# Initialize Experts as approximately Identity on Transformed Data
from trainer import initialize_expert

for i, expert in enumerate(experts):
    if load_initialized_experts:
        path = os.path.join(checkpt_dir, f'{model_for_initialized_experts}_E_{i+1}_init.pth')
        init_weights(expert, path)
    else:
        if optimizer_initialize == 'adam':
            optimizer_E = torch.optim.Adam(expert.parameters(), lr=learning_rate_initialize,
                                                weight_decay=weight_decay)
        elif optimizer_initialize == 'sgd':
            optimizer_E = torch.optim.SGD(expert.parameters(), lr=learning_rate_initialize,
                                                weight_decay=weight_decay)
        else:
            raise NotImplementedError

        initialize_expert(
            epochs=epochs_init, 
            architecture_name=model_for_initialized_experts, 
            expert=expert, 
            i=i, 
            optimizer=optimizer_E, 
            loss=loss_initial, 
            data_train=train_loader,
            device=device,
        )

Initializing expert [1] as identity on preturbed data


Batch: 100%|██████████| 782/782 [00:38<00:00, 20.53it/s]
Epoch:  10%|█         | 1/10 [00:38<05:42, 38.10s/it]

initialization epoch [1] expert [1] loss 7238.9906


Batch: 100%|██████████| 782/782 [00:45<00:00, 17.29it/s]
Epoch:  20%|██        | 2/10 [01:23<05:38, 42.29s/it]

initialization epoch [2] expert [1] loss 7238.6349


