In [None]:


import os
import math
import random
import argparse 
from pathlib import Path
from PIL import Image
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

from sklearn.model_selection import train_test_split
from scipy.optimize import linear_sum_assignment


# Dataset (Updated)
class JigsawDataset(Dataset):
    """
    Assumes CSV with columns: "Image ID" (image file name) and "Label" (permutation string)
    """
    def __init__(self, csv_path, images_dir, image_size=201, transform=None):
        self.df = pd.read_csv(csv_path)
        self.images_dir = images_dir
        self.transform = transform
        self.image_size = image_size  
        assert image_size % 3 == 0, "image_size must be divisible by 3"

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        img_name = row['image']
        label = row['label']
        
        img_path = os.path.join(self.images_dir, img_name)
        img = Image.open(img_path).convert('RGB')
        img = img.resize((self.image_size, self.image_size), Image.BILINEAR)
        img = np.array(img).astype(np.float32) / 255.0
        cut = self.image_size // 3
        
        # tiles: list of (C,H,W)
        tiles = []
        for i in range(3):
            for j in range(3):
                tile = img[i*cut:(i+1)*cut, j*cut:(j+1)*cut, :]
                tile_img = Image.fromarray((tile * 255).astype(np.uint8))
                if self.transform:
                    tile_img = self.transform(tile_img)
                else:
                    transform = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485,0.456,0.406],
                                             std=[0.229,0.224,0.225])
                    ])
                    tile_img = transform(tile_img)
                tiles.append(tile_img)  # C,H,W
        tiles = torch.stack(tiles)  # shape (9, C, H, W)

        label_list = [int(x) for x in str(label).split()]
        label_tensor = torch.tensor(label_list, dtype=torch.long)  # shape (9,)
        return tiles, label_tensor, img_name
    
# Model: per-tile encoder -> transformer -> per-tile logits over positions

class TileTransformer(nn.Module):
    def __init__(self, tile_encoder='resnet18', embed_dim=512, nhead=8, nlayers=3, dropout=0.1):
        super().__init__()
        res = models.resnet18(pretrained=True)
        res.fc = nn.Identity()
        self.backbone = res  #  512-d for resnet18
        self.embed_dim = 512
        if self.embed_dim != embed_dim:
            self.proj = nn.Linear(self.embed_dim, embed_dim)
            final_dim = embed_dim
        else:
            self.proj = None
            final_dim = self.embed_dim
        encoder_layer = nn.TransformerEncoderLayer(d_model=final_dim, nhead=nhead, dropout=dropout, dim_feedforward=final_dim*4, activation='gelu')
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=nlayers)
        self.head = nn.Linear(final_dim, 9)

    def forward(self, tiles):
        # tiles: (B, 9, C, H, W)
        B = tiles.shape[0]
        n = tiles.shape[1]  
        tiles_flat = tiles.view(B*n, tiles.shape[2], tiles.shape[3], tiles.shape[4])
        # run backbone per tile
        x = self.backbone(tiles_flat)  
        if self.proj is not None:
            x = self.proj(x)  
        x = x.view(B, n, -1)  # (B, 9, D)
        # transformer expects seq_len, batch, dim
        x_t = x.permute(1, 0, 2)
        x_out = self.transformer(x_t)  
        x_out = x_out.permute(1, 0, 2)  
        logits = self.head(x_out)  # per tile logits over positions
        return logits

def hungarian_from_logits(logits):
    """
    logits: numpy array shape (9,9) or tensor -> we compute assignment that maximizes sum of logit chosen.
    Returns a permutation list p of length 9 where p[tile_index] = assigned_position
    (tile i assigned to position p[i]).
    """
    if isinstance(logits, torch.Tensor):
        arr = logits.detach().cpu().numpy()
    else:
        arr = np.array(logits)
    cost = -arr
    row_ind, col_ind = linear_sum_assignment(cost)
    perm = [-1]*arr.shape[0]
    for r,c in zip(row_ind, col_ind):
        perm[r] = int(c)
    return perm

def compute_pra(pred_perm, true_perm):
    """pred_perm and true_perm are lists length 9, perm[i] = position for tile i"""
    return int(pred_perm == true_perm)

def compute_paa(pred_perm, true_perm):
    """Pairwise Adjacency Accuracy (4-neighbor). For each adjacent pair in grid positions,
       check if that adjacency is preserved in prediction.
       We'll compute count_correct_pairs / total_pairs.
    """
    inv_pred = {pos: i for i,pos in enumerate(pred_perm)}
    inv_true = {pos: i for i,pos in enumerate(true_perm)}
    neighbor_pairs = []
    for r in range(3):
        for c in range(3):
            pos = r*3 + c
            if c < 2:
                neighbor_pairs.append((pos, pos+1))
            if r < 2:
                neighbor_pairs.append((pos, pos+3))
    total = len(neighbor_pairs)
    correct = 0
    for a,b in neighbor_pairs:
        tile_a_true = inv_true[a]
        tile_b_true = inv_true[b]
        pred_pos_a = pred_perm[tile_a_true]
        pred_pos_b = pred_perm[tile_b_true]
        ra, ca = divmod(pred_pos_a, 3)
        rb, cb = divmod(pred_pos_b, 3)
        if abs(ra-rb) + abs(ca-cb) == 1:
            correct += 1
    return correct / total if total > 0 else 0.0

# Training & eval loops
def train_one_epoch(model, dataloader, opt, device):
    model.train()
    total_loss = 0.0
    for tiles, labels, _ in tqdm(dataloader, desc="train"):
        tiles = tiles.to(device) 
        labels = labels.to(device)  
        opt.zero_grad()
        logits = model(tiles)  
        B = logits.shape[0]
        logits_flat = logits.view(B*9, 9)
        labels_flat = labels.view(B*9)
        loss = F.cross_entropy(logits_flat, labels_flat)
        loss.backward()
        opt.step()
        total_loss += loss.item() * tiles.shape[0]
    return total_loss / len(dataloader.dataset)

def evaluate(model, dataloader, device):
    model.eval()
    total_pra = 0
    total_paa = 0.0
    n = 0
    with torch.no_grad():
        for tiles, labels, _ in tqdm(dataloader, desc="eval"):
            tiles = tiles.to(device)
            labels = labels.to(device)
            logits = model(tiles)  # (B,9,9)
            B = logits.shape[0]
            for i in range(B):
                perm = hungarian_from_logits(logits[i])
                true = labels[i].cpu().numpy().tolist()
                total_pra += compute_pra(perm, true)
                total_paa += compute_paa(perm, true)
                n += 1
    pra = total_pra / n
    paa = total_paa / n
    return pra, paa

# Inference helper for single image -> permutation string
def infer_single_image(model, image_path, image_size=201, device='cpu'):
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
    ])
    img = Image.open(image_path).convert('RGB')
    img = img.resize((image_size, image_size), Image.BILINEAR)
    img = np.array(img).astype(np.uint8)
    cut = image_size // 3
    tiles = []
    to_tensor = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406],
                             std=[0.229,0.224,0.225])
    ])
    for i in range(3):
        for j in range(3):
            tile = img[i*cut:(i+1)*cut, j*cut:(j+1)*cut, :]
            tile_img = Image.fromarray(tile)
            tiles.append(to_tensor(tile_img))
    tiles = torch.stack(tiles).unsqueeze(0).to(device)  # (1,9,C,H,W)
    model.eval()
    with torch.no_grad():
        logits = model(tiles)[0]  # (9,9)
        perm = hungarian_from_logits(logits)
    return perm

# Main & arg parsing
def main(args):
    device = torch.device('cuda' if torch.cuda.is_available() and not args.force_cpu else 'cpu')
    print("Using device:", device)

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size=args.tile_size, scale=(0.8,1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406],
                             std=[0.229,0.224,0.225])
    ])
    image_size = args.tile_size * 3

    ds = JigsawDataset(args.csv, args.images_dir, image_size=image_size, transform=None) # Pass None to use default
    
    
    idxs = list(range(len(ds)))
    train_idx, val_idx = train_test_split(idxs, test_size=args.val_frac, random_state=42)
    train_ds = torch.utils.data.Subset(ds, train_idx)
    val_ds = torch.utils.data.Subset(ds, val_idx)
    
    def collate_fn(batch):
        tiles = torch.stack([b[0] for b in batch], dim=0)  # (B,9,C,H,W)
        labels = torch.stack([b[1] for b in batch], dim=0)
        names = [b[2] for b in batch]
        
        return tiles, labels, names

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=2, collate_fn=collate_fn, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=2, collate_fn=collate_fn, pin_memory=True)

    model = TileTransformer(tile_encoder='resnet18', embed_dim=512, nhead=8, nlayers=3).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.5)

    best_pra = -1.0
    for epoch in range(args.epochs):
        print(f"Epoch {epoch+1}/{args.epochs}")
        train_loss = train_one_epoch(model, train_loader, opt, device)
        pra, paa = evaluate(model, val_loader, device)
        print(f"Train loss: {train_loss:.4f}  Val PRA: {pra:.4f}  Val PAA: {paa:.4f}")
        if pra > best_pra:
            best_pra = pra
            torch.save(model.state_dict(), args.save_model)
            print("Saved best model:", args.save_model)
        scheduler.step()

    print(f"Training complete. Best model saved to {args.save_model} with PRA: {best_pra:.4f}")
    if args.test_image:
        print(f"Running inference on: {args.test_image}")
        model.load_state_dict(torch.load(args.save_model, map_location=device))
        perm = infer_single_image(model, args.test_image, image_size=image_size, device=device)
        print("Predicted permutation string:", " ".join(str(x) for x in perm))

In [None]:

import argparse 

args = argparse.Namespace()

args.csv = "train.csv"
args.images_dir = "train/"
args.save_model = "best_jigsaw.pth"
args.test_image = ""  

args.tile_size = 67  

args.batch_size = 16
args.epochs = 100
args.lr = 1e-4
args.val_frac = 0.1

args.force_cpu = False

print("Configuration loaded:")
print(vars(args))

Configuration loaded:
{'csv': 'train.csv', 'images_dir': 'data/train/', 'save_model': 'best_jigsaw.pth', 'test_image': '', 'tile_size': 67, 'batch_size': 16, 'epochs': 100, 'lr': 0.0001, 'val_frac': 0.1, 'force_cpu': False}


In [None]:
main(args)

Using device: cuda




Epoch 1/100


train:   0%|          | 13/5240 [00:00<02:18, 37.66it/s]

train: 100%|██████████| 5240/5240 [01:52<00:00, 46.76it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 45.77it/s]


Train loss: 1.0811  Val PRA: 0.5077  Val PAA: 0.8042
Saved best model: best_jigsaw.pth
Epoch 2/100


train: 100%|██████████| 5240/5240 [01:55<00:00, 45.26it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 45.73it/s]


Train loss: 0.5463  Val PRA: 0.7588  Val PAA: 0.9110
Saved best model: best_jigsaw.pth
Epoch 3/100


train: 100%|██████████| 5240/5240 [01:54<00:00, 45.61it/s]
eval: 100%|██████████| 583/583 [00:13<00:00, 44.62it/s]


Train loss: 0.3111  Val PRA: 0.8308  Val PAA: 0.9385
Saved best model: best_jigsaw.pth
Epoch 4/100


train: 100%|██████████| 5240/5240 [01:56<00:00, 45.04it/s]
eval: 100%|██████████| 583/583 [00:13<00:00, 44.19it/s]


Train loss: 0.2126  Val PRA: 0.8696  Val PAA: 0.9543
Saved best model: best_jigsaw.pth
Epoch 5/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.22it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.81it/s]


Train loss: 0.1591  Val PRA: 0.8930  Val PAA: 0.9621
Saved best model: best_jigsaw.pth
Epoch 6/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 47.74it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.45it/s]


Train loss: 0.1289  Val PRA: 0.9067  Val PAA: 0.9681
Saved best model: best_jigsaw.pth
Epoch 7/100


train: 100%|██████████| 5240/5240 [01:48<00:00, 48.27it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.70it/s]


Train loss: 0.1098  Val PRA: 0.9093  Val PAA: 0.9686
Saved best model: best_jigsaw.pth
Epoch 8/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.27it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.79it/s]


Train loss: 0.0949  Val PRA: 0.9167  Val PAA: 0.9713
Saved best model: best_jigsaw.pth
Epoch 9/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.44it/s]
eval: 100%|██████████| 583/583 [00:11<00:00, 49.03it/s]


Train loss: 0.0857  Val PRA: 0.9259  Val PAA: 0.9752
Saved best model: best_jigsaw.pth
Epoch 10/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 48.01it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.32it/s]


Train loss: 0.0750  Val PRA: 0.9258  Val PAA: 0.9740
Epoch 11/100


train: 100%|██████████| 5240/5240 [01:48<00:00, 48.35it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.97it/s]


Train loss: 0.0372  Val PRA: 0.9509  Val PAA: 0.9839
Saved best model: best_jigsaw.pth
Epoch 12/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 47.76it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.44it/s]


Train loss: 0.0305  Val PRA: 0.9478  Val PAA: 0.9828
Epoch 13/100


train: 100%|██████████| 5240/5240 [01:47<00:00, 48.88it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.96it/s]


Train loss: 0.0285  Val PRA: 0.9494  Val PAA: 0.9829
Epoch 14/100


train: 100%|██████████| 5240/5240 [01:47<00:00, 48.79it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.47it/s]


Train loss: 0.0260  Val PRA: 0.9485  Val PAA: 0.9830
Epoch 15/100


train: 100%|██████████| 5240/5240 [01:48<00:00, 48.20it/s]
eval: 100%|██████████| 583/583 [00:11<00:00, 48.77it/s]


Train loss: 0.0237  Val PRA: 0.9529  Val PAA: 0.9844
Saved best model: best_jigsaw.pth
Epoch 16/100


train: 100%|██████████| 5240/5240 [01:52<00:00, 46.75it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 45.48it/s]


Train loss: 0.0230  Val PRA: 0.9522  Val PAA: 0.9840
Epoch 17/100


train: 100%|██████████| 5240/5240 [01:48<00:00, 48.32it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.25it/s]


Train loss: 0.0225  Val PRA: 0.9530  Val PAA: 0.9844
Saved best model: best_jigsaw.pth
Epoch 18/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 47.73it/s]
eval: 100%|██████████| 583/583 [00:11<00:00, 48.62it/s]


Train loss: 0.0210  Val PRA: 0.9510  Val PAA: 0.9836
Epoch 19/100


train: 100%|██████████| 5240/5240 [01:47<00:00, 48.61it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 46.99it/s]


Train loss: 0.0206  Val PRA: 0.9533  Val PAA: 0.9845
Saved best model: best_jigsaw.pth
Epoch 20/100


train: 100%|██████████| 5240/5240 [01:51<00:00, 47.20it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.32it/s]


Train loss: 0.0195  Val PRA: 0.9498  Val PAA: 0.9828
Epoch 21/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 48.06it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.21it/s]


Train loss: 0.0121  Val PRA: 0.9617  Val PAA: 0.9870
Saved best model: best_jigsaw.pth
Epoch 22/100


train: 100%|██████████| 5240/5240 [01:48<00:00, 48.19it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.03it/s]


Train loss: 0.0100  Val PRA: 0.9554  Val PAA: 0.9853
Epoch 23/100


train: 100%|██████████| 5240/5240 [01:47<00:00, 48.62it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.83it/s]


Train loss: 0.0094  Val PRA: 0.9581  Val PAA: 0.9859
Epoch 24/100


train: 100%|██████████| 5240/5240 [01:48<00:00, 48.37it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.30it/s]


Train loss: 0.0087  Val PRA: 0.9590  Val PAA: 0.9865
Epoch 25/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.48it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 46.86it/s]


Train loss: 0.0083  Val PRA: 0.9578  Val PAA: 0.9856
Epoch 26/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.35it/s]
eval: 100%|██████████| 583/583 [00:11<00:00, 49.00it/s]


Train loss: 0.0085  Val PRA: 0.9603  Val PAA: 0.9865
Epoch 27/100


train: 100%|██████████| 5240/5240 [01:46<00:00, 49.13it/s]
eval: 100%|██████████| 583/583 [00:11<00:00, 48.98it/s]


Train loss: 0.0077  Val PRA: 0.9621  Val PAA: 0.9873
Saved best model: best_jigsaw.pth
Epoch 28/100


train: 100%|██████████| 5240/5240 [01:51<00:00, 47.19it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.50it/s]


Train loss: 0.0074  Val PRA: 0.9608  Val PAA: 0.9868
Epoch 29/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.60it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.23it/s]


Train loss: 0.0073  Val PRA: 0.9574  Val PAA: 0.9857
Epoch 30/100


train: 100%|██████████| 5240/5240 [01:48<00:00, 48.40it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.04it/s]


Train loss: 0.0074  Val PRA: 0.9589  Val PAA: 0.9862
Epoch 31/100


train: 100%|██████████| 5240/5240 [01:51<00:00, 46.88it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.55it/s]


Train loss: 0.0050  Val PRA: 0.9610  Val PAA: 0.9869
Epoch 32/100


train: 100%|██████████| 5240/5240 [01:48<00:00, 48.17it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.32it/s]


Train loss: 0.0043  Val PRA: 0.9631  Val PAA: 0.9879
Saved best model: best_jigsaw.pth
Epoch 33/100


train: 100%|██████████| 5240/5240 [01:52<00:00, 46.66it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.50it/s]


Train loss: 0.0041  Val PRA: 0.9614  Val PAA: 0.9873
Epoch 34/100


train: 100%|██████████| 5240/5240 [01:48<00:00, 48.27it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.13it/s]


Train loss: 0.0041  Val PRA: 0.9636  Val PAA: 0.9880
Saved best model: best_jigsaw.pth
Epoch 35/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.56it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 46.70it/s]


Train loss: 0.0036  Val PRA: 0.9594  Val PAA: 0.9867
Epoch 36/100


train: 100%|██████████| 5240/5240 [01:48<00:00, 48.11it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 46.94it/s]


Train loss: 0.0034  Val PRA: 0.9626  Val PAA: 0.9874
Epoch 37/100


train: 100%|██████████| 5240/5240 [01:47<00:00, 48.82it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 46.68it/s]


Train loss: 0.0032  Val PRA: 0.9639  Val PAA: 0.9880
Saved best model: best_jigsaw.pth
Epoch 38/100


train: 100%|██████████| 5240/5240 [01:51<00:00, 47.12it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 46.79it/s]


Train loss: 0.0031  Val PRA: 0.9641  Val PAA: 0.9879
Saved best model: best_jigsaw.pth
Epoch 39/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.34it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.09it/s]


Train loss: 0.0030  Val PRA: 0.9608  Val PAA: 0.9873
Epoch 40/100


train: 100%|██████████| 5240/5240 [01:51<00:00, 46.81it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 46.67it/s]


Train loss: 0.0032  Val PRA: 0.9641  Val PAA: 0.9881
Epoch 41/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 47.86it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.89it/s]


Train loss: 0.0024  Val PRA: 0.9653  Val PAA: 0.9885
Saved best model: best_jigsaw.pth
Epoch 42/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 47.67it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.34it/s]


Train loss: 0.0020  Val PRA: 0.9653  Val PAA: 0.9885
Epoch 43/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 48.02it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.22it/s]


Train loss: 0.0020  Val PRA: 0.9641  Val PAA: 0.9880
Epoch 44/100


train: 100%|██████████| 5240/5240 [01:47<00:00, 48.72it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.94it/s]


Train loss: 0.0019  Val PRA: 0.9638  Val PAA: 0.9881
Epoch 45/100


train: 100%|██████████| 5240/5240 [01:47<00:00, 48.97it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.93it/s]


Train loss: 0.0018  Val PRA: 0.9647  Val PAA: 0.9881
Epoch 46/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 47.99it/s]
eval: 100%|██████████| 583/583 [00:11<00:00, 48.63it/s]


Train loss: 0.0017  Val PRA: 0.9650  Val PAA: 0.9882
Epoch 47/100


train: 100%|██████████| 5240/5240 [01:48<00:00, 48.47it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.09it/s]


Train loss: 0.0016  Val PRA: 0.9650  Val PAA: 0.9884
Epoch 48/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 47.89it/s]
eval: 100%|██████████| 583/583 [00:11<00:00, 49.09it/s]


Train loss: 0.0015  Val PRA: 0.9627  Val PAA: 0.9875
Epoch 49/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 47.71it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 46.49it/s]


Train loss: 0.0015  Val PRA: 0.9627  Val PAA: 0.9879
Epoch 50/100


train: 100%|██████████| 5240/5240 [01:48<00:00, 48.37it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.34it/s]


Train loss: 0.0015  Val PRA: 0.9641  Val PAA: 0.9882
Epoch 51/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.60it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.11it/s]


Train loss: 0.0013  Val PRA: 0.9644  Val PAA: 0.9880
Epoch 52/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.36it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.24it/s]


Train loss: 0.0012  Val PRA: 0.9650  Val PAA: 0.9885
Epoch 53/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 47.69it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.30it/s]


Train loss: 0.0010  Val PRA: 0.9652  Val PAA: 0.9885
Epoch 54/100


train: 100%|██████████| 5240/5240 [01:51<00:00, 47.08it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.17it/s]


Train loss: 0.0010  Val PRA: 0.9662  Val PAA: 0.9887
Saved best model: best_jigsaw.pth
Epoch 55/100


train: 100%|██████████| 5240/5240 [01:48<00:00, 48.09it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.39it/s]


Train loss: 0.0010  Val PRA: 0.9666  Val PAA: 0.9888
Saved best model: best_jigsaw.pth
Epoch 56/100


train: 100%|██████████| 5240/5240 [01:48<00:00, 48.20it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.26it/s]


Train loss: 0.0008  Val PRA: 0.9649  Val PAA: 0.9883
Epoch 57/100


train: 100%|██████████| 5240/5240 [01:47<00:00, 48.74it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.02it/s]


Train loss: 0.0009  Val PRA: 0.9649  Val PAA: 0.9884
Epoch 58/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 47.65it/s]
eval: 100%|██████████| 583/583 [00:11<00:00, 48.60it/s]


Train loss: 0.0009  Val PRA: 0.9659  Val PAA: 0.9885
Epoch 59/100


train: 100%|██████████| 5240/5240 [01:48<00:00, 48.18it/s]
eval: 100%|██████████| 583/583 [00:11<00:00, 48.70it/s]


Train loss: 0.0008  Val PRA: 0.9643  Val PAA: 0.9882
Epoch 60/100


train: 100%|██████████| 5240/5240 [01:48<00:00, 48.25it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.17it/s]


Train loss: 0.0008  Val PRA: 0.9647  Val PAA: 0.9881
Epoch 61/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.37it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.21it/s]


Train loss: 0.0007  Val PRA: 0.9663  Val PAA: 0.9885
Epoch 62/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 47.85it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.46it/s]


Train loss: 0.0008  Val PRA: 0.9664  Val PAA: 0.9886
Epoch 63/100


train: 100%|██████████| 5240/5240 [01:48<00:00, 48.33it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.29it/s]


Train loss: 0.0007  Val PRA: 0.9659  Val PAA: 0.9886
Epoch 64/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.62it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.65it/s]


Train loss: 0.0007  Val PRA: 0.9653  Val PAA: 0.9884
Epoch 65/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 48.03it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.70it/s]


Train loss: 0.0007  Val PRA: 0.9658  Val PAA: 0.9884
Epoch 66/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.53it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 46.60it/s]


Train loss: 0.0007  Val PRA: 0.9648  Val PAA: 0.9884
Epoch 67/100


train: 100%|██████████| 5240/5240 [01:52<00:00, 46.53it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.80it/s]


Train loss: 0.0006  Val PRA: 0.9660  Val PAA: 0.9887
Epoch 68/100


train: 100%|██████████| 5240/5240 [01:47<00:00, 48.65it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 45.83it/s]


Train loss: 0.0006  Val PRA: 0.9656  Val PAA: 0.9886
Epoch 69/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.46it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.21it/s]


Train loss: 0.0006  Val PRA: 0.9656  Val PAA: 0.9885
Epoch 70/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 48.02it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.05it/s]


Train loss: 0.0006  Val PRA: 0.9656  Val PAA: 0.9886
Epoch 71/100


train: 100%|██████████| 5240/5240 [01:48<00:00, 48.41it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.00it/s]


Train loss: 0.0005  Val PRA: 0.9656  Val PAA: 0.9886
Epoch 72/100


train: 100%|██████████| 5240/5240 [01:47<00:00, 48.82it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.71it/s]


Train loss: 0.0006  Val PRA: 0.9659  Val PAA: 0.9887
Epoch 73/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.35it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 46.94it/s]


Train loss: 0.0006  Val PRA: 0.9659  Val PAA: 0.9886
Epoch 74/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 48.03it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.11it/s]


Train loss: 0.0006  Val PRA: 0.9654  Val PAA: 0.9885
Epoch 75/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.41it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 46.61it/s]


Train loss: 0.0005  Val PRA: 0.9664  Val PAA: 0.9887
Epoch 76/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 47.95it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.13it/s]


Train loss: 0.0005  Val PRA: 0.9669  Val PAA: 0.9889
Saved best model: best_jigsaw.pth
Epoch 77/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.41it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.30it/s]


Train loss: 0.0005  Val PRA: 0.9661  Val PAA: 0.9888
Epoch 78/100


train: 100%|██████████| 5240/5240 [01:48<00:00, 48.19it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.75it/s]


Train loss: 0.0005  Val PRA: 0.9659  Val PAA: 0.9888
Epoch 79/100


train: 100%|██████████| 5240/5240 [01:51<00:00, 47.10it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.51it/s]


Train loss: 0.0005  Val PRA: 0.9680  Val PAA: 0.9892
Saved best model: best_jigsaw.pth
Epoch 80/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 48.05it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.82it/s]


Train loss: 0.0005  Val PRA: 0.9658  Val PAA: 0.9885
Epoch 81/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.33it/s]
eval: 100%|██████████| 583/583 [00:11<00:00, 48.62it/s]


Train loss: 0.0004  Val PRA: 0.9660  Val PAA: 0.9886
Epoch 82/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.49it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.10it/s]


Train loss: 0.0005  Val PRA: 0.9662  Val PAA: 0.9887
Epoch 83/100


train: 100%|██████████| 5240/5240 [01:48<00:00, 48.39it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 46.49it/s]


Train loss: 0.0004  Val PRA: 0.9659  Val PAA: 0.9886
Epoch 84/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.29it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.23it/s]


Train loss: 0.0005  Val PRA: 0.9651  Val PAA: 0.9884
Epoch 85/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 47.99it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.47it/s]


Train loss: 0.0005  Val PRA: 0.9666  Val PAA: 0.9888
Epoch 86/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 47.96it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.58it/s]


Train loss: 0.0005  Val PRA: 0.9663  Val PAA: 0.9887
Epoch 87/100


train: 100%|██████████| 5240/5240 [01:51<00:00, 46.92it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.27it/s]


Train loss: 0.0005  Val PRA: 0.9669  Val PAA: 0.9888
Epoch 88/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.54it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.00it/s]


Train loss: 0.0004  Val PRA: 0.9661  Val PAA: 0.9885
Epoch 89/100


train: 100%|██████████| 5240/5240 [01:51<00:00, 47.17it/s]
eval: 100%|██████████| 583/583 [00:11<00:00, 48.91it/s]


Train loss: 0.0005  Val PRA: 0.9660  Val PAA: 0.9887
Epoch 90/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.26it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.07it/s]


Train loss: 0.0004  Val PRA: 0.9669  Val PAA: 0.9890
Epoch 91/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.49it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.41it/s]


Train loss: 0.0005  Val PRA: 0.9652  Val PAA: 0.9884
Epoch 92/100


train: 100%|██████████| 5240/5240 [01:47<00:00, 48.64it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.02it/s]


Train loss: 0.0004  Val PRA: 0.9656  Val PAA: 0.9885
Epoch 93/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.33it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.36it/s]


Train loss: 0.0004  Val PRA: 0.9659  Val PAA: 0.9884
Epoch 94/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.39it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.75it/s]


Train loss: 0.0004  Val PRA: 0.9661  Val PAA: 0.9887
Epoch 95/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.58it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.12it/s]


Train loss: 0.0004  Val PRA: 0.9653  Val PAA: 0.9884
Epoch 96/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 48.07it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.07it/s]


Train loss: 0.0005  Val PRA: 0.9673  Val PAA: 0.9891
Epoch 97/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 47.87it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.77it/s]


Train loss: 0.0004  Val PRA: 0.9662  Val PAA: 0.9887
Epoch 98/100


train: 100%|██████████| 5240/5240 [01:50<00:00, 47.64it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 46.70it/s]


Train loss: 0.0005  Val PRA: 0.9663  Val PAA: 0.9888
Epoch 99/100


train: 100%|██████████| 5240/5240 [01:49<00:00, 47.77it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 47.89it/s]


Train loss: 0.0004  Val PRA: 0.9669  Val PAA: 0.9889
Epoch 100/100


train: 100%|██████████| 5240/5240 [01:47<00:00, 48.58it/s]
eval: 100%|██████████| 583/583 [00:12<00:00, 48.02it/s]

Train loss: 0.0004  Val PRA: 0.9669  Val PAA: 0.9889
Training complete. Best model saved to best_jigsaw.pth with PRA: 0.9680





In [None]:
# To predict the output run the following command:
"""
python code.py --image_dir valid \
                  --model_path best_jigsaw.pth \
                  --csv_output predictions.csv \
                  --json_output predictions.json
"""
# OR 
              
# python code.py --image_dir valid --model_path best_jigsaw.pth --csv_output predictions.csv --json_output predictions.json


import os
import csv
import json
import argparse
from pathlib import Path
from PIL import Image
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from scipy.optimize import linear_sum_assignment


class TileTransformer(nn.Module):
    """
    Defines the model architecture. Must be identical to the
    class used during training to load the state_dict.
    """
    def __init__(self, tile_encoder='resnet18', embed_dim=512, nhead=8, nlayers=3, dropout=0.1):
        super().__init__()
        res = models.resnet18(pretrained=False) # Set pretrained=False for inference if not needed
        res.fc = nn.Identity()
        self.backbone = res  # outputs 512-d for resnet18
        self.embed_dim = 512
        if self.embed_dim != embed_dim:
            self.proj = nn.Linear(self.embed_dim, embed_dim)
            final_dim = embed_dim
        else:
            self.proj = None
            final_dim = self.embed_dim
        encoder_layer = nn.TransformerEncoderLayer(d_model=final_dim, nhead=nhead, dropout=dropout, dim_feedforward=final_dim*4, activation='gelu')
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=nlayers)
        self.head = nn.Linear(final_dim, 9)

    def forward(self, tiles):
        B = tiles.shape[0]
        n = tiles.shape[1]  # 9
        tiles_flat = tiles.view(B*n, tiles.shape[2], tiles.shape[3], tiles.shape[4])
        x = self.backbone(tiles_flat)  # (B*n, 512)
        if self.proj is not None:
            x = self.proj(x)  # (B*n, embed_dim)
        x = x.view(B, n, -1)  # (B, 9, D)
        # transformer expects seq_len, batch, dim
        x_t = x.permute(1, 0, 2)
        x_out = self.transformer(x_t)  
        x_out = x_out.permute(1, 0, 2)  
        logits = self.head(x_out)  #  per tile logits over positions
        return logits

def hungarian_from_logits(logits):

    if isinstance(logits, torch.Tensor):
        arr = logits.detach().cpu().numpy()
    else:
        arr = np.array(logits)
    cost = -arr
    row_ind, col_ind = linear_sum_assignment(cost)
    perm = [-1]*arr.shape[0]
    for r,c in zip(row_ind, col_ind):
        perm[r] = int(c)
    return perm

def infer_single_image(model, image_path, image_size=201, device='cpu'):

    to_tensor = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406],
                             std=[0.229,0.224,0.225])
    ])
    
    try:
        img = Image.open(image_path).convert('RGB')
        img = img.resize((image_size, image_size), Image.BILINEAR)
        img = np.array(img).astype(np.uint8)
    except Exception as e:
        print(f"Error opening or resizing image {image_path}: {e}")
        return None

    cut = image_size // 3
    tiles = []
    
    for i in range(3):
        for j in range(3):
            tile = img[i*cut:(i+1)*cut, j*cut:(j+1)*cut, :]
            tile_img = Image.fromarray(tile)
            tiles.append(to_tensor(tile_img))
            
    tiles = torch.stack(tiles).unsqueeze(0).to(device)  # (1, 9, C, H, W)
    
    model.eval()
    with torch.no_grad():
        logits = model(tiles)[0]  # (9, 9)
        perm = hungarian_from_logits(logits)
        
    return perm


def run_inference(model_path, image_dir, csv_output_path, json_output_path, tile_size, force_cpu):
    """
    Main function to run inference on a directory of images.
    """
    device = torch.device('cuda' if torch.cuda.is_available() and not force_cpu else 'cpu')
    print(f"Using device: {device}")

    image_size = tile_size * 3
    print(f"Using tile size: {tile_size}, full image size: {image_size}x{image_size}")

    print("Loading model...")
    model = TileTransformer(
        tile_encoder='resnet18', 
        embed_dim=512, 
        nhead=8, 
        nlayers=3
    ).to(device)
    
    try:
        model.load_state_dict(torch.load(model_path, map_location=device))
    except FileNotFoundError:
        print(f"Error: Model file not found at {model_path}")
        return
    except Exception as e:
        print(f"Error loading model state_dict: {e}")
        print("Ensure the model architecture in this script matches the one used for training.")
        return
        
    model.eval()
    print("Model loaded successfully.")

    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tif', '.tiff'}
    image_files = sorted([
        f for f in os.listdir(image_dir) 
        if os.path.splitext(f)[1].lower() in image_extensions
    ])
    
    if not image_files:
        print(f"No images found in directory: {image_dir}")
        return

    print(f"Found {len(image_files)} images to process...")

    all_results = []
    for filename in tqdm(image_files, desc="Predicting"):
        image_path = os.path.join(image_dir, filename)
        
        perm_list = infer_single_image(model, image_path, image_size=image_size, device=device)
        
        if perm_list:
            perm_string = " ".join(str(p) for p in perm_list)
            
            all_results.append({
                "filename": filename,
                "sequence_list": perm_list,
                "sequence_string": perm_string
            })
        else:
            print(f"Skipping corrupt or unreadable file: {filename}")

    print(f"Saving CSV results to {csv_output_path}...")
    try:
        with open(csv_output_path, 'w', newline='', encoding='utf-8') as f:
            writer = csv.writer(f)
            writer.writerow(["filename", "sequence"])
            for res in all_results:
                writer.writerow([res["filename"], res["sequence_string"]])
    except Exception as e:
        print(f"Error saving CSV file: {e}")

    print(f"Saving JSON results to {json_output_path}...")
    json_data = {
        "images": [
            {
                "filename": res["filename"],
                "sequence": res["sequence_list"]
            } for res in all_results
        ]
    }
    
    try:
        with open(json_output_path, 'w', encoding='utf-8') as f:
            json.dump(json_data, f, indent=2)
    except Exception as e:
        print(f"Error saving JSON file: {e}")
        
    print("\nInference complete.")
    print(f"CSV saved to: {csv_output_path}")
    print(f"JSON saved to: {json_output_path}")

# -----------------------------------------------------------

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Jigsaw Puzzle Inference Script")
    
    parser.add_argument(
        "--image_dir", 
        type=str, 
        required=True, 
        help="Path to the directory containing images for inference."
    )
    parser.add_argument(
        "--model_path", 
        type=str, 
        required=True, 
        help="Path to the saved model file (e.g., 'best_jigsaw.pth')."
    )
    parser.add_argument(
        "--csv_output", 
        type=str, 
        default="predictions.csv", 
        help="Path to save the output CSV file."
    )
    parser.add_argument(
        "--json_output", 
        type=str, 
        default="predictions.json", 
        help="Path to save the output JSON file."
    )
    parser.add_argument(
        "--tile_size", 
        type=int, 
        default=67, 
        help="Tile size used during training. (Default: 67, as in your script)"
    )
    parser.add_argument(
        "--force_cpu", 
        action='store_true', 
        help="Force use of CPU even if CUDA is available."
    )

    args = parser.parse_args()

    run_inference(
        model_path=args.model_path,
        image_dir=args.image_dir,
        csv_output_path=args.csv_output,
        json_output_path=args.json_output,
        tile_size=args.tile_size,
        force_cpu=args.force_cpu
    )