In [5]:
from torchfactor.factorization.svdnet import SVDNet
from torchfactor.experiment.experiment import Experiment
from polyu_dataset import PolyUDataset
import torch
import numpy as np

In [6]:
IMAGE_TYPE = "mean"
TRAIN_BATCH_SIZE = 1
VAL_BATCH_SIZE = 1

training_dataset = PolyUDataset(split_type="train", image_type=IMAGE_TYPE, downsample_shape=(128,128))
training_dataloader = torch.utils.data.DataLoader(
    training_dataset, batch_size=TRAIN_BATCH_SIZE, 
    shuffle=True, num_workers=0, drop_last=True
)

validation_dataset = PolyUDataset(split_type='val', image_type=IMAGE_TYPE, in_memory=False, downsample_shape=(128,128))
validation_dataloader = torch.utils.data.DataLoader(
    validation_dataset, batch_size=VAL_BATCH_SIZE,
    shuffle=False, num_workers=0, drop_last=False
)

In [9]:
# SVDNet as the full network for factorization can only learn a single image
indices = [np.random.randint(len(training_dataset))] # randomly select a single image

single_ele_dataloader = torch.utils.data.DataLoader(
    training_dataset, batch_size=TRAIN_BATCH_SIZE, 
    num_workers=0, drop_last=False,
    sampler=torch.utils.data.SubsetRandomSampler(indices)
)

In [11]:
# SVDNet learning to factor a single image
net = SVDNet(128, 128)
optimizer = torch.optim.Adam(net.parameters(), lr=2e-1)

def loss(x, x_hat):
    return ((x - x_hat)**2).sum()

experiment = Experiment(
    net=net, loss=loss, optimizer=optimizer,
    train_dataloader=single_ele_dataloader, validation_dataloader=single_ele_dataloader,
    use_eye_as_net_input=True, inputs_are_ground_truth=True
)

train_loss_over_epochs, val_loss_over_epochs = experiment.run(train_epochs=500, train_validation_interval=1)
print(train_loss_over_epochs, val_loss_over_epochs)

epoch 0: total loss is 14044.435546875, avg loss is 14044.435546875
epoch 0: val avg loss is 13878.0263671875
epoch 1: total loss is 13878.0263671875, avg loss is 13878.0263671875
epoch 1: val avg loss is 13786.6708984375
epoch 2: total loss is 13786.6708984375, avg loss is 13786.6708984375
epoch 2: val avg loss is 13707.44140625
epoch 3: total loss is 13707.44140625, avg loss is 13707.44140625
epoch 3: val avg loss is 13620.70703125
epoch 4: total loss is 13620.70703125, avg loss is 13620.70703125
epoch 4: val avg loss is 13564.42578125
epoch 5: total loss is 13564.42578125, avg loss is 13564.42578125
epoch 5: val avg loss is 13498.1572265625
epoch 6: total loss is 13498.1572265625, avg loss is 13498.1572265625
epoch 6: val avg loss is 13405.4990234375
epoch 7: total loss is 13405.4990234375, avg loss is 13405.4990234375
epoch 7: val avg loss is 13331.5419921875
epoch 8: total loss is 13331.5419921875, avg loss is 13331.5419921875
epoch 8: val avg loss is 13259.07421875
epoch 9: total