forked from i-deal/MLR-2.0
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_labels.py
executable file
·30 lines (23 loc) · 1.23 KB
/
train_labels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from label_network import *
import torch
from mVAE import vae, load_checkpoint
from torch.utils.data import DataLoader, ConcatDataset
from dataset_builder import Dataset
load_checkpoint('output_emnist_recurr/checkpoint_150.pth')
bs = 50
#load_checkpoint_shapelabels('output_label_net/checkpoint_shapelabels5.pth')
transforms = { 'colorize':True}
emnist_dataset = Dataset('emnist', transforms)
mnist_dataset = Dataset('mnist', transforms)
train_loader_noSkip = torch.utils.data.DataLoader(dataset=ConcatDataset([emnist_dataset, mnist_dataset, mnist_dataset]), batch_size=bs, shuffle=True, drop_last= True)
for epoch in range (1,21):
train_labels(epoch, train_loader_noSkip)
if epoch in [1,5,10,20]:
checkpoint = {
'state_dict_shape_labels': vae_shape_labels.state_dict(),
'state_dict_color_labels': vae_color_labels.state_dict(),
'optimizer_shape' : optimizer_shapelabels.state_dict(),
'optimizer_color': optimizer_colorlabels.state_dict(),
}
torch.save(checkpoint,f'output_label_net/checkpoint_shapelabels'+str(epoch)+'.pth')
torch.save(checkpoint, f'output_label_net/checkpoint_colorlabels' + str(epoch) + '.pth')