In [None]:
!git clone https://github.com/iglu-contest/iglu
!apt-get -qq install openjdk-8-jdk xvfb > /dev/null
!update-alternatives --config java
!sudo add-apt-repository ppa:openjdk-r/ppa
!sudo apt-get update
!sudo apt-get install openjdk-8-jdk
!sudo apt-get install xvfb
!pip uninstall -y iglu && pip install git+https://github.com/iglu-contest/iglu.git
!pip install gym==0.18.3

In [None]:
# exec this cell ONLY in colab
!wget -q https://raw.githubusercontent.com/iglu-contest/tutorials/main/env/colab_setup.sh -O - | sh > /dev/null 2>&1
!pip install -q pyvirtualdisplay
from pyvirtualdisplay import Display
disp = Display(backend="xvnc", size=(800, 600))
disp.start();
# for local notebooks instead launch jupyter as: xvfb-run -s "-screen 0 640x480x24" jupyter ...

In [2]:
import iglu
import gym
from iglu.tasks import RandomTasks
from iglu.tasks.task_set import TaskSet

#env = gym.make('IGLUSilentBuilder-v0')
#obs = env.reset()

In [4]:
!pip install sentence-transformers
import nltk
from sentence_transformers import SentenceTransformer

nltk.download('punkt')
bert_sentence = SentenceTransformer('all-distilroberta-v1')

Collecting sentence-transformers
  Downloading sentence-transformers-2.1.0.tar.gz (78 kB)
[K     |████████████████████████████████| 78 kB 1.8 MB/s eta 0:00:01
[?25hCollecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.11.3-py3-none-any.whl (2.9 MB)
[K     |████████████████████████████████| 2.9 MB 4.0 MB/s eta 0:00:01
[?25hCollecting tokenizers>=0.10.3
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 103.6 MB/s eta 0:00:01
Collecting scikit-learn
  Downloading scikit_learn-1.0.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (23.2 MB)
[K     |████████████████████████████████| 23.2 MB 8.5 MB/s eta 0:00:01
Collecting nltk
  Downloading nltk-3.6.5-py3-none-any.whl (1.5 MB)
[K     |████████████████████████████████| 1.5 MB 11.2 MB/s eta 0:00:01
[?25hCollecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manyl

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


Downloading:   0%|          | 0.00/737 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/9.86k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/653 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/15.7k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/349 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/329M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/333 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/13.1k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [5]:
targets = []
chats = []
for i in range(1,156):
  if ('C'+str(i)) == 'C38': continue
  targets.append(TaskSet(preset=['C'+str(i)]).sample().target_grid)
  chats.append(bert_sentence.encode(TaskSet(preset=['C'+str(i)]).sample().chat))

In [6]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch
from torch.nn.functional import one_hot

class TargetDataset(Dataset):
    def __init__(self, target_list, chat_list):
        self.target_list = target_list
        self.chat_list = chat_list

    def __len__(self):
        return len(self.target_list)

    def __getitem__(self, idx):
        target = self.target_list[idx]
        target_tensor_target = torch.tensor(target, dtype=torch.long)
        # target_tensor_input = one_hot(torch.tensor(target, dtype=torch.long), num_classes=7).permute(3, 0, 1, 2)
        chat_tensor = self.chat_list[idx]
        return chat_tensor, target_tensor_target
    
training_dataset = TargetDataset(targets, chats)

In [7]:
train_dataloader = DataLoader(training_dataset, batch_size=8, shuffle=True)

In [12]:
import torch
from torch import nn
from torch import optim
import numpy as np

device = torch.device('cuda')

class TargetDecoder(nn.Module):
    def __init__(self, features_dim=768):
        super(TargetDecoder, self).__init__()

        self.linear = nn.Sequential(nn.Linear(features_dim, 15680))

        self.cnn = nn.Sequential(
            nn.ConvTranspose3d(64, 64, kernel_size=3),
            nn.ReLU(),
            nn.ConvTranspose3d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(), 
            nn.ConvTranspose3d(32, 7, kernel_size=3),
        )

    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], 64, 5, 7, 7)
        x = self.cnn(x)
        return x

    

target_decoder = TargetDecoder().to(device)
optimizer = optim.Adam(target_decoder.parameters(), lr=1e-3)
loss_function = nn.CrossEntropyLoss()

In [14]:
EPOCHS = 500
for epoch in range(EPOCHS):
  target_decoder.train()
  train_loss = []
  for target_tensor_input, target_tensor_target in train_dataloader:
    target_tensor_input = target_tensor_input.float().to(device)
    target_tensor_target = target_tensor_target.to(device)
    optimizer.zero_grad()
    predict = target_decoder(target_tensor_input)
    loss = loss_function(predict, target_tensor_target)
    train_loss.append(loss.item())
    loss.backward()
    optimizer.step()
  train_loss = np.array(train_loss).mean()
  print(f"epoch: {epoch} | loss: {train_loss}")

epoch: 0 | loss: 0.043094551749527456
epoch: 1 | loss: 0.04088730327785015
epoch: 2 | loss: 0.04016119558364153
epoch: 3 | loss: 0.039527260884642604
epoch: 4 | loss: 0.03710999926552176
epoch: 5 | loss: 0.033643096871674064
epoch: 6 | loss: 0.029962214455008507
epoch: 7 | loss: 0.02785949329845607
epoch: 8 | loss: 0.025902612414211035
epoch: 9 | loss: 0.024425185285508633
epoch: 10 | loss: 0.025513942074030637
epoch: 11 | loss: 0.02567363395355642
epoch: 12 | loss: 0.023767703352496027
epoch: 13 | loss: 0.02086513042449951
epoch: 14 | loss: 0.01964312852360308
epoch: 15 | loss: 0.01705207610502839
epoch: 16 | loss: 0.015394671354442835
epoch: 17 | loss: 0.01437988739926368
epoch: 18 | loss: 0.01568834511563182
epoch: 19 | loss: 0.013960746768862008
epoch: 20 | loss: 0.012519727065227925
epoch: 21 | loss: 0.012677266029641032
epoch: 22 | loss: 0.012790787499397993
epoch: 23 | loss: 0.011986225680448114
epoch: 24 | loss: 0.010217207274399698
epoch: 25 | loss: 0.008862713351845741
epoch:

epoch: 204 | loss: 1.4505299202483001e-05
epoch: 205 | loss: 1.4353677249800966e-05
epoch: 206 | loss: 1.4191325772117124e-05
epoch: 207 | loss: 1.4032018430043536e-05
epoch: 208 | loss: 1.3759782655142772e-05
epoch: 209 | loss: 1.3266514406495844e-05
epoch: 210 | loss: 1.3077803487249184e-05
epoch: 211 | loss: 1.2755791067320387e-05
epoch: 212 | loss: 1.2759171011111902e-05
epoch: 213 | loss: 1.311027438077872e-05
epoch: 214 | loss: 1.2023996930565772e-05
epoch: 215 | loss: 1.1779747654827589e-05
epoch: 216 | loss: 1.1685820845741546e-05
epoch: 217 | loss: 1.1456596098469163e-05
epoch: 218 | loss: 1.2457481693672889e-05
epoch: 219 | loss: 1.1070737537011155e-05
epoch: 220 | loss: 1.0835468930281423e-05
epoch: 221 | loss: 1.0635637158884492e-05
epoch: 222 | loss: 1.0442969804103086e-05
epoch: 223 | loss: 1.2521979476787237e-05
epoch: 224 | loss: 1.0212398990461224e-05
epoch: 225 | loss: 1.103255696079941e-05
epoch: 226 | loss: 9.768943107246741e-06
epoch: 227 | loss: 9.735521894072008e

epoch: 402 | loss: 8.343290630818956e-07
epoch: 403 | loss: 8.346944071035978e-07
epoch: 404 | loss: 8.752803978495649e-07
epoch: 405 | loss: 8.059441498176057e-07
epoch: 406 | loss: 8.075945778784899e-07
epoch: 407 | loss: 7.875361120568414e-07
epoch: 408 | loss: 7.607723450320236e-07
epoch: 409 | loss: 7.684085673531626e-07
epoch: 410 | loss: 7.534090876504252e-07
epoch: 411 | loss: 7.522830316020191e-07
epoch: 412 | loss: 7.359233890724682e-07
epoch: 413 | loss: 7.29234344021279e-07
epoch: 414 | loss: 7.229541743924983e-07
epoch: 415 | loss: 7.057143534439092e-07
epoch: 416 | loss: 7.147055740119868e-07
epoch: 417 | loss: 7.152756325012888e-07
epoch: 418 | loss: 7.159894096275821e-07
epoch: 419 | loss: 7.193107393277387e-07
epoch: 420 | loss: 6.718433830599224e-07
epoch: 421 | loss: 6.551442687907639e-07
epoch: 422 | loss: 6.63968984326857e-07
epoch: 423 | loss: 6.871180261214249e-07
epoch: 424 | loss: 6.275549733203434e-07
epoch: 425 | loss: 6.383589749248131e-07
epoch: 426 | loss: