In [1]:
%load_ext autoreload
%autoreload 2

import gym
import iglu
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from IPython.display import clear_output
from PIL import Image
import numpy as np
import torch
from torch.nn.functional import one_hot

from iglu.tasks import RandomTasks



In [2]:
TRAINING_DATASET_SIZE = int(5e3)
VALIDATION_DATASET_SIZE = int(1e3)

MAX_BLOCKS = 30
MAX_COLORS = 6
MAX_DIST = 5

training_target_list = []
validation_target_list = []

rt = RandomTasks(max_blocks=MAX_BLOCKS, 
                 max_dist=MAX_DIST,
                 num_colors=MAX_COLORS)

print("Generation of training dataset")
for i in tqdm(range(TRAINING_DATASET_SIZE)):
    target = rt.sample().target_grid
    training_target_list.append(target)
    
print("Generation of validation dataset")
for i in tqdm(range(VALIDATION_DATASET_SIZE)):
    target = rt.sample().target_grid
    validation_target_list.append(target)

Generation of training dataset


  0%|          | 0/5000 [00:00<?, ?it/s]

Generation of validation dataset


  0%|          | 0/1000 [00:00<?, ?it/s]

In [7]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class TargetDataset(Dataset):
    def __init__(self, target_list):
        self.target_list = target_list

    def __len__(self):
        return len(self.target_list)

    def __getitem__(self, idx):
        target = self.target_list[idx]
        target_tensor_target = torch.tensor(target, dtype=torch.long)
        target_tensor_input = one_hot(torch.tensor(target, dtype=torch.long), num_classes=7).permute(3, 0, 1, 2)
        return target_tensor_input, target_tensor_target
    
training_dataset = TargetDataset(training_target_list)
validation_dataset = TargetDataset(validation_target_list)

In [8]:
train_dataloader = DataLoader(training_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(validation_dataset, batch_size=64, shuffle=True)

In [15]:
# Training procedure
from tqdm.notebook import tqdm

import torch
from torch import nn
from torch import optim

device = torch.device('cuda')

class TargetEncoder(nn.Module):
    def __init__(self, features_dim=512):
        super(TargetEncoder, self).__init__()

        # input (7, 9, 11, 11)

        self.cnn = nn.Sequential(
            nn.Conv3d(7, 32, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv3d(64, 64, kernel_size=3),
            nn.ReLU(),
            nn.Flatten()
        )

        self.linear = nn.Sequential(nn.Linear(15680, features_dim), nn.ReLU())

    def forward(self, x):
        return self.linear(self.cnn(x))


class TargetDecoder(nn.Module):
    def __init__(self, features_dim=512):
        super(TargetDecoder, self).__init__()

        self.linear = nn.Sequential(nn.Linear(features_dim, 15680))

        self.cnn = nn.Sequential(
            nn.ConvTranspose3d(64, 64, kernel_size=3),
            nn.ReLU(),
            nn.ConvTranspose3d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(), 
            nn.ConvTranspose3d(32, 7, kernel_size=3),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], 64, 5, 7, 7)
        x = self.cnn(x)
        return x

class TargetAutoencoder(nn.Module):
    def __init__(self):
        super(TargetAutoencoder, self).__init__()
        self.encoder = TargetEncoder(2024)
        self.decoder = TargetDecoder(2024)
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
    

target_autoencoder = TargetAutoencoder().to(device)
optimizer = optim.Adam(target_autoencoder.parameters(), lr=1e-4)
loss_function = nn.CrossEntropyLoss()

In [16]:
import wandb

wandb.init("Target Autoencoder")

EPOCHS = 50
for epoch in tqdm(range(EPOCHS)):
    # training
    target_autoencoder.train()
    train_loss = []
    for target_tensor_input, target_tensor_target in train_dataloader:
        target_tensor_input = target_tensor_input.float().to(device)
        target_tensor_target = target_tensor_target.to(device)
    
        optimizer.zero_grad()
        predict = target_autoencoder(target_tensor_input)
        loss = loss_function(predict, target_tensor_target)
        train_loss.append(loss.item())
        loss.backward()
        optimizer.step()
    train_loss = np.array(train_loss).mean()
        
    # evaluation
    target_autoencoder.eval()
    val_loss = []
    for target_tensor_input, target_tensor_target in val_dataloader:
        target_tensor_input = target_tensor_input.float().to(device)
        target_tensor_target = target_tensor_target.to(device)
    
        predict = target_autoencoder(target_tensor_input)
        loss = loss_function(predict, target_tensor_target)
        val_loss.append(loss.item())
    val_loss = np.array(val_loss).mean()

    wandb.log({'train_loss': train_loss, 'val_loss': val_loss})

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

  0%|          | 0/50 [00:00<?, ?it/s]