In [1]:
import torch
from torchvision import transforms

from torch.utils.tensorboard import SummaryWriter

In [2]:
# settings
seed = 2453466
checkpt_dir = 'checkpoints'
dataset = 'MNIST'
num_experts = 4
input_size = 28 * 28

# Initialization hyper-parameters
load_initialized_experts = False
init_expert_model = 'blockmodel'
init_optimizer = 'sgd'
init_learning_rate = .01
init_weight_decay = 0
init_epochs = 1

# Training hyper-parameters
discriminator_optimizer = 'sgd'
discriminator_learning_rate = .01
discriminator_weight_decay = 0

expert_optimizer = 'sgd'
expert_learning_rate = .01
expert_weight_decay = 0

epochs = 1

In [3]:
# load dataset
from torch.utils.data import DataLoader
from dataset import MNISTDataset

train_dataset = MNISTDataset(train=True, transformer_names=["rotate_left", "rotate_left"])
test_dataset = MNISTDataset(train=False, transformer_names=["rotate_left", "rotate_left"])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [4]:
# Init seed and training device
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
torch.manual_seed(seed)
torch.manual_seed(seed)
if cuda:
    torch.cuda.manual_seed_all(seed)

In [5]:
import os

In [6]:
# Directory for checkpoints
if not os.path.exists(checkpt_dir):
    os.mkdir(checkpt_dir)

In [7]:
# initialize experts
from model import Expert, Discriminator
experts = [Expert(dataset=dataset, input_size=input_size, optim=init_optimizer, lr=init_learning_rate, weight_decay=init_weight_decay).to(device) for _ in range(num_experts)]

In [8]:
# Losses
loss_initial = torch.nn.MSELoss(reduction='mean')
criterion = torch.nn.BCELoss(reduction='mean')

In [9]:
# outdir = 'results'
# name = 'test'
# log_dir = os.path.join(outdir, 'logs')
# if not os.path.exists(log_dir):
#     os.mkdir(log_dir)
# log_dir_exp = os.path.join(log_dir, name)
# os.mkdir(log_dir_exp)
# writer = SummaryWriter(log_dir=log_dir_exp)

In [10]:
# for i in range(100):
#     writer.add_scalar(f"num", i+4, i)

In [11]:
# writer.close()

In [12]:
# Initialize Experts as approximately Identity on Transformed Data
from trainer import initialize_expert


for i, expert in enumerate(experts):
    pass
    # if load_initialized_experts:
    #     path = os.path.join(checkpt_dir, f'{model_for_initialized_experts}_E_{i+1}_init.pth')
    #     init_weights(expert, path)
    # else:
    #     initialize_expert(
    #         epochs=init_epochs, 
    #         architecture_name=init_expert_model, 
    #         expert=expert, 
    #         i=i, 
    #         loss=loss_initial, 
    #         data_train=train_loader,
    #         device=device,
    #         checkpt_dir=checkpt_dir,
    #     )

In [13]:
# Initialize Discriminator
discriminator = Discriminator(dataset=dataset, input_size=input_size, 
                                optim=discriminator_optimizer, 
                                lr=discriminator_learning_rate, 
                                weight_decay=discriminator_weight_decay).to(device)

# Optimizers
for expert in experts:
    expert.set_optimizer(optim=expert_optimizer, lr=expert_learning_rate, weight_decay=expert_weight_decay)



In [14]:
from trainer import *

# Training
for epoch in range(epochs):
    train_system(epoch, experts, discriminator, criterion, train_loader, input_size, device)
    # if epoch % args.log_interval == 0 or epoch == args.epochs-1:
        # torch.save(discriminator.state_dict(), checkpt_dir + '/{}_D.pth'.format(args.name))
        # for i in range(args.num_experts):
            # torch.save(experts[i].state_dict(), checkpt_dir + '/{}_E_{}.pth'.format(args.name, i+1))


epoch [1] loss_D_transformed 0.6813
epoch [1] loss_D_canon 0.7164
epoch [1] loss_D_transformed 0.6829
epoch [1] loss_D_canon 0.7020
epoch [1] loss_D_transformed 0.6850
epoch [1] loss_D_canon 0.6894
epoch [1] loss_D_transformed 0.6866
epoch [1] loss_D_canon 0.6769
epoch [1] loss_D_transformed 0.6884
epoch [1] loss_D_canon 0.6657
epoch [1] loss_D_transformed 0.6901
epoch [1] loss_D_canon 0.6545
epoch [1] loss_D_transformed 0.6918
epoch [1] loss_D_canon 0.6438
epoch [1] loss_D_transformed 0.6935
epoch [1] loss_D_canon 0.6327
epoch [1] loss_D_transformed 0.6953
epoch [1] loss_D_canon 0.6222
epoch [1] loss_D_transformed 0.6969
epoch [1] loss_D_canon 0.6128
epoch [1] loss_D_transformed 0.6983
epoch [1] loss_D_canon 0.6032
epoch [1] loss_D_transformed 0.6999
epoch [1] loss_D_canon 0.5931
epoch [1] loss_D_transformed 0.7017
epoch [1] loss_D_canon 0.5835
epoch [1] loss_D_transformed 0.7033
epoch [1] loss_D_canon 0.5739
epoch [1] loss_D_transformed 0.7050
epoch [1] loss_D_canon 0.5637
epoch [1] 

KeyboardInterrupt: 