In [4]:
%load_ext autoreload
%autoreload 2

import gym
import iglu
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from IPython.display import clear_output
from PIL import Image
import numpy as np
import torch
from torch.nn.functional import one_hot

from iglu.tasks import RandomTasks

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [28]:
TRAINING_DATASET_SIZE = int(5e3)
VALIDATION_DATASET_SIZE = int(1e3)

MAX_BLOCKS = 30
MAX_COLORS = 6
MAX_DIST = 5

training_target_list = []
validation_target_list = []

rt = RandomTasks(max_blocks=MAX_BLOCKS, 
                 max_dist=MAX_DIST,
                 num_colors=MAX_COLORS)

print("Generation of training dataset")
for i in tqdm(range(TRAINING_DATASET_SIZE)):
    target = rt.sample().target_grid
    training_target_list.append(target)
    
print("Generation of validation dataset")
for i in tqdm(range(VALIDATION_DATASET_SIZE)):
    target = rt.sample().target_grid
    validation_target_list.append(target)

Generation of training dataset


  0%|          | 0/5000 [00:00<?, ?it/s]

Generation of validation dataset


  0%|          | 0/1000 [00:00<?, ?it/s]

In [35]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class TargetDataset(Dataset):
    def __init__(self, target_list):
        self.target_list = target_list

    def __len__(self):
        return len(self.target_list)

    def __getitem__(self, idx):
        target = target_list[idx]
        target_tensor_target = torch.tensor(target, dtype=torch.long)
        target_tensor_input = one_hot(torch.tensor(target, dtype=torch.long), num_classes=7).permute(3, 0, 1, 2)
        return target_tensor_input, target_tensor_target
    
training_dataset = TargetDataset(training_target_list)
validation_dataset = TargetDataset(validation_target_list)

In [36]:
train_dataloader = DataLoader(training_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(validation_dataset, batch_size=64, shuffle=True)

In [37]:
# Training procedure
from tqdm.notebook import tqdm

from models import TargetEncoder, TargetDecoder
import torch
from torch import nn
from torch import optim

device = torch.device('cuda')

class TargetAutoencoder(nn.Module):
    def __init__(self):
        super(TargetAutoencoder, self).__init__()
        self.encoder = TargetEncoder()
        self.decoder = TargetDecoder()
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
    

target_autoencoder = TargetAutoencoder().to(device)
optimizer = optim.Adam(target_autoencoder.parameters(), lr=1e-4)
loss_function = nn.CrossEntropyLoss()

In [38]:
import wandb

wandb.init("Target Autoencoder")

EPOCHS = 50
for epoch in tqdm(range(EPOCHS)):
    # training
    target_autoencoder.train()
    train_loss = []
    for target_tensor_input, target_tensor_target in train_dataloader:
        target_tensor_input = target_tensor_input.float().to(device)
        target_tensor_target = target_tensor_target.to(device)
    
        optimizer.zero_grad()
        predict = target_autoencoder(target_tensor_input)
        loss = loss_function(predict, target_tensor_target)
        train_loss.append(loss.item())
        loss.backward()
        optimizer.step()
    train_loss = np.array(train_loss).mean()
        
    # evaluation
    target_autoencoder.eval()
    val_loss = []
    for target_tensor_input, target_tensor_target in val_dataloader:
        target_tensor_input = target_tensor_input.float().to(device)
        target_tensor_target = target_tensor_target.to(device)
    
        predict = target_autoencoder(target_tensor_input)
        loss = loss_function(predict, target_tensor_target)
        val_loss.append(loss.item())
    val_loss = np.array(val_loss).mean()

    wandb.log({'train_loss': train_loss, 'val_loss': val_loss})

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train_loss,0.00667
val_loss,0.00667
_runtime,63.0
_timestamp,1630849125.0
_step,49.0


0,1
train_loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


  0%|          | 0/50 [00:00<?, ?it/s]