In [1]:
%load_ext autoreload
%autoreload 2

import torch
from torch import nn, optim
from torch.utils.data import DataLoader

import torchtry as tt
from torchtry.logging import tensor_0_1_to_0_255

In [2]:
tt.set_storage_dir('.')

In [3]:
class ExampleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(3 * 128 * 128, 3 * 16 * 16),
            nn.Linear(3 * 16 * 16, 3 * 128 * 128),
        )
    
    def forward(self, x):
        x_view = x.view(x.size(0), -1)
        x_fc = self.fc(x_view)
        
        return x_fc.view(x.size(0), 3, 128, 128)

In [4]:
example_model = ExampleModel()

In [5]:
example_train_dataset = [
    (
        torch.rand((3, 128, 128)),
        torch.cat([torch.ones((3, 128, 64)), torch.zeros((3, 128, 64))], dim=2),
    )
    for _ in range(100)
]

In [6]:
class ExampleExperiment(tt.Experiment):
    def setup_experiment(self):
        self.model = example_model
        self.optimizer = optim.SGD(params=self.model.parameters(), lr=1e-1)
        self.train_dataloader = DataLoader(
            example_train_dataset,
            batch_size=4, shuffle=True,
        )
    
    def train_step(self, sample, step_number):
        self.model.train()
        images, goals = sample

        predictions = self.model(images)

        loss_res = nn.functional.mse_loss(predictions, goals)
        self.optimizer.zero_grad()
        loss_res.backward()
        self.optimizer.step()
        
        predictions_brightness = predictions.detach().mean(1).view(-1)

        return {
            'loss': loss_res,
            'images': tensor_0_1_to_0_255(images),
            'predictions': tensor_0_1_to_0_255(predictions),
            'goals': tensor_0_1_to_0_255(goals),
            'predictions_brightness': predictions_brightness,
        }

In [7]:
example_experiment = ExampleExperiment()

In [None]:
example_experiment.train(
    save_frequency=3,
    scalars_log_frequency=0.1,
    images_log_frequency=0.5,
    histograms_log_frequency=0.5,
    finish_step=1000,
    train_type='careful',
)

Step:   5%|▌         | 52/1000 [00:18<05:58,  2.64it/s]

In [None]:
example_experiment = ExampleExperiment()

In [None]:
example_experiment.train(
    save_frequency=3,
    scalars_log_frequency=0.1,
    images_log_frequency=0.5,
    histograms_log_frequency=0.5,
    finish_step=1500,
    train_type='continue',
)