In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!mkdir DATA
!unzip -qq {'/content/drive/MyDrive/DACON/open.zip'} -d /content/DATA

In [None]:
pip install timm

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.io import read_image

import timm
from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

import pandas as pd
import numpy as np
from PIL import Image
from tqdm.auto import tqdm

import os
import time

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
def calc_puzzle(answer_df, submission_df):
    # Check for missing values in submission_df
    if submission_df.isnull().values.any():
        raise ValueError("The submission dataframe contains missing values.")

    # Public or Private answer Sample and Sorting by 'ID'
    submission_df = submission_df[submission_df.iloc[:, 0].isin(answer_df.iloc[:, 0])]
    submission_df = submission_df.sort_values(by='ID').reset_index(drop=True)

    # Check for length in submission_df
    if len(submission_df) != len(answer_df):
        raise ValueError("The submission dataframe wrong length.")

    # Convert position data to numpy arrays for efficient computation
    answer_positions = answer_df.iloc[:, 2:].to_numpy()  # Excluding ID, img_path, and type columns
    submission_positions = submission_df.iloc[:, 1:].to_numpy()  # Excluding ID column

    # Initialize the dictionary to hold accuracies
    accuracies = {}

    # Define combinations for 2x2 and 3x3 puzzles
    combinations_2x2 = [(i, j) for i in range(3) for j in range(3)]
    combinations_3x3 = [(i, j) for i in range(2) for j in range(2)]

    # 1x1 Puzzle Accuracy
    accuracies['1x1'] = np.mean(answer_positions == submission_positions)

    # Calculate accuracies for 2x2, 3x3, and 4x4 puzzles
    for size in range(2, 5):  # Loop through sizes 2, 3, 4
        correct_count = 0  # Initialize counter for correct full sub-puzzles
        total_subpuzzles = 0

        # Iterate through each sample's puzzle
        for i in range(len(answer_df)):
            puzzle_a = answer_positions[i].reshape(4, 4)
            puzzle_s = submission_positions[i].reshape(4, 4)
            combinations = combinations_2x2 if size == 2 else combinations_3x3 if size == 3 else [(0, 0)]

            # Calculate the number of correct sub-puzzles for this size within a 4x4
            for start_row, start_col in combinations:
                rows = slice(start_row, start_row + size)
                cols = slice(start_col, start_col + size)
                if np.array_equal(puzzle_a[rows, cols], puzzle_s[rows, cols]):
                    correct_count += 1
                total_subpuzzles += 1

        accuracies[f'{size}x{size}'] = correct_count / total_subpuzzles

    score = (accuracies['1x1'] + accuracies['2x2'] + accuracies['3x3'] + accuracies['4x4']) / 4.
    return score

In [3]:
class Model(nn.Module):
    def __init__(self, mask_ratio = 0.0, pretrained = True):
        super().__init__()

        self.mask_ratio = mask_ratio
        self.pretrained = pretrained

        deit3 = timm.create_model('deit3_base_patch16_384', pretrained = pretrained)

        self.patch_embed = deit3.patch_embed
        self.cls_token = deit3.cls_token
        self.blocks = deit3.blocks
        self.norm = deit3.norm

        self.jigsaw = nn.Sequential(
            nn.Linear(768, 768),
            nn.ReLU(),
            nn.Linear(768, 768),
            nn.ReLU(),
            nn.Linear(768, 24*24)
        )

    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))

        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]

        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        # target = einops.repeat(self.target, 'L -> N L', N=N)
        # target = target.to(x.device)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep] # N, len_keep
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
        target_masked = ids_keep

        return x_masked, target_masked

    def forward(self, x):
        x = self.patch_embed(x)
        x, target = self.random_masking(x, self.mask_ratio)

        # append cls token
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # apply Transformer blocks
        x = self.blocks(x)
        x = self.norm(x)
        x = self.jigsaw(x[:, 1:])
        return x.reshape(-1, 24*24), target.reshape(-1)

    def forward_test(self, x):
        x = self.patch_embed(x)

        # append cls token
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # apply Transformer blocks
        x = self.blocks(x)
        x = self.norm(x)
        x = self.jigsaw(x[:, 1:])
        return x

In [4]:
class JigsawDataset(Dataset):
    def __init__(self, df, data_path, mode='train', transform=None):
        self.df = df
        self.data_path = data_path
        self.mode = mode
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if self.mode == 'train':
            row = self.df.iloc[idx]
            image = read_image(os.path.join(self.data_path, row['img_path']))
            shuffle_order = row[[str(i) for i in range(1, 17)]].values-1
            image = self.reset_image(image, shuffle_order)
            image = Image.fromarray(image)
            if self.transform:
                image = self.transform(image)
            return image
        elif self.mode == 'test':
            row = self.df.iloc[idx]
            image = Image.open(os.path.join(self.data_path, row['img_path']))
            if self.transform:
                image = self.transform(image)
            return image

    def reset_image(self, image, shuffle_order):
        c, h, w = image.shape
        block_h, block_w = h//4, w//4
        image_src = [[0 for _ in range(4)] for _ in range(4)]
        for idx, order in enumerate(shuffle_order):
            h_idx, w_idx = divmod(order,4)
            h_idx_shuffle, w_idx_shuffle = divmod(idx, 4)
            image_src[h_idx][w_idx] = image[:, block_h * h_idx_shuffle : block_h * (h_idx_shuffle+1), block_w * w_idx_shuffle : block_w * (w_idx_shuffle+1)]
        image_src = np.concatenate([np.concatenate(image_row, -1) for image_row in image_src], -2)
        return image_src.transpose(1, 2, 0)

In [5]:
def build_transform(is_train):
    if is_train:
        # this should always dispatch to transforms_imagenet_train
        transform = create_transform(
            input_size = (384, 384),
            is_training = True,
            color_jitter = 0.3,
            auto_augment = 'rand-m9-mstd0.5-inc1',
            interpolation= 'bicubic',
            re_prob= 0.25,
            re_mode= 'pixel',
            re_count= 1,
        )
        return transform

    t = []
    t.append(transforms.Resize((384,384), interpolation=3))
    t.append(transforms.ToTensor())
    t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
    return transforms.Compose(t)

In [6]:
df = pd.read_csv('./DATA/train.csv')
train_df = df.iloc[:-6000]
valid_df = df.iloc[-6000:]

train_transform = build_transform(is_train = True)
valid_transform = build_transform(is_train = False)

train_dataset = JigsawDataset(df = train_df,
                              data_path = './DATA',
                              mode = 'train',
                              transform = train_transform)
valid_dataset = JigsawDataset(df = valid_df,
                              data_path = './DATA',
                              mode = 'test',
                              transform = valid_transform)

train_dataloader = DataLoader(
    train_dataset,
    batch_size = 64,
    shuffle = True
)
valid_dataloader = DataLoader(
    valid_dataset,
    batch_size = 64,
    shuffle = False
)

In [7]:
model = Model(mask_ratio = 0.5)
model.to(device)
optimizer = optim.AdamW(model.parameters(),
                        lr=3e-5,
                        weight_decay = 0.05)

In [None]:
device = 'cuda'
for epoch in range(1, 11):
    print('Epoch ', epoch)
    st = time.time()
    model.train()
    for i, x in enumerate(train_dataloader):
        x = x.to(device)

        optimizer.zero_grad()

        preds, targets = model(x)

        loss = F.cross_entropy(preds, targets)

        loss.backward()
        optimizer.step()

        if i % 1 == 0:
            print(f'[{i} / {len(train_dataloader)}] loss:', loss.item())
    et = time.time()
    print('Time elapsed: ', et-st)

Epoch  1
[0 / 1000] loss: 6.369416236877441
[1 / 1000] loss: 6.365030288696289
[2 / 1000] loss: 6.363351821899414
[3 / 1000] loss: 6.361357688903809
[4 / 1000] loss: 6.35945463180542
[5 / 1000] loss: 6.359017848968506
[6 / 1000] loss: 6.357776641845703
[7 / 1000] loss: 6.358035087585449
[8 / 1000] loss: 6.357987403869629
[9 / 1000] loss: 6.357332706451416
[10 / 1000] loss: 6.357483863830566
[11 / 1000] loss: 6.356811046600342
[12 / 1000] loss: 6.356691360473633
[13 / 1000] loss: 6.3570098876953125
[14 / 1000] loss: 6.356698989868164
[15 / 1000] loss: 6.356221675872803
[16 / 1000] loss: 6.356624603271484
[17 / 1000] loss: 6.356431484222412
[18 / 1000] loss: 6.356326580047607
[19 / 1000] loss: 6.356296539306641
[20 / 1000] loss: 6.355912685394287
[21 / 1000] loss: 6.356122016906738
[22 / 1000] loss: 6.355814456939697
[23 / 1000] loss: 6.35596227645874
[24 / 1000] loss: 6.355783939361572
[25 / 1000] loss: 6.355687141418457
[26 / 1000] loss: 6.3552021980285645
[27 / 1000] loss: 6.355058670

[224 / 1000] loss: 6.293306350708008
[225 / 1000] loss: 6.267543792724609
[226 / 1000] loss: 6.254526615142822
[227 / 1000] loss: 6.291685581207275
[228 / 1000] loss: 6.282882213592529
[229 / 1000] loss: 6.262851238250732
[230 / 1000] loss: 6.258598327636719
[231 / 1000] loss: 6.260290145874023
[232 / 1000] loss: 6.2812113761901855
[233 / 1000] loss: 6.281691551208496
[234 / 1000] loss: 6.286880970001221
[235 / 1000] loss: 6.2635498046875
[236 / 1000] loss: 6.265571117401123
[237 / 1000] loss: 6.263326168060303
[238 / 1000] loss: 6.27601957321167
[239 / 1000] loss: 6.279647350311279
[240 / 1000] loss: 6.279733657836914
[241 / 1000] loss: 6.279229164123535
[242 / 1000] loss: 6.272929668426514
[243 / 1000] loss: 6.2557220458984375
[244 / 1000] loss: 6.269204139709473
[245 / 1000] loss: 6.271265029907227
[246 / 1000] loss: 6.276095867156982
[247 / 1000] loss: 6.249558925628662
[248 / 1000] loss: 6.277849197387695
[249 / 1000] loss: 6.262361526489258
[250 / 1000] loss: 6.2845659255981445
[

[446 / 1000] loss: 6.125204086303711
[447 / 1000] loss: 6.123839378356934
[448 / 1000] loss: 6.1168365478515625
[449 / 1000] loss: 6.133732318878174
[450 / 1000] loss: 6.171934127807617
[451 / 1000] loss: 6.138761520385742
[452 / 1000] loss: 6.151305675506592
[453 / 1000] loss: 6.13554048538208
[454 / 1000] loss: 6.127178192138672
[455 / 1000] loss: 6.118402481079102
[456 / 1000] loss: 6.132478713989258
[457 / 1000] loss: 6.123964309692383
[458 / 1000] loss: 6.194113731384277
[459 / 1000] loss: 6.169597148895264
[460 / 1000] loss: 6.146810054779053
[461 / 1000] loss: 6.162473201751709
[462 / 1000] loss: 6.09674072265625
[463 / 1000] loss: 6.137644290924072
[464 / 1000] loss: 6.082513809204102
[465 / 1000] loss: 6.1098313331604
[466 / 1000] loss: 6.080120086669922
[467 / 1000] loss: 6.1160759925842285
[468 / 1000] loss: 6.121611595153809
[469 / 1000] loss: 6.126533508300781
[470 / 1000] loss: 6.154226303100586
[471 / 1000] loss: 6.132238864898682
[472 / 1000] loss: 6.150615215301514
[47

[668 / 1000] loss: 6.023766994476318
[669 / 1000] loss: 6.0206451416015625
[670 / 1000] loss: 6.063134670257568
[671 / 1000] loss: 6.003284454345703
[672 / 1000] loss: 6.042110443115234
[673 / 1000] loss: 6.087695121765137
[674 / 1000] loss: 6.012162685394287
[675 / 1000] loss: 6.03833532333374
[676 / 1000] loss: 6.078914165496826
[677 / 1000] loss: 5.99901819229126
[678 / 1000] loss: 6.0675811767578125
[679 / 1000] loss: 5.9497246742248535
[680 / 1000] loss: 6.018007278442383
[681 / 1000] loss: 5.955126762390137
[682 / 1000] loss: 5.974175453186035
[683 / 1000] loss: 5.96959114074707
[684 / 1000] loss: 5.929963111877441
[685 / 1000] loss: 6.079584121704102
[686 / 1000] loss: 6.0216965675354
[687 / 1000] loss: 5.979204177856445
[688 / 1000] loss: 5.9456353187561035
[689 / 1000] loss: 5.966953277587891
[690 / 1000] loss: 6.010131359100342
[691 / 1000] loss: 5.966424942016602
[692 / 1000] loss: 5.979798793792725
[693 / 1000] loss: 5.962350845336914
[694 / 1000] loss: 5.962478160858154
[6

[890 / 1000] loss: 5.94336462020874
[891 / 1000] loss: 5.903137683868408
[892 / 1000] loss: 5.925815582275391
[893 / 1000] loss: 5.858716011047363
[894 / 1000] loss: 5.844268798828125
[895 / 1000] loss: 5.922425746917725
[896 / 1000] loss: 5.931849002838135
[897 / 1000] loss: 5.886892318725586
[898 / 1000] loss: 5.943939685821533
[899 / 1000] loss: 5.896613121032715
[900 / 1000] loss: 5.923663139343262
[901 / 1000] loss: 5.936526298522949
[902 / 1000] loss: 5.905036926269531
[903 / 1000] loss: 5.864856243133545
[904 / 1000] loss: 5.964412212371826
[905 / 1000] loss: 5.912222385406494
[906 / 1000] loss: 5.834840774536133
[907 / 1000] loss: 5.94644832611084
[908 / 1000] loss: 5.849674224853516
[909 / 1000] loss: 5.963377475738525
[910 / 1000] loss: 5.926199436187744
[911 / 1000] loss: 5.912779331207275
[912 / 1000] loss: 5.802475929260254
[913 / 1000] loss: 5.949518203735352
[914 / 1000] loss: 5.826175689697266
[915 / 1000] loss: 5.848706245422363
[916 / 1000] loss: 5.843328952789307
[91

[114 / 1000] loss: 5.885475158691406
[115 / 1000] loss: 5.830850601196289
[116 / 1000] loss: 5.932831287384033
[117 / 1000] loss: 5.877423286437988
[118 / 1000] loss: 5.852375030517578
[119 / 1000] loss: 5.7755937576293945
[120 / 1000] loss: 5.874673366546631
[121 / 1000] loss: 5.836808681488037
[122 / 1000] loss: 5.764790058135986
[123 / 1000] loss: 5.812707901000977
[124 / 1000] loss: 5.867384910583496
[125 / 1000] loss: 5.840973854064941
[126 / 1000] loss: 5.820528507232666
[127 / 1000] loss: 5.803206443786621
[128 / 1000] loss: 5.795552730560303
[129 / 1000] loss: 5.796236038208008
[130 / 1000] loss: 5.8794145584106445
[131 / 1000] loss: 5.724178791046143
[132 / 1000] loss: 5.810998916625977
[133 / 1000] loss: 5.853448867797852
[134 / 1000] loss: 5.874540328979492
[135 / 1000] loss: 5.870401859283447
[136 / 1000] loss: 5.795689582824707
[137 / 1000] loss: 5.8290252685546875
[138 / 1000] loss: 5.757518768310547
[139 / 1000] loss: 5.8284220695495605
[140 / 1000] loss: 5.8979644775390

[336 / 1000] loss: 5.781209945678711
[337 / 1000] loss: 5.72142219543457
[338 / 1000] loss: 5.825334548950195
[339 / 1000] loss: 5.7291460037231445
[340 / 1000] loss: 5.784964084625244
[341 / 1000] loss: 5.74423360824585
[342 / 1000] loss: 5.874297142028809
[343 / 1000] loss: 5.818099498748779
[344 / 1000] loss: 5.786969184875488
[345 / 1000] loss: 5.834120273590088
[346 / 1000] loss: 5.742020606994629
[347 / 1000] loss: 5.7715277671813965
[348 / 1000] loss: 5.720001220703125
[349 / 1000] loss: 5.881773471832275
[350 / 1000] loss: 5.83884334564209
[351 / 1000] loss: 5.713202953338623
[352 / 1000] loss: 5.829568862915039
[353 / 1000] loss: 5.811051368713379
[354 / 1000] loss: 5.806196212768555
[355 / 1000] loss: 5.773443698883057
[356 / 1000] loss: 5.825904369354248
[357 / 1000] loss: 5.797574996948242
[358 / 1000] loss: 5.731424808502197
[359 / 1000] loss: 5.645011901855469
[360 / 1000] loss: 5.820802688598633
[361 / 1000] loss: 5.766984939575195
[362 / 1000] loss: 5.804952621459961
[3

[558 / 1000] loss: 5.769659519195557
[559 / 1000] loss: 5.785407066345215
[560 / 1000] loss: 5.7672529220581055
[561 / 1000] loss: 5.762307167053223
[562 / 1000] loss: 5.789980888366699
[563 / 1000] loss: 5.688490390777588
[564 / 1000] loss: 5.712503433227539
[565 / 1000] loss: 5.787868022918701
[566 / 1000] loss: 5.744716167449951
[567 / 1000] loss: 5.7843017578125
[568 / 1000] loss: 5.740514755249023
[569 / 1000] loss: 5.763274192810059
[570 / 1000] loss: 5.661851406097412
[571 / 1000] loss: 5.743999481201172
[572 / 1000] loss: 5.746822834014893
[573 / 1000] loss: 5.719452857971191
[574 / 1000] loss: 5.694503307342529
[575 / 1000] loss: 5.749443054199219
[576 / 1000] loss: 5.758082389831543
[577 / 1000] loss: 5.658321857452393
[578 / 1000] loss: 5.755247116088867
[579 / 1000] loss: 5.648279666900635
[580 / 1000] loss: 5.622764587402344
[581 / 1000] loss: 5.755499839782715
[582 / 1000] loss: 5.674375057220459
[583 / 1000] loss: 5.702480792999268
[584 / 1000] loss: 5.812337398529053
[5

[780 / 1000] loss: 5.6429338455200195
[781 / 1000] loss: 5.6982197761535645
[782 / 1000] loss: 5.627161502838135
[783 / 1000] loss: 5.717919826507568
[784 / 1000] loss: 5.627058506011963
[785 / 1000] loss: 5.675539970397949
[786 / 1000] loss: 5.7342729568481445
[787 / 1000] loss: 5.642313480377197
[788 / 1000] loss: 5.665995121002197
[789 / 1000] loss: 5.750470161437988
[790 / 1000] loss: 5.686150074005127
[791 / 1000] loss: 5.819521903991699
[792 / 1000] loss: 5.75718879699707
[793 / 1000] loss: 5.689653396606445
[794 / 1000] loss: 5.800256252288818
[795 / 1000] loss: 5.724425792694092
[796 / 1000] loss: 5.722959041595459
[797 / 1000] loss: 5.747457027435303
[798 / 1000] loss: 5.651965141296387
[799 / 1000] loss: 5.760153293609619
[800 / 1000] loss: 5.684304237365723
[801 / 1000] loss: 5.636205196380615
[802 / 1000] loss: 5.701035499572754
[803 / 1000] loss: 5.76417350769043
[804 / 1000] loss: 5.690104007720947
[805 / 1000] loss: 5.68082857131958
[806 / 1000] loss: 5.680644512176514
[

[1 / 1000] loss: 5.586784839630127
[2 / 1000] loss: 5.721348285675049
[3 / 1000] loss: 5.662965774536133
[4 / 1000] loss: 5.734747886657715
[5 / 1000] loss: 5.785597324371338
[6 / 1000] loss: 5.617464065551758
[7 / 1000] loss: 5.631441593170166
[8 / 1000] loss: 5.6274542808532715
[9 / 1000] loss: 5.598264217376709
[10 / 1000] loss: 5.625783920288086
[11 / 1000] loss: 5.694779396057129
[12 / 1000] loss: 5.758180618286133
[13 / 1000] loss: 5.64273738861084
[14 / 1000] loss: 5.728546619415283
[15 / 1000] loss: 5.728189468383789
[16 / 1000] loss: 5.643178939819336
[17 / 1000] loss: 5.708621025085449
[18 / 1000] loss: 5.677109241485596
[19 / 1000] loss: 5.712515830993652
[20 / 1000] loss: 5.701814651489258
[21 / 1000] loss: 5.7295966148376465
[22 / 1000] loss: 5.741855144500732
[23 / 1000] loss: 5.691953182220459
[24 / 1000] loss: 5.603532791137695
[25 / 1000] loss: 5.653430938720703
[26 / 1000] loss: 5.645552635192871
[27 / 1000] loss: 5.67903995513916
[28 / 1000] loss: 5.685791969299316
[

[226 / 1000] loss: 5.669903755187988
[227 / 1000] loss: 5.592763900756836
[228 / 1000] loss: 5.588376998901367
[229 / 1000] loss: 5.703458309173584
[230 / 1000] loss: 5.556839466094971
[231 / 1000] loss: 5.65504264831543
[232 / 1000] loss: 5.650009632110596
[233 / 1000] loss: 5.584856033325195
[234 / 1000] loss: 5.755302906036377
[235 / 1000] loss: 5.665627479553223
[236 / 1000] loss: 5.686534881591797
[237 / 1000] loss: 5.626214504241943
[238 / 1000] loss: 5.611289024353027
[239 / 1000] loss: 5.52072811126709
[240 / 1000] loss: 5.720132827758789
[241 / 1000] loss: 5.665929794311523
[242 / 1000] loss: 5.638678550720215
[243 / 1000] loss: 5.5767621994018555
[244 / 1000] loss: 5.6783881187438965
[245 / 1000] loss: 5.674442768096924
[246 / 1000] loss: 5.674225330352783
[247 / 1000] loss: 5.61853551864624
[248 / 1000] loss: 5.65753173828125
[249 / 1000] loss: 5.622576713562012
[250 / 1000] loss: 5.585100173950195
[251 / 1000] loss: 5.660658359527588
[252 / 1000] loss: 5.638288497924805
[25

[448 / 1000] loss: 5.608668804168701
[449 / 1000] loss: 5.596620559692383
[450 / 1000] loss: 5.604048728942871
[451 / 1000] loss: 5.525672435760498
[452 / 1000] loss: 5.558074474334717
[453 / 1000] loss: 5.5220489501953125
[454 / 1000] loss: 5.585840702056885
[455 / 1000] loss: 5.568005561828613
[456 / 1000] loss: 5.631828308105469
[457 / 1000] loss: 5.573237895965576
[458 / 1000] loss: 5.4645795822143555
[459 / 1000] loss: 5.571810722351074
[460 / 1000] loss: 5.497619152069092
[461 / 1000] loss: 5.62588357925415
[462 / 1000] loss: 5.63482141494751
[463 / 1000] loss: 5.5861101150512695
[464 / 1000] loss: 5.619303226470947
[465 / 1000] loss: 5.5561418533325195
[466 / 1000] loss: 5.525867462158203
[467 / 1000] loss: 5.54698371887207
[468 / 1000] loss: 5.576183795928955
[469 / 1000] loss: 5.617574691772461
[470 / 1000] loss: 5.590603828430176
[471 / 1000] loss: 5.612412929534912
[472 / 1000] loss: 5.681035995483398
[473 / 1000] loss: 5.687472343444824
[474 / 1000] loss: 5.556065559387207


[670 / 1000] loss: 5.533397674560547
[671 / 1000] loss: 5.71996545791626
[672 / 1000] loss: 5.500356674194336
[673 / 1000] loss: 5.524527549743652
[674 / 1000] loss: 5.512121200561523
[675 / 1000] loss: 5.5793328285217285
[676 / 1000] loss: 5.6270623207092285
[677 / 1000] loss: 5.560177803039551
[678 / 1000] loss: 5.5501933097839355
[679 / 1000] loss: 5.621738910675049
[680 / 1000] loss: 5.587961196899414
[681 / 1000] loss: 5.5715532302856445
[682 / 1000] loss: 5.585237503051758
[683 / 1000] loss: 5.553598880767822
[684 / 1000] loss: 5.532594680786133
[685 / 1000] loss: 5.532007217407227
[686 / 1000] loss: 5.547874450683594
[687 / 1000] loss: 5.614226818084717
[688 / 1000] loss: 5.586885929107666
[689 / 1000] loss: 5.499849796295166
[690 / 1000] loss: 5.6296868324279785
[691 / 1000] loss: 5.5967912673950195
[692 / 1000] loss: 5.631076335906982
[693 / 1000] loss: 5.515897750854492
[694 / 1000] loss: 5.51191520690918
[695 / 1000] loss: 5.611217021942139
[696 / 1000] loss: 5.5370144844055

[892 / 1000] loss: 5.594653129577637
[893 / 1000] loss: 5.541507244110107
[894 / 1000] loss: 5.536957740783691
[895 / 1000] loss: 5.613935470581055
[896 / 1000] loss: 5.643822193145752
[897 / 1000] loss: 5.639052391052246
[898 / 1000] loss: 5.569070816040039
[899 / 1000] loss: 5.595461368560791
[900 / 1000] loss: 5.634298801422119
[901 / 1000] loss: 5.59263277053833
[902 / 1000] loss: 5.532313346862793
[903 / 1000] loss: 5.65046501159668
[904 / 1000] loss: 5.69536018371582
[905 / 1000] loss: 5.682943344116211
[906 / 1000] loss: 5.497639179229736
[907 / 1000] loss: 5.5748162269592285
[908 / 1000] loss: 5.442249298095703
[909 / 1000] loss: 5.496419906616211
[910 / 1000] loss: 5.456437110900879
[911 / 1000] loss: 5.491948127746582
[912 / 1000] loss: 5.437686920166016
[913 / 1000] loss: 5.49531364440918
[914 / 1000] loss: 5.5230255126953125
[915 / 1000] loss: 5.55265998840332
[916 / 1000] loss: 5.593854904174805
[917 / 1000] loss: 5.614222049713135
[918 / 1000] loss: 5.498509883880615
[919

[116 / 1000] loss: 5.586987018585205
[117 / 1000] loss: 5.587159633636475
[118 / 1000] loss: 5.486969470977783
[119 / 1000] loss: 5.527318954467773
[120 / 1000] loss: 5.483011722564697
[121 / 1000] loss: 5.45863676071167
[122 / 1000] loss: 5.6077375411987305
[123 / 1000] loss: 5.461974620819092
[124 / 1000] loss: 5.600221633911133
[125 / 1000] loss: 5.568180561065674
[126 / 1000] loss: 5.510513782501221
[127 / 1000] loss: 5.537588119506836
[128 / 1000] loss: 5.671854019165039
[129 / 1000] loss: 5.646459102630615
[130 / 1000] loss: 5.593327045440674
[131 / 1000] loss: 5.551420211791992
[132 / 1000] loss: 5.570422172546387
[133 / 1000] loss: 5.516937732696533
[134 / 1000] loss: 5.567124366760254
[135 / 1000] loss: 5.52203893661499
[136 / 1000] loss: 5.584905624389648
[137 / 1000] loss: 5.46038818359375
[138 / 1000] loss: 5.525182723999023
[139 / 1000] loss: 5.498125076293945
[140 / 1000] loss: 5.535850524902344
[141 / 1000] loss: 5.414788246154785
[142 / 1000] loss: 5.5146355628967285
[1

[338 / 1000] loss: 5.512601852416992
[339 / 1000] loss: 5.592321395874023
[340 / 1000] loss: 5.611116886138916
[341 / 1000] loss: 5.451309680938721
[342 / 1000] loss: 5.5637431144714355
[343 / 1000] loss: 5.539496898651123
[344 / 1000] loss: 5.5405120849609375
[345 / 1000] loss: 5.6571855545043945
[346 / 1000] loss: 5.636633396148682
[347 / 1000] loss: 5.503559112548828
[348 / 1000] loss: 5.5772294998168945
[349 / 1000] loss: 5.56986665725708
[350 / 1000] loss: 5.5357985496521
[351 / 1000] loss: 5.469385147094727
[352 / 1000] loss: 5.517498016357422
[353 / 1000] loss: 5.501655101776123
[354 / 1000] loss: 5.361494064331055
[355 / 1000] loss: 5.495475769042969
[356 / 1000] loss: 5.5176801681518555
[357 / 1000] loss: 5.61914587020874
[358 / 1000] loss: 5.562780380249023
[359 / 1000] loss: 5.540101528167725
[360 / 1000] loss: 5.6190009117126465
[361 / 1000] loss: 5.535973072052002
[362 / 1000] loss: 5.467594623565674
[363 / 1000] loss: 5.436497211456299
[364 / 1000] loss: 5.485365867614746

In [None]:
outs = []
model.eval()
with torch.no_grad():
    for x in tqdm(valid_dataloader):
        x = x.to('cuda')
        out = model.forward_test(x)
        out = out.argmax(dim=2).cpu().numpy()
        outs.append(out)

outs = np.vstack(outs)
valid_pred_df = valid_df.copy().drop(columns=['img_path'])

for I, (idx, row) in enumerate(tqdm(valid_pred_df.iterrows(), total=len(valid_df))):
    w = outs[I].reshape(24,24)
    CNT_ROW = np.zeros((4,4,4), dtype=np.int32)
    CNT_COL = np.zeros((4,4,4), dtype=np.int32)
    for i in range(24):
        for j in range(24):
            ROW = i // 6
            COL = j // 6
            v = w[i][j]
            CNT_ROW[ROW][COL][v // 24 // 6] += 1
            CNT_COL[ROW][COL][v % 24 // 6] += 1
    ans = CNT_ROW.argmax(2) * 4 + CNT_COL.argmax(2) + 1
    ans = ans.reshape(16)
    ans = list(map(str, ans))
    valid_pred_df.loc[idx, '1':'16'] = ans
score = calc_puzzle(valid_df, valid_pred_df)
print(score)