In [1]:
import torch
from torchvision import transforms

In [10]:
# settings
seed = 2453466
checkpt_dir = 'checkpoints'
dataset = 'MNIST'
num_experts = 4
input_size = 28 * 28
load_initialized_experts = False
model_for_initialized_experts = 'blockmodel'
optimizer_initialize = 'sgd'
learning_rate_initialize = .01
weight_decay = 0
epochs_init = 4

In [11]:
# load dataset
from torch.utils.data import DataLoader
from dataset import MNISTDataset

train_dataset = MNISTDataset(train=True, transformer_names=["rotate_left", "rotate_left"])
test_dataset = MNISTDataset(train=False, transformer_names=["rotate_left", "rotate_left"])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [12]:
# Init seed and training device
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
torch.manual_seed(seed)
torch.manual_seed(seed)
if cuda:
    torch.cuda.manual_seed_all(seed)

In [13]:
import os

In [14]:
# Directory for checkpoints
if not os.path.exists(checkpt_dir):
    os.mkdir(checkpt_dir)

In [15]:
# initialize models
from model import Expert, Discriminator
experts = [Expert(dataset=dataset, input_size=input_size).to(device) for _ in range(num_experts)]
discriminator = Discriminator(dataset=dataset, input_size=input_size).to(device)

In [16]:
# Losses
loss_initial = torch.nn.MSELoss(reduction='mean')
criterion = torch.nn.BCELoss(reduction='mean')

In [17]:
# Initialize Experts as approximately Identity on Transformed Data
from trainer import initialize_expert

for i, expert in enumerate(experts):
    if load_initialized_experts:
        path = os.path.join(checkpt_dir, f'{model_for_initialized_experts}_E_{i+1}_init.pth')
        init_weights(expert, path)
    else:
        if optimizer_initialize == 'adam':
            optimizer_E = torch.optim.Adam(expert.parameters(), lr=learning_rate_initialize,
                                                weight_decay=weight_decay)
        elif optimizer_initialize == 'sgd':
            optimizer_E = torch.optim.SGD(expert.parameters(), lr=learning_rate_initialize,
                                                weight_decay=weight_decay)
        else:
            raise NotImplementedError

        initialize_expert(
            epochs=epochs_init, 
            architecture_name=model_for_initialized_experts, 
            expert=expert, 
            i=i, 
            optimizer=optimizer_E, 
            loss=loss_initial, 
            data_train=train_loader,
            device=device,
            checkpt_dir=checkpt_dir,
        )

Initializing expert [1] as identity on preturbed data


Batch: 100%|██████████| 782/782 [00:10<00:00, 73.42it/s]
Epoch:  25%|██▌       | 1/4 [00:10<00:31, 10.66s/it]

initialization epoch [1] expert [1] loss 0.7343


Batch: 100%|██████████| 782/782 [00:11<00:00, 65.99it/s]
Epoch:  50%|█████     | 2/4 [00:22<00:22, 11.36s/it]

initialization epoch [2] expert [1] loss 0.4716


Batch: 100%|██████████| 782/782 [00:11<00:00, 68.33it/s]
Epoch:  75%|███████▌  | 3/4 [00:33<00:11, 11.40s/it]

initialization epoch [3] expert [1] loss 0.3729


Batch: 100%|██████████| 782/782 [00:13<00:00, 59.32it/s]
Epoch: 100%|██████████| 4/4 [00:47<00:00, 11.79s/it]


initialization epoch [4] expert [1] loss 0.3157
Initializing expert [2] as identity on preturbed data


Batch: 100%|██████████| 782/782 [00:17<00:00, 45.30it/s]
Epoch:  25%|██▌       | 1/4 [00:17<00:51, 17.27s/it]

initialization epoch [1] expert [2] loss 0.7346


Batch: 100%|██████████| 782/782 [00:16<00:00, 48.39it/s]
Epoch:  50%|█████     | 2/4 [00:33<00:33, 16.62s/it]

initialization epoch [2] expert [2] loss 0.4700


Batch: 100%|██████████| 782/782 [00:15<00:00, 50.59it/s]
Epoch:  75%|███████▌  | 3/4 [00:48<00:16, 16.09s/it]

initialization epoch [3] expert [2] loss 0.3719


Batch: 100%|██████████| 782/782 [00:14<00:00, 55.11it/s]
Epoch: 100%|██████████| 4/4 [01:03<00:00, 15.78s/it]


initialization epoch [4] expert [2] loss 0.3151
Initializing expert [3] as identity on preturbed data


Batch: 100%|██████████| 782/782 [00:15<00:00, 51.57it/s]
Epoch:  25%|██▌       | 1/4 [00:15<00:45, 15.17s/it]

initialization epoch [1] expert [3] loss 0.7384


Batch: 100%|██████████| 782/782 [00:16<00:00, 47.59it/s]
Epoch:  50%|█████     | 2/4 [00:31<00:31, 15.92s/it]

initialization epoch [2] expert [3] loss 0.4696


Batch: 100%|██████████| 782/782 [00:16<00:00, 46.12it/s]
Epoch:  75%|███████▌  | 3/4 [00:48<00:16, 16.40s/it]

initialization epoch [3] expert [3] loss 0.3705


Batch: 100%|██████████| 782/782 [00:18<00:00, 43.34it/s]
Epoch: 100%|██████████| 4/4 [01:06<00:00, 16.66s/it]


initialization epoch [4] expert [3] loss 0.3135
Initializing expert [4] as identity on preturbed data


Batch: 100%|██████████| 782/782 [00:18<00:00, 42.79it/s]
Epoch:  25%|██▌       | 1/4 [00:18<00:54, 18.28s/it]

initialization epoch [1] expert [4] loss 0.7424


Batch: 100%|██████████| 782/782 [00:19<00:00, 40.81it/s]
Epoch:  50%|█████     | 2/4 [00:37<00:37, 18.80s/it]

initialization epoch [2] expert [4] loss 0.4719


Batch: 100%|██████████| 782/782 [00:14<00:00, 55.85it/s]
Epoch:  75%|███████▌  | 3/4 [00:51<00:16, 16.61s/it]

initialization epoch [3] expert [4] loss 0.3725


Batch: 100%|██████████| 782/782 [00:15<00:00, 50.48it/s]
Epoch: 100%|██████████| 4/4 [01:06<00:00, 16.74s/it]

initialization epoch [4] expert [4] loss 0.3152





In [None]:
# Optimizers
optimizers_E = []
for i in range(args.num_experts):
    if args.optimizer_experts == 'adam':
        optimizer_E = torch.optim.Adam(experts[i].parameters(), lr=args.learning_rate_expert,
                                        weight_decay=args.weight_decay)
    elif args.optimizer_experts == 'sgd':
        optimizer_E = torch.optim.SGD(experts[i].parameters(), lr=args.learning_rate_expert,
                                        weight_decay=args.weight_decay)
    else:
        raise NotImplementedError
    optimizers_E.append(optimizer_E)
if args.optimizer_discriminator == 'adam':
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.learning_rate_discriminator,
                                    weight_decay=args.weight_decay)
elif args.optimizer_discriminator == 'sgd':
    optimizer_D = torch.optim.SGD(discriminator.parameters(), lr=args.learning_rate_discriminator,
                                    weight_decay=args.weight_decay)
