In [27]:
import os
import gc
import cv2
import json
import time

import numpy as np
import pandas as pd
from pathlib import Path
from keras.utils import to_categorical

import seaborn as sns
#import plotly.express as px
from matplotlib import colors
import matplotlib.pyplot as plt
#import plotly.figure_factory as ff

import torch
T = torch.Tensor
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader

In [11]:
SIZE = 1000
EPOCHS = 50
CONV_OUT_1 = 50
CONV_OUT_2 = 100
BATCH_SIZE = 128
TEST_PATH = Path('./data')
SUBMISSION_PATH = Path('../input/abstraction-and-reasoning-challenge/')

TEST_PATH = TEST_PATH / 'test'
SUBMISSION_PATH = SUBMISSION_PATH / 'sample_submission.csv'

In [17]:
test_task_files = sorted(os.listdir(TEST_PATH))

test_tasks = []
for task_file in test_task_files:
    with open(str(TEST_PATH / task_file), 'r') as f:
        task = json.load(f)
        test_tasks.append(task)

In [66]:
Xs_test, Xs_train, ys_train = [], [], []

for task in test_tasks:
    X_test, X_train, y_train = [], [], []

    for pair in task["test"]:
        X_test.append(pair["input"])

    for pair in task["train"]:
        X_train.append(pair["input"])
        y_train.append(pair["output"])
    
    Xs_test.append(X_test)
    Xs_train.append(X_train)
    ys_train.append(y_train)

In [45]:
def replace_values(a, d):
    return np.array([d.get(i, -1) for i in range(a.min(), a.max() + 1)])[a - a.min()]

def repeat_matrix(a):
    return np.concatenate([a]*((SIZE // len(a)) + 1))[:SIZE]

def get_new_matrix(X):
    if len(set([np.array(x).shape for x in X])) > 1:
        X = np.array([X[0]])
    return X

def get_outp(outp, dictionary=None, replace=True):
    if replace:
        outp = replace_values(outp, dictionary)

    outp_matrix_dims = outp.shape
    outp_probs_len = outp.shape[0]*outp.shape[1]*10
    outp = to_categorical(outp.flatten(),
                          num_classes=10).flatten()

    return outp, outp_probs_len, outp_matrix_dims

In [64]:
class ARCDataset(Dataset):
    def __init__(self, X, y, stage="train"):
        self.X = get_new_matrix(X)
        self.X = repeat_matrix(self.X)
        
        self.stage = stage
        if self.stage == "train":
            self.y = get_new_matrix(y)
            self.y = repeat_matrix(self.y)
        
    def __len__(self):
        return SIZE
    
    def __getitem__(self, idx):
        inp = self.X[idx]
        if self.stage == "train":
            outp = self.y[idx]

        if idx != 0:
            rep = np.arange(10)
            orig = np.arange(10)
            np.random.shuffle(rep)
            dictionary = dict(zip(orig, rep))
            inp = replace_values(inp, dictionary)
            #print("inp",inp)
            if self.stage == "train":
                outp, outp_probs_len, outp_matrix_dims = get_outp(outp, dictionary)
                
        if idx == 0:
            if self.stage == "train":
                outp, outp_probs_len, outp_matrix_dims = get_outp(outp, None, False)
        #print(inp, outp)
        return inp, outp, outp_probs_len, outp_matrix_dims, self.y

In [33]:
class BasicCNNModel(nn.Module):
    def __init__(self, inp_dim=(10, 10), outp_dim=(10, 10)):
        super(BasicCNNModel, self).__init__()
        
        CONV_IN = 3
        KERNEL_SIZE = 3
        DENSE_IN = CONV_OUT_2
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
        self.dense_1 = nn.Linear(DENSE_IN, outp_dim[0]*outp_dim[1]*10)
        
        if inp_dim[0] < 5 or inp_dim[1] < 5:
            KERNEL_SIZE = 1

        self.conv2d_1 = nn.Conv2d(CONV_IN, CONV_OUT_1, kernel_size=KERNEL_SIZE)
        self.conv2d_2 = nn.Conv2d(CONV_OUT_1, CONV_OUT_2, kernel_size=KERNEL_SIZE)

    def forward(self, x, outp_dim):
        x = torch.cat([x.unsqueeze(0)]*3)
        x = x.permute((1, 0, 2, 3)).float()
        self.conv2d_1.in_features = x.shape[1]
        conv_1_out = self.relu(self.conv2d_1(x))
        self.conv2d_2.in_features = conv_1_out.shape[1]
        conv_2_out = self.relu(self.conv2d_2(conv_1_out))
        
        self.dense_1.out_features = outp_dim
        feature_vector, _ = torch.max(conv_2_out, 2)
        feature_vector, _ = torch.max(feature_vector, 2)
        logit_outputs = self.dense_1(feature_vector)
        
        out = []
        for idx in range(logit_outputs.shape[1]//10):
            out.append(self.softmax(logit_outputs[:, idx*10: (idx+1)*10]))
        return torch.cat(out, axis=1)

In [49]:
def transform_dim(inp_dim, outp_dim, test_dim):
    return (test_dim[0]*outp_dim[0]/inp_dim[0],
            test_dim[1]*outp_dim[1]/inp_dim[1])

def resize(x, test_dim, inp_dim):
    if inp_dim == test_dim:
        return x
    else:
        return cv2.resize(flt(x), inp_dim,
                          interpolation=cv2.INTER_AREA)

def flt(x): return np.float32(x)
def npy(x): return x.cpu().detach().numpy()
def itg(x): return np.int32(np.round(x))

In [65]:
idx = 0
start = time.time()
test_predictions = []

for X_train, y_train in zip(Xs_train, ys_train):
    print("TASK " + str(idx + 1))
    train_set = ARCDataset(X_train, y_train, stage="train")
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=False)

    inp_dim = np.array(X_train[0]).shape
    outp_dim = np.array(y_train[0]).shape
    network = BasicCNNModel(inp_dim, outp_dim).cuda()
    optimizer = Adam(network.parameters(), lr=0.01)
    
    for epoch in range(EPOCHS):
        for train_batch in train_loader:
            print(train_batch)
            train_X, train_y, out_d, d, out = train_batch
            train_preds = network.forward(train_X.cuda(), out_d.cuda())
            #print(train_X.size(), train_y.size())
            train_loss = nn.MSELoss()(train_preds, train_y.cuda())
            
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

    end = time.time()        
    print("Train loss: " + str(np.round(train_loss.item(), 3)) + "   " +\
          "Total time: " + str(np.round(end - start, 1)) + " s" + "\n")
    
    X_test = np.array([resize(flt(X), np.shape(X), inp_dim) for X in Xs_test[idx-1]])
    for X in X_test:
        test_dim = np.array(T(X)).shape
        test_preds = npy(network.forward(T(X).unsqueeze(0).cuda(), out_d.cuda()))
        test_preds = np.argmax(test_preds.reshape((10, *outp_dim)), axis=0)
        test_predictions.append(itg(resize(test_preds, np.shape(test_preds),
                                           tuple(itg(transform_dim(inp_dim,
                                                                   outp_dim,
                                                                   test_dim))))))
    idx += 1

TASK 1
[tensor([[[8, 6],
         [6, 4]],

        [[3, 2],
         [8, 5]],

        [[6, 8],
         [8, 0]],

        [[4, 0],
         [5, 2]],

        [[5, 9],
         [9, 6]],

        [[4, 1],
         [2, 3]],

        [[8, 1],
         [1, 5]],

        [[2, 3],
         [8, 6]],

        [[9, 0],
         [0, 5]],

        [[5, 7],
         [1, 4]],

        [[6, 9],
         [9, 7]],

        [[1, 7],
         [6, 5]],

        [[5, 2],
         [2, 6]],

        [[8, 1],
         [3, 0]],

        [[2, 1],
         [1, 0]],

        [[4, 6],
         [2, 3]],

        [[7, 1],
         [1, 9]],

        [[6, 3],
         [9, 2]],

        [[6, 1],
         [1, 9]],

        [[6, 5],
         [7, 2]],

        [[1, 2],
         [2, 3]],

        [[0, 1],
         [5, 8]],

        [[3, 5],
         [5, 4]],

        [[5, 3],
         [2, 4]],

        [[4, 5],
         [5, 3]],

        [[2, 9],
         [5, 7]],

        [[3, 1],
         [1, 5]],

        [[0, 5],
   

[tensor([[[5, 3],
         [3, 8]],

        [[9, 1],
         [8, 0]],

        [[1, 9],
         [9, 8]],

        [[2, 9],
         [1, 5]],

        [[0, 9],
         [9, 3]],

        [[4, 8],
         [1, 9]],

        [[6, 1],
         [1, 3]],

        [[9, 5],
         [3, 0]],

        [[0, 5],
         [5, 2]],

        [[0, 4],
         [5, 2]],

        [[0, 5],
         [5, 1]],

        [[9, 4],
         [8, 1]],

        [[3, 4],
         [4, 2]],

        [[9, 7],
         [2, 6]],

        [[0, 3],
         [3, 6]],

        [[1, 9],
         [3, 4]],

        [[7, 6],
         [6, 1]],

        [[4, 8],
         [2, 7]],

        [[2, 4],
         [4, 6]],

        [[4, 5],
         [2, 8]],

        [[5, 3],
         [3, 2]],

        [[7, 9],
         [0, 4]],

        [[6, 9],
         [9, 5]],

        [[9, 5],
         [4, 1]],

        [[0, 1],
         [1, 9]],

        [[2, 9],
         [6, 8]],

        [[0, 2],
         [2, 8]],

        [[6, 4],
         [

[tensor([[[5, 7],
         [7, 0]],

        [[8, 9],
         [0, 4]],

        [[6, 1],
         [1, 0]],

        [[3, 0],
         [5, 7]],

        [[6, 2],
         [2, 1]],

        [[9, 6],
         [8, 3]],

        [[8, 0],
         [0, 9]],

        [[9, 7],
         [8, 3]],

        [[8, 9],
         [9, 2]],

        [[5, 2],
         [9, 6]],

        [[0, 2],
         [2, 3]],

        [[4, 7],
         [2, 3]],

        [[7, 2],
         [2, 8]],

        [[4, 3],
         [7, 6]],

        [[1, 6],
         [6, 2]],

        [[7, 0],
         [2, 6]],

        [[8, 4],
         [4, 5]],

        [[0, 5],
         [9, 7]],

        [[4, 7],
         [7, 3]],

        [[2, 9],
         [7, 0]],

        [[9, 3],
         [3, 0]],

        [[4, 7],
         [3, 6]],

        [[3, 1],
         [1, 5]],

        [[8, 5],
         [0, 1]],

        [[9, 0],
         [0, 7]],

        [[6, 7],
         [9, 8]],

        [[8, 2],
         [2, 4]],

        [[6, 3],
         [

[tensor([[[4, 3],
         [3, 7]],

        [[6, 0],
         [4, 9]],

        [[1, 6],
         [6, 7]],

        [[5, 9],
         [4, 2]],

        [[9, 2],
         [2, 4]],

        [[2, 1],
         [0, 7]],

        [[4, 7],
         [7, 0]],

        [[6, 2],
         [0, 9]],

        [[4, 0],
         [0, 9]],

        [[5, 6],
         [4, 7]],

        [[8, 7],
         [7, 1]],

        [[4, 8],
         [9, 1]],

        [[3, 0],
         [0, 2]],

        [[0, 5],
         [4, 2]],

        [[3, 1],
         [1, 6]],

        [[2, 4],
         [1, 5]],

        [[1, 9],
         [9, 2]],

        [[1, 4],
         [2, 0]],

        [[1, 0],
         [0, 4]],

        [[4, 8],
         [3, 1]],

        [[5, 2],
         [2, 6]],

        [[5, 9],
         [0, 4]],

        [[1, 3],
         [3, 5]],

        [[7, 6],
         [2, 1]],

        [[9, 8],
         [8, 5]],

        [[1, 9],
         [4, 8]],

        [[4, 5],
         [5, 9]],

        [[2, 0],
         [

[tensor([[[0, 6],
         [6, 3]],

        [[0, 5],
         [2, 8]],

        [[9, 5],
         [5, 6]],

        [[7, 8],
         [9, 0]],

        [[5, 0],
         [0, 7]],

        [[9, 5],
         [3, 8]],

        [[9, 8],
         [8, 5]],

        [[1, 3],
         [4, 9]],

        [[3, 5],
         [5, 0]],

        [[9, 8],
         [3, 7]],

        [[0, 3],
         [3, 8]],

        [[7, 8],
         [1, 0]],

        [[3, 0],
         [0, 8]],

        [[7, 1],
         [6, 3]],

        [[5, 0],
         [0, 3]],

        [[7, 5],
         [9, 1]],

        [[0, 8],
         [8, 1]],

        [[4, 2],
         [9, 7]],

        [[1, 5],
         [5, 7]],

        [[4, 1],
         [3, 9]],

        [[7, 4],
         [4, 3]],

        [[2, 1],
         [5, 3]],

        [[0, 1],
         [1, 6]],

        [[1, 5],
         [0, 8]],

        [[1, 6],
         [6, 8]],

        [[3, 4],
         [9, 1]],

        [[7, 1],
         [1, 3]],

        [[7, 9],
         [

[tensor([[[2, 6],
         [6, 7]],

        [[0, 9],
         [1, 2]],

        [[2, 3],
         [3, 1]],

        [[3, 4],
         [8, 7]],

        [[6, 1],
         [1, 2]],

        [[7, 4],
         [2, 0]],

        [[2, 3],
         [3, 0]],

        [[3, 7],
         [8, 6]],

        [[5, 1],
         [1, 2]],

        [[1, 0],
         [5, 8]],

        [[4, 0],
         [0, 8]],

        [[8, 3],
         [5, 7]],

        [[6, 1],
         [1, 8]],

        [[6, 7],
         [4, 5]],

        [[6, 5],
         [5, 1]],

        [[7, 9],
         [5, 6]],

        [[3, 6],
         [6, 0]],

        [[4, 6],
         [3, 5]],

        [[7, 1],
         [1, 9]],

        [[5, 6],
         [0, 1]],

        [[7, 4],
         [4, 0]],

        [[7, 2],
         [1, 9]],

        [[4, 9],
         [9, 8]],

        [[4, 3],
         [0, 2]],

        [[0, 2],
         [2, 4]],

        [[6, 9],
         [3, 5]],

        [[2, 9],
         [9, 5]],

        [[7, 5],
         [

[tensor([[[3, 0],
         [0, 7]],

        [[8, 9],
         [2, 4]],

        [[1, 5],
         [5, 6]],

        [[3, 8],
         [6, 5]],

        [[3, 2],
         [2, 0]],

        [[9, 5],
         [0, 6]],

        [[2, 6],
         [6, 1]],

        [[2, 8],
         [9, 0]],

        [[9, 7],
         [7, 8]],

        [[0, 5],
         [1, 6]],

        [[5, 6],
         [6, 3]],

        [[3, 4],
         [6, 7]],

        [[4, 2],
         [2, 0]],

        [[4, 0],
         [7, 8]],

        [[1, 7],
         [7, 0]],

        [[0, 4],
         [5, 2]],

        [[9, 7],
         [7, 8]],

        [[9, 0],
         [8, 2]],

        [[3, 7],
         [7, 9]],

        [[1, 5],
         [8, 4]],

        [[4, 2],
         [2, 1]],

        [[4, 9],
         [0, 7]],

        [[2, 1],
         [1, 6]],

        [[5, 3],
         [2, 0]],

        [[9, 3],
         [3, 1]],

        [[3, 5],
         [7, 0]],

        [[1, 6],
         [6, 8]],

        [[0, 6],
         [

[tensor([[[0, 4],
         [4, 3]],

        [[5, 8],
         [3, 9]],

        [[2, 9],
         [9, 0]],

        [[5, 1],
         [2, 4]],

        [[4, 7],
         [7, 1]],

        [[1, 5],
         [8, 2]],

        [[7, 2],
         [2, 1]],

        [[2, 3],
         [6, 0]],

        [[1, 4],
         [4, 8]],

        [[2, 4],
         [9, 1]],

        [[4, 1],
         [1, 2]],

        [[9, 7],
         [1, 2]],

        [[4, 5],
         [5, 8]],

        [[5, 0],
         [4, 7]],

        [[5, 0],
         [0, 7]],

        [[4, 0],
         [5, 8]],

        [[7, 3],
         [3, 8]],

        [[8, 4],
         [2, 3]],

        [[8, 6],
         [6, 3]],

        [[9, 0],
         [4, 3]],

        [[5, 3],
         [3, 9]],

        [[8, 0],
         [2, 1]],

        [[5, 9],
         [9, 0]],

        [[2, 7],
         [1, 4]],

        [[2, 3],
         [3, 8]],

        [[4, 1],
         [5, 9]],

        [[8, 2],
         [2, 1]],

        [[1, 6],
         [

[tensor([[[1, 2],
         [2, 3]],

        [[8, 6],
         [4, 9]],

        [[6, 1],
         [1, 5]],

        [[9, 2],
         [7, 3]],

        [[6, 0],
         [0, 2]],

        [[7, 8],
         [0, 2]],

        [[3, 6],
         [6, 4]],

        [[3, 4],
         [7, 6]],

        [[6, 0],
         [0, 2]],

        [[0, 3],
         [8, 7]],

        [[4, 2],
         [2, 1]],

        [[1, 4],
         [2, 6]],

        [[7, 0],
         [0, 1]],

        [[4, 8],
         [1, 7]],

        [[2, 3],
         [3, 4]],

        [[4, 8],
         [5, 9]],

        [[0, 1],
         [1, 3]],

        [[7, 1],
         [3, 5]],

        [[3, 0],
         [0, 5]],

        [[7, 4],
         [8, 6]],

        [[3, 2],
         [2, 9]],

        [[2, 6],
         [1, 9]],

        [[6, 0],
         [0, 4]],

        [[2, 0],
         [6, 8]],

        [[7, 5],
         [5, 4]],

        [[0, 2],
         [8, 9]],

        [[0, 7],
         [7, 2]],

        [[8, 5],
         [

[tensor([[[9, 1],
         [1, 7]],

        [[0, 6],
         [5, 1]],

        [[2, 4],
         [4, 7]],

        [[1, 7],
         [9, 6]],

        [[0, 3],
         [3, 2]],

        [[5, 6],
         [8, 0]],

        [[7, 2],
         [2, 1]],

        [[6, 2],
         [8, 4]],

        [[2, 4],
         [4, 5]],

        [[2, 3],
         [1, 9]],

        [[5, 0],
         [0, 8]],

        [[8, 5],
         [0, 6]],

        [[6, 4],
         [4, 1]],

        [[8, 9],
         [4, 6]],

        [[5, 6],
         [6, 9]],

        [[3, 5],
         [8, 4]],

        [[2, 7],
         [7, 3]],

        [[4, 7],
         [6, 2]],

        [[8, 5],
         [5, 7]],

        [[0, 3],
         [8, 2]],

        [[9, 8],
         [8, 6]],

        [[0, 8],
         [1, 2]],

        [[5, 1],
         [1, 6]],

        [[7, 0],
         [2, 8]],

        [[8, 6],
         [6, 5]],

        [[9, 7],
         [0, 8]],

        [[4, 1],
         [1, 3]],

        [[0, 8],
         [

KeyboardInterrupt: 