In [3]:
%load_ext autoreload
%autoreload 2

import gym
import iglu
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from IPython.display import clear_output
from PIL import Image
import numpy as np
import torch
from torch.nn.functional import one_hot

from iglu.tasks import RandomTasks
from iglu.tasks.task_set import TaskSet

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
print('\n'.join(f'{k}: {v}' for k, v in TaskSet.ALL.items()))

C1: bell (29 blocks, 5 colors)
C2: black-hole (28 blocks, 2 colors)
C3: blue-original-L (3 blocks, 1 colors)
C4: flower_new (17 blocks, 4 colors)
C5: overlapping-chain-links (20 blocks, 2 colors)
C6: rectangle-chain (40 blocks, 4 colors)
C7: scissors (23 blocks, 2 colors)
C8: table2 (12 blocks, 2 colors)
C9: asterisk (33 blocks, 2 colors)
C10: concentric_semicircles (21 blocks, 3 colors)
C11: broken_heart (21 blocks, 2 colors)
C12: diagonal-Ls (18 blocks, 6 colors)
C13: eye (29 blocks, 3 colors)
C14: diagonal-zigzag (16 blocks, 3 colors)
C15: double_stairs (36 blocks, 4 colors)
C16: bloody-sword (24 blocks, 4 colors)
C17: orange-flat-original-L (3 blocks, 1 colors)
C18: overlapping-reticles (12 blocks, 3 colors)
C19: slide_with_arch (48 blocks, 3 colors)
C20: rainbow-lasso (9 blocks, 6 colors)
C21: spectacles (29 blocks, 2 colors)
C22: smiley (6 blocks, 5 colors)
C23: suspension_bridge (68 blocks, 4 colors)
C24: cup (46 blocks, 4 colors)
C25: music-notes (18 blocks, 2 colors)
C26: rain

In [8]:
target_list = []
label_list = []

print("Generation of training dataset")
for key, v in TaskSet.ALL.items():
    if key == 'C38': continue
    target = TaskSet(preset=[key]).sample().target_grid
    target_list.append(target)
    label_list.append(key)

Generation of training dataset


In [12]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class TargetDataset(Dataset):
    def __init__(self, target_list, label_list):
        self.target_list = target_list
        self.label_list = label_list

    def __len__(self):
        return len(self.target_list)

    def __getitem__(self, idx):
        target = self.target_list[idx]
        target_tensor_target = torch.tensor(target, dtype=torch.long)
        target_tensor_input = one_hot(torch.tensor(target, dtype=torch.long), num_classes=7).permute(3, 0, 1, 2)
        label_tensor = torch.tensor(idx, dtype=torch.long)
        return target_tensor_input, label_tensor
    
training_dataset = TargetDataset(target_list, label_list)

In [14]:
train_dataloader = DataLoader(training_dataset, batch_size=156, shuffle=True)

In [25]:
# Training procedure
from tqdm.notebook import tqdm

import torch
from torch import nn
from torch import optim

device = torch.device('cuda')

class TargetEncoder(nn.Module):
    def __init__(self, features_dim=512):
        super(TargetEncoder, self).__init__()

        # input (7, 9, 11, 11)

        self.cnn = nn.Sequential(
            nn.Conv3d(7, 32, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv3d(64, 64, kernel_size=3),
            nn.ReLU(),
            nn.Flatten()
        )

        self.linear = nn.Sequential(nn.Linear(15680, features_dim), nn.ReLU())

    def forward(self, x):
        return self.linear(self.cnn(x))


class TargetDecoder(nn.Module):
    def __init__(self, features_dim=512):
        super(TargetDecoder, self).__init__()

        self.linear = nn.Sequential(nn.Linear(features_dim, 15680))

        self.cnn = nn.Sequential(
            nn.ConvTranspose3d(64, 64, kernel_size=3),
            nn.ReLU(),
            nn.ConvTranspose3d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(), 
            nn.ConvTranspose3d(32, 7, kernel_size=3),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], 64, 5, 7, 7)
        x = self.cnn(x)
        return x

class TargetAutoencoder(nn.Module):
    def __init__(self):
        super(TargetAutoencoder, self).__init__()
        self.encoder = TargetEncoder(2024)
        self.decoder = TargetDecoder(2024)
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
    
class TargetClassifier(nn.Module):
    def __init__(self):
        super(TargetClassifier, self).__init__()
        self.encoder = TargetEncoder(1024)
        self.linear = nn.Linear(1024, 156)
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.linear(x)
        return x
    

#target_autoencoder = TargetAutoencoder().to(device)
target_classifier = TargetClassifier().to(device)
optimizer = optim.Adam(target_classifier.parameters(), lr=1e-4)
loss_function = nn.CrossEntropyLoss()

In [26]:
import wandb

wandb.init("Target Classifier")

EPOCHS = 5000
for epoch in tqdm(range(EPOCHS)):
    # training
    target_autoencoder.train()
    train_loss = []
    train_acc = []
    for target_tensor_input, label_tensor in train_dataloader:
        target_tensor_input = target_tensor_input.float().to(device)
        label_tensor = label_tensor.to(device)
    
        optimizer.zero_grad()
        predict = target_classifier(target_tensor_input)
        loss = loss_function(predict, label_tensor)
        train_acc.append((predict.argmax(dim=1) == label_tensor).float().mean().item())
        train_loss.append(loss.item())
        loss.backward()
        optimizer.step()
    train_loss = np.array(train_loss).mean()
    train_acc = np.array(train_acc).mean()

    wandb.log({'train_loss': train_loss, 'train_acc': train_acc})

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train_acc,1.0
train_loss,0.0


0,1
train_acc,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,██▇▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁


  0%|          | 0/5000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [10]:
from torch import nn

model = nn.Sequential(
    nn.Conv3d(7, 32, kernel_size=3, stride=1, padding=0),
    nn.ReLU(),
    nn.Conv3d(32, 32, kernel_size=3, stride=1, padding=0),
    nn.ReLU(),
    nn.Conv3d(32, 32, kernel_size=3),
    nn.ReLU(),
    nn.Conv3d(32, 32, kernel_size=3),
    nn.ReLU(),
    nn.Conv3d(32, 32, kernel_size=3),
    nn.Flatten()
)

In [11]:
sum(p.numel() for p in model.parameters())

116800